carb/extras/Options.h

File members: carb/extras/Options.h

// Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//
#pragma once

#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <string>

namespace carb
{
namespace options
{

enum class ParseResult
{
    eSuccess,
    eInvalidValue,
};

enum class ValueType
{
    eIgnore = -1,
    eNone,
    eString,
    eLong,
    eLongLong,
    eFloat,
    eDouble,
};

class Value
{
public:
    Value()
    {
        clear();
    }

    void clear()
    {
        m_type = ValueType::eNone;
        m_string.clear();
    }

    void set(std::string value)
    {
        m_type = ValueType::eString;
        m_string = std::move(value);
    }

    void set(long value)
    {
        m_type = ValueType::eLong;
        m_value.integer = value;
    }

    void set(long long value)
    {
        m_type = ValueType::eLongLong;
        m_value.longInteger = value;
    }

    void set(float value)
    {
        m_type = ValueType::eFloat;
        m_value.floatValue = value;
    }

    void set(double value)
    {
        m_type = ValueType::eDouble;
        m_value.doubleValue = value;
    }

    ValueType getType() const
    {
        return m_type;
    }

    const char* getString() const
    {
        if (m_type != ValueType::eString)
            return nullptr;

        return m_string.c_str();
    }

    long getLong() const
    {
        return getNumber<long>();
    }

    long long getLongLong() const
    {
        return getNumber<long long>();
    }

    float getFloat() const
    {
        return getNumber<float>();
    }

    double getDouble() const
    {
        return getNumber<double>();
    }

private:
    template <typename T>
    T getNumber() const
    {
        switch (m_type)
        {
            default:
            case ValueType::eString:
                return 0;

            case ValueType::eLong:
                return static_cast<T>(m_value.integer);

            case ValueType::eLongLong:
                return static_cast<T>(m_value.longInteger);

            case ValueType::eFloat:
                return static_cast<T>(m_value.floatValue);

            case ValueType::eDouble:
                return static_cast<T>(m_value.doubleValue);
        }
    }

    ValueType m_type;
    std::string m_string;

    union
    {
        long integer;
        long long longInteger;
        float floatValue;
        double doubleValue;
    } m_value;
};

class Options
{
public:
    int argc = 0;

    char** argv = nullptr;

    int firstCommandArgument = -1;

    template <typename T>
    T* cast()
    {
        return reinterpret_cast<T*>(this);
    }
};

using ArgParserFunc = ParseResult (*)(const char* name, const Value* value, Options* args);

struct Option
{
    const char* shortName;

    const char* longName;

    int expectedArgs;

    ValueType expectedType;

    ArgParserFunc parser;

