carb/container/LocklessStack.h

File members: carb/container/LocklessStack.h

// Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include "../Defines.h"
#include "../cpp/Atomic.h"
#include "../thread/Mutex.h"
#include "../thread/Util.h"

#include <algorithm>
#include <atomic>
#include <chrono>
#include <thread>

#if CARB_POSIX
#    include <dlfcn.h>
#endif

namespace carb
{
namespace container
{

template <class T>
class LocklessStackLink;
template <class T, LocklessStackLink<T> T::*U>
class LocklessStack;

#ifndef DOXYGEN_BUILD
namespace detail
{
template <class T, LocklessStackLink<T> T::*U>
class LocklessStackHelpers;
template <class T, LocklessStackLink<T> T::*U>
class LocklessStackBase;
} // namespace detail
#endif

template <class T>
class LocklessStackLink
{
public:
    constexpr LocklessStackLink() = default;

private:
    CARB_VIZ LocklessStackLink<T>* m_next;

    friend T;
    template <class U, LocklessStackLink<U> U::*V>
    friend class detail::LocklessStackHelpers;
    template <class U, LocklessStackLink<U> U::*V>
    friend class detail::LocklessStackBase;
    template <class U, LocklessStackLink<U> U::*V>
    friend class LocklessStack;
};

#if !defined(DOXYGEN_BUILD)
namespace detail
{

template <class T, LocklessStackLink<T> T::*U>
class LocklessStackHelpers
{
public:
    // Access the LocklessStackLink member of `p`
    static LocklessStackLink<T>* link(T* p)
    {
        return std::addressof(p->*U);
    }

    // Converts a LocklessStackLink to the containing object
    static T* convert(LocklessStackLink<T>* p)
    {
        // We need to calculate the offset of our link member and calculate where T is.
        // Note that this doesn't work if T uses virtual inheritance
        size_t offset = (size_t) reinterpret_cast<char*>(&(((T*)0)->*U));
        return reinterpret_cast<T*>(reinterpret_cast<char*>(p) - offset);
    }
};

// Base implementations
template <class T, LocklessStackLink<T> T::*U>
class LocklessStackBase : protected LocklessStackHelpers<T, U>
{
    using Base = LocklessStackHelpers<T, U>;
    using Link = LocklessStackLink<T>;

public:
    bool _isEmpty() const
    {
        return !m_head.load(std::memory_order_acquire);
    }

    bool _push(T* first, T* last)
    {
        Link* old = m_head.load(std::memory_order_relaxed);
        do
        {
            // NOTE: if m_head has the lock bit set, the cmpxchg will fail. This is by design. Paradoxically this is
            // faster than testing and reloading.
            old = withoutLockBit(old);
            Base::link(last)->m_next = old;
        } while (!m_head.compare_exchange_weak(
            old, Base::link(first), std::memory_order_release, std::memory_order_relaxed));
        return old == nullptr;
    }

    T* _popOne()
    {
        this_thread::atomic_fence_seq_cst(); // force visibility
        Link* old = m_head.load(std::memory_order_relaxed);
        for (thread::AtomicBackoff<> backoff;; backoff.pause())
        {
            if (!old)
                return nullptr;

            // Try to lock. On average it is faster for our contention tests if we use AtomicBackoff and do not check
            // whether `old` has the lock bit set before attempting the cmpxchg.
            if (m_head.compare_exchange_weak(old, withLockBit(old), std::memory_order_relaxed, std::memory_order_relaxed) &&
                !hasLockBit(old))
                break;
        }

        // Now has exclusive access
        CARB_ASSERT(!hasLockBit(old->m_next));
        m_head.store(std::exchange(old->m_next, nullptr), std::memory_order_release);

        return Base::convert(old);
    }

    T* _popAll()
    {
        this_thread::atomic_fence_seq_cst(); // force visibility
        Link* old = m_head.load(std::memory_order_relaxed);
        for (thread::AtomicBackoff<> backoff;; backoff.pause())
        {
            if (!old)
                return nullptr;

            // Try to lock. On average it is faster for our contention tests if we use AtomicBackoff and do not check
            // whether `old` has the lock bit set before attempting the cmpxchg.
            if (m_head.compare_exchange_weak(old, withLockBit(old), std::memory_order_relaxed, std::memory_order_relaxed) &&
                !hasLockBit(old))
                break;
        }

        // Now has exclusive access
        m_head.store(nullptr, std::memory_order_release);

        return Base::convert(old);
    }

    void _wait()
    {
        auto p = m_head.load();
        while (!p)
        {
            m_head.wait(p);
            p = m_head.load();
        }
    }

    template <class Rep, class Period>
    bool _waitFor(const std::chrono::duration<Rep, Period>& dur)
    {
        return _waitUntil(std::chrono::steady_clock::now() + dur);
    }

