carb/extras/Base64.h

File members: carb/extras/Base64.h

// Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//
#pragma once

#include <stdint.h>
#include <string.h>

namespace carb
{
namespace extras
{

class Base64
{
public:
    static constexpr size_t kNullTerminated = ~0ull;

    enum class Variant
    {
        eDefault,
        ePem,
        eMime,
        eRfc4648,
        eFilenameSafe,
        eOpenPgp,
        eUtf7,
        eImap,
        eYui,
        eProgramId1,
        eProgramId2,
        eFreenetUrl,
    };

    Base64() : Base64(Variant::eDefault)
    {
    }

    Base64(Variant variant)
    {
        switch (variant)
        {
            default:
            case Variant::eDefault:
            case Variant::ePem:
            case Variant::eMime:
            case Variant::eRfc4648:
            case Variant::eOpenPgp:
            case Variant::eUtf7:
                initCodec('+', '/', '=');
                break;

            case Variant::eFilenameSafe:
                initCodec('-', '_', '=');
                break;

            case Variant::eImap:
                initCodec('+', ',', '=');
                break;

            case Variant::eYui:
                initCodec('.', '_', '-');
                break;

            case Variant::eProgramId1:
                initCodec('_', '-', '=');
                break;

            case Variant::eProgramId2:
                initCodec('.', '_', '=');
                break;

            case Variant::eFreenetUrl:
                initCodec('~', '-', '=');
                break;
        }
    }

    Base64(uint8_t byte62, uint8_t byte63, uint8_t padding = 0)
    {
        initCodec(byte62, byte63, padding);
    }

    static size_t getEncodeOutputSize(size_t inputSize)
    {
        return 4 * ((inputSize + 2) / 3) + 1;
    }

    static size_t getEncodeInputSize(size_t outputSize)
    {
        return ((outputSize - 1) / 4 * 3) / 3 * 3;
    }

    static size_t getDecodeOutputSize(size_t inputSize)
    {
        return 3 * ((inputSize + 3) / 4);
    }

    size_t encode(const void* buffer, size_t size, char* output, size_t maxOut)
    {
        uint32_t data;
        size_t j = 0;
        size_t stop;
        size_t extra;
        size_t paddingCount = (m_padding == 0) ? 0 : 1;

        // null terminated C string input data => calculate its length.
        if (size == kNullTerminated)
            size = strlen(reinterpret_cast<const char*>(buffer));

        // the output buffer is not large enough => fail.
        if (maxOut < getEncodeOutputSize(size))
            return 0;

        // calculate the number of input bytes that can be bulk processed.
        stop = (size / 3) * 3;
        extra = size - stop;

        // bulk process all aligned input bytes.  Each three byte input block will produce
        // four output bytes.
        for (size_t i = 0; i < stop; i += 3)
        {
            data = (reinterpret_cast<const uint8_t*>(buffer)[i + 0] << 16) |
                   (reinterpret_cast<const uint8_t*>(buffer)[i + 1] << 8) |
                   (reinterpret_cast<const uint8_t*>(buffer)[i + 2] << 0);

            output[j + 0] = m_encode[(data >> 18) & 0x3f];
            output[j + 1] = m_encode[(data >> 12) & 0x3f];
            output[j + 2] = m_encode[(data >> 6) & 0x3f];
            output[j + 3] = m_encode[(data >> 0) & 0x3f];
            j += 4;
        }

        // process any remaining bytes.  Note that a value of 0 indicates that the original
        // input data was a multiple of 3 bytes and no unaligned data needs to be processed.
        switch (extra)
        {
            // one extra unaligned input byte was provided.  This will produce two output bytes
            // followed by two padding bytes.
            case 1:
                data = reinterpret_cast<const uint8_t*>(buffer)[stop + 0] << 16;

                output[j + 0] = m_encode[(data >> 18) & 0x3f];
                output[j + 1] = m_encode[(data >> 12) & 0x3f];
                output[j + 2] = m_padding;
                output[j + 3] = m_padding;
                j += 2 + (paddingCount * 2);
                break;

            // two extra unaligned input bytes were provided.  This will produce three output
            // bytes followed by one padding byte.
            case 2:
                data = (reinterpret_cast<const uint8_t*>(buffer)[stop + 0] << 16) |
                       (reinterpret_cast<const uint8_t*>(buffer)[stop + 1] << 8);

                output[j + 0] = m_encode[(data >> 18) & 0x3f];
                output[j + 1] = m_encode[(data >> 12) & 0x3f];
                output[j + 2] = m_encode[(data >> 6) & 0x3f];
                output[j + 3] = m_padding;
                j += 3 + paddingCount;
                break;

            // another count of destination characters (!?) -> should never happen for any value
            // other than 0 => ignore it.
            default:
                break;
        }

        // always null terminate the output buffer so that it can be treated as a C string by
        // the caller.
        output[j] = 0;

        return j;
    }

