diff --git a/include/exec/static_thread_pool.hpp b/include/exec/static_thread_pool.hpp index 5f18060f7..fa8900af6 100644 --- a/include/exec/static_thread_pool.hpp +++ b/include/exec/static_thread_pool.hpp @@ -1749,7 +1749,8 @@ namespace experimental::execution std::size_t nthreads = this->pool_.available_parallelism(); bwos_params params = this->pool_.params(); std::size_t local_size = params.blockSize * params.numBlocks; - std::size_t chunk_size = __umin({size / nthreads, local_size * nthreads}); + std::size_t chunk_size = + __umax({std::size_t{1}, __umin({size / nthreads, local_size * nthreads})}); auto& remote_queue = *this->pool_.get_remote_queue(); auto it = std::ranges::begin(this->range_); std::size_t i0 = 0; @@ -1765,6 +1766,7 @@ namespace experimental::execution std::unique_lock lock{this->start_mutex_}; this->pool_.bulk_enqueue(remote_queue, std::move(this->tasks_), this->tasks_size_); + this->tasks_size_ = 0; lock.unlock(); i0 += chunk_size; } @@ -1778,6 +1780,7 @@ namespace experimental::execution std::unique_lock lock{this->start_mutex_}; this->has_started_ = true; this->pool_.bulk_enqueue(remote_queue, std::move(this->tasks_), this->tasks_size_); + this->tasks_size_ = 0; } }; diff --git a/test/exec/test_static_thread_pool.cpp b/test/exec/test_static_thread_pool.cpp index 4d38f825e..83d499c8b 100644 --- a/test/exec/test_static_thread_pool.cpp +++ b/test/exec/test_static_thread_pool.cpp @@ -1,8 +1,12 @@ #include "catch2/catch_all.hpp" +#include +#include #include #include +#include #include +#include #include #include namespace ex = STDEXEC; @@ -65,3 +69,22 @@ TEST_CASE("bulk on static_thread_pool executes on multiple threads, take 2", ex::sync_wait(std::move(sender)); REQUIRE(thread_ids.size() == num_of_threads); } + +TEST_CASE("schedule_all on static_thread_pool handles fewer items than threads", + "[types][static_thread_pool]") +{ + constexpr size_t const num_of_threads = 4; + exec::static_thread_pool pool{num_of_threads}; + std::array visited{}; + + auto sender = exec::schedule_all(pool, std::views::iota(size_t{0}, visited.size())) + | exec::transform_each(ex::then([&](size_t i) noexcept { visited[i] = true; })) + | exec::ignore_all_values(); + + ex::sync_wait(std::move(sender)); + + for (bool item_visited: visited) + { + REQUIRE(item_visited); + } +}