omni/kit/exec/core/unstable/ParallelScheduler.h
File members: omni/kit/exec/core/unstable/ParallelScheduler.h
// Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//
#pragma once
#include <omni/graph/exec/unstable/AtomicBackoff.h>
#include <omni/kit/exec/core/unstable/ITbbSchedulerState.h>
#include <omni/kit/exec/core/unstable/TbbSchedulerState.h>
#ifndef __TBB_ALLOW_MUTABLE_FUNCTORS
# define __TBB_ALLOW_MUTABLE_FUNCTORS 1
#endif
#include <tbb/task.h>
#include <tbb/task_group.h>
#include <atomic>
#include <thread>
namespace omni
{
namespace kit
{
namespace exec
{
namespace core
{
namespace unstable
{
namespace detail
{
class ParallelSchedulerDebugger;
} // namespace detail
#ifndef DOXYGEN_BUILD
class ParallelScheduler
{
public:
static ParallelScheduler& getSingleton()
{
// this static is per-DLL, but the actual private data (s_state) is shared across DLLs
static ParallelScheduler sScheduler;
return sScheduler;
}
template <typename Fn>
class TaskBody : public TbbSchedulerState::TaskStackEntry
{
public:
TaskBody(Fn&& f) : m_fn(std::forward<Fn>(f)), m_sharedState(ParallelScheduler::getSingleton().s_state)
{
++m_sharedState->totalTasksInFlight;
}
~TaskBody()
{
--m_sharedState->totalTasksInFlight;
}
task* execute() override
{
struct ExecutingCountedTaskScope
{
ExecutingCountedTaskScope(tbb::enumerable_thread_specific<int>& executingPerThread)
: _executingPerThread(executingPerThread)
{
_executingPerThread.local()++;
}
~ExecutingCountedTaskScope()
{
_executingPerThread.local()--;
}
tbb::enumerable_thread_specific<int>& _executingPerThread;
} executingScope(m_sharedState->tasksExecutingPerThread);
return m_fn(); // Allow for potential continuation.
};
private:
Fn m_fn;
TbbSchedulerState* m_sharedState{ nullptr };
};
template <typename Fn>
class IsolateTaskBody : public TbbSchedulerState::TaskStackEntry
{
public:
IsolateTaskBody(Fn&& f) : m_fn(std::forward<Fn>(f))
{
}
task* execute() override
{
// Isolate tasks are not counted in total in flight tasks to allow us detect when
// all other work is finished. Once we are entering scope of isolate task, we make sure
// PauseTaskScope doesn't consider this task as taking part in totalTasksInFlight.
struct ExecutingNotCountedTaskScope
{
ExecutingNotCountedTaskScope(tbb::enumerable_thread_specific<int>& executingPerThread)
: _executingPerThread(executingPerThread)
{
auto& perThreadCount = _executingPerThread.local();
_originalPerThread = perThreadCount;
perThreadCount = 0;
}
~ExecutingNotCountedTaskScope()
{
_executingPerThread.local() = _originalPerThread;
}
tbb::enumerable_thread_specific<int>& _executingPerThread;
int _originalPerThread;
} executingScope(ParallelScheduler::getSingleton().s_state->tasksExecutingPerThread);
return m_fn(); // Allow for potential continuation.
};
private:
Fn m_fn;
};
class PauseTaskScope
{
public:
PauseTaskScope()
: m_scheduler(ParallelScheduler::getSingleton()),
m_isCountedTask(m_scheduler.s_state->tasksExecutingPerThread.local() > 0)
{
if (m_isCountedTask)
{
--m_scheduler.s_state->totalTasksInFlight;
}
}
~PauseTaskScope()
{
if (m_isCountedTask)
{
++m_scheduler.s_state->totalTasksInFlight;
}
}
private:
ParallelScheduler& m_scheduler;
const bool m_isCountedTask;
};
void pushParallelTask(TbbSchedulerState::TaskStackEntry& t, bool isExecutingThread)
{
if (!isProcessingIsolate())
{
if (isExecutingThread && (s_state->totalTasksInFlight < 2))
{
// Context execution-invoking threads never joins the arena, so won't process its local queue. Add the
// task to the global queue. We only do this if there aren't already tasks in flight that can pick up
// task if we spawn() it.
tbb::task::enqueue(t);
}
else
{
tbb::task::spawn(t); // add to local queue
}
}
else
{
s_state->stackParallel.push(&t);
}
}
void pushSerialTask(TbbSchedulerState::TaskStackEntry& t)
{
s_state->stackSerial.push(&t);
}
void pushIsolateTask(TbbSchedulerState::TaskStackEntry& t)
{
s_state->stackIsolate.push(&t);
}
bool isWithinSerialTask() const
{
return (s_state->serialThreadId.load() == std::this_thread::get_id());
}
bool isWithinIsolateTask() const
{
return (s_state->isolateThreadId.load() == std::this_thread::get_id());
}
bool isProcessingIsolate() const
{
return (s_state->isolateThreadId.load() != std::thread::id());
}
void processContextThread(tbb::empty_task* const rootTask)
{
// See the note in TbbSchedulerState.h as to why we need this
// mutex lock here.
std::unique_lock<tbb::recursive_mutex> uniqueLock(s_state->executingThreadMutex);
graph::exec::unstable::AtomicBackoff backoff;
if (isWithinIsolateTask())
{
while (rootTask->ref_count() > 1)
{
if (processQueuedTasks(rootTask))
backoff.reset();
else
backoff.pause();
}
}
else
{
while (rootTask->ref_count() > 1)
{
if (processSerialTasks() || processIsolateTasks())
backoff.reset();
else
backoff.pause();
}
}
}
private:
bool processSerialTasks()
{
tbb::task* task = s_state->stackSerial.pop();
// No work.
if (!task)
{
return false;
}
{
// Acquire the thread that's supposed to be handling serial task evaluation. Note
// that this thread can be derived from many different execution contexts kickstarted
// by multiple different threads.
std::thread::id currentThreadId = std::this_thread::get_id();
std::thread::id originalThreadId;
if (!s_state->serialThreadId.compare_exchange_strong(originalThreadId, currentThreadId) &&
originalThreadId != currentThreadId)
{
return false;
}
// Once serial tasks evaluation is complete, we will want to restore the serial thread ID
// back to its default value; employ an RAII helper to do so.
struct ScopeRelease
{
std::thread::id _originalThreadId;
std::atomic<std::thread::id>& _serialThreadId;
~ScopeRelease()
{
_serialThreadId = _originalThreadId;
}
} threadScope = { originalThreadId, s_state->serialThreadId };
// Run the loop over tasks. We are responsible here to delete the task since it is consumed outside of TBB.
do
{
(void)task->execute();
tbb::task::destroy(*task);
} while ((task = s_state->stackSerial.pop()));
}
// Let the caller know that we had something to do
return true;
}
bool processIsolateTasks()
{
if (s_state->totalTasksInFlight > 0 || s_state->stackIsolate.isEmpty())
{
return false;
}
// Try to acquire the right to process isolated tasks. Need to support nested executions.
{
std::thread::id currentThreadId = std::this_thread::get_id();
std::thread::id originalThreadId;
if (!s_state->isolateThreadId.compare_exchange_strong(originalThreadId, currentThreadId) &&
originalThreadId != currentThreadId)
{
return false;
}
// We acquired the thread, nothing else will be running until the end of this scope.
struct ScopeRelease
{
std::thread::id _originalThreadId;
std::atomic<std::thread::id>& _isolateThreadId;
~ScopeRelease()
{
_isolateThreadId = _originalThreadId;
}
} threadScope = { originalThreadId, s_state->isolateThreadId };
// Run the loop over tasks. We are responsible here to delete the task since it is consumed outside of TBB.
while (tbb::task* task = s_state->stackIsolate.pop())
{
(void)task->execute();
tbb::task::destroy(*task);
}
// Here we will release this thread from processing isolate tasks.
// We don't worry about synchronization between push tasks and this operation because
// all push call can only happen from within above loop (indirectly via execute). There is no
// other concurrent execution happening when we are in here.
}
// Do NOT start parallel work in nested isolate task; it has to be all consumed on this thread until we can
// exit isolation.
if (!isWithinIsolateTask())
{
// Restart parallel task execution.
while (tbb::task* dispatchTask = s_state->stackParallel.pop())
{
if (s_state->totalTasksInFlight < 2)
{
tbb::task::enqueue(*dispatchTask);
}
else
{
tbb::task::spawn(*dispatchTask);
}
}
}
return true;
}
bool processQueuedTasks(tbb::empty_task* const rootTask)
{
OMNI_ASSERT(isWithinIsolateTask());
// First attempt to only process the first task in each scheduling-type stack. In most cases this is enough to
// allow for subsequent execution to occur.
bool ret = processTaskStackTop(s_state->stackIsolate, rootTask) ||
processTaskStackTop(s_state->stackParallel, rootTask) ||
processTaskStackTop(s_state->stackSerial, rootTask);
// If no tasks have been processed yet because the top task in each stack belongs to a different ParallelSpawner
// instance than the one currently running (i.e., not because all of the task stacks are empty), perform more
// intensive task processing by looping through each stack to find at least one task that's been enqueued by
// the current ParallelSpanwer/Executor. This situation has not been observed thus far in any cases where only
// a single ExecutionGraph is present in the runtime (as is the case in kit at the moment, and even if there are
// multiple contexts associated with it for evaluation, which is exercised in the unit tests), and only seems to
// happen occassionally (~5% of the time) in a couple of unit tests (specifically ones that either contain very
// large execution graphs or multiple different execution graphs that are evaluated across multiple threads
// simultaneously, and even those cases seem to only manifest in Linux, or at least have not been observed in
// Windows so far). The relative rarity with which this scenario is encountered, combined with the fact that the
// aforementioned looping logic does add some extra overhead to potentially hot code-paths, is why we've broken
// the below section out into its own scope.
if (!ret &&
(!s_state->stackIsolate.isEmpty() || !s_state->stackParallel.isEmpty() || !s_state->stackSerial.isEmpty()))
{
if (!ret)
{
ret = processTaskStack(s_state->stackIsolate, rootTask);
}
if (!ret)
{
ret = processTaskStack(s_state->stackParallel, rootTask);
}
if (!ret)
{
ret = processTaskStack(s_state->stackSerial, rootTask);
}
}
return ret;
}
// Helper method that will execute and clean up the top task in the specified stack if said task was originally
// inserted by the current ParallelSpawner instance.
const bool processTaskStackTop(carb::container::LocklessStack<TbbSchedulerState::TaskStackEntry,
&TbbSchedulerState::TaskStackEntry::m_link>& taskStack,
tbb::empty_task* const rootTask) noexcept
{
if (TbbSchedulerState::TaskStackEntry* const task = taskStack.pop())
{
if (task->parent() == rootTask) // Prevent task execute continuation if the current task
// belongs to a different ParallelSpawner instance.
{
(void)task->execute();
tbb::task::destroy(*task); // We are responsible here to delete the task since it is consumed manually.
return true;
}
else
{
taskStack.push(task); // Add the task back to the stack without processing it.
}
}
return false;
}
// Helper method that will execute and clean up the first task it finds in the specified stack that was originally
// inserted by the current ParallelSpawner instance.
const bool processTaskStack(carb::container::LocklessStack<TbbSchedulerState::TaskStackEntry,
&TbbSchedulerState::TaskStackEntry::m_link>& taskStack,
tbb::empty_task* const rootTask) noexcept
{
if (!taskStack.isEmpty())
{
TbbSchedulerState::TaskStackEntry* taskToExecute = nullptr;
carb::container::LocklessStack<TbbSchedulerState::TaskStackEntry, &TbbSchedulerState::TaskStackEntry::m_link> tempStack;
while (TbbSchedulerState::TaskStackEntry* const task = taskStack.pop())
{
if (task->parent() == rootTask)
{
taskToExecute = task;
break;
}
tempStack.push(task);
}
while (TbbSchedulerState::TaskStackEntry* const task = tempStack.pop())
{
taskStack.push(task);
}
if (taskToExecute)
{
(void)taskToExecute->execute();
tbb::task::destroy(*taskToExecute); // We are responsible here to delete the task since it is consumed
// manually.
return true;
}
}
return false;
}
explicit ParallelScheduler() noexcept
{
omni::core::ObjectPtr<ITbbSchedulerState> sInterface = omni::core::createType<ITbbSchedulerState>();
OMNI_ASSERT(sInterface);
s_state = sInterface->getState();
OMNI_ASSERT(s_state);
}
TbbSchedulerState* s_state;
friend class detail::ParallelSchedulerDebugger;
};
namespace detail
{
class ParallelSchedulerDebugger
{
public:
static const size_t enqueuedIsolateTasksUnsafeCount()
{
ParallelScheduler& scheduler = ParallelScheduler::getSingleton();
// First obtain the total number of enqueued isolate tasks by emptying the stack and counting all
// popped tasks; store said tasks in a temporary container.
std::size_t count = 0;
carb::container::LocklessStack<TbbSchedulerState::TaskStackEntry, &TbbSchedulerState::TaskStackEntry::m_link> tempStack;
while (auto task = scheduler.s_state->stackIsolate.pop())
{
++count;
tempStack.push(task);
}
// Refill the isolate tasks stack.
while (auto task = tempStack.pop())
{
scheduler.s_state->stackIsolate.push(task);
}
return count;
}
static const size_t enqueuedParallelTasksUnsafeCount()
{
ParallelScheduler& scheduler = ParallelScheduler::getSingleton();
// First obtain the total number of enqueued parallel tasks by emptying the stack and counting all
// popped tasks; store said tasks in a temporary container.
std::size_t count = 0;
carb::container::LocklessStack<TbbSchedulerState::TaskStackEntry, &TbbSchedulerState::TaskStackEntry::m_link> tempStack;
while (auto task = scheduler.s_state->stackParallel.pop())
{
++count;
tempStack.push(task);
}
// Refill the parallel tasks stack.
while (auto task = tempStack.pop())
{
scheduler.s_state->stackParallel.push(task);
}
return count;
}
static const size_t enqueuedSerialTasksUnsafeCount()
{
ParallelScheduler& scheduler = ParallelScheduler::getSingleton();
// First obtain the total number of enqueued serial tasks by emptying the stack and counting all
// popped tasks; store said tasks in a temporary container.
std::size_t count = 0;
carb::container::LocklessStack<TbbSchedulerState::TaskStackEntry, &TbbSchedulerState::TaskStackEntry::m_link> tempStack;
while (auto task = scheduler.s_state->stackSerial.pop())
{
++count;
tempStack.push(task);
}
// Refill the serial tasks stack.
while (auto task = tempStack.pop())
{
scheduler.s_state->stackSerial.push(task);
}
return count;
}
};
template <template <class...> typename TaskT, typename A, typename Fn>
TaskT<Fn>* makeTask(A&& allocTask, Fn&& f)
{
return new (allocTask) TaskT<Fn>(std::forward<Fn>(f));
}
} // namespace detail
#endif // DOXYGEN_BUILD
struct ParallelSpawner
{
ParallelSpawner(graph::exec::unstable::IExecutionContext* context)
: m_context(context), m_scheduler(ParallelScheduler::getSingleton())
{
m_root = new (tbb::task::allocate_root()) tbb::empty_task;
m_root->set_ref_count(1);
}
~ParallelSpawner()
{
tbb::task::destroy(*m_root);
}
template <typename Fn>
graph::exec::unstable::Status schedule(Fn&& task, graph::exec::unstable::SchedulingInfo schedInfo)
{
using namespace detail;
if (schedInfo == graph::exec::unstable::SchedulingInfo::eParallel)
{
auto* dispatchTask = detail::makeTask<ParallelScheduler::TaskBody>(
tbb::task::allocate_additional_child_of(*m_root),
[task = graph::exec::unstable::captureScheduleFunction(task), this]() mutable
{
graph::exec::unstable::Status ret = graph::exec::unstable::invokeScheduleFunction(task);
this->accumulateStatus(ret);
return nullptr;
});
m_scheduler.pushParallelTask(*dispatchTask, m_context->isExecutingThread());
}
else if (schedInfo == graph::exec::unstable::SchedulingInfo::eIsolate)
{
auto* dispatchTask = detail::makeTask<ParallelScheduler::IsolateTaskBody>(
tbb::task::allocate_additional_child_of(*m_root),
[task = graph::exec::unstable::captureScheduleFunction(task), this]() mutable
{
graph::exec::unstable::Status ret = graph::exec::unstable::invokeScheduleFunction(task);
this->accumulateStatus(ret);
return nullptr;
});
m_scheduler.pushIsolateTask(*dispatchTask);
}
else
{
auto* dispatchTask = detail::makeTask<ParallelScheduler::TaskBody>(
tbb::task::allocate_additional_child_of(*m_root),
[task = graph::exec::unstable::captureScheduleFunction(task), this]() mutable -> tbb::task*
{
graph::exec::unstable::Status ret = graph::exec::unstable::invokeScheduleFunction(task);
this->accumulateStatus(ret);
return nullptr;
});
m_scheduler.pushSerialTask(*dispatchTask);
}
return graph::exec::unstable::Status::eSuccess;
}
void accumulateStatus(graph::exec::unstable::Status ret)
{
graph::exec::unstable::Status current, newValue = graph::exec::unstable::Status::eUnknown;
do
{
current = m_status.load();
newValue = ret | current;
} while (!m_status.compare_exchange_weak(current, newValue));
}
graph::exec::unstable::Status getStatus()
{
// We are about to enter nested execution. This has an effect on total tasks in flight, i.e.
// we will suspend the current task by reducing the counters. All that is done by RAII class below.
ParallelScheduler::PauseTaskScope pauseTask;
// Note that in situations where multiple different contexts are running from multiple different
// threads, just having the first check (m_context->isExecutingThread()) won't be enough because
// we may be attempting to get the status of a serial/isolate task that was originally created in
// context A running on thread 1 while context B running on thread 2 is being processed in the
// below check; this can occur if context A/thread 1 was temporarily suspended in the past after
// context B/thread 2 beat it to acquiring the s_state->executingThreadMutex, meaning that we
// are currently processing nested tasks (which can be of any scheduling type) that are derived
// from some top-level set of serial/isolate tasks in context B. In such situations, we need to
// additionally check if we are currently within a serial or isolate task scope, since otherwise
// the task originally created in context A/thread 1 will incorrectly skip processing on context B's
// running kickstarter thread, despite thread 2 being the only such evaluating kickstarter thread at
// the moment, and take a different code-path that leads to hangs. Serial/isolate tasks don't need to
// be run exclusively on the context-kickstarting thread from which they were eventually scheduled -
// they can be run on any such thread type as long as that thread is the only one running/processing
// serial/isolate tasks.
if (m_context->isExecutingThread() || m_scheduler.isWithinSerialTask() || m_scheduler.isWithinIsolateTask())
{
m_scheduler.processContextThread(m_root);
}
else
{
if (m_root->ref_count() > 1)
{
m_root->wait_for_all();
m_root->set_ref_count(1);
}
}
return m_status;
}
protected:
graph::exec::unstable::IExecutionContext* m_context;
tbb::empty_task* m_root{ nullptr };
std::atomic<graph::exec::unstable::Status> m_status{ graph::exec::unstable::Status::eUnknown };
ParallelScheduler& m_scheduler;
};
} // namespace unable
} // namespace core
} // namespace exec
} // namespace kit
} // namespace omni