diff --git a/keygen/millerrabin.c b/keygen/millerrabin.c index 19ca1bd3..24ee6193 100644 --- a/keygen/millerrabin.c +++ b/keygen/millerrabin.c @@ -95,10 +95,8 @@ struct MillerRabin { MontyContext *mc; - size_t k; - mp_int *q; - - mp_int *two, *pm1, *m_pm1; + mp_int *pm1, *m_pm1; + mp_int *lowbit, *two; }; MillerRabin *miller_rabin_new(mp_int *p) @@ -108,16 +106,19 @@ MillerRabin *miller_rabin_new(mp_int *p) assert(mp_hs_integer(p, 2)); assert(mp_get_bit(p, 0) == 1); - mr->k = 1; - while (!mp_get_bit(p, mr->k)) - mr->k++; - mr->q = mp_rshift_safe(p, mr->k); + mr->pm1 = mp_copy(p); + mp_sub_integer_into(mr->pm1, mr->pm1, 1); + + /* + * Standard bit-twiddling trick for isolating the lowest set bit + * of a number: x & (-x) + */ + mr->lowbit = mp_new(mp_max_bits(mr->pm1)); + mp_sub_into(mr->lowbit, mr->lowbit, mr->pm1); + mp_and_into(mr->lowbit, mr->lowbit, mr->pm1); mr->two = mp_from_integer(2); - mr->pm1 = mp_unsafe_copy(p); - mp_sub_integer_into(mr->pm1, mr->pm1, 1); - mr->mc = monty_new(p); mr->m_pm1 = monty_import(mr->mc, mr->pm1); @@ -126,10 +127,10 @@ MillerRabin *miller_rabin_new(mp_int *p) void miller_rabin_free(MillerRabin *mr) { - mp_free(mr->q); - mp_free(mr->two); mp_free(mr->pm1); mp_free(mr->m_pm1); + mp_free(mr->lowbit); + mp_free(mr->two); monty_free(mr->mc); smemclr(mr, sizeof(*mr)); sfree(mr); @@ -144,35 +145,93 @@ void miller_rabin_free(MillerRabin *mr) */ static struct mr_result miller_rabin_test_inner(MillerRabin *mr, mp_int *mw) { - /* - * Compute w^q mod p. - */ - mp_int *wqp = monty_pow(mr->mc, mw, mr->q); + mp_int *acc = mp_copy(monty_identity(mr->mc)); + mp_int *spare = mp_new(mp_max_bits(mr->pm1)); + size_t bit = mp_max_bits(mr->pm1); /* - * See if this is 1, or if it is -1, or if it becomes -1 - * when squared at most k-1 times. + * The obvious approach to Miller-Rabin would be to start by + * calling monty_pow to raise w to the power q, and then square it + * k times ourselves. But that introduces a timing leak that gives + * away the value of k, i.e., how many factors of 2 there are in + * p-1. + * + * Instead, we don't call monty_pow at all. We do a modular + * exponentiation ourselves to compute w^((p-1)/2), using the + * technique that works from the top bit of the exponent + * downwards. That is, in each iteration we compute + * w^floor(exponent/2^i) for i one less than the previous + * iteration, by squaring the value we previously had and then + * optionally multiplying in w if the next exponent bit is 1. + * + * At the end of that process, once i <= k, the division + * (exponent/2^i) yields an integer, so the values we're computing + * are not just w^(floor of that), but w^(exactly that). In other + * words, the last k intermediate values of this modexp are + * precisely the values M-R wants to check against +1 or -1. + * + * So we interleave those checks with the modexp loop itself, and + * to avoid a timing leak, we check _every_ intermediate result + * against (the Montgomery representations of) both +1 and -1. And + * then we do bitwise masking to arrange that only the sensible + * ones of those checks find their way into our final answer. */ + + unsigned active = 0; + struct mr_result result; - result.passed = false; - result.potential_primitive_root = false; + result.passed = result.potential_primitive_root = 0; - if (mp_cmp_eq(wqp, monty_identity(mr->mc))) { - result.passed = true; - } else { - for (size_t i = 0; i < mr->k; i++) { - if (mp_cmp_eq(wqp, mr->m_pm1)) { - result.passed = true; - result.potential_primitive_root = (i == mr->k - 1); - break; - } - if (i == mr->k - 1) - break; - monty_mul_into(mr->mc, wqp, wqp, wqp); - } + while (bit-- > 1) { + /* + * In this iteration, we're computing w^(2e) or w^(2e+1), + * where we have w^e from the previous iteration. So we square + * the value we had already, and then optionally multiply in + * another copy of w depending on the next bit of the exponent. + */ + monty_mul_into(mr->mc, acc, acc, acc); + monty_mul_into(mr->mc, spare, acc, mw); + mp_select_into(acc, acc, spare, mp_get_bit(mr->pm1, bit)); + + /* + * mr->lowbit is a number with only one bit set, corresponding + * to the lowest set bit in p-1. So when that's the bit of the + * exponent we've just processed, we'll detect it by setting + * first_iter to true. That's our indication that we're now + * generating intermediate results useful to M-R, so we also + * set 'active', which stays set from then on. + */ + unsigned first_iter = mp_get_bit(mr->lowbit, bit); + active |= first_iter; + + /* + * Check the intermediate result against both +1 and -1. + */ + unsigned is_plus_1 = mp_cmp_eq(acc, monty_identity(mr->mc)); + unsigned is_minus_1 = mp_cmp_eq(acc, mr->m_pm1); + + /* + * M-R must report success iff either: the first of the useful + * intermediate results (which is w^q) is 1, or _any_ of them + * (from w^q all the way up to w^((p-1)/2)) is -1. + * + * So we want to pass the test if is_plus_1 is set on the + * first iteration, or if is_minus_1 is set on any iteration. + */ + result.passed |= (first_iter & is_plus_1); + result.passed |= (active & is_minus_1); + + /* + * In the final iteration, is_minus_1 is also used to set the + * 'potential primitive root' flag, because we haven't found + * any exponent smaller than p-1 for which w^(that) == 1. + */ + if (bit == 1) + result.potential_primitive_root = is_minus_1; } - mp_free(wqp); + mp_free(acc); + mp_free(spare); return result; } diff --git a/sshkeygen.h b/sshkeygen.h index fae6fa83..60b2e836 100644 --- a/sshkeygen.h +++ b/sshkeygen.h @@ -97,8 +97,8 @@ void miller_rabin_free(MillerRabin *mr); /* Perform a single Miller-Rabin test, using a specified witness value. * Used in the test suite. */ struct mr_result { - bool passed; - bool potential_primitive_root; + unsigned passed; + unsigned potential_primitive_root; }; struct mr_result miller_rabin_test(MillerRabin *mr, mp_int *w); diff --git a/test/cryptsuite.py b/test/cryptsuite.py index 2993dbd4..2ab95481 100755 --- a/test/cryptsuite.py +++ b/test/cryptsuite.py @@ -1255,6 +1255,25 @@ class keygen(MyTestBase): assert(pow(2, n-1, n) == 1) # Fermat test would pass, but ... self.assertEqual(miller_rabin_test(mr, 2), "failed") # ... this fails + # A white-box test for the side-channel-safe M-R + # implementation, which has to check a^e against +-1 for every + # exponent e of the form floor((n-1) / power of 2), so as to + # avoid giving away exactly how many of the trailing values of + # that sequence are significant to the test. + # + # When the power of 2 is large enough that the division was + # not exact, the results of these comparisons are _not_ + # significant to the test, and we're required to ignore them! + # + # This pair of values has the property that none of the values + # legitimately computed by M-R is either +1 _or_ -1, but if + # you shift n-1 right by one too many bits (losing the lowest + # set bit of 0x6d00 to get 0x36), then _that_ power of the + # witness integer is -1. This should not cause a spurious pass. + n = 0x6d01 + mr = miller_rabin_new(n) + self.assertEqual(miller_rabin_test(mr, 0x251), "failed") + class crypt(MyTestBase): def testSSH1Fingerprint(self): # Example key and reference fingerprint value generated by