use anyhow::anyhow; use bytes::Bytes; use h2::{RecvStream, client, server}; use http::Request; use log::{debug, error, info}; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::error::Error; use std::fs::File; use std::io::BufReader; use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{Mutex, Notify}; use tokio::time::{Duration, sleep, timeout}; use tokio_rustls::rustls; use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName}; use tokio_rustls::{TlsAcceptor, TlsConnector}; #[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum Identity { Server, Client, } impl Identity { pub fn as_u8(&self) -> u8 { match self { Identity::Server => b'S', Identity::Client => b'C', } } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ClientTunnelConfig { pub host: Option, pub url: String, pub token: String, pub identity: Identity, } impl Default for ClientTunnelConfig { fn default() -> Self { Self { host: None, url: "127.0.0.1:3333".to_string(), token: "super_secret_magic_token".to_string(), identity: Identity::Client, } } } unsafe impl Send for ClientTunnelConfig {} #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct SslConfig { pub cert: String, pub key: String, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ServerTunnelConfig { pub url: String, pub cert: Option, pub token: String, pub identity: Identity, } impl Default for ServerTunnelConfig { fn default() -> Self { Self { url: "127.0.0.1:3333".to_string(), cert: None, token: "super_secret_magic_token".to_string(), identity: Identity::Client, } } } static CLIENT_HELLO: &[u8] = b"Sekai unpacker client hello"; static SERVER_HELLO: &[u8] = b"Sekai unpacker server hello"; static TLS_BOOTSTRAP_MAGIC: &[u8] = b"SPTLSP1"; static RESUME_MAGIC: &[u8] = b"SPRESM1"; pub enum TunnelEndpointRaw { Client(Box>), Server(Box>), } #[derive(Clone)] pub enum TunnelEndpoint { Client(Arc), Server(Arc), } pub struct ClientManager { pub session_id: AtomicU64, pub current_client: Mutex>>, pub config: Option, pub notify: Arc, } impl ClientManager { pub async fn get_client(&self) -> anyhow::Result> { loop { let cached = { self.current_client.lock().await.clone() }; if let Some(mut c) = cached { if futures::future::poll_fn(|cx| c.poll_ready(cx)) .await .is_ok() { return Ok(c); } debug!("Client physical connection dead, preparing to recover..."); } let mut lock = self.current_client.lock().await; if lock.is_some() && futures::future::poll_fn(|cx| lock.as_mut().unwrap().poll_ready(cx)) .await .is_ok() { continue; } if let Some(config) = &self.config { let mut retry_delay = Duration::from_secs(1); let mut sid = self.session_id.load(Ordering::Relaxed); let new_client = loop { match do_client_reconnect(config, &mut sid).await { Ok(TunnelEndpointRaw::Client(c)) => { self.session_id.store(sid, Ordering::Relaxed); break c; } Ok(_) => return Err(anyhow::anyhow!("Identity mismatch on reconnect")), Err(e) => { error!("Reconnect failed: {}. Retrying in {:?}...", e, retry_delay); drop(e); sleep(retry_delay).await; retry_delay = std::cmp::min(retry_delay * 2, Duration::from_secs(30)); } } }; *lock = Some(*new_client.clone()); return Ok(*new_client); } else { *lock = None; drop(lock); match timeout(Duration::from_secs(15), self.notify.notified()).await { Ok(_) => continue, Err(_) => { anyhow::bail!("Session did not reconnect within 15s. Throwing error!") } } } } } } pub struct ServerManager { pub session_id: AtomicU64, pub current_server: Mutex>>, pub config: Option, pub notify: Arc, } impl ServerManager { pub async fn accept( &self, ) -> Option, server::SendResponse), h2::Error>> { loop { let mut conn_guard = self.current_server.lock().await; if let Some(conn) = conn_guard.as_mut() { if let Some(res) = conn.accept().await { return Some(res); } *conn_guard = None; log::warn!("Server physical connection dropped, waiting for recovery..."); } drop(conn_guard); if let Some(config) = &self.config { let mut retry_delay = Duration::from_secs(1); let mut sid = self.session_id.load(Ordering::Relaxed); let new_server = loop { match do_client_reconnect(config, &mut sid).await { Ok(TunnelEndpointRaw::Server(s)) => { self.session_id.store(sid, Ordering::Relaxed); break s; } Ok(_) => { error!("Identity mismatch on reconnect"); return None; } Err(e) => { error!("Reconnect failed: {}. Retrying in {:?}...", e, retry_delay); drop(e); sleep(retry_delay).await; retry_delay = std::cmp::min(retry_delay * 2, Duration::from_secs(30)); } } }; *self.current_server.lock().await = Some(*new_server); continue; } else { match timeout(Duration::from_secs(15), self.notify.notified()).await { Ok(_) => continue, Err(_) => { error!( "Session did not reconnect within 15s. Throwing error (returning None)!" ); return None; } } } } } } enum ResumeResult { NotResume(TcpStream), NewSession(TunnelEndpointRaw, u64), ResumedExisting, Invalid, } pub struct TunnelListener { listener: TcpListener, config: ServerTunnelConfig, pending_plain_sessions: Mutex>, next_session_id: AtomicU64, active_sessions: Mutex>, } impl TunnelListener { pub async fn bind(config: ServerTunnelConfig) -> anyhow::Result { let listener = TcpListener::bind(&config.url).await?; info!("TCP tunnel listener bound to {}", &config.url); Ok(Self { listener, config, pending_plain_sessions: Mutex::new(HashSet::new()), next_session_id: AtomicU64::new(1), active_sessions: Mutex::new(HashMap::new()), }) } pub async fn accept(&self) -> anyhow::Result { loop { let (stream, peer_addr) = self.listener.accept().await?; debug!("[{}] Connected on tcp", peer_addr); if is_tls_client_hello(&stream).await { if let Err(e) = self.handle_tls_bootstrap(stream, peer_addr).await { error!("[{}] TLS bootstrap failed: {}", peer_addr, e); } continue; } let mut stream = match self.try_resume_plain_session(stream, peer_addr).await? { ResumeResult::NewSession(ep_raw, sid) => { let ep = wrap_raw_endpoint(sid, ep_raw, None); self.active_sessions.lock().await.insert(sid, ep.clone()); return Ok(ep); } ResumeResult::ResumedExisting => continue, ResumeResult::Invalid => continue, ResumeResult::NotResume(s) => s, }; if perform_server_handshake(&mut stream, &self.config.token, self.config.identity) .await .is_err() { debug!("[{}] Plain handshake failed", peer_addr); continue; } debug!( "[{}] Plain handshake completed, upgrading to H2...", peer_addr ); let ep_raw = upgrade_to_h2_raw(stream, self.config.identity) .await .map_err(|e| anyhow::anyhow!("{}", e))?; let sid = self.next_session_id.fetch_add(1, Ordering::Relaxed); let ep = wrap_raw_endpoint(sid, ep_raw, None); self.active_sessions.lock().await.insert(sid, ep.clone()); return Ok(ep); } } async fn handle_tls_bootstrap( &self, stream: TcpStream, peer_addr: std::net::SocketAddr, ) -> Result<(), Box> { let cert_cfg = self .config .cert .as_ref() .ok_or("TLS is not enabled on server")?; let acceptor = TlsAcceptor::from(build_server_tls_config(cert_cfg)?); let mut tls_stream = acceptor.accept(stream).await?; perform_server_handshake(&mut tls_stream, &self.config.token, self.config.identity).await?; let mut magic = vec![0u8; TLS_BOOTSTRAP_MAGIC.len()]; tls_stream.read_exact(&mut magic).await?; if magic != TLS_BOOTSTRAP_MAGIC { return Err("TLS bootstrap marker mismatch".into()); } let session_id = self.next_session_id.fetch_add(1, Ordering::Relaxed); { let mut pending = self.pending_plain_sessions.lock().await; pending.insert(session_id); } tls_stream.write_all(TLS_BOOTSTRAP_MAGIC).await?; tls_stream.write_all(&session_id.to_be_bytes()).await?; debug!( "[{}] TLS bootstrap done, issued plain-H2 session {}", peer_addr, session_id ); Ok(()) } async fn try_resume_plain_session( &self, mut stream: TcpStream, peer_addr: std::net::SocketAddr, ) -> anyhow::Result { let mut peek_buf = vec![0u8; RESUME_MAGIC.len()]; let n = stream.peek(&mut peek_buf).await?; if n < RESUME_MAGIC.len() || peek_buf != RESUME_MAGIC { return Ok(ResumeResult::NotResume(stream)); } let mut magic = vec![0u8; RESUME_MAGIC.len()]; stream.read_exact(&mut magic).await?; let mut sid_buf = [0u8; 8]; stream.read_exact(&mut sid_buf).await?; let session_id = u64::from_be_bytes(sid_buf); let is_pending = { let mut pending = self.pending_plain_sessions.lock().await; pending.remove(&session_id) }; if is_pending { let ep_raw = upgrade_to_h2_raw(stream, self.config.identity) .await .map_err(|e| anyhow::anyhow!("{}", e))?; return Ok(ResumeResult::NewSession(ep_raw, session_id)); } let active = self.active_sessions.lock().await.get(&session_id).cloned(); if let Some(ep) = active { let ep_raw = upgrade_to_h2_raw(stream, self.config.identity) .await .map_err(|e| anyhow::anyhow!("{}", e))?; update_endpoint(&ep, ep_raw).await; info!( "[{}] Successfully resumed existing session {}", peer_addr, session_id ); return Ok(ResumeResult::ResumedExisting); } error!("[{}] Invalid plain-H2 session {}", peer_addr, session_id); Ok(ResumeResult::Invalid) } } async fn is_tls_client_hello(stream: &TcpStream) -> bool { let mut header = [0u8; 3]; match stream.peek(&mut header).await { Ok(n) if n >= 3 => header[0] == 0x16 && header[1] == 0x03 && (1..=4).contains(&header[2]), _ => false, } } async fn perform_server_handshake( stream: &mut S, token: &str, identity: Identity, ) -> Result<(), Box> where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, { let mut client_hello_buf = vec![0u8; CLIENT_HELLO.len()]; stream.read_exact(&mut client_hello_buf).await?; if client_hello_buf != CLIENT_HELLO { return Err("Client Hello mismatch".into()); } let mut token_buf = vec![0u8; token.len()]; stream.read_exact(&mut token_buf).await?; if String::from_utf8_lossy(&token_buf) != token { return Err("Wrong token".into()); } stream.write_all(SERVER_HELLO).await?; let mut peer_id_buf = [0u8; 1]; stream.read_exact(&mut peer_id_buf).await?; stream.write_all(&[identity.as_u8()]).await?; if peer_id_buf[0] == identity.as_u8() { return Err("Identity collision with peer".into()); } Ok(()) } async fn perform_client_handshake( stream: &mut S, token: &str, identity: Identity, ) -> anyhow::Result<()> where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, { stream.write_all(CLIENT_HELLO).await?; stream.write_all(token.as_bytes()).await?; let mut server_hello_buf = vec![0u8; SERVER_HELLO.len()]; stream.read_exact(&mut server_hello_buf).await?; if server_hello_buf != SERVER_HELLO { return Err(anyhow!("Server Hello mismatch")); } stream.write_all(&[identity.as_u8()]).await?; let mut peer_id_buf = [0u8; 1]; stream.read_exact(&mut peer_id_buf).await?; if peer_id_buf[0] == identity.as_u8() { return Err(anyhow!("Identity collision with server")); } Ok(()) } fn build_server_tls_config( cert_cfg: &SslConfig, ) -> Result, Box> { let cert_file = File::open(&cert_cfg.cert)?; let mut cert_reader = BufReader::new(cert_file); let cert_chain: Vec> = rustls_pemfile::certs(&mut cert_reader).collect::, _>>()?; if cert_chain.is_empty() { return Err("No certificate found in PEM".into()); } let key_file = File::open(&cert_cfg.key)?; let mut key_reader = BufReader::new(key_file); let private_key = rustls_pemfile::private_key(&mut key_reader)?.ok_or("No private key found in PEM")?; let server_config = rustls::ServerConfig::builder() .with_no_client_auth() .with_single_cert(cert_chain, private_key)?; Ok(Arc::new(server_config)) } fn build_client_tls_connector() -> TlsConnector { let root_store = rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); let config = rustls::ClientConfig::builder() .with_root_certificates(root_store) .with_no_client_auth(); TlsConnector::from(Arc::new(config)) } async fn upgrade_to_h2_raw( stream: TcpStream, identity: Identity, ) -> anyhow::Result { match identity { Identity::Client => { let (h2_client, h2_conn) = client::handshake(stream).await?; tokio::spawn(async move { if let Err(e) = h2_conn.await { debug!("H2 connection driver finished/error: {:?}", e); } }); Ok(TunnelEndpointRaw::Client(Box::new(h2_client))) } Identity::Server => { let h2_conn = server::handshake(stream).await?; Ok(TunnelEndpointRaw::Server(Box::new(h2_conn))) } } } fn wrap_raw_endpoint( sid: u64, raw: TunnelEndpointRaw, config: Option, ) -> TunnelEndpoint { match raw { TunnelEndpointRaw::Client(c) => TunnelEndpoint::Client(Arc::new(ClientManager { session_id: AtomicU64::new(sid), current_client: Mutex::new(Some(*c)), config, notify: Arc::new(Notify::new()), })), TunnelEndpointRaw::Server(s) => TunnelEndpoint::Server(Arc::new(ServerManager { session_id: AtomicU64::new(sid), current_server: Mutex::new(Some(*s)), config, notify: Arc::new(Notify::new()), })), } } async fn update_endpoint(ep: &TunnelEndpoint, ep_raw: TunnelEndpointRaw) { match (ep, ep_raw) { (TunnelEndpoint::Client(mgr), TunnelEndpointRaw::Client(c)) => { *mgr.current_client.lock().await = Some(*c); mgr.notify.notify_waiters(); } (TunnelEndpoint::Server(mgr), TunnelEndpointRaw::Server(s)) => { *mgr.current_server.lock().await = Some(*s); mgr.notify.notify_waiters(); } _ => error!("Identity mismatch during session resume!"), } } pub async fn connect_tunnel(config: ClientTunnelConfig) -> Result> { info!("Connecting to tunnel at {}", &config.url); let mut sid = 0; let ep_raw = do_client_reconnect(&config, &mut sid).await?; Ok(wrap_raw_endpoint(sid, ep_raw, Some(config))) } async fn do_client_reconnect( config: &ClientTunnelConfig, current_sid: &mut u64, ) -> anyhow::Result { if *current_sid != 0 { match resume_tunnel_client(config, *current_sid).await { Ok(ep) => { info!("Resumed existing session {}", current_sid); return Ok(ep); } Err(e) => { log::warn!("Resume failed: {}. Falling back to full bootstrap.", e); } } } if let Some(host) = config.host.clone() { let (sid, _) = bootstrap_tls_and_get_sid(config, &host).await?; *current_sid = sid; resume_tunnel_client(config, sid).await } else { let mut stream = TcpStream::connect(&config.url).await?; perform_client_handshake(&mut stream, &config.token, config.identity).await?; let raw = upgrade_to_h2_raw(stream, config.identity).await?; *current_sid = 0; Ok(raw) } } async fn bootstrap_tls_and_get_sid( config: &ClientTunnelConfig, host: &str, ) -> anyhow::Result<(u64, ())> { let connector = build_client_tls_connector(); let tcp = TcpStream::connect(&config.url).await?; let server_name = ServerName::try_from(host.to_string()) .map_err(|_| anyhow!("Invalid TLS host: {}", host))? .to_owned(); let mut tls_stream = connector .connect(server_name, tcp) .await .map_err(|e| anyhow!("TLS handshake failed (server may not support TLS): {}", e))?; perform_client_handshake(&mut tls_stream, &config.token, config.identity).await?; tls_stream.write_all(TLS_BOOTSTRAP_MAGIC).await?; let mut magic = vec![0u8; TLS_BOOTSTRAP_MAGIC.len()]; tls_stream.read_exact(&mut magic).await?; if magic != TLS_BOOTSTRAP_MAGIC { return Err(anyhow!("TLS bootstrap response mismatch")); } let mut sid_buf = [0u8; 8]; tls_stream.read_exact(&mut sid_buf).await?; Ok((u64::from_be_bytes(sid_buf), ())) } async fn resume_tunnel_client( config: &ClientTunnelConfig, session_id: u64, ) -> anyhow::Result { let mut plain_stream = TcpStream::connect(&config.url).await?; plain_stream.write_all(RESUME_MAGIC).await?; plain_stream.write_all(&session_id.to_be_bytes()).await?; upgrade_to_h2_raw(plain_stream, config.identity).await }