examples/example.stats/plugins/carb.stats/Stats.cpp

File members: examples/example.stats/plugins/carb.stats/Stats.cpp

// Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//
#ifndef DOXYGEN_BUILD
#    include "Stats.h"

namespace carb
{
namespace stats
{

using ReadLock = carb::thread::shared_lock<carb::thread::shared_mutex>;
using WriteLock = std::lock_guard<carb::thread::shared_mutex>;

StatsImpl::StatsImpl()
{
}

StatsImpl::~StatsImpl()
{
    clear();
}

StatId StatsImpl::addStat(const StatDesc& desc)
{
    StatRecord record;

    // check for bad parameters.
    if (desc.name == nullptr)
        return kBadStatId;

    switch (desc.aggregationType)
    {
        case AggregationType::eReplace:
        case AggregationType::eAccumulate:
        case AggregationType::eAverage:
        case AggregationType::eMax:
        case AggregationType::eMin:
            break;

        default:
            return kBadStatId;
    }

    switch (desc.value.type)
    {
        case StatType::eInt:
        case StatType::eDouble:
            break;

        default:
            return kBadStatId;
    }

    // fill in the new record.
    record.name = desc.name;
    record.description = desc.description == nullptr ? "" : desc.description;
    record.aggregationType = desc.aggregationType;
    record.totalAdded = 1;
    record.accumulation.type = desc.value.type;
    record.accumulation.intValue = desc.value.intValue;
    record.accumulation.doubleValue = desc.value.doubleValue;
    record.value.type = desc.value.type;
    record.value.intValue = desc.value.intValue;
    record.value.doubleValue = desc.value.doubleValue;

    // add the new record to the table.
    WriteLock lock(m_mutex);
    record.id = m_nextId;
    m_nextId = StatId(m_nextId.get() + 1);
    m_table[record.id] = record;
    return record.id;
}

bool StatsImpl::removeStat(StatId stat)
{
    WriteLock lock(m_mutex);
    return m_table.erase(stat) != 0;
}

bool StatsImpl::addValue(StatId stat, const Value& value)
{
    WriteLock lock(m_mutex);
    Table::iterator it;

    // find the stat in the table.
    it = m_table.find(stat);

    if (it == m_table.end())
        return false;

    // aggregate the new value into the existing stat.
    StatRecord& record = it->second;

    // make sure the input value is the same type as this stat.
    if (record.value.type != value.type)
        return false;

    switch (record.aggregationType)
    {
        case AggregationType::eAccumulate:
            if (record.value.type == StatType::eInt)
                record.value.intValue += value.intValue;

            else if (record.value.type == StatType::eDouble)
                record.value.doubleValue += value.doubleValue;

            else
                return false;

            break;

        case AggregationType::eAverage:
            if (record.value.type == StatType::eInt)
            {
                record.accumulation.intValue += value.intValue;
                record.totalAdded++;

                record.value.intValue = record.accumulation.intValue / record.totalAdded;
            }

            else if (record.value.type == StatType::eDouble)
            {
                record.accumulation.doubleValue += value.doubleValue;
                record.totalAdded++;

                record.value.doubleValue = record.accumulation.doubleValue / record.totalAdded;
            }

            else
                return false;

            break;

        case AggregationType::eMax:
            if (record.value.type == StatType::eInt)
                record.value.intValue = CARB_MAX(record.value.intValue, value.intValue);

            else if (record.value.type == StatType::eDouble)
                record.value.doubleValue = CARB_MAX(record.value.doubleValue, value.doubleValue);

            else
                return false;

            break;

        case AggregationType::eMin:
            if (record.value.type == StatType::eInt)
                record.value.intValue = CARB_MIN(record.value.intValue, value.intValue);

            else if (record.value.type == StatType::eDouble)
                record.value.doubleValue = CARB_MIN(record.value.doubleValue, value.doubleValue);

            else
                return false;

            break;

        case AggregationType::eReplace:
            record.value.intValue = value.intValue;
            record.value.doubleValue = value.doubleValue;
            break;

        default:
            return false;
    }

    return true;
}

bool StatsImpl::getValue(StatId stat, StatDesc& value)
{
    ReadLock lock(m_mutex);
    Table::const_iterator it;

    // look up the entry in the table.
    it = m_table.find(stat);

    if (it == m_table.end())
        return false;

    // fill in the output buffer with the requested info.
    value.name = it->second.name.c_str();
    value.description = it->second.description.c_str();
    value.aggregationType = it->second.aggregationType;
    value.value.type = it->second.value.type;
    value.value.intValue = it->second.value.intValue;
    value.value.doubleValue = it->second.value.doubleValue;

    return true;
}

size_t StatsImpl::getCount()
{
    ReadLock lock(m_mutex);
    return m_table.size();
}

void StatsImpl::clear()
{
    WriteLock lock(m_mutex);
    m_table.clear();
}

} // namespace stats
} // namespace carb
#endif