diff --git a/crypto/ecc-ssh.c b/crypto/ecc-ssh.c index b2eb6e51..e57e8bb1 100644 --- a/crypto/ecc-ssh.c +++ b/crypto/ecc-ssh.c @@ -1367,121 +1367,131 @@ const ssh_keyalg ssh_ecdsa_nistp521 = { }; /* ---------------------------------------------------------------------- - * Exposed ECDH interface + * Exposed ECDH interfaces */ struct eckex_extra { struct ec_curve *(*curve)(void); - void (*setup)(ecdh_key *dh); - void (*cleanup)(ecdh_key *dh); - void (*getpublic)(ecdh_key *dh, BinarySink *bs); - mp_int *(*getkey)(ecdh_key *dh, ptrlen remoteKey); }; -struct ecdh_key { +typedef struct ecdh_key_w { const struct eckex_extra *extra; const struct ec_curve *curve; mp_int *private; - union { - WeierstrassPoint *w_public; - MontgomeryPoint *m_public; - }; -}; + WeierstrassPoint *w_public; -const char *ssh_ecdhkex_curve_textname(const ssh_kex *kex) -{ - const struct eckex_extra *extra = (const struct eckex_extra *)kex->extra; - struct ec_curve *curve = extra->curve(); - return curve->textname; -} + ecdh_key ek; +} ecdh_key_w; -static void ssh_ecdhkex_w_setup(ecdh_key *dh) -{ - mp_int *one = mp_from_integer(1); - dh->private = mp_random_in_range(one, dh->curve->w.G_order); - mp_free(one); +typedef struct ecdh_key_m { + const struct eckex_extra *extra; + const struct ec_curve *curve; + mp_int *private; + MontgomeryPoint *m_public; - dh->w_public = ecc_weierstrass_multiply(dh->curve->w.G, dh->private); -} + ecdh_key ek; +} ecdh_key_m; -static void ssh_ecdhkex_m_setup(ecdh_key *dh) -{ - strbuf *bytes = strbuf_new_nm(); - random_read(strbuf_append(bytes, dh->curve->fieldBytes), - dh->curve->fieldBytes); - - dh->private = mp_from_bytes_le(ptrlen_from_strbuf(bytes)); - - /* Ensure the private key has the highest valid bit set, and no - * bits _above_ the highest valid one */ - mp_reduce_mod_2to(dh->private, dh->curve->fieldBits); - mp_set_bit(dh->private, dh->curve->fieldBits - 1, 1); - - /* Clear a curve-specific number of low bits */ - for (unsigned bit = 0; bit < dh->curve->m.log2_cofactor; bit++) - mp_set_bit(dh->private, bit, 0); - - strbuf_free(bytes); - - dh->m_public = ecc_montgomery_multiply(dh->curve->m.G, dh->private); -} - -ecdh_key *ssh_ecdhkex_newkey(const ssh_kex *kex) +ecdh_key *ssh_ecdhkex_w_new(const ssh_kex *kex, bool is_server) { const struct eckex_extra *extra = (const struct eckex_extra *)kex->extra; const struct ec_curve *curve = extra->curve(); - ecdh_key *dh = snew(ecdh_key); - dh->extra = extra; - dh->curve = curve; - dh->extra->setup(dh); - return dh; + ecdh_key_w *dhw = snew(ecdh_key_w); + dhw->ek.vt = kex->ecdh_vt; + dhw->extra = extra; + dhw->curve = curve; + + mp_int *one = mp_from_integer(1); + dhw->private = mp_random_in_range(one, dhw->curve->w.G_order); + mp_free(one); + + dhw->w_public = ecc_weierstrass_multiply(dhw->curve->w.G, dhw->private); + + return &dhw->ek; +} + +ecdh_key *ssh_ecdhkex_m_new(const ssh_kex *kex, bool is_server) +{ + const struct eckex_extra *extra = (const struct eckex_extra *)kex->extra; + const struct ec_curve *curve = extra->curve(); + + ecdh_key_m *dhm = snew(ecdh_key_m); + dhm->ek.vt = kex->ecdh_vt; + dhm->extra = extra; + dhm->curve = curve; + + strbuf *bytes = strbuf_new_nm(); + random_read(strbuf_append(bytes, dhm->curve->fieldBytes), + dhm->curve->fieldBytes); + + dhm->private = mp_from_bytes_le(ptrlen_from_strbuf(bytes)); + + /* Ensure the private key has the highest valid bit set, and no + * bits _above_ the highest valid one */ + mp_reduce_mod_2to(dhm->private, dhm->curve->fieldBits); + mp_set_bit(dhm->private, dhm->curve->fieldBits - 1, 1); + + /* Clear a curve-specific number of low bits */ + for (unsigned bit = 0; bit < dhm->curve->m.log2_cofactor; bit++) + mp_set_bit(dhm->private, bit, 0); + + strbuf_free(bytes); + + dhm->m_public = ecc_montgomery_multiply(dhm->curve->m.G, dhm->private); + + return &dhm->ek; } static void ssh_ecdhkex_w_getpublic(ecdh_key *dh, BinarySink *bs) { - put_wpoint(bs, dh->w_public, dh->curve, true); + ecdh_key_w *dhw = container_of(dh, ecdh_key_w, ek); + put_wpoint(bs, dhw->w_public, dhw->curve, true); } static void ssh_ecdhkex_m_getpublic(ecdh_key *dh, BinarySink *bs) { + ecdh_key_m *dhm = container_of(dh, ecdh_key_m, ek); mp_int *x; - ecc_montgomery_get_affine(dh->m_public, &x); - for (size_t i = 0; i < dh->curve->fieldBytes; ++i) + ecc_montgomery_get_affine(dhm->m_public, &x); + for (size_t i = 0; i < dhm->curve->fieldBytes; ++i) put_byte(bs, mp_get_byte(x, i)); mp_free(x); } -void ssh_ecdhkex_getpublic(ecdh_key *dh, BinarySink *bs) +static bool ssh_ecdhkex_w_getkey(ecdh_key *dh, ptrlen remoteKey, + BinarySink *bs) { - dh->extra->getpublic(dh, bs); -} + ecdh_key_w *dhw = container_of(dh, ecdh_key_w, ek); -static mp_int *ssh_ecdhkex_w_getkey(ecdh_key *dh, ptrlen remoteKey) -{ - WeierstrassPoint *remote_p = ecdsa_decode(remoteKey, dh->curve); + WeierstrassPoint *remote_p = ecdsa_decode(remoteKey, dhw->curve); if (!remote_p) - return NULL; + return false; if (ecc_weierstrass_is_identity(remote_p)) { /* Not a sensible Diffie-Hellman input value */ ecc_weierstrass_point_free(remote_p); - return NULL; + return false; } - WeierstrassPoint *p = ecc_weierstrass_multiply(remote_p, dh->private); + WeierstrassPoint *p = ecc_weierstrass_multiply(remote_p, dhw->private); mp_int *x; ecc_weierstrass_get_affine(p, &x, NULL); + put_mp_ssh2(bs, x); + mp_free(x); ecc_weierstrass_point_free(remote_p); ecc_weierstrass_point_free(p); - return x; + return true; } -static mp_int *ssh_ecdhkex_m_getkey(ecdh_key *dh, ptrlen remoteKey) +static bool ssh_ecdhkex_m_getkey(ecdh_key *dh, ptrlen remoteKey, + BinarySink *bs) { + ecdh_key_m *dhm = container_of(dh, ecdh_key_m, ek); + mp_int *remote_x = mp_from_bytes_le(remoteKey); /* Per RFC 7748 section 5, discard any set bits of the other @@ -1489,18 +1499,18 @@ static mp_int *ssh_ecdhkex_m_getkey(ecdh_key *dh, ptrlen remoteKey) * to represent all valid values. However, an overlarge value that * still fits into the remaining number of bits is accepted, and * will be reduced mod p. */ - mp_reduce_mod_2to(remote_x, dh->curve->fieldBits); + mp_reduce_mod_2to(remote_x, dhm->curve->fieldBits); MontgomeryPoint *remote_p = ecc_montgomery_point_new( - dh->curve->m.mc, remote_x); + dhm->curve->m.mc, remote_x); mp_free(remote_x); - MontgomeryPoint *p = ecc_montgomery_multiply(remote_p, dh->private); + MontgomeryPoint *p = ecc_montgomery_multiply(remote_p, dhm->private); if (ecc_montgomery_is_identity(p)) { ecc_montgomery_point_free(remote_p); ecc_montgomery_point_free(p); - return NULL; + return false; } mp_int *x; @@ -1524,48 +1534,54 @@ static mp_int *ssh_ecdhkex_m_getkey(ecdh_key *dh, ptrlen remoteKey) * with the _low_ byte zero, i.e. a multiple of 256. */ strbuf *sb = strbuf_new(); - for (size_t i = 0; i < dh->curve->fieldBytes; ++i) + for (size_t i = 0; i < dhm->curve->fieldBytes; ++i) put_byte(sb, mp_get_byte(x, i)); mp_free(x); x = mp_from_bytes_be(ptrlen_from_strbuf(sb)); strbuf_free(sb); + put_mp_ssh2(bs, x); + mp_free(x); - return x; + return true; } -mp_int *ssh_ecdhkex_getkey(ecdh_key *dh, ptrlen remoteKey) +static void ssh_ecdhkex_w_free(ecdh_key *dh) { - return dh->extra->getkey(dh, remoteKey); + ecdh_key_w *dhw = container_of(dh, ecdh_key_w, ek); + mp_free(dhw->private); + ecc_weierstrass_point_free(dhw->w_public); + sfree(dhw); } -static void ssh_ecdhkex_w_cleanup(ecdh_key *dh) +static void ssh_ecdhkex_m_free(ecdh_key *dh) { - ecc_weierstrass_point_free(dh->w_public); + ecdh_key_m *dhm = container_of(dh, ecdh_key_m, ek); + mp_free(dhm->private); + ecc_montgomery_point_free(dhm->m_public); + sfree(dhm); } -static void ssh_ecdhkex_m_cleanup(ecdh_key *dh) +static char *ssh_ecdhkex_description(const ssh_kex *kex) { - ecc_montgomery_point_free(dh->m_public); + const struct eckex_extra *extra = (const struct eckex_extra *)kex->extra; + const struct ec_curve *curve = extra->curve(); + return dupprintf("ECDH key exchange with curve %s", curve->textname); } -void ssh_ecdhkex_freekey(ecdh_key *dh) -{ - mp_free(dh->private); - dh->extra->cleanup(dh); - sfree(dh); -} +static const struct eckex_extra kex_extra_curve25519 = { ec_curve25519 }; -static const struct eckex_extra kex_extra_curve25519 = { - ec_curve25519, - ssh_ecdhkex_m_setup, - ssh_ecdhkex_m_cleanup, - ssh_ecdhkex_m_getpublic, - ssh_ecdhkex_m_getkey, +static const ecdh_keyalg ssh_ecdhkex_m_alg = { + .new = ssh_ecdhkex_m_new, + .free = ssh_ecdhkex_m_free, + .getpublic = ssh_ecdhkex_m_getpublic, + .getkey = ssh_ecdhkex_m_getkey, + .description = ssh_ecdhkex_description, }; const ssh_kex ssh_ec_kex_curve25519 = { .name = "curve25519-sha256", .main_type = KEXTYPE_ECDH, .hash = &ssh_sha256, + .ecdh_vt = &ssh_ecdhkex_m_alg, .extra = &kex_extra_curve25519, }; /* Pre-RFC alias */ @@ -1573,62 +1589,50 @@ const ssh_kex ssh_ec_kex_curve25519_libssh = { .name = "curve25519-sha256@libssh.org", .main_type = KEXTYPE_ECDH, .hash = &ssh_sha256, + .ecdh_vt = &ssh_ecdhkex_m_alg, .extra = &kex_extra_curve25519, }; -static const struct eckex_extra kex_extra_curve448 = { - ec_curve448, - ssh_ecdhkex_m_setup, - ssh_ecdhkex_m_cleanup, - ssh_ecdhkex_m_getpublic, - ssh_ecdhkex_m_getkey, -}; +static const struct eckex_extra kex_extra_curve448 = { ec_curve448 }; const ssh_kex ssh_ec_kex_curve448 = { .name = "curve448-sha512", .main_type = KEXTYPE_ECDH, .hash = &ssh_sha512, + .ecdh_vt = &ssh_ecdhkex_m_alg, .extra = &kex_extra_curve448, }; -static const struct eckex_extra kex_extra_nistp256 = { - ec_p256, - ssh_ecdhkex_w_setup, - ssh_ecdhkex_w_cleanup, - ssh_ecdhkex_w_getpublic, - ssh_ecdhkex_w_getkey, +static const ecdh_keyalg ssh_ecdhkex_w_alg = { + .new = ssh_ecdhkex_w_new, + .free = ssh_ecdhkex_w_free, + .getpublic = ssh_ecdhkex_w_getpublic, + .getkey = ssh_ecdhkex_w_getkey, + .description = ssh_ecdhkex_description, }; +static const struct eckex_extra kex_extra_nistp256 = { ec_p256 }; const ssh_kex ssh_ec_kex_nistp256 = { .name = "ecdh-sha2-nistp256", .main_type = KEXTYPE_ECDH, .hash = &ssh_sha256, + .ecdh_vt = &ssh_ecdhkex_w_alg, .extra = &kex_extra_nistp256, }; -static const struct eckex_extra kex_extra_nistp384 = { - ec_p384, - ssh_ecdhkex_w_setup, - ssh_ecdhkex_w_cleanup, - ssh_ecdhkex_w_getpublic, - ssh_ecdhkex_w_getkey, -}; +static const struct eckex_extra kex_extra_nistp384 = { ec_p384 }; const ssh_kex ssh_ec_kex_nistp384 = { .name = "ecdh-sha2-nistp384", .main_type = KEXTYPE_ECDH, .hash = &ssh_sha384, + .ecdh_vt = &ssh_ecdhkex_w_alg, .extra = &kex_extra_nistp384, }; -static const struct eckex_extra kex_extra_nistp521 = { - ec_p521, - ssh_ecdhkex_w_setup, - ssh_ecdhkex_w_cleanup, - ssh_ecdhkex_w_getpublic, - ssh_ecdhkex_w_getkey, -}; +static const struct eckex_extra kex_extra_nistp521 = { ec_p521 }; const ssh_kex ssh_ec_kex_nistp521 = { .name = "ecdh-sha2-nistp521", .main_type = KEXTYPE_ECDH, .hash = &ssh_sha512, + .ecdh_vt = &ssh_ecdhkex_w_alg, .extra = &kex_extra_nistp521, }; diff --git a/defs.h b/defs.h index 354c208f..1edf3442 100644 --- a/defs.h +++ b/defs.h @@ -167,6 +167,7 @@ typedef struct ssh_cipher ssh_cipher; typedef struct ssh2_ciphers ssh2_ciphers; typedef struct dh_ctx dh_ctx; typedef struct ecdh_key ecdh_key; +typedef struct ecdh_keyalg ecdh_keyalg; typedef struct dlgparam dlgparam; diff --git a/ssh.h b/ssh.h index 8187b394..60880482 100644 --- a/ssh.h +++ b/ssh.h @@ -615,15 +615,6 @@ strbuf *ssh_rsakex_encrypt( mp_int *ssh_rsakex_decrypt( RSAKey *key, const ssh_hashalg *h, ptrlen ciphertext); -/* - * SSH2 ECDH key exchange functions - */ -const char *ssh_ecdhkex_curve_textname(const ssh_kex *kex); -ecdh_key *ssh_ecdhkex_newkey(const ssh_kex *kex); -void ssh_ecdhkex_freekey(ecdh_key *key); -void ssh_ecdhkex_getpublic(ecdh_key *key, BinarySink *bs); -mp_int *ssh_ecdhkex_getkey(ecdh_key *key, ptrlen remoteKey); - /* * Helper function for k generation in DSA, reused in ECDSA */ @@ -806,6 +797,9 @@ struct ssh_kex { const char *name, *groupname; enum { KEXTYPE_DH, KEXTYPE_RSA, KEXTYPE_ECDH, KEXTYPE_GSS } main_type; const ssh_hashalg *hash; + union { /* publicly visible data for each type */ + const ecdh_keyalg *ecdh_vt; /* for KEXTYPE_ECDH */ + }; const void *extra; /* private to the kex methods */ }; @@ -884,6 +878,35 @@ static inline const char *ssh_key_ssh_id(ssh_key *key) static inline const char *ssh_key_cache_id(ssh_key *key) { return key->vt->cache_id; } +/* + * SSH2 ECDH key exchange vtable + */ +struct ecdh_key { + const ecdh_keyalg *vt; +}; +struct ecdh_keyalg { + /* Unusually, the 'new' method here doesn't directly take a vt + * pointer, because it will also need the containing ssh_kex + * structure for top-level parameters, and since that contains a + * vt pointer anyway, we might as well _only_ pass that. */ + ecdh_key *(*new)(const ssh_kex *kex, bool is_server); + void (*free)(ecdh_key *key); + void (*getpublic)(ecdh_key *key, BinarySink *bs); + bool (*getkey)(ecdh_key *key, ptrlen remoteKey, BinarySink *bs); + char *(*description)(const ssh_kex *kex); +}; +static inline ecdh_key *ecdh_key_new(const ssh_kex *kex, bool is_server) +{ return kex->ecdh_vt->new(kex, is_server); } +static inline void ecdh_key_free(ecdh_key *key) +{ key->vt->free(key); } +static inline void ecdh_key_getpublic(ecdh_key *key, BinarySink *bs) +{ key->vt->getpublic(key, bs); } +static inline bool ecdh_key_getkey(ecdh_key *key, ptrlen remoteKey, + BinarySink *bs) +{ return key->vt->getkey(key, remoteKey, bs); } +static inline char *ecdh_keyalg_description(const ssh_kex *kex) +{ return kex->ecdh_vt->description(kex); } + /* * Enumeration of signature flags from draft-miller-ssh-agent-02 */ diff --git a/ssh/kex2-client.c b/ssh/kex2-client.c index 633360ec..ff78840a 100644 --- a/ssh/kex2-client.c +++ b/ssh/kex2-client.c @@ -185,13 +185,14 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) mp_free(s->p); s->p = NULL; } } else if (s->kex_alg->main_type == KEXTYPE_ECDH) { - - ppl_logevent("Doing ECDH key exchange with curve %s and hash %s", - ssh_ecdhkex_curve_textname(s->kex_alg), + char *desc = ecdh_keyalg_description(s->kex_alg); + ppl_logevent("Doing %s, using hash %s", desc, ssh_hash_alg(s->exhash)->text_name); + sfree(desc); + s->ppl.bpp->pls->kctx = SSH2_PKTCTX_ECDHKEX; - s->ecdh_key = ssh_ecdhkex_newkey(s->kex_alg); + s->ecdh_key = ecdh_key_new(s->kex_alg, false); if (!s->ecdh_key) { ssh_sw_abort(s->ppl.ssh, "Unable to generate key for ECDH"); *aborted = true; @@ -201,7 +202,7 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) pktout = ssh_bpp_new_pktout(s->ppl.bpp, SSH2_MSG_KEX_ECDH_INIT); { strbuf *pubpoint = strbuf_new(); - ssh_ecdhkex_getpublic(s->ecdh_key, BinarySink_UPCAST(pubpoint)); + ecdh_key_getpublic(s->ecdh_key, BinarySink_UPCAST(pubpoint)); put_stringsb(pktout, pubpoint); } @@ -224,7 +225,7 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) { strbuf *pubpoint = strbuf_new(); - ssh_ecdhkex_getpublic(s->ecdh_key, BinarySink_UPCAST(pubpoint)); + ecdh_key_getpublic(s->ecdh_key, BinarySink_UPCAST(pubpoint)); put_string(s->exhash, pubpoint->u, pubpoint->len); strbuf_free(pubpoint); } @@ -232,15 +233,14 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) { ptrlen keydata = get_string(pktin); put_stringpl(s->exhash, keydata); - mp_int *K = ssh_ecdhkex_getkey(s->ecdh_key, keydata); - if (!get_err(pktin) && !K) { + bool ok = ecdh_key_getkey(s->ecdh_key, keydata, + BinarySink_UPCAST(s->kex_shared_secret)); + if (!get_err(pktin) && !ok) { ssh_proto_error(s->ppl.ssh, "Received invalid elliptic curve " "point in ECDH reply"); *aborted = true; return; } - put_mp_ssh2(s->kex_shared_secret, K); - mp_free(K); } s->sigdata = get_string(pktin); @@ -250,7 +250,7 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) return; } - ssh_ecdhkex_freekey(s->ecdh_key); + ecdh_key_free(s->ecdh_key); s->ecdh_key = NULL; #ifndef NO_GSSAPI } else if (s->kex_alg->main_type == KEXTYPE_GSS) { diff --git a/ssh/kex2-server.c b/ssh/kex2-server.c index 9657589b..570d7750 100644 --- a/ssh/kex2-server.c +++ b/ssh/kex2-server.c @@ -191,12 +191,12 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) mp_free(s->p); s->p = NULL; } } else if (s->kex_alg->main_type == KEXTYPE_ECDH) { - ppl_logevent("Doing ECDH key exchange with curve %s and hash %s", - ssh_ecdhkex_curve_textname(s->kex_alg), + char *desc = ecdh_keyalg_description(s->kex_alg); + ppl_logevent("Doing %s, using hash %s", desc, ssh_hash_alg(s->exhash)->text_name); - s->ppl.bpp->pls->kctx = SSH2_PKTCTX_ECDHKEX; + sfree(desc); - s->ecdh_key = ssh_ecdhkex_newkey(s->kex_alg); + s->ecdh_key = ecdh_key_new(s->kex_alg, true); if (!s->ecdh_key) { ssh_sw_abort(s->ppl.ssh, "Unable to generate key for ECDH"); *aborted = true; @@ -219,29 +219,28 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) ptrlen keydata = get_string(pktin); put_stringpl(s->exhash, keydata); - mp_int *K = ssh_ecdhkex_getkey(s->ecdh_key, keydata); - if (!get_err(pktin) && !K) { + bool ok = ecdh_key_getkey(s->ecdh_key, keydata, + BinarySink_UPCAST(s->kex_shared_secret)); + if (!get_err(pktin) && !ok) { ssh_proto_error(s->ppl.ssh, "Received invalid elliptic curve " "point in ECDH initial packet"); *aborted = true; return; } - put_mp_ssh2(s->kex_shared_secret, K); - mp_free(K); } pktout = ssh_bpp_new_pktout(s->ppl.bpp, SSH2_MSG_KEX_ECDH_REPLY); put_stringpl(pktout, s->hostkeydata); { strbuf *pubpoint = strbuf_new(); - ssh_ecdhkex_getpublic(s->ecdh_key, BinarySink_UPCAST(pubpoint)); + ecdh_key_getpublic(s->ecdh_key, BinarySink_UPCAST(pubpoint)); put_string(s->exhash, pubpoint->u, pubpoint->len); put_stringsb(pktout, pubpoint); } put_stringsb(pktout, finalise_and_sign_exhash(s)); pq_push(s->ppl.out_pq, pktout); - ssh_ecdhkex_freekey(s->ecdh_key); + ecdh_key_free(s->ecdh_key); s->ecdh_key = NULL; } else if (s->kex_alg->main_type == KEXTYPE_GSS) { ssh_sw_abort(s->ppl.ssh, "GSS key exchange not supported in server"); diff --git a/ssh/transport2.c b/ssh/transport2.c index 4ce4156d..e00507e1 100644 --- a/ssh/transport2.c +++ b/ssh/transport2.c @@ -227,7 +227,7 @@ static void ssh2_transport_free(PacketProtocolLayer *ppl) sfree(s->rsa_kex_key); } if (s->ecdh_key) - ssh_ecdhkex_freekey(s->ecdh_key); + ecdh_key_free(s->ecdh_key); if (s->exhash) ssh_hash_free(s->exhash); strbuf_free(s->outgoing_kexinit); diff --git a/test/cryptsuite.py b/test/cryptsuite.py index 4971a599..958980b9 100755 --- a/test/cryptsuite.py +++ b/test/cryptsuite.py @@ -1772,13 +1772,13 @@ class crypt(MyTestBase): ] with random_prng("doesn't matter"): - ecdh25519 = ssh_ecdhkex_newkey('curve25519') - ecdh448 = ssh_ecdhkex_newkey('curve448') + ecdh25519 = ecdh_key_new('curve25519', False) + ecdh448 = ecdh_key_new('curve448', False) for pub in bad_keys_25519: - key = ssh_ecdhkex_getkey(ecdh25519, unhex(pub)) + key = ecdh_key_getkey(ecdh25519, unhex(pub)) self.assertEqual(key, None) for pub in bad_keys_448: - key = ssh_ecdhkex_getkey(ecdh448, unhex(pub)) + key = ecdh_key_getkey(ecdh448, unhex(pub)) self.assertEqual(key, None) def testPRNG(self): @@ -3107,9 +3107,9 @@ class standard_test_vectors(MyTestBase): for method, priv, pub, expected in rfc7748s5_2: with queued_specific_random_data(unhex(priv)): - ecdh = ssh_ecdhkex_newkey(method) - key = ssh_ecdhkex_getkey(ecdh, unhex(pub)) - self.assertEqual(int(key), expected) + ecdh = ecdh_key_new(method, False) + key = ecdh_key_getkey(ecdh, unhex(pub)) + self.assertEqual(key, ssh2_mpint(expected)) # Bidirectional tests, consisting of the input random number # strings for both parties, and the expected public values and @@ -3131,15 +3131,15 @@ class standard_test_vectors(MyTestBase): for method, apriv, apub, bpriv, bpub, expected in rfc7748s6: with queued_specific_random_data(unhex(apriv)): - alice = ssh_ecdhkex_newkey(method) + alice = ecdh_key_new(method, False) with queued_specific_random_data(unhex(bpriv)): - bob = ssh_ecdhkex_newkey(method) - self.assertEqualBin(ssh_ecdhkex_getpublic(alice), unhex(apub)) - self.assertEqualBin(ssh_ecdhkex_getpublic(bob), unhex(bpub)) - akey = ssh_ecdhkex_getkey(alice, unhex(bpub)) - bkey = ssh_ecdhkex_getkey(bob, unhex(apub)) - self.assertEqual(int(akey), expected) - self.assertEqual(int(bkey), expected) + bob = ecdh_key_new(method, False) + self.assertEqualBin(ecdh_key_getpublic(alice), unhex(apub)) + self.assertEqualBin(ecdh_key_getpublic(bob), unhex(bpub)) + akey = ecdh_key_getkey(alice, unhex(bpub)) + bkey = ecdh_key_getkey(bob, unhex(apub)) + self.assertEqual(akey, ssh2_mpint(expected)) + self.assertEqual(bkey, ssh2_mpint(expected)) def testCRC32(self): self.assertEqual(crc32_rfc1662("123456789"), 0xCBF43926) diff --git a/test/testcrypt-func.h b/test/testcrypt-func.h index 2cb0b3dc..1188d2c4 100644 --- a/test/testcrypt-func.h +++ b/test/testcrypt-func.h @@ -346,11 +346,11 @@ FUNC(val_mpint, dh_find_K, ARG(val_dh, ctx), ARG(val_mpint, f)) /* * Elliptic-curve Diffie-Hellman. */ -FUNC(val_ecdh, ssh_ecdhkex_newkey, ARG(ecdh_alg, alg)) -FUNC(void, ssh_ecdhkex_getpublic, ARG(val_ecdh, key), +FUNC(val_ecdh, ecdh_key_new, ARG(ecdh_alg, alg), ARG(boolean, is_server)) +FUNC(void, ecdh_key_getpublic, ARG(val_ecdh, key), ARG(out_val_string_binarysink, pub)) -FUNC(opt_val_mpint, ssh_ecdhkex_getkey, ARG(val_ecdh, key), - ARG(val_string_ptrlen, pub)) +FUNC_WRAPPED(opt_val_string, ecdh_key_getkey, ARG(val_ecdh, key), + ARG(val_string_ptrlen, pub)) /* * RSA key exchange, and also the BinarySource get function diff --git a/test/testcrypt.c b/test/testcrypt.c index 01b5c501..cfe28ae2 100644 --- a/test/testcrypt.c +++ b/test/testcrypt.c @@ -87,7 +87,7 @@ uint64_t prng_reseed_time_ms(void) X(cipher, ssh_cipher *, ssh_cipher_free(v)) \ X(mac, ssh2_mac *, ssh2_mac_free(v)) \ X(dh, dh_ctx *, dh_cleanup(v)) \ - X(ecdh, ecdh_key *, ssh_ecdhkex_freekey(v)) \ + X(ecdh, ecdh_key *, ecdh_key_free(v)) \ X(rsakex, RSAKey *, ssh_rsakex_freekey(v)) \ X(rsa, RSAKey *, rsa_free(v)) \ X(prng, prng *, prng_free(v)) \ @@ -787,6 +787,18 @@ static RSAKey *rsa_new(void) return rsa; } +strbuf *ecdh_key_getkey_wrapper(ecdh_key *ek, ptrlen remoteKey) +{ + /* Fold the boolean return value in C into the string return value + * for this purpose, by returning NULL on failure */ + strbuf *sb = strbuf_new(); + if (!ecdh_key_getkey(ek, remoteKey, BinarySink_UPCAST(sb))) { + strbuf_free(sb); + return NULL; + } + return sb; +} + strbuf *rsa_ssh1_encrypt_wrapper(ptrlen input, RSAKey *key) { /* Fold the boolean return value in C into the string return value