MulDiv.h#

Fully qualified name: carb/math/MulDiv.h

File members: carb/math/MulDiv.h

// Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto. Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include "../cpp/Optional.h"

#if CARB_COMPILER_GNUC || CARB_TOOLCHAIN_CLANG
#elif CARB_COMPILER_MSC
extern "C"
{
#    if CARB_X86_64
    unsigned char _addcarry_u64(unsigned char, unsigned __int64, unsigned __int64, unsigned __int64*);
    unsigned __int64 _umul128(unsigned __int64, unsigned __int64, unsigned __int64*);
    unsigned __int64 _udiv128(unsigned __int64, unsigned __int64, unsigned __int64, unsigned __int64*);
#        pragma intrinsic(_addcarry_u64)
#        pragma intrinsic(_umul128)
#        pragma intrinsic(_udiv128)
#    elif CARB_AARCH64
    unsigned __int64 __umulh(unsigned __int64, unsigned __int64);
#        pragma intrinsic(__umulh)
#    else
    CARB_UNSUPPORTED_ARCHITECTURE();
#    endif
}
#else
CARB_UNSUPPORTED_PLATFORM();
#endif

namespace carb
{
namespace math
{
namespace detail
{

inline uint8_t adc(uint8_t cf, uint64_t addend1, uint64_t addend2, uint64_t& sum) noexcept
{
#if CARB_COMPILER_GNUC || CARB_TOOLCHAIN_CLANG
    auto c1 = __builtin_add_overflow(addend1, addend2, &sum);
    auto c2 = __builtin_add_overflow(sum, cf, &sum);
    return uint8_t(c1 | c2);
#elif CARB_COMPILER_MSC
#    if CARB_X86_64
    return _addcarry_u64(cf, addend1, addend2, &sum);
#    elif CARB_AARCH64
    // Don't have any intrinsics that support carry flag here, so do manual carry check
    auto temp = addend1 + addend2;
    auto c1 = temp < addend1;
    sum = temp + cf;
    c1 |= (sum < temp);
    return uint8_t(c1);
#    else
    CARB_UNSUPPORTED_ARCHITECTURE();
#    endif
#else
    CARB_UNSUPPORTED_COMPILER();
#endif
}

inline unsigned clz(uint64_t val) noexcept
{
#if CARB_COMPILER_MSC
    unsigned long index;
    return _BitScanReverse64(&index, val) ? 64u - 1u - index : 64u;
#elif CARB_COMPILER_GNUC
    return val ? unsigned(__builtin_clzll(val)) : 64u;
#else
    CARB_UNSUPPORTED_COMPILER();
#endif
}

inline uint64_t umul128(uint64_t multiplier, uint64_t multiplicand, uint64_t& productHigh) noexcept
{
#if CARB_COMPILER_GNUC || CARB_TOOLCHAIN_CLANG
#    if CARB_X86_64
    uint64_t low;
    asm("mulq %[multiplicand]"
        : "=d"(productHigh), "=a"(low)
        : "a"(multiplier), [multiplicand] "rm"(multiplicand)
        : "cc" /* flags */
    );
    return low;
#    elif CARB_AARCH64
    // Combining these instructions into one asm block usually results in a runtime error since the compiler uses the x0
    // register both for the multiplier as well as the result from the umulh instruction.
    uint64_t low;
    asm("umulh %[high], %[u], %[v]" : [high] "=r"(productHigh) : [u] "r"(multiplier), [v] "r"(multiplicand));
    asm("mul %[low], %[u], %[v]" : [low] "=r"(low) : [u] "r"(multiplier), [v] "r"(multiplicand));
    return low;
#    else
    CARB_UNSUPPORTED_ARCHITECTURE();
#    endif
#elif CARB_COMPILER_MSC
#    if CARB_X86_64
    return _umul128(multiplier, multiplicand, &productHigh);
#    elif CARB_AARCH64
    productHigh = __umulh(multiplier, multiplicand);
    return multiplier * multiplicand;
#    else
    CARB_UNSUPPORTED_ARCHITECTURE();
#    endif
#else
    CARB_UNSUPPORTED_COMPILER();
#endif
}

inline uint64_t udiv128(uint64_t dividendHigh, uint64_t dividendLow, uint64_t divisor, uint64_t* remainder = nullptr) noexcept
{
#if CARB_AARCH64
    // aarch64 doesn't have 128-bit divide, so we do Knuth's long division (Algorithm D)
    CARB_ASSERT(divisor != 0); // divide-by-zero
    CARB_ASSERT(dividendHigh < divisor); // overflow

    unsigned s = clz(divisor);
    if (s)
    {
        divisor <<= s;
        dividendHigh <<= s;
        dividendHigh |= (dividendLow >> (64 - s));
        dividendLow <<= s;
    }

    // High quotient
    uint64_t qhat = dividendHigh / unsigned(divisor >> 32);
    uint64_t rhat = dividendHigh % unsigned(divisor >> 32);

    while (unsigned(qhat >> 32) != 0 ||
           uint64_t(unsigned(qhat)) * unsigned(divisor) > ((rhat << 32) | unsigned(dividendLow >> 32)))
    {
        --qhat;
        rhat += unsigned(divisor >> 32);
        if (unsigned(rhat >> 32))
            break;
    }

    unsigned q1 = unsigned(qhat);
    uint64_t uhat = ((dividendHigh << 32) | unsigned(dividendLow >> 32)) - q1 * divisor;

    qhat = uhat / unsigned(divisor >> 32);
    rhat = uhat % unsigned(divisor >> 32);

    while (unsigned(qhat >> 32) != 0 ||
           uint64_t(unsigned(qhat)) * unsigned(divisor) > ((rhat << 32) | unsigned(dividendLow)))
    {
        --qhat;
        rhat += unsigned(divisor >> 32);
        if (unsigned(rhat >> 32))
            break;
    }

    unsigned q0 = unsigned(qhat);

    if (remainder)
        *remainder = (((uhat << 32) | unsigned(dividendLow)) - q0 * divisor) >> s;
    return uint64_t(q1) << 32 | q0;
#elif CARB_X86_64
#    if CARB_COMPILER_GNUC || CARB_TOOLCHAIN_CLANG
    uint64_t quotient, rem;
    asm("divq %4" : "=a"(quotient), "=d"(rem) : "1"(dividendHigh), "0"(dividendLow), "rm"(divisor) : "cc" /*flags*/);
    if (remainder)
        *remainder = rem;
    return quotient;
#    elif CARB_COMPILER_MSC
    uint64_t rem;
    uint64_t quotient = _udiv128(dividendHigh, dividendLow, divisor, &rem);
    if (remainder)
        *remainder = rem;
    return quotient;
#    else
    CARB_UNSUPPORTED_COMPILER();
#    endif
#else
    CARB_UNSUPPORTED_ARCHITECTURE();
#endif
}

} // namespace detail

struct round_toward_zero_t
{
    void operator()(uint64_t&, uint64_t&, uint64_t) noexcept
    {
        // nothing needed; integer math already rounds towards zero
    }
};
CARB_WEAKLINK round_toward_zero_t round_toward_zero;
CARB_WEAKLINK round_toward_zero_t round_floor;

struct round_away_from_zero_t
{
    void operator()(uint64_t& high, uint64_t& low, uint64_t divisor) noexcept
    {
        high += detail::adc(0, low, divisor - 1, low);
    }
};
CARB_WEAKLINK round_away_from_zero_t round_away_from_zero;
CARB_WEAKLINK round_away_from_zero_t round_ceil;

struct round_nearest_neighbor_t
{
    void operator()(uint64_t& high, uint64_t& low, uint64_t divisor) noexcept
    {
        high += detail::adc(0, low, divisor / 2, low);
    }
};
CARB_WEAKLINK round_nearest_neighbor_t round_nearest_neighbor;

template <class RoundPolicy>
[[nodiscard]] cpp::optional<uint64_t> mulDiv(RoundPolicy round, uint64_t num, uint64_t multiplier, uint64_t divisor) noexcept
{
    // No divide-by-zero
    CARB_UNLIKELY_IF(!divisor)
    {
        return cpp::nullopt;
    }

    uint64_t high, low;
    low = detail::umul128(num, multiplier, high);

    // Apply rounding policy
    round(high, low, divisor);

    // Fail on overflow
    if (high >= divisor)
        return {};

    return cpp::make_optional(detail::udiv128(high, low, divisor));
}

template <class RoundPolicy>
[[nodiscard]] cpp::optional<int64_t> mulDiv(RoundPolicy round, int64_t num, int64_t multiplier, int64_t divisor) noexcept
{
    bool const negative = (num ^ multiplier ^ divisor) < 0;

    auto&& absolute = [](int64_t val) noexcept { return val >= 0 ? uint64_t(val) : uint64_t(-val); };

    auto uresult = mulDiv(round, absolute(num), absolute(multiplier), absolute(divisor));
    if (!uresult)
        return cpp::nullopt;

    // Change the sign
    int64_t result = negative ? -int64_t(uresult.value()) : int64_t(uresult.value());

    // Overflow occurred if the sign doesn't match expected
    if (result == 0 || negative == (result < 0))
        return cpp::make_optional(result);
    return cpp::nullopt;
}

template <class RoundPolicy>
[[nodiscard]] cpp::optional<double> mulDiv(RoundPolicy /*ignored*/, double num, double multiplier, double divisor) noexcept
{
    if (divisor == 0.)
        return cpp::nullopt;
    return (num * multiplier) / divisor;
}

[[nodiscard]] inline cpp::optional<uint64_t> mulDiv(uint64_t num, uint64_t multiplier, uint64_t divisor) noexcept
{
    return mulDiv(round_toward_zero, num, multiplier, divisor);
}

[[nodiscard]] inline cpp::optional<int64_t> mulDiv(int64_t num, int64_t multiplier, int64_t divisor) noexcept
{
    return mulDiv(round_toward_zero, num, multiplier, divisor);
}

[[nodiscard]] inline cpp::optional<double> mulDiv(double num, double multiplier, double divisor) noexcept
{
    return mulDiv([] {}, num, multiplier, divisor);
}

} // namespace math
} // namespace carb