From e103ab1fb6d6b3a41f2b53905f6eea26bdababe6 Mon Sep 17 00:00:00 2001 From: Simon Tatham Date: Thu, 14 Apr 2022 06:23:11 +0100 Subject: [PATCH] Refactor handling of SSH kex shared secret. Until now, every kex method has represented the output as an mp_int. So we were storing it in the mp_int field s->K, and adding it to the exchange hash and key derivation hashes via put_mp_ssh2. But there's now going to be the first kex method that represents the output as a string (so that it might have the top bit set, or multiple leading zero bytes, without its length varying). So we now need to be more general. The most general thing it's sensible to do is to replace s->K with a strbuf containing _already-encoded_ data to become part of the hash, including length fields if necessary. So every existing kex method still derives an mp_int, but then immediately puts it into that strbuf using put_mp_ssh2 and frees it. --- ssh/kex2-client.c | 26 +++++++++++++++++++------- ssh/kex2-server.c | 18 +++++++++++++----- ssh/transport2.c | 28 ++++++++++++++++------------ ssh/transport2.h | 3 ++- 4 files changed, 50 insertions(+), 25 deletions(-) diff --git a/ssh/kex2-client.c b/ssh/kex2-client.c index 9a8f75e2..633360ec 100644 --- a/ssh/kex2-client.c +++ b/ssh/kex2-client.c @@ -156,7 +156,9 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) return; } } - s->K = dh_find_K(s->dh_ctx, s->f); + mp_int *K = dh_find_K(s->dh_ctx, s->f); + put_mp_ssh2(s->kex_shared_secret, K); + mp_free(K); /* We assume everything from now on will be quick, and it might * involve user interaction. */ @@ -230,13 +232,15 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) { ptrlen keydata = get_string(pktin); put_stringpl(s->exhash, keydata); - s->K = ssh_ecdhkex_getkey(s->ecdh_key, keydata); - if (!get_err(pktin) && !s->K) { + mp_int *K = ssh_ecdhkex_getkey(s->ecdh_key, keydata); + if (!get_err(pktin) && !K) { ssh_proto_error(s->ppl.ssh, "Received invalid elliptic curve " "point in ECDH reply"); *aborted = true; return; } + put_mp_ssh2(s->kex_shared_secret, K); + mp_free(K); } s->sigdata = get_string(pktin); @@ -485,7 +489,9 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) return; } } - s->K = dh_find_K(s->dh_ctx, s->f); + mp_int *K = dh_find_K(s->dh_ctx, s->f); + put_mp_ssh2(s->kex_shared_secret, K); + mp_free(K); /* We assume everything from now on will be quick, and it might * involve user interaction. */ @@ -584,15 +590,21 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) strbuf *buf, *outstr; mp_int *tmp = mp_random_bits(nbits - 1); - s->K = mp_power_2(nbits - 1); - mp_add_into(s->K, s->K, tmp); + mp_int *K = mp_power_2(nbits - 1); + mp_add_into(K, K, tmp); mp_free(tmp); /* * Encode this as an mpint. */ buf = strbuf_new_nm(); - put_mp_ssh2(buf, s->K); + put_mp_ssh2(buf, K); + + /* + * Store a copy as the output shared secret from the kex. + */ + put_mp_ssh2(s->kex_shared_secret, K); + mp_free(K); /* * Encrypt it with the given RSA key. diff --git a/ssh/kex2-server.c b/ssh/kex2-server.c index 3c017077..9657589b 100644 --- a/ssh/kex2-server.c +++ b/ssh/kex2-server.c @@ -161,7 +161,9 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) return; } } - s->K = dh_find_K(s->dh_ctx, s->f); + mp_int *K = dh_find_K(s->dh_ctx, s->f); + put_mp_ssh2(s->kex_shared_secret, K); + mp_free(K); if (dh_is_gex(s->kex_alg)) { if (s->dh_got_size_bounds) @@ -217,13 +219,15 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) ptrlen keydata = get_string(pktin); put_stringpl(s->exhash, keydata); - s->K = ssh_ecdhkex_getkey(s->ecdh_key, keydata); - if (!get_err(pktin) && !s->K) { + mp_int *K = ssh_ecdhkex_getkey(s->ecdh_key, keydata); + if (!get_err(pktin) && !K) { ssh_proto_error(s->ppl.ssh, "Received invalid elliptic curve " "point in ECDH initial packet"); *aborted = true; return; } + put_mp_ssh2(s->kex_shared_secret, K); + mp_free(K); } pktout = ssh_bpp_new_pktout(s->ppl.bpp, SSH2_MSG_KEX_ECDH_REPLY); @@ -301,19 +305,23 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) return; } + mp_int *K; { ptrlen encrypted_secret = get_string(pktin); put_stringpl(s->exhash, encrypted_secret); - s->K = ssh_rsakex_decrypt( + K = ssh_rsakex_decrypt( s->rsa_kex_key, s->kex_alg->hash, encrypted_secret); } - if (!s->K) { + if (!K) { ssh_proto_error(s->ppl.ssh, "Unable to decrypt RSA kex secret"); *aborted = true; return; } + put_mp_ssh2(s->kex_shared_secret, K); + mp_free(K); + if (s->rsa_kex_key_needs_freeing) { ssh_rsakex_freekey(s->rsa_kex_key); sfree(s->rsa_kex_key); diff --git a/ssh/transport2.c b/ssh/transport2.c index 2f9b02af..4ce4156d 100644 --- a/ssh/transport2.c +++ b/ssh/transport2.c @@ -219,7 +219,7 @@ static void ssh2_transport_free(PacketProtocolLayer *ppl) if (s->f) mp_free(s->f); if (s->p) mp_free(s->p); if (s->g) mp_free(s->g); - if (s->K) mp_free(s->K); + if (s->kex_shared_secret) strbuf_free(s->kex_shared_secret); if (s->dh_ctx) dh_cleanup(s->dh_ctx); if (s->rsa_kex_key_needs_freeing) { @@ -245,7 +245,7 @@ static void ssh2_transport_free(PacketProtocolLayer *ppl) */ static void ssh2_mkkey( struct ssh2_transport_state *s, strbuf *out, - mp_int *K, unsigned char *H, char chr, int keylen) + strbuf *kex_shared_secret, unsigned char *H, char chr, int keylen) { int hlen = s->kex_alg->hash->hlen; int keylen_padded; @@ -273,7 +273,7 @@ static void ssh2_mkkey( /* First hlen bytes. */ h = ssh_hash_new(s->kex_alg->hash); if (!(s->ppl.remote_bugs & BUG_SSH2_DERIVEKEY)) - put_mp_ssh2(h, K); + put_datapl(h, ptrlen_from_strbuf(kex_shared_secret)); put_data(h, H, hlen); put_byte(h, chr); put_data(h, s->session_id, s->session_id_len); @@ -285,7 +285,7 @@ static void ssh2_mkkey( ssh_hash_reset(h); if (!(s->ppl.remote_bugs & BUG_SSH2_DERIVEKEY)) - put_mp_ssh2(h, K); + put_datapl(h, ptrlen_from_strbuf(kex_shared_secret)); put_data(h, H, hlen); for (offset = hlen; offset < keylen_padded; offset += hlen) { @@ -1093,7 +1093,7 @@ static bool ssh2_scan_kexinits( void ssh2transport_finalise_exhash(struct ssh2_transport_state *s) { - put_mp_ssh2(s->exhash, s->K); + put_datapl(s->exhash, ptrlen_from_strbuf(s->kex_shared_secret)); assert(ssh_hash_alg(s->exhash)->hlen <= sizeof(s->exchange_hash)); ssh_hash_final(s->exhash, s->exchange_hash); s->exhash = NULL; @@ -1363,6 +1363,9 @@ static void ssh2_transport_process_queue(PacketProtocolLayer *ppl) * Actually perform the key exchange. */ s->exhash = ssh_hash_new(s->kex_alg->hash); + if (s->kex_shared_secret) + strbuf_free(s->kex_shared_secret); + s->kex_shared_secret = strbuf_new_nm(); put_stringz(s->exhash, s->client_greeting); put_stringz(s->exhash, s->server_greeting); put_string(s->exhash, s->client_kexinit->u, s->client_kexinit->len); @@ -1416,14 +1419,14 @@ static void ssh2_transport_process_queue(PacketProtocolLayer *ppl) strbuf *mac_key = strbuf_new_nm(); if (s->out.cipher) { - ssh2_mkkey(s, cipher_iv, s->K, s->exchange_hash, + ssh2_mkkey(s, cipher_iv, s->kex_shared_secret, s->exchange_hash, 'A' + s->out.mkkey_adjust, s->out.cipher->blksize); - ssh2_mkkey(s, cipher_key, s->K, s->exchange_hash, + ssh2_mkkey(s, cipher_key, s->kex_shared_secret, s->exchange_hash, 'C' + s->out.mkkey_adjust, s->out.cipher->padded_keybytes); } if (s->out.mac) { - ssh2_mkkey(s, mac_key, s->K, s->exchange_hash, + ssh2_mkkey(s, mac_key, s->kex_shared_secret, s->exchange_hash, 'E' + s->out.mkkey_adjust, s->out.mac->keylen); } @@ -1508,14 +1511,14 @@ static void ssh2_transport_process_queue(PacketProtocolLayer *ppl) strbuf *mac_key = strbuf_new_nm(); if (s->in.cipher) { - ssh2_mkkey(s, cipher_iv, s->K, s->exchange_hash, + ssh2_mkkey(s, cipher_iv, s->kex_shared_secret, s->exchange_hash, 'A' + s->in.mkkey_adjust, s->in.cipher->blksize); - ssh2_mkkey(s, cipher_key, s->K, s->exchange_hash, + ssh2_mkkey(s, cipher_key, s->kex_shared_secret, s->exchange_hash, 'C' + s->in.mkkey_adjust, s->in.cipher->padded_keybytes); } if (s->in.mac) { - ssh2_mkkey(s, mac_key, s->K, s->exchange_hash, + ssh2_mkkey(s, mac_key, s->kex_shared_secret, s->exchange_hash, 'E' + s->in.mkkey_adjust, s->in.mac->keylen); } @@ -1533,7 +1536,8 @@ static void ssh2_transport_process_queue(PacketProtocolLayer *ppl) /* * Free shared secret. */ - mp_free(s->K); s->K = NULL; + strbuf_free(s->kex_shared_secret); + s->kex_shared_secret = NULL; /* * Update the specials menu to list the remaining uncertified host diff --git a/ssh/transport2.h b/ssh/transport2.h index 5f0df80e..aaec91a8 100644 --- a/ssh/transport2.h +++ b/ssh/transport2.h @@ -167,7 +167,8 @@ struct ssh2_transport_state { int nbits, pbits; bool warn_kex, warn_hk, warn_cscipher, warn_sccipher; - mp_int *p, *g, *e, *f, *K; + mp_int *p, *g, *e, *f; + strbuf *kex_shared_secret; strbuf *outgoing_kexinit, *incoming_kexinit; strbuf *client_kexinit, *server_kexinit; /* aliases to the above */ int kex_init_value, kex_reply_value;