Semaphore.h#

Fully qualified name: carb/cpp/Semaphore.h

File members: carb/cpp/Semaphore.h

// Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include "Atomic.h"
#include "../thread/Futex.h"

#include <atomic>
#include <chrono>

namespace carb
{
namespace cpp
{
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail
{

#    if CARB_PLATFORM_WINDOWS
constexpr ptrdiff_t kSemaphoreValueMax = LONG_MAX;
#    else
constexpr ptrdiff_t kSemaphoreValueMax = INT_MAX;
#    endif

} // namespace detail
#endif

// Handle case where Windows.h may have defined 'max'
#pragma push_macro("max")
#undef max

template <ptrdiff_t least_max_value = detail::kSemaphoreValueMax>
class CARB_VIZ counting_semaphore
{
    CARB_PREVENT_COPY_AND_MOVE(counting_semaphore);

public:
    constexpr explicit counting_semaphore(ptrdiff_t desired) noexcept
        : m_data((uint64_t)::carb_min(::carb_max(ptrdiff_t(0), desired), least_max_value))
    {
        CARB_ASSERT(desired >= 0 && desired <= least_max_value, "Precondition violation: N4950 [thread.sema.cnt]/5");
        static_assert(least_max_value >= 0, "semaphore with negative count is ill-formed");
        static_assert(least_max_value <= detail::kSemaphoreValueMax,
                      "semaphore count too high (must not exceed detail::kSemaphoreValueMax)");
    }

    ~counting_semaphore() noexcept
    {
#if CARB_PLATFORM_LINUX
        // Make sure we don't have any waiters when we are destroyed
        CARB_CHECK((m_data.load(std::memory_order_acquire) >> kWaitersShift) == 0, "Semaphore destroyed with waiters");
#endif
    }

    static constexpr ptrdiff_t max() noexcept
    {
        return least_max_value;
    }

    void release(ptrdiff_t update = 1) noexcept
    {
        CARB_FATAL_UNLESS(update >= 0 && update <= least_max_value, "Precondition violation: N4950 [thread.sema.cnt]/5");

        uint64_t d = m_data.fetch_add(update, std::memory_order_release);
        const ptrdiff_t value(int32_t(d & kValueMask));
        CARB_FATAL_UNLESS(value >= 0 && update <= (least_max_value - value),
                          "Precondition violation: update <= max() - counter (N4950 [thread.sema.cnt]/5)");

        // At this point, the Semaphore could be destroyed by another thread. Therefore, we shouldn't access any other
        // members (taking the reference of m_data below is okay because that would not actually read any memory that
        // may be destroyed)

        // Always issue a wake if there are waiters
        uint64_t const waiters = d >> kWaitersShift;
        if (waiters == 0)
        {
            // nothing to wake
        }
        else if (waiters <= size_t(update))
        {
            m_data.notify_all();
        }
        else
        {
            // Wake at most `update`. Since cpp::atomic does not have a way to notify `n`, we use futex directly.
            thread::futex::notify(m_data, unsigned(size_t(update)));
        }
    }

    void acquire() noexcept
    {
        CARB_LIKELY_IF(fast_acquire(false))
        {
            return;
        }

        // Register as a waiter
        uint64_t d = m_data.fetch_add(kOneWaiter, std::memory_order_relaxed) + kOneWaiter;
        for (;;)
        {
            if ((d & kValueMask) != 0)
            {
                // Try to unregister as a waiter and grab a token at the same time
                CARB_LIKELY_IF(m_data.compare_exchange_weak(
                    d, d - 1 - kOneWaiter, std::memory_order_acquire, std::memory_order_relaxed))
                {
                    return;
                }
            }
            else
            {
                // Need to wait
                m_data.wait(d, std::memory_order_relaxed);

                // Reload
                d = m_data.load(std::memory_order_relaxed);
            }
        }
    }

    bool try_acquire() noexcept
    {
        return fast_acquire(true);
    }

    template <class Rep, class Period>
    bool try_acquire_for(const std::chrono::duration<Rep, Period>& duration) noexcept
    {
        if (duration.count() <= 0)
            return try_acquire();

        CARB_LIKELY_IF(fast_acquire(false))
        {
            return true;
        }

        // Register as a waiter
        uint64_t d = m_data.fetch_add(kOneWaiter, std::memory_order_relaxed) + kOneWaiter;
        while ((d & kValueMask) != 0)
        {
            // Try to unregister as a waiter and grab a token at the same time
            CARB_LIKELY_IF(m_data.compare_exchange_weak(
                d, d - 1 - kOneWaiter, std::memory_order_acquire, std::memory_order_relaxed))
            {
                return true;
            }
        }

        // Now we need to wait, but do it with absolute time so that we properly handle spurious futex wakeups
        auto const deadline = cpp::detail::absTime<std::chrono::steady_clock>(duration);
        for (;;)
        {
            if ((d & kValueMask) != 0)
            {
                // Try to unreference as a waiter and grab a token at the same time
                CARB_LIKELY_IF(m_data.compare_exchange_weak(
                    d, d - 1 - kOneWaiter, std::memory_order_acquire, std::memory_order_relaxed))
                {
                    return true;
                }
            }
            else
            {
                // Need to wait
                CARB_UNLIKELY_IF(!m_data.wait_until(d, deadline, std::memory_order_relaxed))
                {
                    // Timed out. Unregister as a waiter
                    m_data.fetch_sub(kOneWaiter, std::memory_order_relaxed);
                    return false;
                }

                // Reload after wait
                d = m_data.load(std::memory_order_relaxed);
            }
        }
    }

