mirror of
https://git.tartarus.org/simon/putty.git
synced 2025-01-10 09:58:01 +00:00
mpint: add a gcd function.
This is another application of the existing mp_bezout_into, which needed a tweak or two to cope with the numbers not necessarily being coprime, plus a wrapper function to deal with shared factors of 2. It reindents the entire second half of mp_bezout_into, so the patch is best viewed with whitespace differences ignored.
This commit is contained in:
parent
957f14088f
commit
2debb352b0
168
mpint.c
168
mpint.c
@ -1556,10 +1556,10 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus)
|
|||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Given two coprime nonzero input integers a,b, returns two integers
|
* Given two input integers a,b which are not both even, computes d =
|
||||||
* A,B such that A*a - B*b = 1. A,B will be the minimal non-negative
|
* gcd(a,b) and also two integers A,B such that A*a - B*b = d. A,B
|
||||||
* pair satisfying that criterion, which is equivalent to saying that
|
* will be the minimal non-negative pair satisfying that criterion,
|
||||||
* 0<=A<b and 0<=B<a.
|
* which is equivalent to saying that 0 <= A < b/d and 0 <= B < a/d.
|
||||||
*
|
*
|
||||||
* This algorithm is an adapted form of Stein's algorithm, which
|
* This algorithm is an adapted form of Stein's algorithm, which
|
||||||
* computes gcd(a,b) using only addition and bit shifts (i.e. without
|
* computes gcd(a,b) using only addition and bit shifts (i.e. without
|
||||||
@ -1571,9 +1571,11 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus)
|
|||||||
* - if both of a,b are odd, then WLOG a>b, and gcd(a,b) =
|
* - if both of a,b are odd, then WLOG a>b, and gcd(a,b) =
|
||||||
* gcd(b,(a-b)/2).
|
* gcd(b,(a-b)/2).
|
||||||
*
|
*
|
||||||
* For this application, I always expect the actual gcd to be coprime,
|
* Sometimes this function is used for modular inversion, in which
|
||||||
* so we can rule out the 'both even' initial case. So this function
|
* case we already know we expect the two inputs to be coprime, so to
|
||||||
* just performs a sequence of reductions in the following form:
|
* save time the 'both even' initial case is assumed not to arise (or
|
||||||
|
* to have been handled already by the caller). So this function just
|
||||||
|
* performs a sequence of reductions in the following form:
|
||||||
*
|
*
|
||||||
* - if a,b are both odd, sort them so that a > b, and replace a with
|
* - if a,b are both odd, sort them so that a > b, and replace a with
|
||||||
* b-a; otherwise sort them so that a is the even one
|
* b-a; otherwise sort them so that a is the even one
|
||||||
@ -1584,14 +1586,14 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus)
|
|||||||
* generate those in each case, based on the coefficients from the
|
* generate those in each case, based on the coefficients from the
|
||||||
* reduced pair of numbers:
|
* reduced pair of numbers:
|
||||||
*
|
*
|
||||||
* - If a is even, and u,v are such that u*(a/2) + v*b = 1:
|
* - If a is even, and u,v are such that u*(a/2) + v*b = d:
|
||||||
* + if u is also even, then this is just (u/2)*a + v*b = 1
|
* + if u is also even, then this is just (u/2)*a + v*b = d
|
||||||
* + otherwise, (u+b)*(a/2) + (v-a/2)*b is also equal to 1, and
|
* + otherwise, (u+b)*(a/2) + (v-a/2)*b is also equal to d, and
|
||||||
* since u and b are both odd, (u+b)/2 is an integer, so we have
|
* since u and b are both odd, (u+b)/2 is an integer, so we have
|
||||||
* ((u+b)/2)*a + (v-a/2)*b = 1.
|
* ((u+b)/2)*a + (v-a/2)*b = d.
|
||||||
*
|
*
|
||||||
* - If a,b are both odd, and u,v are such that u*b + v*(a-b) = 1,
|
* - If a,b are both odd, and u,v are such that u*b + v*(a-b) = d,
|
||||||
* then v*a + (u-v)*b = 1.
|
* then v*a + (u-v)*b = d.
|
||||||
*
|
*
|
||||||
* In the case where we passed from (a,b) to (b,(a-b)/2), we regard it
|
* In the case where we passed from (a,b) to (b,(a-b)/2), we regard it
|
||||||
* as having first subtracted b from a and then halved a, so both of
|
* as having first subtracted b from a and then halved a, so both of
|
||||||
@ -1609,11 +1611,11 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus)
|
|||||||
* Also, since these mp_ints are generally treated as unsigned, we
|
* Also, since these mp_ints are generally treated as unsigned, we
|
||||||
* store the coefficients by absolute value, with the semantics that
|
* store the coefficients by absolute value, with the semantics that
|
||||||
* they always have opposite sign, and in the unwinding loop we keep a
|
* they always have opposite sign, and in the unwinding loop we keep a
|
||||||
* bit indicating whether Aa-Bb is currently expected to be +1 or -1,
|
* bit indicating whether Aa-Bb is currently expected to be +d or -d,
|
||||||
* so that we can do one final conditional adjustment if it's -1.
|
* so that we can do one final conditional adjustment if it's -d.
|
||||||
*
|
*
|
||||||
* Once the reduction rules have managed to reduce the input numbers
|
* Once the reduction rules have managed to reduce the input numbers
|
||||||
* to (0,1), then they are stable (the next reduction will always
|
* to (0,d), then they are stable (the next reduction will always
|
||||||
* divide the even one by 2, which maps 0 to 0). So it doesn't matter
|
* divide the even one by 2, which maps 0 to 0). So it doesn't matter
|
||||||
* if we do more steps of the algorithm than necessary; hence, for
|
* if we do more steps of the algorithm than necessary; hence, for
|
||||||
* constant time, we just need to find the maximum number we could
|
* constant time, we just need to find the maximum number we could
|
||||||
@ -1632,7 +1634,7 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus)
|
|||||||
* n further steps each of which subtracts 1 from y and halves it.
|
* n further steps each of which subtracts 1 from y and halves it.
|
||||||
*/
|
*/
|
||||||
static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
|
static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
|
||||||
mp_int *a_in, mp_int *b_in)
|
mp_int *gcd_out, mp_int *a_in, mp_int *b_in)
|
||||||
{
|
{
|
||||||
size_t nw = size_t_max(1, size_t_max(a_in->nw, b_in->nw));
|
size_t nw = size_t_max(1, size_t_max(a_in->nw, b_in->nw));
|
||||||
|
|
||||||
@ -1691,31 +1693,59 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
|
|||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Now we expect to have reduced the two numbers to 0 and 1,
|
* Now we expect to have reduced the two numbers to 0 and d,
|
||||||
* although we don't know which way round. (But we avoid checking
|
* although we don't know which way round. (But we avoid checking
|
||||||
* this by assertion; sometimes we'll need to do this computation
|
* this by assertion; sometimes we'll need to do this computation
|
||||||
* without giving away that we already know the inputs were bogus.
|
* without giving away that we already know the inputs were bogus.
|
||||||
* So we'd prefer to just press on and return nonsense.)
|
* So we'd prefer to just press on and return nonsense.)
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
if (gcd_out) {
|
||||||
/*
|
/*
|
||||||
* So their Bezout coefficients at this point are simply
|
* At this point we can return the actual gcd. Since one of
|
||||||
* themselves.
|
* a,b is it and the other is zero, the easiest way to get it
|
||||||
|
* is to add them together.
|
||||||
*/
|
*/
|
||||||
mp_copy_into(ac, a);
|
mp_add_into(gcd_out, a, b);
|
||||||
mp_copy_into(bc, b);
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* We'll maintain the invariant as we unwind that ac * a - bc * b
|
* If the caller _only_ wanted the gcd, and neither Bezout
|
||||||
* is either +1 or -1, and we'll remember which. (We _could_ keep
|
* coefficient is even required, we can skip the entire unwind
|
||||||
* it at +1 the whole time, but it would cost more work every time
|
* stage.
|
||||||
* round the loop, so it's cheaper to fix that up once at the
|
|
||||||
* end.)
|
|
||||||
*
|
|
||||||
* Initially, the result is +1 if a was the nonzero value after
|
|
||||||
* reduction, and -1 if b was.
|
|
||||||
*/
|
*/
|
||||||
unsigned minus_one = b->w[0];
|
if (a_coeff_out || b_coeff_out) {
|
||||||
|
|
||||||
|
/*
|
||||||
|
* The Bezout coefficients of a,b at this point are simply 0
|
||||||
|
* for whichever of a,b is zero, and 1 for whichever is
|
||||||
|
* nonzero. The nonzero number equals gcd(a,b), which by
|
||||||
|
* assumption is odd, so we can do this by just taking the low
|
||||||
|
* bit of each one.
|
||||||
|
*/
|
||||||
|
ac->w[0] = mp_get_bit(a, 0);
|
||||||
|
bc->w[0] = mp_get_bit(b, 0);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Overwrite a,b themselves with those same numbers. This has
|
||||||
|
* the effect of dividing both of them by d, which will
|
||||||
|
* arrange that during the unwind stage we generate the
|
||||||
|
* minimal coefficients instead of a larger pair.
|
||||||
|
*/
|
||||||
|
mp_copy_into(a, ac);
|
||||||
|
mp_copy_into(b, bc);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* We'll maintain the invariant as we unwind that ac * a - bc
|
||||||
|
* * b is either +d or -d (or rather, +1/-1 after scaling by
|
||||||
|
* d), and we'll remember which. (We _could_ keep it at +d the
|
||||||
|
* whole time, but it would cost more work every time round
|
||||||
|
* the loop, so it's cheaper to fix that up once at the end.)
|
||||||
|
*
|
||||||
|
* Initially, the result is +d if a was the nonzero value after
|
||||||
|
* reduction, and -d if b was.
|
||||||
|
*/
|
||||||
|
unsigned minus_d = b->w[0];
|
||||||
|
|
||||||
for (size_t step = steps; step-- > 0 ;) {
|
for (size_t step = steps; step-- > 0 ;) {
|
||||||
/*
|
/*
|
||||||
@ -1757,25 +1787,22 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
|
|||||||
*/
|
*/
|
||||||
mp_cond_swap(a, b, swap);
|
mp_cond_swap(a, b, swap);
|
||||||
mp_cond_swap(ac, bc, swap);
|
mp_cond_swap(ac, bc, swap);
|
||||||
minus_one ^= swap;
|
minus_d ^= swap;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Now we expect to have recovered the input a,b.
|
* Now we expect to have recovered the input a,b (or rather,
|
||||||
*/
|
* the versions of them divided by d). But we might find that
|
||||||
assert(mp_cmp_eq(a, a_in) & mp_cmp_eq(b, b_in));
|
* our current result is -d instead of +d, that is, we have
|
||||||
|
* A',B' such that A'a - B'b = -d.
|
||||||
/*
|
|
||||||
* But we might find that our current result is -1 instead of +1,
|
|
||||||
* that is, we have A',B' such that A'a - B'b = -1.
|
|
||||||
*
|
*
|
||||||
* In that situation, we set A = b-A' and B = a-B', giving us
|
* In that situation, we set A = b-A' and B = a-B', giving us
|
||||||
* Aa-Bb = ab - A'a - ab + B'b = +1.
|
* Aa-Bb = ab - A'a - ab + B'b = +1.
|
||||||
*/
|
*/
|
||||||
mp_sub_into(tmp, b, ac);
|
mp_sub_into(tmp, b, ac);
|
||||||
mp_select_into(ac, ac, tmp, minus_one);
|
mp_select_into(ac, ac, tmp, minus_d);
|
||||||
mp_sub_into(tmp, a, bc);
|
mp_sub_into(tmp, a, bc);
|
||||||
mp_select_into(bc, bc, tmp, minus_one);
|
mp_select_into(bc, bc, tmp, minus_d);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Now we really are done. Return the outputs.
|
* Now we really are done. Return the outputs.
|
||||||
@ -1785,6 +1812,8 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
|
|||||||
if (b_coeff_out)
|
if (b_coeff_out)
|
||||||
mp_copy_into(b_coeff_out, bc);
|
mp_copy_into(b_coeff_out, bc);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
mp_free(a);
|
mp_free(a);
|
||||||
mp_free(b);
|
mp_free(b);
|
||||||
mp_free(ac);
|
mp_free(ac);
|
||||||
@ -1796,10 +1825,65 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
|
|||||||
mp_int *mp_invert(mp_int *x, mp_int *m)
|
mp_int *mp_invert(mp_int *x, mp_int *m)
|
||||||
{
|
{
|
||||||
mp_int *result = mp_make_sized(m->nw);
|
mp_int *result = mp_make_sized(m->nw);
|
||||||
mp_bezout_into(result, NULL, x, m);
|
mp_bezout_into(result, NULL, NULL, x, m);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void mp_gcd_into(mp_int *a, mp_int *b, mp_int *gcd, mp_int *A, mp_int *B)
|
||||||
|
{
|
||||||
|
/*
|
||||||
|
* Identify shared factors of 2. To do this we OR the two numbers
|
||||||
|
* to get something whose lowest set bit is in the right place,
|
||||||
|
* remove all higher bits by ANDing it with its own negation, and
|
||||||
|
* use mp_get_nbits to find the location of the single remaining
|
||||||
|
* set bit.
|
||||||
|
*/
|
||||||
|
mp_int *tmp = mp_make_sized(size_t_max(a->nw, b->nw));
|
||||||
|
for (size_t i = 0; i < tmp->nw; i++)
|
||||||
|
tmp->w[i] = mp_word(a, i) | mp_word(b, i);
|
||||||
|
BignumCarry carry = 1;
|
||||||
|
for (size_t i = 0; i < tmp->nw; i++) {
|
||||||
|
BignumInt negw;
|
||||||
|
BignumADC(negw, carry, 0, ~tmp->w[i], carry);
|
||||||
|
tmp->w[i] &= negw;
|
||||||
|
}
|
||||||
|
size_t shift = mp_get_nbits(tmp) - 1;
|
||||||
|
mp_free(tmp);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Make copies of a,b with those shared factors of 2 divided off,
|
||||||
|
* so that at least one is odd (which is the precondition for
|
||||||
|
* mp_bezout_into). Compute the gcd of those.
|
||||||
|
*/
|
||||||
|
mp_int *as = mp_rshift_safe(a, shift);
|
||||||
|
mp_int *bs = mp_rshift_safe(b, shift);
|
||||||
|
mp_bezout_into(A, B, gcd, as, bs);
|
||||||
|
mp_free(as);
|
||||||
|
mp_free(bs);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* And finally shift the gcd back up (unless the caller didn't
|
||||||
|
* even ask for it), to put the shared factors of 2 back in.
|
||||||
|
*/
|
||||||
|
if (gcd)
|
||||||
|
mp_lshift_safe_in_place(gcd, shift);
|
||||||
|
}
|
||||||
|
|
||||||
|
mp_int *mp_gcd(mp_int *a, mp_int *b)
|
||||||
|
{
|
||||||
|
mp_int *gcd = mp_make_sized(size_t_min(a->nw, b->nw));
|
||||||
|
mp_gcd_into(a, b, gcd, NULL, NULL);
|
||||||
|
return gcd;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned mp_coprime(mp_int *a, mp_int *b)
|
||||||
|
{
|
||||||
|
mp_int *gcd = mp_gcd(a, b);
|
||||||
|
unsigned toret = mp_eq_integer(gcd, 1);
|
||||||
|
mp_free(gcd);
|
||||||
|
return toret;
|
||||||
|
}
|
||||||
|
|
||||||
static uint32_t recip_approx_32(uint32_t x)
|
static uint32_t recip_approx_32(uint32_t x)
|
||||||
{
|
{
|
||||||
/*
|
/*
|
||||||
|
19
mpint.h
19
mpint.h
@ -270,6 +270,25 @@ void mp_reduce_mod_2to(mp_int *x, size_t p);
|
|||||||
mp_int *mp_invert_mod_2to(mp_int *x, size_t p);
|
mp_int *mp_invert_mod_2to(mp_int *x, size_t p);
|
||||||
mp_int *mp_invert(mp_int *x, mp_int *modulus);
|
mp_int *mp_invert(mp_int *x, mp_int *modulus);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Greatest common divisor.
|
||||||
|
*
|
||||||
|
* mp_gcd_into also returns a pair of Bezout coefficients, namely A,B
|
||||||
|
* such that a*A - b*B = gcd. (The minus sign is so that both returned
|
||||||
|
* coefficients can be positive.)
|
||||||
|
*
|
||||||
|
* You can pass any of mp_gcd_into's output pointers as NULL if you
|
||||||
|
* don't need that output value.
|
||||||
|
*
|
||||||
|
* mp_gcd is a wrapper with a less cumbersome API, for the case where
|
||||||
|
* the only output value you need is the gcd itself. mp_coprime is
|
||||||
|
* even easier, if all you care about is whether or not that gcd is 1.
|
||||||
|
*/
|
||||||
|
mp_int *mp_gcd(mp_int *a, mp_int *b);
|
||||||
|
void mp_gcd_into(mp_int *a, mp_int *b,
|
||||||
|
mp_int *gcd_out, mp_int *A_out, mp_int *B_out);
|
||||||
|
unsigned mp_coprime(mp_int *a, mp_int *b);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* System for taking square roots modulo an odd prime.
|
* System for taking square roots modulo an odd prime.
|
||||||
*
|
*
|
||||||
|
@ -437,6 +437,48 @@ class mpint(MyTestBase):
|
|||||||
int(monty_invert(mc, monty_import(mc, x))),
|
int(monty_invert(mc, monty_import(mc, x))),
|
||||||
int(monty_import(mc, inv)))
|
int(monty_import(mc, inv)))
|
||||||
|
|
||||||
|
def testGCD(self):
|
||||||
|
powerpairs = [(0,0), (1,0), (1,1), (2,1), (2,2), (75,3), (17,23)]
|
||||||
|
for a2, b2 in powerpairs:
|
||||||
|
for a3, b3 in powerpairs:
|
||||||
|
for a5, b5 in powerpairs:
|
||||||
|
a = 2**a2 * 3**a3 * 5**a5 * 17 * 19 * 23
|
||||||
|
b = 2**b2 * 3**b3 * 5**b5 * 65423
|
||||||
|
d = 2**min(a2, b2) * 3**min(a3, b3) * 5**min(a5, b5)
|
||||||
|
|
||||||
|
ma = mp_copy(a)
|
||||||
|
mb = mp_copy(b)
|
||||||
|
|
||||||
|
self.assertEqual(int(mp_gcd(ma, mb)), d)
|
||||||
|
|
||||||
|
md = mp_new(nbits(d))
|
||||||
|
mA = mp_new(nbits(b))
|
||||||
|
mB = mp_new(nbits(a))
|
||||||
|
mp_gcd_into(ma, mb, md, mA, mB)
|
||||||
|
self.assertEqual(int(md), d)
|
||||||
|
A = int(mA)
|
||||||
|
B = int(mB)
|
||||||
|
self.assertEqual(a*A - b*B, d)
|
||||||
|
self.assertTrue(0 <= A < b//d)
|
||||||
|
self.assertTrue(0 <= B < a//d)
|
||||||
|
|
||||||
|
self.assertEqual(mp_coprime(ma, mb), 1 if d==1 else 0)
|
||||||
|
|
||||||
|
# Make sure gcd_into can handle not getting some
|
||||||
|
# of its output pointers.
|
||||||
|
mp_clear(md)
|
||||||
|
mp_gcd_into(ma, mb, md, None, None)
|
||||||
|
self.assertEqual(int(md), d)
|
||||||
|
mp_clear(mA)
|
||||||
|
mp_gcd_into(ma, mb, None, mA, None)
|
||||||
|
self.assertEqual(int(mA), A)
|
||||||
|
mp_clear(mB)
|
||||||
|
mp_gcd_into(ma, mb, None, None, mB)
|
||||||
|
self.assertEqual(int(mB), B)
|
||||||
|
mp_gcd_into(ma, mb, None, None, None)
|
||||||
|
# No tests we can do after that last one - we just
|
||||||
|
# insist that it isn't allowed to have crashed!
|
||||||
|
|
||||||
def testMonty(self):
|
def testMonty(self):
|
||||||
moduli = [5, 19, 2**16+1, 2**31-1, 2**128-159, 2**255-19,
|
moduli = [5, 19, 2**16+1, 2**31-1, 2**128-159, 2**255-19,
|
||||||
293828847201107461142630006802421204703,
|
293828847201107461142630006802421204703,
|
||||||
|
@ -54,6 +54,9 @@ FUNC2(val_mpint, mp_mod, val_mpint, val_mpint)
|
|||||||
FUNC2(void, mp_reduce_mod_2to, val_mpint, uint)
|
FUNC2(void, mp_reduce_mod_2to, val_mpint, uint)
|
||||||
FUNC2(val_mpint, mp_invert_mod_2to, val_mpint, uint)
|
FUNC2(val_mpint, mp_invert_mod_2to, val_mpint, uint)
|
||||||
FUNC2(val_mpint, mp_invert, val_mpint, val_mpint)
|
FUNC2(val_mpint, mp_invert, val_mpint, val_mpint)
|
||||||
|
FUNC5(void, mp_gcd_into, val_mpint, val_mpint, opt_val_mpint, opt_val_mpint, opt_val_mpint)
|
||||||
|
FUNC2(val_mpint, mp_gcd, val_mpint, val_mpint)
|
||||||
|
FUNC2(uint, mp_coprime, val_mpint, val_mpint)
|
||||||
FUNC2(val_modsqrt, modsqrt_new, val_mpint, val_mpint)
|
FUNC2(val_modsqrt, modsqrt_new, val_mpint, val_mpint)
|
||||||
/* The modsqrt functions' 'success' pointer becomes a second return value */
|
/* The modsqrt functions' 'success' pointer becomes a second return value */
|
||||||
FUNC3(val_mpint, mp_modsqrt, val_modsqrt, val_mpint, out_uint)
|
FUNC3(val_mpint, mp_modsqrt, val_modsqrt, val_mpint, out_uint)
|
||||||
|
Loading…
Reference in New Issue
Block a user