mirror of
https://git.tartarus.org/simon/putty.git
synced 2025-01-09 17:38:00 +00:00
Add mp_nthroot function.
This takes ordinary integer square and cube roots (i.e. not mod anything) of mp_ints.
This commit is contained in:
parent
ece788240c
commit
6b27999500
76
mpint.c
76
mpint.c
@ -2225,6 +2225,82 @@ mp_int *mp_mod(mp_int *n, mp_int *d)
|
||||
return r;
|
||||
}
|
||||
|
||||
mp_int *mp_nthroot(mp_int *y, unsigned n, mp_int *remainder_out)
|
||||
{
|
||||
/*
|
||||
* Allocate scratch space.
|
||||
*/
|
||||
mp_int **alloc, **powers, **newpowers, *scratch;
|
||||
size_t nalloc = 2*(n+1)+1;
|
||||
alloc = snewn(nalloc, mp_int *);
|
||||
for (size_t i = 0; i < nalloc; i++)
|
||||
alloc[i] = mp_make_sized(y->nw + 1);
|
||||
powers = alloc;
|
||||
newpowers = alloc + (n+1);
|
||||
scratch = alloc[2*n+2];
|
||||
|
||||
/*
|
||||
* We're computing the rounded-down nth root of y, i.e. the
|
||||
* maximal x such that x^n <= y. We try to add 2^i to it for each
|
||||
* possible value of i, starting from the largest one that might
|
||||
* fit (i.e. such that 2^{n*i} fits in the size of y) downwards to
|
||||
* i=0.
|
||||
*
|
||||
* We track all the smaller powers of x in the array 'powers'. In
|
||||
* each iteration, if we update x, we update all of those values
|
||||
* to match.
|
||||
*/
|
||||
mp_copy_integer_into(powers[0], 1);
|
||||
for (size_t s = mp_max_bits(y) / n + 1; s-- > 0 ;) {
|
||||
/*
|
||||
* Let b = 2^s. We need to compute the powers (x+b)^i for each
|
||||
* i, starting from our recorded values of x^i.
|
||||
*/
|
||||
for (size_t i = 0; i < n+1; i++) {
|
||||
/*
|
||||
* (x+b)^i = x^i
|
||||
* + (i choose 1) x^{i-1} b
|
||||
* + (i choose 2) x^{i-2} b^2
|
||||
* + ...
|
||||
* + b^i
|
||||
*/
|
||||
uint16_t binom = 1; /* coefficient of b^i */
|
||||
mp_copy_into(newpowers[i], powers[i]);
|
||||
for (size_t j = 0; j < i; j++) {
|
||||
/* newpowers[i] += binom * powers[j] * 2^{(i-j)*s} */
|
||||
mp_mul_integer_into(scratch, powers[j], binom);
|
||||
mp_lshift_fixed_into(scratch, scratch, (i-j) * s);
|
||||
mp_add_into(newpowers[i], newpowers[i], scratch);
|
||||
|
||||
uint32_t binom_mul = binom;
|
||||
binom_mul *= (i-j);
|
||||
binom_mul /= (j+1);
|
||||
assert(binom_mul < 0x10000);
|
||||
binom = binom_mul;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Now, is the new value of x^n still <= y? If so, update.
|
||||
*/
|
||||
unsigned newbit = mp_cmp_hs(y, newpowers[n]);
|
||||
for (size_t i = 0; i < n+1; i++)
|
||||
mp_select_into(powers[i], powers[i], newpowers[i], newbit);
|
||||
}
|
||||
|
||||
if (remainder_out)
|
||||
mp_sub_into(remainder_out, y, powers[n]);
|
||||
|
||||
mp_int *root = mp_new(mp_max_bits(y) / n);
|
||||
mp_copy_into(root, powers[1]);
|
||||
|
||||
for (size_t i = 0; i < nalloc; i++)
|
||||
mp_free(alloc[i]);
|
||||
sfree(alloc);
|
||||
|
||||
return root;
|
||||
}
|
||||
|
||||
mp_int *mp_modmul(mp_int *x, mp_int *y, mp_int *modulus)
|
||||
{
|
||||
mp_int *product = mp_mul(x, y);
|
||||
|
11
mpint.h
11
mpint.h
@ -256,6 +256,17 @@ void mp_divmod_into(mp_int *n, mp_int *d, mp_int *q, mp_int *r);
|
||||
mp_int *mp_div(mp_int *n, mp_int *d);
|
||||
mp_int *mp_mod(mp_int *x, mp_int *modulus);
|
||||
|
||||
/*
|
||||
* Integer nth root. mp_nthroot returns the largest integer x such
|
||||
* that x^n <= y, and if 'remainder' is non-NULL then it fills it with
|
||||
* the residue (y - x^n).
|
||||
*
|
||||
* Currently, n has to be small enough that the largest binomial
|
||||
* coefficient (n choose k) fits in 16 bits, which works out to at
|
||||
* most 18.
|
||||
*/
|
||||
mp_int *mp_nthroot(mp_int *y, unsigned n, mp_int *remainder);
|
||||
|
||||
/*
|
||||
* Trivially easy special case of mp_mod: reduce a number mod a power
|
||||
* of two.
|
||||
|
@ -391,6 +391,29 @@ class mpint(MyTestBase):
|
||||
# No tests we can do after that last one - we just
|
||||
# insist that it isn't allowed to have crashed!
|
||||
|
||||
def testNthRoot(self):
|
||||
roots = [1, 13, 1234567654321,
|
||||
57721566490153286060651209008240243104215933593992]
|
||||
tests = []
|
||||
tests.append((0, 2, 0, 0))
|
||||
tests.append((0, 3, 0, 0))
|
||||
for r in roots:
|
||||
for n in 2, 3, 5:
|
||||
tests.append((r**n, n, r, 0))
|
||||
tests.append((r**n+1, n, r, 1))
|
||||
tests.append((r**n-1, n, r-1, r**n - (r-1)**n - 1))
|
||||
for x, n, eroot, eremainder in tests:
|
||||
with self.subTest(x=x):
|
||||
mx = mp_copy(x)
|
||||
remainder = mp_copy(mx)
|
||||
root = mp_nthroot(x, n, remainder)
|
||||
self.assertEqual(int(root), eroot)
|
||||
self.assertEqual(int(remainder), eremainder)
|
||||
self.assertEqual(int(mp_nthroot(2*10**100, 2, None)),
|
||||
141421356237309504880168872420969807856967187537694)
|
||||
self.assertEqual(int(mp_nthroot(3*10**150, 3, None)),
|
||||
144224957030740838232163831078010958839186925349935)
|
||||
|
||||
def testBitwise(self):
|
||||
p = 0x3243f6a8885a308d313198a2e03707344a4093822299f31d0082efa98ec4e
|
||||
e = 0x2b7e151628aed2a6abf7158809cf4f3c762e7160f38b4da56a784d9045190
|
||||
|
@ -51,6 +51,7 @@ FUNC2(void, mp_cond_clear, val_mpint, uint)
|
||||
FUNC4(void, mp_divmod_into, val_mpint, val_mpint, opt_val_mpint, opt_val_mpint)
|
||||
FUNC2(val_mpint, mp_div, val_mpint, val_mpint)
|
||||
FUNC2(val_mpint, mp_mod, val_mpint, val_mpint)
|
||||
FUNC3(val_mpint, mp_nthroot, val_mpint, uint, opt_val_mpint)
|
||||
FUNC2(void, mp_reduce_mod_2to, val_mpint, uint)
|
||||
FUNC2(val_mpint, mp_invert_mod_2to, val_mpint, uint)
|
||||
FUNC2(val_mpint, mp_invert, val_mpint, val_mpint)
|
||||
|
18
testsc.c
18
testsc.c
@ -294,6 +294,7 @@ VOLATILE_WRAPPED_DEFN(static, size_t, looplimit, (size_t x))
|
||||
X(mp_mul) \
|
||||
X(mp_rshift_safe) \
|
||||
X(mp_divmod) \
|
||||
X(mp_nthroot) \
|
||||
X(mp_modadd) \
|
||||
X(mp_modsub) \
|
||||
X(mp_modmul) \
|
||||
@ -577,6 +578,23 @@ static void test_mp_divmod(void)
|
||||
mp_free(r);
|
||||
}
|
||||
|
||||
static void test_mp_nthroot(void)
|
||||
{
|
||||
mp_int *x = mp_new(256), *remainder = mp_new(256);
|
||||
|
||||
for (size_t i = 0; i < looplimit(32); i++) {
|
||||
uint8_t sizes[1];
|
||||
random_read(sizes, 1);
|
||||
mp_random_bits_into(x, sizes[0]);
|
||||
log_start();
|
||||
mp_free(mp_nthroot(x, 3, remainder));
|
||||
log_end();
|
||||
}
|
||||
|
||||
mp_free(x);
|
||||
mp_free(remainder);
|
||||
}
|
||||
|
||||
static void test_mp_modarith(
|
||||
mp_int *(*mp_modarith)(mp_int *x, mp_int *y, mp_int *modulus))
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user