    template <class Clock, class Duration>
    bool try_acquire_until(const std::chrono::time_point<Clock, Duration>& time_point) noexcept
    {
        CARB_LIKELY_IF(fast_acquire(false))
        {
            return true;
        }

        // Register as a waiter
        uint64_t d = m_data.fetch_add(kOneWaiter, std::memory_order_relaxed) + kOneWaiter;
        for (;;)
        {
            if ((d & kValueMask) != 0)
            {
                // Try to unregister as a waiter and grab a token at the same time
                CARB_LIKELY_IF(m_data.compare_exchange_weak(
                    d, d - 1 - kOneWaiter, std::memory_order_acquire, std::memory_order_relaxed))
                {
                    return true;
                }
            }
            else
            {
                // Need to wait
                CARB_UNLIKELY_IF(!m_data.wait_until(d, time_point, std::memory_order_relaxed))
                {
                    // Timed out. Unregister as a waiter
                    m_data.fetch_sub(kOneWaiter, std::memory_order_relaxed);
                    return false;
                }

                // Reload after wait
                d = m_data.load(std::memory_order_relaxed);
            }
        }
    }

private:
    // The 32 most significant bits are the waiters; the lower 32 bits is the value of the semaphore
    CARB_VIZ cpp::atomic_uint64_t m_data;
    constexpr static int kWaitersShift = 32;
    constexpr static uint64_t kOneWaiter = uint64_t(1) << kWaitersShift;
    constexpr static unsigned kValueMask = 0xffffffff;

    CARB_ALWAYS_INLINE bool fast_acquire(bool needResolution) noexcept
    {
        // According to the standard, release() strongly happens before invocations of try_acquire that observe the
        // result of the effects, so synchronize-with release() with memory_order_acquire here.
        uint64_t d = m_data.load(needResolution ? std::memory_order_acquire : std::memory_order_relaxed);
        for (;;)
        {
            if (!(d & kValueMask))
                return false;

            CARB_LIKELY_IF(m_data.compare_exchange_weak(d, d - 1, std::memory_order_acquire, std::memory_order_relaxed))
            {
                return true;
            }

            if (!needResolution)
                return false;
        }
    }
};

#ifndef DOXYGEN_SHOULD_SKIP_THIS
template <>
class CARB_VIZ counting_semaphore<1>
{
    CARB_PREVENT_COPY_AND_MOVE(counting_semaphore);

public:
    static constexpr ptrdiff_t max() noexcept
    {
        return 1;
    }

    constexpr explicit counting_semaphore(ptrdiff_t desired) noexcept
        : m_val(uint8_t(size_t(::carb_min(::carb_max(ptrdiff_t(0), desired), max()))))
    {
        CARB_ASSERT(
            desired == 0 || desired == 1, "Precondition violation: desired must be 0 or 1 (N4950 [thread.sema.cnt]/5)");
    }

    void release(ptrdiff_t update = 1) noexcept
    {
        if (update == 0)
            return;

        CARB_FATAL_UNLESS(update == 1, "Precondition violation: update must be 0 or 1 (N4950 [thread.sema.cnt]/5)");

        // Our notify_one() is very fast if nothing is actually waiting. Therefore, don't bother with a slower
        // exchange() here, just store and notify.
        m_val.store(1, std::memory_order_release);
        m_val.notify_one();
    }

    void acquire() noexcept
    {
        while (!try_acquire())
        {
            m_val.wait(0, std::memory_order_relaxed);
        }
    }

    bool try_acquire() noexcept
    {
        uint8_t old = m_val.exchange(0, std::memory_order_acquire);
        CARB_ASSERT(old <= 1); // m_val can only be 0 or 1 (possible precondition violation)
        return old == 1;
    }

    template <class Rep, class Period>
    bool try_acquire_for(const std::chrono::duration<Rep, Period>& duration) noexcept
    {
        // Attempt to acquire before we do the deadline calculations
        CARB_LIKELY_IF(try_acquire())
        {
            return true;
        }

        auto deadline = cpp::detail::absTime<std::chrono::steady_clock>(duration);
        do
        {
            CARB_UNLIKELY_IF(!m_val.wait_until(0, deadline, std::memory_order_relaxed))
            {
                return false;
            }
        } while (!try_acquire());

        return true;
    }

    template <class Clock, class Duration>
    bool try_acquire_until(const std::chrono::time_point<Clock, Duration>& time_point) noexcept
    {
        while (!try_acquire())
        {
            CARB_UNLIKELY_IF(!m_val.wait_until(0, time_point, std::memory_order_relaxed))
            {
                return false;
            }
        }

        return true;
    }

private:
    CARB_VIZ cpp::atomic_uint8_t m_val;
};
#endif

using binary_semaphore = counting_semaphore<1>;

#pragma pop_macro("max")

} // namespace cpp
} // namespace carb