From ece788240c3fed2a77d42a7783907fb8a29640e1 Mon Sep 17 00:00:00 2001 From: Simon Tatham Date: Sat, 29 Feb 2020 09:10:47 +0000 Subject: [PATCH] Introduce a vtable system for prime generation. The functions primegen() and primegen_add_progress_phase() are gone. In their place is a small vtable system with two methods corresponding to them, plus the usual admin of allocating and freeing contexts. This API change is the starting point for being able to drop in different prime generation algorithms at run time in response to user configuration. --- cmdgen.c | 9 ++++-- ssh1login-server.c | 7 ++++- ssh2kex-server.c | 13 +++++++-- sshdssg.c | 11 ++++---- sshkeygen.h | 44 ++++++++++++++++++++++++----- sshprime.c | 33 ++++++++++++++++++---- sshrsag.c | 17 +++++------ test/testcrypt.py | 2 +- testcrypt.c | 70 +++++++++++++++++++++++++++++++--------------- testcrypt.h | 9 +++--- windows/winpgen.c | 9 ++++-- 11 files changed, 163 insertions(+), 61 deletions(-) diff --git a/cmdgen.c b/cmdgen.c index 6d06c443..d48b3166 100644 --- a/cmdgen.c +++ b/cmdgen.c @@ -668,9 +668,12 @@ int main(int argc, char **argv) smemclr(entropy, bits/8); sfree(entropy); + PrimeGenerationContext *pgc = primegen_new_context( + &primegen_probabilistic); + if (keytype == DSA) { struct dss_key *dsskey = snew(struct dss_key); - dsa_generate(dsskey, bits, &cmdgen_progress); + dsa_generate(dsskey, bits, pgc, &cmdgen_progress); ssh2key = snew(ssh2_userkey); ssh2key->key = &dsskey->sshk; ssh1key = NULL; @@ -688,7 +691,7 @@ int main(int argc, char **argv) ssh1key = NULL; } else { RSAKey *rsakey = snew(RSAKey); - rsa_generate(rsakey, bits, &cmdgen_progress); + rsa_generate(rsakey, bits, pgc, &cmdgen_progress); rsakey->comment = NULL; if (keytype == RSA1) { ssh1key = rsakey; @@ -698,6 +701,8 @@ int main(int argc, char **argv) } } + primegen_free_context(pgc); + if (ssh2key) ssh2key->comment = dupstr(default_comment); if (ssh1key) diff --git a/ssh1login-server.c b/ssh1login-server.c index 528c9d09..bd8de7f9 100644 --- a/ssh1login-server.c +++ b/ssh1login-server.c @@ -140,9 +140,14 @@ static void ssh1_login_server_process_queue(PacketProtocolLayer *ppl) if (server_key_bits < 512) server_key_bits = s->hostkey->bytes + 256; s->servkey = snew(RSAKey); + + PrimeGenerationContext *pgc = primegen_new_context( + &primegen_probabilistic); ProgressReceiver null_progress; null_progress.vt = &null_progress_vt; - rsa_generate(s->servkey, server_key_bits, &null_progress); + rsa_generate(s->servkey, server_key_bits, pgc, &null_progress); + primegen_free_context(pgc); + s->servkey->comment = NULL; s->servkey_generated_here = true; } diff --git a/ssh2kex-server.c b/ssh2kex-server.c index d2aef99d..8a8c0f3a 100644 --- a/ssh2kex-server.c +++ b/ssh2kex-server.c @@ -97,9 +97,13 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) * group! It's good enough for testing a client against, * but not for serious use. */ + PrimeGenerationContext *pgc = primegen_new_context( + &primegen_probabilistic); ProgressReceiver null_progress; null_progress.vt = &null_progress_vt; - s->p = primegen(pcs_new(s->pbits), &null_progress); + s->p = primegen_generate(pgc, pcs_new(s->pbits), &null_progress); + primegen_free_context(pgc); + s->g = mp_from_integer(2); s->dh_ctx = dh_setup_gex(s->p, s->g); s->kex_init_value = SSH2_MSG_KEX_DH_GEX_INIT; @@ -261,9 +265,14 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) ppl_logevent("Generating a %d-bit RSA key", extra->minklen); s->rsa_kex_key = snew(RSAKey); + + PrimeGenerationContext *pgc = primegen_new_context( + &primegen_probabilistic); ProgressReceiver null_progress; null_progress.vt = &null_progress_vt; - rsa_generate(s->rsa_kex_key, extra->minklen, &null_progress); + rsa_generate(s->rsa_kex_key, extra->minklen, pgc, &null_progress); + primegen_free_context(pgc); + s->rsa_kex_key->comment = NULL; s->rsa_kex_key_needs_freeing = true; } diff --git a/sshdssg.c b/sshdssg.c index eb369a7a..ad3fcc02 100644 --- a/sshdssg.c +++ b/sshdssg.c @@ -7,7 +7,8 @@ #include "sshkeygen.h" #include "mpint.h" -int dsa_generate(struct dss_key *key, int bits, ProgressReceiver *prog) +int dsa_generate(struct dss_key *key, int bits, PrimeGenerationContext *pgc, + ProgressReceiver *prog) { /* * Progress-reporting setup. @@ -26,8 +27,8 @@ int dsa_generate(struct dss_key *key, int bits, ProgressReceiver *prog) * (So the probability of success will end up indistinguishable * from 1 in IEEE standard floating point! But what can you do.) */ - ProgressPhase phase_q = primegen_add_progress_phase(prog, 160); - ProgressPhase phase_p = primegen_add_progress_phase(prog, bits); + ProgressPhase phase_q = primegen_add_progress_phase(pgc, prog, 160); + ProgressPhase phase_p = primegen_add_progress_phase(pgc, prog, bits); double g_failure_probability = 1.0 / (double)(1ULL << 53) / (double)(1ULL << 53) @@ -43,7 +44,7 @@ int dsa_generate(struct dss_key *key, int bits, ProgressReceiver *prog) */ progress_start_phase(prog, phase_q); pcs = pcs_new(160); - mp_int *q = primegen(pcs, prog); + mp_int *q = primegen_generate(pgc, pcs, prog); progress_report_phase_complete(prog); /* @@ -53,7 +54,7 @@ int dsa_generate(struct dss_key *key, int bits, ProgressReceiver *prog) progress_start_phase(prog, phase_p); pcs = pcs_new(bits); pcs_require_residue_1(pcs, q); - mp_int *p = primegen(pcs, prog); + mp_int *p = primegen_generate(pgc, pcs, prog); progress_report_phase_complete(prog); /* diff --git a/sshkeygen.h b/sshkeygen.h index f82d0dc9..2565302a 100644 --- a/sshkeygen.h +++ b/sshkeygen.h @@ -148,18 +148,48 @@ double estimate_modexp_cost(unsigned bits); * The top-level API for generating primes. */ -/* This function consumes and frees the PrimeCandidateSource you give it */ -mp_int *primegen(PrimeCandidateSource *pcs, ProgressReceiver *prog); +typedef struct PrimeGenerationPolicy PrimeGenerationPolicy; +typedef struct PrimeGenerationContext PrimeGenerationContext; -/* Estimate how long it will take, and add a phase to a ProgressReceiver */ -ProgressPhase primegen_add_progress_phase(ProgressReceiver *prog, - unsigned bits); +struct PrimeGenerationContext { + const PrimeGenerationPolicy *vt; +}; + +struct PrimeGenerationPolicy { + ProgressPhase (*add_progress_phase)(const PrimeGenerationPolicy *policy, + ProgressReceiver *prog, unsigned bits); + PrimeGenerationContext *(*new_context)( + const PrimeGenerationPolicy *policy); + void (*free_context)(PrimeGenerationContext *ctx); + mp_int *(*generate)( + PrimeGenerationContext *ctx, + PrimeCandidateSource *pcs, ProgressReceiver *prog); + + const void *extra; /* additional data a particular impl might need */ +}; + +static inline ProgressPhase primegen_add_progress_phase( + PrimeGenerationContext *ctx, ProgressReceiver *prog, unsigned bits) +{ return ctx->vt->add_progress_phase(ctx->vt, prog, bits); } +static inline PrimeGenerationContext *primegen_new_context( + const PrimeGenerationPolicy *policy) +{ return policy->new_context(policy); } +static inline void primegen_free_context(PrimeGenerationContext *ctx) +{ ctx->vt->free_context(ctx); } +static inline mp_int *primegen_generate( + PrimeGenerationContext *ctx, + PrimeCandidateSource *pcs, ProgressReceiver *prog) +{ return ctx->vt->generate(ctx, pcs, prog); } + +extern const PrimeGenerationPolicy primegen_probabilistic; /* ---------------------------------------------------------------------- * The overall top-level API for generating entire key pairs. */ -int rsa_generate(RSAKey *key, int bits, ProgressReceiver *prog); -int dsa_generate(struct dss_key *key, int bits, ProgressReceiver *prog); +int rsa_generate(RSAKey *key, int bits, PrimeGenerationContext *pgc, + ProgressReceiver *prog); +int dsa_generate(struct dss_key *key, int bits, PrimeGenerationContext *pgc, + ProgressReceiver *prog); int ecdsa_generate(struct ecdsa_key *key, int bits); int eddsa_generate(struct eddsa_key *key, int bits); diff --git a/sshprime.c b/sshprime.c index 9220d4e6..0ee435e9 100644 --- a/sshprime.c +++ b/sshprime.c @@ -10,9 +10,8 @@ #include "mpunsafe.h" #include "sshkeygen.h" -/* - * This prime generation algorithm is pretty much cribbed from - * OpenSSL. The algorithm is: +/* ---------------------------------------------------------------------- + * Standard probabilistic prime-generation algorithm: * * - invent a B-bit random number and ensure the top and bottom * bits are set (so it's definitely B-bit, and it's definitely @@ -28,8 +27,22 @@ * - go back to square one if any M-R test fails. */ -ProgressPhase primegen_add_progress_phase(ProgressReceiver *prog, - unsigned bits) +static PrimeGenerationContext *probprime_new_context( + const PrimeGenerationPolicy *policy) +{ + PrimeGenerationContext *ctx = snew(PrimeGenerationContext); + ctx->vt = policy; + return ctx; +} + +static void probprime_free_context(PrimeGenerationContext *ctx) +{ + sfree(ctx); +} + +static ProgressPhase probprime_add_progress_phase( + const PrimeGenerationPolicy *policy, + ProgressReceiver *prog, unsigned bits) { /* * The density of primes near x is 1/(log x). When x is about 2^b, @@ -56,7 +69,9 @@ ProgressPhase primegen_add_progress_phase(ProgressReceiver *prog, return progress_add_probabilistic(prog, cost, prob); } -mp_int *primegen(PrimeCandidateSource *pcs, ProgressReceiver *prog) +static mp_int *probprime_generate( + PrimeGenerationContext *ctx, + PrimeCandidateSource *pcs, ProgressReceiver *prog) { pcs_ready(pcs); @@ -88,6 +103,12 @@ mp_int *primegen(PrimeCandidateSource *pcs, ProgressReceiver *prog) } } +const PrimeGenerationPolicy primegen_probabilistic = { + probprime_add_progress_phase, + probprime_new_context, + probprime_free_context, + probprime_generate, +}; /* ---------------------------------------------------------------------- * Reusable null implementation of the progress-reporting API. diff --git a/sshrsag.c b/sshrsag.c index 91dbcdb1..278c9848 100644 --- a/sshrsag.c +++ b/sshrsag.c @@ -14,10 +14,9 @@ static void invent_firstbits(unsigned *one, unsigned *two, unsigned min_separation); -int rsa_generate(RSAKey *key, int bits, ProgressReceiver *prog) +int rsa_generate(RSAKey *key, int bits, PrimeGenerationContext *pgc, + ProgressReceiver *prog) { - unsigned pfirst, qfirst; - key->sshk.vt = &ssh_rsa; /* @@ -38,27 +37,29 @@ int rsa_generate(RSAKey *key, int bits, ProgressReceiver *prog) * more so than an attacker guessing a whole 256-bit session key - * but it doesn't cost much to make sure.) */ - invent_firstbits(&pfirst, &qfirst, 2); int qbits = bits / 2; int pbits = bits - qbits; assert(pbits >= qbits); - ProgressPhase phase_p = primegen_add_progress_phase(prog, pbits); - ProgressPhase phase_q = primegen_add_progress_phase(prog, qbits); + ProgressPhase phase_p = primegen_add_progress_phase(pgc, prog, pbits); + ProgressPhase phase_q = primegen_add_progress_phase(pgc, prog, qbits); progress_ready(prog); + unsigned pfirst, qfirst; + invent_firstbits(&pfirst, &qfirst, 2); + PrimeCandidateSource *pcs; progress_start_phase(prog, phase_p); pcs = pcs_new_with_firstbits(pbits, pfirst, NFIRSTBITS); pcs_avoid_residue_small(pcs, RSA_EXPONENT, 1); - mp_int *p = primegen(pcs, prog); + mp_int *p = primegen_generate(pgc, pcs, prog); progress_report_phase_complete(prog); progress_start_phase(prog, phase_q); pcs = pcs_new_with_firstbits(qbits, qfirst, NFIRSTBITS); pcs_avoid_residue_small(pcs, RSA_EXPONENT, 1); - mp_int *q = primegen(pcs, prog); + mp_int *q = primegen_generate(pgc, pcs, prog); progress_report_phase_complete(prog); /* diff --git a/test/testcrypt.py b/test/testcrypt.py index 7412fcc7..de27ee38 100644 --- a/test/testcrypt.py +++ b/test/testcrypt.py @@ -177,7 +177,7 @@ def make_argword(arg, argtype, fnname, argindex, to_preserve): return "0x{:x}".format(arg) if typename in { "hashalg", "macalg", "keyalg", "cipheralg", - "dh_group", "ecdh_alg", "rsaorder"}: + "dh_group", "ecdh_alg", "rsaorder", "primegenpolicy"}: arg = unicode_to_bytes(arg) if isinstance(arg, bytes) and b" " not in arg: return arg diff --git a/testcrypt.c b/testcrypt.c index 61204bd7..c2c1e3cd 100644 --- a/testcrypt.c +++ b/testcrypt.c @@ -95,6 +95,7 @@ uint64_t prng_reseed_time_ms(void) X(prng, prng *, prng_free(v)) \ X(keycomponents, key_components *, key_components_free(v)) \ X(pcs, PrimeCandidateSource *, pcs_free(v)) \ + X(pgc, PrimeGenerationContext *, primegen_free_context(v)) \ /* end of list */ typedef struct Value Value; @@ -113,6 +114,12 @@ static const char *const type_names[] = { #undef VALTYPE_NAME }; +#define VALTYPE_TYPEDEF(n,t,f) \ + typedef t TD_val_##n; \ + typedef t *TD_out_val_##n; +VALUE_TYPES(VALTYPE_TYPEDEF) +#undef VALTYPE_TYPEDEF + struct Value { /* * Protocol identifier assigned to this value when it was created. @@ -363,6 +370,23 @@ static RsaSsh1Order get_rsaorder(BinarySource *in) fatal_error("rsaorder '%.*s': not found", PTRLEN_PRINTF(name)); } +static const PrimeGenerationPolicy *get_primegenpolicy(BinarySource *in) +{ + static const struct { + const char *key; + const PrimeGenerationPolicy *value; + } algs[] = { + {"probabilistic", &primegen_probabilistic}, + }; + + ptrlen name = get_word(in); + for (size_t i = 0; i < lenof(algs); i++) + if (ptrlen_eq_string(name, algs[i].key)) + return algs[i].value; + + fatal_error("primegenpolicy '%.*s': not found", PTRLEN_PRINTF(name)); +} + static uintmax_t get_uint(BinarySource *in) { ptrlen word = get_word(in); @@ -536,14 +560,18 @@ static BinarySource *get_val_string_binarysource(BinarySource *in) return src; } -static ssh_hash *get_consumed_val_hash(BinarySource *in) -{ - Value *val = get_value_hash(in); - ssh_hash *toret = val->vu_hash; - del234(values, val); - sfree(val); - return toret; -} +#define GET_CONSUMED_FN(type) \ + typedef TD_val_##type TD_consumed_val_##type; \ + static TD_val_##type get_consumed_val_##type(BinarySource *in) \ + { \ + Value *val = get_value_##type(in); \ + TD_val_##type toret = val->vu_##type; \ + del234(values, val); \ + sfree(val); \ + return toret; \ + } +GET_CONSUMED_FN(hash) +GET_CONSUMED_FN(pcs) static void return_int(strbuf *out, intmax_t u) { @@ -1049,30 +1077,31 @@ strbuf *rsa1_save_sb_wrapper(RSAKey *key, const char *comment, static ProgressReceiver null_progress = { .vt = &null_progress_vt }; -mp_int *primegen_wrapper(PrimeCandidateSource *pcs) +mp_int *primegen_generate_wrapper( + PrimeGenerationContext *ctx, PrimeCandidateSource *pcs) { - return primegen(pcs, &null_progress); + return primegen_generate(ctx, pcs, &null_progress); } -#define primegen primegen_wrapper +#define primegen_generate primegen_generate_wrapper -RSAKey *rsa1_generate(int bits) +RSAKey *rsa1_generate(int bits, PrimeGenerationContext *pgc) { RSAKey *rsakey = snew(RSAKey); - rsa_generate(rsakey, bits, &null_progress); + rsa_generate(rsakey, bits, pgc, &null_progress); rsakey->comment = NULL; return rsakey; } -ssh_key *rsa_generate_wrapper(int bits) +ssh_key *rsa_generate_wrapper(int bits, PrimeGenerationContext *pgc) { - return &rsa1_generate(bits)->sshk; + return &rsa1_generate(bits, pgc)->sshk; } #define rsa_generate rsa_generate_wrapper -ssh_key *dsa_generate_wrapper(int bits) +ssh_key *dsa_generate_wrapper(int bits, PrimeGenerationContext *pgc) { struct dss_key *dsskey = snew(struct dss_key); - dsa_generate(dsskey, bits, &null_progress); + dsa_generate(dsskey, bits, pgc, &null_progress); return &dsskey->sshk; } #define dsa_generate dsa_generate_wrapper @@ -1118,12 +1147,6 @@ mp_int *key_components_nth_mp(key_components *kc, size_t n) mp_copy(kc->components[n].mp)); } -#define VALTYPE_TYPEDEF(n,t,f) \ - typedef t TD_val_##n; \ - typedef t *TD_out_val_##n; -VALUE_TYPES(VALTYPE_TYPEDEF) -#undef VALTYPE_TYPEDEF - #define OPTIONAL_PTR_FUNC(type) \ typedef TD_val_##type TD_opt_val_##type; \ static TD_opt_val_##type get_opt_val_##type(BinarySource *in) { \ @@ -1155,6 +1178,7 @@ typedef const ssh_kex *TD_dh_group; typedef const ssh_kex *TD_ecdh_alg; typedef RsaSsh1Order TD_rsaorder; typedef key_components *TD_keycomponents; +typedef const PrimeGenerationPolicy *TD_primegenpolicy; #define FUNC0(rettype, function) \ static void handle_##function(BinarySource *in, strbuf *out) { \ diff --git a/testcrypt.h b/testcrypt.h index 4469b391..617f746b 100644 --- a/testcrypt.h +++ b/testcrypt.h @@ -260,12 +260,13 @@ FUNC3(val_string, rsa1_save_sb, val_rsa, opt_val_string_asciz, opt_val_string_as /* * Key generation functions. */ -FUNC1(val_key, rsa_generate, uint) -FUNC1(val_key, dsa_generate, uint) +FUNC2(val_key, rsa_generate, uint, val_pgc) +FUNC2(val_key, dsa_generate, uint, val_pgc) FUNC1(opt_val_key, ecdsa_generate, uint) FUNC1(opt_val_key, eddsa_generate, uint) -FUNC1(val_rsa, rsa1_generate, uint) -FUNC1(val_mpint, primegen, val_pcs) +FUNC2(val_rsa, rsa1_generate, uint, val_pgc) +FUNC1(val_pgc, primegen_new_context, primegenpolicy) +FUNC2(opt_val_mpint, primegen_generate, val_pgc, consumed_val_pcs) FUNC1(val_pcs, pcs_new, uint) FUNC3(val_pcs, pcs_new_with_firstbits, uint, uint, uint) FUNC3(void, pcs_require_residue, val_pcs, val_mpint, val_mpint) diff --git a/windows/winpgen.c b/windows/winpgen.c index d1ec8d0b..435500a6 100644 --- a/windows/winpgen.c +++ b/windows/winpgen.c @@ -386,14 +386,19 @@ static DWORD WINAPI generate_key_thread(void *param) win_progress_initialise(&prog); + PrimeGenerationContext *pgc = primegen_new_context( + &primegen_probabilistic); + if (params->keytype == DSA) - dsa_generate(params->dsskey, params->key_bits, &prog.rec); + dsa_generate(params->dsskey, params->key_bits, pgc, &prog.rec); else if (params->keytype == ECDSA) ecdsa_generate(params->eckey, params->curve_bits); else if (params->keytype == ED25519) eddsa_generate(params->edkey, 255); else - rsa_generate(params->key, params->key_bits, &prog.rec); + rsa_generate(params->key, params->key_bits, pgc, &prog.rec); + + primegen_free_context(pgc); PostMessage(params->dialog, WM_DONEKEY, 0, 0);