diff --git a/make/compiler_flags b/make/compiler_flags index 177bebca39b..f66d4687506 100644 --- a/make/compiler_flags +++ b/make/compiler_flags @@ -303,9 +303,6 @@ ifneq ($(OS), Windows_NT) LDFLAGS_TBB_RPATH ?= -Wl,-rpath,"$(TBB_LIB)" endif -LDFLAGS_TBB ?= -Wl,-L,"$(TBB_LIB)" $(LDFLAGS_TBB_DTAGS) $(LDFLAGS_TBB_RPATH) - -LDLIBS_TBB ?= -ltbb else @@ -342,7 +339,7 @@ endif # Sets up CXXFLAGS_THREADS to use threading ifdef STAN_THREADS - CXXFLAGS_THREADS ?= -DSTAN_THREADS + CXXFLAGS_THREADS ?= -DSTAN_THREADS -pthread endif ################################################################################ diff --git a/make/standalone b/make/standalone index 1dfc9203af1..ca6ffc5b4fd 100644 --- a/make/standalone +++ b/make/standalone @@ -22,7 +22,7 @@ MATH ?= $(realpath $(MATH_MAKE)..)/ # The sundials libraries are only needed for # programs using the stiff ode solver or the # algebra solver -MATH_LIBS ?= $(SUNDIALS_TARGETS) $(MPI_TARGETS) $(TBB_TARGETS) +MATH_LIBS ?= $(SUNDIALS_TARGETS) $(MPI_TARGETS) LDLIBS += $(MATH_LIBS) diff --git a/stan/math/opencl/concurrent_vector.hpp b/stan/math/opencl/concurrent_vector.hpp new file mode 100644 index 00000000000..1dcadd49152 --- /dev/null +++ b/stan/math/opencl/concurrent_vector.hpp @@ -0,0 +1,374 @@ +#ifndef STAN_MATH_OPENCL_CONCURRENT_VECTOR_HPP +#define STAN_MATH_OPENCL_CONCURRENT_VECTOR_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace math { +namespace internal { + +/** + * Segmented concurrent_vector. + * + * Properties: + * - concurrent emplace_back/push_back via atomic size counter + * - segmented storage => no moving elements during growth, stable addresses + * - segments allocated lazily; allocation uses CAS to avoid locks + * - movable (required for return-by-value usage such as vec_concat) + * + * Notes: + * - Intended for append-then-read patterns. + * - size_ increments before construction finishes. If readers iterate up to + * size() concurrently with writers, you need a constructed/published protocol. + * - clear()/destruction are NOT concurrent with pushes/reads. + */ +template +class concurrent_vector { + static_assert(BaseSegmentSize > 0, "BaseSegmentSize must be > 0"); + static_assert((BaseSegmentSize & (BaseSegmentSize - 1)) == 0, + "BaseSegmentSize must be a power of two"); + static_assert(MaxSegments > 0, "MaxSegments must be > 0"); + + public: + concurrent_vector() noexcept : size_(0) { + for (auto& p : segments_) + p.store(nullptr, std::memory_order_relaxed); + } + + concurrent_vector(const concurrent_vector& other) : concurrent_vector() { + copy_from_(other); + } + + concurrent_vector& operator=(const concurrent_vector& other) { + if (this != &other) { + clear(); + copy_from_(other); + } + return *this; + } + + // Movable (needed so Stan can return-by-value) + concurrent_vector(concurrent_vector&& other) noexcept : size_(0) { + for (auto& p : segments_) + p.store(nullptr, std::memory_order_relaxed); + move_from_(other); + } + + concurrent_vector& operator=(concurrent_vector&& other) noexcept { + if (this != &other) { + destroy_all_(); + for (auto& p : segments_) + p.store(nullptr, std::memory_order_relaxed); + size_.store(0, std::memory_order_relaxed); + move_from_(other); + } + return *this; + } + + ~concurrent_vector() noexcept { destroy_all_(); } + + std::size_t size() const noexcept { + return size_.load(std::memory_order_acquire); + } + + bool empty() const noexcept { return size() == 0; } + + // Not concurrent with pushes/reads. + void clear() { + destroy_all_(); + size_.store(0, std::memory_order_release); + } + + // Pre-allocate enough segments to back indices [0, capacity-1]. + // Safe to call concurrently with emplace_back (may race allocating segments; + // losers free). + void reserve(std::size_t capacity) { + if (capacity == 0) + return; + const std::size_t last = capacity - 1; + const std::size_t last_seg = segment_index_(last); + if (last_seg >= MaxSegments) { + throw std::length_error( + "concurrent_vector::reserve: exceeds MaxSegments"); + } + for (std::size_t s = 0; s <= last_seg; ++s) { + ensure_segment_(s); + } + } + + template + std::size_t emplace_back(Args&&... args) { + const std::size_t idx = size_.fetch_add(1, std::memory_order_acq_rel); + T* seg = ensure_segment_for_index_(idx); + T* slot = seg + offset_in_segment_(idx); + ::new (static_cast(slot)) T(std::forward(args)...); + return idx; + } + + std::size_t push_back(const T& v) { return emplace_back(v); } + std::size_t push_back(T&& v) { return emplace_back(std::move(v)); } + + // Pointer helper (no bounds check). + T* data_at(std::size_t i) noexcept { + T* seg = segment_ptr_(segment_index_(i)); + return seg + offset_in_segment_(i); + } + const T* data_at(std::size_t i) const noexcept { + const T* seg = segment_ptr_(segment_index_(i)); + return seg + offset_in_segment_(i); + } + + // Bounds-checked access. + T& at(std::size_t i) { + if (i >= size()) + throw std::out_of_range("concurrent_vector::at"); + return *data_at(i); + } + const T& at(std::size_t i) const { + if (i >= size()) + throw std::out_of_range("concurrent_vector::at"); + return *data_at(i); + } + + // Unchecked access. + T& operator[](std::size_t i) noexcept { return *data_at(i); } + const T& operator[](std::size_t i) const noexcept { return *data_at(i); } + + // ------------------------- + // Iterators (InputIterator is enough for std::vector(first,last)) + // ------------------------- + + class iterator { + public: + using iterator_category = std::input_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = T*; + using reference = T&; + + iterator() : v_(nullptr), i_(0) {} + iterator(concurrent_vector* v, std::size_t i) : v_(v), i_(i) {} + + reference operator*() const { return (*v_)[i_]; } + pointer operator->() const { return &(*v_)[i_]; } + + iterator& operator++() { + ++i_; + return *this; + } + iterator operator++(int) { + iterator tmp = *this; + ++(*this); + return tmp; + } + + friend bool operator==(const iterator& a, const iterator& b) { + return a.v_ == b.v_ && a.i_ == b.i_; + } + friend bool operator!=(const iterator& a, const iterator& b) { + return !(a == b); + } + + private: + concurrent_vector* v_; + std::size_t i_; + }; + + class const_iterator { + public: + using iterator_category = std::input_iterator_tag; + using value_type = T; + using difference_type = std::ptrdiff_t; + using pointer = const T*; + using reference = const T&; + + const_iterator() : v_(nullptr), i_(0) {} + const_iterator(const concurrent_vector* v, std::size_t i) : v_(v), i_(i) {} + + reference operator*() const { return (*v_)[i_]; } + pointer operator->() const { return &(*v_)[i_]; } + + const_iterator& operator++() { + ++i_; + return *this; + } + const_iterator operator++(int) { + const_iterator tmp = *this; + ++(*this); + return tmp; + } + + friend bool operator==(const const_iterator& a, const const_iterator& b) { + return a.v_ == b.v_ && a.i_ == b.i_; + } + friend bool operator!=(const const_iterator& a, const const_iterator& b) { + return !(a == b); + } + + private: + const concurrent_vector* v_; + std::size_t i_; + }; + + iterator begin() noexcept { return iterator(this, 0); } + iterator end() noexcept { + return iterator(this, size()); + } // snapshot at call time + + const_iterator begin() const noexcept { return const_iterator(this, 0); } + const_iterator end() const noexcept { return const_iterator(this, size()); } + + const_iterator cbegin() const noexcept { return const_iterator(this, 0); } + const_iterator cend() const noexcept { return const_iterator(this, size()); } + + T& back() { + const std::size_t n = size(); + if (n == 0) + throw std::out_of_range("concurrent_vector::back on empty"); + return (*this)[n - 1]; + } + + const T& back() const { + const std::size_t n = size(); + if (n == 0) + throw std::out_of_range("concurrent_vector::back on empty"); + return (*this)[n - 1]; + } + // ------------------------- + + private: + // Segment k has size BaseSegmentSize * 2^k + static constexpr std::size_t segment_size_(std::size_t k) noexcept { + return BaseSegmentSize << k; + } + + // Prefix elements before segment k: Base * (2^k - 1) + static constexpr std::size_t segment_prefix_(std::size_t k) noexcept { + return BaseSegmentSize * ((std::size_t{1} << k) - 1); + } + + // Map global index -> segment index + // Let q = idx / Base. Then segment = floor(log2(q + 1)). + static std::size_t segment_index_(std::size_t idx) noexcept { + const std::size_t q = idx / BaseSegmentSize; + const std::size_t x = q + 1; + +#if defined(__GNUG__) || defined(__clang__) + if constexpr (sizeof(std::size_t) == 8) { + return 63u + - static_cast( + __builtin_clzll(static_cast(x))); + } else { + return 31u + - static_cast( + __builtin_clzl(static_cast(x))); + } +#else + std::size_t s = 0; + std::size_t t = x; + while (t >>= 1) + ++s; + return s; +#endif + } + + static std::size_t offset_in_segment_(std::size_t idx) noexcept { + const std::size_t s = segment_index_(idx); + return idx - segment_prefix_(s); + } + + T* segment_ptr_(std::size_t s) noexcept { + return static_cast(segments_[s].load(std::memory_order_acquire)); + } + const T* segment_ptr_(std::size_t s) const noexcept { + return static_cast(segments_[s].load(std::memory_order_acquire)); + } + + T* ensure_segment_(std::size_t s) { + T* seg = segment_ptr_(s); + if (seg) + return seg; + + const std::size_t n = segment_size_(s); + void* raw = ::operator new(sizeof(T) * n); + T* fresh = static_cast(raw); + + void* expected = nullptr; + if (!segments_[s].compare_exchange_strong(expected, fresh, + std::memory_order_release, + std::memory_order_acquire)) { + ::operator delete(raw); + seg = segment_ptr_(s); + assert(seg != nullptr); + return seg; + } + return fresh; + } + + T* ensure_segment_for_index_(std::size_t idx) { + const std::size_t s = segment_index_(idx); + if (s >= MaxSegments) { + throw std::length_error("concurrent_vector: exceeded MaxSegments"); + } + return ensure_segment_(s); + } + + void destroy_all_() noexcept { + const std::size_t n = size_.load(std::memory_order_acquire); + + // Assumes [0, n) constructed. + for (std::size_t i = 0; i < n; ++i) { + data_at(i)->~T(); + } + + for (auto& a : segments_) { + void* p = a.exchange(nullptr, std::memory_order_acq_rel); + if (p) + ::operator delete(p); + } + } + + void move_from_(concurrent_vector& other) noexcept { + // Steal size + const std::size_t n = other.size_.exchange(0, std::memory_order_acq_rel); + size_.store(n, std::memory_order_release); + + // Steal segments + for (std::size_t s = 0; s < MaxSegments; ++s) { + void* p = other.segments_[s].exchange(nullptr, std::memory_order_acq_rel); + segments_[s].store(p, std::memory_order_release); + } + } + + void copy_from_(const concurrent_vector& other) { + const std::size_t n = other.size(); + if (n == 0) + return; + + reserve(n); + // Important: we want size_ to match, but we must construct elements. + // Use emplace_back so construction happens in this container. + for (std::size_t i = 0; i < n; ++i) { + emplace_back(other[i]); + } + } + + std::atomic size_; + std::array, MaxSegments> segments_; +}; + +} // namespace internal +} // namespace math +} // namespace stan + +#endif diff --git a/stan/math/opencl/kernel_cl.hpp b/stan/math/opencl/kernel_cl.hpp index 8f0b6a66e7d..3dda705d331 100644 --- a/stan/math/opencl/kernel_cl.hpp +++ b/stan/math/opencl/kernel_cl.hpp @@ -109,17 +109,19 @@ inline void assign_events(const cl::Event& new_event, CallArg& m, * @return A vector of OpenCL events. */ template * = nullptr> -inline tbb::concurrent_vector select_events(const T& m) { - return tbb::concurrent_vector{}; +inline stan::math::internal::concurrent_vector select_events( + const T& m) { + return stan::math::internal::concurrent_vector{}; } template * = nullptr, require_same_t* = nullptr> -inline const tbb::concurrent_vector& select_events(const K& m) { +inline const stan::math::internal::concurrent_vector& select_events( + const K& m) { return m.write_events(); } template * = nullptr, require_any_same_t* = nullptr> -inline tbb::concurrent_vector select_events(K& m) { +inline stan::math::internal::concurrent_vector select_events(K& m) { static_assert(!std::is_const::value, "Can not write to const matrix_cl!"); return m.read_write_events(); } diff --git a/stan/math/opencl/matrix_cl.hpp b/stan/math/opencl/matrix_cl.hpp index 4114c1199c8..3d3a4c6e436 100644 --- a/stan/math/opencl/matrix_cl.hpp +++ b/stan/math/opencl/matrix_cl.hpp @@ -12,7 +12,7 @@ #include #include #include -#include +#include #include #include #include @@ -51,8 +51,9 @@ class matrix_cl : public matrix_cl_base { int cols_{0}; // Number of columns. // Holds info on if matrix is a special type matrix_cl_view view_{matrix_cl_view::Entire}; - mutable tbb::concurrent_vector write_events_; // Tracks write jobs - mutable tbb::concurrent_vector read_events_; // Tracks reads + mutable internal::concurrent_vector + write_events_; // Tracks write jobs + mutable internal::concurrent_vector read_events_; // Tracks reads public: using Scalar = T; // Underlying type of the matrix @@ -100,7 +101,7 @@ class matrix_cl : public matrix_cl_base { * Get the events from the event stacks. * @return The write event stack. */ - inline const tbb::concurrent_vector& write_events() const { + inline const internal::concurrent_vector& write_events() const { return write_events_; } @@ -108,7 +109,7 @@ class matrix_cl : public matrix_cl_base { * Get the events from the event stacks. * @return The read/write event stack. */ - inline const tbb::concurrent_vector& read_events() const { + inline const internal::concurrent_vector& read_events() const { return read_events_; } @@ -116,7 +117,8 @@ class matrix_cl : public matrix_cl_base { * Get the events from the event stacks. * @return The read/write event stack. */ - inline const tbb::concurrent_vector read_write_events() const { + inline const internal::concurrent_vector read_write_events() + const { return vec_concat(this->read_events(), this->write_events()); } diff --git a/stan/math/opencl/opencl_context.hpp b/stan/math/opencl/opencl_context.hpp index e2373df126d..b0c435e33e0 100644 --- a/stan/math/opencl/opencl_context.hpp +++ b/stan/math/opencl/opencl_context.hpp @@ -14,7 +14,7 @@ #include #include -#include +#include #include #include #include @@ -208,7 +208,7 @@ class opencl_context_base { * The API to access the methods and values in opencl_context_base */ class opencl_context { - tbb::concurrent_vector kernel_caches_; + internal::concurrent_vector kernel_caches_; public: opencl_context() = default; diff --git a/stan/math/prim/functor/reduce_sum_static.hpp b/stan/math/prim/functor/reduce_sum_static.hpp index 7ad8648aad8..29f70f193d3 100644 --- a/stan/math/prim/functor/reduce_sum_static.hpp +++ b/stan/math/prim/functor/reduce_sum_static.hpp @@ -4,9 +4,6 @@ #include #include #include -#include -#include -#include #include #include diff --git a/stan/math/rev/core.hpp b/stan/math/rev/core.hpp index 9b7646c7d04..0dd91e8d9a6 100644 --- a/stan/math/rev/core.hpp +++ b/stan/math/rev/core.hpp @@ -63,6 +63,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/rev/core/simple_thread_pool.hpp b/stan/math/rev/core/simple_thread_pool.hpp new file mode 100644 index 00000000000..16749ee1e5b --- /dev/null +++ b/stan/math/rev/core/simple_thread_pool.hpp @@ -0,0 +1,160 @@ +#ifndef STAN_MATH_REV_CORE_SIMPLE_THREAD_POOL_HPP +#define STAN_MATH_REV_CORE_SIMPLE_THREAD_POOL_HPP + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace stan { +namespace math { + +class SimpleThreadPool { + public: + static SimpleThreadPool& instance() { + static SimpleThreadPool pool; + return pool; + } + + SimpleThreadPool(const SimpleThreadPool&) = delete; + SimpleThreadPool& operator=(const SimpleThreadPool&) = delete; + + std::size_t thread_count() const noexcept { return workers_.size(); } + + template + auto submit(F&& f, Args&&... args) + -> std::future> { + using R = std::invoke_result_t; + + auto task_ptr = std::make_shared>( + std::bind(std::forward(f), std::forward(args)...)); + + enqueue_([task_ptr] { (*task_ptr)(); }); + return task_ptr->get_future(); + } + + template + void parallel_region(std::size_t n, F&& fn) { + if (n == 0) + return; + + // Avoid nested parallelism deadlocks/oversubscription. + if (in_worker_) { + fn(std::size_t{0}); + return; + } + + const std::size_t tc = thread_count(); + if (tc == 0) { + fn(std::size_t{0}); + return; + } + if (n > tc) + n = tc; + + using Fn = std::decay_t; + struct Shared { + std::atomic remaining; + std::mutex m; + std::condition_variable cv; + Fn fn; + Shared(std::size_t n_, Fn&& f_) : remaining(n_), fn(std::move(f_)) {} + }; + + auto shared = std::make_shared(n, Fn(std::forward(fn))); + + for (std::size_t tid = 0; tid < n; ++tid) { + enqueue_([shared, tid] { + shared->fn(tid); + if (shared->remaining.fetch_sub(1, std::memory_order_acq_rel) == 1) { + std::lock_guard lk(shared->m); + shared->cv.notify_one(); + } + }); + } + + std::unique_lock lk(shared->m); + shared->cv.wait(lk, [&] { + return shared->remaining.load(std::memory_order_acquire) == 0; + }); + } + + private: + SimpleThreadPool() : done_(false) { + unsigned hw = std::thread::hardware_concurrency(); + if (hw == 0) + hw = 2; + const unsigned num_threads = hw; + + workers_.reserve(num_threads); + for (unsigned i = 0; i < num_threads; ++i) { + workers_.emplace_back([this] { + // Per-worker AD tape (TLS) initialized once. + static thread_local ChainableStack ad_tape; + + for (;;) { + std::function task; + { + std::unique_lock lock(mtx_); + cv_.wait(lock, [&] { return done_ || !tasks_.empty(); }); + if (done_ && tasks_.empty()) + return; + task = std::move(tasks_.front()); + tasks_.pop(); + } + + WorkerScope scope; // sets in_worker_ for all tasks + task(); + } + }); + } + } + + ~SimpleThreadPool() { + { + std::lock_guard lock(mtx_); + done_ = true; + } + cv_.notify_all(); + for (auto& th : workers_) { + if (th.joinable()) + th.join(); + } + } + + void enqueue_(std::function task) { + { + std::lock_guard lock(mtx_); + tasks_.emplace(std::move(task)); + } + cv_.notify_one(); + } + + struct WorkerScope { + WorkerScope() : prev_(in_worker_) { in_worker_ = true; } + ~WorkerScope() { in_worker_ = prev_; } + bool prev_; + }; + + static inline thread_local bool in_worker_ = false; + + std::vector workers_; + std::queue> tasks_; + std::mutex mtx_; + std::condition_variable cv_; + bool done_; +}; + +} // namespace math +} // namespace stan + +#endif diff --git a/stan/math/rev/core/team_thread_pool.hpp b/stan/math/rev/core/team_thread_pool.hpp new file mode 100644 index 00000000000..90db011cb7e --- /dev/null +++ b/stan/math/rev/core/team_thread_pool.hpp @@ -0,0 +1,312 @@ +#ifndef STAN_MATH_REV_CORE_TEAM_THREAD_POOL_HPP +#define STAN_MATH_REV_CORE_TEAM_THREAD_POOL_HPP + +#include + +#include +#include +#include +#include // getenv, strtol +#include // exception_ptr +#include +#include +#include +#include +#include + +namespace stan { +namespace math { + +/** + * TeamThreadPool + * + * - Fixed set of worker threads created once. + * - Caller participates as logical tid=0. + * - Worker threads have stable logical tids 1..(cap-1). + * - parallel_region(n, fn): runs fn(tid) for tid in [0, n), exactly once each. + * + * Notes: + * - Nested parallel_region calls from a worker run serially to avoid deadlock. + * - Uses an epoch counter + condition_variable to wake workers per region. + * - Startup barrier ensures all workers are waiting before the first region + * launch. + */ +class TeamThreadPool { + public: + // Total participants INCLUDING caller (tid=0). Call before instance(). + static void set_num_threads(std::size_t n) noexcept { + if (n < 1) + n = 1; + user_cap_().store(n, std::memory_order_release); + } + + static std::size_t get_num_threads() noexcept { + return user_cap_().load(std::memory_order_acquire); + } + + static TeamThreadPool& instance() { + static TeamThreadPool pool; + return pool; + } + + std::size_t worker_count() const noexcept { return workers_.size(); } + std::size_t team_size() const noexcept { return workers_.size() + 1; } + + template + void parallel_region(std::size_t n, F&& fn) { + if (n == 0) + return; + + // Prevent nested parallelism from deadlocking the pool. + if (in_worker_) { + fn(std::size_t{0}); + return; + } + + // Only one active region at a time (shared region state). + std::unique_lock region_lock(region_m_); + + const std::size_t max_team = team_size(); + if (max_team <= 1) { + fn(std::size_t{0}); + return; + } + if (n > max_team) + n = max_team; + if (n == 1) { + fn(std::size_t{0}); + return; + } + + using Fn = std::decay_t; + Fn fn_copy = std::forward(fn); + + // Exception propagation (first exception wins). + std::exception_ptr eptr = nullptr; + { + std::lock_guard lk(exc_m_); + exc_ptr_ = &eptr; + } + + // Publish region state BEFORE bumping epoch. + remaining_.store(n - 1, std::memory_order_release); // workers only + region_n_.store(n, std::memory_order_release); + region_ctx_.store(static_cast(&fn_copy), std::memory_order_release); + region_call_.store(&call_impl, std::memory_order_release); + + // Bump epoch to start the region, then wake workers. + const std::size_t new_epoch + = epoch_.fetch_add(1, std::memory_order_acq_rel) + 1; + + { + std::lock_guard lk(wake_m_); + // epoch_ already updated; the mutex pairs with the cv wait. + (void)new_epoch; + } + wake_cv_.notify_all(); + + // Caller participates as tid=0. + in_worker_ = true; + try { + fn_copy(0); + } catch (...) { + std::lock_guard lk(exc_m_); + if (eptr == nullptr) + eptr = std::current_exception(); + } + in_worker_ = false; + + // Wait for workers 1..n-1. + std::unique_lock lk(done_m_); + done_cv_.wait( + lk, [&] { return remaining_.load(std::memory_order_acquire) == 0; }); + + // Hygiene. + region_n_.store(0, std::memory_order_release); + + if (eptr) + std::rethrow_exception(eptr); + } + + private: + using call_fn_t = void (*)(void*, std::size_t); + + template + static void call_impl(void* ctx, std::size_t tid) { + (*static_cast(ctx))(tid); + } + + static std::atomic& user_cap_() { + static std::atomic cap{0}; // 0 => unset + return cap; + } + + static std::size_t env_num_threads_() noexcept { + const char* s = std::getenv("STAN_NUM_THREADS"); + if (!s || !*s) + return 0; + char* end = nullptr; + std::size_t v = static_cast(std::strtol(s, &end, 10)); + if (end == s || v <= 0) + return 0; + return v; + } + + static std::size_t configured_cap_(std::size_t hw) noexcept { + std::size_t cap = user_cap_().load(std::memory_order_acquire); + if (cap == 0) + cap = env_num_threads_(); + if (cap == 0) + cap = hw; + if (cap < 1) + cap = 1; + if (cap > hw) + cap = hw; + return cap; + } + + TeamThreadPool() + : stop_(false), + epoch_(0), + region_n_(0), + region_ctx_(nullptr), + region_call_(nullptr), + remaining_(0), + exc_ptr_(nullptr), + ready_count_(0) { + unsigned hw_u = std::thread::hardware_concurrency(); + if (hw_u == 0) + hw_u = 2; + const std::size_t hw = static_cast(hw_u); + + const std::size_t cap = configured_cap_(hw); + const std::size_t num_workers = (cap > 1) ? (cap - 1) : 0; + + workers_.reserve(num_workers); + for (std::size_t i = 0; i < num_workers; ++i) { + const std::size_t tid = i + 1; // workers are 1..num_workers + workers_.emplace_back([this, tid] { + // Per-worker AD tape initialized once. + static thread_local ChainableStack ad_tape; + (void)ad_tape; + + in_worker_ = true; + + // Startup barrier: ensure each worker has entered the wait loop once. + { + std::lock_guard lk(wake_m_); + ready_count_.fetch_add(1, std::memory_order_acq_rel); + } + ready_cv_.notify_one(); + + std::size_t seen_epoch = epoch_.load(std::memory_order_acquire); + + for (;;) { + // Wait for a new epoch (or stop). + { + std::unique_lock lk(wake_m_); + wake_cv_.wait(lk, [&] { + return stop_.load(std::memory_order_acquire) + || epoch_.load(std::memory_order_acquire) != seen_epoch; + }); + if (stop_.load(std::memory_order_acquire)) + break; + seen_epoch = epoch_.load(std::memory_order_acquire); + } + + const std::size_t n = region_n_.load(std::memory_order_acquire); + if (tid >= n) + continue; // not participating this region + + // Always decrement once for participating workers. + struct DoneGuard { + std::atomic& rem; + std::mutex& m; + std::condition_variable& cv; + ~DoneGuard() { + if (rem.fetch_sub(1, std::memory_order_acq_rel) == 1) { + std::lock_guard lk(m); + cv.notify_one(); + } + } + } guard{remaining_, done_m_, done_cv_}; + + void* ctx = region_ctx_.load(std::memory_order_acquire); + call_fn_t call = region_call_.load(std::memory_order_acquire); + + if (call) { + try { + call(ctx, tid); + } catch (...) { + std::lock_guard lk(exc_m_); + if (exc_ptr_ && *exc_ptr_ == nullptr) { + *exc_ptr_ = std::current_exception(); + } + } + } + } + + in_worker_ = false; + }); + } + + // Wait for all workers to reach the wait loop once before returning. + { + std::unique_lock lk(wake_m_); + ready_cv_.wait(lk, [&] { + return ready_count_.load(std::memory_order_acquire) == workers_.size(); + }); + } + } + + ~TeamThreadPool() { + stop_.store(true, std::memory_order_release); + { + std::lock_guard lk(wake_m_); + // bump epoch to ensure wake predicate flips + epoch_.fetch_add(1, std::memory_order_acq_rel); + } + wake_cv_.notify_all(); + + for (auto& t : workers_) { + if (t.joinable()) + t.join(); + } + } + + static inline thread_local bool in_worker_ = false; + + std::vector workers_; + std::atomic stop_; + + // Serialize regions. + std::mutex region_m_; + + // Region publication. + std::atomic epoch_; + std::atomic region_n_; + std::atomic region_ctx_; + std::atomic region_call_; + + // Wake workers. + std::mutex wake_m_; + std::condition_variable wake_cv_; + + // Startup barrier. + std::condition_variable ready_cv_; + std::atomic ready_count_; + + // Completion. + std::atomic remaining_; + std::mutex done_m_; + std::condition_variable done_cv_; + + // Exceptions. + std::mutex exc_m_; + std::exception_ptr* exc_ptr_; +}; + +} // namespace math +} // namespace stan + +#endif diff --git a/stan/math/rev/functor/map_rect_concurrent.hpp b/stan/math/rev/functor/map_rect_concurrent.hpp index 878261056fb..fe2dee2a380 100644 --- a/stan/math/rev/functor/map_rect_concurrent.hpp +++ b/stan/math/rev/functor/map_rect_concurrent.hpp @@ -7,11 +7,11 @@ #include #include #include - -#include -#include +#include #include +#include +#include #include namespace stan { @@ -32,7 +32,7 @@ map_rect_concurrent( = map_rect_reduce, T_job_param>; using CombineF = map_rect_combine; - const int num_jobs = job_params.size(); + const std::size_t num_jobs = job_params.size(); const vector_d shared_params_dbl = value_of(shared_params); std::vector job_output(num_jobs); std::vector world_f_out(num_jobs, 0); @@ -46,18 +46,25 @@ map_rect_concurrent( }; #ifdef STAN_THREADS - // we must use task isolation as described here: - // https://software.intel.com/content/www/us/en/develop/documentation/tbb-documentation/top/intel-threading-building-blocks-developer-guide/task-isolation.html - // this is to ensure that the thread local AD tape ressource is - // not being modified from a different task which may happen - // whenever this function is being used itself in a parallel - // context (like running multiple chains for Stan) - tbb::this_task_arena::isolate([&] { - tbb::parallel_for(tbb::blocked_range(0, num_jobs), - [&](const tbb::blocked_range& r) { - execute_chunk(r.begin(), r.end()); - }); - }); + auto& pool = stan::math::TeamThreadPool::instance(); + + // Total participants includes caller (tid=0). + const std::size_t max_team = pool.team_size(); + const std::size_t n + = std::min(max_team, num_jobs == 0 ? 1u : num_jobs); + + if (n <= 1 || num_jobs <= 1) { + execute_chunk(0, num_jobs); + } else { + pool.parallel_region(n, [&](std::size_t tid) { + const std::size_t nj = num_jobs; + const std::size_t b0 = (nj * tid) / n; + const std::size_t b1 = (nj * (tid + 1)) / n; + if (b0 < b1) { + execute_chunk(b0, b1); + } + }); + } #else execute_chunk(0, num_jobs); #endif diff --git a/stan/math/rev/functor/reduce_sum.hpp b/stan/math/rev/functor/reduce_sum.hpp index 25cbef1e073..814bc6e3f59 100644 --- a/stan/math/rev/functor/reduce_sum.hpp +++ b/stan/math/rev/functor/reduce_sum.hpp @@ -4,13 +4,14 @@ #include #include #include +#include -#include -#include -#include - -#include +#include +#include #include +#include +#include +#include #include #include @@ -18,214 +19,139 @@ namespace stan { namespace math { namespace internal { -/** - * Var specialization of reduce_sum_impl - * - * @tparam ReduceFunction Type of reducer function - * @tparam ReturnType Must be var - * @tparam Vec Type of sliced argument - * @tparam Args Types of shared arguments - */ template struct reduce_sum_impl, ReturnType, Vec, Args...> { struct scoped_args_tuple { ScopedChainableStack stack_; - using args_tuple_t - = std::tuple()))...>; + using args_tuple_t = std::tuple>()))...>; std::unique_ptr args_tuple_holder_; - scoped_args_tuple() : stack_(), args_tuple_holder_(nullptr) {} }; - /** - * This struct is used by the TBB to accumulate partial - * sums over consecutive ranges of the input. To distribute the workload, - * the TBB can split larger partial sums into smaller ones in which - * case the splitting copy constructor is used. It is designed to - * meet the Imperative form requirements of `tbb::parallel_reduce`. - * - * @note see link [here](https://tinyurl.com/vp7xw2t) for requirements. - */ struct recursive_reducer { - const size_t num_vars_per_term_; - const size_t num_vars_shared_terms_; // Number of vars in shared arguments - double* sliced_partials_; // Points to adjoints of the partial calculations - Vec vmapped_; - std::stringstream msgs_; - std::tuple args_tuple_; + using VecRef = std::decay_t; + using args_ptrs_t = std::tuple*...>; + + // Apply a callable to tuple of pointers, dereferencing each pointer. + template + static inline decltype(auto) apply_ptr_tuple_impl( + Fn&& fn, Tuple& t, std::index_sequence) { + return std::forward(fn)((*std::get(t))...); + } + template + static inline decltype(auto) apply_ptr_tuple(Fn&& fn, + std::tuple& t) { + return apply_ptr_tuple_impl(std::forward(fn), t, + std::index_sequence_for{}); + } + + const std::size_t num_vars_per_term_; + const std::size_t num_vars_shared_terms_; + double* sliced_partials_; + + const VecRef* vmapped_; + args_ptrs_t args_ptrs_; + + // msgs only if requested + std::unique_ptr msgs_; + std::ostream* msgs_out_; + scoped_args_tuple local_args_tuple_scope_; + double sum_{0.0}; - Eigen::VectorXd args_adjoints_{0}; + std::vector args_adjoints_; // faster than Eigen::VectorXd + + // Reusable buffer to avoid realloc per chunk + VecRef local_sub_slice_; + std::size_t reserved_ = 0; - template - recursive_reducer(size_t num_vars_per_term, size_t num_vars_shared_terms, - double* sliced_partials, VecT&& vmapped, ArgsT&&... args) + recursive_reducer(std::size_t num_vars_per_term, + std::size_t num_vars_shared_terms, + double* sliced_partials, const VecRef& vmapped, + std::ostream* msgs, const std::decay_t&... args) : num_vars_per_term_(num_vars_per_term), num_vars_shared_terms_(num_vars_shared_terms), sliced_partials_(sliced_partials), - vmapped_(std::forward(vmapped)), - args_tuple_(std::forward(args)...) {} + vmapped_(&vmapped), + args_ptrs_(&args...), + msgs_out_(msgs) { + if (msgs_out_) + msgs_ = std::make_unique(); + } - /* - * This is the copy operator as required for tbb::parallel_reduce - * Imperative form. This requires sum_ and arg_adjoints_ be reset - * to zero since the newly created reducer is used to accumulate - * an independent partial sum. - */ - recursive_reducer(recursive_reducer& other, tbb::split) - : num_vars_per_term_(other.num_vars_per_term_), - num_vars_shared_terms_(other.num_vars_shared_terms_), - sliced_partials_(other.sliced_partials_), - vmapped_(other.vmapped_), - args_tuple_(other.args_tuple_) {} - - /** - * Compute, using nested autodiff, the value and Jacobian of - * `ReduceFunction` called over the range defined by r and accumulate those - * in member variable sum_ (for the value) and args_adjoints_ (for the - * Jacobian). The nested autodiff uses deep copies of the involved operands - * ensuring that no side effects are implied to the adjoints of the input - * operands which reside potentially on a autodiff tape stored in a - * different thread other than the current thread of execution. This - * function may be called multiple times per object instantiation (so the - * sum_ and args_adjoints_ must be accumulated, not just assigned). - * - * @param r Range over which to compute reduce_sum - */ - inline void operator()(const tbb::blocked_range& r) { - if (r.empty()) { + inline void operator()(std::size_t begin, std::size_t end) { + if (begin >= end) return; - } - if (args_adjoints_.size() == 0) { - args_adjoints_ = Eigen::VectorXd::Zero(num_vars_shared_terms_); + if (args_adjoints_.empty()) { + args_adjoints_.assign(num_vars_shared_terms_, 0.0); } - // Obtain reference to a local copy of all shared arguments that do - // not point - // back to main autodiff stack - if (!local_args_tuple_scope_.args_tuple_holder_) { - // shared arguments need to be copied to reducer-specific - // scope. In this case no need for zeroing adjoints, since the - // fresh copy has all adjoints set to zero. local_args_tuple_scope_.stack_.execute([&]() { - math::apply( - [&](auto&&... args) { + apply_ptr_tuple( + [&](auto const&... a) { local_args_tuple_scope_.args_tuple_holder_ = std::make_unique< typename scoped_args_tuple::args_tuple_t>( - deep_copy_vars(args)...); + deep_copy_vars(a)...); }, - args_tuple_); + args_ptrs_); }); } else { - // set adjoints of shared arguments to zero local_args_tuple_scope_.stack_.execute([] { set_zero_all_adjoints(); }); } auto& args_tuple_local = *(local_args_tuple_scope_.args_tuple_holder_); - // Initialize nested autodiff stack const nested_rev_autodiff begin_nest; - // Create nested autodiff copies of sliced argument that do not point - // back to main autodiff stack - std::decay_t local_sub_slice; - local_sub_slice.reserve(r.size()); - for (size_t i = r.begin(); i < r.end(); ++i) { - local_sub_slice.emplace_back(deep_copy_vars(vmapped_[i])); + // Reuse per-worker buffer + const std::size_t n = end - begin; + local_sub_slice_.clear(); + if (reserved_ < n) { + local_sub_slice_.reserve(n); + reserved_ = n; + } + + for (std::size_t i = begin; i < end; ++i) { + local_sub_slice_.emplace_back(deep_copy_vars((*vmapped_)[i])); } - // Perform calculation + std::ostream* local_msgs + = msgs_ ? static_cast(msgs_.get()) : nullptr; + var sub_sum_v = math::apply( - [&](auto&&... args) { - return ReduceFunction()(local_sub_slice, r.begin(), r.end() - 1, - &msgs_, args...); + [&](auto&&... args_local) { + return ReduceFunction()(local_sub_slice_, begin, end - 1, + local_msgs, args_local...); }, args_tuple_local); - // Compute Jacobian sub_sum_v.grad(); - - // Accumulate value of reduce_sum sum_ += sub_sum_v.val(); - // Accumulate adjoints of sliced_arguments - accumulate_adjoints(sliced_partials_ + r.begin() * num_vars_per_term_, - std::move(local_sub_slice)); + accumulate_adjoints(sliced_partials_ + begin * num_vars_per_term_, + std::move(local_sub_slice_)); + + // local_sub_slice_ got moved-from; restore it to a valid empty state + local_sub_slice_.clear(); - // Accumulate adjoints of shared_arguments math::apply( - [&](auto&&... args) { - accumulate_adjoints(args_adjoints_.data(), args...); + [&](auto&&... args_local) { + accumulate_adjoints(args_adjoints_.data(), args_local...); }, args_tuple_local); } - - /** - * Join reducers. Accumuluate the value (sum_) and Jacobian (arg_adoints_) - * of the other reducer. - * - * @param rhs Another partial sum - */ - inline void join(const recursive_reducer& rhs) { - sum_ += rhs.sum_; - if (args_adjoints_.size() != 0 && rhs.args_adjoints_.size() != 0) { - args_adjoints_ += rhs.args_adjoints_; - } else if (args_adjoints_.size() == 0 && rhs.args_adjoints_.size() != 0) { - args_adjoints_ = rhs.args_adjoints_; - } - msgs_ << rhs.msgs_.str(); - } }; - /** - * Call an instance of the function `ReduceFunction` on every element - * of an input sequence and sum these terms. - * - * This specialization is parallelized using tbb and works for reverse - * mode autodiff. - * - * ReduceFunction must define an operator() with the same signature as: - * var f(Vec&& vmapped_subset, int start, int end, std::ostream* msgs, - * Args&&... args) - * - * `ReduceFunction` must be default constructible without any arguments - * - * Each call to `ReduceFunction` is responsible for computing the - * start through end (inclusive) terms of the overall sum. All args are - * passed from this function through to the `ReduceFunction` instances. - * However, only the start through end (inclusive) elements of the vmapped - * argument are passed to the `ReduceFunction` instances (as the - * `vmapped_subset` argument). - * - * This function distributes computation of the desired sum and the Jacobian - * of that sum over multiple threads by coordinating calls to `ReduceFunction` - * instances. Results are stored as precomputed varis in the autodiff tree. - * - * If auto partitioning is true, break work into pieces automatically, - * taking grainsize as a recommended work size. The partitioning is - * not deterministic nor is the order guaranteed in which partial - * sums are accumulated. Due to floating point imprecisions this will likely - * lead to slight differences in the accumulated results between - * multiple runs. If false, break work deterministically into pieces smaller - * than or equal to grainsize and accumulate all the partial sums - * in the same order. This still may not achieve bitwise reproducibility. - * - * @param vmapped Vector containing one element per term of sum - * @param auto_partitioning Work partitioning style - * @param grainsize Suggested grainsize for tbb - * @param[in, out] msgs The print stream for warning messages - * @param args Shared arguments used in every sum term - * @return Summation of all terms - */ - inline var operator()(Vec&& vmapped, bool auto_partitioning, int grainsize, - std::ostream* msgs, Args&&... args) const { - if (vmapped.empty()) { + inline var operator()(Vec&& vmapped, bool /*auto_partitioning*/, + int grainsize, std::ostream* msgs, + Args&&... args) const { + if (vmapped.empty()) return var(0.0); - } const std::size_t num_terms = vmapped.size(); const std::size_t num_vars_per_term = count_vars(vmapped[0]); @@ -240,47 +166,117 @@ struct reduce_sum_impl, ReturnType, save_varis(varis, vmapped); save_varis(varis + num_vars_sliced_terms, args...); - for (size_t i = 0; i < num_vars_sliced_terms; ++i) { + for (std::size_t i = 0; i < num_vars_sliced_terms; ++i) partials[i] = 0.0; + + auto& pool = stan::math::TeamThreadPool::instance(); + const std::size_t max_team = pool.team_size(); + + // Choose workers. (Caller participates, so total participants = n) + std::size_t n + = std::min(max_team, num_terms == 0 ? 1 : num_terms); + if (n < 1) + n = 1; + + // Chunking: default to ~2 chunks per participant (lower overhead). + std::size_t gs; + if (grainsize > 0) { + gs = static_cast(grainsize); + if (gs < 1) + gs = 1; + } else { + const std::size_t target_chunks = n * 2; + gs = (num_terms + target_chunks - 1) / target_chunks; + if (gs < 1) + gs = 1; } - recursive_reducer worker(num_vars_per_term, num_vars_shared_terms, partials, - std::forward(vmapped), - std::forward(args)...); - - // we must use task isolation as described here: - // https://software.intel.com/content/www/us/en/develop/documentation/tbb-documentation/top/intel-threading-building-blocks-developer-guide/task-isolation.html - // this is to ensure that the thread local AD tape ressource is - // not being modified from a different task which may happen - // whenever this function is being used itself in a parallel - // context (like running multiple chains for Stan) - tbb::this_task_arena::isolate([&] { - if (auto_partitioning) { - tbb::parallel_reduce( - tbb::blocked_range(0, num_terms, grainsize), worker); + // Serial cutoff: if too few terms, don't parallelize. + if (n == 1 || num_terms <= gs) { + recursive_reducer r(num_vars_per_term, num_vars_shared_terms, partials, + vmapped, msgs, args...); + r(0, num_terms); + + // write shared adjoints + if (!r.args_adjoints_.empty()) { + for (std::size_t i = 0; i < num_vars_shared_terms; ++i) { + partials[num_vars_sliced_terms + i] = r.args_adjoints_[i]; + } } else { - tbb::simple_partitioner partitioner; - tbb::parallel_deterministic_reduce( - tbb::blocked_range(0, num_terms, grainsize), worker, - partitioner); + for (std::size_t i = 0; i < num_vars_shared_terms; ++i) { + partials[num_vars_sliced_terms + i] = 0.0; + } + } + + if (msgs && r.msgs_) + *msgs << r.msgs_->str(); + + return var(new precomputed_gradients_vari( + r.sum_, num_vars_sliced_terms + num_vars_shared_terms, varis, + partials)); + } + + // One reducer per participant (0..n-1) for static partitioning. + // NOTE: we avoid copying vmapped/args by taking references/pointers inside + // reducer. + std::vector> workers; + workers.reserve(n); + for (std::size_t tid = 0; tid < n; ++tid) { + workers.emplace_back(std::make_unique( + num_vars_per_term, num_vars_shared_terms, partials, vmapped, msgs, + args...)); + } + /* + std::cout << + "--------------------------------------------------------------------------------" + << std::endl + << "worker count = " << pool.worker_count() << std::endl + << "team size = " << pool.team_size() << std::endl + << "gs = " << gs << std::endl + << std::endl << std::endl; + */ + + // Static partition: each participant gets a contiguous block once + pool.parallel_region(n, [&](std::size_t tid) { + const std::size_t b0 = (num_terms * tid) / n; + const std::size_t b1 = (num_terms * (tid + 1)) / n; + if (b0 < b1) { + (*workers[tid])(b0, b1); } }); - for (size_t i = 0; i < num_vars_shared_terms; ++i) { - partials[num_vars_sliced_terms + i] = worker.args_adjoints_.coeff(i); + // Aggregate + double total_sum = 0.0; + std::vector shared_adj(num_vars_shared_terms, 0.0); + std::stringstream all_msgs; + + for (auto& wptr : workers) { + auto& w = *wptr; + total_sum += w.sum_; + if (!w.args_adjoints_.empty()) { + for (std::size_t i = 0; i < num_vars_shared_terms; ++i) { + shared_adj[i] += w.args_adjoints_[i]; + } + } + if (msgs && w.msgs_) { + all_msgs << w.msgs_->str(); + } } - if (msgs) { - *msgs << worker.msgs_.str(); + for (std::size_t i = 0; i < num_vars_shared_terms; ++i) { + partials[num_vars_sliced_terms + i] = shared_adj[i]; } + if (msgs) + *msgs << all_msgs.str(); + return var(new precomputed_gradients_vari( - worker.sum_, num_vars_sliced_terms + num_vars_shared_terms, varis, + total_sum, num_vars_sliced_terms + num_vars_shared_terms, varis, partials)); } }; -} // namespace internal +} // namespace internal } // namespace math } // namespace stan