1
0
mirror of https://git.tartarus.org/simon/putty.git synced 2025-01-09 17:38:00 +00:00

Speed up and simplify mp_invert.

When I was originally designing my knockoff of Stein's algorithm, I
simplified it for my own understanding by replacing the step that
turns a into (a-b)/2 with a step that simply turned it into a-b, on
the basis that the next step would do the division by 2 in any case.
This made it easier to get my head round in the first place, and in
the initial Python prototype of the algorithm, it looked more sensible
to have two different kinds of simple step rather than one simple and
one complicated.

But actually, when it's rewritten under the constraints of time
invariance, the standard way is better, because we had to do the
computation for both kinds of step _anyway_, and this way we sometimes
make both of them useful at once instead of only ever using one.

So I've put it back to the more standard version of Stein, which is a
big improvement, because now we can run in at most 2n iterations
instead of 3n _and_ the code implementing each step is simpler. A
quick timing test suggests that modular inversion is now faster by a
factor of about 1.75.

Also, since I went to the effort of thinking up and commenting a pair
of worst-case inputs for the iteration count of Stein's algorithm, it
seems like an omission not to have made sure they were in the test
suite! Added extra tests that include 2^128-1 as a modulus and 2^127
as a value to invert.
This commit is contained in:
Simon Tatham 2019-01-05 13:47:26 +00:00
parent 4a0fa90979
commit 8e399f9aa7
2 changed files with 55 additions and 54 deletions

104
mpint.c
View File

