carb/tasking/ThreadPoolUtils.h
File members: carb/tasking/ThreadPoolUtils.h
// Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//
#pragma once
#include "../cpp/Tuple.h"
#include "../logging/Log.h"
#include "IThreadPool.h"
#include <future>
namespace carb
{
namespace tasking
{
#ifndef DOXYGEN_BUILD
namespace detail
{
template <class ReturnType>
struct ApplyWithPromise
{
template <class Callable, class Tuple>
void operator()(std::promise<ReturnType>& promise, Callable&& f, Tuple&& t)
{
promise.set_value(std::forward<ReturnType>(cpp::apply(std::forward<Callable>(f), std::forward<Tuple>(t))));
}
};
template <>
struct ApplyWithPromise<void>
{
template <class Callable, class Tuple>
void operator()(std::promise<void>& promise, Callable& f, Tuple&& t)
{
cpp::apply(std::forward<Callable>(f), std::forward<Tuple>(t));
promise.set_value();
}
};
} // namespace detail
#endif
class ThreadPoolWrapper
{
public:
ThreadPoolWrapper(IThreadPool* poolInterface, size_t workerCount = 0) : m_interface(poolInterface)
{
if (m_interface == nullptr)
{
CARB_LOG_ERROR("IThreadPool interface used to create a thread pool wrapper is null.");
return;
}
if (workerCount == 0)
{
workerCount = m_interface->getDefaultWorkerCount();
}
m_pool = m_interface->createEx(workerCount);
if (m_pool == nullptr)
{
CARB_LOG_ERROR("Couldn't create a new thread pool.");
}
}
size_t getWorkerCount() const
{
if (!isValid())
{
CARB_LOG_ERROR("Attempt to call the 'getWorkerCount' method of an invalid thread pool wrapper.");
return 0;
}
return m_interface->getWorkerCount(m_pool);
}
template <class Callable, class... Args>
auto enqueueJob(Callable&& task, Args&&... args)
{
using ReturnType = typename cpp::invoke_result_t<Callable, Args...>;
using Future = std::future<ReturnType>;
using Tuple = std::tuple<std::decay_t<Args>...>;
struct Data
{
std::promise<ReturnType> promise{};
Callable f;
Tuple args;
Data(Callable&& f_, Args&&... args_) : f(std::forward<Callable>(f_)), args(std::forward<Args>(args_)...)
{
}
void callAndDelete()
{
detail::ApplyWithPromise<ReturnType>{}(promise, f, args);
delete this;
}
};
if (!isValid())
{
CARB_LOG_ERROR("Attempt to call the 'enqueueJob' method of an invalid thread pool wrapper.");
return Future{};
}
Data* pData = new (std::nothrow) Data{ std::forward<Callable>(task), std::forward<Args>(args)... };
if (!pData)
{
CARB_LOG_ERROR("ThreadPoolWrapper: No memory for job");
return Future{};
}
Future result = pData->promise.get_future();
if (CARB_LIKELY(m_interface->enqueueJob(
m_pool, [](void* userData) { static_cast<Data*>(userData)->callAndDelete(); }, pData)))
{
return result;
}
CARB_LOG_ERROR("ThreadPoolWrapper: failed to enqueue job");
delete pData;
return Future{};
}
size_t getCurrentlyRunningJobCount() const
{
if (!isValid())
{
CARB_LOG_ERROR("Attempt to call the 'getCurrentlyRunningJobCount' method of an invalid thread pool wrapper.");
return 0;
}
return m_interface->getCurrentlyRunningJobCount(m_pool);
}
void waitUntilFinished() const
{
if (!isValid())
{
CARB_LOG_ERROR("Attempt to call the 'waitUntilFinished' method of an invalid thread pool wrapper.");
return;
}
m_interface->waitUntilFinished(m_pool);
}
bool isValid() const
{
return m_pool != nullptr;
}
~ThreadPoolWrapper()
{
if (isValid())
{
m_interface->destroy(m_pool);
}
}
CARB_PREVENT_COPY_AND_MOVE(ThreadPoolWrapper);
private:
// ThreadPoolWrapper private members and functions
IThreadPool* m_interface = nullptr;
ThreadPool* m_pool = nullptr;
};
} // namespace tasking
} // namespace carb