carb/thread/Spinlock.h

File members: carb/thread/Spinlock.h

// Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include "../Defines.h"

#include <atomic>
#include <thread>

namespace carb
{
namespace thread
{

namespace detail
{

#ifndef DOXYGEN_SHOULD_SKIP_THIS
class RecursionPolicyDisallow
{
public:
    constexpr RecursionPolicyDisallow() = default;

    bool ownsLock() const
    {
        return std::this_thread::get_id() == m_owner;
    }
    void enter()
    {
        auto cur = std::this_thread::get_id();
        CARB_FATAL_UNLESS(cur != m_owner, "Recursion is not allowed");
        m_owner = cur;
    }
    bool tryLeave()
    {
        CARB_FATAL_UNLESS(ownsLock(), "Not owning thread");
        m_owner = std::thread::id(); // clear the owner
        return true;
    }

private:
    std::thread::id m_owner{};
};

class RecursionPolicyAllow
{
public:
    constexpr RecursionPolicyAllow() = default;

    bool ownsLock() const
    {
        return std::this_thread::get_id() == m_owner;
    }
    void enter()
    {
        auto cur = std::this_thread::get_id();
        if (cur == m_owner)
            ++m_recursion;
        else
        {
            CARB_ASSERT(m_owner == std::thread::id()); // owner should be clear
            m_owner = cur;
            m_recursion = 1;
        }
    }
    bool tryLeave()
    {
        CARB_FATAL_UNLESS(ownsLock(), "Not owning thread");
        if (--m_recursion == 0)
        {
            m_owner = std::thread::id(); // clear the owner
            return true;
        }
        return false;
    }

private:
    std::thread::id m_owner{};
    size_t m_recursion{ 0 };
};
#endif

template <class RecursionPolicy>
class SpinlockImpl
{
public:
    constexpr SpinlockImpl() = default;

    ~SpinlockImpl() = default;

    CARB_PREVENT_COPY(SpinlockImpl);

    void lock()
    {
        if (!m_rp.ownsLock())
        {
            // Spin trying to set the lock bit
            while (CARB_UNLIKELY(!!m_lock.fetch_or(1, std::memory_order_acquire)))
            {
                CARB_HARDWARE_PAUSE();
            }
        }
        m_rp.enter();
    }

    void unlock()
    {
        if (m_rp.tryLeave())
        {
            // Released the lock
            m_lock.store(0, std::memory_order_release);
        }
    }

    bool try_lock()
    {
        if (!m_rp.ownsLock())
        {
            // See if we can set the lock bit
            if (CARB_UNLIKELY(!!m_lock.fetch_or(1, std::memory_order_acquire)))
            {
                // Failed!
                return false;
            }
        }
        m_rp.enter();
        return true;
    }

    bool isLockedByThisThread() const
    {
        return m_rp.ownsLock();
    }

private:
    std::atomic<size_t> m_lock{ 0 };
    RecursionPolicy m_rp;
};

} // namespace detail

using RecursiveSpinlock = detail::SpinlockImpl<detail::RecursionPolicyAllow>;

using Spinlock = detail::SpinlockImpl<detail::RecursionPolicyDisallow>;

} // namespace thread
} // namespace carb