diff --git a/client/src/main.rs b/client/src/main.rs index a5e5374..70ceefd 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -183,7 +183,6 @@ async fn main() -> anyhow::Result<()> { let liveness_tx = liveness_tx.clone(); join_set.spawn(async move { let _guard = liveness_tx; - let mut inner_set = JoinSet::new(); loop { if cancel_token.is_cancelled() { return; @@ -205,7 +204,7 @@ async fn main() -> anyhow::Result<()> { let local_manifest = local_manifest.clone(); let signal = signal.clone(); - inner_set.spawn(async move { + tokio::spawn(async move { loop { if client.get_client().await.is_err() { return; diff --git a/client/src/task.rs b/client/src/task.rs index d96c24e..760037c 100644 --- a/client/src/task.rs +++ b/client/src/task.rs @@ -121,11 +121,6 @@ pub async fn run_main( match r { Ok(Ok(())) => cnt.inc_success(), Ok(Err(e)) => { - if e.to_string() - .contains("Session did not reconnect within 15s") - { - return Err(anyhow!(e)); - } error!("{}", e); cnt.inc_failure() } @@ -163,12 +158,21 @@ pub async fn run_side( let n = p1.concurrent.unwrap_or(5); let semaphore = Arc::new(Semaphore::new(n)); let mut join_set = JoinSet::new(); + + let cancel_token = CancellationToken::new(); while let Some(task) = queue.try_pop() { + if cancel_token.is_cancelled() { + break; + } let permit = semaphore.clone().acquire_owned().await?; + if cancel_token.is_cancelled() { + break; + } let client = client.clone(); let id = id.clone(); let local_manifest = manifest.clone(); let p1 = p1.clone(); + let cancel_token = cancel_token.clone(); join_set.spawn(async move { let guard = task; @@ -185,6 +189,11 @@ pub async fn run_side( let mut retry_conn = client.get_client().await?; result = download(&mut retry_conn, &req, &p1).await; } + if let Err(e) = &result + && e.downcast_ref::().is_some() + { + cancel_token.cancel(); + } result?; local_manifest diff --git a/communicator/src/stream.rs b/communicator/src/stream.rs index 07032d4..c5e6bcf 100644 --- a/communicator/src/stream.rs +++ b/communicator/src/stream.rs @@ -95,6 +95,30 @@ pub enum TunnelEndpoint { Server(Arc), } +#[derive(Clone)] +pub enum WeakTunnelEndpoint { + Client(std::sync::Weak), + Server(std::sync::Weak), +} + +impl TunnelEndpoint { + pub fn downgrade(&self) -> WeakTunnelEndpoint { + match self { + Self::Client(c) => WeakTunnelEndpoint::Client(Arc::downgrade(c)), + Self::Server(s) => WeakTunnelEndpoint::Server(Arc::downgrade(s)), + } + } +} + +impl WeakTunnelEndpoint { + pub fn upgrade(&self) -> Option { + match self { + Self::Client(c) => c.upgrade().map(TunnelEndpoint::Client), + Self::Server(s) => s.upgrade().map(TunnelEndpoint::Server), + } + } +} + pub struct ClientManager { pub session_id: AtomicU64, pub current_client: Mutex>>, @@ -231,9 +255,9 @@ enum ResumeResult { pub struct TunnelListener { listener: TcpListener, config: ServerTunnelConfig, - pending_plain_sessions: Mutex>, + pending_plain_sessions: Mutex>, next_session_id: AtomicU64, - active_sessions: Mutex>, + active_sessions: Mutex>, } impl TunnelListener { @@ -243,7 +267,7 @@ impl TunnelListener { Ok(Self { listener, config, - pending_plain_sessions: Mutex::new(HashSet::new()), + pending_plain_sessions: Mutex::new(HashMap::new()), next_session_id: AtomicU64::new(1), active_sessions: Mutex::new(HashMap::new()), }) @@ -264,7 +288,9 @@ impl TunnelListener { let mut stream = match self.try_resume_plain_session(stream, peer_addr).await? { ResumeResult::NewSession(ep_raw, sid) => { let ep = wrap_raw_endpoint(sid, ep_raw, None); - self.active_sessions.lock().await.insert(sid, ep.clone()); + let mut sessions = self.active_sessions.lock().await; + sessions.retain(|_, weak_ep| weak_ep.upgrade().is_some()); + sessions.insert(sid, ep.downgrade()); return Ok(ep); } ResumeResult::ResumedExisting => continue, @@ -290,7 +316,9 @@ impl TunnelListener { let sid = self.next_session_id.fetch_add(1, Ordering::Relaxed); let ep = wrap_raw_endpoint(sid, ep_raw, None); - self.active_sessions.lock().await.insert(sid, ep.clone()); + let mut sessions = self.active_sessions.lock().await; + sessions.retain(|_, weak_ep| weak_ep.upgrade().is_some()); + sessions.insert(sid, ep.downgrade()); return Ok(ep); } } @@ -320,7 +348,8 @@ impl TunnelListener { let session_id = self.next_session_id.fetch_add(1, Ordering::Relaxed); { let mut pending = self.pending_plain_sessions.lock().await; - pending.insert(session_id); + pending.retain(|_, time| time.elapsed() < Duration::from_secs(30)); + pending.insert(session_id, std::time::Instant::now()); } tls_stream.write_all(TLS_BOOTSTRAP_MAGIC).await?; @@ -351,7 +380,7 @@ impl TunnelListener { let is_pending = { let mut pending = self.pending_plain_sessions.lock().await; - pending.remove(&session_id) + pending.remove(&session_id).is_some() }; if is_pending { @@ -366,7 +395,11 @@ impl TunnelListener { let ep_raw = upgrade_to_h2_raw(stream, self.config.identity) .await .map_err(|e| anyhow::anyhow!("{}", e))?; - update_endpoint(&ep, ep_raw).await; + update_endpoint( + &ep.upgrade().ok_or(anyhow!("Connection is cleared"))?, + ep_raw, + ) + .await; info!( "[{}] Successfully resumed existing session {}", peer_addr, session_id