carb/thread/Mutex.h
File members: carb/thread/Mutex.h
// Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//
#pragma once
#include "Futex.h"
#include "Util.h"
#include "../cpp20/Atomic.h"
#include <system_error>
#if CARB_PLATFORM_WINDOWS
# include "../CarbWindows.h"
#endif
namespace carb
{
namespace thread
{
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace detail
{
# if CARB_PLATFORM_WINDOWS
template <bool Recursive>
class BaseMutex
{
public:
constexpr static bool kRecursive = Recursive;
CARB_PREVENT_COPY_AND_MOVE(BaseMutex);
constexpr BaseMutex() noexcept = default;
~BaseMutex()
{
CARB_FATAL_UNLESS(m_count == 0, "Mutex destroyed while busy");
}
void lock()
{
uint32_t const tid = this_thread::getId();
if (!Recursive)
{
CARB_FATAL_UNLESS(tid != m_owner, "Recursion not allowed");
}
else if (tid == m_owner)
{
++m_count;
return;
}
AcquireSRWLockExclusive((PSRWLOCK)&m_lock);
m_owner = tid;
m_count = 1;
}
bool try_lock()
{
uint32_t const tid = this_thread::getId();
if (!Recursive)
{
CARB_FATAL_UNLESS(tid != m_owner, "Recursion not allowed");
}
else if (tid == m_owner)
{
++m_count;
return true;
}
if (CARB_LIKELY(TryAcquireSRWLockExclusive((PSRWLOCK)&m_lock)))
{
m_owner = tid;
m_count = 1;
return true;
}
return false;
}
void unlock()
{
uint32_t tid = this_thread::getId();
CARB_FATAL_UNLESS(m_owner == tid, "Not owner");
if (--m_count == 0)
{
m_owner = kInvalidOwner;
ReleaseSRWLockExclusive((PSRWLOCK)&m_lock);
}
}
private:
constexpr static uint32_t kInvalidOwner = uint32_t(0);
CARBWIN_SRWLOCK m_lock{ CARBWIN_SRWLOCK_INIT };
uint32_t m_owner{ kInvalidOwner };
int m_count{ 0 };
};
# else
template <bool Recursive>
class BaseMutex;
// BaseMutex (non-recursive)
template <>
class BaseMutex<false>
{
public:
constexpr static bool kRecursive = false;
constexpr BaseMutex() noexcept = default;
~BaseMutex()
{
CARB_FATAL_UNLESS(m_lock.load(std::memory_order_relaxed) == Unlocked, "Mutex destroyed while busy");
}
void lock()
{
// Blindly attempt to lock
LockState val = Unlocked;
if (CARB_UNLIKELY(
!m_lock.compare_exchange_strong(val, Locked, std::memory_order_acquire, std::memory_order_relaxed)))
{
CARB_FATAL_UNLESS(m_owner != this_thread::getId(), "Recursive locking not allowed");
// Failed to lock and need to wait
if (val == LockedMaybeWaiting)
{
m_lock.wait(LockedMaybeWaiting, std::memory_order_relaxed);
}
while (m_lock.exchange(LockedMaybeWaiting, std::memory_order_acquire) != Unlocked)
{
m_lock.wait(LockedMaybeWaiting, std::memory_order_relaxed);
}
CARB_ASSERT(m_owner == kInvalidOwner);
}
// Now inside the lock
m_owner = this_thread::getId();
}
bool try_lock()
{
// Blindly attempt to lock
LockState val = Unlocked;
if (CARB_LIKELY(m_lock.compare_exchange_strong(val, Locked, std::memory_order_acquire, std::memory_order_relaxed)))
{
m_owner = this_thread::getId();
return true;
}
CARB_FATAL_UNLESS(m_owner != this_thread::getId(), "Recursive locking not allowed");
return false;
}
void unlock()
{
CARB_FATAL_UNLESS(m_owner == this_thread::getId(), "Not owner");
m_owner = kInvalidOwner;
LockState val = m_lock.exchange(Unlocked, std::memory_order_release);
if (val == LockedMaybeWaiting)
{
m_lock.notify_one();
}
}
private:
enum LockState : uint8_t
{
Unlocked = 0,
Locked = 1,
LockedMaybeWaiting = 2,
};
constexpr static uint32_t kInvalidOwner = 0;
cpp20::atomic<LockState> m_lock{ Unlocked };
uint32_t m_owner{ kInvalidOwner };
};
// BaseMutex (recursive)
template <>
class BaseMutex<true>
{
public:
constexpr static bool kRecursive = true;
constexpr BaseMutex() noexcept = default;
~BaseMutex()
{
CARB_FATAL_UNLESS(m_lock.load(std::memory_order_relaxed) == 0, "Mutex destroyed while busy");
}
void lock()
{
// Blindly attempt to lock
uint32_t val = Unlocked;
if (CARB_UNLIKELY(
!m_lock.compare_exchange_strong(val, Locked, std::memory_order_acquire, std::memory_order_relaxed)))
{
// Failed to lock (or recursive)
if (m_owner == this_thread::getId())
{
val = m_lock.fetch_add(DepthUnit, std::memory_order_relaxed);
CARB_FATAL_UNLESS((val & DepthMask) != DepthMask, "Recursion overflow");
return;
}
// Failed to lock and need to wait
if ((val & ~DepthMask) == LockedMaybeWaiting)
{
m_lock.wait(val, std::memory_order_relaxed);
}
for (;;)
{
// Atomically set to LockedMaybeWaiting in a loop since the owning thread could be changing the depth
while (!m_lock.compare_exchange_weak(
val, (val & DepthMask) | LockedMaybeWaiting, std::memory_order_acquire, std::memory_order_relaxed))
CARB_HARDWARE_PAUSE();
if ((val & ~DepthMask) == Unlocked)
break;
m_lock.wait((val & DepthMask) | LockedMaybeWaiting, std::memory_order_relaxed);
}
CARB_ASSERT(m_owner == kInvalidOwner);
}
// Now inside the lock
m_owner = this_thread::getId();
}
bool try_lock()
{
// Blindly attempt to lock
uint32_t val = Unlocked;
if (CARB_LIKELY(m_lock.compare_exchange_strong(val, Locked, std::memory_order_acquire, std::memory_order_relaxed)))
{
// Succeeded, we now own the lock
m_owner = this_thread::getId();
return true;
}
// Failed (or recursive)
if (m_owner == this_thread::getId())
{
// Recursive, increment the depth
val = m_lock.fetch_add(DepthUnit, std::memory_order_acquire);
CARB_FATAL_UNLESS((val & DepthMask) != DepthMask, "Recursion overflow");
return true;
}
return false;
}
void unlock()
{
CARB_FATAL_UNLESS(m_owner == this_thread::getId(), "Not owner");
uint32_t val = m_lock.load(std::memory_order_relaxed);
if (!(val & DepthMask))
{
// Depth count is at zero, so this is the last unlock().
m_owner = kInvalidOwner;
uint32_t val = m_lock.exchange(Unlocked, std::memory_order_release);
if (val == LockedMaybeWaiting)
{
m_lock.notify_one();
}
}
else
m_lock.fetch_sub(DepthUnit, std::memory_order_release);
}
private:
enum LockState : uint32_t
{
Unlocked = 0,
Locked = 1,
LockedMaybeWaiting = 2,
DepthUnit = 1 << 2, // Each recursion count increment
DepthMask = 0xFFFFFFFC // The 30 MSBs are used for the recursion count
};
constexpr static uint32_t kInvalidOwner = 0;
cpp20::atomic_uint32_t m_lock{ Unlocked };
uint32_t m_owner{ kInvalidOwner };
};
# endif
} // namespace detail
#endif
class mutex : public detail::BaseMutex<false>
{
using Base = detail::BaseMutex<false>;
public:
constexpr mutex() noexcept = default;
~mutex() = default;
void lock()
{
Base::lock();
}
bool try_lock()
{
return Base::try_lock();
}
void unlock()
{
Base::unlock();
}
};
class recursive_mutex : public detail::BaseMutex<true>
{
using Base = detail::BaseMutex<true>;
public:
constexpr recursive_mutex() noexcept = default;
~recursive_mutex() = default;
void lock()
{
Base::lock();
}
bool try_lock()
{
return Base::try_lock();
}
void unlock()
{
Base::unlock();
}
};
} // namespace thread
} // namespace carb