Mutex.h#

Fully qualified name: carb/thread/Mutex.h

File members: carb/thread/Mutex.h

// SPDX-FileCopyrightText: Copyright (c) 2020-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: LicenseRef-NvidiaProprietary
//
// NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
// property and proprietary rights in and to this material, related
// documentation and any modifications thereto. Any use, reproduction,
// disclosure or distribution of this material and related documentation
// without an express license agreement from NVIDIA CORPORATION or
// its affiliates is strictly prohibited.

#pragma once

#include "Futex.h"
#include "Util.h"

#include <type_traits>

#if CARB_PLATFORM_WINDOWS
#    include "../CarbWindows.h"
#endif

namespace carb
{
namespace thread
{

#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail
{
#    if CARB_PLATFORM_WINDOWS
template <bool Recursive>
class BaseMutex
{
public:
    constexpr static bool kRecursive = Recursive;

    CARB_PREVENT_COPY_AND_MOVE(BaseMutex);

    constexpr BaseMutex() noexcept = default;
    ~BaseMutex()
    {
        CARB_FATAL_UNLESS(m_count == 0, "Mutex destroyed while busy");
    }

    void lock()
    {
        uint32_t const tid = this_thread::getId();
        if (try_recurse(tid, std::bool_constant<kRecursive>()))
            return;

        m_lll.lock();
        m_owner = tid;
        m_count = 1;
    }
    bool try_lock()
    {
        uint32_t const tid = this_thread::getId();
        if (try_recurse(tid, std::bool_constant<kRecursive>()))
            return true;

        if (m_lll.try_lock())
        {
            m_owner = tid;
            m_count = 1;
            return true;
        }
        return false;
    }
    void unlock()
    {
        uint32_t tid = this_thread::getId();
        CARB_FATAL_UNLESS(m_owner == tid, "Not owner");
        if (--m_count == 0)
        {
            m_owner = kInvalidOwner;
            m_lll.unlock();
        }
    }
    bool is_current_thread_owner() const noexcept
    {
        // We don't need this to be an atomic op because one of the following must be true during this call:
        // - m_owner is equal to this_thread::getId() and cannot change
        // - m_owner is not equal and cannot become equal to this_thread::getId()
        return m_owner == this_thread::getId();
    }

private:
    bool try_recurse(ThreadId who, std::true_type) // recursive
    {
        if (who == m_owner)
        {
            // Already inside the lock
            ++m_count;
            return true;
        }
        return false;
    }
    bool try_recurse(ThreadId who, std::false_type) // non-recursive
    {
        CARB_FATAL_UNLESS(who != m_owner, "Attempted recursion on non-recursive mutex");
        return false;
    }

    constexpr static ThreadId kInvalidOwner = ThreadId{};

    LowLevelLock m_lll;
    ThreadId m_owner{ kInvalidOwner };
    int m_count{ 0 };
};

#    else
template <bool Recursive>
class BaseMutex;

// BaseMutex (non-recursive)
template <>
class BaseMutex<false>
{
public:
    constexpr static bool kRecursive = false;

    CARB_IF_NOT_TSAN(constexpr) BaseMutex() noexcept
    {
    }

    ~BaseMutex()
    {
        CARB_FATAL_UNLESS(m_owner == kInvalidOwner, "Mutex destroyed while busy");
    }

    void lock()
    {
        auto self = pthread_self();
        CARB_FATAL_UNLESS(m_owner != self, "Attempted recursive lock on non-recursive mutex");
        m_lll.lock();
        // Now inside the lock
        m_owner = self;
    }

    bool try_lock()
    {
        auto self = pthread_self();
        CARB_FATAL_UNLESS(m_owner != self, "Attempted recursive lock on non-recursive mutex");
        if (m_lll.try_lock())
        {
            // Now inside the lock
            m_owner = self;
            return true;
        }
        return false;
    }

    void unlock()
    {
        CARB_FATAL_UNLESS(m_owner == pthread_self(), "unlock() called by non-owning thread");
        m_owner = kInvalidOwner;
        m_lll.unlock();
    }

    bool is_current_thread_owner() const noexcept
    {
        // We don't need this to be an atomic op because one of the following must be true during this call:
        // - m_owner is equal to pthread_self() and cannot change
        // - m_owner is not equal and cannot become equal to pthread_self()
        return m_owner == pthread_self();
    }

private:
    constexpr static pthread_t kInvalidOwner = pthread_t();

    LowLevelLock m_lll;
    // this_thread::getId() is incredibly slow because it makes a syscall; use pthread_self() instead.
    // TSan complains about a data race here, but see is_current_thread_owner() for why it's not actually an issue.
    carb::AtomicIfTSan<pthread_t> m_owner = kInvalidOwner;
};

// BaseMutex (recursive)
template <>
class BaseMutex<true>
{
public:
    constexpr static bool kRecursive = true;

    CARB_IF_NOT_TSAN(constexpr) BaseMutex() noexcept = default;

    ~BaseMutex()
    {
        CARB_FATAL_UNLESS(m_owner == kInvalidOwner, "Mutex destroyed while busy");
    }

    void lock()
    {
        auto self = pthread_self();
        if (self == m_owner)
        {
            CARB_FATAL_UNLESS(m_depth != kMaxDepth, "Recursion overflow");
            ++m_depth;
            // Keep tsan updated wrt recursion
            __tsan_mutex_pre_lock(this, 0);
            __tsan_mutex_post_lock(this, 0, 0);
            return;
        }
        m_lll.lock();
        // Now inside the lock
        m_owner = self;
        CARB_ASSERT(m_depth == 0);
        m_depth = 1;
    }

    bool try_lock()
    {
        auto self = pthread_self();
        if (self == m_owner)
        {
            CARB_FATAL_UNLESS(m_depth != kMaxDepth, "Recursion overflow");
            ++m_depth;
            // Keep tsan updated wrt recursion
            __tsan_mutex_pre_lock(this, 0);
            __tsan_mutex_post_lock(this, 0, 0);
            return true;
        }

        if (m_lll.try_lock())
        {
            // Now inside the lock
            m_owner = self;
            CARB_ASSERT(m_depth == 0);
            m_depth = 1;
            return true;
        }
        return false;
    }

    void unlock()
    {
        CARB_FATAL_UNLESS(m_owner == pthread_self(), "unlock() called by non-owning thread");
        CARB_ASSERT(m_depth > 0);
        if (--m_depth == 0)
        {
            m_owner = kInvalidOwner;
            m_lll.unlock();
        }
        else
        {
            // Keep tsan updated wrt recursion
            __tsan_mutex_pre_unlock(this, 0);
            __tsan_mutex_post_unlock(this, 0);
        }
    }

    bool is_current_thread_owner() const noexcept
    {
        // We don't need this to be an atomic op because one of the following must be true during this call:
        // - m_owner is equal to pthread_self() and cannot change
        // - m_owner is not equal and cannot become equal to pthread_self()
        return m_owner == pthread_self();
    }

private:
    constexpr static pthread_t kInvalidOwner = pthread_t();
    constexpr static size_t kMaxDepth = size_t(INT_MAX); // For tsan since we can only pass an int for depth

    LowLevelLock m_lll;
    // this_thread::getId() is incredibly slow because it makes a syscall; use pthread_self() instead.
    // TSan complains about a data race here, but see is_current_thread_owner for why it's not an issue.
    carb::AtomicIfTSan<pthread_t> m_owner = kInvalidOwner;
    size_t m_depth = 0;
};
#    endif
} // namespace detail
#endif

class mutex : protected detail::BaseMutex<false>
{
    using Base = detail::BaseMutex<false>;

public:
    CARB_IF_NOT_TSAN(constexpr) mutex() noexcept = default;

    ~mutex() = default;

    void lock()
    {
        Base::lock();
    }

    bool try_lock()
    {
        return Base::try_lock();
    }

    void unlock()
    {
        Base::unlock();
    }

    bool is_current_thread_owner() const noexcept
    {
        return Base::is_current_thread_owner();
    }
};

class recursive_mutex : protected detail::BaseMutex<true>
{
    using Base = detail::BaseMutex<true>;

public:
    CARB_IF_NOT_TSAN(constexpr) recursive_mutex() noexcept = default;

    ~recursive_mutex() = default;

    void lock()
    {
        Base::lock();
    }

    bool try_lock()
    {
        return Base::try_lock();
    }

    void unlock()
    {
        Base::unlock();
    }

    bool is_current_thread_owner() const noexcept
    {
        return Base::is_current_thread_owner();
    }
};

} // namespace thread
} // namespace carb