    size_t decode(const char* buffer, size_t size, void* output, size_t maxOut)
    {
        uint32_t data;
        size_t j = 0;
        size_t stop;
        size_t extra;
        uint8_t* out = reinterpret_cast<uint8_t*>(output);

        // the input buffer is a null terminated C string => calculate its length.
        if (size == kNullTerminated)
            size = strlen(reinterpret_cast<const char*>(buffer));

        // the output buffer is not large enough => fail.
        if (maxOut < getDecodeOutputSize(size))
            return 0;

        // invalid encoding length -> decodes to less than 1 byte => fail.
        if (size < 2)
            return 0;

        // calculate the number of input bytes that can be bulk processed.
        extra = size & 3;
        stop = size - extra;

        // no extra unaligned input bytes were provided -> the input buffer is actually aligned
        // or it contains padding bytes => determine which one is correct and adjust sizes.
        if (extra == 0)
        {
            // two padding bytes were specified => produces one byte in the last block.
            if (buffer[size - 2] == m_padding)
            {
                extra = 2;
                stop -= 4;
            }

            // one padding byte was specified => produces two bytes in the last block.
            else if (buffer[size - 1] == m_padding)
            {
                extra = 3;
                stop -= 4;
            }

            // at this point, we know the input is either block aligned or is corrupt.  Either
            // way we don't care and will just continue decoding.  It is left up to the caller
            // to verify the validity of the decoded data.  All we are interested in here is
            // the actual decoding process.
        }

        // bulk process the aligned input data.  Each four byte block will produce three output
        // bytes.
        for (size_t i = 0; i < stop; i += 4)
        {
            data = (m_decode[static_cast<size_t>(buffer[i + 0])] << 18) |
                   (m_decode[static_cast<size_t>(buffer[i + 1])] << 12) |
                   (m_decode[static_cast<size_t>(buffer[i + 2])] << 6) |
                   (m_decode[static_cast<size_t>(buffer[i + 3])] << 0);

            out[j + 0] = (data >> 16) & 0xff;
            out[j + 1] = (data >> 8) & 0xff;
            out[j + 2] = (data >> 0) & 0xff;
            j += 3;
        }

        // process any extra unaligned bytes.  This will allow extra or optional padding bytes to
        // be ignored in the input buffer.
        switch (extra)
        {
            // two extra bytes (plus two optional padding bytes) were provided.  This produces one
            // output byte.
            case 2:
                data = (m_decode[static_cast<size_t>(buffer[stop + 0])] << 18) |
                       (m_decode[static_cast<size_t>(buffer[stop + 1])] << 12);

                out[j + 0] = (data >> 16) & 0xff;
                j += 1;
                break;

            // three extra bytes (plus one optional padding byte) were provided.  This produces two
            // output bytes.
            case 3:
                data = (m_decode[static_cast<size_t>(buffer[stop + 0])] << 18) |
                       (m_decode[static_cast<size_t>(buffer[stop + 1])] << 12) |
                       (m_decode[static_cast<size_t>(buffer[stop + 2])] << 6);

                out[j + 0] = (data >> 16) & 0xff;
                out[j + 1] = (data >> 8) & 0xff;
                j += 2;
                break;

            // invalid 'extra' count (!?) or no extra bytes => fail.
            default:
                break;
        }

        // return the number of decoded bytes.
        return j;
    }

private:
    void initCodec(uint8_t char62, uint8_t char63, uint8_t padding)
    {
        // generate the encoding map.
        for (size_t i = 0; i < 26; i++)
        {
            m_encode[i] = char(('A' + i) & 0xff);
            m_encode[i + 26] = char(('a' + i) & 0xff);
        }

        for (size_t i = 0; i < 10; i++)
            m_encode[i + 52] = char(('0' + i) & 0xff);

        m_encode[62] = char62;
        m_encode[63] = char63;

        // generate the decoding map as a the reverse of the encoding table.
        memset(m_decode, 0, sizeof(m_decode));

        for (size_t i = 0; i < 64; i++)
            m_decode[static_cast<size_t>(m_encode[i])] = char(i & 0xff);

        // store the padding byte.  Note that the padding byte does not need to be part of the
        // decoding table since it will never be decoded.
        m_padding = padding;
    }

    char m_encode[64];

    char m_decode[256];

    char m_padding;
};

} // namespace extras
} // namespace carb