Question

I'm looking for a fast method to efficiently compute  (ab) modulo n  (in the mathematical sense of that) for a, b, n of type uint64_t. I could live with preconditions such as n!=0, or even a<n && b<n.

Notice that the C expression (a*b)%n won't cut it, because the product is truncated to 64 bits. I'm looking for (uint64_t)(((uint128_t)a*b)%n) except that I do not have a uint128_t (that I know, in Visual C++).

I'm in for a Visual C++ (preferably) or GCC/clang intrinsic making best use of the underlying hardware available on x86-64 platforms; or if that can't be done for a portable inline function.

Était-ce utile?

La solution 3

7 years later, I got a solution working in Visual Studio 2019

#include <stdint.h>
#include <intrin.h>
#pragma intrinsic(_umul128)
#pragma intrinsic(_udiv128)

// compute (a*b)%n with 128-bit intermediary result
// assumes n>0  and  a*b < n * 2**64 (always the case when a<=n || b<=n )
inline uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
  uint64_t r, s = _umul128(a, b, &r);
  (void)_udiv128(r, s, n, &r);
  return r;
}

// compute (a*b)%n with 128-bit intermediary result
// assumes n>0, works including if a*b >= n * 2**64
inline uint64_t mulmod1(uint64_t a, uint64_t b, uint64_t n) {
  uint64_t r, s = _umul128(a % n, b, &r);
  (void)_udiv128(r, s, n, &r);
  return r;
}

Autres conseils

Ok, how about this (not tested)

modmul:
; rcx = a
; rdx = b
; r8 = n
mov rax, rdx
mul rcx
div r8
mov rax, rdx
ret

The precondition is that a * b / n <= ~0ULL, otherwise there will be a divide error. That's a slightly less strict condition than a < n && m < n, one of them can be bigger than n as long as the other is small enough.

Unfortunately it has to be assembled and linked in separately, because MSVC doesn't support inline asm for 64bit targets.

It's also still slow, the real problem is that 64bit div, which can take nearly a hundred cycles (seriously, up to 90 cycles on Nehalem for example).

You could do it the old-fashioned way with shift/add/subtract. The below code assumes a < n and
n < 263 (so things don't overflow):

uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
    uint64_t rv = 0;
    while (b) {
        if (b&1)
            if ((rv += a) >= n) rv -= n;
        if ((a += a) >= n) a -= n;
        b >>= 1; }
    return rv;
}

You could use while (a && b) for the loop instead to short-circuit things if it's likely that a will be a factor of n. Will be slightly slower (more comparisons and likely correctly predicted branches) if a is not a factor of n.

If you really, absolutely, need that last bit (allowing n up to 264-1), you can use:

uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
    uint64_t rv = 0;
    while (b) {
        if (b&1) {
            rv += a;
            if (rv < a || rv >= n) rv -= n; }
        uint64_t t = a;
        a += a;
        if (a < t || a >= n) a -= n;
        b >>= 1; }
    return rv;
}

Alternately, just use GCC instrinsics to access the underlying x64 instructions:

inline uint64_t mulmod(uint64_t a, uint64_t b, uint64_t n) {
    uint64_t rv;
    asm ("mul %3" : "=d"(rv), "=a"(a) : "1"(a), "r"(b));
    asm ("div %4" : "=d"(rv), "=a"(a) : "0"(rv), "1"(a), "r"(n));
    return rv;
}

The 64-bit div instruction is really slow, however, so the loop might actually be faster. You'd need to profile to be sure.

This intrinsic is named __mul128.

typedef unsigned long long BIG;

// handles only the "hard" case when high bit of n is set
BIG shl_mod( BIG v, BIG n, int by )
{
    if (v > n) v -= n;
    while (by--) {
        if (v > (n-v))
            v -= n-v;
        else
            v <<= 1;
    }
    return v;
}

Now you can use shl_mod(B, n, 64)

Having no inline assembly kind of sucks. Anyway, the function call overhead is actually extremely small. Parameters are passed in volatile registers and no cleanup is needed.

I don't have an assembler, and x64 targets don't support __asm, so I had no choice but to "assemble" my function from opcodes myself.

Obviously it depends on . I'm using mpir (gmp) as a reference to show the function produces correct results.


#include "stdafx.h"

// mulmod64(a, b, m) == (a * b) % m
typedef uint64_t(__cdecl *mulmod64_fnptr_t)(uint64_t a, uint64_t b, uint64_t m);

uint8_t mulmod64_opcodes[] = {
    0x48, 0x89, 0xC8, // mov rax, rcx
    0x48, 0xF7, 0xE2, // mul rdx
    0x4C, 0x89, 0xC1, // mov rcx, r8
    0x48, 0xF7, 0xF1, // div rcx
    0x48, 0x89, 0xD0, // mov rax,rdx
    0xC3              // ret
};

mulmod64_fnptr_t mulmod64_fnptr;

void init() {
    DWORD dwOldProtect;
    VirtualProtect(
        &mulmod64_opcodes,
        sizeof(mulmod64_opcodes),
        PAGE_EXECUTE_READWRITE,
        &dwOldProtect);
    // NOTE: reinterpret byte array as a function pointer
    mulmod64_fnptr = (mulmod64_fnptr_t)(void*)mulmod64_opcodes;
}

int main() {
    init();

    uint64_t a64 = 2139018971924123ull;
    uint64_t b64 = 1239485798578921ull;
    uint64_t m64 = 8975489368910167ull;

    // reference code
    mpz_t a, b, c, m, r;
    mpz_inits(a, b, c, m, r, NULL);
    mpz_set_ui(a, a64);
    mpz_set_ui(b, b64);
    mpz_set_ui(m, m64);
    mpz_mul(c, a, b);
    mpz_mod(r, c, m);

    gmp_printf("(%Zd * %Zd) mod %Zd = %Zd\n", a, b, m, r);

    // using mulmod64
    uint64_t r64 = mulmod64_fnptr(a64, b64, m64);
    printf("(%llu * %llu) mod %llu = %llu\n", a64, b64, m64, r64);
    return 0;
}

Licencié sous: CC-BY-SA avec attribution
Non affilié à StackOverflow
scroll top