1
0
mirror of https://github.com/Bluemangoo/sekai-unpacker.git synced 2026-05-06 20:44:47 +08:00
2026-04-24 12:50:40 +08:00

691 lines
23 KiB
Rust

use anyhow::anyhow;
use async_http_proxy::http_connect_tokio;
use bytes::Bytes;
use h2::{RecvStream, client, server};
use http::Request;
use log::{debug, error, info};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::error::Error;
use std::fs::File;
use std::io::BufReader;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{Mutex, Notify};
use tokio::time::{Duration, sleep, timeout};
use tokio_rustls::rustls;
use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName};
use tokio_rustls::{TlsAcceptor, TlsConnector};
use tokio_socks::tcp::Socks5Stream;
use url::Url;
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum Identity {
Server,
Client,
}
impl Identity {
pub fn as_u8(&self) -> u8 {
match self {
Identity::Server => b'S',
Identity::Client => b'C',
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientTunnelConfig {
pub host: Option<String>,
pub url: String,
pub token: String,
pub identity: Identity,
pub proxy: Option<String>,
}
impl Default for ClientTunnelConfig {
fn default() -> Self {
Self {
host: None,
url: "127.0.0.1:3333".to_string(),
token: "super_secret_magic_token".to_string(),
identity: Identity::Client,
proxy: None,
}
}
}
unsafe impl Send for ClientTunnelConfig {}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SslConfig {
pub cert: String,
pub key: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerTunnelConfig {
pub url: String,
pub cert: Option<SslConfig>,
pub token: String,
pub identity: Identity,
}
impl Default for ServerTunnelConfig {
fn default() -> Self {
Self {
url: "127.0.0.1:3333".to_string(),
cert: None,
token: "super_secret_magic_token".to_string(),
identity: Identity::Client,
}
}
}
static CLIENT_HELLO: &[u8] = b"Sekai unpacker client hello";
static SERVER_HELLO: &[u8] = b"Sekai unpacker server hello";
static TLS_BOOTSTRAP_MAGIC: &[u8] = b"SPTLSP1";
static RESUME_MAGIC: &[u8] = b"SPRESM1";
pub enum TunnelEndpointRaw {
Client(Box<client::SendRequest<Bytes>>),
Server(Box<server::Connection<TcpStream, Bytes>>),
}
#[derive(Clone)]
pub enum TunnelEndpoint {
Client(Arc<ClientManager>),
Server(Arc<ServerManager>),
}
#[derive(Clone)]
pub enum WeakTunnelEndpoint {
Client(std::sync::Weak<ClientManager>),
Server(std::sync::Weak<ServerManager>),
}
impl TunnelEndpoint {
pub fn downgrade(&self) -> WeakTunnelEndpoint {
match self {
Self::Client(c) => WeakTunnelEndpoint::Client(Arc::downgrade(c)),
Self::Server(s) => WeakTunnelEndpoint::Server(Arc::downgrade(s)),
}
}
}
impl WeakTunnelEndpoint {
pub fn upgrade(&self) -> Option<TunnelEndpoint> {
match self {
Self::Client(c) => c.upgrade().map(TunnelEndpoint::Client),
Self::Server(s) => s.upgrade().map(TunnelEndpoint::Server),
}
}
}
pub struct ClientManager {
pub session_id: AtomicU64,
pub current_client: Mutex<Option<client::SendRequest<Bytes>>>,
pub config: Option<ClientTunnelConfig>,
pub notify: Arc<Notify>,
}
impl ClientManager {
pub async fn get_client(&self) -> anyhow::Result<client::SendRequest<Bytes>> {
loop {
let cached = { self.current_client.lock().await.clone() };
if let Some(mut c) = cached {
if futures::future::poll_fn(|cx| c.poll_ready(cx))
.await
.is_ok()
{
return Ok(c);
}
debug!("Client physical connection dead, preparing to recover...");
}
let mut lock = self.current_client.lock().await;
if lock.is_some()
&& futures::future::poll_fn(|cx| lock.as_mut().unwrap().poll_ready(cx))
.await
.is_ok()
{
continue;
}
if let Some(config) = &self.config {
let mut retry_delay = Duration::from_secs(1);
let mut sid = self.session_id.load(Ordering::Relaxed);
let new_client = loop {
match do_client_reconnect(config, &mut sid).await {
Ok(TunnelEndpointRaw::Client(c)) => {
self.session_id.store(sid, Ordering::Relaxed);
break c;
}
Ok(_) => return Err(anyhow::anyhow!("Identity mismatch on reconnect")),
Err(e) => {
error!("Reconnect failed: {}. Retrying in {:?}...", e, retry_delay);
drop(e);
sleep(retry_delay).await;
retry_delay = std::cmp::min(retry_delay * 2, Duration::from_secs(30));
}
}
};
*lock = Some(*new_client.clone());
return Ok(*new_client);
} else {
*lock = None;
drop(lock);
match timeout(Duration::from_secs(15), self.notify.notified()).await {
Ok(_) => continue,
Err(_) => {
anyhow::bail!("Session did not reconnect within 15s. Throwing error!")
}
}
}
}
}
}
pub struct ServerManager {
pub session_id: AtomicU64,
pub current_server: Mutex<Option<server::Connection<TcpStream, Bytes>>>,
pub config: Option<ClientTunnelConfig>,
pub notify: Arc<Notify>,
}
impl ServerManager {
pub async fn accept(
&self,
) -> Option<Result<(Request<RecvStream>, server::SendResponse<Bytes>), h2::Error>> {
loop {
let mut conn_guard = self.current_server.lock().await;
if let Some(conn) = conn_guard.as_mut() {
if let Some(res) = conn.accept().await {
return Some(res);
}
*conn_guard = None;
log::warn!("Server physical connection dropped, waiting for recovery...");
}
drop(conn_guard);
if let Some(config) = &self.config {
let mut retry_delay = Duration::from_secs(1);
let mut sid = self.session_id.load(Ordering::Relaxed);
let new_server = loop {
match do_client_reconnect(config, &mut sid).await {
Ok(TunnelEndpointRaw::Server(s)) => {
self.session_id.store(sid, Ordering::Relaxed);
break s;
}
Ok(_) => {
error!("Identity mismatch on reconnect");
return None;
}
Err(e) => {
error!("Reconnect failed: {}. Retrying in {:?}...", e, retry_delay);
drop(e);
sleep(retry_delay).await;
retry_delay = std::cmp::min(retry_delay * 2, Duration::from_secs(30));
}
}
};
*self.current_server.lock().await = Some(*new_server);
continue;
} else {
match timeout(Duration::from_secs(15), self.notify.notified()).await {
Ok(_) => continue,
Err(_) => {
error!(
"Session did not reconnect within 15s. Throwing error (returning None)!"
);
return None;
}
}
}
}
}
}
enum ResumeResult {
NotResume(TcpStream),
NewSession(TunnelEndpointRaw, u64),
ResumedExisting,
Invalid,
}
pub struct TunnelListener {
listener: TcpListener,
config: ServerTunnelConfig,
pending_plain_sessions: Mutex<HashMap<u64, std::time::Instant>>,
next_session_id: AtomicU64,
active_sessions: Mutex<HashMap<u64, WeakTunnelEndpoint>>,
}
impl TunnelListener {
pub async fn bind(config: ServerTunnelConfig) -> anyhow::Result<Self> {
let listener = TcpListener::bind(&config.url).await?;
info!("TCP tunnel listener bound to {}", &config.url);
Ok(Self {
listener,
config,
pending_plain_sessions: Mutex::new(HashMap::new()),
next_session_id: AtomicU64::new(1),
active_sessions: Mutex::new(HashMap::new()),
})
}
pub async fn accept(&self) -> anyhow::Result<TunnelEndpoint> {
loop {
let (stream, peer_addr) = self.listener.accept().await?;
debug!("[{}] Connected on tcp", peer_addr);
if is_tls_client_hello(&stream).await {
if let Err(e) = self.handle_tls_bootstrap(stream, peer_addr).await {
error!("[{}] TLS bootstrap failed: {}", peer_addr, e);
}
continue;
}
let mut stream = match self.try_resume_plain_session(stream, peer_addr).await? {
ResumeResult::NewSession(ep_raw, sid) => {
let ep = wrap_raw_endpoint(sid, ep_raw, None);
let mut sessions = self.active_sessions.lock().await;
sessions.retain(|_, weak_ep| weak_ep.upgrade().is_some());
sessions.insert(sid, ep.downgrade());
return Ok(ep);
}
ResumeResult::ResumedExisting => continue,
ResumeResult::Invalid => continue,
ResumeResult::NotResume(s) => s,
};
if perform_server_handshake(&mut stream, &self.config.token, self.config.identity)
.await
.is_err()
{
debug!("[{}] Plain handshake failed", peer_addr);
continue;
}
debug!(
"[{}] Plain handshake completed, upgrading to H2...",
peer_addr
);
let ep_raw = upgrade_to_h2_raw(stream, self.config.identity)
.await
.map_err(|e| anyhow::anyhow!("{}", e))?;
let sid = self.next_session_id.fetch_add(1, Ordering::Relaxed);
let ep = wrap_raw_endpoint(sid, ep_raw, None);
let mut sessions = self.active_sessions.lock().await;
sessions.retain(|_, weak_ep| weak_ep.upgrade().is_some());
sessions.insert(sid, ep.downgrade());
return Ok(ep);
}
}
async fn handle_tls_bootstrap(
&self,
stream: TcpStream,
peer_addr: std::net::SocketAddr,
) -> Result<(), Box<dyn Error>> {
let cert_cfg = self
.config
.cert
.as_ref()
.ok_or("TLS is not enabled on server")?;
let acceptor = TlsAcceptor::from(build_server_tls_config(cert_cfg)?);
let mut tls_stream = acceptor.accept(stream).await?;
perform_server_handshake(&mut tls_stream, &self.config.token, self.config.identity).await?;
let mut magic = vec![0u8; TLS_BOOTSTRAP_MAGIC.len()];
tls_stream.read_exact(&mut magic).await?;
if magic != TLS_BOOTSTRAP_MAGIC {
return Err("TLS bootstrap marker mismatch".into());
}
let session_id = self.next_session_id.fetch_add(1, Ordering::Relaxed);
{
let mut pending = self.pending_plain_sessions.lock().await;
pending.retain(|_, time| time.elapsed() < Duration::from_secs(30));
pending.insert(session_id, std::time::Instant::now());
}
tls_stream.write_all(TLS_BOOTSTRAP_MAGIC).await?;
tls_stream.write_all(&session_id.to_be_bytes()).await?;
debug!(
"[{}] TLS bootstrap done, issued plain-H2 session {}",
peer_addr, session_id
);
Ok(())
}
async fn try_resume_plain_session(
&self,
mut stream: TcpStream,
peer_addr: std::net::SocketAddr,
) -> anyhow::Result<ResumeResult> {
let mut peek_buf = vec![0u8; RESUME_MAGIC.len()];
let n = stream.peek(&mut peek_buf).await?;
if n < RESUME_MAGIC.len() || peek_buf != RESUME_MAGIC {
return Ok(ResumeResult::NotResume(stream));
}
let mut magic = vec![0u8; RESUME_MAGIC.len()];
stream.read_exact(&mut magic).await?;
let mut sid_buf = [0u8; 8];
stream.read_exact(&mut sid_buf).await?;
let session_id = u64::from_be_bytes(sid_buf);
let is_pending = {
let mut pending = self.pending_plain_sessions.lock().await;
pending.remove(&session_id).is_some()
};
if is_pending {
let ep_raw = upgrade_to_h2_raw(stream, self.config.identity)
.await
.map_err(|e| anyhow::anyhow!("{}", e))?;
return Ok(ResumeResult::NewSession(ep_raw, session_id));
}
let active = self.active_sessions.lock().await.get(&session_id).cloned();
if let Some(ep) = active {
let ep_raw = upgrade_to_h2_raw(stream, self.config.identity)
.await
.map_err(|e| anyhow::anyhow!("{}", e))?;
update_endpoint(
&ep.upgrade().ok_or(anyhow!("Connection is cleared"))?,
ep_raw,
)
.await;
info!(
"[{}] Successfully resumed existing session {}",
peer_addr, session_id
);
return Ok(ResumeResult::ResumedExisting);
}
error!("[{}] Invalid plain-H2 session {}", peer_addr, session_id);
Ok(ResumeResult::Invalid)
}
}
async fn is_tls_client_hello(stream: &TcpStream) -> bool {
let mut header = [0u8; 3];
match stream.peek(&mut header).await {
Ok(n) if n >= 3 => header[0] == 0x16 && header[1] == 0x03 && (1..=4).contains(&header[2]),
_ => false,
}
}
async fn perform_server_handshake<S>(
stream: &mut S,
token: &str,
identity: Identity,
) -> Result<(), Box<dyn Error>>
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
let mut client_hello_buf = vec![0u8; CLIENT_HELLO.len()];
stream.read_exact(&mut client_hello_buf).await?;
if client_hello_buf != CLIENT_HELLO {
return Err("Client Hello mismatch".into());
}
let mut token_buf = vec![0u8; token.len()];
stream.read_exact(&mut token_buf).await?;
if String::from_utf8_lossy(&token_buf) != token {
return Err("Wrong token".into());
}
stream.write_all(SERVER_HELLO).await?;
let mut peer_id_buf = [0u8; 1];
stream.read_exact(&mut peer_id_buf).await?;
stream.write_all(&[identity.as_u8()]).await?;
if peer_id_buf[0] == identity.as_u8() {
return Err("Identity collision with peer".into());
}
Ok(())
}
async fn perform_client_handshake<S>(
stream: &mut S,
token: &str,
identity: Identity,
) -> anyhow::Result<()>
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
stream.write_all(CLIENT_HELLO).await?;
stream.write_all(token.as_bytes()).await?;
let mut server_hello_buf = vec![0u8; SERVER_HELLO.len()];
stream.read_exact(&mut server_hello_buf).await?;
if server_hello_buf != SERVER_HELLO {
return Err(anyhow!("Server Hello mismatch"));
}
stream.write_all(&[identity.as_u8()]).await?;
let mut peer_id_buf = [0u8; 1];
stream.read_exact(&mut peer_id_buf).await?;
if peer_id_buf[0] == identity.as_u8() {
return Err(anyhow!("Identity collision with server"));
}
Ok(())
}
fn build_server_tls_config(
cert_cfg: &SslConfig,
) -> Result<Arc<rustls::ServerConfig>, Box<dyn Error>> {
let cert_file = File::open(&cert_cfg.cert)?;
let mut cert_reader = BufReader::new(cert_file);
let cert_chain: Vec<CertificateDer<'static>> =
rustls_pemfile::certs(&mut cert_reader).collect::<Result<Vec<_>, _>>()?;
if cert_chain.is_empty() {
return Err("No certificate found in PEM".into());
}
let key_file = File::open(&cert_cfg.key)?;
let mut key_reader = BufReader::new(key_file);
let private_key =
rustls_pemfile::private_key(&mut key_reader)?.ok_or("No private key found in PEM")?;
let server_config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert_chain, private_key)?;
Ok(Arc::new(server_config))
}
fn build_client_tls_connector() -> TlsConnector {
let root_store =
rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
TlsConnector::from(Arc::new(config))
}
async fn upgrade_to_h2_raw(
stream: TcpStream,
identity: Identity,
) -> anyhow::Result<TunnelEndpointRaw> {
match identity {
Identity::Client => {
let (h2_client, h2_conn) = client::handshake(stream).await?;
tokio::spawn(async move {
if let Err(e) = h2_conn.await {
debug!("H2 connection driver finished/error: {:?}", e);
}
});
Ok(TunnelEndpointRaw::Client(Box::new(h2_client)))
}
Identity::Server => {
let h2_conn = server::handshake(stream).await?;
Ok(TunnelEndpointRaw::Server(Box::new(h2_conn)))
}
}
}
fn wrap_raw_endpoint(
sid: u64,
raw: TunnelEndpointRaw,
config: Option<ClientTunnelConfig>,
) -> TunnelEndpoint {
match raw {
TunnelEndpointRaw::Client(c) => TunnelEndpoint::Client(Arc::new(ClientManager {
session_id: AtomicU64::new(sid),
current_client: Mutex::new(Some(*c)),
config,
notify: Arc::new(Notify::new()),
})),
TunnelEndpointRaw::Server(s) => TunnelEndpoint::Server(Arc::new(ServerManager {
session_id: AtomicU64::new(sid),
current_server: Mutex::new(Some(*s)),
config,
notify: Arc::new(Notify::new()),
})),
}
}
async fn update_endpoint(ep: &TunnelEndpoint, ep_raw: TunnelEndpointRaw) {
match (ep, ep_raw) {
(TunnelEndpoint::Client(mgr), TunnelEndpointRaw::Client(c)) => {
*mgr.current_client.lock().await = Some(*c);
mgr.notify.notify_waiters();
}
(TunnelEndpoint::Server(mgr), TunnelEndpointRaw::Server(s)) => {
*mgr.current_server.lock().await = Some(*s);
mgr.notify.notify_waiters();
}
_ => error!("Identity mismatch during session resume!"),
}
}
pub async fn connect_tunnel(config: ClientTunnelConfig) -> Result<TunnelEndpoint, Box<dyn Error>> {
info!("Connecting to tunnel at {}", &config.url);
let mut sid = 0;
let ep_raw = do_client_reconnect(&config, &mut sid).await?;
Ok(wrap_raw_endpoint(sid, ep_raw, Some(config)))
}
pub async fn connect_with_auto_proxy(config: &ClientTunnelConfig) -> anyhow::Result<TcpStream> {
let Some(proxy) = &config.proxy else {
return Ok(TcpStream::connect(&config.url).await?);
};
let parsed_proxy = Url::parse(proxy)?;
match parsed_proxy.scheme() {
"socks5" => {
let host = parsed_proxy.host_str().unwrap();
let port = parsed_proxy.port().unwrap_or(1080);
let stream = Socks5Stream::connect((host, port), config.url.clone())
.await?
.into_inner();
Ok(stream)
}
"http" | "https" => {
let proxy_addr = format!(
"{}:{}",
parsed_proxy.host_str().unwrap(),
parsed_proxy.port().unwrap_or(80)
);
let mut stream = TcpStream::connect(proxy_addr).await?;
let target_url = Url::parse(&format!("tcp://{}", config.url))?;
http_connect_tokio(
&mut stream,
target_url.host_str().unwrap(),
target_url.port().unwrap_or(80),
)
.await?;
Ok(stream)
}
_ => Err(anyhow::anyhow!(
"Unsupported proxy scheme: {}",
parsed_proxy.scheme()
)),
}
}
async fn do_client_reconnect(
config: &ClientTunnelConfig,
current_sid: &mut u64,
) -> anyhow::Result<TunnelEndpointRaw> {
if *current_sid != 0 {
match resume_tunnel_client(config, *current_sid).await {
Ok(ep) => {
info!("Resumed existing session {}", current_sid);
return Ok(ep);
}
Err(e) => {
log::warn!("Resume failed: {}. Falling back to full bootstrap.", e);
}
}
}
if let Some(host) = config.host.clone() {
let (sid, _) = bootstrap_tls_and_get_sid(config, &host).await?;
*current_sid = sid;
resume_tunnel_client(config, sid).await
} else {
let mut stream = connect_with_auto_proxy(config).await?;
perform_client_handshake(&mut stream, &config.token, config.identity).await?;
let raw = upgrade_to_h2_raw(stream, config.identity).await?;
*current_sid = 0;
Ok(raw)
}
}
async fn bootstrap_tls_and_get_sid(
config: &ClientTunnelConfig,
host: &str,
) -> anyhow::Result<(u64, ())> {
let connector = build_client_tls_connector();
let tcp = connect_with_auto_proxy(config).await?;
let server_name = ServerName::try_from(host.to_string())
.map_err(|_| anyhow!("Invalid TLS host: {}", host))?
.to_owned();
let mut tls_stream = connector
.connect(server_name, tcp)
.await
.map_err(|e| anyhow!("TLS handshake failed (server may not support TLS): {}", e))?;
perform_client_handshake(&mut tls_stream, &config.token, config.identity).await?;
tls_stream.write_all(TLS_BOOTSTRAP_MAGIC).await?;
let mut magic = vec![0u8; TLS_BOOTSTRAP_MAGIC.len()];
tls_stream.read_exact(&mut magic).await?;
if magic != TLS_BOOTSTRAP_MAGIC {
return Err(anyhow!("TLS bootstrap response mismatch"));
}
let mut sid_buf = [0u8; 8];
tls_stream.read_exact(&mut sid_buf).await?;
Ok((u64::from_be_bytes(sid_buf), ()))
}
async fn resume_tunnel_client(
config: &ClientTunnelConfig,
session_id: u64,
) -> anyhow::Result<TunnelEndpointRaw> {
let mut plain_stream = connect_with_auto_proxy(config).await?;
plain_stream.write_all(RESUME_MAGIC).await?;
plain_stream.write_all(&session_id.to_be_bytes()).await?;
upgrade_to_h2_raw(plain_stream, config.identity).await
}