mirror of
https://git.tartarus.org/simon/putty.git
synced 2025-01-10 01:48:00 +00:00
mpint: add a gcd function.
This is another application of the existing mp_bezout_into, which needed a tweak or two to cope with the numbers not necessarily being coprime, plus a wrapper function to deal with shared factors of 2. It reindents the entire second half of mp_bezout_into, so the patch is best viewed with whitespace differences ignored.
This commit is contained in:
parent
957f14088f
commit
2debb352b0
168
mpint.c
168
mpint.c
@ -1556,10 +1556,10 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus)
|
||||
}
|
||||
|
||||
/*
|
||||
* Given two coprime nonzero input integers a,b, returns two integers
|
||||
* A,B such that A*a - B*b = 1. A,B will be the minimal non-negative
|
||||
* pair satisfying that criterion, which is equivalent to saying that
|
||||
* 0<=A<b and 0<=B<a.
|
||||
* Given two input integers a,b which are not both even, computes d =
|
||||
* gcd(a,b) and also two integers A,B such that A*a - B*b = d. A,B
|
||||
* will be the minimal non-negative pair satisfying that criterion,
|
||||
* which is equivalent to saying that 0 <= A < b/d and 0 <= B < a/d.
|
||||
*
|
||||
* This algorithm is an adapted form of Stein's algorithm, which
|
||||
* computes gcd(a,b) using only addition and bit shifts (i.e. without
|
||||
@ -1571,9 +1571,11 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus)
|
||||
* - if both of a,b are odd, then WLOG a>b, and gcd(a,b) =
|
||||
* gcd(b,(a-b)/2).
|
||||
*
|
||||
* For this application, I always expect the actual gcd to be coprime,
|
||||
* so we can rule out the 'both even' initial case. So this function
|
||||
* just performs a sequence of reductions in the following form:
|
||||
* Sometimes this function is used for modular inversion, in which
|
||||
* case we already know we expect the two inputs to be coprime, so to
|
||||
* save time the 'both even' initial case is assumed not to arise (or
|
||||
* to have been handled already by the caller). So this function just
|
||||
* performs a sequence of reductions in the following form:
|
||||
*
|
||||
* - if a,b are both odd, sort them so that a > b, and replace a with
|
||||
* b-a; otherwise sort them so that a is the even one
|
||||
@ -1584,14 +1586,14 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus)
|
||||
* generate those in each case, based on the coefficients from the
|
||||
* reduced pair of numbers:
|
||||
*
|
||||
* - If a is even, and u,v are such that u*(a/2) + v*b = 1:
|
||||
* + if u is also even, then this is just (u/2)*a + v*b = 1
|
||||
* + otherwise, (u+b)*(a/2) + (v-a/2)*b is also equal to 1, and
|
||||
* - If a is even, and u,v are such that u*(a/2) + v*b = d:
|
||||
* + if u is also even, then this is just (u/2)*a + v*b = d
|
||||
* + otherwise, (u+b)*(a/2) + (v-a/2)*b is also equal to d, and
|
||||
* since u and b are both odd, (u+b)/2 is an integer, so we have
|
||||
* ((u+b)/2)*a + (v-a/2)*b = 1.
|
||||
* ((u+b)/2)*a + (v-a/2)*b = d.
|
||||
*
|
||||
* - If a,b are both odd, and u,v are such that u*b + v*(a-b) = 1,
|
||||
* then v*a + (u-v)*b = 1.
|
||||
* - If a,b are both odd, and u,v are such that u*b + v*(a-b) = d,
|
||||
* then v*a + (u-v)*b = d.
|
||||
*
|
||||
* In the case where we passed from (a,b) to (b,(a-b)/2), we regard it
|
||||
* as having first subtracted b from a and then halved a, so both of
|
||||
@ -1609,11 +1611,11 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus)
|
||||
* Also, since these mp_ints are generally treated as unsigned, we
|
||||
* store the coefficients by absolute value, with the semantics that
|
||||
* they always have opposite sign, and in the unwinding loop we keep a
|
||||
* bit indicating whether Aa-Bb is currently expected to be +1 or -1,
|
||||
* so that we can do one final conditional adjustment if it's -1.
|
||||
* bit indicating whether Aa-Bb is currently expected to be +d or -d,
|
||||
* so that we can do one final conditional adjustment if it's -d.
|
||||
*
|
||||
* Once the reduction rules have managed to reduce the input numbers
|
||||
* to (0,1), then they are stable (the next reduction will always
|
||||
* to (0,d), then they are stable (the next reduction will always
|
||||
* divide the even one by 2, which maps 0 to 0). So it doesn't matter
|
||||
* if we do more steps of the algorithm than necessary; hence, for
|
||||
* constant time, we just need to find the maximum number we could
|
||||
@ -1632,7 +1634,7 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus)
|
||||
* n further steps each of which subtracts 1 from y and halves it.
|
||||
*/
|
||||
static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
|
||||
mp_int *a_in, mp_int *b_in)
|
||||
mp_int *gcd_out, mp_int *a_in, mp_int *b_in)
|
||||
{
|
||||
size_t nw = size_t_max(1, size_t_max(a_in->nw, b_in->nw));
|
||||
|
||||
@ -1691,31 +1693,59 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
|
||||
}
|
||||
|
||||
/*
|
||||
* Now we expect to have reduced the two numbers to 0 and 1,
|
||||
* Now we expect to have reduced the two numbers to 0 and d,
|
||||
* although we don't know which way round. (But we avoid checking
|
||||
* this by assertion; sometimes we'll need to do this computation
|
||||
* without giving away that we already know the inputs were bogus.
|
||||
* So we'd prefer to just press on and return nonsense.)
|
||||
*/
|
||||
|
||||
if (gcd_out) {
|
||||
/*
|
||||
* So their Bezout coefficients at this point are simply
|
||||
* themselves.
|
||||
* At this point we can return the actual gcd. Since one of
|
||||
* a,b is it and the other is zero, the easiest way to get it
|
||||
* is to add them together.
|
||||
*/
|
||||
mp_copy_into(ac, a);
|
||||
mp_copy_into(bc, b);
|
||||
mp_add_into(gcd_out, a, b);
|
||||
}
|
||||
|
||||
/*
|
||||
* We'll maintain the invariant as we unwind that ac * a - bc * b
|
||||
* is either +1 or -1, and we'll remember which. (We _could_ keep
|
||||
* it at +1 the whole time, but it would cost more work every time
|
||||
* round the loop, so it's cheaper to fix that up once at the
|
||||
* end.)
|
||||
*
|
||||
* Initially, the result is +1 if a was the nonzero value after
|
||||
* reduction, and -1 if b was.
|
||||
* If the caller _only_ wanted the gcd, and neither Bezout
|
||||
* coefficient is even required, we can skip the entire unwind
|
||||
* stage.
|
||||
*/
|
||||
unsigned minus_one = b->w[0];
|
||||
if (a_coeff_out || b_coeff_out) {
|
||||
|
||||
/*
|
||||
* The Bezout coefficients of a,b at this point are simply 0
|
||||
* for whichever of a,b is zero, and 1 for whichever is
|
||||
* nonzero. The nonzero number equals gcd(a,b), which by
|
||||
* assumption is odd, so we can do this by just taking the low
|
||||
* bit of each one.
|
||||
*/
|
||||
ac->w[0] = mp_get_bit(a, 0);
|
||||
bc->w[0] = mp_get_bit(b, 0);
|
||||
|
||||
/*
|
||||
* Overwrite a,b themselves with those same numbers. This has
|
||||
* the effect of dividing both of them by d, which will
|
||||
* arrange that during the unwind stage we generate the
|
||||
* minimal coefficients instead of a larger pair.
|
||||
*/
|
||||
mp_copy_into(a, ac);
|
||||
mp_copy_into(b, bc);
|
||||
|
||||
/*
|
||||
* We'll maintain the invariant as we unwind that ac * a - bc
|
||||
* * b is either +d or -d (or rather, +1/-1 after scaling by
|
||||
* d), and we'll remember which. (We _could_ keep it at +d the
|
||||
* whole time, but it would cost more work every time round
|
||||
* the loop, so it's cheaper to fix that up once at the end.)
|
||||
*
|
||||
* Initially, the result is +d if a was the nonzero value after
|
||||
* reduction, and -d if b was.
|
||||
*/
|
||||
unsigned minus_d = b->w[0];
|
||||
|
||||
for (size_t step = steps; step-- > 0 ;) {
|
||||
/*
|
||||
@ -1757,25 +1787,22 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
|
||||
*/
|
||||
mp_cond_swap(a, b, swap);
|
||||
mp_cond_swap(ac, bc, swap);
|
||||
minus_one ^= swap;
|
||||
minus_d ^= swap;
|
||||
}
|
||||
|
||||
/*
|
||||
* Now we expect to have recovered the input a,b.
|
||||
*/
|
||||
assert(mp_cmp_eq(a, a_in) & mp_cmp_eq(b, b_in));
|
||||
|
||||
/*
|
||||
* But we might find that our current result is -1 instead of +1,
|
||||
* that is, we have A',B' such that A'a - B'b = -1.
|
||||
* Now we expect to have recovered the input a,b (or rather,
|
||||
* the versions of them divided by d). But we might find that
|
||||
* our current result is -d instead of +d, that is, we have
|
||||
* A',B' such that A'a - B'b = -d.
|
||||
*
|
||||
* In that situation, we set A = b-A' and B = a-B', giving us
|
||||
* Aa-Bb = ab - A'a - ab + B'b = +1.
|
||||
*/
|
||||
mp_sub_into(tmp, b, ac);
|
||||
mp_select_into(ac, ac, tmp, minus_one);
|
||||
mp_select_into(ac, ac, tmp, minus_d);
|
||||
mp_sub_into(tmp, a, bc);
|
||||
mp_select_into(bc, bc, tmp, minus_one);
|
||||
mp_select_into(bc, bc, tmp, minus_d);
|
||||
|
||||
/*
|
||||
* Now we really are done. Return the outputs.
|
||||
@ -1785,6 +1812,8 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
|
||||
if (b_coeff_out)
|
||||
mp_copy_into(b_coeff_out, bc);
|
||||
|
||||
}
|
||||
|
||||
mp_free(a);
|
||||
mp_free(b);
|
||||
mp_free(ac);
|
||||
@ -1796,10 +1825,65 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
|
||||
mp_int *mp_invert(mp_int *x, mp_int *m)
|
||||
{
|
||||
mp_int *result = mp_make_sized(m->nw);
|
||||
mp_bezout_into(result, NULL, x, m);
|
||||
mp_bezout_into(result, NULL, NULL, x, m);
|
||||
return result;
|
||||
}
|
||||
|
||||
void mp_gcd_into(mp_int *a, mp_int *b, mp_int *gcd, mp_int *A, mp_int *B)
|
||||
{
|
||||
/*
|
||||
* Identify shared factors of 2. To do this we OR the two numbers
|
||||
* to get something whose lowest set bit is in the right place,
|
||||
* remove all higher bits by ANDing it with its own negation, and
|
||||
* use mp_get_nbits to find the location of the single remaining
|
||||
* set bit.
|
||||
*/
|
||||
mp_int *tmp = mp_make_sized(size_t_max(a->nw, b->nw));
|
||||
for (size_t i = 0; i < tmp->nw; i++)
|
||||
tmp->w[i] = mp_word(a, i) | mp_word(b, i);
|
||||
BignumCarry carry = 1;
|
||||
for (size_t i = 0; i < tmp->nw; i++) {
|
||||
BignumInt negw;
|
||||
BignumADC(negw, carry, 0, ~tmp->w[i], carry);
|
||||
tmp->w[i] &= negw;
|
||||
}
|
||||
size_t shift = mp_get_nbits(tmp) - 1;
|
||||
mp_free(tmp);
|
||||
|
||||
/*
|
||||
* Make copies of a,b with those shared factors of 2 divided off,
|
||||
* so that at least one is odd (which is the precondition for
|
||||
* mp_bezout_into). Compute the gcd of those.
|
||||
*/
|
||||
mp_int *as = mp_rshift_safe(a, shift);
|
||||
mp_int *bs = mp_rshift_safe(b, shift);
|
||||
mp_bezout_into(A, B, gcd, as, bs);
|
||||
mp_free(as);
|
||||
mp_free(bs);
|
||||
|
||||
/*
|
||||
* And finally shift the gcd back up (unless the caller didn't
|
||||
* even ask for it), to put the shared factors of 2 back in.
|
||||
*/
|
||||
if (gcd)
|
||||
mp_lshift_safe_in_place(gcd, shift);
|
||||
}
|
||||
|
||||
mp_int *mp_gcd(mp_int *a, mp_int *b)
|
||||
{
|
||||
mp_int *gcd = mp_make_sized(size_t_min(a->nw, b->nw));
|
||||
mp_gcd_into(a, b, gcd, NULL, NULL);
|
||||
return gcd;
|
||||
}
|
||||
|
||||
unsigned mp_coprime(mp_int *a, mp_int *b)
|
||||
{
|
||||
mp_int *gcd = mp_gcd(a, b);
|
||||
unsigned toret = mp_eq_integer(gcd, 1);
|
||||
mp_free(gcd);
|
||||
return toret;
|
||||
}
|
||||
|
||||
static uint32_t recip_approx_32(uint32_t x)
|
||||
{
|
||||
/*
|
||||
|
19
mpint.h
19
mpint.h
@ -270,6 +270,25 @@ void mp_reduce_mod_2to(mp_int *x, size_t p);
|
||||
mp_int *mp_invert_mod_2to(mp_int *x, size_t p);
|
||||
mp_int *mp_invert(mp_int *x, mp_int *modulus);
|
||||
|
||||
/*
|
||||
* Greatest common divisor.
|
||||
*
|
||||
* mp_gcd_into also returns a pair of Bezout coefficients, namely A,B
|
||||
* such that a*A - b*B = gcd. (The minus sign is so that both returned
|
||||
* coefficients can be positive.)
|
||||
*
|
||||
* You can pass any of mp_gcd_into's output pointers as NULL if you
|
||||
* don't need that output value.
|
||||
*
|
||||
* mp_gcd is a wrapper with a less cumbersome API, for the case where
|
||||
* the only output value you need is the gcd itself. mp_coprime is
|
||||
* even easier, if all you care about is whether or not that gcd is 1.
|
||||
*/
|
||||
mp_int *mp_gcd(mp_int *a, mp_int *b);
|
||||
void mp_gcd_into(mp_int *a, mp_int *b,
|
||||
mp_int *gcd_out, mp_int *A_out, mp_int *B_out);
|
||||
unsigned mp_coprime(mp_int *a, mp_int *b);
|
||||
|
||||
/*
|
||||
* System for taking square roots modulo an odd prime.
|
||||
*
|
||||
|
@ -437,6 +437,48 @@ class mpint(MyTestBase):
|
||||
int(monty_invert(mc, monty_import(mc, x))),
|
||||
int(monty_import(mc, inv)))
|
||||
|
||||
def testGCD(self):
|
||||
powerpairs = [(0,0), (1,0), (1,1), (2,1), (2,2), (75,3), (17,23)]
|
||||
for a2, b2 in powerpairs:
|
||||
for a3, b3 in powerpairs:
|
||||
for a5, b5 in powerpairs:
|
||||
a = 2**a2 * 3**a3 * 5**a5 * 17 * 19 * 23
|
||||
b = 2**b2 * 3**b3 * 5**b5 * 65423
|
||||
d = 2**min(a2, b2) * 3**min(a3, b3) * 5**min(a5, b5)
|
||||
|
||||
ma = mp_copy(a)
|
||||
mb = mp_copy(b)
|
||||
|
||||
self.assertEqual(int(mp_gcd(ma, mb)), d)
|
||||
|
||||
md = mp_new(nbits(d))
|
||||
mA = mp_new(nbits(b))
|
||||
mB = mp_new(nbits(a))
|
||||
mp_gcd_into(ma, mb, md, mA, mB)
|
||||
self.assertEqual(int(md), d)
|
||||
A = int(mA)
|
||||
B = int(mB)
|
||||
self.assertEqual(a*A - b*B, d)
|
||||
self.assertTrue(0 <= A < b//d)
|
||||
self.assertTrue(0 <= B < a//d)
|
||||
|
||||
self.assertEqual(mp_coprime(ma, mb), 1 if d==1 else 0)
|
||||
|
||||
# Make sure gcd_into can handle not getting some
|
||||
# of its output pointers.
|
||||
mp_clear(md)
|
||||
mp_gcd_into(ma, mb, md, None, None)
|
||||
self.assertEqual(int(md), d)
|
||||
mp_clear(mA)
|
||||
mp_gcd_into(ma, mb, None, mA, None)
|
||||
self.assertEqual(int(mA), A)
|
||||
mp_clear(mB)
|
||||
mp_gcd_into(ma, mb, None, None, mB)
|
||||
self.assertEqual(int(mB), B)
|
||||
mp_gcd_into(ma, mb, None, None, None)
|
||||
# No tests we can do after that last one - we just
|
||||
# insist that it isn't allowed to have crashed!
|
||||
|
||||
def testMonty(self):
|
||||
moduli = [5, 19, 2**16+1, 2**31-1, 2**128-159, 2**255-19,
|
||||
293828847201107461142630006802421204703,
|
||||
|
@ -54,6 +54,9 @@ FUNC2(val_mpint, mp_mod, val_mpint, val_mpint)
|
||||
FUNC2(void, mp_reduce_mod_2to, val_mpint, uint)
|
||||
FUNC2(val_mpint, mp_invert_mod_2to, val_mpint, uint)
|
||||
FUNC2(val_mpint, mp_invert, val_mpint, val_mpint)
|
||||
FUNC5(void, mp_gcd_into, val_mpint, val_mpint, opt_val_mpint, opt_val_mpint, opt_val_mpint)
|
||||
FUNC2(val_mpint, mp_gcd, val_mpint, val_mpint)
|
||||
FUNC2(uint, mp_coprime, val_mpint, val_mpint)
|
||||
FUNC2(val_modsqrt, modsqrt_new, val_mpint, val_mpint)
|
||||
/* The modsqrt functions' 'success' pointer becomes a second return value */
|
||||
FUNC3(val_mpint, mp_modsqrt, val_modsqrt, val_mpint, out_uint)
|
||||
|
Loading…
Reference in New Issue
Block a user