@ -1469,17 +1469,17 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus)
* gcd(b,(a-b)/2).
*
* For this application, I always expect the actual gcd to be coprime,
* so we can rule out the 'both even' initial case. For simplicity
* I've changed the 'both odd' case to turn (a,b) into (b,a-b) without
* the division by 2 (the next iteration would divide by 2 anyway).
* so we can rule out the 'both even' initial case. So this function
* just performs a sequence of reductions in the following form:
*
* But the big change is that we need the Bezout coefficients as
* output, not just the gcd. So we need to know how to generate those
* in each case, based on the coefficients from the reduced pair of
* numbers:
* - if a,b are both odd, sort them so that a > b, and replace a with
* b-a; otherwise sort them so that a is the even one
* - either way, now a is even and b is odd, so divide a by 2.
*
* - If a,b are both odd, and u,v are such that u*b + v*(a-b) = 1,
* then v*a + (u-v)*b = 1.
* The big change to Stein's algorithm is that we need the Bezout
* coefficients as output, not just the gcd. So we need to know how to
* generate those in each case, based on the coefficients from the
* reduced pair of numbers:
*
* - If a is even, and u,v are such that u*(a/2) + v*b = 1:
* + if u is also even, then this is just (u/2)*a + v*b = 1
@ -1487,13 +1487,21 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus)
* since u and b are both odd, (u+b)/2 is an integer, so we have
* ((u+b)/2)*a + (v-a/2)*b = 1.
*
* - If a,b are both odd, and u,v are such that u*b + v*(a-b) = 1,
* then v*a + (u-v)*b = 1.
*
* In the case where we passed from (a,b) to (b,(a-b)/2), we regard it
* as having first subtracted b from a and then halved a, so both of
* these transformations must be done in sequence.
*
* The code below transforms this from a recursive to an iterative
* algorithm. We first reduce a,b to 0,1, recording at each stage
* whether one of them was even, and whether we had to swap them; then
* we iterate backwards over that record of what we did, applying the
* above rules for building up the Bezout coefficients as we go. Of
* course, all the case analysis is done by the usual bit-twiddling
* conditionalisation to avoid data-dependent control flow.
* whether we did the initial subtraction, and whether we had to swap
* the two values; then we iterate backwards over that record of what
* we did, applying the above rules for building up the Bezout
* coefficients as we go. Of course, all the case analysis is done by
* the usual bit-twiddling conditionalisation to avoid data-dependent
* control flow.
*
* Also, since these mp_ints are generally treated as unsigned, we
* store the coefficients by absolute value, with the semantics that
@ -1508,25 +1516,17 @@ mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus)
* constant time, we just need to find the maximum number we could
* _possibly_ require, and do that many.
*
* If a,b < 2^n, at most 3n iterations are required. Proof: consider
* the quantity Q = log_2(min(a,b)) + 2 log_2(max(a,b)).
* - If the smaller number is even, then the next iteration halves
* it, decreasing Q by 1.
* - If the larger number is even, then the next iteration halves
* it, decreasing Q by 2.
* - If the two numbers are both odd, then the combined effect of the
* next two steps will be to replace the larger number with
* something less than half its original value.
* In any of these cases, the effect is that in k steps (where k = 1
* or 2 depending on the case) Q decreases by at least k. So on
* average it decreases by at least 1 per step, and since it starts
* off at 3n, that's how many steps it might take.
* If a,b < 2^n, at most 2n iterations are required. Proof: consider
* the quantity Q = log_2(a) + log_2(b). Every step halves one of the
* numbers (and may also reduce one of them further by doing a
* subtraction beforehand, but in the worst case, not by much or not
* at all). So Q reduces by at least 1 per iteration, and it starts
* off with a value at most 2n.
*
* The worst case inputs (I think) are where x=2^{n-1} and y=2^n-1
* (i.e. x is a power of 2 and y is all 1s). In that situation, the
* first n-1 steps repeatedly halve x until it's 1, and then there are
* n pairs of steps each of which subtracts 1 from y and then halves
* it.
* n further steps each of which subtracts 1 from y and halves it.
*/
static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
mp_int *a_in, mp_int *b_in)
@ -1551,7 +1551,7 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
* mp_make_sized conveniently zeroes the allocation and mp_free
* wipes it, and (b) this way I can use mp_dump() if I have to
* debug this code. */
size_t steps = 3 * nw * BIGNUM_INT_BITS;
size_t steps = 2 * nw * BIGNUM_INT_BITS;
mp_int *record = mp_make_sized(
(steps*2 + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS);
@ -1570,13 +1570,15 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
mp_cond_swap(a, b, swap);
/*
* Now, if we've made a the even number, divide it by two; if
* we've made it the larger of two odd numbers, subtract the
* smaller one from it.
* If a,b are both odd, then a is the larger number, so
* subtract the smaller one from it.
*/
mp_rshift_fixed_into(tmp, a, 1);
mp_sub_into(a, a, b);
mp_select_into(a, tmp, a, both_odd);
mp_cond_sub_into(a, a, b, both_odd);
/*
* Now a is even, so divide it by two.
*/
mp_rshift_fixed_into(a, a, 1);
/*
* Record the two 1-bit values both_odd and swap.
@ -1620,37 +1622,35 @@ static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
unsigned swap = mp_get_bit(record, step*2+1);
/*
* If this was a division step (!both_odd), and our
* coefficient of a is not the even one, we need to adjust the
* coefficients by +b and +a respectively.
* Unwind the division: if our coefficient of a is odd, we
* adjust the coefficients by +b and +a respectively.
*/
unsigned adjust = (ac->w[0] & 1) & ~both_odd;
unsigned adjust = ac->w[0] & 1;
mp_cond_add_into(ac, ac, b, adjust);
mp_cond_add_into(bc, bc, a, adjust);
/*
* Now, if it was a division step, then ac is even, and we
* divide it by two.
* Now ac is definitely even, so we divide it by two.
*/
mp_rshift_fixed_into(tmp, ac, 1);
mp_select_into(ac, tmp, ac, both_odd);
mp_rshift_fixed_into(ac, ac, 1);
/*
* But if it was a subtraction step, we add ac to bc instead.
* Now unwind the subtraction, if there was one, by adding
* ac to bc.
*/
mp_cond_add_into(bc, bc, ac, both_odd);
/*
* Undo the transformation of the input numbers, by adding b
* to a (if both_odd) or multiplying a by 2 (otherwise).
* Undo the transformation of the input numbers, by
* multiplying a by 2 and then adding b to a (the latter
* only if both_odd).
*/
mp_lshift_fixed_into(tmp, a, 1);
mp_add_into(a, a, b);
mp_select_into(a, tmp, a, both_odd);
mp_lshift_fixed_into(a, a, 1);
mp_cond_add_into(a, a, b, both_odd);
/*
* Finally, undo the swap. If we do swap, this also reverses
* the sign of the current result ac*a+bc*b.
* Finally, undo the swap. If we do swap, this also
* reverses the sign of the current result ac*a+bc*b.
*/
mp_cond_swap(a, b, swap);
mp_cond_swap(ac, bc, swap);

View File

@ -359,13 +359,14 @@ class mpint(unittest.TestCase):
# Test mp_invert proper.
moduli = [2, 3, 2**16+1, 2**32-1, 2**32+1, 2**128-159,
141421356237309504880168872420969807856967187537694807]
141421356237309504880168872420969807856967187537694807,
2**128-1]
for m in moduli:
# Prepare a MontyContext for the monty_invert test below
# (unless m is even, in which case we can't)
mc = monty_new(m) if m & 1 else None
to_invert = {1, 2, 3, 7, 19, m-1, 5*m//17}
to_invert = {1, 2, 3, 7, 19, m-1, 5*m//17, (m-1)//2, (m+1)//2}
for x in sorted(to_invert):
if gcd(x, m) != 1:
continue # filter out non-invertible cases