From de7b63d366ffb453d59c52bc8642890bb579572a Mon Sep 17 00:00:00 2001 From: Bluemangoo Date: Fri, 24 Apr 2026 00:32:59 +0800 Subject: [PATCH] multi-server --- client/src/main.rs | 96 +++++++++++++++++++-- client/src/queue.rs | 119 +++++++++++++++++++++++++ client/src/signal.rs | 86 ++++++++++++++++++ client/src/task.rs | 201 ++++++++++++++++++++++++++++++++++++++----- 4 files changed, 475 insertions(+), 27 deletions(-) create mode 100644 client/src/queue.rs create mode 100644 client/src/signal.rs diff --git a/client/src/main.rs b/client/src/main.rs index 057a7a1..a5e5374 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -1,6 +1,9 @@ use crate::config::{ClientConfig, Profile}; -use crate::task::run; +use crate::queue::SharedQueue; +use crate::signal::Signal; +use crate::task::{AtomicCounters, AutoSaveManifest, post_run, run_main, run_side}; use common::strings::REGION_NOT_FOUND; +use common::updater::DownloadTask; use communicator::{ClientManager, Identity, TunnelEndpoint, TunnelListener, connect_tunnel}; use futures_util::future::join_all; use lazy_static::lazy_static; @@ -8,6 +11,7 @@ use log::{LevelFilter, error, info}; use simplelog::{ColorChoice, Config, TermLogger, TerminalMode}; use std::collections::{HashMap, VecDeque}; use std::fs; +use std::path::Path; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; @@ -19,6 +23,8 @@ use tokio_util::sync::CancellationToken; mod config; mod http; +mod queue; +mod signal; mod task; #[derive(StructOpt)] @@ -126,6 +132,18 @@ async fn main() -> anyhow::Result<()> { for profile in profiles { let profile = Arc::new(profile.clone()); let semaphore = Arc::new(Semaphore::new(1)); + let tasks: SharedQueue = SharedQueue::new(); + let cnt = AtomicCounters::new(); + let local_manifest = Arc::new( + AutoSaveManifest::new( + 5, + Path::new(&{ profile.1.read().await.path.clone() }) + .join("manifest.json") + .to_path_buf(), + ) + .await?, + ); + let signal = Signal::new(); let cancel_token = CancellationToken::new(); let post_task = { async |profile: Arc<(String, Arc>)>, @@ -158,6 +176,10 @@ async fn main() -> anyhow::Result<()> { let semaphore = semaphore.clone(); let cancel_token = cancel_token.clone(); let profile = profile.clone(); + let signal = signal.clone(); + let tasks = tasks.clone(); + let cnt = cnt.clone(); + let local_manifest = local_manifest.clone(); let liveness_tx = liveness_tx.clone(); join_set.spawn(async move { let _guard = liveness_tx; @@ -168,6 +190,7 @@ async fn main() -> anyhow::Result<()> { } let client = sender.recv(); if client.is_none() { + tokio::time::sleep(Duration::from_millis(50)).await; continue; } let client = client.unwrap(); @@ -177,6 +200,11 @@ async fn main() -> anyhow::Result<()> { let semaphore = semaphore.clone(); let cancel_token = cancel_token.clone(); let profile = profile.clone(); + let tasks = tasks.clone(); + let cnt = cnt.clone(); + let local_manifest = local_manifest.clone(); + let signal = signal.clone(); + inner_set.spawn(async move { loop { if client.get_client().await.is_err() { @@ -186,11 +214,38 @@ async fn main() -> anyhow::Result<()> { return; } { - let permit = semaphore.clone().acquire_owned().await.unwrap(); + let permit = loop { + tokio::select! { + awaitable = signal.subscribe() => { + let r = run_side(client.clone(), tasks.clone(),cnt.clone(),local_manifest.clone(),profile.clone()).await; + if let Err(e)=r{ + error!("{}", e); + } + awaitable.wait().await; + } + + res = semaphore.clone().acquire_owned() => { + let permit = res.expect("Semaphore closed"); + + break permit; + } + } + }; if cancel_token.is_cancelled() { return; } - let result = run(client.clone(), profile.clone()).await; + let sync_id = post_run(client.clone(), profile.clone(), tasks.clone(), cnt.clone()).await; + let sig = signal.pick().await; + let result = match sync_id { + Ok(Some(id)) => { + run_main(client.clone(), profile.clone(), id, tasks.clone(), cnt.clone(), local_manifest.clone()).await + } + Ok(None) => { + Ok(true) + } + Err(e) => { Err(e) } + }; + drop(sig); match result { Ok(true) => { post_task(profile.clone(), permit, cancel_token.clone()) @@ -227,6 +282,10 @@ async fn main() -> anyhow::Result<()> { let semaphore = semaphore.clone(); let cancel_token = cancel_token.clone(); let profile = profile.clone(); + let tasks = tasks.clone(); + let cnt = cnt.clone(); + let local_manifest = local_manifest.clone(); + let signal = signal.clone(); info!("tcp client started for {}", client_conf.url); join_set.spawn(async move { loop { @@ -250,11 +309,38 @@ async fn main() -> anyhow::Result<()> { return; } { - let permit = semaphore.clone().acquire_owned().await.unwrap(); + let permit = loop { + tokio::select! { + awaitable = signal.subscribe() => { + let r = run_side(client.clone(), tasks.clone(),cnt.clone(),local_manifest.clone(),profile.clone()).await; + if let Err(e)=r{ + error!("{}", e); + } + awaitable.wait().await; + } + + res = semaphore.clone().acquire_owned() => { + let permit = res.expect("Semaphore closed"); + + break permit; + } + } + }; if cancel_token.is_cancelled() { return; } - let result = run(client.clone(), profile.clone()).await; + let sync_id = post_run(client.clone(), profile.clone(), tasks.clone(), cnt.clone()).await; + let sig = signal.pick().await; + let result = match sync_id { + Ok(Some(id)) => { + run_main(client.clone(), profile.clone(), id, tasks.clone(), cnt.clone(), local_manifest.clone()).await + } + Ok(None) => { + Ok(true) + } + Err(e) => { Err(e) } + }; + drop(sig); match result { Ok(true) => { post_task(profile.clone(), permit, cancel_token.clone()) diff --git a/client/src/queue.rs b/client/src/queue.rs new file mode 100644 index 0000000..d7bbbb0 --- /dev/null +++ b/client/src/queue.rs @@ -0,0 +1,119 @@ +use std::collections::VecDeque; +use std::sync::{Arc, Condvar, Mutex, atomic::{AtomicUsize, Ordering}}; + +pub struct SharedQueue { + inner: Arc>, +} + +struct QueueInner { + data: Mutex>, + // 用于 pop 的阻塞 + pop_cond: Condvar, + // 用于“全部消费完”的阻塞 + done_cond: Condvar, + // 在途任务计数(队列中 + 正在处理中) + pending: AtomicUsize, +} + +/// 任务守卫:当它被释放时,说明消费彻底结束 +pub struct TaskGuard { + pub item: T, + inner: Arc>, +} + +impl Clone for TaskGuard { + fn clone(&self) -> Self { + // 关键:每多出一个 Guard 副本,就意味着多了一个需要等待的“消费行为” + // 必须增加全局在途计数,否则会导致 pending 减成负数或提前归零 + self.inner.pending.fetch_add(1, Ordering::SeqCst); + + Self { + item: self.item.clone(), + inner: self.inner.clone(), + } + } +} + +impl Drop for TaskGuard { + fn drop(&mut self) { + // 1. 任务完成,计数减一 + let prev = self.inner.pending.fetch_sub(1, Ordering::SeqCst); + + // 2. 如果减完后是 0,说明最后一项任务也处理完了 + if prev == 1 { + let _lock = self.inner.data.lock().unwrap(); + self.inner.done_cond.notify_all(); + } + } +} + +impl SharedQueue { + pub fn new() -> Self { + Self { + inner: Arc::new(QueueInner { + data: Mutex::new(VecDeque::new()), + pop_cond: Condvar::new(), + done_cond: Condvar::new(), + pending: AtomicUsize::new(0), + }), + } + } + + // pub fn push(&self, item: T) { + // let mut queue = self.inner.data.lock().unwrap(); + // // 增加在途计数 + // self.inner.pending.fetch_add(1, Ordering::SeqCst); + // queue.push_back(item); + // self.inner.pop_cond.notify_one(); + // } + + pub fn push_all(&self, items: impl IntoIterator) { + let mut queue = self.inner.data.lock().unwrap(); + let mut count = 0; + for item in items { + queue.push_back(item); + count += 1; + } + if count > 0 { + self.inner.pending.fetch_add(count, Ordering::SeqCst); + self.inner.pop_cond.notify_all(); + } + } + + // pub fn pop(&self) -> TaskGuard { + // let mut queue = self.inner.data.lock().unwrap(); + // while queue.is_empty() { + // queue = self.inner.pop_cond.wait(queue).unwrap(); + // } + // let item = queue.pop_front().unwrap(); + // TaskGuard { + // item, + // inner: self.inner.clone(), + // } + // } + + pub fn try_pop(&self) -> Option> { + let mut queue = self.inner.data.lock().unwrap(); + + queue.pop_front().map(|item| { + TaskGuard { + item, + inner: self.inner.clone(), + } + }) + } + + /// 阻塞当前线程,直到所有在途任务(pending == 0)全部处理完 + pub fn wait_until_all_consumed(&self) { + let mut _queue_lock = self.inner.data.lock().unwrap(); + while self.inner.pending.load(Ordering::SeqCst) > 0 { + _queue_lock = self.inner.done_cond.wait(_queue_lock).unwrap(); + } + } +} + +impl Clone for SharedQueue { + fn clone(&self) -> Self { + Self { inner: Arc::clone(&self.inner) } + } +} \ No newline at end of file diff --git a/client/src/signal.rs b/client/src/signal.rs new file mode 100644 index 0000000..9f75b5a --- /dev/null +++ b/client/src/signal.rs @@ -0,0 +1,86 @@ +use std::sync::Arc; +use tokio::sync::{watch, Mutex, OwnedMutexGuard}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Stage { + Idle, // 空闲/等待发令 + Processing, // Leader 干活中 +} + +pub struct Signal { + inner: Arc, +} + +struct Inner { + stage_tx: watch::Sender, + pick_lock: Arc>, +} + +impl Signal { + pub fn new() -> Self { + let (stage_tx, _) = watch::channel(Stage::Idle); + Self { + inner: Arc::new(Inner { + stage_tx, + pick_lock: Arc::new(Mutex::new(())), + }), + } + } + + pub async fn pick(&self) -> LeaderHandler { + let lock_handle = self.inner.pick_lock.clone(); + let _owned_guard = lock_handle.lock_owned().await; + + // 切换到工作状态 + let _ = self.inner.stage_tx.send(Stage::Processing); + + LeaderHandler { + inner: self.inner.clone(), + _guard: _owned_guard, + } + } + + pub async fn subscribe(&self) -> FollowerAwaitable { + let mut rx = self.inner.stage_tx.subscribe(); + // 如果当前是 Idle,就挂起等待 Leader 变为 Processing + while *rx.borrow() != Stage::Processing { + if rx.changed().await.is_err() { + break; + } + } + FollowerAwaitable { rx } + } +} + +pub struct LeaderHandler { + inner: Arc, + _guard: OwnedMutexGuard<()>, +} + +impl Drop for LeaderHandler { + fn drop(&mut self) { + // Leader 掉落,重置为 Idle,允许下一轮竞争 + let _ = self.inner.stage_tx.send(Stage::Idle); + } +} + +pub struct FollowerAwaitable { + rx: watch::Receiver, +} + +impl FollowerAwaitable { + pub async fn wait(mut self) { + // 等待状态变回 Idle (说明 Leader 掉落了) + while *self.rx.borrow() == Stage::Processing { + if self.rx.changed().await.is_err() { + break; + } + } + } +} + +impl Clone for Signal { + fn clone(&self) -> Self { + Self { inner: self.inner.clone() } + } +} \ No newline at end of file diff --git a/client/src/task.rs b/client/src/task.rs index 85f8817..d96c24e 100644 --- a/client/src/task.rs +++ b/client/src/task.rs @@ -1,21 +1,26 @@ use crate::config::Profile; use crate::http::{close, download, sync}; -use common::http::{CloseRequest, DownloadRequest}; -use communicator::ClientManager; +use crate::queue::SharedQueue; use anyhow::anyhow; +use common::http::{CloseRequest, DownloadRequest}; +use common::updater::DownloadTask; +use communicator::ClientManager; use log::{error, info}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::path::{Path, PathBuf}; +use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc}; use tokio::sync::{RwLock, Semaphore}; use tokio::task::JoinSet; +use tokio_util::sync::CancellationToken; -pub async fn run( +pub async fn post_run( client: Arc, profile: Arc<(String, Arc>)>, -) -> anyhow::Result { + queue: SharedQueue, + cnt: AtomicCounters, +) -> anyhow::Result> { info!("[{}]: Starting sync", profile.0); let p1 = Arc::new(profile.1.read().await.clone()); tokio::fs::create_dir_all(&p1.path).await?; @@ -47,20 +52,127 @@ pub async fn run( info!("[{}]: No tasks to sync, skipping", profile.0); let req = CloseRequest { id: id.clone() }; close(&mut client.get_client().await?, &req).await?; - return Ok(true); + return Ok(None); } + cnt.reset(); + queue.push_all(tasks); + Ok(Some(id)) +} + +pub async fn run_main( + client: Arc, + profile: Arc<(String, Arc>)>, + id: String, + queue: SharedQueue, + cnt: AtomicCounters, + manifest: Arc, +) -> anyhow::Result { + info!("[{}]: Starting sync", profile.0); + let p1 = Arc::new(profile.1.read().await.clone()); let n = p1.concurrent.unwrap_or(5); info!("[{}]: Start sync with {} thread", profile.0, n); let semaphore = Arc::new(Semaphore::new(n)); let mut join_set = JoinSet::new(); - for task in tasks { + let cancel_token = CancellationToken::new(); + while let Some(task) = queue.try_pop() { + if cancel_token.is_cancelled() { + break; + } let permit = semaphore.clone().acquire_owned().await?; let client = client.clone(); let id = id.clone(); - let local_manifest = local_manifest.clone(); + let local_manifest = manifest.clone(); + let p1 = p1.clone(); + + let cancel_token = cancel_token.clone(); + join_set.spawn(async move { + if cancel_token.is_cancelled() { + return Ok::<(), anyhow::Error>(()); + } + let guard = task; + let task = &guard.item; + let req = DownloadRequest { + id: id.clone(), + task: task.clone(), + }; + let mut conn = client.get_client().await?; + let mut result = download(&mut conn, &req, &p1).await; + if let Err(e) = &result + && e.downcast_ref::().is_some() + { + let mut retry_conn = client.get_client().await?; + result = download(&mut retry_conn, &req, &p1).await; + } + if let Err(e) = &result + && e.downcast_ref::().is_some() + { + cancel_token.cancel(); + } + result?; + + local_manifest + .add_bundle(task.bundle_path.clone(), task.bundle_hash.clone()) + .await?; + drop(permit); + Ok::<(), anyhow::Error>(()) + }); + } + while let Some(r) = join_set.join_next().await { + match r { + Ok(Ok(())) => cnt.inc_success(), + Ok(Err(e)) => { + if e.to_string() + .contains("Session did not reconnect within 15s") + { + return Err(anyhow!(e)); + } + error!("{}", e); + cnt.inc_failure() + } + Err(e) => { + error!("{}", e); + cnt.inc_failure() + } + } + } + manifest.save().await?; + queue.wait_until_all_consumed(); + info!( + "[{}]: Sync finished with {} succeed, {} failed", + profile.0, + cnt.get_success(), + cnt.get_failure() + ); + let req = CloseRequest { id: id.clone() }; + close(&mut client.get_client().await?, &req).await?; + + Ok(cnt.get_failure() == 0) +} + +pub async fn run_side( + client: Arc, + queue: SharedQueue, + cnt: AtomicCounters, + manifest: Arc, + profile: Arc<(String, Arc>)>, +) -> anyhow::Result<()> { + let p1 = Arc::new(profile.1.read().await.clone()); + tokio::fs::create_dir_all(&p1.path).await?; + let sync_resp = sync(&mut client.get_client().await?, &p1).await?; + let id = sync_resp.id; + let n = p1.concurrent.unwrap_or(5); + let semaphore = Arc::new(Semaphore::new(n)); + let mut join_set = JoinSet::new(); + while let Some(task) = queue.try_pop() { + let permit = semaphore.clone().acquire_owned().await?; + let client = client.clone(); + let id = id.clone(); + let local_manifest = manifest.clone(); let p1 = p1.clone(); join_set.spawn(async move { + let guard = task; + let task = &guard.item; let req = DownloadRequest { id: id.clone(), task: task.clone(), @@ -77,41 +189,36 @@ pub async fn run( local_manifest .add_bundle(task.bundle_path.clone(), task.bundle_hash.clone()) - .await - ?; + .await?; drop(permit); Ok::<(), anyhow::Error>(()) }); } - let mut succeed = 0; - let mut failed = 0; while let Some(r) = join_set.join_next().await { match r { Ok(Ok(())) => { - succeed += 1; + cnt.inc_success(); } Ok(Err(e)) => { - if e.to_string().contains("Session did not reconnect within 15s") { + if e.to_string() + .contains("Session did not reconnect within 15s") + { return Err(anyhow!(e)); } error!("{}", e); - failed += 1; + cnt.inc_failure(); } Err(e) => { error!("{}", e); - failed += 1; + cnt.inc_failure(); } } } - local_manifest.save().await?; - info!( - "[{}]: Sync finished with {} succeed, {} failed", - profile.0, succeed, failed - ); + manifest.save().await?; let req = CloseRequest { id: id.clone() }; close(&mut client.get_client().await?, &req).await?; - Ok(failed == 0) + Ok(()) } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -166,3 +273,53 @@ impl AutoSaveManifest { Ok(()) } } + +/// 支持 Clone 的双原子计数器 +#[derive(Clone)] +pub struct AtomicCounters { + inner: Arc, +} + +struct CountersInner { + success: AtomicUsize, + failure: AtomicUsize, +} + +impl AtomicCounters { + pub fn new() -> Self { + Self { + inner: Arc::new(CountersInner { + success: AtomicUsize::new(0), + failure: AtomicUsize::new(0), + }), + } + } + + pub fn inc_success(&self) { + self.inner.success.fetch_add(1, Ordering::Relaxed); + } + + pub fn inc_failure(&self) { + self.inner.failure.fetch_add(1, Ordering::Relaxed); + } + + pub fn get_success(&self) -> usize { + self.inner.success.load(Ordering::Relaxed) + } + + pub fn get_failure(&self) -> usize { + self.inner.failure.load(Ordering::Relaxed) + } + + // pub fn load(&self) -> (usize, usize) { + // ( + // self.inner.success.load(Ordering::Relaxed), + // self.inner.failure.load(Ordering::Relaxed), + // ) + // } + + pub fn reset(&self) { + self.inner.success.store(0, Ordering::Relaxed); + self.inner.failure.store(0, Ordering::Relaxed); + } +}