diff --git a/mpint.c b/mpint.c index 7a70d400..ddd21411 100644 --- a/mpint.c +++ b/mpint.c @@ -2225,6 +2225,82 @@ mp_int *mp_mod(mp_int *n, mp_int *d) return r; } +mp_int *mp_nthroot(mp_int *y, unsigned n, mp_int *remainder_out) +{ + /* + * Allocate scratch space. + */ + mp_int **alloc, **powers, **newpowers, *scratch; + size_t nalloc = 2*(n+1)+1; + alloc = snewn(nalloc, mp_int *); + for (size_t i = 0; i < nalloc; i++) + alloc[i] = mp_make_sized(y->nw + 1); + powers = alloc; + newpowers = alloc + (n+1); + scratch = alloc[2*n+2]; + + /* + * We're computing the rounded-down nth root of y, i.e. the + * maximal x such that x^n <= y. We try to add 2^i to it for each + * possible value of i, starting from the largest one that might + * fit (i.e. such that 2^{n*i} fits in the size of y) downwards to + * i=0. + * + * We track all the smaller powers of x in the array 'powers'. In + * each iteration, if we update x, we update all of those values + * to match. + */ + mp_copy_integer_into(powers[0], 1); + for (size_t s = mp_max_bits(y) / n + 1; s-- > 0 ;) { + /* + * Let b = 2^s. We need to compute the powers (x+b)^i for each + * i, starting from our recorded values of x^i. + */ + for (size_t i = 0; i < n+1; i++) { + /* + * (x+b)^i = x^i + * + (i choose 1) x^{i-1} b + * + (i choose 2) x^{i-2} b^2 + * + ... + * + b^i + */ + uint16_t binom = 1; /* coefficient of b^i */ + mp_copy_into(newpowers[i], powers[i]); + for (size_t j = 0; j < i; j++) { + /* newpowers[i] += binom * powers[j] * 2^{(i-j)*s} */ + mp_mul_integer_into(scratch, powers[j], binom); + mp_lshift_fixed_into(scratch, scratch, (i-j) * s); + mp_add_into(newpowers[i], newpowers[i], scratch); + + uint32_t binom_mul = binom; + binom_mul *= (i-j); + binom_mul /= (j+1); + assert(binom_mul < 0x10000); + binom = binom_mul; + } + } + + /* + * Now, is the new value of x^n still <= y? If so, update. + */ + unsigned newbit = mp_cmp_hs(y, newpowers[n]); + for (size_t i = 0; i < n+1; i++) + mp_select_into(powers[i], powers[i], newpowers[i], newbit); + } + + if (remainder_out) + mp_sub_into(remainder_out, y, powers[n]); + + mp_int *root = mp_new(mp_max_bits(y) / n); + mp_copy_into(root, powers[1]); + + for (size_t i = 0; i < nalloc; i++) + mp_free(alloc[i]); + sfree(alloc); + + return root; +} + mp_int *mp_modmul(mp_int *x, mp_int *y, mp_int *modulus) { mp_int *product = mp_mul(x, y); diff --git a/mpint.h b/mpint.h index ffc1e160..bcbe442e 100644 --- a/mpint.h +++ b/mpint.h @@ -256,6 +256,17 @@ void mp_divmod_into(mp_int *n, mp_int *d, mp_int *q, mp_int *r); mp_int *mp_div(mp_int *n, mp_int *d); mp_int *mp_mod(mp_int *x, mp_int *modulus); +/* + * Integer nth root. mp_nthroot returns the largest integer x such + * that x^n <= y, and if 'remainder' is non-NULL then it fills it with + * the residue (y - x^n). + * + * Currently, n has to be small enough that the largest binomial + * coefficient (n choose k) fits in 16 bits, which works out to at + * most 18. + */ +mp_int *mp_nthroot(mp_int *y, unsigned n, mp_int *remainder); + /* * Trivially easy special case of mp_mod: reduce a number mod a power * of two. diff --git a/test/cryptsuite.py b/test/cryptsuite.py index c5a04c92..c48a71fc 100755 --- a/test/cryptsuite.py +++ b/test/cryptsuite.py @@ -391,6 +391,29 @@ class mpint(MyTestBase): # No tests we can do after that last one - we just # insist that it isn't allowed to have crashed! + def testNthRoot(self): + roots = [1, 13, 1234567654321, + 57721566490153286060651209008240243104215933593992] + tests = [] + tests.append((0, 2, 0, 0)) + tests.append((0, 3, 0, 0)) + for r in roots: + for n in 2, 3, 5: + tests.append((r**n, n, r, 0)) + tests.append((r**n+1, n, r, 1)) + tests.append((r**n-1, n, r-1, r**n - (r-1)**n - 1)) + for x, n, eroot, eremainder in tests: + with self.subTest(x=x): + mx = mp_copy(x) + remainder = mp_copy(mx) + root = mp_nthroot(x, n, remainder) + self.assertEqual(int(root), eroot) + self.assertEqual(int(remainder), eremainder) + self.assertEqual(int(mp_nthroot(2*10**100, 2, None)), + 141421356237309504880168872420969807856967187537694) + self.assertEqual(int(mp_nthroot(3*10**150, 3, None)), + 144224957030740838232163831078010958839186925349935) + def testBitwise(self): p = 0x3243f6a8885a308d313198a2e03707344a4093822299f31d0082efa98ec4e e = 0x2b7e151628aed2a6abf7158809cf4f3c762e7160f38b4da56a784d9045190 diff --git a/testcrypt.h b/testcrypt.h index 617f746b..3eddc9ba 100644 --- a/testcrypt.h +++ b/testcrypt.h @@ -51,6 +51,7 @@ FUNC2(void, mp_cond_clear, val_mpint, uint) FUNC4(void, mp_divmod_into, val_mpint, val_mpint, opt_val_mpint, opt_val_mpint) FUNC2(val_mpint, mp_div, val_mpint, val_mpint) FUNC2(val_mpint, mp_mod, val_mpint, val_mpint) +FUNC3(val_mpint, mp_nthroot, val_mpint, uint, opt_val_mpint) FUNC2(void, mp_reduce_mod_2to, val_mpint, uint) FUNC2(val_mpint, mp_invert_mod_2to, val_mpint, uint) FUNC2(val_mpint, mp_invert, val_mpint, val_mpint) diff --git a/testsc.c b/testsc.c index bbe15faf..7558297e 100644 --- a/testsc.c +++ b/testsc.c @@ -294,6 +294,7 @@ VOLATILE_WRAPPED_DEFN(static, size_t, looplimit, (size_t x)) X(mp_mul) \ X(mp_rshift_safe) \ X(mp_divmod) \ + X(mp_nthroot) \ X(mp_modadd) \ X(mp_modsub) \ X(mp_modmul) \ @@ -577,6 +578,23 @@ static void test_mp_divmod(void) mp_free(r); } +static void test_mp_nthroot(void) +{ + mp_int *x = mp_new(256), *remainder = mp_new(256); + + for (size_t i = 0; i < looplimit(32); i++) { + uint8_t sizes[1]; + random_read(sizes, 1); + mp_random_bits_into(x, sizes[0]); + log_start(); + mp_free(mp_nthroot(x, 3, remainder)); + log_end(); + } + + mp_free(x); + mp_free(remainder); +} + static void test_mp_modarith( mp_int *(*mp_modarith)(mp_int *x, mp_int *y, mp_int *modulus)) {