diff --git a/include/exec/static_thread_pool.hpp b/include/exec/static_thread_pool.hpp index 5f18060f7..42746b6d6 100644 --- a/include/exec/static_thread_pool.hpp +++ b/include/exec/static_thread_pool.hpp @@ -1747,9 +1747,13 @@ namespace experimental::execution { std::size_t size = items_.size(); std::size_t nthreads = this->pool_.available_parallelism(); + STDEXEC_ASSERT(nthreads > 0); bwos_params params = this->pool_.params(); std::size_t local_size = params.blockSize * params.numBlocks; - std::size_t chunk_size = __umin({size / nthreads, local_size * nthreads}); + std::size_t chunk_size = size == 0 + ? 0 + : __umax({std::size_t{1}, + __umin({size / nthreads, local_size * nthreads})}); auto& remote_queue = *this->pool_.get_remote_queue(); auto it = std::ranges::begin(this->range_); std::size_t i0 = 0; @@ -1765,6 +1769,7 @@ namespace experimental::execution std::unique_lock lock{this->start_mutex_}; this->pool_.bulk_enqueue(remote_queue, std::move(this->tasks_), this->tasks_size_); + this->tasks_size_ = 0; lock.unlock(); i0 += chunk_size; } diff --git a/test/exec/test_static_thread_pool.cpp b/test/exec/test_static_thread_pool.cpp index 4d38f825e..d1f7d3299 100644 --- a/test/exec/test_static_thread_pool.cpp +++ b/test/exec/test_static_thread_pool.cpp @@ -1,8 +1,12 @@ #include "catch2/catch_all.hpp" +#include "exec/sequence/ignore_all_values.hpp" +#include "exec/sequence/transform_each.hpp" #include #include +#include #include +#include #include #include namespace ex = STDEXEC; @@ -45,6 +49,32 @@ TEST_CASE("bulk on static_thread_pool executes on multiple threads", "[types][st REQUIRE(thread_ids.size() == num_of_threads); } +TEST_CASE("schedule_all on static_thread_pool handles ranges smaller than available parallelism", + "[types][static_thread_pool]") +{ + constexpr size_t const num_of_threads = 5; + constexpr int const range_size = 3; + + exec::static_thread_pool pool{num_of_threads}; + REQUIRE(range_size < pool.available_parallelism()); + + std::atomic count{0}; + std::atomic sum{0}; + auto sender = + exec::schedule_all(pool, std::views::iota(0, range_size)) + | exec::transform_each(ex::then( + [&](int x) noexcept + { + count.fetch_add(1, std::memory_order_relaxed); + sum.fetch_add(x, std::memory_order_relaxed); + })) + | exec::ignore_all_values(); + + CHECK(ex::sync_wait(std::move(sender))); + CHECK(count.load(std::memory_order_relaxed) == range_size); + CHECK(sum.load(std::memory_order_relaxed) == 3); +} + TEST_CASE("bulk on static_thread_pool executes on multiple threads, take 2", "[types][static_thread_pool]") {