use std::collections::VecDeque; use std::sync::{Arc, Condvar, Mutex, atomic::{AtomicUsize, Ordering}}; pub struct SharedQueue { inner: Arc>, } struct QueueInner { data: Mutex>, // 用于 pop 的阻塞 pop_cond: Condvar, // 用于“全部消费完”的阻塞 done_cond: Condvar, // 在途任务计数(队列中 + 正在处理中) pending: AtomicUsize, } /// 任务守卫:当它被释放时,说明消费彻底结束 pub struct TaskGuard { pub item: T, inner: Arc>, } impl Clone for TaskGuard { fn clone(&self) -> Self { // 关键:每多出一个 Guard 副本,就意味着多了一个需要等待的“消费行为” // 必须增加全局在途计数,否则会导致 pending 减成负数或提前归零 self.inner.pending.fetch_add(1, Ordering::SeqCst); Self { item: self.item.clone(), inner: self.inner.clone(), } } } impl Drop for TaskGuard { fn drop(&mut self) { // 1. 任务完成,计数减一 let prev = self.inner.pending.fetch_sub(1, Ordering::SeqCst); // 2. 如果减完后是 0,说明最后一项任务也处理完了 if prev == 1 { let _lock = self.inner.data.lock().unwrap(); self.inner.done_cond.notify_all(); } } } impl SharedQueue { pub fn new() -> Self { Self { inner: Arc::new(QueueInner { data: Mutex::new(VecDeque::new()), pop_cond: Condvar::new(), done_cond: Condvar::new(), pending: AtomicUsize::new(0), }), } } // pub fn push(&self, item: T) { // let mut queue = self.inner.data.lock().unwrap(); // // 增加在途计数 // self.inner.pending.fetch_add(1, Ordering::SeqCst); // queue.push_back(item); // self.inner.pop_cond.notify_one(); // } pub fn push_all(&self, items: impl IntoIterator) { let mut queue = self.inner.data.lock().unwrap(); let mut count = 0; for item in items { queue.push_back(item); count += 1; } if count > 0 { self.inner.pending.fetch_add(count, Ordering::SeqCst); self.inner.pop_cond.notify_all(); } } // pub fn pop(&self) -> TaskGuard { // let mut queue = self.inner.data.lock().unwrap(); // while queue.is_empty() { // queue = self.inner.pop_cond.wait(queue).unwrap(); // } // let item = queue.pop_front().unwrap(); // TaskGuard { // item, // inner: self.inner.clone(), // } // } pub fn try_pop(&self) -> Option> { let mut queue = self.inner.data.lock().unwrap(); queue.pop_front().map(|item| { TaskGuard { item, inner: self.inner.clone(), } }) } /// 阻塞当前线程,直到所有在途任务(pending == 0)全部处理完 pub fn wait_until_all_consumed(&self) { let mut _queue_lock = self.inner.data.lock().unwrap(); while self.inner.pending.load(Ordering::SeqCst) > 0 { _queue_lock = self.inner.done_cond.wait(_queue_lock).unwrap(); } } } impl Clone for SharedQueue { fn clone(&self) -> Self { Self { inner: Arc::clone(&self.inner) } } }