carb/thread/Mutex.h

File members: carb/thread/Mutex.h

// Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include "Futex.h"
#include "Util.h"
#include "../cpp/Atomic.h"

#include <system_error>

#if CARB_PLATFORM_WINDOWS
#    include "../CarbWindows.h"
#endif

namespace carb
{
namespace thread
{

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail
{
#    if CARB_PLATFORM_WINDOWS
template <bool Recursive>
class BaseMutex
{
public:
    constexpr static bool kRecursive = Recursive;

    CARB_PREVENT_COPY_AND_MOVE(BaseMutex);

    constexpr BaseMutex() noexcept = default;
    ~BaseMutex()
    {
        CARB_FATAL_UNLESS(m_count == 0, "Mutex destroyed while busy");
    }

    void lock()
    {
        uint32_t const tid = this_thread::getId();
        if (!Recursive)
        {
            CARB_FATAL_UNLESS(tid != m_owner, "Recursion not allowed");
        }
        else if (tid == m_owner)
        {
            ++m_count;
            return;
        }
        AcquireSRWLockExclusive((PSRWLOCK)&m_lock);
        m_owner = tid;
        m_count = 1;
    }
    bool try_lock()
    {
        uint32_t const tid = this_thread::getId();
        if (!Recursive)
        {
            CARB_FATAL_UNLESS(tid != m_owner, "Recursion not allowed");
        }
        else if (tid == m_owner)
        {
            ++m_count;
            return true;
        }
        if (CARB_LIKELY(TryAcquireSRWLockExclusive((PSRWLOCK)&m_lock)))
        {
            m_owner = tid;
            m_count = 1;
            return true;
        }
        return false;
    }
    void unlock()
    {
        uint32_t tid = this_thread::getId();
        CARB_FATAL_UNLESS(m_owner == tid, "Not owner");
        if (--m_count == 0)
        {
            m_owner = kInvalidOwner;
            ReleaseSRWLockExclusive((PSRWLOCK)&m_lock);
        }
    }
    bool is_current_thread_owner() const noexcept
    {
        // We don't need this to be an atomic op because one of the following must be true during this call:
        // - m_owner is equal to this_thread::getId() and cannot change
        // - m_owner is not equal and cannot become equal to this_thread::getId()
        return m_owner == this_thread::getId();
    }

private:
    constexpr static uint32_t kInvalidOwner = uint32_t(0);

    CARBWIN_SRWLOCK m_lock{ CARBWIN_SRWLOCK_INIT };
    uint32_t m_owner{ kInvalidOwner };
    int m_count{ 0 };
};

#    else
template <bool Recursive>
class BaseMutex;

// BaseMutex (non-recursive)
template <>
class BaseMutex<false>
{
public:
    constexpr static bool kRecursive = false;

    constexpr BaseMutex() noexcept = default;
    ~BaseMutex()
    {
        CARB_FATAL_UNLESS(m_lock.load(std::memory_order_relaxed) == Unlocked, "Mutex destroyed while busy");
    }

    void lock()
    {
        // Blindly attempt to lock
        LockState val = Unlocked;
        if (CARB_UNLIKELY(
                !m_lock.compare_exchange_strong(val, Locked, std::memory_order_acquire, std::memory_order_relaxed)))
        {
            CARB_FATAL_UNLESS(m_owner != this_thread::getId(), "Recursive locking not allowed");

            // Failed to lock and need to wait
            if (val == LockedMaybeWaiting)
            {
                m_lock.wait(LockedMaybeWaiting, std::memory_order_relaxed);
            }

            while (m_lock.exchange(LockedMaybeWaiting, std::memory_order_acquire) != Unlocked)
            {
                m_lock.wait(LockedMaybeWaiting, std::memory_order_relaxed);
            }

            CARB_ASSERT(m_owner == kInvalidOwner);
        }
        // Now inside the lock
        m_owner = this_thread::getId();
    }

    bool try_lock()
    {
        // Blindly attempt to lock
        LockState val = Unlocked;
        if (CARB_LIKELY(m_lock.compare_exchange_strong(val, Locked, std::memory_order_acquire, std::memory_order_relaxed)))
        {
            m_owner = this_thread::getId();
            return true;
        }
        CARB_FATAL_UNLESS(m_owner != this_thread::getId(), "Recursive locking not allowed");
        return false;
    }

    void unlock()
    {
        CARB_FATAL_UNLESS(is_current_thread_owner(), "Not owner");
        m_owner = kInvalidOwner;
        LockState val = m_lock.exchange(Unlocked, std::memory_order_release);
        if (val == LockedMaybeWaiting)
        {
            m_lock.notify_one();
        }
    }

    bool is_current_thread_owner() const noexcept
    {
        // We don't need this to be an atomic op because one of the following must be true during this call:
        // - m_owner is equal to this_thread::getId() and cannot change
        // - m_owner is not equal and cannot become equal to this_thread::getId()
        return m_owner == this_thread::getId();
    }

private:
    enum LockState : uint8_t
    {
        Unlocked = 0,
        Locked = 1,
        LockedMaybeWaiting = 2,
    };

    constexpr static uint32_t kInvalidOwner = 0;

    cpp::atomic<LockState> m_lock{ Unlocked };
    uint32_t m_owner{ kInvalidOwner };
};

// BaseMutex (recursive)
template <>
class BaseMutex<true>
{
public:
    constexpr static bool kRecursive = true;

    constexpr BaseMutex() noexcept = default;
    ~BaseMutex()
    {
        CARB_FATAL_UNLESS(m_lock.load(std::memory_order_relaxed) == 0, "Mutex destroyed while busy");
    }

    void lock()
    {
        // Blindly attempt to lock
        uint32_t val = Unlocked;
        if (CARB_UNLIKELY(
                !m_lock.compare_exchange_strong(val, Locked, std::memory_order_acquire, std::memory_order_relaxed)))
        {
            // Failed to lock (or recursive)
            if (m_owner == this_thread::getId())
            {
                val = m_lock.fetch_add(DepthUnit, std::memory_order_relaxed);
                CARB_FATAL_UNLESS((val & DepthMask) != DepthMask, "Recursion overflow");
                return;
            }

            // Failed to lock and need to wait
            if ((val & ~DepthMask) == LockedMaybeWaiting)
            {
                m_lock.wait(val, std::memory_order_relaxed);
            }

            for (;;)
            {
                // Atomically set to LockedMaybeWaiting in a loop since the owning thread could be changing the depth
                while (!m_lock.compare_exchange_weak(
                    val, (val & DepthMask) | LockedMaybeWaiting, std::memory_order_acquire, std::memory_order_relaxed))
                    CARB_HARDWARE_PAUSE();
                if ((val & ~DepthMask) == Unlocked)
                    break;
                m_lock.wait((val & DepthMask) | LockedMaybeWaiting, std::memory_order_relaxed);
            }

            CARB_ASSERT(m_owner == kInvalidOwner);
        }
        // Now inside the lock
        m_owner = this_thread::getId();
    }

    bool try_lock()
    {
        // Blindly attempt to lock
        uint32_t val = Unlocked;
        if (CARB_LIKELY(m_lock.compare_exchange_strong(val, Locked, std::memory_order_acquire, std::memory_order_relaxed)))
        {
            // Succeeded, we now own the lock
            m_owner = this_thread::getId();
            return true;
        }
        // Failed (or recursive)
        if (m_owner == this_thread::getId())
        {
            // Recursive, increment the depth
            val = m_lock.fetch_add(DepthUnit, std::memory_order_acquire);
            CARB_FATAL_UNLESS((val & DepthMask) != DepthMask, "Recursion overflow");
            return true;
        }
        return false;
    }

    void unlock()
    {
        CARB_FATAL_UNLESS(is_current_thread_owner(), "Not owner");
        uint32_t val = m_lock.load(std::memory_order_relaxed);
        if (!(val & DepthMask))
        {
            // Depth count is at zero, so this is the last unlock().
            m_owner = kInvalidOwner;
            uint32_t val = m_lock.exchange(Unlocked, std::memory_order_release);
            if (val == LockedMaybeWaiting)
            {
                m_lock.notify_one();
            }
        }
        else
            m_lock.fetch_sub(DepthUnit, std::memory_order_release);
    }

    bool is_current_thread_owner() const noexcept
    {
        // We don't need this to be an atomic op because one of the following must be true during this call:
        // - m_owner is equal to this_thread::getId() and cannot change
        // - m_owner is not equal and cannot become equal to this_thread::getId()
        return m_owner == this_thread::getId();
    }

private:
    enum LockState : uint32_t
    {
        Unlocked = 0,
        Locked = 1,
        LockedMaybeWaiting = 2,

        DepthUnit = 1 << 2, // Each recursion count increment
        DepthMask = 0xFFFFFFFC // The 30 MSBs are used for the recursion count
    };

    constexpr static uint32_t kInvalidOwner = 0;

    cpp::atomic_uint32_t m_lock{ Unlocked };
    uint32_t m_owner{ kInvalidOwner };
};
#    endif
} // namespace detail
#endif

class mutex : public detail::BaseMutex<false>
{
    using Base = detail::BaseMutex<false>;

public:
    constexpr mutex() noexcept = default;

    ~mutex() = default;

    void lock()
    {
        Base::lock();
    }

    bool try_lock()
    {
        return Base::try_lock();
    }

    void unlock()
    {
        Base::unlock();
    }

    bool is_current_thread_owner() const noexcept
    {
        return Base::is_current_thread_owner();
    }
};

class recursive_mutex : public detail::BaseMutex<true>
{
    using Base = detail::BaseMutex<true>;

public:
    constexpr recursive_mutex() noexcept = default;

    ~recursive_mutex() = default;

    void lock()
    {
        Base::lock();
    }

    bool try_lock()
    {
        return Base::try_lock();
    }

    void unlock()
    {
        Base::unlock();
    }

    bool is_current_thread_owner() const noexcept
    {
        return Base::is_current_thread_owner();
    }
};

} // namespace thread
} // namespace carb