
File members: carb/container/LocklessStack.h

// Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.

#pragma once

#include "../Defines.h"
#include "../cpp/Atomic.h"
#include "../thread/Mutex.h"
#include "../thread/Util.h"

#include <algorithm>
#include <atomic>
#include <chrono>
#include <thread>

#    include <dlfcn.h>


namespace carb
namespace container

template <class T>
class LocklessStackLink;
template <class T, LocklessStackLink<T> T::*U>
class LocklessStack;

namespace detail
template <class T, LocklessStackLink<T> T::*U>
class LocklessStackHelpers;
template <class T, LocklessStackLink<T> T::*U>
class LocklessStackBase;
} // namespace detail

template <class T>
class LocklessStackLink
    constexpr LocklessStackLink() = default;

    CARB_VIZ LocklessStackLink<T>* m_next;

    friend T;
    template <class U, LocklessStackLink<U> U::*V>
    friend class detail::LocklessStackHelpers;
    template <class U, LocklessStackLink<U> U::*V>
    friend class detail::LocklessStackBase;
    template <class U, LocklessStackLink<U> U::*V>
    friend class LocklessStack;

#if !defined(DOXYGEN_BUILD)
namespace detail

class SignalHandler
    static bool readNext(void** out, void* in)
        // We do this in a SEH block (on Windows) because it's possible (though rare) that another thread could have
        // already popped and destroyed `cur` which would cause EXCEPTION_ACCESS_VIOLATION. By handling it in an
        // exception handler, we recover cleanly and try again. On 64-bit Windows, there is zero cost unless an
        // exception is thrown, at which point the kernel will find the Exception info and Unwind Info for the function
        // that we're in.
            *out = *(void**)in;
            return true;
        __except (1)
            return false;
#    endif

template <class T, LocklessStackLink<T> T::*U>
class LocklessStackHelpers
    // Access the LocklessStackLink member of `p`
    static LocklessStackLink<T>* link(T* p)
        return std::addressof(p->*U);

    // Converts a LocklessStackLink to the containing object
    static T* convert(LocklessStackLink<T>* p)
        // We need to calculate the offset of our link member and calculate where T is.
        // Note that this doesn't work if T uses virtual inheritance
        size_t offset = (size_t) reinterpret_cast<char*>(&(((T*)0)->*U));
        return reinterpret_cast<T*>(reinterpret_cast<char*>(p) - offset);

// Base implementations
template <class T, LocklessStackLink<T> T::*U>
class LocklessStackBase : protected LocklessStackHelpers<T, U>
    using Base = LocklessStackHelpers<T, U>;

    bool _isEmpty() const
        return !m_head.load();

    bool _push(T* first, T* last)
        std::lock_guard<Lock> g(m_lock);
        // Relaxed because under the lock
        Base::link(last)->m_next = m_head.load(std::memory_order_relaxed);
        bool const wasEmpty = !m_head.load(std::memory_order_relaxed);, std::memory_order_relaxed);
        return wasEmpty;

    T* _popOne()
        std::unique_lock<Lock> g(m_lock);
        // Relaxed because under the lock
        auto cur = m_head.load(std::memory_order_relaxed);
        if (!cur)
            return nullptr;
        }>m_next, std::memory_order_relaxed);
        return Base::convert(cur);

    T* _popAll()
        std::lock_guard<Lock> g(m_lock);
        // Relaxed because under the lock
        LocklessStackLink<T>* head =, std::memory_order_relaxed);
        return head ? Base::convert(head) : nullptr;

    void _wait()
        auto p = m_head.load();
        while (!p)
            p = m_head.load();

    template <class Rep, class Period>
    bool _waitFor(const std::chrono::duration<Rep, Period>& dur)
        return _waitUntil(std::chrono::steady_clock::now() + dur);

    template <class Clock, class Duration>
    bool _waitUntil(const std::chrono::time_point<Clock, Duration>& tp)
        auto p = m_head.load();
        while (!p)
            if (!m_head.wait_until(p, tp))
                return false;
            p = m_head.load();
        return true;

    void _notifyOne()

    void _notifyAll()

    // Cannot be lockless if we don't have SEH
    // Testing reveals that mutex is significantly faster than spinlock in highly-contended cases.
    using Lock = carb::thread::mutex;
    Lock m_lock;
    cpp::atomic<LocklessStackLink<T>*> m_head{ nullptr };