    const char* documentation;
};

namespace details
{
constexpr int kArgFailExpectedArgument = -1;

constexpr int kArgFailExpectedQuote = -2;

constexpr int kArgFailIndexOutOfRange = -3;

template <typename T>
inline bool stringToNumber(const char* string, T* value)
{
    char* endp = nullptr;
    *value = strtoull(string, &endp, 10);
    return endp == nullptr || endp[0] == 0;
}

template <>
inline bool stringToNumber(const char* string, long* value)
{
    char* endp = nullptr;
    *value = strtol(string, &endp, 10);
    return endp == nullptr || endp[0] == 0;
}

template <>
inline bool stringToNumber(const char* string, float* value)
{
    char* endp = nullptr;
    *value = strtof(string, &endp);
    return endp == nullptr || endp[0] == 0;
}

template <>
inline bool stringToNumber(const char* string, double* value)
{
    char* endp = nullptr;
    *value = strtod(string, &endp);
    return endp == nullptr || endp[0] == 0;
}

inline bool isArgParseFailureCode(int result)
{
    return result < 0;
}

inline int getArgString(int argc, char** argv, int argIndex, std::string& value)
{
    char* equal;
    int argsConsumed = 0;
    char* valueOut;
    size_t lengthToCopy;

    // make sure the requested argument index is in range of the argument list.
    if (argIndex >= argc)
        return kArgFailIndexOutOfRange;

    equal = strchr(argv[argIndex], '=');

    // the argument is of the form "name=value" => parse the number from the arg itself.
    if (equal != nullptr)
        valueOut = equal + 1;

    // the argument is of the form "name value" => parse the number from the next arg.
    else if (argIndex + 1 < argc)
    {
        valueOut = argv[argIndex + 1];
        argsConsumed = 1;
    }

    // not enough args given => fail.
    else
    {
        fprintf(stderr, "expected another argument after '%s'.\n", argv[argIndex]);
        return kArgFailExpectedArgument;
    }

    lengthToCopy = strlen(valueOut);

    if (valueOut[0] == '\"' || valueOut[0] == '\'')
    {
        char openQuote = valueOut[0];

        if (valueOut[lengthToCopy - 1] != openQuote)
            return kArgFailExpectedQuote;

        // calculate the length of the value string.  This is two less than the string's length -
        // one each for the opening and closing quotation marks.
        lengthToCopy -= 2;
        valueOut++;
    }

    value = std::string(valueOut, lengthToCopy);
    return argsConsumed;
}
} // namespace details
inline bool parseOptions(const Option* supportedArgs, int argc, char** argv, Options* args)
{
    bool handled;
    int argsConsumed;
    ParseResult result;
    Value value;
    Value* valueToSend;

    auto argMatches = [](const char* string, const char* arg, bool terminated) {
        size_t len = strlen(arg);

        if (!terminated)
            return strncmp(string, arg, len) == 0;

        return strncmp(string, arg, len) == 0 && (string[len] == 0 || string[len] == '=');
    };

    for (int i = 1; i < argc; i++)
    {
        handled = false;

        for (size_t j = 0; supportedArgs[j].parser != nullptr; j++)
        {
            bool checkTermination = supportedArgs[j].expectedType != ValueType::eIgnore;

            if ((supportedArgs[j].shortName != nullptr &&
                 argMatches(argv[i], supportedArgs[j].shortName, checkTermination)) ||
                (supportedArgs[j].longName != nullptr && argMatches(argv[i], supportedArgs[j].longName, checkTermination)))
            {
                std::string valueStr;
                valueToSend = nullptr;
                argsConsumed = 0;
                value.clear();

                if (supportedArgs[j].expectedArgs > 0)
                {
                    argsConsumed = details::getArgString(argc, argv, i, valueStr);

                    if (details::isArgParseFailureCode(argsConsumed))
                    {
                        switch (argsConsumed)
                        {
                            case details::kArgFailExpectedArgument:
                                fprintf(
                                    stderr, "ERROR-> expected an extra argument after argument %d ('%s').", i, argv[i]);
                                break;

                            case details::kArgFailExpectedQuote:
                                fprintf(
                                    stderr,
                                    "ERROR-> expected a matching quotation mark at the end of the argument value for argument %d ('%s').",
                                    i, argv[i]);
                                break;

                            // this error will never be returned given how the loop above
                            // iterates.  We'll print a basic message for it anyway.
                            case details::kArgFailIndexOutOfRange:
                                fprintf(stderr, "ERROR-> argument index out of range.\n");
                                break;

                            default:
                                break;
                        }

                        return false;
                    }

                    if (supportedArgs[j].expectedType == ValueType::eIgnore)
                    {
                        i += argsConsumed;
                        handled = true;
                        break;
                    }

#if !defined(DOXYGEN_SHOULD_SKIP_THIS)
#    define SETVALUE(value, str, type)                                                                                 \
        do                                                                                                             \
        {                                                                                                              \
            type convertedValue;                                                                                       \
            if (!details::stringToNumber(str.c_str(), &convertedValue))                                                \
            {                                                                                                          \
                fprintf(stderr, "ERROR-> expected a %s value after '%s'.\n", #type, argv[i]);                          \
                return false;                                                                                          \
            }                                                                                                          \
            value.set(convertedValue);                                                                                 \
        } while (0)
#endif

                    switch (supportedArgs[j].expectedType)
                    {
                        default:
                        case ValueType::eString:
                            value.set(valueStr);
                            break;

                        case ValueType::eLong:
                            SETVALUE(value, valueStr, long);
                            break;

                        case ValueType::eLongLong:
                            SETVALUE(value, valueStr, long long);
                            break;

                        case ValueType::eFloat:
                            SETVALUE(value, valueStr, float);
                            break;

                        case ValueType::eDouble:
                            SETVALUE(value, valueStr, double);
                            break;
                    }

                    valueToSend = &value;
#undef SETVALUE
                }

                result = supportedArgs[j].parser(argv[i], valueToSend, args);

                switch (result)
                {
                    default:
                    case ParseResult::eSuccess:
                        break;

                    case ParseResult::eInvalidValue:
                        fprintf(stderr, "ERROR-> unknown or invalid value in '%s'.\n", argv[i]);
                        return false;
                }

                i += argsConsumed;
                handled = true;
                break;
            }
        }

        if (!handled)
        {
            if (args->firstCommandArgument < 0)
                args->firstCommandArgument = i;

            break;
        }
    }

    args->argc = argc;
    args->argv = argv;
    return true;
}

inline void printOptionUsage(const Option* supportedArgs, const char* helpString, FILE* stream)
{
    const char* str;
    const char* newline;
    const char* argStr;

    fputs(helpString, stream);
    fputs("Supported options:\n", stream);

    for (size_t i = 0; supportedArgs[i].parser != nullptr; i++)
    {
        str = supportedArgs[i].documentation;
        argStr = "";

        if (supportedArgs[i].expectedArgs > 0)
            argStr = " [value]";

        if (supportedArgs[i].shortName != nullptr)
            fprintf(stream, "    %s%s:\n", supportedArgs[i].shortName, argStr);

        if (supportedArgs[i].longName != nullptr)
            fprintf(stream, "    %s%s:\n", supportedArgs[i].longName, argStr);

        for (newline = strchr(str, '\n'); newline != nullptr; str = newline + 1, newline = strchr(str + 1, '\n'))
            fprintf(stream, "        %.*s\n", static_cast<int>(newline - str), str);

        fputs("\n", stream);
    }

    fputs("\n", stream);
}

} // namespace options
} // namespace carb