carb/cpp/Latch.h
File members: carb/cpp/Latch.h
// Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//
#pragma once
#include "Atomic.h"
#include <algorithm>
#include <thread>
namespace carb
{
namespace cpp
{
// Handle case where Windows.h may have defined 'max'
#pragma push_macro("max")
#undef max
class latch
{
CARB_PREVENT_COPY_AND_MOVE(latch);
public:
static constexpr ptrdiff_t max() noexcept
{
return ptrdiff_t(UINT_MAX);
}
constexpr explicit latch(ptrdiff_t expected) noexcept : m_counter(uint32_t(::carb_min(max(), expected)))
{
CARB_ASSERT(expected >= 0 && expected <= max());
}
~latch() noexcept
{
// Wait until we have no waiters
while (m_waiters.load(std::memory_order_acquire) != 0)
std::this_thread::yield();
}
void count_down(ptrdiff_t update = 1) noexcept
{
CARB_ASSERT(update >= 0);
// `fetch_sub` returns value before operation
uint32_t count = m_counter.fetch_sub(uint32_t(update), std::memory_order_release);
CARB_CHECK((count - uint32_t(update)) <= count); // Malformed if we go below zero or overflow
if ((count - uint32_t(update)) == 0)
{
// Wake all waiters
m_counter.notify_all();
}
}
// Returns whether the latch has completed. Allowed to return spuriously false with very low probability.
bool try_wait() const noexcept
{
return m_counter.load(std::memory_order_acquire) == 0;
}
void wait() const noexcept
{
uint32_t count = m_counter.load(std::memory_order_acquire);
if (count != 0)
{
// Register as a waiter
m_waiters.fetch_add(1, std::memory_order_relaxed);
_wait(count);
}
}
void arrive_and_wait(ptrdiff_t update = 1) noexcept
{
uint32_t original = m_counter.load(std::memory_order_acquire);
if (original == uint32_t(update))
{
// We're the last and won't be waiting.
#if CARB_ASSERT_ENABLED
uint32_t updated = m_counter.exchange(0, std::memory_order_release);
CARB_ASSERT(updated == original);
#else
m_counter.store(0, std::memory_order_release);
#endif
// Wake all waiters
m_counter.notify_all();
return;
}
// Speculatively register as a waiter
m_waiters.fetch_add(1, std::memory_order_relaxed);
original = m_counter.fetch_sub(uint32_t(update), std::memory_order_release);
if (CARB_UNLIKELY(original == uint32_t(update)))
{
// Wake all waiters and unregister as a waiter
m_counter.notify_all();
m_waiters.fetch_sub(1, std::memory_order_release);
}
else
{
CARB_CHECK(original >= uint32_t(update)); // Malformed if we underflow
_wait(original - uint32_t(update));
}
}
private:
mutable atomic_uint32_t m_counter;
mutable atomic_uint32_t m_waiters{ 0 };
CARB_ALWAYS_INLINE void _wait(uint32_t count) const noexcept
{
CARB_ASSERT(count != 0);
do
{
m_counter.wait(count, std::memory_order_relaxed);
count = m_counter.load(std::memory_order_acquire);
} while (count != 0);
// Done waiting
m_waiters.fetch_sub(1, std::memory_order_release);
}
};
#pragma pop_macro("max")
} // namespace cpp
} // namespace carb