diff --git a/ssh/kex2-client.c b/ssh/kex2-client.c index 9a8f75e2..633360ec 100644 --- a/ssh/kex2-client.c +++ b/ssh/kex2-client.c @@ -156,7 +156,9 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) return; } } - s->K = dh_find_K(s->dh_ctx, s->f); + mp_int *K = dh_find_K(s->dh_ctx, s->f); + put_mp_ssh2(s->kex_shared_secret, K); + mp_free(K); /* We assume everything from now on will be quick, and it might * involve user interaction. */ @@ -230,13 +232,15 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) { ptrlen keydata = get_string(pktin); put_stringpl(s->exhash, keydata); - s->K = ssh_ecdhkex_getkey(s->ecdh_key, keydata); - if (!get_err(pktin) && !s->K) { + mp_int *K = ssh_ecdhkex_getkey(s->ecdh_key, keydata); + if (!get_err(pktin) && !K) { ssh_proto_error(s->ppl.ssh, "Received invalid elliptic curve " "point in ECDH reply"); *aborted = true; return; } + put_mp_ssh2(s->kex_shared_secret, K); + mp_free(K); } s->sigdata = get_string(pktin); @@ -485,7 +489,9 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) return; } } - s->K = dh_find_K(s->dh_ctx, s->f); + mp_int *K = dh_find_K(s->dh_ctx, s->f); + put_mp_ssh2(s->kex_shared_secret, K); + mp_free(K); /* We assume everything from now on will be quick, and it might * involve user interaction. */ @@ -584,15 +590,21 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) strbuf *buf, *outstr; mp_int *tmp = mp_random_bits(nbits - 1); - s->K = mp_power_2(nbits - 1); - mp_add_into(s->K, s->K, tmp); + mp_int *K = mp_power_2(nbits - 1); + mp_add_into(K, K, tmp); mp_free(tmp); /* * Encode this as an mpint. */ buf = strbuf_new_nm(); - put_mp_ssh2(buf, s->K); + put_mp_ssh2(buf, K); + + /* + * Store a copy as the output shared secret from the kex. + */ + put_mp_ssh2(s->kex_shared_secret, K); + mp_free(K); /* * Encrypt it with the given RSA key. diff --git a/ssh/kex2-server.c b/ssh/kex2-server.c index 3c017077..9657589b 100644 --- a/ssh/kex2-server.c +++ b/ssh/kex2-server.c @@ -161,7 +161,9 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) return; } } - s->K = dh_find_K(s->dh_ctx, s->f); + mp_int *K = dh_find_K(s->dh_ctx, s->f); + put_mp_ssh2(s->kex_shared_secret, K); + mp_free(K); if (dh_is_gex(s->kex_alg)) { if (s->dh_got_size_bounds) @@ -217,13 +219,15 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) ptrlen keydata = get_string(pktin); put_stringpl(s->exhash, keydata); - s->K = ssh_ecdhkex_getkey(s->ecdh_key, keydata); - if (!get_err(pktin) && !s->K) { + mp_int *K = ssh_ecdhkex_getkey(s->ecdh_key, keydata); + if (!get_err(pktin) && !K) { ssh_proto_error(s->ppl.ssh, "Received invalid elliptic curve " "point in ECDH initial packet"); *aborted = true; return; } + put_mp_ssh2(s->kex_shared_secret, K); + mp_free(K); } pktout = ssh_bpp_new_pktout(s->ppl.bpp, SSH2_MSG_KEX_ECDH_REPLY); @@ -301,19 +305,23 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted) return; } + mp_int *K; { ptrlen encrypted_secret = get_string(pktin); put_stringpl(s->exhash, encrypted_secret); - s->K = ssh_rsakex_decrypt( + K = ssh_rsakex_decrypt( s->rsa_kex_key, s->kex_alg->hash, encrypted_secret); } - if (!s->K) { + if (!K) { ssh_proto_error(s->ppl.ssh, "Unable to decrypt RSA kex secret"); *aborted = true; return; } + put_mp_ssh2(s->kex_shared_secret, K); + mp_free(K); + if (s->rsa_kex_key_needs_freeing) { ssh_rsakex_freekey(s->rsa_kex_key); sfree(s->rsa_kex_key); diff --git a/ssh/transport2.c b/ssh/transport2.c index 2f9b02af..4ce4156d 100644 --- a/ssh/transport2.c +++ b/ssh/transport2.c @@ -219,7 +219,7 @@ static void ssh2_transport_free(PacketProtocolLayer *ppl) if (s->f) mp_free(s->f); if (s->p) mp_free(s->p); if (s->g) mp_free(s->g); - if (s->K) mp_free(s->K); + if (s->kex_shared_secret) strbuf_free(s->kex_shared_secret); if (s->dh_ctx) dh_cleanup(s->dh_ctx); if (s->rsa_kex_key_needs_freeing) { @@ -245,7 +245,7 @@ static void ssh2_transport_free(PacketProtocolLayer *ppl) */ static void ssh2_mkkey( struct ssh2_transport_state *s, strbuf *out, - mp_int *K, unsigned char *H, char chr, int keylen) + strbuf *kex_shared_secret, unsigned char *H, char chr, int keylen) { int hlen = s->kex_alg->hash->hlen; int keylen_padded; @@ -273,7 +273,7 @@ static void ssh2_mkkey( /* First hlen bytes. */ h = ssh_hash_new(s->kex_alg->hash); if (!(s->ppl.remote_bugs & BUG_SSH2_DERIVEKEY)) - put_mp_ssh2(h, K); + put_datapl(h, ptrlen_from_strbuf(kex_shared_secret)); put_data(h, H, hlen); put_byte(h, chr); put_data(h, s->session_id, s->session_id_len); @@ -285,7 +285,7 @@ static void ssh2_mkkey( ssh_hash_reset(h); if (!(s->ppl.remote_bugs & BUG_SSH2_DERIVEKEY)) - put_mp_ssh2(h, K); + put_datapl(h, ptrlen_from_strbuf(kex_shared_secret)); put_data(h, H, hlen); for (offset = hlen; offset < keylen_padded; offset += hlen) { @@ -1093,7 +1093,7 @@ static bool ssh2_scan_kexinits( void ssh2transport_finalise_exhash(struct ssh2_transport_state *s) { - put_mp_ssh2(s->exhash, s->K); + put_datapl(s->exhash, ptrlen_from_strbuf(s->kex_shared_secret)); assert(ssh_hash_alg(s->exhash)->hlen <= sizeof(s->exchange_hash)); ssh_hash_final(s->exhash, s->exchange_hash); s->exhash = NULL; @@ -1363,6 +1363,9 @@ static void ssh2_transport_process_queue(PacketProtocolLayer *ppl) * Actually perform the key exchange. */ s->exhash = ssh_hash_new(s->kex_alg->hash); + if (s->kex_shared_secret) + strbuf_free(s->kex_shared_secret); + s->kex_shared_secret = strbuf_new_nm(); put_stringz(s->exhash, s->client_greeting); put_stringz(s->exhash, s->server_greeting); put_string(s->exhash, s->client_kexinit->u, s->client_kexinit->len); @@ -1416,14 +1419,14 @@ static void ssh2_transport_process_queue(PacketProtocolLayer *ppl) strbuf *mac_key = strbuf_new_nm(); if (s->out.cipher) { - ssh2_mkkey(s, cipher_iv, s->K, s->exchange_hash, + ssh2_mkkey(s, cipher_iv, s->kex_shared_secret, s->exchange_hash, 'A' + s->out.mkkey_adjust, s->out.cipher->blksize); - ssh2_mkkey(s, cipher_key, s->K, s->exchange_hash, + ssh2_mkkey(s, cipher_key, s->kex_shared_secret, s->exchange_hash, 'C' + s->out.mkkey_adjust, s->out.cipher->padded_keybytes); } if (s->out.mac) { - ssh2_mkkey(s, mac_key, s->K, s->exchange_hash, + ssh2_mkkey(s, mac_key, s->kex_shared_secret, s->exchange_hash, 'E' + s->out.mkkey_adjust, s->out.mac->keylen); } @@ -1508,14 +1511,14 @@ static void ssh2_transport_process_queue(PacketProtocolLayer *ppl) strbuf *mac_key = strbuf_new_nm(); if (s->in.cipher) { - ssh2_mkkey(s, cipher_iv, s->K, s->exchange_hash, + ssh2_mkkey(s, cipher_iv, s->kex_shared_secret, s->exchange_hash, 'A' + s->in.mkkey_adjust, s->in.cipher->blksize); - ssh2_mkkey(s, cipher_key, s->K, s->exchange_hash, + ssh2_mkkey(s, cipher_key, s->kex_shared_secret, s->exchange_hash, 'C' + s->in.mkkey_adjust, s->in.cipher->padded_keybytes); } if (s->in.mac) { - ssh2_mkkey(s, mac_key, s->K, s->exchange_hash, + ssh2_mkkey(s, mac_key, s->kex_shared_secret, s->exchange_hash, 'E' + s->in.mkkey_adjust, s->in.mac->keylen); } @@ -1533,7 +1536,8 @@ static void ssh2_transport_process_queue(PacketProtocolLayer *ppl) /* * Free shared secret. */ - mp_free(s->K); s->K = NULL; + strbuf_free(s->kex_shared_secret); + s->kex_shared_secret = NULL; /* * Update the specials menu to list the remaining uncertified host diff --git a/ssh/transport2.h b/ssh/transport2.h index 5f0df80e..aaec91a8 100644 --- a/ssh/transport2.h +++ b/ssh/transport2.h @@ -167,7 +167,8 @@ struct ssh2_transport_state { int nbits, pbits; bool warn_kex, warn_hk, warn_cscipher, warn_sccipher; - mp_int *p, *g, *e, *f, *K; + mp_int *p, *g, *e, *f; + strbuf *kex_shared_secret; strbuf *outgoing_kexinit, *incoming_kexinit; strbuf *client_kexinit, *server_kexinit; /* aliases to the above */ int kex_init_value, kex_reply_value;