carb/cpp/Latch.h
File members: carb/cpp/Latch.h
// Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//
#pragma once
#include "Atomic.h"
#include <algorithm>
#include <thread>
namespace carb
{
namespace cpp
{
// Handle case where Windows.h may have defined 'max'
#pragma push_macro("max")
#undef max
class latch
{
CARB_PREVENT_COPY_AND_MOVE(latch);
public:
CARB_NODISCARD static constexpr ptrdiff_t max() noexcept
{
return ptrdiff_t(INT_MAX);
}
constexpr explicit latch(ptrdiff_t expected) noexcept : m_counter(int32_t(::carb_min(max(), expected)))
{
CARB_ASSERT(expected >= 0, "latch::latch: Precondition violation: expected >= 0");
}
~latch() noexcept
{
// Wait until we have no waiters
while (m_waiters.load(std::memory_order_acquire) != 0)
std::this_thread::yield();
}
void count_down(ptrdiff_t update = 1) noexcept
{
CARB_ASSERT(update >= 0, "latch::count_down: Precondition violation: update >= 0");
// `fetch_sub` returns value before operation
auto count = m_counter.fetch_sub(int32_t(update), std::memory_order_release) - int32_t(update);
if (count == 0)
{
// Wake all waiters
m_counter.notify_all();
}
else
{
CARB_CHECK(count >= 0, "latch::count_down: Precondition violation: update <= counter");
}
}
// Returns whether the latch has completed. Allowed to return spuriously false with very low probability.
CARB_NODISCARD bool try_wait() const noexcept
{
return m_counter.load(std::memory_order_acquire) == 0;
}
void wait() const noexcept
{
int32_t count = m_counter.load(std::memory_order_acquire);
if (count != 0)
{
// Register as a waiter
m_waiters.fetch_add(1, std::memory_order_relaxed);
_wait(count);
}
}
void arrive_and_wait(ptrdiff_t update = 1) noexcept
{
CARB_ASSERT(update >= 0, "latch::arrive_and_wait: Precondition: update >= 0");
auto count = m_counter.fetch_sub(int32_t(update), std::memory_order_acq_rel) - int32_t(update);
if (count == 0)
{
// We're the last, wake all waiters
m_counter.notify_all();
}
else
{
CARB_CHECK(count >= 0, "latch::arrive_and_wait: Precondition: update <= counter");
// Register as a waiter
m_waiters.fetch_add(1, std::memory_order_relaxed);
_wait(count);
}
}
private:
atomic_int32_t m_counter;
mutable std::atomic_uint32_t m_waiters{ 0 };
CARB_ALWAYS_INLINE void _wait(int32_t count) const noexcept
{
CARB_ASSERT(count != 0);
do
{
m_counter.wait(count, std::memory_order_relaxed);
count = m_counter.load(std::memory_order_acquire);
} while (count != 0);
// Done waiting
m_waiters.fetch_sub(1, std::memory_order_release);
}
};
#pragma pop_macro("max")
} // namespace cpp
} // namespace carb