use crate::stream::ServerManager; use bytes::{BufMut, Bytes, BytesMut}; use futures_util::stream::{self}; use h2::server::SendResponse; use h2::{RecvStream, client}; use http::{Request, Response, StatusCode}; use log::{debug, error}; use serde::de::DeserializeOwned; use std::collections::HashMap; use std::io; use std::pin::Pin; use std::sync::Arc; use tokio_util::io::StreamReader; type BoxFuture = Pin + Send>>; type Handler = Arc< dyn Fn(Request, SendResponse) -> BoxFuture> + Send + Sync, >; pub struct Router { routes: HashMap, } impl Default for Router { fn default() -> Self { Self::new() } } impl Router { pub fn new() -> Self { Self { routes: HashMap::new(), } } pub fn add_route(&mut self, path: &str, handler: F) where F: Fn(Request, SendResponse) -> Fut + Send + Sync + 'static, Fut: Future> + Send + 'static, { self.routes.insert( path.to_string(), Arc::new(move |req, res| Box::pin(handler(req, res))), ); } pub async fn dispatch( &self, req: Request, mut res: SendResponse, ) -> Result<(), h2::Error> { let path = req.uri().path(); debug!("Received request for path: {}", path); if let Some(handler) = self.routes.get(path) { handler(req, res).await } else { // 404 let response = Response::builder().status(404).body(()).unwrap(); res.send_response(response, true)?; Ok(()) } } } pub struct Server { router: Arc, } impl Server { pub fn new(router: Arc) -> Self { Self { router } } pub async fn on_conn(&self, server: Arc) -> anyhow::Result<()> { let router = self.router.clone(); while let Some(result) = server.accept().await { let (request, respond) = result?; let r = router.clone(); tokio::spawn(async move { if let Err(e) = r.dispatch(request, respond).await { error!("Handler error: {:?}", e); } }); } Ok(()) } } pub fn send(mut res: SendResponse, status: u16, content_type: &str, body: String) { let response = Response::builder() .status(status) .header("content-type", content_type) .body(()) .unwrap(); if let Ok(mut send_stream) = res.send_response(response, false) { let _ = send_stream.send_data(Bytes::from(body), true); } } pub fn send_error(mut res: SendResponse, error: anyhow::Error) { let response = Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .header("content-type", "text/plain") .body(()) .unwrap(); if let Ok(mut send_stream) = res.send_response(response, false) { let error_msg = format!("Internal Server Error: {}", error); let _ = send_stream.send_data(Bytes::from(error_msg), true); } } pub async fn json_from_request(req: &mut Request) -> anyhow::Result where T: DeserializeOwned, { let body_stream = req.body_mut(); let mut buf = BytesMut::new(); while let Some(chunk) = body_stream.data().await { let data = chunk?; let len = data.len(); buf.put(data); body_stream.flow_control().release_capacity(len)?; } if buf.is_empty() { return Err(anyhow::anyhow!("Request body is empty")); } let result = serde_json::from_slice(&buf)?; Ok(result) } pub async fn request( client: &mut client::SendRequest, path: &str, body: String, ) -> anyhow::Result> { let request = Request::builder() .method("POST") .uri("http://0.0.0.0".to_owned() + path) .header("content-type", "application/json") .body(())?; let (response, mut send_stream) = client.send_request(request, false)?; send_stream.send_data(Bytes::from(body), true)?; Ok(response.await?) } async fn bytes_from_response(response: &mut Response) -> anyhow::Result { let body_stream = response.body_mut(); let mut buf = BytesMut::new(); while let Some(chunk) = body_stream.data().await { let data = chunk?; let len = data.len(); buf.put(data); body_stream.flow_control().release_capacity(len)?; } Ok(buf.freeze()) } pub async fn text_from_response(response: &mut Response) -> anyhow::Result { let buf = bytes_from_response(response).await?; if buf.is_empty() { return Ok(String::new()); } Ok(String::from_utf8_lossy(&buf).to_string()) } pub async fn json_from_response(response: &mut Response) -> anyhow::Result where T: DeserializeOwned, { let buf = bytes_from_response(response).await?; if buf.is_empty() { return Err(anyhow::anyhow!("Request body is empty")); } let result = serde_json::from_slice(&buf)?; Ok(result) } pub fn response_to_async_read(res: Response) -> impl tokio::io::AsyncRead { let body = res.into_body(); let byte_stream = stream::unfold(body, |mut body| async move { match body.data().await { Some(Ok(bytes)) => { let len = bytes.len(); if let Err(e) = body.flow_control().release_capacity(len) { return Some((Err(io::Error::other(e)), body)); } Some((Ok::<_, io::Error>(bytes), body)) } Some(Err(e)) => Some((Err(io::Error::other(e)), body)), None => None, } }); StreamReader::new(byte_stream) }