diff --git a/mpint.c b/mpint.c index 577cc026..123e3ae3 100644 --- a/mpint.c +++ b/mpint.c @@ -1469,17 +1469,17 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus) * gcd(b,(a-b)/2). * * For this application, I always expect the actual gcd to be coprime, - * so we can rule out the 'both even' initial case. For simplicity - * I've changed the 'both odd' case to turn (a,b) into (b,a-b) without - * the division by 2 (the next iteration would divide by 2 anyway). + * so we can rule out the 'both even' initial case. So this function + * just performs a sequence of reductions in the following form: * - * But the big change is that we need the Bezout coefficients as - * output, not just the gcd. So we need to know how to generate those - * in each case, based on the coefficients from the reduced pair of - * numbers: + * - if a,b are both odd, sort them so that a > b, and replace a with + * b-a; otherwise sort them so that a is the even one + * - either way, now a is even and b is odd, so divide a by 2. * - * - If a,b are both odd, and u,v are such that u*b + v*(a-b) = 1, - * then v*a + (u-v)*b = 1. + * The big change to Stein's algorithm is that we need the Bezout + * coefficients as output, not just the gcd. So we need to know how to + * generate those in each case, based on the coefficients from the + * reduced pair of numbers: * * - If a is even, and u,v are such that u*(a/2) + v*b = 1: * + if u is also even, then this is just (u/2)*a + v*b = 1 @@ -1487,13 +1487,21 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus) * since u and b are both odd, (u+b)/2 is an integer, so we have * ((u+b)/2)*a + (v-a/2)*b = 1. * + * - If a,b are both odd, and u,v are such that u*b + v*(a-b) = 1, + * then v*a + (u-v)*b = 1. + * + * In the case where we passed from (a,b) to (b,(a-b)/2), we regard it + * as having first subtracted b from a and then halved a, so both of + * these transformations must be done in sequence. + * * The code below transforms this from a recursive to an iterative * algorithm. We first reduce a,b to 0,1, recording at each stage - * whether one of them was even, and whether we had to swap them; then - * we iterate backwards over that record of what we did, applying the - * above rules for building up the Bezout coefficients as we go. Of - * course, all the case analysis is done by the usual bit-twiddling - * conditionalisation to avoid data-dependent control flow. + * whether we did the initial subtraction, and whether we had to swap + * the two values; then we iterate backwards over that record of what + * we did, applying the above rules for building up the Bezout + * coefficients as we go. Of course, all the case analysis is done by + * the usual bit-twiddling conditionalisation to avoid data-dependent + * control flow. * * Also, since these mp_ints are generally treated as unsigned, we * store the coefficients by absolute value, with the semantics that @@ -1508,25 +1516,17 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus) * constant time, we just need to find the maximum number we could * _possibly_ require, and do that many. * - * If a,b < 2^n, at most 3n iterations are required. Proof: consider - * the quantity Q = log_2(min(a,b)) + 2 log_2(max(a,b)). - * - If the smaller number is even, then the next iteration halves - * it, decreasing Q by 1. - * - If the larger number is even, then the next iteration halves - * it, decreasing Q by 2. - * - If the two numbers are both odd, then the combined effect of the - * next two steps will be to replace the larger number with - * something less than half its original value. - * In any of these cases, the effect is that in k steps (where k = 1 - * or 2 depending on the case) Q decreases by at least k. So on - * average it decreases by at least 1 per step, and since it starts - * off at 3n, that's how many steps it might take. + * If a,b < 2^n, at most 2n iterations are required. Proof: consider + * the quantity Q = log_2(a) + log_2(b). Every step halves one of the + * numbers (and may also reduce one of them further by doing a + * subtraction beforehand, but in the worst case, not by much or not + * at all). So Q reduces by at least 1 per iteration, and it starts + * off with a value at most 2n. * * The worst case inputs (I think) are where x=2^{n-1} and y=2^n-1 * (i.e. x is a power of 2 and y is all 1s). In that situation, the * first n-1 steps repeatedly halve x until it's 1, and then there are - * n pairs of steps each of which subtracts 1 from y and then halves - * it. + * n further steps each of which subtracts 1 from y and halves it. */ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out, mp_int *a_in, mp_int *b_in) @@ -1551,7 +1551,7 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out, * mp_make_sized conveniently zeroes the allocation and mp_free * wipes it, and (b) this way I can use mp_dump() if I have to * debug this code. */ - size_t steps = 3 * nw * BIGNUM_INT_BITS; + size_t steps = 2 * nw * BIGNUM_INT_BITS; mp_int *record = mp_make_sized( (steps*2 + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS); @@ -1570,13 +1570,15 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out, mp_cond_swap(a, b, swap); /* - * Now, if we've made a the even number, divide it by two; if - * we've made it the larger of two odd numbers, subtract the - * smaller one from it. + * If a,b are both odd, then a is the larger number, so + * subtract the smaller one from it. */ - mp_rshift_fixed_into(tmp, a, 1); - mp_sub_into(a, a, b); - mp_select_into(a, tmp, a, both_odd); + mp_cond_sub_into(a, a, b, both_odd); + + /* + * Now a is even, so divide it by two. + */ + mp_rshift_fixed_into(a, a, 1); /* * Record the two 1-bit values both_odd and swap. @@ -1620,37 +1622,35 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out, unsigned swap = mp_get_bit(record, step*2+1); /* - * If this was a division step (!both_odd), and our - * coefficient of a is not the even one, we need to adjust the - * coefficients by +b and +a respectively. + * Unwind the division: if our coefficient of a is odd, we + * adjust the coefficients by +b and +a respectively. */ - unsigned adjust = (ac->w[0] & 1) & ~both_odd; + unsigned adjust = ac->w[0] & 1; mp_cond_add_into(ac, ac, b, adjust); mp_cond_add_into(bc, bc, a, adjust); /* - * Now, if it was a division step, then ac is even, and we - * divide it by two. + * Now ac is definitely even, so we divide it by two. */ - mp_rshift_fixed_into(tmp, ac, 1); - mp_select_into(ac, tmp, ac, both_odd); + mp_rshift_fixed_into(ac, ac, 1); /* - * But if it was a subtraction step, we add ac to bc instead. + * Now unwind the subtraction, if there was one, by adding + * ac to bc. */ mp_cond_add_into(bc, bc, ac, both_odd); /* - * Undo the transformation of the input numbers, by adding b - * to a (if both_odd) or multiplying a by 2 (otherwise). + * Undo the transformation of the input numbers, by + * multiplying a by 2 and then adding b to a (the latter + * only if both_odd). */ - mp_lshift_fixed_into(tmp, a, 1); - mp_add_into(a, a, b); - mp_select_into(a, tmp, a, both_odd); + mp_lshift_fixed_into(a, a, 1); + mp_cond_add_into(a, a, b, both_odd); /* - * Finally, undo the swap. If we do swap, this also reverses - * the sign of the current result ac*a+bc*b. + * Finally, undo the swap. If we do swap, this also + * reverses the sign of the current result ac*a+bc*b. */ mp_cond_swap(a, b, swap); mp_cond_swap(ac, bc, swap); diff --git a/test/cryptsuite.py b/test/cryptsuite.py index b3b23960..9bb1e022 100755 --- a/test/cryptsuite.py +++ b/test/cryptsuite.py @@ -359,13 +359,14 @@ class mpint(unittest.TestCase): # Test mp_invert proper. moduli = [2, 3, 2**16+1, 2**32-1, 2**32+1, 2**128-159, - 141421356237309504880168872420969807856967187537694807] + 141421356237309504880168872420969807856967187537694807, + 2**128-1] for m in moduli: # Prepare a MontyContext for the monty_invert test below # (unless m is even, in which case we can't) mc = monty_new(m) if m & 1 else None - to_invert = {1, 2, 3, 7, 19, m-1, 5*m//17} + to_invert = {1, 2, 3, 7, 19, m-1, 5*m//17, (m-1)//2, (m+1)//2} for x in sorted(to_invert): if gcd(x, m) != 1: continue # filter out non-invertible cases