carb/cpp/Semaphore.h
File members: carb/cpp/Semaphore.h
// Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//
#pragma once
#include "Atomic.h"
#include "../thread/Futex.h"
#include <algorithm>
#include <thread>
namespace carb
{
namespace cpp
{
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail
{
# if CARB_PLATFORM_WINDOWS
constexpr ptrdiff_t kSemaphoreValueMax = LONG_MAX;
# else
constexpr ptrdiff_t kSemaphoreValueMax = INT_MAX;
# endif
} // namespace detail
#endif
// Handle case where Windows.h may have defined 'max'
#pragma push_macro("max")
#undef max
template <ptrdiff_t least_max_value = detail::kSemaphoreValueMax>
class CARB_VIZ counting_semaphore
{
CARB_PREVENT_COPY_AND_MOVE(counting_semaphore);
public:
constexpr explicit counting_semaphore(ptrdiff_t desired) noexcept
: m_data(::carb_min(::carb_max(ptrdiff_t(0), desired), least_max_value))
{
static_assert(least_max_value >= 1, "semaphore needs a count of at least 1");
static_assert(least_max_value <= detail::kSemaphoreValueMax, "semaphore count too high");
}
~counting_semaphore() noexcept
{
#if CARB_PLATFORM_LINUX
// Make sure we don't have any waiters when we are destroyed
CARB_CHECK((m_data.load(std::memory_order_acquire) >> kWaitersShift) == 0, "Semaphore destroyed with waiters");
#endif
}
static constexpr ptrdiff_t max() noexcept
{
return least_max_value;
}
void release(ptrdiff_t update = 1) noexcept
{
CARB_ASSERT(update >= 0);
uint64_t d = m_data.load(std::memory_order_relaxed), u;
for (;;)
{
// The standard is somewhat unclear here. Preconditions are that update >= 0 is true and update <= max() -
// counter is true. And it throws system_error when an exception is required. So I supposed that it's likely
// that violating the precondition would cause a system_error exception which doesn't completely make sense
// (I would think runtime_error would make more sense). However, throwing at all is inconvenient, as is
// asserting/crashing/etc. Therefore, we clamp the update value here.
u = ::carb_min(update, max() - ptrdiff_t(d & kValueMask));
if (CARB_LIKELY(m_data.compare_exchange_weak(d, d + u, std::memory_order_release, std::memory_order_relaxed)))
break;
}
// At this point, the Semaphore could be destroyed by another thread. Therefore, we shouldn't access any other
// members (taking the address of m_data below is okay because that would not actually read any memory that
// may be destroyed)
// waiters with a value have been notified already by whatever thread added the value. Only wake threads that
// haven't been woken yet.
ptrdiff_t waiters = ptrdiff_t(d >> kWaitersShift);
ptrdiff_t value = ptrdiff_t(d & kValueMask);
ptrdiff_t wake = ::carb_min(ptrdiff_t(u), waiters - value);
if (wake > 0)
{
// cpp::atomic only has notify_one() and notify_all(). Call the futex system directly to wake N.
thread::futex::notify(m_data, unsigned(size_t(wake)), unsigned(size_t(waiters)));
}
}
void acquire() noexcept
{
if (CARB_LIKELY(fast_acquire(false)))
return;
// Register as a waiter
uint64_t d =
m_data.fetch_add(uint64_t(1) << kWaitersShift, std::memory_order_relaxed) + (uint64_t(1) << kWaitersShift);
for (;;)
{
if ((d & kValueMask) == 0)
{
// Need to wait
m_data.wait(d, std::memory_order_relaxed);
// Reload
d = m_data.load(std::memory_order_relaxed);
}
else
{
// Try to unregister as a waiter and grab a token at the same time
if (CARB_LIKELY(m_data.compare_exchange_weak(d, d - 1 - (uint64_t(1) << kWaitersShift),
std::memory_order_acquire, std::memory_order_relaxed)))
return;
}
}
}
bool try_acquire() noexcept
{
return fast_acquire(true);
}
template <class Rep, class Period>
bool try_acquire_for(const std::chrono::duration<Rep, Period>& duration) noexcept
{
if (CARB_LIKELY(fast_acquire(false)))
return true;
if (duration.count() <= 0)
return false;
// Register as a waiter
uint64_t d =
m_data.fetch_add(uint64_t(1) << kWaitersShift, std::memory_order_relaxed) + (uint64_t(1) << kWaitersShift);
while ((d & kValueMask) != 0)
{
// Try to unregister as a waiter and grab a token at the same time
if (CARB_LIKELY(m_data.compare_exchange_weak(
d, d - 1 - (uint64_t(1) << kWaitersShift), std::memory_order_acquire, std::memory_order_relaxed)))
return true;
}
// Now we need to wait, but do it with absolute time so that we properly handle spurious futex wakeups
auto time_point = std::chrono::steady_clock::now() + thread::detail::clampDuration(duration);
for (;;)
{
if (!m_data.wait_until(d, time_point, std::memory_order_relaxed))
{
// Timed out. Unregister as a waiter
m_data.fetch_sub(uint64_t(1) << kWaitersShift, std::memory_order_relaxed);
return false;
}
// Reload after wait
d = m_data.load(std::memory_order_relaxed);
if ((d & kValueMask) != 0)
{
// Try to unreference as a waiter and grab a token at the same time
if (CARB_LIKELY(m_data.compare_exchange_weak(d, d - 1 - (uint64_t(1) << kWaitersShift),
std::memory_order_acquire, std::memory_order_relaxed)))
return true;
}
}
}
template <class Clock, class Duration>
bool try_acquire_until(const std::chrono::time_point<Clock, Duration>& time_point) noexcept
{
if (CARB_LIKELY(fast_acquire(false)))
return true;
// Register as a waiter
uint64_t d =
m_data.fetch_add(uint64_t(1) << kWaitersShift, std::memory_order_relaxed) + (uint64_t(1) << kWaitersShift);
for (;;)
{
if ((d & kValueMask) == 0)
{
// Need to wait
if (!m_data.wait_until(d, time_point, std::memory_order_relaxed))
{
// Timed out. Unregister as a waiter
m_data.fetch_sub(uint64_t(1) << kWaitersShift, std::memory_order_relaxed);
return false;
}
// Reload after wait
d = m_data.load(std::memory_order_relaxed);
}
else
{
// Try to unregister as a waiter and grab a token at the same time
if (CARB_LIKELY(m_data.compare_exchange_weak(d, d - 1 - (uint64_t(1) << kWaitersShift),
std::memory_order_acquire, std::memory_order_relaxed)))
return true;
}
}
}
#ifndef DOXYGEN_SHOULD_SKIP_THIS
protected:
// The 32 most significant bits are the waiters; the lower 32 bits is the value of the semaphore
CARB_VIZ cpp::atomic_uint64_t m_data;
constexpr static int kWaitersShift = 32;
constexpr static unsigned kValueMask = 0xffffffff;
CARB_ALWAYS_INLINE bool fast_acquire(bool needResolution) noexcept
{
uint64_t d = m_data.load(needResolution ? std::memory_order_acquire : std::memory_order_relaxed);
for (;;)
{
if (uint32_t(d & kValueMask) == 0)
return false;
if (CARB_LIKELY(m_data.compare_exchange_weak(d, d - 1, std::memory_order_acquire, std::memory_order_relaxed)))
return true;
if (!needResolution)
return false;
}
}
#endif
};
#ifndef DOXYGEN_SHOULD_SKIP_THIS
template <>
class CARB_VIZ counting_semaphore<1>
{
CARB_PREVENT_COPY_AND_MOVE(counting_semaphore);
public:
static constexpr ptrdiff_t max() noexcept
{
return 1;
}
constexpr explicit counting_semaphore(ptrdiff_t desired) noexcept
: m_val(uint8_t(size_t(::carb_min(::carb_max(ptrdiff_t(0), desired), max()))))
{
}
void release(ptrdiff_t update = 1) noexcept
{
if (CARB_UNLIKELY(update <= 0))
return;
CARB_ASSERT(update == 1); // precondition failure
if (!m_val.exchange(1, std::memory_order_release))
m_val.notify_one();
}
void acquire() noexcept
{
for (;;)
{
uint8_t old = m_val.exchange(0, std::memory_order_acquire);
if (CARB_LIKELY(old == 1))
break;
CARB_ASSERT(old == 0); // m_val can only be 0 or 1
m_val.wait(0, std::memory_order_relaxed);
}
}
bool try_acquire() noexcept
{
uint8_t old = m_val.exchange(0, std::memory_order_acquire);
CARB_ASSERT(old <= 1); // m_val can only be 0 or 1
return old == 1;
}
template <class Rep, class Period>
bool try_acquire_for(const std::chrono::duration<Rep, Period>& duration) noexcept
{
return try_acquire_until(std::chrono::steady_clock::now() + thread::detail::clampDuration(duration));
}
template <class Clock, class Duration>
bool try_acquire_until(const std::chrono::time_point<Clock, Duration>& time_point) noexcept
{
for (;;)
{
uint8_t old = m_val.exchange(0, std::memory_order_acquire);
if (CARB_LIKELY(old == 1))
return true;
CARB_ASSERT(old == 0); // m_val can only be 0 or 1
if (!m_val.wait_until(0, time_point, std::memory_order_relaxed))
return false;
}
}
protected:
CARB_VIZ cpp::atomic_uint8_t m_val;
};
#endif
using binary_semaphore = counting_semaphore<1>;
#pragma pop_macro("max")
} // namespace cpp
} // namespace carb