// SPDX-License-Identifier: MIT

use anyhow::{Context, Result};
use axum::extract::ws::{self, WebSocket};
use futures::{
    sink::SinkExt,
    stream::{SplitSink, SplitStream, StreamExt},
};
use tracing::{info, trace};

use crate::common::process::{InputStream, OutputChannel, OutputSink};

impl InputStream for &mut SplitStream<&mut WebSocket> {
    async fn next_chunk(&mut self) -> Option<Vec<u8>> {
        match self.next().await {
            None => {
                info!("Websocket closed");
                None
            }
            Some(Err(e)) => {
                info!("Error reading from websocket: {:?}", e);
                None
            }
            Some(Ok(ws::Message::Binary(chunk))) => Some(chunk.to_vec()),
            Some(Ok(ws::Message::Text(text))) if text.as_str() == "END_OF_FILE" => {
                trace!("End of file indicator");
                None
            }
            Some(Ok(msg)) => {
                info!("Unexpected message after file: {:?}", msg);
                None
            }
        }
    }
}

impl OutputSink for &mut SplitSink<&mut WebSocket, ws::Message> {
    async fn send_line(&mut self, _chan: OutputChannel, data: String) -> Result<()> {
        self.send(ws::Message::text(data))
            .await
            .context("Websocket closed?")
    }
    async fn check_output(&mut self) -> Result<()> {
        self.flush().await.context("Websocket closed?")
    }
}
