carb/thread/Mutex.h

File members: carb/thread/Mutex.h

// Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include "Futex.h"
#include "Util.h"
#include "../cpp/Atomic.h"

#include <system_error>

#if CARB_PLATFORM_WINDOWS
#    include "../CarbWindows.h"
#endif

namespace carb
{
namespace thread
{

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail
{
#    if CARB_PLATFORM_WINDOWS
template <bool Recursive>
class BaseMutex
{
public:
    constexpr static bool kRecursive = Recursive;

    CARB_PREVENT_COPY_AND_MOVE(BaseMutex);

    constexpr BaseMutex() noexcept = default;
    ~BaseMutex()
    {
        CARB_FATAL_UNLESS(m_count == 0, "Mutex destroyed while busy");
    }

    void lock()
    {
        uint32_t const tid = this_thread::getId();
        if (try_recurse(tid, cpp::bool_constant<kRecursive>()))
            return;

        m_lll.lock();
        m_owner = tid;
        m_count = 1;
    }
    bool try_lock()
    {
        uint32_t const tid = this_thread::getId();
        if (try_recurse(tid, cpp::bool_constant<kRecursive>()))
            return true;

        if (m_lll.try_lock())
        {
            m_owner = tid;
            m_count = 1;
            return true;
        }
        return false;
    }
    void unlock()
    {
        uint32_t tid = this_thread::getId();
        CARB_FATAL_UNLESS(m_owner == tid, "Not owner");
        if (--m_count == 0)
        {
            m_owner = kInvalidOwner;
            m_lll.unlock();
        }
    }
    bool is_current_thread_owner() const noexcept
    {
        // We don't need this to be an atomic op because one of the following must be true during this call:
        // - m_owner is equal to this_thread::getId() and cannot change
        // - m_owner is not equal and cannot become equal to this_thread::getId()
        return m_owner == this_thread::getId();
    }

private:
    bool try_recurse(ThreadId who, std::true_type) // recursive
    {
        if (who == m_owner)
        {
            // Already inside the lock
            ++m_count;
            return true;
        }
        return false;
    }
    bool try_recurse(ThreadId who, std::false_type) // non-recursive
    {
        CARB_FATAL_UNLESS(who != m_owner, "Attempted recursion on non-recursive mutex");
        return false;
    }

    constexpr static ThreadId kInvalidOwner = ThreadId();

    LowLevelLock m_lll;
    ThreadId m_owner = kInvalidOwner;
    int m_count = 0;
};

#    else
template <bool Recursive>
class BaseMutex;

// BaseMutex (non-recursive)
template <>
class BaseMutex<false>
{
public:
    constexpr static bool kRecursive = false;

    CARB_IF_NOT_TSAN(constexpr) BaseMutex() noexcept
    {
    }

    ~BaseMutex()
    {
        CARB_FATAL_UNLESS(m_owner == kInvalidOwner, "Mutex destroyed while busy");
    }

    void lock()
    {
        auto self = pthread_self();
        CARB_FATAL_UNLESS(m_owner != self, "Attempted recursive lock on non-recursive mutex");
        m_lll.lock();
        // Now inside the lock
        m_owner = self;
    }

    bool try_lock()
    {
        auto self = pthread_self();
        CARB_FATAL_UNLESS(m_owner != self, "Attempted recursive lock on non-recursive mutex");
        if (m_lll.try_lock())
        {
            // Now inside the lock
            m_owner = self;
            return true;
        }
        return false;
    }

    void unlock()
    {
        CARB_FATAL_UNLESS(m_owner == pthread_self(), "unlock() called by non-owning thread");
        m_owner = kInvalidOwner;
        m_lll.unlock();
    }

    bool is_current_thread_owner() const noexcept
    {
        // We don't need this to be an atomic op because one of the following must be true during this call:
        // - m_owner is equal to pthread_self() and cannot change
        // - m_owner is not equal and cannot become equal to pthread_self()
        return m_owner == pthread_self();
    }

private:
    constexpr static pthread_t kInvalidOwner = pthread_t();

    LowLevelLock m_lll;
    // this_thread::getId() is incredibly slow because it makes a syscall; use pthread_self() instead.
    pthread_t m_owner = kInvalidOwner;
};

// BaseMutex (recursive)
template <>
class BaseMutex<true>
{
public:
    constexpr static bool kRecursive = true;

    CARB_IF_NOT_TSAN(constexpr) BaseMutex() noexcept = default;

    ~BaseMutex()
    {
        CARB_FATAL_UNLESS(m_owner == kInvalidOwner, "Mutex destroyed while busy");
    }

    void lock()
    {
        auto self = pthread_self();
        if (self == m_owner)
        {
            CARB_FATAL_UNLESS(m_depth != kMaxDepth, "Recursion overflow");
            __tsan_mutex_pre_lock(this, 0);
            ++m_depth;
            __tsan_mutex_post_lock(this, __tsan_mutex_recursive_lock, (int)(ptrdiff_t)m_depth);
            return;
        }
        m_lll.lock();
        // Now inside the lock
        m_owner = self;
        CARB_ASSERT(m_depth == 0);
        m_depth = 1;
    }

    bool try_lock()
    {
        auto self = pthread_self();
        if (self == m_owner)
        {
            CARB_FATAL_UNLESS(m_depth != kMaxDepth, "Recursion overflow");
            __tsan_mutex_pre_lock(this, __tsan_mutex_try_lock);
            ++m_depth;
            __tsan_mutex_post_lock(this, __tsan_mutex_try_lock | __tsan_mutex_recursive_lock, (int)(ptrdiff_t)m_depth);
            return true;
        }

        if (m_lll.try_lock())
        {
            // Now inside the lock
            m_owner = self;
            CARB_ASSERT(m_depth == 0);
            m_depth = 1;
            return true;
        }
        return false;
    }

    void unlock()
    {
        CARB_FATAL_UNLESS(m_owner == pthread_self(), "unlock() called by non-owning thread");
        CARB_ASSERT(m_depth > 0);
        if (--m_depth == 0)
        {
            m_owner = kInvalidOwner;
            m_lll.unlock();
        }
        else
        {
            // Keep tsan updated
            __tsan_mutex_pre_unlock(this, 0);
            __tsan_mutex_post_unlock(this, 0);
        }
    }

    bool is_current_thread_owner() const noexcept
    {
        // We don't need this to be an atomic op because one of the following must be true during this call:
        // - m_owner is equal to pthread_self() and cannot change
        // - m_owner is not equal and cannot become equal to pthread_self()
        return m_owner == pthread_self();
    }

private:
    constexpr static pthread_t kInvalidOwner = pthread_t();
    constexpr static size_t kMaxDepth = size_t(INT_MAX); // For tsan since we can only pass an int for depth

    LowLevelLock m_lll;
    // this_thread::getId() is incredibly slow because it makes a syscall; use pthread_self() instead.
    pthread_t m_owner = kInvalidOwner;
    size_t m_depth = 0;
};
#    endif
} // namespace detail
#endif

class mutex : protected detail::BaseMutex<false>
{
    using Base = detail::BaseMutex<false>;

public:
    CARB_IF_NOT_TSAN(constexpr) mutex() noexcept = default;

    ~mutex() = default;

    void lock()
    {
        Base::lock();
    }

    bool try_lock()
    {
        return Base::try_lock();
    }

    void unlock()
    {
        Base::unlock();
    }

    bool is_current_thread_owner() const noexcept
    {
        return Base::is_current_thread_owner();
    }
};

class recursive_mutex : protected detail::BaseMutex<true>
{
    using Base = detail::BaseMutex<true>;

public:
    CARB_IF_NOT_TSAN(constexpr) recursive_mutex() noexcept = default;

    ~recursive_mutex() = default;

    void lock()
    {
        Base::lock();
    }

    bool try_lock()
    {
        return Base::try_lock();
    }

    void unlock()
    {
        Base::unlock();
    }

    bool is_current_thread_owner() const noexcept
    {
        return Base::is_current_thread_owner();
    }
};

} // namespace thread
} // namespace carb