diff --git a/tests/test_thread.cpp b/tests/test_thread.cpp index 9cc11228fd0..69901c0bb8d 100644 --- a/tests/test_thread.cpp +++ b/tests/test_thread.cpp @@ -116,6 +116,7 @@ class threadpool : public dnnl::threadpool_interop::threadpool_iface { #include "tbb/parallel_for.h" #include "tbb/task_arena.h" +#include "src/cpu/platform.hpp" namespace dnnl { namespace testing { @@ -123,7 +124,8 @@ class threadpool : public dnnl::threadpool_interop::threadpool_iface { public: explicit threadpool(int num_threads = 0) { (void)num_threads; } int get_num_threads() const override { - return tbb::this_task_arena::max_concurrency(); + return std::min(tbb::this_task_arena::max_concurrency(), + (int)dnnl::impl::cpu::platform::get_max_threads_to_use()); } bool get_in_parallel() const override { return 0; } uint64_t get_flags() const override { return 0; }