carb/cpp/Latch.h

File members: carb/cpp/Latch.h

// Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include "Atomic.h"

#include <algorithm>
#include <thread>

namespace carb
{
namespace cpp
{

// Handle case where Windows.h may have defined 'max'
#pragma push_macro("max")
#undef max

class latch
{
    CARB_PREVENT_COPY_AND_MOVE(latch);

public:
    CARB_NODISCARD static constexpr ptrdiff_t max() noexcept
    {
        return ptrdiff_t(INT_MAX);
    }

    constexpr explicit latch(ptrdiff_t expected) noexcept : m_counter(int32_t(::carb_min(max(), expected)))
    {
        CARB_ASSERT(expected >= 0, "latch::latch: Precondition violation: expected >= 0");
    }

    ~latch() noexcept
    {
        // Wait until we have no waiters
        while (m_waiters.load(std::memory_order_acquire) != 0)
            std::this_thread::yield();
    }

    void count_down(ptrdiff_t update = 1) noexcept
    {
        CARB_ASSERT(update >= 0, "latch::count_down: Precondition violation: update >= 0");

        // `fetch_sub` returns value before operation
        auto count = m_counter.fetch_sub(int32_t(update), std::memory_order_release) - int32_t(update);
        if (count == 0)
        {
            // Wake all waiters
            m_counter.notify_all();
        }
        else
        {
            CARB_CHECK(count >= 0, "latch::count_down: Precondition violation: update <= counter");
        }
    }

    // Returns whether the latch has completed. Allowed to return spuriously false with very low probability.
    CARB_NODISCARD bool try_wait() const noexcept
    {
        return m_counter.load(std::memory_order_acquire) == 0;
    }

    void wait() const noexcept
    {
        int32_t count = m_counter.load(std::memory_order_acquire);
        if (count != 0)
        {
            // Register as a waiter
            m_waiters.fetch_add(1, std::memory_order_relaxed);
            _wait(count);
        }
    }

    void arrive_and_wait(ptrdiff_t update = 1) noexcept
    {
        CARB_ASSERT(update >= 0, "latch::arrive_and_wait: Precondition: update >= 0");
        auto count = m_counter.fetch_sub(int32_t(update), std::memory_order_acq_rel) - int32_t(update);
        if (count == 0)
        {
            // We're the last, wake all waiters
            m_counter.notify_all();
        }
        else
        {
            CARB_CHECK(count >= 0, "latch::arrive_and_wait: Precondition: update <= counter");
            // Register as a waiter
            m_waiters.fetch_add(1, std::memory_order_relaxed);
            _wait(count);
        }
    }

private:
    atomic_int32_t m_counter;
    mutable std::atomic_uint32_t m_waiters{ 0 };

    CARB_ALWAYS_INLINE void _wait(int32_t count) const noexcept
    {
        CARB_ASSERT(count != 0);
        do
        {
            m_counter.wait(count, std::memory_order_relaxed);
            count = m_counter.load(std::memory_order_acquire);
        } while (count != 0);

        // Done waiting
        m_waiters.fetch_sub(1, std::memory_order_release);
    }
};

#pragma pop_macro("max")

} // namespace cpp
} // namespace carb