Latch.h#
Fully qualified name: carb/cpp/Latch.h
File members: carb/cpp/Latch.h
// SPDX-FileCopyrightText: Copyright (c) 2019-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: LicenseRef-NvidiaProprietary
//
// NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
// property and proprietary rights in and to this material, related
// documentation and any modifications thereto. Any use, reproduction,
// disclosure or distribution of this material and related documentation
// without an express license agreement from NVIDIA CORPORATION or
// its affiliates is strictly prohibited.
#pragma once
#include "Atomic.h"
#include <atomic>
#include <climits>
#include <cstdint>
#include <thread>
namespace carb
{
namespace cpp
{
// Handle case where Windows.h may have defined 'max'
#pragma push_macro("max")
#undef max
class latch
{
    CARB_PREVENT_COPY_AND_MOVE(latch);
public:
    [[nodiscard]] static constexpr ptrdiff_t max() noexcept
    {
        return ptrdiff_t(INT_MAX);
    }
    constexpr explicit latch(ptrdiff_t expected) noexcept : m_counter(int32_t(::carb_min(max(), expected)))
    {
        CARB_ASSERT(expected >= 0, "latch::latch: Precondition violation: expected >= 0");
    }
    ~latch() noexcept
    {
        // Wait until we have no waiters
        while (m_waiters.load(std::memory_order_acquire) != 0)
            std::this_thread::yield();
    }
    void count_down(ptrdiff_t update = 1) noexcept
    {
        CARB_ASSERT(update >= 0, "latch::count_down: Precondition violation: update >= 0");
        // `fetch_sub` returns value before operation
        auto count = m_counter.fetch_sub(int32_t(update), std::memory_order_release) - int32_t(update);
        if (count == 0)
        {
            // Wake all waiters
            m_counter.notify_all();
        }
        else
        {
            CARB_RELEASE_ASSERT(count >= 0, "latch::count_down: Precondition violation: update <= counter");
        }
    }
    // Returns whether the latch has completed. Allowed to return spuriously false with very low probability.
    [[nodiscard]] bool try_wait() const noexcept
    {
        return m_counter.load(std::memory_order_acquire) == 0;
    }
    void wait() const noexcept
    {
        int32_t count = m_counter.load(std::memory_order_acquire);
        if (count != 0)
        {
            // Register as a waiter
            m_waiters.fetch_add(1, std::memory_order_relaxed);
            _wait(count);
        }
    }
    void arrive_and_wait(ptrdiff_t update = 1) noexcept
    {
        CARB_ASSERT(update >= 0, "latch::arrive_and_wait: Precondition: update >= 0");
        auto count = m_counter.fetch_sub(int32_t(update), std::memory_order_acq_rel) - int32_t(update);
        if (count == 0)
        {
            // We're the last, wake all waiters
            m_counter.notify_all();
        }
        else
        {
            CARB_RELEASE_ASSERT(count >= 0, "latch::arrive_and_wait: Precondition: update <= counter");
            // Register as a waiter
            m_waiters.fetch_add(1, std::memory_order_relaxed);
            _wait(count);
        }
    }
private:
    atomic_int32_t m_counter;
    mutable std::atomic_uint32_t m_waiters{ 0 };
    CARB_ALWAYS_INLINE void _wait(int32_t count) const noexcept
    {
        CARB_ASSERT(count != 0);
        do
        {
            m_counter.wait(count, std::memory_order_relaxed);
            count = m_counter.load(std::memory_order_acquire);
        } while (count != 0);
        // Done waiting
        m_waiters.fetch_sub(1, std::memory_order_release);
    }
};
#pragma pop_macro("max")
} // namespace cpp
} // namespace carb