1
0
mirror of https://git.tartarus.org/simon/putty.git synced 2025-01-10 01:48:00 +00:00

mpint: add a gcd function.

This is another application of the existing mp_bezout_into, which
needed a tweak or two to cope with the numbers not necessarily being
coprime, plus a wrapper function to deal with shared factors of 2.

It reindents the entire second half of mp_bezout_into, so the patch is
best viewed with whitespace differences ignored.
This commit is contained in:
Simon Tatham 2020-02-18 18:55:57 +00:00
parent 957f14088f
commit 2debb352b0
4 changed files with 245 additions and 97 deletions

168
mpint.c
View File

@ -1556,10 +1556,10 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus)
}
/*
* Given two coprime nonzero input integers a,b, returns two integers
* A,B such that A*a - B*b = 1. A,B will be the minimal non-negative
* pair satisfying that criterion, which is equivalent to saying that
* 0<=A<b and 0<=B<a.
* Given two input integers a,b which are not both even, computes d =
* gcd(a,b) and also two integers A,B such that A*a - B*b = d. A,B
* will be the minimal non-negative pair satisfying that criterion,
* which is equivalent to saying that 0 <= A < b/d and 0 <= B < a/d.
*
* This algorithm is an adapted form of Stein's algorithm, which
* computes gcd(a,b) using only addition and bit shifts (i.e. without
@ -1571,9 +1571,11 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus)
* - if both of a,b are odd, then WLOG a>b, and gcd(a,b) =
* gcd(b,(a-b)/2).
*
* For this application, I always expect the actual gcd to be coprime,
* so we can rule out the 'both even' initial case. So this function
* just performs a sequence of reductions in the following form:
* Sometimes this function is used for modular inversion, in which
* case we already know we expect the two inputs to be coprime, so to
* save time the 'both even' initial case is assumed not to arise (or
* to have been handled already by the caller). So this function just
* performs a sequence of reductions in the following form:
*
* - if a,b are both odd, sort them so that a > b, and replace a with
* b-a; otherwise sort them so that a is the even one
@ -1584,14 +1586,14 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus)
* generate those in each case, based on the coefficients from the
* reduced pair of numbers:
*
* - If a is even, and u,v are such that u*(a/2) + v*b = 1:
* + if u is also even, then this is just (u/2)*a + v*b = 1
* + otherwise, (u+b)*(a/2) + (v-a/2)*b is also equal to 1, and
* - If a is even, and u,v are such that u*(a/2) + v*b = d:
* + if u is also even, then this is just (u/2)*a + v*b = d
* + otherwise, (u+b)*(a/2) + (v-a/2)*b is also equal to d, and
* since u and b are both odd, (u+b)/2 is an integer, so we have
* ((u+b)/2)*a + (v-a/2)*b = 1.
* ((u+b)/2)*a + (v-a/2)*b = d.
*
* - If a,b are both odd, and u,v are such that u*b + v*(a-b) = 1,
* then v*a + (u-v)*b = 1.
* - If a,b are both odd, and u,v are such that u*b + v*(a-b) = d,
* then v*a + (u-v)*b = d.
*
* In the case where we passed from (a,b) to (b,(a-b)/2), we regard it
* as having first subtracted b from a and then halved a, so both of
@ -1609,11 +1611,11 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus)
* Also, since these mp_ints are generally treated as unsigned, we
* store the coefficients by absolute value, with the semantics that
* they always have opposite sign, and in the unwinding loop we keep a
* bit indicating whether Aa-Bb is currently expected to be +1 or -1,
* so that we can do one final conditional adjustment if it's -1.
* bit indicating whether Aa-Bb is currently expected to be +d or -d,
* so that we can do one final conditional adjustment if it's -d.
*
* Once the reduction rules have managed to reduce the input numbers
* to (0,1), then they are stable (the next reduction will always
* to (0,d), then they are stable (the next reduction will always
* divide the even one by 2, which maps 0 to 0). So it doesn't matter
* if we do more steps of the algorithm than necessary; hence, for
* constant time, we just need to find the maximum number we could
@ -1632,7 +1634,7 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus)
* n further steps each of which subtracts 1 from y and halves it.
*/
static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
mp_int *a_in, mp_int *b_in)
mp_int *gcd_out, mp_int *a_in, mp_int *b_in)
{
size_t nw = size_t_max(1, size_t_max(a_in->nw, b_in->nw));
@ -1691,31 +1693,59 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
}
/*
* Now we expect to have reduced the two numbers to 0 and 1,
* Now we expect to have reduced the two numbers to 0 and d,
* although we don't know which way round. (But we avoid checking
* this by assertion; sometimes we'll need to do this computation
* without giving away that we already know the inputs were bogus.
* So we'd prefer to just press on and return nonsense.)
*/
if (gcd_out) {
/*
* So their Bezout coefficients at this point are simply
* themselves.
* At this point we can return the actual gcd. Since one of
* a,b is it and the other is zero, the easiest way to get it
* is to add them together.
*/
mp_copy_into(ac, a);
mp_copy_into(bc, b);
mp_add_into(gcd_out, a, b);
}
/*
* We'll maintain the invariant as we unwind that ac * a - bc * b
* is either +1 or -1, and we'll remember which. (We _could_ keep
* it at +1 the whole time, but it would cost more work every time
* round the loop, so it's cheaper to fix that up once at the
* end.)
*
* Initially, the result is +1 if a was the nonzero value after
* reduction, and -1 if b was.
* If the caller _only_ wanted the gcd, and neither Bezout
* coefficient is even required, we can skip the entire unwind
* stage.
*/
unsigned minus_one = b->w[0];
if (a_coeff_out || b_coeff_out) {
/*
* The Bezout coefficients of a,b at this point are simply 0
* for whichever of a,b is zero, and 1 for whichever is
* nonzero. The nonzero number equals gcd(a,b), which by
* assumption is odd, so we can do this by just taking the low
* bit of each one.
*/
ac->w[0] = mp_get_bit(a, 0);
bc->w[0] = mp_get_bit(b, 0);
/*
* Overwrite a,b themselves with those same numbers. This has
* the effect of dividing both of them by d, which will
* arrange that during the unwind stage we generate the
* minimal coefficients instead of a larger pair.
*/
mp_copy_into(a, ac);
mp_copy_into(b, bc);
/*
* We'll maintain the invariant as we unwind that ac * a - bc
* * b is either +d or -d (or rather, +1/-1 after scaling by
* d), and we'll remember which. (We _could_ keep it at +d the
* whole time, but it would cost more work every time round
* the loop, so it's cheaper to fix that up once at the end.)
*
* Initially, the result is +d if a was the nonzero value after
* reduction, and -d if b was.
*/
unsigned minus_d = b->w[0];
for (size_t step = steps; step-- > 0 ;) {
/*
@ -1757,25 +1787,22 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
*/
mp_cond_swap(a, b, swap);
mp_cond_swap(ac, bc, swap);
minus_one ^= swap;
minus_d ^= swap;
}
/*
* Now we expect to have recovered the input a,b.
*/
assert(mp_cmp_eq(a, a_in) & mp_cmp_eq(b, b_in));
/*
* But we might find that our current result is -1 instead of +1,
* that is, we have A',B' such that A'a - B'b = -1.
* Now we expect to have recovered the input a,b (or rather,
* the versions of them divided by d). But we might find that
* our current result is -d instead of +d, that is, we have
* A',B' such that A'a - B'b = -d.
*
* In that situation, we set A = b-A' and B = a-B', giving us
* Aa-Bb = ab - A'a - ab + B'b = +1.
*/
mp_sub_into(tmp, b, ac);
mp_select_into(ac, ac, tmp, minus_one);
mp_select_into(ac, ac, tmp, minus_d);
mp_sub_into(tmp, a, bc);
mp_select_into(bc, bc, tmp, minus_one);
mp_select_into(bc, bc, tmp, minus_d);
/*
* Now we really are done. Return the outputs.
@ -1785,6 +1812,8 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
if (b_coeff_out)
mp_copy_into(b_coeff_out, bc);
}
mp_free(a);
mp_free(b);
mp_free(ac);
@ -1796,10 +1825,65 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
mp_int *mp_invert(mp_int *x, mp_int *m)
{
mp_int *result = mp_make_sized(m->nw);
mp_bezout_into(result, NULL, x, m);
mp_bezout_into(result, NULL, NULL, x, m);
return result;
}
void mp_gcd_into(mp_int *a, mp_int *b, mp_int *gcd, mp_int *A, mp_int *B)
{
/*
* Identify shared factors of 2. To do this we OR the two numbers
* to get something whose lowest set bit is in the right place,
* remove all higher bits by ANDing it with its own negation, and
* use mp_get_nbits to find the location of the single remaining
* set bit.
*/
mp_int *tmp = mp_make_sized(size_t_max(a->nw, b->nw));
for (size_t i = 0; i < tmp->nw; i++)
tmp->w[i] = mp_word(a, i) | mp_word(b, i);
BignumCarry carry = 1;
for (size_t i = 0; i < tmp->nw; i++) {
BignumInt negw;
BignumADC(negw, carry, 0, ~tmp->w[i], carry);
tmp->w[i] &= negw;
}
size_t shift = mp_get_nbits(tmp) - 1;
mp_free(tmp);
/*
* Make copies of a,b with those shared factors of 2 divided off,
* so that at least one is odd (which is the precondition for
* mp_bezout_into). Compute the gcd of those.
*/
mp_int *as = mp_rshift_safe(a, shift);
mp_int *bs = mp_rshift_safe(b, shift);
mp_bezout_into(A, B, gcd, as, bs);
mp_free(as);
mp_free(bs);
/*
* And finally shift the gcd back up (unless the caller didn't
* even ask for it), to put the shared factors of 2 back in.
*/
if (gcd)
mp_lshift_safe_in_place(gcd, shift);
}
mp_int *mp_gcd(mp_int *a, mp_int *b)
{
mp_int *gcd = mp_make_sized(size_t_min(a->nw, b->nw));
mp_gcd_into(a, b, gcd, NULL, NULL);
return gcd;
}
unsigned mp_coprime(mp_int *a, mp_int *b)
{
mp_int *gcd = mp_gcd(a, b);
unsigned toret = mp_eq_integer(gcd, 1);
mp_free(gcd);
return toret;
}
static uint32_t recip_approx_32(uint32_t x)
{
/*

19
mpint.h
View File

@ -270,6 +270,25 @@ void mp_reduce_mod_2to(mp_int *x, size_t p);
mp_int *mp_invert_mod_2to(mp_int *x, size_t p);
mp_int *mp_invert(mp_int *x, mp_int *modulus);
/*
* Greatest common divisor.
*
* mp_gcd_into also returns a pair of Bezout coefficients, namely A,B
* such that a*A - b*B = gcd. (The minus sign is so that both returned
* coefficients can be positive.)
*
* You can pass any of mp_gcd_into's output pointers as NULL if you
* don't need that output value.
*
* mp_gcd is a wrapper with a less cumbersome API, for the case where
* the only output value you need is the gcd itself. mp_coprime is
* even easier, if all you care about is whether or not that gcd is 1.
*/
mp_int *mp_gcd(mp_int *a, mp_int *b);
void mp_gcd_into(mp_int *a, mp_int *b,
mp_int *gcd_out, mp_int *A_out, mp_int *B_out);
unsigned mp_coprime(mp_int *a, mp_int *b);
/*
* System for taking square roots modulo an odd prime.
*

View File

@ -437,6 +437,48 @@ class mpint(MyTestBase):
int(monty_invert(mc, monty_import(mc, x))),
int(monty_import(mc, inv)))
def testGCD(self):
powerpairs = [(0,0), (1,0), (1,1), (2,1), (2,2), (75,3), (17,23)]
for a2, b2 in powerpairs:
for a3, b3 in powerpairs:
for a5, b5 in powerpairs:
a = 2**a2 * 3**a3 * 5**a5 * 17 * 19 * 23
b = 2**b2 * 3**b3 * 5**b5 * 65423
d = 2**min(a2, b2) * 3**min(a3, b3) * 5**min(a5, b5)
ma = mp_copy(a)
mb = mp_copy(b)
self.assertEqual(int(mp_gcd(ma, mb)), d)
md = mp_new(nbits(d))
mA = mp_new(nbits(b))
mB = mp_new(nbits(a))
mp_gcd_into(ma, mb, md, mA, mB)
self.assertEqual(int(md), d)
A = int(mA)
B = int(mB)
self.assertEqual(a*A - b*B, d)
self.assertTrue(0 <= A < b//d)
self.assertTrue(0 <= B < a//d)
self.assertEqual(mp_coprime(ma, mb), 1 if d==1 else 0)
# Make sure gcd_into can handle not getting some
# of its output pointers.
mp_clear(md)
mp_gcd_into(ma, mb, md, None, None)
self.assertEqual(int(md), d)
mp_clear(mA)
mp_gcd_into(ma, mb, None, mA, None)
self.assertEqual(int(mA), A)
mp_clear(mB)
mp_gcd_into(ma, mb, None, None, mB)
self.assertEqual(int(mB), B)
mp_gcd_into(ma, mb, None, None, None)
# No tests we can do after that last one - we just
# insist that it isn't allowed to have crashed!
def testMonty(self):
moduli = [5, 19, 2**16+1, 2**31-1, 2**128-159, 2**255-19,
293828847201107461142630006802421204703,

View File

@ -54,6 +54,9 @@ FUNC2(val_mpint, mp_mod, val_mpint, val_mpint)
FUNC2(void, mp_reduce_mod_2to, val_mpint, uint)
FUNC2(val_mpint, mp_invert_mod_2to, val_mpint, uint)
FUNC2(val_mpint, mp_invert, val_mpint, val_mpint)
FUNC5(void, mp_gcd_into, val_mpint, val_mpint, opt_val_mpint, opt_val_mpint, opt_val_mpint)
FUNC2(val_mpint, mp_gcd, val_mpint, val_mpint)
FUNC2(uint, mp_coprime, val_mpint, val_mpint)
FUNC2(val_modsqrt, modsqrt_new, val_mpint, val_mpint)
/* The modsqrt functions' 'success' pointer becomes a second return value */
FUNC3(val_mpint, mp_modsqrt, val_modsqrt, val_mpint, out_uint)