1
0
mirror of https://git.tartarus.org/simon/putty.git synced 2025-01-09 17:38:00 +00:00

Add mp_nthroot function.

This takes ordinary integer square and cube roots (i.e. not mod
anything) of mp_ints.
This commit is contained in:
Simon Tatham 2020-02-18 20:07:55 +00:00
parent ece788240c
commit 6b27999500
5 changed files with 129 additions and 0 deletions

76
mpint.c
View File

@ -2225,6 +2225,82 @@ mp_int *mp_mod(mp_int *n, mp_int *d)
return r; return r;
} }
mp_int *mp_nthroot(mp_int *y, unsigned n, mp_int *remainder_out)
{
/*
* Allocate scratch space.
*/
mp_int **alloc, **powers, **newpowers, *scratch;
size_t nalloc = 2*(n+1)+1;
alloc = snewn(nalloc, mp_int *);
for (size_t i = 0; i < nalloc; i++)
alloc[i] = mp_make_sized(y->nw + 1);
powers = alloc;
newpowers = alloc + (n+1);
scratch = alloc[2*n+2];
/*
* We're computing the rounded-down nth root of y, i.e. the
* maximal x such that x^n <= y. We try to add 2^i to it for each
* possible value of i, starting from the largest one that might
* fit (i.e. such that 2^{n*i} fits in the size of y) downwards to
* i=0.
*
* We track all the smaller powers of x in the array 'powers'. In
* each iteration, if we update x, we update all of those values
* to match.
*/
mp_copy_integer_into(powers[0], 1);
for (size_t s = mp_max_bits(y) / n + 1; s-- > 0 ;) {
/*
* Let b = 2^s. We need to compute the powers (x+b)^i for each
* i, starting from our recorded values of x^i.
*/
for (size_t i = 0; i < n+1; i++) {
/*
* (x+b)^i = x^i
* + (i choose 1) x^{i-1} b
* + (i choose 2) x^{i-2} b^2
* + ...
* + b^i
*/
uint16_t binom = 1; /* coefficient of b^i */
mp_copy_into(newpowers[i], powers[i]);
for (size_t j = 0; j < i; j++) {
/* newpowers[i] += binom * powers[j] * 2^{(i-j)*s} */
mp_mul_integer_into(scratch, powers[j], binom);
mp_lshift_fixed_into(scratch, scratch, (i-j) * s);
mp_add_into(newpowers[i], newpowers[i], scratch);
uint32_t binom_mul = binom;
binom_mul *= (i-j);
binom_mul /= (j+1);
assert(binom_mul < 0x10000);
binom = binom_mul;
}
}
/*
* Now, is the new value of x^n still <= y? If so, update.
*/
unsigned newbit = mp_cmp_hs(y, newpowers[n]);
for (size_t i = 0; i < n+1; i++)
mp_select_into(powers[i], powers[i], newpowers[i], newbit);
}
if (remainder_out)
mp_sub_into(remainder_out, y, powers[n]);
mp_int *root = mp_new(mp_max_bits(y) / n);
mp_copy_into(root, powers[1]);
for (size_t i = 0; i < nalloc; i++)
mp_free(alloc[i]);
sfree(alloc);
return root;
}
mp_int *mp_modmul(mp_int *x, mp_int *y, mp_int *modulus) mp_int *mp_modmul(mp_int *x, mp_int *y, mp_int *modulus)
{ {
mp_int *product = mp_mul(x, y); mp_int *product = mp_mul(x, y);

11
mpint.h
View File

@ -256,6 +256,17 @@ void mp_divmod_into(mp_int *n, mp_int *d, mp_int *q, mp_int *r);
mp_int *mp_div(mp_int *n, mp_int *d); mp_int *mp_div(mp_int *n, mp_int *d);
mp_int *mp_mod(mp_int *x, mp_int *modulus); mp_int *mp_mod(mp_int *x, mp_int *modulus);
/*
* Integer nth root. mp_nthroot returns the largest integer x such
* that x^n <= y, and if 'remainder' is non-NULL then it fills it with
* the residue (y - x^n).
*
* Currently, n has to be small enough that the largest binomial
* coefficient (n choose k) fits in 16 bits, which works out to at
* most 18.
*/
mp_int *mp_nthroot(mp_int *y, unsigned n, mp_int *remainder);
/* /*
* Trivially easy special case of mp_mod: reduce a number mod a power * Trivially easy special case of mp_mod: reduce a number mod a power
* of two. * of two.

View File

@ -391,6 +391,29 @@ class mpint(MyTestBase):
# No tests we can do after that last one - we just # No tests we can do after that last one - we just
# insist that it isn't allowed to have crashed! # insist that it isn't allowed to have crashed!
def testNthRoot(self):
roots = [1, 13, 1234567654321,
57721566490153286060651209008240243104215933593992]
tests = []
tests.append((0, 2, 0, 0))
tests.append((0, 3, 0, 0))
for r in roots:
for n in 2, 3, 5:
tests.append((r**n, n, r, 0))
tests.append((r**n+1, n, r, 1))
tests.append((r**n-1, n, r-1, r**n - (r-1)**n - 1))
for x, n, eroot, eremainder in tests:
with self.subTest(x=x):
mx = mp_copy(x)
remainder = mp_copy(mx)
root = mp_nthroot(x, n, remainder)
self.assertEqual(int(root), eroot)
self.assertEqual(int(remainder), eremainder)
self.assertEqual(int(mp_nthroot(2*10**100, 2, None)),
141421356237309504880168872420969807856967187537694)
self.assertEqual(int(mp_nthroot(3*10**150, 3, None)),
144224957030740838232163831078010958839186925349935)
def testBitwise(self): def testBitwise(self):
p = 0x3243f6a8885a308d313198a2e03707344a4093822299f31d0082efa98ec4e p = 0x3243f6a8885a308d313198a2e03707344a4093822299f31d0082efa98ec4e
e = 0x2b7e151628aed2a6abf7158809cf4f3c762e7160f38b4da56a784d9045190 e = 0x2b7e151628aed2a6abf7158809cf4f3c762e7160f38b4da56a784d9045190

View File

@ -51,6 +51,7 @@ FUNC2(void, mp_cond_clear, val_mpint, uint)
FUNC4(void, mp_divmod_into, val_mpint, val_mpint, opt_val_mpint, opt_val_mpint) FUNC4(void, mp_divmod_into, val_mpint, val_mpint, opt_val_mpint, opt_val_mpint)
FUNC2(val_mpint, mp_div, val_mpint, val_mpint) FUNC2(val_mpint, mp_div, val_mpint, val_mpint)
FUNC2(val_mpint, mp_mod, val_mpint, val_mpint) FUNC2(val_mpint, mp_mod, val_mpint, val_mpint)
FUNC3(val_mpint, mp_nthroot, val_mpint, uint, opt_val_mpint)
FUNC2(void, mp_reduce_mod_2to, val_mpint, uint) FUNC2(void, mp_reduce_mod_2to, val_mpint, uint)
FUNC2(val_mpint, mp_invert_mod_2to, val_mpint, uint) FUNC2(val_mpint, mp_invert_mod_2to, val_mpint, uint)
FUNC2(val_mpint, mp_invert, val_mpint, val_mpint) FUNC2(val_mpint, mp_invert, val_mpint, val_mpint)

View File

@ -294,6 +294,7 @@ VOLATILE_WRAPPED_DEFN(static, size_t, looplimit, (size_t x))
X(mp_mul) \ X(mp_mul) \
X(mp_rshift_safe) \ X(mp_rshift_safe) \
X(mp_divmod) \ X(mp_divmod) \
X(mp_nthroot) \
X(mp_modadd) \ X(mp_modadd) \
X(mp_modsub) \ X(mp_modsub) \
X(mp_modmul) \ X(mp_modmul) \
@ -577,6 +578,23 @@ static void test_mp_divmod(void)
mp_free(r); mp_free(r);
} }
static void test_mp_nthroot(void)
{
mp_int *x = mp_new(256), *remainder = mp_new(256);
for (size_t i = 0; i < looplimit(32); i++) {
uint8_t sizes[1];
random_read(sizes, 1);
mp_random_bits_into(x, sizes[0]);
log_start();
mp_free(mp_nthroot(x, 3, remainder));
log_end();
}
mp_free(x);
mp_free(remainder);
}
static void test_mp_modarith( static void test_mp_modarith(
mp_int *(*mp_modarith)(mp_int *x, mp_int *y, mp_int *modulus)) mp_int *(*mp_modarith)(mp_int *x, mp_int *y, mp_int *modulus))
{ {