    template <class Clock, class Duration>
    bool _waitUntil(const std::chrono::time_point<Clock, Duration>& tp)
    {
        auto p = m_head.load();
        while (!p)
        {
            if (!m_head.wait_until(p, tp))
                return false;
            p = m_head.load();
        }
        return true;
    }

    void _notifyOne()
    {
        m_head.notify_one();
    }

    void _notifyAll()
    {
        m_head.notify_all();
    }

private:
    constexpr static uintptr_t kLock = 1;

    static bool hasLockBit(Link* in)
    {
        return !!(reinterpret_cast<uintptr_t>(in) & kLock);
    }

    static Link* withLockBit(Link* in)
    {
        return reinterpret_cast<Link*>(reinterpret_cast<uintptr_t>(in) | kLock);
    }

    static Link* withoutLockBit(Link* in)
    {
        return reinterpret_cast<Link*>(reinterpret_cast<uintptr_t>(in) & ~kLock);
    }

    cpp::atomic<Link*> m_head{ nullptr };
};
} // namespace detail
#endif

template <class T, LocklessStackLink<T> T::*U>
class LocklessStack final : protected detail::LocklessStackBase<T, U>
{
    using Base = detail::LocklessStackBase<T, U>;

public:
    constexpr LocklessStack() = default;

    ~LocklessStack()
    {
        // Ensure the stack is empty
        CARB_ASSERT(isEmpty());
    }

    bool isEmpty() const
    {
        return Base::_isEmpty();
    }

    bool push(T* p)
    {
        return Base::_push(p, p);
    }

#ifndef DOXYGEN_BUILD
    template <class InputItRef,
              std::enable_if_t<std::is_convertible<decltype(std::declval<InputItRef&>()++, *std::declval<InputItRef&>()), T&>::value,
                               bool> = false>
#else
    template <class InputItRef>
#endif
    bool push(InputItRef begin, InputItRef end)
    {
        if (begin == end)
        {
            return false;
        }

        // Walk the list and have them point to each other
        InputItRef last = begin;
        InputItRef iter = begin;
        for (iter++; iter != end; last = iter++)
        {
            Base::link(std::addressof(*last))->m_next = Base::link(std::addressof(*iter));
        }

        return Base::_push(std::addressof(*begin), std::addressof(*last));
    }

#ifndef DOXYGEN_BUILD
    template <class InputItPtr,
              std::enable_if_t<std::is_convertible<decltype(std::declval<InputItPtr&>()++, *std::declval<InputItPtr&>()), T*>::value,
                               bool> = true>
#else
    template <class InputItPtr>
#endif
    bool push(InputItPtr begin, InputItPtr end)
    {
        if (begin == end)
        {
            return false;
        }

        // Walk the list and have them point to each other
        InputItPtr last = begin;
        InputItPtr iter = begin;
        for (iter++; iter != end; last = iter++)
        {
            Base::link(*last)->m_next = Base::link(*iter);
        }

        return Base::_push(*begin, *last);
    }

    T* pop()
    {
        return Base::_popOne();
    }

    void popAll()
    {
        Base::_popAll();
    }

    template <class Func>
    void forEach(Func&& f)
    {
        T* p = Base::_popAll();
        LocklessStackLink<T>* h = p ? Base::link(p) : nullptr;
        while (h)
        {
            p = Base::convert(h);
            h = h->m_next;
            f(p);
        }
    }

    bool pushNotify(T* p)
    {
        bool b = push(p);
        notifyOne();
        return b;
    }

    T* popWait()
    {
        T* p = pop();
        while (!p)
        {
            wait();
            p = pop();
        }
        return p;
    }

    template <class Rep, class Period>
    T* popWaitFor(const std::chrono::duration<Rep, Period>& dur)
    {
        return popWaitUntil(std::chrono::steady_clock::now() + dur);
    }

    template <class Clock, class Duration>
    T* popWaitUntil(const std::chrono::time_point<Clock, Duration>& tp)
    {
        T* p = pop();
        while (!p)
        {
            if (!waitUntil(tp))
            {
                return pop();
            }
            p = pop();
        }
        return p;
    }

    void wait()
    {
        Base::_wait();
    }

    template <class Rep, class Period>
    bool waitFor(const std::chrono::duration<Rep, Period>& dur)
    {
        return Base::_waitFor(dur);
    }

    template <class Clock, class Duration>
    bool waitUntil(const std::chrono::time_point<Clock, Duration>& tp)
    {
        return Base::_waitUntil(tp);
    }

    void notifyOne()
    {
        Base::_notifyOne();
    }

    void notifyAll()
    {
        Base::_notifyAll();
    }
};

} // namespace container
} // namespace carb