#    else
// Windows implementation: requires SEH and relies upon the fact that aligned pointers on modern OSes don't
// use at least 10 bits of the 64-bit space, so it uses those bits as a sequence number to ensure uniqueness between
// different threads competing to pop.
template <class T, LocklessStackLink<T> T::*U>
class LocklessStackBase : protected LocklessStackHelpers<T, U>
    using Base = LocklessStackHelpers<T, U>;

    constexpr LocklessStackBase()
    bool _isEmpty() const
        return !decode(m_head.load(std::memory_order_acquire));

    bool _push(T* first, T* last)
        // All OS bits should either be zero or one, and it needs to be 8-byte-aligned.
        LocklessStackLink<T>* lnk = Base::link(first);
        CARB_ASSERT((size_t(lnk) & kCPUMask) == 0 || (size_t(lnk) & kCPUMask) == kCPUMask, "Unexpected OS bits set");
        CARB_ASSERT((size_t(lnk) & ((1 << 3) - 1)) == 0, "Pointer not aligned properly");

        uint16_t seq;
        uint64_t expected = m_head.load(std::memory_order_acquire), temp;
        decltype(lnk) next;
            next = decode(expected, seq);
            Base::link(last)->m_next = next;
            temp = encode(lnk, seq + 1); // Increase sequence
        } while (CARB_UNLIKELY(
            !m_head.compare_exchange_strong(expected, temp, std::memory_order_release, std::memory_order_relaxed)));
        return !next;

    T* _popOne()
        uint64_t expected = m_head.load(std::memory_order_acquire);
        LocklessStackLink<T>* cur;
        uint16_t seq;

        bool isNull = false;

        this_thread::spinWaitWithBackoff([&] {
            cur = decode(expected, seq);
            if (!cur)
                // End attempts because the stack is empty
                isNull = true;
                return true;

            // Attempt to read the next value
            LocklessStackLink<T>* newhead;
            if (!detail::SignalHandler::readNext((void**)&newhead, cur))
                // Another thread changed `cur`, so reload and try again.
                expected = m_head.load(std::memory_order_acquire);
                return false;

            // Only push needs to increase `seq`
            uint64_t temp = encode(newhead, seq);
            return m_head.compare_exchange_strong(expected, temp, std::memory_order_release, std::memory_order_relaxed);

        return isNull ? nullptr : Base::convert(cur);

    T* _popAll()
        uint16_t seq;
        uint64_t expected = m_head.load(std::memory_order_acquire), temp;
        for (;;)
            LocklessStackLink<T>* head = decode(expected, seq);
            if (!head)
                return nullptr;

            // Keep the same sequence since only push() needs to increment the sequence
            temp = encode(nullptr, seq);
            if (CARB_LIKELY(
                    m_head.compare_exchange_weak(expected, temp, std::memory_order_release, std::memory_order_relaxed)))
                return Base::convert(head);

    void _wait()
        uint64_t head = m_head.load(std::memory_order_acquire);
        while (!decode(head))
            m_head.wait(head, std::memory_order_relaxed);
            head = m_head.load(std::memory_order_acquire);

    template <class Rep, class Period>
    bool _waitFor(const std::chrono::duration<Rep, Period>& dur)
        return _waitUntil(std::chrono::steady_clock::now() + dur);

    template <class Clock, class Duration>
    bool _waitUntil(const std::chrono::time_point<Clock, Duration>& tp)
        uint64_t head = m_head.load(std::memory_order_acquire);
        while (!decode(head))
            if (!m_head.wait_until(head, tp, std::memory_order_relaxed))
                return false;
            head = m_head.load(std::memory_order_acquire);
        return true;

    void _notifyOne()

    void _notifyAll()

    // On 64-bit architectures, we make use of the fact that CPUs only use a certain number of address bits.
    // Intel CPUs require that these 8 to 16 most-significant-bits match (all 1s or 0s). Since 8 appears to be the
    // lowest common denominator, we steal 7 bits (to save the value of one of the bits so that they can match) for a
    // sequence number. The sequence is important as it causes the resulting stored value to change even if the stack is
    // pushing and popping the same value.
    // Pointer compression drops the `m` and `z` bits from the pointer. `m` are expected to be consistent (all 1 or 0)
    // and match the most-significant `P` bit. `z` are expected to be zeros:
    // 63 ------------------------------ BITS ------------------------------ 0
    // `m_head` is encoded as the shifted compressed pointer bits `P` with sequence bits `s`:
    // 63 ------------------------------ BITS ------------------------------ 0

    static_assert(sizeof(size_t) == 8, "64-bit only");
    CARB_VIZ constexpr const static size_t kCpuBits = 7; // MSBs that are limited by CPU hardware and must match the
                                                         // 56th bit
    constexpr const static size_t kCPUMask = ((size_t(1) << (kCpuBits + 1)) - 1) << (63 - kCpuBits);
    CARB_VIZ constexpr const static size_t kSeqBits = kCpuBits + 3; // We also use the lowest 3 bits as part of the
                                                                    // sequence
    CARB_VIZ constexpr const static size_t kSeqMask = (size_t(1) << kSeqBits) - 1;
    CARB_VIZ cpp::atomic_uint64_t m_head{ 0 };

    static LocklessStackLink<T>* decode(size_t val)
        // Clear the `s` sequence bits and shift as a signed value to sign-extend so that the `m` bits are filled in to
        // match the most-significant `P` bit.
        return reinterpret_cast<LocklessStackLink<T>*>(ptrdiff_t(val & ~kSeqMask) >> kCpuBits);
    static LocklessStackLink<T>* decode(size_t val, uint16_t& seq)
        seq = val & kSeqMask;
        return decode(val);
    static size_t encode(LocklessStackLink<T>* p, uint16_t seq)
        // Shift the pointer value, dropping the most significant `m` bits and write the sequence number over the `z`
        // and space created in the least-significant area.
        return ((reinterpret_cast<size_t>(p) << kCpuBits) & ~kSeqMask) + (seq & uint16_t(kSeqMask));
#    endif

} // namespace detail

template <class T, LocklessStackLink<T> T::*U>
class LocklessStack final : protected detail::LocklessStackBase<T, U>
    using Base = detail::LocklessStackBase<T, U>;

    constexpr LocklessStack() = default;

        // Ensure the stack is empty

    bool isEmpty() const
        return Base::_isEmpty();

    bool push(T* p)
        return Base::_push(p, p);

    template <class InputItRef,
              std::enable_if_t<std::is_convertible<decltype(std::declval<InputItRef&>()++, *std::declval<InputItRef&>()), T&>::value,
                               bool> = false>
    template <class InputItRef>
    bool push(InputItRef begin, InputItRef end)
        if (begin == end)
            return false;

        // Walk the list and have them point to each other
        InputItRef last = begin;
        InputItRef iter = begin;
        for (iter++; iter != end; last = iter++)
            Base::link(std::addressof(*last))->m_next = Base::link(std::addressof(*iter));

        return Base::_push(std::addressof(*begin), std::addressof(*last));

    template <class InputItPtr,
              std::enable_if_t<std::is_convertible<decltype(std::declval<InputItPtr&>()++, *std::declval<InputItPtr&>()), T*>::value,
                               bool> = true>
    template <class InputItPtr>
    bool push(InputItPtr begin, InputItPtr end)
        if (begin == end)
            return false;

        // Walk the list and have them point to each other
        InputItPtr last = begin;
        InputItPtr iter = begin;
        for (iter++; iter != end; last = iter++)
            Base::link(*last)->m_next = Base::link(*iter);

        return Base::_push(*begin, *last);

    T* pop()
        return Base::_popOne();

    void popAll()

    template <class Func>
    void forEach(Func&& f)
        T* p = Base::_popAll();
        LocklessStackLink<T>* h = p ? Base::link(p) : nullptr;
        while (h)
            p = Base::convert(h);
            h = h->m_next;

    bool pushNotify(T* p)
        bool b = push(p);
        return b;

    T* popWait()
        T* p = pop();
        while (!p)
            p = pop();
        return p;

    template <class Rep, class Period>
    T* popWaitFor(const std::chrono::duration<Rep, Period>& dur)
        return popWaitUntil(std::chrono::steady_clock::now() + dur);

    template <class Clock, class Duration>
    T* popWaitUntil(const std::chrono::time_point<Clock, Duration>& tp)
        T* p = pop();
        while (!p)
            if (!waitUntil(tp))
                return pop();
            p = pop();
        return p;

    void wait()

    template <class Rep, class Period>
    bool waitFor(const std::chrono::duration<Rep, Period>& dur)
        return Base::_waitFor(dur);

    template <class Clock, class Duration>
    bool waitUntil(const std::chrono::time_point<Clock, Duration>& tp)
        return Base::_waitUntil(tp);

    void notifyOne()

    void notifyAll()

} // namespace container
} // namespace carb