1
0
mirror of https://git.tartarus.org/simon/putty.git synced 2025-01-25 01:02:24 +00:00

Speed up and simplify mp_invert.

When I was originally designing my knockoff of Stein's algorithm, I
simplified it for my own understanding by replacing the step that
turns a into (a-b)/2 with a step that simply turned it into a-b, on
the basis that the next step would do the division by 2 in any case.
This made it easier to get my head round in the first place, and in
the initial Python prototype of the algorithm, it looked more sensible
to have two different kinds of simple step rather than one simple and
one complicated.

But actually, when it's rewritten under the constraints of time
invariance, the standard way is better, because we had to do the
computation for both kinds of step _anyway_, and this way we sometimes
make both of them useful at once instead of only ever using one.

So I've put it back to the more standard version of Stein, which is a
big improvement, because now we can run in at most 2n iterations
instead of 3n _and_ the code implementing each step is simpler. A
quick timing test suggests that modular inversion is now faster by a
factor of about 1.75.

Also, since I went to the effort of thinking up and commenting a pair
of worst-case inputs for the iteration count of Stein's algorithm, it
seems like an omission not to have made sure they were in the test
suite! Added extra tests that include 2^128-1 as a modulus and 2^127
as a value to invert.
This commit is contained in:
Simon Tatham 2019-01-05 13:47:26 +00:00
parent 4a0fa90979
commit 8e399f9aa7
2 changed files with 55 additions and 54 deletions

104
mpint.c
View File

