carb/dictionary/DictionaryUtils.h

File members: carb/dictionary/DictionaryUtils.h

// Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//
#pragma once

#include "../Framework.h"
#include "../InterfaceUtils.h"
#include "../datasource/IDataSource.h"
#include "../extras/CmdLineParser.h"
#include "../filesystem/IFileSystem.h"
#include "../logging/Log.h"
#include "IDictionary.h"
#include "ISerializer.h"

#include <algorithm>
#include <string>

namespace carb
{
namespace dictionary
{

inline IDictionary* getCachedDictionaryInterface()
{
    return getCachedInterface<IDictionary>();
}

template <typename ElementData>
using OnItemFn = ElementData (*)(const Item* srcItem, ElementData elementData, void* userData);

template <typename ItemPtrType>
inline ItemPtrType* getChildByIndex(IDictionary* dict, ItemPtrType* item, size_t idx);

template <>
inline const Item* getChildByIndex(IDictionary* dict, const Item* item, size_t idx)
{
    return dict->getItemChildByIndex(item, idx);
}

enum class WalkerMode
{
    eIncludeRoot,

    eSkipRoot
};

template <typename ElementData,
          typename OnItemFnType,
          typename ItemPtrType = const Item,
          typename GetChildByIndexFuncType CARB_NO_DOC(= decltype(getChildByIndex<ItemPtrType>))>
inline void walkDictionary(IDictionary* dict,
                           WalkerMode walkerMode,
                           ItemPtrType* root,
                           ElementData rootElementData,
                           OnItemFnType onItemFn,
                           void* userData,
                           GetChildByIndexFuncType getChildByIndexFunc = getChildByIndex<ItemPtrType>)
{
    if (!root)
    {
        return;
    }

    struct ValueToParse
    {
        ItemPtrType* srcItem;
        ElementData elementData;
    };

    std::vector<ValueToParse> valuesToParse;
    valuesToParse.reserve(100);

    if (walkerMode == WalkerMode::eSkipRoot)
    {
        size_t numChildren = dict->getItemChildCount(root);
        for (size_t chIdx = 0; chIdx < numChildren; ++chIdx)
        {
            valuesToParse.push_back({ getChildByIndexFunc(dict, root, numChildren - chIdx - 1), rootElementData });
        }
    }
    else
    {
        valuesToParse.push_back({ root, rootElementData });
    }

    while (valuesToParse.size())
    {
        const ValueToParse valueToParse = valuesToParse.back();
        ItemPtrType* curItem = valueToParse.srcItem;
        ItemType curItemType = dict->getItemType(curItem);
        valuesToParse.pop_back();

        if (curItemType == ItemType::eDictionary)
        {
            size_t numChildren = dict->getItemChildCount(curItem);
            ElementData elementData = onItemFn(curItem, valueToParse.elementData, userData);
            for (size_t chIdx = 0; chIdx < numChildren; ++chIdx)
            {
                valuesToParse.push_back({ getChildByIndexFunc(dict, curItem, numChildren - chIdx - 1), elementData });
            }
        }
        else
        {
            onItemFn(curItem, valueToParse.elementData, userData);
        }
    }
}

inline std::string getStringFromItemName(const IDictionary* dict, const Item* baseItem, const char* path = nullptr)
{
    const Item* item = dict->getItem(baseItem, path);
    if (!item)
    {
        return std::string();
    }
    const char* itemNameBuf = dict->createStringBufferFromItemName(item);
    std::string returnString = itemNameBuf;
    dict->destroyStringBuffer(itemNameBuf);
    return returnString;
}

inline std::string getStringFromItemValue(const IDictionary* dict, const Item* baseItem, const char* path = nullptr)
{
    const Item* item = dict->getItem(baseItem, path);
    if (!item)
    {
        return std::string();
    }
    const char* stringBuf = dict->createStringBufferFromItemValue(item);
    std::string returnString = stringBuf;
    dict->destroyStringBuffer(stringBuf);
    return returnString;
}

inline std::vector<std::string> getStringArray(const IDictionary* dict, const Item* baseItem, const char* path)
{
    const Item* itemAtKey = dict->getItem(baseItem, path);
    std::vector<std::string> stringArray(dict->getArrayLength(itemAtKey));
    for (size_t i = 0; i < stringArray.size(); i++)
    {
        stringArray[i] = dict->getStringBufferAt(itemAtKey, i);
    }
    return stringArray;
}

inline std::vector<std::string> getStringArray(const IDictionary* dict, const Item* item)
{
    return getStringArray(dict, item, nullptr);
}

inline void setStringArray(IDictionary* dict, Item* baseItem, const char* path, const std::vector<std::string>& stringArray)
{
    Item* itemAtKey = dict->getItemMutable(baseItem, path);
    if (dict->getItemType(itemAtKey) != dictionary::ItemType::eCount)
    {
        dict->destroyItem(itemAtKey);
    }
    for (size_t i = 0, stringCount = stringArray.size(); i < stringCount; ++i)
    {
        dict->setStringAt(itemAtKey, i, stringArray[i].c_str());
    }
}

inline void setStringArray(IDictionary* dict, Item* item, const std::vector<std::string>& stringArray)
{
    setStringArray(dict, item, nullptr, stringArray);
}

inline void setDictionaryElementAutoType(IDictionary* id, Item* dict, const std::string& path, const std::string& value)
{
    if (!path.empty())
    {
        // We should validate that provided path is a proper path but for now we just use it
        //
        // Simple rules to support basic values:
        // if the value starts and with quotes (" or ') then it's the string inside the quotes
        // else if we can parse the value as a bool, int or float then we read it
        // according to the type. Otherwise we consider it to be a string.

        // Special case, if the string is empty, write an empty string early
        if (value.empty())
        {
            constexpr const char* kEmptyString = "";
            id->makeStringAtPath(dict, path.c_str(), kEmptyString);
            return;
        }

        if (value.size() > 1 &&
            ((value.front() == '"' && value.back() == '"') || (value.front() == '\'' && value.back() == '\'')))
        {
            // string value - chop off quotes
            id->makeStringAtPath(dict, path.c_str(), value.substr(1, value.size() - 2).c_str());
            return;
        }

        // Convert the value to upper case to simplify checks
        std::string uppercaseValue = value;
        std::transform(value.begin(), value.end(), uppercaseValue.begin(),
                       [](const char c) { return static_cast<char>(::toupper(c)); });

        // let's see if it's a boolean
        if (uppercaseValue == "TRUE")
        {
            id->makeBoolAtPath(dict, path.c_str(), true);
            return;
        }
        if (uppercaseValue == "FALSE")
        {
            id->makeBoolAtPath(dict, path.c_str(), false);
            return;
        }

        // let's see if it's an integer
        size_t valueLen = value.length();
        char* endptr;
        // Use a radix of 0 to allow for decimal, octal, and hexadecimal values to all be parsed.
        const long long int valueAsInt = strtoll(value.c_str(), &endptr, 0);

        if (endptr - value.c_str() == (ptrdiff_t)valueLen)
        {
            id->makeInt64AtPath(dict, path.c_str(), valueAsInt);
            return;
        }
        // let's see if it's a float
        const double valueAsFloat = strtod(value.c_str(), &endptr);
        if (endptr - value.c_str() == (ptrdiff_t)valueLen)
        {
            id->makeFloat64AtPath(dict, path.c_str(), valueAsFloat);
            return;
        }

        // consider the value to be a string even if it's empty
        id->makeStringAtPath(dict, path.c_str(), value.c_str());
    }
}

inline void setDictionaryFromStringMapping(IDictionary* id, Item* dict, const std::map<std::string, std::string>& mapping)
{
    for (const auto& kv : mapping)
    {
        setDictionaryElementAutoType(id, dict, kv.first, kv.second);
    }
}

inline void setDictionaryFromCmdLine(IDictionary* id, Item* dict, char** argv, int argc, const char* prefix = "--/")
{
    carb::extras::CmdLineParser cmdLineParser(prefix);
    cmdLineParser.parse(argv, argc);

    const std::map<std::string, std::string>& opts = cmdLineParser.getOptions();
    setDictionaryFromStringMapping(id, dict, opts);
}

inline void setDictionaryArrayElementFromStringValue(dictionary::IDictionary* dictionaryInterface,
                                                     dictionary::Item* targetDictionary,
                                                     const std::string& elementPath,
                                                     const std::string& elementValue)
{
    if (elementPath.empty())
    {
        return;
    }

    CARB_ASSERT(elementValue.size() >= 2 && elementValue.front() == '[' && elementValue.back() == ']');

    // Force delete item if it exists before creating a new array
    dictionary::Item* arrayItem = dictionaryInterface->getItemMutable(targetDictionary, elementPath.c_str());
    if (arrayItem)
    {
        dictionaryInterface->destroyItem(arrayItem);
    }

    // Creating a new dictionary element at the required path
    arrayItem = dictionaryInterface->makeDictionaryAtPath(targetDictionary, elementPath.c_str());
    // Setting necessary flag to make it a proper empty array
    // This will result in correct item replacement in case of dictionary merging
    dictionaryInterface->setItemFlag(arrayItem, dictionary::ItemFlag::eUnitSubtree, true);

    // Skip initial and the last square brackets and consider all elements separated by commas one by one
    // For each value create corresponding new path including index
    // Ex. "/some/path=[10,20]" will be processed as "/some/path/0=10" and "/some/path/1=20"

    const std::string commonElementPath = elementPath + '/';
    size_t curElementIndex = 0;

    // Helper adds provided value into the dictionary and increases index for the next addition
    auto dictElementAddHelper = [&](std::string value) {
        carb::extras::trimStringInplace(value);
        // Processing only non empty strings, empty string values should be stated as "": [ "a", "", "b" ]
        if (value.empty())
        {
            CARB_LOG_WARN(
                "Encountered and skipped an empty value for dictionary array element '%s' while parsing value '%s'",
                elementPath.c_str(), elementValue.c_str());
            return;
        }
        carb::dictionary::setDictionaryElementAutoType(
            dictionaryInterface, targetDictionary, commonElementPath + std::to_string(curElementIndex), value);
        ++curElementIndex;
    };

    std::string::size_type curValueStartPos = 1;

    // Add comma-separated values (except for the last one)
    for (std::string::size_type curCommaPos = elementValue.find(',', curValueStartPos);
         curCommaPos != std::string::npos; curCommaPos = elementValue.find(',', curValueStartPos))
    {
        dictElementAddHelper(elementValue.substr(curValueStartPos, curCommaPos - curValueStartPos));
        curValueStartPos = curCommaPos + 1;
    }

    // Now only the last value is left for addition
    std::string lastValue = elementValue.substr(curValueStartPos, elementValue.size() - curValueStartPos - 1);
    carb::extras::trimStringInplace(lastValue);
    // Do nothing if it's just a trailing comma: [ 1, 2, 3, ]
    if (!lastValue.empty())
    {
        carb::dictionary::setDictionaryElementAutoType(
            dictionaryInterface, targetDictionary, commonElementPath + std::to_string(curElementIndex), lastValue);
    }
}

inline Item* createDictionaryFromFile(ISerializer* serializer, const char* filename)
{
    carb::filesystem::IFileSystem* fs = carb::getCachedInterface<carb::filesystem::IFileSystem>();

    auto file = fs->openFileToRead(filename);
    if (!file)
        return nullptr;

    const size_t fileSize = fs->getFileSize(file);
    const size_t contentLen = fileSize + 1;

    std::unique_ptr<char[]> heap;
    char* content;

    if (contentLen <= 4096)
    {
        content = CARB_STACK_ALLOC(char, contentLen);
    }
    else
    {
        heap.reset(new char[contentLen]);
        content = heap.get();
    }

    const size_t readBytes = fs->readFileChunk(file, content, contentLen);
    fs->closeFile(file);

    if (readBytes != fileSize)
    {
        CARB_LOG_ERROR("Only read %zu bytes of a total of %zu bytes from file '%s'", readBytes, fileSize, filename);
    }

    // NUL terminate
    content[readBytes] = '\0';

    return serializer->createDictionaryFromStringBuffer(content, readBytes, fDeserializerOptionInSitu);
}

inline void saveFileFromDictionary(ISerializer* serializer,
                                   const dictionary::Item* dictionary,
                                   const char* filename,
                                   SerializerOptions serializerOptions)
{
    const char* serializedString = serializer->createStringBufferFromDictionary(dictionary, serializerOptions);
    filesystem::IFileSystem* fs = carb::getCachedInterface<filesystem::IFileSystem>();
    filesystem::File* sFile = fs->openFileToWrite(filename);
    if (sFile == nullptr)
    {
        CARB_LOG_ERROR("failed to open file '%s' - unable to save the dictionary", filename);
        return;
    }
    fs->writeFileChunk(sFile, serializedString, strlen(serializedString));
    fs->closeFile(sFile);
    serializer->destroyStringBuffer(serializedString);
}

inline std::string dumpToString(const dictionary::Item* c, const char* serializerName = nullptr)
{
    std::string serializedDictionary;

    Framework* framework = carb::getFramework();
    dictionary::ISerializer* configSerializer = nullptr;

    // First, try to acquire interface with provided plugin name, if any
    if (serializerName)
    {
        configSerializer = framework->tryAcquireInterface<dictionary::ISerializer>(serializerName);
    }
    // If not available, or plugin name is not provided, try to acquire any serializer interface
    if (!configSerializer)
    {
        configSerializer = framework->tryAcquireInterface<dictionary::ISerializer>();
    }

    const char* configString =
        configSerializer->createStringBufferFromDictionary(c, dictionary::fSerializerOptionMakePretty);

    if (configString != nullptr)
    {
        serializedDictionary = configString;

        configSerializer->destroyStringBuffer(configString);
    }

    return serializedDictionary;
};

inline std::string getItemFullPath(dictionary::IDictionary* dict,
                                   const carb::dictionary::Item* item,
                                   bool includeRoot = true)
{
    if (!item)
    {
        return std::string();
    }
    std::vector<const char*> pathElementsNames;

    while (item)
    {
        pathElementsNames.push_back(dict->getItemName(item));
        item = dict->getItemParent(item);
    }

    if (!includeRoot)
    {
        pathElementsNames.pop_back();
    }

    size_t totalSize = 0;
    for (const auto& elementName : pathElementsNames)
    {
        totalSize += 1; // the '/' separator
        if (elementName)
        {
            totalSize += std::strlen(elementName);
        }
    }

    std::string result;
    result.reserve(totalSize);

    for (size_t idx = 0, elementCount = pathElementsNames.size(); idx < elementCount; ++idx)
    {
        const char* elementName = pathElementsNames[elementCount - idx - 1];
        result += '/';
        if (elementName)
        {
            result += elementName;
        }
    }
    return result;
}

template <typename Type>
inline ItemType toItemType();

template <>
inline ItemType toItemType<int32_t>()
{
    return ItemType::eInt;
}

template <>
inline ItemType toItemType<int64_t>()
{
    return ItemType::eInt;
}

template <>
inline ItemType toItemType<float>()
{
    return ItemType::eFloat;
}

template <>
inline ItemType toItemType<double>()
{
    return ItemType::eFloat;
}

template <>
inline ItemType toItemType<bool>()
{
    return ItemType::eBool;
}

template <>
inline ItemType toItemType<char*>()
{
    return ItemType::eString;
}

template <>
inline ItemType toItemType<const char*>()
{
    return ItemType::eString;
}

inline void unsubscribeTreeFromAllEvents(IDictionary* dict, Item* item)
{
    auto unsubscribeItem = [](Item* srcItem, uint32_t elementData, void* userData) -> uint32_t {
        IDictionary* dict = (IDictionary*)userData;
        dict->unsubscribeItemFromNodeChangeEvents(srcItem);
        dict->unsubscribeItemFromTreeChangeEvents(srcItem);
        return elementData;
    };

    const auto getChildByIndexMutable = [](IDictionary* dict, Item* item, size_t index) {
        return dict->getItemChildByIndexMutable(item, index);
    };

    walkDictionary(dict, WalkerMode::eIncludeRoot, item, 0, unsubscribeItem, dict, getChildByIndexMutable);
}

inline UpdateAction overwriteOriginalWithArrayHandling(
    const Item* dstItem, ItemType dstItemType, const Item* srcItem, ItemType srcItemType, void* dictionaryInterface)
{
    CARB_UNUSED(dstItemType, srcItemType);
    if (dstItem && dictionaryInterface)
    {
        carb::dictionary::IDictionary* dictInt = static_cast<carb::dictionary::IDictionary*>(dictionaryInterface);

        if (dictInt->getItemFlag(srcItem, carb::dictionary::ItemFlag::eUnitSubtree))
        {
            return carb::dictionary::UpdateAction::eReplaceSubtree;
        }
    }
    return carb::dictionary::UpdateAction::eOverwrite;
}

} // namespace dictionary
} // namespace carb