/* * Centralised machinery for hybridised post-quantum + classical key * exchange setups, using the same message structure as ECDH but the * strings sent each way are the concatenation of a key or ciphertext * of each type, and the output shared secret is obtained by hashing * together both of the sub-methods' outputs. */ #include #include #include #include "putty.h" #include "ssh.h" #include "mpint.h" /* ---------------------------------------------------------------------- * Common definitions between client and server sides. */ typedef struct hybrid_alg hybrid_alg; struct hybrid_alg { const ssh_hashalg *combining_hash; const pq_kemalg *pq_alg; const ssh_kex *classical_alg; void (*reformat)(ptrlen input, BinarySink *output); }; static char *hybrid_description(const ssh_kex *kex) { const struct hybrid_alg *alg = kex->extra; /* Bit of a bodge, but think up a short name to describe the * classical algorithm */ const char *classical_name; if (alg->classical_alg == &ssh_ec_kex_curve25519) classical_name = "Curve25519"; else if (alg->classical_alg == &ssh_ec_kex_nistp256) classical_name = "NIST P256"; else if (alg->classical_alg == &ssh_ec_kex_nistp384) classical_name = "NIST P384"; else unreachable("don't have a name for this classical alg"); return dupprintf("%s / %s hybrid key exchange", alg->pq_alg->description, classical_name); } static void reformat_mpint_be(ptrlen input, BinarySink *output, size_t bytes) { BinarySource src[1]; BinarySource_BARE_INIT_PL(src, input); mp_int *mp = get_mp_ssh2(src); assert(!get_err(src)); assert(get_avail(src) == 0); for (size_t i = bytes; i-- > 0 ;) put_byte(output, mp_get_byte(mp, i)); mp_free(mp); } static void reformat_mpint_be_32(ptrlen input, BinarySink *output) { reformat_mpint_be(input, output, 32); } static void reformat_mpint_be_48(ptrlen input, BinarySink *output) { reformat_mpint_be(input, output, 48); } /* ---------------------------------------------------------------------- * Client side. */ typedef struct hybrid_client_state hybrid_client_state; static const ecdh_keyalg hybrid_client_vt; struct hybrid_client_state { const hybrid_alg *alg; strbuf *pq_ek; pq_kem_dk *pq_dk; ecdh_key *classical; ecdh_key ek; }; static ecdh_key *hybrid_client_new(const ssh_kex *kex, bool is_server) { assert(!is_server); hybrid_client_state *s = snew(hybrid_client_state); s->alg = kex->extra; s->ek.vt = &hybrid_client_vt; s->pq_ek = strbuf_new(); s->pq_dk = pq_kem_keygen(s->alg->pq_alg, BinarySink_UPCAST(s->pq_ek)); s->classical = ecdh_key_new(s->alg->classical_alg, is_server); return &s->ek; } static void hybrid_client_free(ecdh_key *ek) { hybrid_client_state *s = container_of(ek, hybrid_client_state, ek); strbuf_free(s->pq_ek); pq_kem_free_dk(s->pq_dk); ecdh_key_free(s->classical); sfree(s); } /* * In the client, getpublic is called first: we make up a KEM key * pair, and send the public key along with a classical DH value. */ static void hybrid_client_getpublic(ecdh_key *ek, BinarySink *bs) { hybrid_client_state *s = container_of(ek, hybrid_client_state, ek); put_datapl(bs, ptrlen_from_strbuf(s->pq_ek)); ecdh_key_getpublic(s->classical, bs); } /* * In the client, getkey is called second, after the server sends its * response: we use our KEM private key to decapsulate the server's * ciphertext. */ static bool hybrid_client_getkey(ecdh_key *ek, ptrlen remoteKey, BinarySink *bs) { hybrid_client_state *s = container_of(ek, hybrid_client_state, ek); BinarySource src[1]; BinarySource_BARE_INIT_PL(src, remoteKey); ssh_hash *h = ssh_hash_new(s->alg->combining_hash); ptrlen pq_ciphertext = get_data(src, s->alg->pq_alg->c_len); if (get_err(src)) { ssh_hash_free(h); return false; /* not enough data */ } if (!pq_kem_decaps(s->pq_dk, BinarySink_UPCAST(h), pq_ciphertext)) { ssh_hash_free(h); return false; /* pq ciphertext didn't validate */ } ptrlen classical_data = get_data(src, get_avail(src)); strbuf *classical_key = strbuf_new(); if (!ecdh_key_getkey(s->classical, classical_data, BinarySink_UPCAST(classical_key))) { ssh_hash_free(h); return false; /* classical DH key didn't validate */ } s->alg->reformat(ptrlen_from_strbuf(classical_key), BinarySink_UPCAST(h)); strbuf_free(classical_key); /* * Finish up: compute the final output hash and return it encoded * as a string. */ unsigned char hashdata[MAX_HASH_LEN]; ssh_hash_final(h, hashdata); put_stringpl(bs, make_ptrlen(hashdata, s->alg->combining_hash->hlen)); smemclr(hashdata, sizeof(hashdata)); return true; } static const ecdh_keyalg hybrid_client_vt = { .new = hybrid_client_new, /* but normally the selector calls this */ .free = hybrid_client_free, .getpublic = hybrid_client_getpublic, .getkey = hybrid_client_getkey, .description = hybrid_description, .packet_naming_ctx = SSH2_PKTCTX_HYBRIDKEX, }; /* ---------------------------------------------------------------------- * Server side. */ typedef struct hybrid_server_state hybrid_server_state; static const ecdh_keyalg hybrid_server_vt; struct hybrid_server_state { const hybrid_alg *alg; strbuf *pq_ciphertext; ecdh_key *classical; ecdh_key ek; }; static ecdh_key *hybrid_server_new(const ssh_kex *kex, bool is_server) { assert(is_server); hybrid_server_state *s = snew(hybrid_server_state); s->alg = kex->extra; s->ek.vt = &hybrid_server_vt; s->pq_ciphertext = strbuf_new_nm(); s->classical = ecdh_key_new(s->alg->classical_alg, is_server); return &s->ek; } static void hybrid_server_free(ecdh_key *ek) { hybrid_server_state *s = container_of(ek, hybrid_server_state, ek); strbuf_free(s->pq_ciphertext); ecdh_key_free(s->classical); sfree(s); } /* * In the server, getkey is called first: we receive a KEM encryption * key from the client and encapsulate a secret with it. We write the * output secret to bs; the data we'll send to the client is saved to * return from getpublic. */ static bool hybrid_server_getkey(ecdh_key *ek, ptrlen remoteKey, BinarySink *bs) { hybrid_server_state *s = container_of(ek, hybrid_server_state, ek); BinarySource src[1]; BinarySource_BARE_INIT_PL(src, remoteKey); ssh_hash *h = ssh_hash_new(s->alg->combining_hash); ptrlen pq_ek = get_data(src, s->alg->pq_alg->ek_len); if (get_err(src)) { ssh_hash_free(h); return false; /* not enough data */ } if (!pq_kem_encaps(s->alg->pq_alg, BinarySink_UPCAST(s->pq_ciphertext), BinarySink_UPCAST(h), pq_ek)) { ssh_hash_free(h); return false; /* pq encryption key didn't validate */ } ptrlen classical_data = get_data(src, get_avail(src)); strbuf *classical_key = strbuf_new(); if (!ecdh_key_getkey(s->classical, classical_data, BinarySink_UPCAST(classical_key))) { ssh_hash_free(h); return false; /* classical DH key didn't validate */ } s->alg->reformat(ptrlen_from_strbuf(classical_key), BinarySink_UPCAST(h)); strbuf_free(classical_key); /* * Finish up: compute the final output hash and return it encoded * as a string. */ unsigned char hashdata[MAX_HASH_LEN]; ssh_hash_final(h, hashdata); put_stringpl(bs, make_ptrlen(hashdata, s->alg->combining_hash->hlen)); smemclr(hashdata, sizeof(hashdata)); return true; } static void hybrid_server_getpublic(ecdh_key *ek, BinarySink *bs) { hybrid_server_state *s = container_of(ek, hybrid_server_state, ek); put_datapl(bs, ptrlen_from_strbuf(s->pq_ciphertext)); ecdh_key_getpublic(s->classical, bs); } static const ecdh_keyalg hybrid_server_vt = { .new = hybrid_server_new, /* but normally the selector calls this */ .free = hybrid_server_free, .getkey = hybrid_server_getkey, .getpublic = hybrid_server_getpublic, .description = hybrid_description, .packet_naming_ctx = SSH2_PKTCTX_HYBRIDKEX, }; /* ---------------------------------------------------------------------- * Selector vtable that instantiates the appropriate one of the above, * depending on is_server. */ static ecdh_key *hybrid_selector_new(const ssh_kex *kex, bool is_server) { if (is_server) return hybrid_server_new(kex, is_server); else return hybrid_client_new(kex, is_server); } static const ecdh_keyalg hybrid_selector_vt = { /* This is a never-instantiated vtable which only implements the * functions that don't require an instance. */ .new = hybrid_selector_new, .description = hybrid_description, .packet_naming_ctx = SSH2_PKTCTX_HYBRIDKEX, }; /* ---------------------------------------------------------------------- * Actual KEX methods. */ static const hybrid_alg ssh_ntru_curve25519_hybrid = { .combining_hash = &ssh_sha512, .pq_alg = &ssh_ntru, .classical_alg = &ssh_ec_kex_curve25519, .reformat = reformat_mpint_be_32, }; static const ssh_kex ssh_ntru_curve25519 = { .name = "sntrup761x25519-sha512", .main_type = KEXTYPE_ECDH, .hash = &ssh_sha512, .ecdh_vt = &hybrid_selector_vt, .extra = &ssh_ntru_curve25519_hybrid, }; static const ssh_kex ssh_ntru_curve25519_openssh = { .name = "sntrup761x25519-sha512@openssh.com", .main_type = KEXTYPE_ECDH, .hash = &ssh_sha512, .ecdh_vt = &hybrid_selector_vt, .extra = &ssh_ntru_curve25519_hybrid, }; static const ssh_kex *const ntru_hybrid_list[] = { &ssh_ntru_curve25519, &ssh_ntru_curve25519_openssh, }; const ssh_kexes ssh_ntru_hybrid_kex = { lenof(ntru_hybrid_list), ntru_hybrid_list, }; static const hybrid_alg ssh_mlkem768_curve25519_hybrid = { .combining_hash = &ssh_sha256, .pq_alg = &ssh_mlkem768, .classical_alg = &ssh_ec_kex_curve25519, .reformat = reformat_mpint_be_32, }; static const ssh_kex ssh_mlkem768_curve25519 = { .name = "mlkem768x25519-sha256", .main_type = KEXTYPE_ECDH, .hash = &ssh_sha256, .ecdh_vt = &hybrid_selector_vt, .extra = &ssh_mlkem768_curve25519_hybrid, }; static const ssh_kex *const mlkem_curve25519_hybrid_list[] = { &ssh_mlkem768_curve25519, }; const ssh_kexes ssh_mlkem_curve25519_hybrid_kex = { lenof(mlkem_curve25519_hybrid_list), mlkem_curve25519_hybrid_list, }; static const hybrid_alg ssh_mlkem768_p256_hybrid = { .combining_hash = &ssh_sha256, .pq_alg = &ssh_mlkem768, .classical_alg = &ssh_ec_kex_nistp256, .reformat = reformat_mpint_be_32, }; static const ssh_kex ssh_mlkem768_p256 = { .name = "mlkem768nistp256-sha256", .main_type = KEXTYPE_ECDH, .hash = &ssh_sha256, .ecdh_vt = &hybrid_selector_vt, .extra = &ssh_mlkem768_p256_hybrid, }; static const hybrid_alg ssh_mlkem1024_p384_hybrid = { .combining_hash = &ssh_sha384, .pq_alg = &ssh_mlkem1024, .classical_alg = &ssh_ec_kex_nistp384, .reformat = reformat_mpint_be_48, }; static const ssh_kex ssh_mlkem1024_p384 = { .name = "mlkem1024nistp384-sha384", .main_type = KEXTYPE_ECDH, .hash = &ssh_sha384, .ecdh_vt = &hybrid_selector_vt, .extra = &ssh_mlkem1024_p384_hybrid, }; static const ssh_kex *const mlkem_nist_hybrid_list[] = { &ssh_mlkem1024_p384, &ssh_mlkem768_p256, }; const ssh_kexes ssh_mlkem_nist_hybrid_kex = { lenof(mlkem_nist_hybrid_list), mlkem_nist_hybrid_list, };