diff --git a/primecandidate.c b/primecandidate.c index 2b191052..353dc4ad 100644 --- a/primecandidate.c +++ b/primecandidate.c @@ -9,6 +9,10 @@ #include "mpunsafe.h" #include "sshkeygen.h" +struct avoid { + unsigned mod, res; +}; + struct PrimeCandidateSource { unsigned bits; bool ready; @@ -23,6 +27,11 @@ struct PrimeCandidateSource { * avoid it being a multiple of any small prime. Also, for RSA, we * may need to avoid it being _this_ multiple of _this_: */ unsigned avoid_residue, avoid_modulus; + + /* Once we're actually running, this will be the complete list of + * (modulus, residue) pairs we want to avoid. */ + struct avoid *avoids; + size_t navoids, avoidsize; }; PrimeCandidateSource *pcs_new_with_firstbits(unsigned bits, @@ -35,6 +44,9 @@ PrimeCandidateSource *pcs_new_with_firstbits(unsigned bits, s->bits = bits; s->ready = false; + s->avoids = NULL; + s->navoids = s->avoidsize = 0; + /* Make the number that's the lower limit of our range */ mp_int *firstmp = mp_from_integer(first); mp_int *base = mp_lshift_fixed(firstmp, bits - nfirst); @@ -71,6 +83,7 @@ void pcs_free(PrimeCandidateSource *s) mp_free(s->limit); mp_free(s->factor); mp_free(s->addend); + sfree(s->avoids); sfree(s); } @@ -188,22 +201,111 @@ void pcs_avoid_residue_small(PrimeCandidateSource *s, { assert(!s->avoid_modulus); /* can't cope with more than one */ s->avoid_modulus = mod; - s->avoid_residue = res; + s->avoid_residue = res % mod; /* reduce, just in case */ +} + +static int avoid_cmp(const void *av, const void *bv) +{ + const struct avoid *a = (const struct avoid *)av; + const struct avoid *b = (const struct avoid *)bv; + return a->mod < b->mod ? -1 : a->mod > b->mod ? +1 : 0; +} + +static uint64_t invert(uint64_t a, uint64_t m) +{ + int64_t v0 = a, i0 = 1; + int64_t v1 = m, i1 = 0; + while (v0) { + int64_t tmp, q = v1 / v0; + tmp = v0; v0 = v1 - q*v0; v1 = tmp; + tmp = i0; i0 = i1 - q*i0; i1 = tmp; + } + assert(v1 == 1 || v1 == -1); + return i1 * v1; } void pcs_ready(PrimeCandidateSource *s) { /* - * Reduce the upper limit of the range we're searching, to account - * for the fact that in the generation loop we may add up to 2^16 - * product to the random number we pick from that range. - * - * We can't do this until we've finished dividing limit by things, - * of course. + * List all the small (modulus, residue) pairs we want to avoid. */ - assert(mp_hs_integer(s->limit, 0x10001)); - mp_sub_integer_into(s->limit, s->limit, 0x10000); + init_smallprimes(); + +#define ADD_AVOID(newmod, newres) do { \ + sgrowarray(s->avoids, s->avoidsize, s->navoids); \ + s->avoids[s->navoids].mod = (newmod); \ + s->avoids[s->navoids].res = (newres); \ + s->navoids++; \ + } while (0) + + unsigned limit = (mp_hs_integer(s->addend, 65536) ? 65536 : + mp_get_integer(s->addend)); + + /* + * Don't be divisible by any small prime, or at least, any prime + * smaller than our output number might actually manage to be. (If + * asked to generate a really small prime, it would be + * embarrassing to rule out legitimate answers on the grounds that + * they were divisible by themselves.) + */ + for (size_t i = 0; i < NSMALLPRIMES && smallprimes[i] < limit; i++) + ADD_AVOID(smallprimes[i], 0); + + /* + * Finally, if there's a particular modulus and residue we've been + * told to avoid, put it on the list. + */ + if (s->avoid_modulus) + ADD_AVOID(s->avoid_modulus, s->avoid_residue); + +#undef ADD_AVOID + + /* + * Sort our to-avoid list by modulus. Partly this is so that we'll + * check the smaller moduli first during the live runs, which lets + * us spot most failing cases earlier rather than later. Also, it + * brings equal moduli together, so that we can reuse the residue + * we computed from a previous one. + */ + qsort(s->avoids, s->navoids, sizeof(*s->avoids), avoid_cmp); + + /* + * Next, adjust each of these moduli to take account of our factor + * and addend. If we want factor*x+addend to avoid being congruent + * to 'res' modulo 'mod', then x itself must avoid being congruent + * to (res - addend) * factor^{-1}. + * + * If factor == 0 modulo mod, then the answer will have a fixed + * residue anyway, so we can discard it from our list to test. + */ + int64_t factor_m = 0, addend_m = 0, last_mod = 0; + + size_t out = 0; + for (size_t i = 0; i < s->navoids; i++) { + int64_t mod = s->avoids[i].mod, res = s->avoids[i].res; + if (mod != last_mod) { + last_mod = mod; + addend_m = mp_unsafe_mod_integer(s->addend, mod); + factor_m = mp_unsafe_mod_integer(s->factor, mod); + } + + if (factor_m == 0) { + assert(res != addend_m); + continue; + } + + res = (res - addend_m) * invert(factor_m, mod); + res %= mod; + if (res < 0) + res += mod; + + s->avoids[out].mod = mod; + s->avoids[out].res = res; + out++; + } + + s->navoids = out; s->ready = true; } @@ -212,83 +314,37 @@ mp_int *pcs_generate(PrimeCandidateSource *s) { assert(s->ready); - /* List the (modulus, residue) pairs we want to avoid. Mostly this - * will be 'don't be 0 mod any small prime', but we may have one - * to add from our parameters. */ - init_smallprimes(); - uint64_t avoidmod[NSMALLPRIMES + 1], avoidres[NSMALLPRIMES + 1]; - size_t navoid = 0; - for (size_t i = 0; i < NSMALLPRIMES; i++) { - avoidmod[navoid] = smallprimes[i]; - avoidres[navoid] = 0; - navoid++; - } - if (s->avoid_modulus) { - avoidmod[navoid] = s->avoid_modulus; - avoidres[navoid] = s->avoid_residue % s->avoid_modulus; - navoid++; - } - while (true) { mp_int *x = mp_random_upto(s->limit); - uint64_t xres[NSMALLPRIMES + 1], xmul[NSMALLPRIMES + 1]; - for (size_t i = 0; i < navoid; i++) { - uint64_t mod = avoidmod[i], res = avoidres[i]; + int64_t x_res = 0, last_mod = 0; + bool ok = true; - uint64_t factor_m = mp_unsafe_mod_integer(s->factor, mod); - uint64_t addend_m = mp_unsafe_mod_integer(s->addend, mod); - uint64_t x_m = mp_unsafe_mod_integer(x, mod); + for (size_t i = 0; i < s->navoids; i++) { + int64_t mod = s->avoids[i].mod, avoid_res = s->avoids[i].res; - xmul[i] = factor_m; - xres[i] = (addend_m + x_m * factor_m - res + mod) % mod; - } - - /* - * Try to find a value delta such that x + delta * factor - * avoids all the residues we want to avoid. We select - * candidates at random to avoid a directional bias, and if we - * don't find one quickly enough, give up and try a fresh - * random x. - */ - unsigned delta; - for (unsigned delta_attempts = 0; delta_attempts < 1024 ;) { - unsigned char randbuf[64]; - random_read(randbuf, sizeof(randbuf)); - - for (size_t pos = 0; pos+2 <= sizeof(randbuf); - pos += 2, delta_attempts++) { - - delta = GET_16BIT_MSB_FIRST(randbuf + pos); - - bool ok = true; - for (size_t i = 0; i < navoid; i++) - if (!((xres[i] + delta * xmul[i]) % avoidmod[i])) { - ok = false; - break; - } - - if (ok) - goto found; + if (mod != last_mod) { + last_mod = mod; + x_res = mp_unsafe_mod_integer(x, mod); } - smemclr(randbuf, sizeof(randbuf)); + if (x_res == avoid_res) { + ok = false; + break; + } } - mp_free(x); - continue; /* try a new x */ + if (!ok) { + mp_free(x); + continue; /* try a new x */ + } - found:; /* - * We've found a viable delta. Make the final output value. + * We've found a viable x. Make the final output value. */ - mp_int *mpdelta = mp_from_integer(delta); - mp_int *xplus = mp_add(x, mpdelta); mp_int *toret = mp_new(s->bits); - mp_mul_into(toret, xplus, s->factor); + mp_mul_into(toret, x, s->factor); mp_add_into(toret, toret, s->addend); - mp_free(mpdelta); - mp_free(xplus); mp_free(x); return toret; }