@ -1469,17 +1469,17 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus)
* gcd(b,(a-b)/2). * gcd(b,(a-b)/2).
* *
* For this application, I always expect the actual gcd to be coprime, * For this application, I always expect the actual gcd to be coprime,
* so we can rule out the 'both even' initial case. For simplicity * so we can rule out the 'both even' initial case. So this function
* I've changed the 'both odd' case to turn (a,b) into (b,a-b) without * just performs a sequence of reductions in the following form:
* the division by 2 (the next iteration would divide by 2 anyway).
* *
* But the big change is that we need the Bezout coefficients as * - if a,b are both odd, sort them so that a > b, and replace a with
* output, not just the gcd. So we need to know how to generate those * b-a; otherwise sort them so that a is the even one
* in each case, based on the coefficients from the reduced pair of * - either way, now a is even and b is odd, so divide a by 2.
* numbers:
* *
* - If a,b are both odd, and u,v are such that u*b + v*(a-b) = 1, * The big change to Stein's algorithm is that we need the Bezout
* then v*a + (u-v)*b = 1. * coefficients as output, not just the gcd. So we need to know how to
* generate those in each case, based on the coefficients from the
* reduced pair of numbers:
* *
* - If a is even, and u,v are such that u*(a/2) + v*b = 1: * - If a is even, and u,v are such that u*(a/2) + v*b = 1:
* + if u is also even, then this is just (u/2)*a + v*b = 1 * + if u is also even, then this is just (u/2)*a + v*b = 1
@ -1487,13 +1487,21 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus)
* since u and b are both odd, (u+b)/2 is an integer, so we have * since u and b are both odd, (u+b)/2 is an integer, so we have
* ((u+b)/2)*a + (v-a/2)*b = 1. * ((u+b)/2)*a + (v-a/2)*b = 1.
* *
* - If a,b are both odd, and u,v are such that u*b + v*(a-b) = 1,
* then v*a + (u-v)*b = 1.
*
* In the case where we passed from (a,b) to (b,(a-b)/2), we regard it
* as having first subtracted b from a and then halved a, so both of
* these transformations must be done in sequence.
*
* The code below transforms this from a recursive to an iterative * The code below transforms this from a recursive to an iterative
* algorithm. We first reduce a,b to 0,1, recording at each stage * algorithm. We first reduce a,b to 0,1, recording at each stage
* whether one of them was even, and whether we had to swap them; then * whether we did the initial subtraction, and whether we had to swap
* we iterate backwards over that record of what we did, applying the * the two values; then we iterate backwards over that record of what
* above rules for building up the Bezout coefficients as we go. Of * we did, applying the above rules for building up the Bezout
* course, all the case analysis is done by the usual bit-twiddling * coefficients as we go. Of course, all the case analysis is done by
* conditionalisation to avoid data-dependent control flow. * the usual bit-twiddling conditionalisation to avoid data-dependent
* control flow.
* *
* Also, since these mp_ints are generally treated as unsigned, we * Also, since these mp_ints are generally treated as unsigned, we
* store the coefficients by absolute value, with the semantics that * store the coefficients by absolute value, with the semantics that
@ -1508,25 +1516,17 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus)
* constant time, we just need to find the maximum number we could * constant time, we just need to find the maximum number we could
* _possibly_ require, and do that many. * _possibly_ require, and do that many.
* *
* If a,b < 2^n, at most 3n iterations are required. Proof: consider * If a,b < 2^n, at most 2n iterations are required. Proof: consider
* the quantity Q = log_2(min(a,b)) + 2 log_2(max(a,b)). * the quantity Q = log_2(a) + log_2(b). Every step halves one of the
* - If the smaller number is even, then the next iteration halves * numbers (and may also reduce one of them further by doing a
* it, decreasing Q by 1. * subtraction beforehand, but in the worst case, not by much or not
* - If the larger number is even, then the next iteration halves * at all). So Q reduces by at least 1 per iteration, and it starts
* it, decreasing Q by 2. * off with a value at most 2n.
* - If the two numbers are both odd, then the combined effect of the
* next two steps will be to replace the larger number with
* something less than half its original value.
* In any of these cases, the effect is that in k steps (where k = 1
* or 2 depending on the case) Q decreases by at least k. So on
* average it decreases by at least 1 per step, and since it starts
* off at 3n, that's how many steps it might take.
* *
* The worst case inputs (I think) are where x=2^{n-1} and y=2^n-1 * The worst case inputs (I think) are where x=2^{n-1} and y=2^n-1
* (i.e. x is a power of 2 and y is all 1s). In that situation, the * (i.e. x is a power of 2 and y is all 1s). In that situation, the
* first n-1 steps repeatedly halve x until it's 1, and then there are * first n-1 steps repeatedly halve x until it's 1, and then there are
* n pairs of steps each of which subtracts 1 from y and then halves * n further steps each of which subtracts 1 from y and halves it.
* it.
*/ */
static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out, static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
mp_int *a_in, mp_int *b_in) mp_int *a_in, mp_int *b_in)
@ -1551,7 +1551,7 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
* mp_make_sized conveniently zeroes the allocation and mp_free * mp_make_sized conveniently zeroes the allocation and mp_free
* wipes it, and (b) this way I can use mp_dump() if I have to * wipes it, and (b) this way I can use mp_dump() if I have to
* debug this code. */ * debug this code. */
size_t steps = 3 * nw * BIGNUM_INT_BITS; size_t steps = 2 * nw * BIGNUM_INT_BITS;
mp_int *record = mp_make_sized( mp_int *record = mp_make_sized(
(steps*2 + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS); (steps*2 + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS);
@ -1570,13 +1570,15 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
mp_cond_swap(a, b, swap); mp_cond_swap(a, b, swap);
/* /*
* Now, if we've made a the even number, divide it by two; if * If a,b are both odd, then a is the larger number, so
* we've made it the larger of two odd numbers, subtract the * subtract the smaller one from it.
* smaller one from it.
*/ */
mp_rshift_fixed_into(tmp, a, 1); mp_cond_sub_into(a, a, b, both_odd);
mp_sub_into(a, a, b);
mp_select_into(a, tmp, a, both_odd); /*
* Now a is even, so divide it by two.
*/
mp_rshift_fixed_into(a, a, 1);
/* /*
* Record the two 1-bit values both_odd and swap. * Record the two 1-bit values both_odd and swap.
@ -1620,37 +1622,35 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
unsigned swap = mp_get_bit(record, step*2+1); unsigned swap = mp_get_bit(record, step*2+1);
/* /*
* If this was a division step (!both_odd), and our * Unwind the division: if our coefficient of a is odd, we
* coefficient of a is not the even one, we need to adjust the * adjust the coefficients by +b and +a respectively.
* coefficients by +b and +a respectively.
*/ */
unsigned adjust = (ac->w[0] & 1) & ~both_odd; unsigned adjust = ac->w[0] & 1;
mp_cond_add_into(ac, ac, b, adjust); mp_cond_add_into(ac, ac, b, adjust);
mp_cond_add_into(bc, bc, a, adjust); mp_cond_add_into(bc, bc, a, adjust);
/* /*
* Now, if it was a division step, then ac is even, and we * Now ac is definitely even, so we divide it by two.
* divide it by two.
*/ */
mp_rshift_fixed_into(tmp, ac, 1); mp_rshift_fixed_into(ac, ac, 1);
mp_select_into(ac, tmp, ac, both_odd);
/* /*
* But if it was a subtraction step, we add ac to bc instead. * Now unwind the subtraction, if there was one, by adding
* ac to bc.
*/ */
mp_cond_add_into(bc, bc, ac, both_odd); mp_cond_add_into(bc, bc, ac, both_odd);
/* /*
* Undo the transformation of the input numbers, by adding b * Undo the transformation of the input numbers, by
* to a (if both_odd) or multiplying a by 2 (otherwise). * multiplying a by 2 and then adding b to a (the latter
* only if both_odd).
*/ */
mp_lshift_fixed_into(tmp, a, 1); mp_lshift_fixed_into(a, a, 1);
mp_add_into(a, a, b); mp_cond_add_into(a, a, b, both_odd);
mp_select_into(a, tmp, a, both_odd);
/* /*
* Finally, undo the swap. If we do swap, this also reverses * Finally, undo the swap. If we do swap, this also
* the sign of the current result ac*a+bc*b. * reverses the sign of the current result ac*a+bc*b.
*/ */
mp_cond_swap(a, b, swap); mp_cond_swap(a, b, swap);
mp_cond_swap(ac, bc, swap); mp_cond_swap(ac, bc, swap);

View File

@ -359,13 +359,14 @@ class mpint(unittest.TestCase):
# Test mp_invert proper. # Test mp_invert proper.
moduli = [2, 3, 2**16+1, 2**32-1, 2**32+1, 2**128-159, moduli = [2, 3, 2**16+1, 2**32-1, 2**32+1, 2**128-159,
141421356237309504880168872420969807856967187537694807] 141421356237309504880168872420969807856967187537694807,
2**128-1]
for m in moduli: for m in moduli:
# Prepare a MontyContext for the monty_invert test below # Prepare a MontyContext for the monty_invert test below
# (unless m is even, in which case we can't) # (unless m is even, in which case we can't)
mc = monty_new(m) if m & 1 else None mc = monty_new(m) if m & 1 else None
to_invert = {1, 2, 3, 7, 19, m-1, 5*m//17} to_invert = {1, 2, 3, 7, 19, m-1, 5*m//17, (m-1)//2, (m+1)//2}
for x in sorted(to_invert): for x in sorted(to_invert):
if gcd(x, m) != 1: if gcd(x, m) != 1:
continue # filter out non-invertible cases continue # filter out non-invertible cases