carb/cpp/Barrier.h

File members: carb/cpp/Barrier.h

// Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include "Atomic.h"
#include "Bit.h"

#include <utility>

namespace carb
{
namespace cpp
{

namespace detail
{

constexpr uint32_t kInvalidPhase = 0;

struct NullFunction
{
    constexpr NullFunction() noexcept = default;
    constexpr void operator()() noexcept
    {
    }
};

} // namespace detail

// Handle case where Windows.h may have defined 'max'
#pragma push_macro("max")
#undef max

template <class CompletionFunction = detail::NullFunction>
class barrier
{
    CARB_PREVENT_COPY_AND_MOVE(barrier);

public:
    static constexpr ptrdiff_t max() noexcept
    {
        return ptrdiff_t(INT_MAX);
    }

    constexpr explicit barrier(ptrdiff_t expected, CompletionFunction f = CompletionFunction{})
        : m_emo(InitBoth{}, std::move(f), (uint64_t(1) << kPhaseBitShift) + uint32_t(::carb_min(expected, max()))),
          m_expected(uint32_t(::carb_min(expected, max())))
    {
        CARB_ASSERT(expected >= 0 && expected <= max());
    }

    ~barrier()
    {
        // Wait for destruction until all waiters are clear
        while (m_waiters.load(std::memory_order_acquire) != 0)
            std::this_thread::yield();
    }

    class arrival_token
    {
        CARB_PREVENT_COPY(arrival_token);
        friend class barrier;
        uint32_t m_token{ detail::kInvalidPhase };
        arrival_token(uint32_t token) : m_token(token)
        {
        }

    public:
        arrival_token() = default;
        arrival_token(arrival_token&& rhs) : m_token(std::exchange(rhs.m_token, detail::kInvalidPhase))
        {
        }
        arrival_token& operator=(arrival_token&& rhs)
        {
            m_token = std::exchange(rhs.m_token, detail::kInvalidPhase);
            return *this;
        }
    };

    CARB_NODISCARD arrival_token arrive(ptrdiff_t update = 1)
    {
        return arrival_token(uint32_t(_arrive(update).first >> kPhaseBitShift));
    }

    void wait(arrival_token&& arrival) const
    {
        // Precondition: arrival is associated with the phase synchronization point for the current phase or the
        // immediately preceding phase.
        CARB_CHECK(arrival.m_token != 0); // No invalid tokens
        uint64_t data = m_emo.second.load(std::memory_order_acquire);
        uint32_t phase = uint32_t(data >> kPhaseBitShift);
        CARB_CHECK((phase - arrival.m_token) <= 1, "arrival %u is not the previous or current phase %u",
                   arrival.m_token, phase);

        if (phase != arrival.m_token)
            return;

        // Register as a waiter
        m_waiters.fetch_add(1, std::memory_order_relaxed);

        do
        {
            // Wait for the phase to change
            m_emo.second.wait(data, std::memory_order_relaxed);

            // Reload after waiting
            data = m_emo.second.load(std::memory_order_acquire);
            phase = uint32_t(data >> kPhaseBitShift);
        } while (phase == arrival.m_token);

        // Unregister as a waiter
        m_waiters.fetch_sub(1, std::memory_order_release);
    }

    void arrive_and_wait()
    {
        // Two main differences over just doing arrive(wait()):
        // - We return immediately if _arrive() did the phase shift
        // - We don't CARB_CHECK that the phase is the current or preceding one since it is guaranteed
        auto result = _arrive(1);
        if (result.second)
            return;

        // Register as a waiter
        m_waiters.fetch_add(1, std::memory_order_relaxed);

        uint64_t data = result.first;
        uint32_t origPhase = uint32_t(data >> kPhaseBitShift), phase;
        do
        {
            // Wait for the phase to change
            m_emo.second.wait(data, std::memory_order_relaxed);

            // Reload after waiting
            data = m_emo.second.load(std::memory_order_acquire);
            phase = uint32_t(data >> kPhaseBitShift);
        } while (phase == origPhase);

        // Unregister as a waiter
        m_waiters.fetch_sub(1, std::memory_order_release);
    }

    void arrive_and_drop()
    {
        uint32_t prev = m_expected.fetch_sub(1, std::memory_order_relaxed);
        CARB_CHECK(prev != 0); // Precondition failure: expected count for the current barrier phase must be greater
                               // than zero.

        _arrive(1);
    }

private:
    constexpr static int kPhaseBitShift = 32;
    constexpr static uint64_t kCounterMask = 0xffffffffull;

    CARB_ALWAYS_INLINE std::pair<uint64_t, bool> _arrive(ptrdiff_t update)
    {
        CARB_CHECK(update > 0 && update <= max());

        uint64_t pre = m_emo.second.fetch_sub(uint32_t(update), std::memory_order_acq_rel);
        CARB_CHECK(ptrdiff_t(int32_t(uint32_t(pre & kCounterMask))) >= update); // Precondition check

        bool completed = false;
        if (uint32_t(pre & kCounterMask) == uint32_t(update))
        {
            // Phase is now complete
            std::atomic_thread_fence(std::memory_order_acquire);
            _completePhase(pre - uint32_t(update));
            completed = true;
        }

        return std::make_pair(pre - uint32_t(update), completed);
    }

    void _completePhase(uint64_t data)
    {
        uint32_t expected = m_expected.load(std::memory_order_relaxed);

        // Run the completion routine before releasing threads
        m_emo.first()();

        // Increment the phase and don't allow the invalid phase
        uint32_t phase = uint32_t(data >> kPhaseBitShift);
        if (++phase == detail::kInvalidPhase)
            ++phase;

#if CARB_ASSERT_ENABLED
        // Should not have changed during completion function.
        uint64_t old = m_emo.second.exchange((uint64_t(phase) << kPhaseBitShift) + expected, std::memory_order_release);
        CARB_ASSERT(old == data);
#else
        m_emo.second.store((uint64_t(phase) << kPhaseBitShift) + expected, std::memory_order_release);
#endif

        // Release all waiting threads
        m_emo.second.notify_all();
    }

    // The MSB 32 bits of the atomic_uint64_t are the Phase; the other bits are the Counter
    EmptyMemberPair<CompletionFunction, atomic_uint64_t> m_emo;
    std::atomic_uint32_t m_expected;
    mutable std::atomic_uint32_t m_waiters{ 0 };
};

#pragma pop_macro("max")

} // namespace cpp
} // namespace carb