diff --git a/mpint.c b/mpint.c index 9541535e..7a70d400 100644 --- a/mpint.c +++ b/mpint.c @@ -1556,10 +1556,10 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus) } /* - * Given two coprime nonzero input integers a,b, returns two integers - * A,B such that A*a - B*b = 1. A,B will be the minimal non-negative - * pair satisfying that criterion, which is equivalent to saying that - * 0<=Ab, and gcd(a,b) = * gcd(b,(a-b)/2). * - * For this application, I always expect the actual gcd to be coprime, - * so we can rule out the 'both even' initial case. So this function - * just performs a sequence of reductions in the following form: + * Sometimes this function is used for modular inversion, in which + * case we already know we expect the two inputs to be coprime, so to + * save time the 'both even' initial case is assumed not to arise (or + * to have been handled already by the caller). So this function just + * performs a sequence of reductions in the following form: * * - if a,b are both odd, sort them so that a > b, and replace a with * b-a; otherwise sort them so that a is the even one @@ -1584,14 +1586,14 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus) * generate those in each case, based on the coefficients from the * reduced pair of numbers: * - * - If a is even, and u,v are such that u*(a/2) + v*b = 1: - * + if u is also even, then this is just (u/2)*a + v*b = 1 - * + otherwise, (u+b)*(a/2) + (v-a/2)*b is also equal to 1, and + * - If a is even, and u,v are such that u*(a/2) + v*b = d: + * + if u is also even, then this is just (u/2)*a + v*b = d + * + otherwise, (u+b)*(a/2) + (v-a/2)*b is also equal to d, and * since u and b are both odd, (u+b)/2 is an integer, so we have - * ((u+b)/2)*a + (v-a/2)*b = 1. + * ((u+b)/2)*a + (v-a/2)*b = d. * - * - If a,b are both odd, and u,v are such that u*b + v*(a-b) = 1, - * then v*a + (u-v)*b = 1. + * - If a,b are both odd, and u,v are such that u*b + v*(a-b) = d, + * then v*a + (u-v)*b = d. * * In the case where we passed from (a,b) to (b,(a-b)/2), we regard it * as having first subtracted b from a and then halved a, so both of @@ -1609,11 +1611,11 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus) * Also, since these mp_ints are generally treated as unsigned, we * store the coefficients by absolute value, with the semantics that * they always have opposite sign, and in the unwinding loop we keep a - * bit indicating whether Aa-Bb is currently expected to be +1 or -1, - * so that we can do one final conditional adjustment if it's -1. + * bit indicating whether Aa-Bb is currently expected to be +d or -d, + * so that we can do one final conditional adjustment if it's -d. * * Once the reduction rules have managed to reduce the input numbers - * to (0,1), then they are stable (the next reduction will always + * to (0,d), then they are stable (the next reduction will always * divide the even one by 2, which maps 0 to 0). So it doesn't matter * if we do more steps of the algorithm than necessary; hence, for * constant time, we just need to find the maximum number we could @@ -1632,7 +1634,7 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus) * n further steps each of which subtracts 1 from y and halves it. */ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out, - mp_int *a_in, mp_int *b_in) + mp_int *gcd_out, mp_int *a_in, mp_int *b_in) { size_t nw = size_t_max(1, size_t_max(a_in->nw, b_in->nw)); @@ -1691,99 +1693,126 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out, } /* - * Now we expect to have reduced the two numbers to 0 and 1, + * Now we expect to have reduced the two numbers to 0 and d, * although we don't know which way round. (But we avoid checking * this by assertion; sometimes we'll need to do this computation * without giving away that we already know the inputs were bogus. * So we'd prefer to just press on and return nonsense.) */ - /* - * So their Bezout coefficients at this point are simply - * themselves. - */ - mp_copy_into(ac, a); - mp_copy_into(bc, b); - - /* - * We'll maintain the invariant as we unwind that ac * a - bc * b - * is either +1 or -1, and we'll remember which. (We _could_ keep - * it at +1 the whole time, but it would cost more work every time - * round the loop, so it's cheaper to fix that up once at the - * end.) - * - * Initially, the result is +1 if a was the nonzero value after - * reduction, and -1 if b was. - */ - unsigned minus_one = b->w[0]; - - for (size_t step = steps; step-- > 0 ;) { + if (gcd_out) { /* - * Recover the data from the step we're unwinding. + * At this point we can return the actual gcd. Since one of + * a,b is it and the other is zero, the easiest way to get it + * is to add them together. */ - unsigned both_odd = mp_get_bit(record, step*2); - unsigned swap = mp_get_bit(record, step*2+1); - - /* - * Unwind the division: if our coefficient of a is odd, we - * adjust the coefficients by +b and +a respectively. - */ - unsigned adjust = ac->w[0] & 1; - mp_cond_add_into(ac, ac, b, adjust); - mp_cond_add_into(bc, bc, a, adjust); - - /* - * Now ac is definitely even, so we divide it by two. - */ - mp_rshift_fixed_into(ac, ac, 1); - - /* - * Now unwind the subtraction, if there was one, by adding - * ac to bc. - */ - mp_cond_add_into(bc, bc, ac, both_odd); - - /* - * Undo the transformation of the input numbers, by - * multiplying a by 2 and then adding b to a (the latter - * only if both_odd). - */ - mp_lshift_fixed_into(a, a, 1); - mp_cond_add_into(a, a, b, both_odd); - - /* - * Finally, undo the swap. If we do swap, this also - * reverses the sign of the current result ac*a+bc*b. - */ - mp_cond_swap(a, b, swap); - mp_cond_swap(ac, bc, swap); - minus_one ^= swap; + mp_add_into(gcd_out, a, b); } /* - * Now we expect to have recovered the input a,b. + * If the caller _only_ wanted the gcd, and neither Bezout + * coefficient is even required, we can skip the entire unwind + * stage. */ - assert(mp_cmp_eq(a, a_in) & mp_cmp_eq(b, b_in)); + if (a_coeff_out || b_coeff_out) { - /* - * But we might find that our current result is -1 instead of +1, - * that is, we have A',B' such that A'a - B'b = -1. - * - * In that situation, we set A = b-A' and B = a-B', giving us - * Aa-Bb = ab - A'a - ab + B'b = +1. - */ - mp_sub_into(tmp, b, ac); - mp_select_into(ac, ac, tmp, minus_one); - mp_sub_into(tmp, a, bc); - mp_select_into(bc, bc, tmp, minus_one); + /* + * The Bezout coefficients of a,b at this point are simply 0 + * for whichever of a,b is zero, and 1 for whichever is + * nonzero. The nonzero number equals gcd(a,b), which by + * assumption is odd, so we can do this by just taking the low + * bit of each one. + */ + ac->w[0] = mp_get_bit(a, 0); + bc->w[0] = mp_get_bit(b, 0); - /* - * Now we really are done. Return the outputs. - */ - if (a_coeff_out) - mp_copy_into(a_coeff_out, ac); - if (b_coeff_out) - mp_copy_into(b_coeff_out, bc); + /* + * Overwrite a,b themselves with those same numbers. This has + * the effect of dividing both of them by d, which will + * arrange that during the unwind stage we generate the + * minimal coefficients instead of a larger pair. + */ + mp_copy_into(a, ac); + mp_copy_into(b, bc); + + /* + * We'll maintain the invariant as we unwind that ac * a - bc + * * b is either +d or -d (or rather, +1/-1 after scaling by + * d), and we'll remember which. (We _could_ keep it at +d the + * whole time, but it would cost more work every time round + * the loop, so it's cheaper to fix that up once at the end.) + * + * Initially, the result is +d if a was the nonzero value after + * reduction, and -d if b was. + */ + unsigned minus_d = b->w[0]; + + for (size_t step = steps; step-- > 0 ;) { + /* + * Recover the data from the step we're unwinding. + */ + unsigned both_odd = mp_get_bit(record, step*2); + unsigned swap = mp_get_bit(record, step*2+1); + + /* + * Unwind the division: if our coefficient of a is odd, we + * adjust the coefficients by +b and +a respectively. + */ + unsigned adjust = ac->w[0] & 1; + mp_cond_add_into(ac, ac, b, adjust); + mp_cond_add_into(bc, bc, a, adjust); + + /* + * Now ac is definitely even, so we divide it by two. + */ + mp_rshift_fixed_into(ac, ac, 1); + + /* + * Now unwind the subtraction, if there was one, by adding + * ac to bc. + */ + mp_cond_add_into(bc, bc, ac, both_odd); + + /* + * Undo the transformation of the input numbers, by + * multiplying a by 2 and then adding b to a (the latter + * only if both_odd). + */ + mp_lshift_fixed_into(a, a, 1); + mp_cond_add_into(a, a, b, both_odd); + + /* + * Finally, undo the swap. If we do swap, this also + * reverses the sign of the current result ac*a+bc*b. + */ + mp_cond_swap(a, b, swap); + mp_cond_swap(ac, bc, swap); + minus_d ^= swap; + } + + /* + * Now we expect to have recovered the input a,b (or rather, + * the versions of them divided by d). But we might find that + * our current result is -d instead of +d, that is, we have + * A',B' such that A'a - B'b = -d. + * + * In that situation, we set A = b-A' and B = a-B', giving us + * Aa-Bb = ab - A'a - ab + B'b = +1. + */ + mp_sub_into(tmp, b, ac); + mp_select_into(ac, ac, tmp, minus_d); + mp_sub_into(tmp, a, bc); + mp_select_into(bc, bc, tmp, minus_d); + + /* + * Now we really are done. Return the outputs. + */ + if (a_coeff_out) + mp_copy_into(a_coeff_out, ac); + if (b_coeff_out) + mp_copy_into(b_coeff_out, bc); + + } mp_free(a); mp_free(b); @@ -1796,10 +1825,65 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out, mp_int *mp_invert(mp_int *x, mp_int *m) { mp_int *result = mp_make_sized(m->nw); - mp_bezout_into(result, NULL, x, m); + mp_bezout_into(result, NULL, NULL, x, m); return result; } +void mp_gcd_into(mp_int *a, mp_int *b, mp_int *gcd, mp_int *A, mp_int *B) +{ + /* + * Identify shared factors of 2. To do this we OR the two numbers + * to get something whose lowest set bit is in the right place, + * remove all higher bits by ANDing it with its own negation, and + * use mp_get_nbits to find the location of the single remaining + * set bit. + */ + mp_int *tmp = mp_make_sized(size_t_max(a->nw, b->nw)); + for (size_t i = 0; i < tmp->nw; i++) + tmp->w[i] = mp_word(a, i) | mp_word(b, i); + BignumCarry carry = 1; + for (size_t i = 0; i < tmp->nw; i++) { + BignumInt negw; + BignumADC(negw, carry, 0, ~tmp->w[i], carry); + tmp->w[i] &= negw; + } + size_t shift = mp_get_nbits(tmp) - 1; + mp_free(tmp); + + /* + * Make copies of a,b with those shared factors of 2 divided off, + * so that at least one is odd (which is the precondition for + * mp_bezout_into). Compute the gcd of those. + */ + mp_int *as = mp_rshift_safe(a, shift); + mp_int *bs = mp_rshift_safe(b, shift); + mp_bezout_into(A, B, gcd, as, bs); + mp_free(as); + mp_free(bs); + + /* + * And finally shift the gcd back up (unless the caller didn't + * even ask for it), to put the shared factors of 2 back in. + */ + if (gcd) + mp_lshift_safe_in_place(gcd, shift); +} + +mp_int *mp_gcd(mp_int *a, mp_int *b) +{ + mp_int *gcd = mp_make_sized(size_t_min(a->nw, b->nw)); + mp_gcd_into(a, b, gcd, NULL, NULL); + return gcd; +} + +unsigned mp_coprime(mp_int *a, mp_int *b) +{ + mp_int *gcd = mp_gcd(a, b); + unsigned toret = mp_eq_integer(gcd, 1); + mp_free(gcd); + return toret; +} + static uint32_t recip_approx_32(uint32_t x) { /* diff --git a/mpint.h b/mpint.h index f0bb0fe6..ffc1e160 100644 --- a/mpint.h +++ b/mpint.h @@ -270,6 +270,25 @@ void mp_reduce_mod_2to(mp_int *x, size_t p); mp_int *mp_invert_mod_2to(mp_int *x, size_t p); mp_int *mp_invert(mp_int *x, mp_int *modulus); +/* + * Greatest common divisor. + * + * mp_gcd_into also returns a pair of Bezout coefficients, namely A,B + * such that a*A - b*B = gcd. (The minus sign is so that both returned + * coefficients can be positive.) + * + * You can pass any of mp_gcd_into's output pointers as NULL if you + * don't need that output value. + * + * mp_gcd is a wrapper with a less cumbersome API, for the case where + * the only output value you need is the gcd itself. mp_coprime is + * even easier, if all you care about is whether or not that gcd is 1. + */ +mp_int *mp_gcd(mp_int *a, mp_int *b); +void mp_gcd_into(mp_int *a, mp_int *b, + mp_int *gcd_out, mp_int *A_out, mp_int *B_out); +unsigned mp_coprime(mp_int *a, mp_int *b); + /* * System for taking square roots modulo an odd prime. * diff --git a/test/cryptsuite.py b/test/cryptsuite.py index b25123c8..02f77d56 100755 --- a/test/cryptsuite.py +++ b/test/cryptsuite.py @@ -437,6 +437,48 @@ class mpint(MyTestBase): int(monty_invert(mc, monty_import(mc, x))), int(monty_import(mc, inv))) + def testGCD(self): + powerpairs = [(0,0), (1,0), (1,1), (2,1), (2,2), (75,3), (17,23)] + for a2, b2 in powerpairs: + for a3, b3 in powerpairs: + for a5, b5 in powerpairs: + a = 2**a2 * 3**a3 * 5**a5 * 17 * 19 * 23 + b = 2**b2 * 3**b3 * 5**b5 * 65423 + d = 2**min(a2, b2) * 3**min(a3, b3) * 5**min(a5, b5) + + ma = mp_copy(a) + mb = mp_copy(b) + + self.assertEqual(int(mp_gcd(ma, mb)), d) + + md = mp_new(nbits(d)) + mA = mp_new(nbits(b)) + mB = mp_new(nbits(a)) + mp_gcd_into(ma, mb, md, mA, mB) + self.assertEqual(int(md), d) + A = int(mA) + B = int(mB) + self.assertEqual(a*A - b*B, d) + self.assertTrue(0 <= A < b//d) + self.assertTrue(0 <= B < a//d) + + self.assertEqual(mp_coprime(ma, mb), 1 if d==1 else 0) + + # Make sure gcd_into can handle not getting some + # of its output pointers. + mp_clear(md) + mp_gcd_into(ma, mb, md, None, None) + self.assertEqual(int(md), d) + mp_clear(mA) + mp_gcd_into(ma, mb, None, mA, None) + self.assertEqual(int(mA), A) + mp_clear(mB) + mp_gcd_into(ma, mb, None, None, mB) + self.assertEqual(int(mB), B) + mp_gcd_into(ma, mb, None, None, None) + # No tests we can do after that last one - we just + # insist that it isn't allowed to have crashed! + def testMonty(self): moduli = [5, 19, 2**16+1, 2**31-1, 2**128-159, 2**255-19, 293828847201107461142630006802421204703, diff --git a/testcrypt.h b/testcrypt.h index 6832e017..2fec0ba2 100644 --- a/testcrypt.h +++ b/testcrypt.h @@ -54,6 +54,9 @@ FUNC2(val_mpint, mp_mod, val_mpint, val_mpint) FUNC2(void, mp_reduce_mod_2to, val_mpint, uint) FUNC2(val_mpint, mp_invert_mod_2to, val_mpint, uint) FUNC2(val_mpint, mp_invert, val_mpint, val_mpint) +FUNC5(void, mp_gcd_into, val_mpint, val_mpint, opt_val_mpint, opt_val_mpint, opt_val_mpint) +FUNC2(val_mpint, mp_gcd, val_mpint, val_mpint) +FUNC2(uint, mp_coprime, val_mpint, val_mpint) FUNC2(val_modsqrt, modsqrt_new, val_mpint, val_mpint) /* The modsqrt functions' 'success' pointer becomes a second return value */ FUNC3(val_mpint, mp_modsqrt, val_modsqrt, val_mpint, out_uint)