diff --git a/crypto/mpint.c b/crypto/mpint.c index fca7530f..f015bd09 100644 --- a/crypto/mpint.c +++ b/crypto/mpint.c @@ -2263,6 +2263,87 @@ mp_int *mp_mod(mp_int *n, mp_int *d) return r; } +uint32_t mp_mod_known_integer(mp_int *x, uint32_t m) +{ + uint64_t reciprocal = ((uint64_t)1 << 48) / m; + uint64_t accumulator = 0; + for (size_t i = mp_max_bytes(x); i-- > 0 ;) { + accumulator = 0x100 * accumulator + mp_get_byte(x, i); + /* + * Let A be the value in 'accumulator' at this point, and let + * R be the value it will have after we subtract quot*m below. + * + * Lemma 1: if A < 2^48, then R < 2m. + * + * Proof: + * + * By construction, we have 2^48/m - 1 < reciprocal <= 2^48/m. + * Multiplying that by the accumulator gives + * + * A/m * 2^48 - A < unshifted_quot <= A/m * 2^48 + * i.e. 0 <= (A/m * 2^48) - unshifted_quot < A + * i.e. 0 <= A/m - unshifted_quot/2^48 < A/2^48 + * + * So when we shift this quotient right by 48 bits, i.e. take + * the floor of (unshifted_quot/2^48), the value we take the + * floor of is at most A/2^48 less than the true rational + * value A/m that we _wanted_ to take the floor of. + * + * Provided A < 2^48, this is less than 1. So the quotient + * 'quot' that we've just produced is either the true quotient + * floor(A/m), or one less than it. Hence, the output value R + * is less than 2m. [] + * + * Lemma 2: if A < 2^16 m, then the multiplication of + * accumulator*reciprocal does not overflow. + * + * Proof: as above, we have reciprocal <= 2^48/m. Multiplying + * by A gives unshifted_quot <= 2^48 * A / m < 2^48 * 2^16 = + * 2^64. [] + */ + uint64_t unshifted_quot = accumulator * reciprocal; + uint64_t quot = unshifted_quot >> 48; + accumulator -= quot * m; + } + + /* + * Theorem 1: accumulator < 2m at the end of every iteration of + * this loop. + * + * Proof: induction on the above loop. + * + * Base case: at the start of the first loop iteration, the + * accumulator is 0, which is certainly < 2m. + * + * Inductive step: in each loop iteration, we take a value at most + * 2m-1, multiply it by 2^8, and add another byte less than 2^8 to + * generate the input value A to the reduction process above. So + * we have A < 2m * 2^8 - 1. We know m < 2^32 (because it was + * passed in as a uint32_t), so A < 2^41, which is enough to allow + * us to apply Lemma 1, showing that the value of 'accumulator' at + * the end of the loop is still < 2m. [] + * + * Corollary: we need at most one final subtraction of m to + * produce the canonical residue of x mod m, i.e. in the range + * [0,m). + * + * Theorem 2: no multiplication in the inner loop overflows. + * + * Proof: in Theorem 1 we established A < 2m * 2^8 - 1 in every + * iteration. That is less than m * 2^16, so Lemma 2 applies. + * + * The other multiplication, of quot * m, cannot overflow because + * quot is at most A/m, so quot*m <= A < 2^64. [] + */ + + uint32_t result = accumulator; + uint32_t reduced = result - m; + uint32_t select = -(reduced >> 31); + result = reduced ^ ((result ^ reduced) & select); + assert(result < m); + return result; +} + mp_int *mp_nthroot(mp_int *y, unsigned n, mp_int *remainder_out) { /* diff --git a/keygen/mpunsafe.c b/keygen/mpunsafe.c index f33532c4..6265d40f 100644 --- a/keygen/mpunsafe.c +++ b/keygen/mpunsafe.c @@ -45,13 +45,3 @@ mp_int *mp_unsafe_copy(mp_int *x) mp_copy_into(copy, x); return copy; } - -uint32_t mp_unsafe_mod_integer(mp_int *x, uint32_t modulus) -{ - uint64_t accumulator = 0; - for (size_t i = mp_max_bytes(x); i-- > 0 ;) { - accumulator = 0x100 * accumulator + mp_get_byte(x, i); - accumulator %= modulus; - } - return accumulator; -} diff --git a/keygen/mpunsafe.h b/keygen/mpunsafe.h index 0b6ba3bd..07215372 100644 --- a/keygen/mpunsafe.h +++ b/keygen/mpunsafe.h @@ -36,11 +36,4 @@ mp_int *mp_unsafe_shrink(mp_int *m); mp_int *mp_unsafe_copy(mp_int *m); -/* - * Compute the residue of x mod m. This is implemented in the most - * obvious way using the C % operator, which won't be constant-time on - * many C implementations. - */ -uint32_t mp_unsafe_mod_integer(mp_int *x, uint32_t m); - #endif /* PUTTY_MPINT_UNSAFE_H */ diff --git a/keygen/primecandidate.c b/keygen/primecandidate.c index cf55919e..02c0259d 100644 --- a/keygen/primecandidate.c +++ b/keygen/primecandidate.c @@ -341,8 +341,8 @@ void pcs_ready(PrimeCandidateSource *s) int64_t mod = s->avoids[i].mod, res = s->avoids[i].res; if (mod != last_mod) { last_mod = mod; - addend_m = mp_unsafe_mod_integer(s->addend, mod); - factor_m = mp_unsafe_mod_integer(s->factor, mod); + addend_m = mp_mod_known_integer(s->addend, mod); + factor_m = mp_mod_known_integer(s->factor, mod); } if (factor_m == 0) { @@ -385,7 +385,7 @@ mp_int *pcs_generate(PrimeCandidateSource *s) if (mod != last_mod) { last_mod = mod; - x_res = mp_unsafe_mod_integer(x, mod); + x_res = mp_mod_known_integer(x, mod); } if (x_res == avoid_res) { diff --git a/mpint.h b/mpint.h index 5611a007..51322aa6 100644 --- a/mpint.h +++ b/mpint.h @@ -257,6 +257,12 @@ void mp_divmod_into(mp_int *n, mp_int *d, mp_int *q, mp_int *r); mp_int *mp_div(mp_int *n, mp_int *d); mp_int *mp_mod(mp_int *x, mp_int *modulus); +/* + * Compute the residue of x mod m, where m is a small integer. x is + * kept secret, but m is not. + */ +uint32_t mp_mod_known_integer(mp_int *x, uint32_t m); + /* * Integer nth root. mp_nthroot returns the largest integer x such * that x^n <= y, and if 'remainder' is non-NULL then it fills it with