5 #include <condition_variable>
11 #include <shared_mutex>
15 #ifdef NDEBUG // assert needs to work even in release mode
20 using std::chrono::steady_clock
;
22 class BenchHarnessBase::ThreadCache final
25 std::vector
<std::thread
> threads
;
26 std::shared_mutex state_lock
;
27 std::unique_lock
<std::shared_mutex
> locked_state
;
28 std::condition_variable_any cond_var
;
29 struct UnlockGuard final
31 std::shared_mutex
&state_lock
;
32 UnlockGuard(std::shared_mutex
&state_lock
) : state_lock(state_lock
)
43 std::function
<void()> fn
;
45 struct ThreadState final
47 std::unique_ptr
<Task
> task
;
50 std::vector
<std::shared_ptr
<ThreadState
>> states
;
51 bool shutting_down
= false;
52 std::atomic_size_t tasks_left_to_drain
= 0;
55 auto thread_state
= std::make_shared
<ThreadState
>();
56 states
.push_back(thread_state
);
57 threads
.push_back(std::thread([this, thread_state
]() {
58 auto shared_lock
= std::shared_lock(state_lock
);
61 auto lock
= std::unique_lock(thread_state
->mutex
);
62 auto task
= std::move(thread_state
->task
);
68 tasks_left_to_drain
--;
69 cond_var
.notify_all();
73 if (this->shutting_down
)
76 cond_var
.wait(shared_lock
);
84 locked_state
= std::unique_lock(state_lock
);
86 ThreadCache(const ThreadCache
&) = delete;
87 ThreadCache
&operator=(const ThreadCache
&) = delete;
91 cond_var
.notify_all();
92 locked_state
.unlock();
93 for (auto &thread
: threads
)
98 static std::shared_ptr
<ThreadCache
> get()
100 // weak so it's destroyed before returning from main()
101 static std::weak_ptr
<ThreadCache
> static_thread_cache
;
103 std::shared_ptr
<ThreadCache
> thread_cache
= static_thread_cache
.lock();
106 thread_cache
= std::make_shared
<ThreadCache
>();
107 static_thread_cache
= thread_cache
;
111 static std::shared_ptr
<ThreadCache
> get(BenchHarnessBase
&bhb
,
112 std::uint32_t thread_count
)
114 std::shared_ptr
<ThreadCache
> thread_cache
= get();
115 bhb
.thread_cache
= thread_cache
;
116 while (thread_cache
->threads
.size() < thread_count
)
117 thread_cache
->add_thread();
122 while (tasks_left_to_drain
> 0)
124 // unlocks state_lock, allowing all threads to proceed
126 cond_var
.wait(locked_state
);
129 template <typename Fn
> void schedule_on(std::uint32_t thread_num
, Fn fn
)
131 auto lock
= std::unique_lock(states
[thread_num
]->mutex
);
132 assert(!states
[thread_num
]->task
);
133 tasks_left_to_drain
++;
134 states
[thread_num
]->task
= std::make_unique
<Task
>(Task
{.fn
= fn
});
135 cond_var
.notify_all();
139 struct WriteDuration final
141 std::chrono::duration
<double> dur
;
142 friend std::ostream
&operator<<(std::ostream
&os
,
143 const WriteDuration
&wdur
)
145 double dur
= wdur
.dur
.count();
146 if (!std::isfinite(dur
) || std::fabs(dur
) > 0.1)
150 else if (std::fabs(dur
) > 0.1e-3)
152 os
<< dur
* 1e3
<< " ms";
154 else if (std::fabs(dur
) > 0.1e-6)
156 os
<< dur
* 1e6
<< " us";
158 else if (std::fabs(dur
) > 0.1e-9)
160 os
<< dur
* 1e9
<< " ns";
162 else if (std::fabs(dur
) > 0.1e-12)
164 os
<< dur
* 1e12
<< " ps";
174 void BenchHarnessBase::base_run(
176 void (*fn
)(BenchHarnessBase
*bench_harness_base
,
177 std::uint64_t iteration_count
, std::uint32_t thread_num
))
180 std::uint32_t thread_count
=
181 config
.thread_count
.value_or(std::thread::hardware_concurrency());
183 thread_count
== 0 || (thread_count
== 1 && !config
.thread_count
);
189 std::vector
<steady_clock::duration
> elapsed(thread_count
);
190 auto run_base
= [&](std::uint64_t iteration_count
,
191 std::uint32_t thread_num
) {
192 auto start_time
= steady_clock::now();
193 fn(this, iteration_count
, thread_num
);
194 auto end_time
= steady_clock::now();
195 elapsed
[thread_num
] = end_time
- start_time
;
197 auto run
= [&](std::uint64_t iteration_count
) {
200 return run_base(iteration_count
, 0);
202 auto thread_cache
= ThreadCache::get(*this, thread_count
);
203 for (std::uint32_t thread_num
= 0; thread_num
< thread_count
;
206 thread_cache
->schedule_on(
207 thread_num
, [&run_base
, iteration_count
, thread_num
]() {
208 run_base(iteration_count
, thread_num
);
211 thread_cache
->drain();
213 std::uint64_t iteration_count
= 1;
214 if (config
.iteration_count
)
216 iteration_count
= *config
.iteration_count
;
217 run(iteration_count
);
223 run(iteration_count
);
224 steady_clock::duration total_elapsed
{};
225 for (auto i
: elapsed
)
229 auto target_average_elapsed
= std::chrono::milliseconds(500);
230 if (total_elapsed
> thread_count
* target_average_elapsed
)
234 iteration_count
<<= 1;
237 steady_clock::duration total_elapsed
{};
238 for (std::uint32_t thread_num
= 0; thread_num
< thread_count
; thread_num
++)
240 total_elapsed
+= elapsed
[thread_num
];
241 if (thread_count
> 1)
243 auto dur
= std::chrono::duration
<double>(elapsed
[thread_num
]);
244 std::cout
<< "Thread #" << thread_num
<< " took "
245 << WriteDuration
{dur
} << " for " << iteration_count
247 << WriteDuration
{dur
/ iteration_count
} << "/iter.\n";
250 auto total
= std::chrono::duration
<double>(total_elapsed
);
251 std::cout
<< "Average elapsed time: "
252 << WriteDuration
{total
/ thread_count
} << " for "
253 << iteration_count
<< " iterations -- "
254 << WriteDuration
{total
/ thread_count
/ iteration_count
}
259 std::shared_ptr
<void> BenchHarnessBase::get_thread_cache()
261 return ThreadCache::get();