carb/extras/Options.h
File members: carb/extras/Options.h
// Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//
#pragma once
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <string>
namespace carb
{
namespace options
{
enum class ParseResult
{
eSuccess,
eInvalidValue,
};
enum class ValueType
{
eIgnore = -1,
eNone,
eString,
eLong,
eLongLong,
eFloat,
eDouble,
};
class Value
{
public:
Value()
{
clear();
}
void clear()
{
m_type = ValueType::eNone;
m_string.clear();
}
void set(std::string value)
{
m_type = ValueType::eString;
m_string = std::move(value);
}
void set(long value)
{
m_type = ValueType::eLong;
m_value.integer = value;
}
void set(long long value)
{
m_type = ValueType::eLongLong;
m_value.longInteger = value;
}
void set(float value)
{
m_type = ValueType::eFloat;
m_value.floatValue = value;
}
void set(double value)
{
m_type = ValueType::eDouble;
m_value.doubleValue = value;
}
ValueType getType() const
{
return m_type;
}
const char* getString() const
{
if (m_type != ValueType::eString)
return nullptr;
return m_string.c_str();
}
long getLong() const
{
return getNumber<long>();
}
long long getLongLong() const
{
return getNumber<long long>();
}
float getFloat() const
{
return getNumber<float>();
}
double getDouble() const
{
return getNumber<double>();
}
private:
template <typename T>
T getNumber() const
{
switch (m_type)
{
default:
case ValueType::eString:
return 0;
case ValueType::eLong:
return static_cast<T>(m_value.integer);
case ValueType::eLongLong:
return static_cast<T>(m_value.longInteger);
case ValueType::eFloat:
return static_cast<T>(m_value.floatValue);
case ValueType::eDouble:
return static_cast<T>(m_value.doubleValue);
}
}
ValueType m_type;
std::string m_string;
union
{
long integer;
long long longInteger;
float floatValue;
double doubleValue;
} m_value;
};
class Options
{
public:
int argc = 0;
char** argv = nullptr;
int firstCommandArgument = -1;
template <typename T>
T* cast()
{
return reinterpret_cast<T*>(this);
}
};
using ArgParserFunc = ParseResult (*)(const char* name, const Value* value, Options* args);
struct Option
{
const char* shortName;
const char* longName;
int expectedArgs;
ValueType expectedType;
ArgParserFunc parser;
const char* documentation;
};
namespace details
{
constexpr int kArgFailExpectedArgument = -1;
constexpr int kArgFailExpectedQuote = -2;
constexpr int kArgFailIndexOutOfRange = -3;
template <typename T>
inline bool stringToNumber(const char* string, T* value)
{
char* endp = nullptr;
*value = strtoull(string, &endp, 10);
return endp == nullptr || endp[0] == 0;
}
template <>
inline bool stringToNumber(const char* string, long* value)
{
char* endp = nullptr;
*value = strtol(string, &endp, 10);
return endp == nullptr || endp[0] == 0;
}
template <>
inline bool stringToNumber(const char* string, float* value)
{
char* endp = nullptr;
*value = strtof(string, &endp);
return endp == nullptr || endp[0] == 0;
}
template <>
inline bool stringToNumber(const char* string, double* value)
{
char* endp = nullptr;
*value = strtod(string, &endp);
return endp == nullptr || endp[0] == 0;
}
inline bool isArgParseFailureCode(int result)
{
return result < 0;
}
inline int getArgString(int argc, char** argv, int argIndex, std::string& value)
{
char* equal;
int argsConsumed = 0;
char* valueOut;
size_t lengthToCopy;
// make sure the requested argument index is in range of the argument list.
if (argIndex >= argc)
return kArgFailIndexOutOfRange;
equal = strchr(argv[argIndex], '=');
// the argument is of the form "name=value" => parse the number from the arg itself.
if (equal != nullptr)
valueOut = equal + 1;
// the argument is of the form "name value" => parse the number from the next arg.
else if (argIndex + 1 < argc)
{
valueOut = argv[argIndex + 1];
argsConsumed = 1;
}
// not enough args given => fail.
else
{
fprintf(stderr, "expected another argument after '%s'.\n", argv[argIndex]);
return kArgFailExpectedArgument;
}
lengthToCopy = strlen(valueOut);
if (valueOut[0] == '\"' || valueOut[0] == '\'')
{
char openQuote = valueOut[0];
if (valueOut[lengthToCopy - 1] != openQuote)
return kArgFailExpectedQuote;
// calculate the length of the value string. This is two less than the string's length -
// one each for the opening and closing quotation marks.
lengthToCopy -= 2;
valueOut++;
}
value = std::string(valueOut, lengthToCopy);
return argsConsumed;
}
} // namespace details
inline bool parseOptions(const Option* supportedArgs, int argc, char** argv, Options* args)
{
bool handled;
int argsConsumed;
ParseResult result;
Value value;
Value* valueToSend;
auto argMatches = [](const char* string, const char* arg, bool terminated) {
size_t len = strlen(arg);
if (!terminated)
return strncmp(string, arg, len) == 0;
return strncmp(string, arg, len) == 0 && (string[len] == 0 || string[len] == '=');
};
for (int i = 1; i < argc; i++)
{
handled = false;
for (size_t j = 0; supportedArgs[j].parser != nullptr; j++)
{
bool checkTermination = supportedArgs[j].expectedType != ValueType::eIgnore;
if ((supportedArgs[j].shortName != nullptr &&
argMatches(argv[i], supportedArgs[j].shortName, checkTermination)) ||
(supportedArgs[j].longName != nullptr && argMatches(argv[i], supportedArgs[j].longName, checkTermination)))
{
std::string valueStr;
valueToSend = nullptr;
argsConsumed = 0;
value.clear();
if (supportedArgs[j].expectedArgs > 0)
{
argsConsumed = details::getArgString(argc, argv, i, valueStr);
if (details::isArgParseFailureCode(argsConsumed))
{
switch (argsConsumed)
{
case details::kArgFailExpectedArgument:
fprintf(
stderr, "ERROR-> expected an extra argument after argument %d ('%s').", i, argv[i]);
break;
case details::kArgFailExpectedQuote:
fprintf(
stderr,
"ERROR-> expected a matching quotation mark at the end of the argument value for argument %d ('%s').",
i, argv[i]);
break;
// this error will never be returned given how the loop above
// iterates. We'll print a basic message for it anyway.
case details::kArgFailIndexOutOfRange:
fprintf(stderr, "ERROR-> argument index out of range.\n");
break;
default:
break;
}
return false;
}
if (supportedArgs[j].expectedType == ValueType::eIgnore)
{
i += argsConsumed;
handled = true;
break;
}
#if !defined(DOXYGEN_SHOULD_SKIP_THIS)
# define SETVALUE(value, str, type) \
do \
{ \
type convertedValue; \
if (!details::stringToNumber(str.c_str(), &convertedValue)) \
{ \
fprintf(stderr, "ERROR-> expected a %s value after '%s'.\n", #type, argv[i]); \
return false; \
} \
value.set(convertedValue); \
} while (0)
#endif
switch (supportedArgs[j].expectedType)
{
default:
case ValueType::eString:
value.set(valueStr);
break;
case ValueType::eLong:
SETVALUE(value, valueStr, long);
break;
case ValueType::eLongLong:
SETVALUE(value, valueStr, long long);
break;
case ValueType::eFloat:
SETVALUE(value, valueStr, float);
break;
case ValueType::eDouble:
SETVALUE(value, valueStr, double);
break;
}
valueToSend = &value;
#undef SETVALUE
}
result = supportedArgs[j].parser(argv[i], valueToSend, args);
switch (result)
{
default:
case ParseResult::eSuccess:
break;
case ParseResult::eInvalidValue:
fprintf(stderr, "ERROR-> unknown or invalid value in '%s'.\n", argv[i]);
return false;
}
i += argsConsumed;
handled = true;
break;
}
}
if (!handled)
{
if (args->firstCommandArgument < 0)
args->firstCommandArgument = i;
break;
}
}
args->argc = argc;
args->argv = argv;
return true;
}
inline void printOptionUsage(const Option* supportedArgs, const char* helpString, FILE* stream)
{
const char* str;
const char* newline;
const char* argStr;
fputs(helpString, stream);
fputs("Supported options:\n", stream);
for (size_t i = 0; supportedArgs[i].parser != nullptr; i++)
{
str = supportedArgs[i].documentation;
argStr = "";
if (supportedArgs[i].expectedArgs > 0)
argStr = " [value]";
if (supportedArgs[i].shortName != nullptr)
fprintf(stream, " %s%s:\n", supportedArgs[i].shortName, argStr);
if (supportedArgs[i].longName != nullptr)
fprintf(stream, " %s%s:\n", supportedArgs[i].longName, argStr);
for (newline = strchr(str, '\n'); newline != nullptr; str = newline + 1, newline = strchr(str + 1, '\n'))
fprintf(stream, " %.*s\n", static_cast<int>(newline - str), str);
fputs("\n", stream);
}
fputs("\n", stream);
}
} // namespace options
} // namespace carb