use crate::config::Profile; use crate::http::{close, download, sync}; use crate::queue::SharedQueue; use anyhow::anyhow; use common::http::{CloseRequest, DownloadRequest}; use common::updater::DownloadTask; use communicator::ClientManager; use log::{error, info}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::path::{Path, PathBuf}; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; use tokio::sync::{RwLock, Semaphore}; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; pub async fn post_run( client: Arc, profile: Arc<(String, Arc>)>, queue: SharedQueue, cnt: AtomicCounters, ) -> anyhow::Result> { info!("[{}]: Starting sync", profile.0); let p1 = Arc::new(profile.1.read().await.clone()); tokio::fs::create_dir_all(&p1.path).await?; let sync_resp = sync(&mut client.get_client().await?, &p1).await?; let id = sync_resp.id; let local_manifest = Arc::new( AutoSaveManifest::new(5, Path::new(&p1.path).join("manifest.json").to_path_buf()).await?, ); let manifest_snapshot = { local_manifest.manifest.read().await.clone() }; let all_cnt = sync_resp.tasks.len(); let tasks = sync_resp .tasks .into_iter() .filter(|task| { let bundle_name = &task.bundle_path; match manifest_snapshot.bundles.get(bundle_name) { Some(local_hash) => local_hash != &task.bundle_hash, None => true, } }) .collect::>(); info!( "[{}]: Collected {}/{} tasks", profile.0, tasks.len(), all_cnt ); if tasks.is_empty() { info!("[{}]: No tasks to sync, skipping", profile.0); let req = CloseRequest { id: id.clone() }; close(&mut client.get_client().await?, &req).await?; return Ok(None); } cnt.reset(); queue.clear(); queue.push_all(tasks); Ok(Some(id)) } pub async fn run_main( client: Arc, profile: Arc<(String, Arc>)>, id: String, queue: SharedQueue, cnt: AtomicCounters, manifest: Arc, ) -> anyhow::Result { info!("[{}]: Starting sync", profile.0); let p1 = Arc::new(profile.1.read().await.clone()); let n = p1.concurrent.unwrap_or(5); info!("[{}]: Start sync with {} thread", profile.0, n); let semaphore = Arc::new(Semaphore::new(n)); let mut join_set = JoinSet::new(); let cancel_token = CancellationToken::new(); while let Some(task) = queue.try_pop() { if cancel_token.is_cancelled() { break; } let permit = semaphore.clone().acquire_owned().await?; let client = client.clone(); let id = id.clone(); let local_manifest = manifest.clone(); let p1 = p1.clone(); let cancel_token = cancel_token.clone(); join_set.spawn(async move { if cancel_token.is_cancelled() { return Ok::<(), anyhow::Error>(()); } let guard = task; let task = &guard.item; let req = DownloadRequest { id: id.clone(), task: task.clone(), }; let mut conn = client.get_client().await?; let mut result = download(&mut conn, &req, &p1).await; if let Err(e) = &result && e.downcast_ref::().is_some() { let mut retry_conn = client.get_client().await?; result = download(&mut retry_conn, &req, &p1).await; } if let Err(e) = &result && e.downcast_ref::().is_some() { cancel_token.cancel(); } result?; local_manifest .add_bundle(task.bundle_path.clone(), task.bundle_hash.clone()) .await?; drop(permit); Ok::<(), anyhow::Error>(()) }); } while let Some(r) = join_set.join_next().await { match r { Ok(Ok(())) => cnt.inc_success(), Ok(Err(e)) => { error!("{}", e); cnt.inc_failure() } Err(e) => { error!("{}", e); cnt.inc_failure() } } } manifest.save().await?; queue.wait_until_all_consumed(); info!( "[{}]: Sync finished with {} succeed, {} failed", profile.0, cnt.get_success(), cnt.get_failure() ); let req = CloseRequest { id: id.clone() }; close(&mut client.get_client().await?, &req).await?; Ok(cnt.get_failure() == 0) } pub async fn run_side( client: Arc, queue: SharedQueue, cnt: AtomicCounters, manifest: Arc, profile: Arc<(String, Arc>)>, ) -> anyhow::Result<()> { let p1 = Arc::new(profile.1.read().await.clone()); tokio::fs::create_dir_all(&p1.path).await?; let sync_resp = sync(&mut client.get_client().await?, &p1).await?; let id = sync_resp.id; let n = p1.concurrent.unwrap_or(5); let semaphore = Arc::new(Semaphore::new(n)); let mut join_set = JoinSet::new(); let cancel_token = CancellationToken::new(); while let Some(task) = queue.try_pop() { if cancel_token.is_cancelled() { break; } let permit = semaphore.clone().acquire_owned().await?; if cancel_token.is_cancelled() { break; } let client = client.clone(); let id = id.clone(); let local_manifest = manifest.clone(); let p1 = p1.clone(); let cancel_token = cancel_token.clone(); join_set.spawn(async move { let guard = task; let task = &guard.item; let req = DownloadRequest { id: id.clone(), task: task.clone(), }; let mut conn = client.get_client().await?; let mut result = download(&mut conn, &req, &p1).await; if let Err(e) = &result && e.downcast_ref::().is_some() { let mut retry_conn = client.get_client().await?; result = download(&mut retry_conn, &req, &p1).await; } if let Err(e) = &result && e.downcast_ref::().is_some() { cancel_token.cancel(); } result?; local_manifest .add_bundle(task.bundle_path.clone(), task.bundle_hash.clone()) .await?; drop(permit); Ok::<(), anyhow::Error>(()) }); } while let Some(r) = join_set.join_next().await { match r { Ok(Ok(())) => { cnt.inc_success(); } Ok(Err(e)) => { if e.to_string() .contains("Session did not reconnect within 15s") { return Err(anyhow!(e)); } error!("{}", e); cnt.inc_failure(); } Err(e) => { error!("{}", e); cnt.inc_failure(); } } } manifest.save().await?; let req = CloseRequest { id: id.clone() }; close(&mut client.get_client().await?, &req).await?; Ok(()) } #[derive(Debug, Clone, Serialize, Deserialize)] struct Manifest { #[serde(default)] bundles: HashMap, } pub struct AutoSaveManifest { manifest: Arc>, counter: AtomicUsize, save_interval: usize, storage_path: PathBuf, } impl AutoSaveManifest { pub async fn new(interval: usize, path: PathBuf) -> anyhow::Result { Ok(Self { manifest: Arc::new(RwLock::new(serde_json::from_str( &tokio::fs::read_to_string(&path) .await .unwrap_or("{}".to_owned()), )?)), counter: AtomicUsize::new(0), save_interval: interval, storage_path: path, }) } pub async fn add_bundle(&self, key: String, value: String) -> anyhow::Result<()> { { let mut w = self.manifest.write().await; w.bundles.insert(key, value); } let current_count = self.counter.fetch_add(1, Ordering::SeqCst) + 1; if current_count.is_multiple_of(self.save_interval) { self.save().await?; } Ok(()) } pub async fn save(&self) -> anyhow::Result<()> { let data = { let r = self.manifest.read().await; serde_json::to_vec(&*r)? }; tokio::fs::write(&self.storage_path, data).await?; Ok(()) } } /// 支持 Clone 的双原子计数器 #[derive(Clone)] pub struct AtomicCounters { inner: Arc, } struct CountersInner { success: AtomicUsize, failure: AtomicUsize, } impl AtomicCounters { pub fn new() -> Self { Self { inner: Arc::new(CountersInner { success: AtomicUsize::new(0), failure: AtomicUsize::new(0), }), } } pub fn inc_success(&self) { self.inner.success.fetch_add(1, Ordering::Relaxed); } pub fn inc_failure(&self) { self.inner.failure.fetch_add(1, Ordering::Relaxed); } pub fn get_success(&self) -> usize { self.inner.success.load(Ordering::Relaxed) } pub fn get_failure(&self) -> usize { self.inner.failure.load(Ordering::Relaxed) } // pub fn load(&self) -> (usize, usize) { // ( // self.inner.success.load(Ordering::Relaxed), // self.inner.failure.load(Ordering::Relaxed), // ) // } pub fn reset(&self) { self.inner.success.store(0, Ordering::Relaxed); self.inner.failure.store(0, Ordering::Relaxed); } }