C++11訊號量實現
阿新 • • 發佈:2019-02-19
#pragma once #include <mutex> #include <condition_variable> #include <algorithm> class semaphore { public: struct closed_exception {}; public: explicit semaphore(size_t cnt = 0) : m_cnt(cnt) , m_opened(true) {} void open() { std::lock_guard<std::mutex> _(m_mtx); m_opened = true; } void close() { std::lock_guard<std::mutex> _(m_mtx); m_opened = false; m_evt.notify_all(); } void wait() { std::unique_lock<std::mutex> lck(m_mtx); m_evt.wait(lck, [this] { if (!m_opened) { throw closed_exception(); } return m_cnt > 0; }); --m_cnt; } void post(size_t n = 1) { std::unique_lock<std::mutex> lck(m_mtx); m_cnt += n; m_evt.notify(lck, n); } protected: class guard_ { public: explicit guard_(size_t& waiters) : waiters_(waiters) { ++waiters_; } ~guard_() { --waiters_; } private: size_t & waiters_; }; class event_ { public: void wait(std::unique_lock<std::mutex>& lck) { guard_ _(m_waiters); m_cnd.wait(lck); } template <typename F> void wait(std::unique_lock<std::mutex>& lck, F f) { guard_ _(m_waiters); m_cnd.wait(lck, f); } void notify(std::unique_lock<std::mutex>& lck, size_t n = 1) { auto times = std::min(n, m_waiters); for (size_t i = 0; i < times; i++) { m_cnd.notify_one(); } } void notify_all() { m_cnd.notify_all(); } private: std::condition_variable m_cnd; size_t m_waiters{ 0 }; }; private: std::mutex m_mtx; event_ m_evt; size_t m_cnt{ 0 }; bool m_opened{ false }; };
測試程式碼
#include <thread> #include <iostream> #include "semaphore.h" int main(int argc, char* argv[]) { semaphore sem; size_t cnt = 0; std::thread thds[2]; for (size_t i = 0; i < 2; i++) { thds[i] = std::move(std::thread(([&] { try { for (;;) { sem.wait(); std::cout << "thread:" << std::this_thread::get_id() << ", semaphore post: " << cnt++ << std::endl; } } catch (const semaphore::closed_exception&) { std::cout << "thread:" << std::this_thread::get_id() << ", semaphore closed" << std::endl; } }))); } for (size_t i = 0; i < 10; i++) { std::this_thread::sleep_for(std::chrono::seconds(1)); sem.post(); } sem.close(); for (size_t i = 0; i < 2; i++) { thds[i].join(); } return 1; }