diff --git a/Recipe b/Recipe index d4ef8c94..84c93d8f 100644 --- a/Recipe +++ b/Recipe @@ -282,7 +282,7 @@ UXSSH = SSH uxnoise uxagentc uxgss uxshare SFTP = psftpcommon sftp sftpcommon logging cmdline # Components of the prime-generation system. -SSHPRIME = sshprime smallprimes mpunsafe +SSHPRIME = sshprime smallprimes primecandidate mpunsafe # Miscellaneous objects appearing in all the utilities, or all the # network ones, or the Unix or Windows subsets of those in turn. diff --git a/primecandidate.c b/primecandidate.c new file mode 100644 index 00000000..f884f591 --- /dev/null +++ b/primecandidate.c @@ -0,0 +1,297 @@ +/* + * primecandidate.c: implementation of the PrimeCandidateSource + * abstraction declared in sshkeygen.h. + */ + +#include +#include "ssh.h" +#include "mpint.h" +#include "mpunsafe.h" +#include "sshkeygen.h" + +struct PrimeCandidateSource { + unsigned bits; + bool ready; + + /* We'll start by making up a random number strictly less than this ... */ + mp_int *limit; + + /* ... then we'll multiply by 'factor', and add 'addend'. */ + mp_int *factor, *addend; + + /* Then we'll try to add a small multiple of 'factor' to it to + * avoid it being a multiple of any small prime. Also, for RSA, we + * may need to avoid it being _this_ multiple of _this_: */ + unsigned avoid_residue, avoid_modulus; +}; + +PrimeCandidateSource *pcs_new(unsigned bits, unsigned first, unsigned nfirst) +{ + PrimeCandidateSource *s = snew(PrimeCandidateSource); + + assert(first >> (nfirst-1) == 1); + + s->bits = bits; + s->ready = false; + + /* Make the number that's the lower limit of our range */ + mp_int *firstmp = mp_from_integer(first); + mp_int *base = mp_lshift_fixed(firstmp, bits - nfirst); + mp_free(firstmp); + + /* Set the low bit of that, because all (nontrivial) primes are odd */ + mp_set_bit(base, 0, 1); + + /* That's our addend. Now initialise factor to 2, to ensure we + * only generate odd numbers */ + s->factor = mp_from_integer(2); + s->addend = base; + + /* And that means the limit of our random numbers must be one + * factor of two _less_ than the position of the low bit of + * 'first', because we'll be multiplying the random number by + * 2 immediately afterwards. */ + s->limit = mp_power_2(bits - nfirst - 1); + + /* avoid_modulus == 0 signals that there's no extra residue to avoid */ + s->avoid_residue = 1; + s->avoid_modulus = 0; + + return s; +} + +void pcs_free(PrimeCandidateSource *s) +{ + mp_free(s->limit); + mp_free(s->factor); + mp_free(s->addend); + sfree(s); +} + +static void pcs_require_residue_inner(PrimeCandidateSource *s, + mp_int *mod, mp_int *res) +{ + /* + * We already have a factor and addend. Ensure this one doesn't + * contradict it. + */ + mp_int *gcd = mp_gcd(mod, s->factor); + mp_int *test1 = mp_mod(s->addend, gcd); + mp_int *test2 = mp_mod(res, gcd); + assert(mp_cmp_eq(test1, test2)); + mp_free(test1); + mp_free(test2); + + /* + * Reduce our input factor and addend, which are constraints on + * the ultimate output number, so that they're constraints on the + * initial cofactor we're going to make up. + * + * If we're generating x and we want to ensure ax+b == r (mod m), + * how does that work? We've already checked that b == r modulo g + * = gcd(a,m), i.e. r-b is a multiple of g, and so are a and m. So + * let's write a=gA, m=gM, (r-b)=gR, and then we can start by + * dividing that off: + * + * ax == r-b (mod m ) + * => gAx == gR (mod gM) + * => Ax == R (mod M) + * + * Now the moduli A,M are coprime, which makes things easier. + * + * We're going to need to generate the x in this equation by + * generating a new smaller value y, multiplying it by M, and + * adding some constant K. So we have x = My + K, and we need to + * work out what K will satisfy the above equation. In other + * words, we need A(My+K) == R (mod M), and the AMy term vanishes, + * so we just need AK == R (mod M). So our congruence is solved by + * setting K to be R * A^{-1} mod M. + */ + mp_int *A = mp_div(s->factor, gcd); + mp_int *M = mp_div(mod, gcd); + mp_int *Rpre = mp_modsub(res, s->addend, mod); + mp_int *R = mp_div(Rpre, gcd); + mp_int *Ainv = mp_invert(A, M); + mp_int *K = mp_modmul(R, Ainv, M); + + mp_free(gcd); + mp_free(Rpre); + mp_free(Ainv); + mp_free(A); + mp_free(R); + + /* + * So we know we have to transform our existing (factor, addend) + * pair into (factor * M, addend * factor * K). Now we just need + * to work out what the limit should be on the random value we're + * generating. + * + * If we need My+K < old_limit, then y < (old_limit-K)/M. But the + * RHS is a fraction, so in integers, we need y < ceil of it. + */ + assert(!mp_cmp_hs(K, s->limit)); + mp_int *dividend = mp_add(s->limit, M); + mp_sub_integer_into(dividend, dividend, 1); + mp_sub_into(dividend, dividend, K); + mp_free(s->limit); + s->limit = mp_div(dividend, M); + mp_free(dividend); + + /* + * Now just update the real factor and addend, and we're done. + */ + + mp_int *addend_old = s->addend; + mp_int *tmp = mp_mul(s->factor, K); /* use the _old_ value of factor */ + s->addend = mp_add(s->addend, tmp); + mp_free(tmp); + mp_free(addend_old); + + mp_int *factor_old = s->factor; + s->factor = mp_mul(s->factor, M); + mp_free(factor_old); + + mp_free(M); + mp_free(K); + s->factor = mp_unsafe_shrink(s->factor); + s->addend = mp_unsafe_shrink(s->addend); + s->limit = mp_unsafe_shrink(s->limit); +} + +void pcs_require_residue(PrimeCandidateSource *s, + mp_int *mod, mp_int *res_orig) +{ + /* + * Reduce the input residue to its least non-negative value, in + * case it was given as a larger equivalent value. + */ + mp_int *res_reduced = mp_mod(res_orig, mod); + pcs_require_residue_inner(s, mod, res_reduced); + mp_free(res_reduced); +} + +void pcs_require_residue_1(PrimeCandidateSource *s, mp_int *mod) +{ + mp_int *res = mp_from_integer(1); + pcs_require_residue(s, mod, res); + mp_free(res); +} + +void pcs_avoid_residue_small(PrimeCandidateSource *s, + unsigned mod, unsigned res) +{ + assert(!s->avoid_modulus); /* can't cope with more than one */ + s->avoid_modulus = mod; + s->avoid_residue = res; +} + +void pcs_ready(PrimeCandidateSource *s) +{ + /* + * Reduce the upper limit of the range we're searching, to account + * for the fact that in the generation loop we may add up to 2^16 + * product to the random number we pick from that range. + * + * We can't do this until we've finished dividing limit by things, + * of course. + */ + + assert(mp_hs_integer(s->limit, 0x10001)); + mp_sub_integer_into(s->limit, s->limit, 0x10000); + + s->ready = true; +} + +mp_int *pcs_generate(PrimeCandidateSource *s) +{ + assert(s->ready); + + /* List the (modulus, residue) pairs we want to avoid. Mostly this + * will be 'don't be 0 mod any small prime', but we may have one + * to add from our parameters. */ + init_smallprimes(); + uint64_t avoidmod[NSMALLPRIMES + 1], avoidres[NSMALLPRIMES + 1]; + size_t navoid = 0; + for (size_t i = 0; i < NSMALLPRIMES; i++) { + avoidmod[navoid] = smallprimes[i]; + avoidres[navoid] = 0; + navoid++; + } + if (s->avoid_modulus) { + avoidmod[navoid] = s->avoid_modulus; + avoidres[navoid] = s->avoid_residue % s->avoid_modulus; + navoid++; + } + + while (true) { + mp_int *x = mp_random_upto(s->limit); + + uint64_t xres[NSMALLPRIMES + 1], xmul[NSMALLPRIMES + 1]; + for (size_t i = 0; i < navoid; i++) { + uint64_t mod = avoidmod[i], res = avoidres[i]; + + uint64_t factor_m = mp_unsafe_mod_integer(s->factor, mod); + uint64_t addend_m = mp_unsafe_mod_integer(s->addend, mod); + uint64_t x_m = mp_unsafe_mod_integer(x, mod); + + xmul[i] = factor_m; + xres[i] = (addend_m + x_m * factor_m - res + mod) % mod; + } + + /* + * Try to find a value delta such that x + delta * factor + * avoids all the residues we want to avoid. We select + * candidates at random to avoid a directional bias, and if we + * don't find one quickly enough, give up and try a fresh + * random x. + */ + unsigned delta; + for (unsigned delta_attempts = 0; delta_attempts < 1024 ;) { + unsigned char randbuf[64]; + random_read(randbuf, sizeof(randbuf)); + + for (size_t pos = 0; pos+2 <= sizeof(randbuf); + pos += 2, delta_attempts++) { + + delta = GET_16BIT_MSB_FIRST(randbuf + pos); + + bool ok = true; + for (size_t i = 0; i < navoid; i++) + if (!((xres[i] + delta * xmul[i]) % avoidmod[i])) { + ok = false; + break; + } + + if (ok) + goto found; + } + + smemclr(randbuf, sizeof(randbuf)); + } + + mp_free(x); + continue; /* try a new x */ + + found:; + /* + * We've found a viable delta. Make the final output value. + */ + mp_int *mpdelta = mp_from_integer(delta); + mp_int *xplus = mp_add(x, mpdelta); + mp_int *toret = mp_new(s->bits); + mp_mul_into(toret, xplus, s->factor); + mp_add_into(toret, toret, s->addend); + mp_free(mpdelta); + mp_free(xplus); + mp_free(x); + return toret; + } +} + +void pcs_inspect(PrimeCandidateSource *pcs, mp_int **limit_out, + mp_int **factor_out, mp_int **addend_out) +{ + *limit_out = mp_copy(pcs->limit); + *factor_out = mp_copy(pcs->factor); + *addend_out = mp_copy(pcs->addend); +} diff --git a/sshdssg.c b/sshdssg.c index 2d957734..10f3d72f 100644 --- a/sshdssg.c +++ b/sshdssg.c @@ -56,17 +56,15 @@ int dsa_generate(struct dss_key *key, int bits, progfn_t pfn, pfn(pfnparam, PROGFN_READY, 0, 0); - unsigned pfirst, qfirst; - invent_firstbits(&pfirst, &qfirst, 0); /* * Generate q: a prime of length 160. */ - mp_int *q = primegen(160, 2, 2, NULL, 1, pfn, pfnparam, qfirst); + mp_int *q = primegen(160, 0, 0, NULL, 1, pfn, pfnparam, 1); /* * Now generate p: a prime of length `bits', such that p-1 is * divisible by q. */ - mp_int *p = primegen(bits-160, 2, 2, q, 2, pfn, pfnparam, pfirst); + mp_int *p = primegen(bits, 0, 0, q, 2, pfn, pfnparam, 1); /* * Next we need g. Raise 2 to the power (p-1)/q modulo p, and diff --git a/sshkeygen.h b/sshkeygen.h index 89b1a243..f5af2a10 100644 --- a/sshkeygen.h +++ b/sshkeygen.h @@ -2,7 +2,7 @@ * sshkeygen.h: routines used internally to key generation. */ -/* +/* ---------------------------------------------------------------------- * A table of all the primes that fit in a 16-bit integer. Call * init_primes_array to make sure it's been initialised. */ @@ -10,3 +10,54 @@ #define NSMALLPRIMES 6542 /* number of primes < 65536 */ extern const unsigned short *const smallprimes; void init_smallprimes(void); + +/* ---------------------------------------------------------------------- + * A system for making up random candidate integers during prime + * generation. This unconditionally ensures that the numbers have the + * right number of bits and are not divisible by any prime in the + * smallprimes[] array above. It can also impose further constraints, + * as documented below. + */ +typedef struct PrimeCandidateSource PrimeCandidateSource; + +/* + * pcs_new: you say how many bits you want the prime to have (with the + * usual semantics that an n-bit number is in the range [2^{n-1},2^n)) + * and also specify what you want its topmost 'nfirst' bits to be. + * + * (The 'first' system is used for RSA keys, where you need to arrange + * that the product of your two primes is in a more tightly + * constrained range than the factor of 4 you'd get by just generating + * two (n/2)-bit primes and multiplying them. Any application that + * doesn't need that can simply specify first = nfirst = 1.) + */ +PrimeCandidateSource *pcs_new(unsigned bits, unsigned first, unsigned nfirst); + +/* Insist that generated numbers must be congruent to 'res' mod 'mod' */ +void pcs_require_residue(PrimeCandidateSource *s, mp_int *mod, mp_int *res); + +/* Convenience wrapper for the common case where res = 1 */ +void pcs_require_residue_1(PrimeCandidateSource *s, mp_int *mod); + +/* Insist that generated numbers must _not_ be congruent to 'res' mod + * 'mod'. This is used to avoid being 1 mod the RSA public exponent, + * which is small, so it only needs ordinary integer parameters. */ +void pcs_avoid_residue_small(PrimeCandidateSource *s, + unsigned mod, unsigned res); + +/* Prepare a PrimeCandidateSource to actually generate numbers. This + * function does last-minute computation that has to be delayed until + * all constraints have been input. */ +void pcs_ready(PrimeCandidateSource *s); + +/* Actually generate a candidate integer. You must free the result, of + * course. */ +mp_int *pcs_generate(PrimeCandidateSource *s); + +/* Free a PrimeCandidateSource. */ +void pcs_free(PrimeCandidateSource *s); + +/* Return some internal fields of the PCS. Used by testcrypt for + * unit-testing this system. */ +void pcs_inspect(PrimeCandidateSource *pcs, mp_int **limit_out, + mp_int **factor_out, mp_int **addend_out); diff --git a/sshprime.c b/sshprime.c index f6c02500..79f35c14 100644 --- a/sshprime.c +++ b/sshprime.c @@ -133,107 +133,24 @@ mp_int *primegen( int bits, int modulus, int residue, mp_int *factor, int phase, progfn_t pfn, void *pfnparam, unsigned firstbits) { - init_smallprimes(); - int progress = 0; size_t fbsize = 0; while (firstbits >> fbsize) /* work out how to align this */ fbsize++; + PrimeCandidateSource *pcs = pcs_new(bits, firstbits, fbsize); + if (factor) + pcs_require_residue_1(pcs, factor); + if (modulus) + pcs_avoid_residue_small(pcs, modulus, residue); + pcs_ready(pcs); + STARTOVER: pfn(pfnparam, PROGFN_PROGRESS, phase, ++progress); - /* - * Generate a k-bit random number with top and bottom bits set. - * Alternatively, if `factor' is nonzero, generate a k-bit - * random number with the top bit set and the bottom bit clear, - * multiply it by `factor', and add one. - */ - mp_int *p = mp_power_2(bits - 1); /* ensure top bit is 1 */ - mp_int *r = mp_random_bits(bits - 1); - mp_or_into(p, p, r); - mp_free(r); - mp_set_bit(p, 0, factor ? 0 : 1); /* set bottom bit appropriately */ - - for (size_t i = 0; i < fbsize; i++) - mp_set_bit(p, bits-fbsize + i, 1 & (firstbits >> i)); - - if (factor) { - mp_int *tmp = p; - p = mp_mul(tmp, factor); - mp_free(tmp); - assert(mp_get_bit(p, 0) == 0); - mp_set_bit(p, 0, 1); - } - - /* - * We need to ensure this random number is coprime to the first - * few primes, by repeatedly adding either 2 or 2*factor to it - * until it is. To do this we make a list of (modulus, residue) - * pairs to avoid, and we also add to that list the extra pair our - * caller wants to avoid. - */ - - /* List the moduli */ - unsigned long moduli[NSMALLPRIMES + 1]; - for (size_t i = 0; i < NSMALLPRIMES; i++) - moduli[i] = smallprimes[i]; - moduli[NSMALLPRIMES] = modulus; - - /* Find the residue of our starting number mod each of them. Also - * set up the multipliers array which tells us how each one will - * change when we increment the number (which isn't just 1 if - * we're incrementing by multiples of factor). */ - unsigned long residues[NSMALLPRIMES + 1], multipliers[NSMALLPRIMES + 1]; - for (size_t i = 0; i < lenof(moduli); i++) { - residues[i] = mp_unsafe_mod_integer(p, moduli[i]); - if (factor) - multipliers[i] = mp_unsafe_mod_integer(factor, moduli[i]); - else - multipliers[i] = 1; - } - - /* Adjust the last entry so that it avoids a residue other than zero */ - residues[NSMALLPRIMES] = (residues[NSMALLPRIMES] + modulus - - residue) % modulus; - - /* - * Now loop until no residue in that list is zero, to find a - * sensible increment. We maintain the increment in an ordinary - * integer, so if it gets too big, we'll have to give up and go - * back to making up a fresh random large integer. - */ - unsigned delta = 0; - while (1) { - for (size_t i = 0; i < lenof(moduli); i++) - if (!((residues[i] + delta * multipliers[i]) % moduli[i])) - goto found_a_zero; - - /* If we didn't exit that loop by goto, we've got our candidate. */ - break; - - found_a_zero: - delta += 2; - if (delta > 65536) { - mp_free(p); - goto STARTOVER; - } - } - - /* - * Having found a plausible increment, actually add it on. - */ - if (factor) { - mp_int *d = mp_from_integer(delta); - mp_int *df = mp_mul(d, factor); - mp_add_into(p, p, df); - mp_free(d); - mp_free(df); - } else { - mp_add_integer_into(p, p, delta); - } + mp_int *p = pcs_generate(pcs); /* * Now apply the Miller-Rabin primality test a few times. First diff --git a/test/cryptsuite.py b/test/cryptsuite.py index e6f638d6..1df54117 100755 --- a/test/cryptsuite.py +++ b/test/cryptsuite.py @@ -130,6 +130,9 @@ def mac_str(alg, key, message, cipher=None): ssh2_mac_update(m, message) return ssh2_mac_genresult(m) +def lcm(a, b): + return a * b // gcd(a, b) + class MyTestBase(unittest.TestCase): "Intermediate class that adds useful helper methods." def assertEqualBin(self, x, y): @@ -864,6 +867,78 @@ class ecc(MyTestBase): self.assertEqual(int(x), int(rGi.x)) self.assertEqual(int(y), int(rGi.y)) +class keygen(MyTestBase): + def testPrimeCandidateSource(self): + def inspect(pcs): + # Returns (pcs->limit, pcs->factor, pcs->addend) as Python integers + return tuple(map(int, pcs_inspect(pcs))) + + # Test accumulating modular congruence requirements, by + # inspecting the internal values computed during + # require_residue. We ensure that the addend satisfies all our + # congruences and the factor is the lcm of all the moduli + # (hence, the arithmetic progression defined by those + # parameters is precisely the set of integers satisfying the + # requirements); we also ensure that the limiting values + # (addend itself at the low end, and addend + (limit-1) * + # factor at the high end) are the maximal subsequence of that + # progression that are within the originally specified range. + + def check(pcs, lo, hi, mod_res_pairs): + limit, factor, addend = inspect(pcs) + + for mod, res in mod_res_pairs: + self.assertEqual(addend % mod, res % mod) + + self.assertEqual(factor, functools.reduce( + lcm, [mod for mod, res in mod_res_pairs])) + + self.assertFalse(lo <= addend + (-1) * factor < hi) + self.assertTrue (lo <= addend < hi) + self.assertTrue (lo <= addend + (limit-1) * factor < hi) + self.assertFalse(lo <= addend + limit * factor < hi) + + pcs = pcs_new(64, 1, 1) + check(pcs, 2**63, 2**64, [(2, 1)]) + pcs_require_residue(pcs, 3, 2) + check(pcs, 2**63, 2**64, [(2, 1), (3, 2)]) + pcs_require_residue_1(pcs, 7) + check(pcs, 2**63, 2**64, [(2, 1), (3, 2), (7, 1)]) + pcs_require_residue(pcs, 16, 7) + check(pcs, 2**63, 2**64, [(2, 1), (3, 2), (7, 1), (16, 7)]) + pcs_require_residue(pcs, 49, 8) + check(pcs, 2**63, 2**64, [(2, 1), (3, 2), (7, 1), (16, 7), (49, 8)]) + + # Now test-generate some actual values, and ensure they + # satisfy all the congruences, and also avoid one residue mod + # 5 that we told them to. Also, give a nontrivial range. + pcs = pcs_new(64, 0xAB, 8) + pcs_require_residue(pcs, 0x100, 0xCD) + pcs_require_residue_1(pcs, 65537) + pcs_avoid_residue_small(pcs, 5, 3) + pcs_ready(pcs) + with random_prng("test seed"): + for i in range(100): + n = int(pcs_generate(pcs)) + self.assertTrue((0xAB<<56) < n < (0xAC<<56)) + self.assertEqual(n % 0x100, 0xCD) + self.assertEqual(n % 65537, 1) + self.assertNotEqual(n % 5, 3) + + # I'm not actually testing here that the outputs of + # pcs_generate are non-multiples of _all_ primes up to + # 2^16. But checking this many for 100 turns is enough + # to be pretty sure. (If you take the product of + # (1-1/p) over all p in the list below, you find that + # a given random number has about a 13% chance of + # avoiding being a multiple of any of them. So 100 + # trials without a mistake gives you 0.13^100 < 10^-88 + # as the probability of it happening by chance. More + # likely the code is actually working :-) + + for p in [2,3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61]: + self.assertNotEqual(n % p, 0) + class crypt(MyTestBase): def testSSH1Fingerprint(self): # Example key and reference fingerprint value generated by diff --git a/testcrypt.c b/testcrypt.c index e01ed606..413f5fa1 100644 --- a/testcrypt.c +++ b/testcrypt.c @@ -31,6 +31,7 @@ #include "defs.h" #include "ssh.h" +#include "sshkeygen.h" #include "misc.h" #include "mpint.h" #include "ecc.h" @@ -93,6 +94,7 @@ uint64_t prng_reseed_time_ms(void) X(rsa, RSAKey *, rsa_free(v)) \ X(prng, prng *, prng_free(v)) \ X(keycomponents, key_components *, key_components_free(v)) \ + X(pcs, PrimeCandidateSource *, pcs_free(v)) \ /* end of list */ typedef struct Value Value; diff --git a/testcrypt.h b/testcrypt.h index 4ed190af..9efbb87d 100644 --- a/testcrypt.h +++ b/testcrypt.h @@ -265,6 +265,13 @@ FUNC1(opt_val_key, ecdsa_generate, uint) FUNC1(opt_val_key, eddsa_generate, uint) FUNC1(val_rsa, rsa1_generate, uint) FUNC5(val_mpint, primegen, uint, uint, uint, val_mpint, uint) +FUNC3(val_pcs, pcs_new, uint, uint, uint) +FUNC3(void, pcs_require_residue, val_pcs, val_mpint, val_mpint) +FUNC2(void, pcs_require_residue_1, val_pcs, val_mpint) +FUNC3(void, pcs_avoid_residue_small, val_pcs, uint, uint) +FUNC1(void, pcs_ready, val_pcs) +FUNC4(void, pcs_inspect, val_pcs, out_val_mpint, out_val_mpint, out_val_mpint) +FUNC1(val_mpint, pcs_generate, val_pcs) /* * Miscellaneous.