diff --git a/ssh/bpp.h b/ssh/bpp.h index 87e7d7e7..23af5236 100644 --- a/ssh/bpp.h +++ b/ssh/bpp.h @@ -138,12 +138,14 @@ void ssh2_bpp_new_outgoing_crypto( BinaryPacketProtocol *bpp, const ssh_cipheralg *cipher, const void *ckey, const void *iv, const ssh2_macalg *mac, bool etm_mode, const void *mac_key, - const ssh_compression_alg *compression, bool delayed_compression); + const ssh_compression_alg *compression, bool delayed_compression, + bool reset_sequence_number); void ssh2_bpp_new_incoming_crypto( BinaryPacketProtocol *bpp, const ssh_cipheralg *cipher, const void *ckey, const void *iv, const ssh2_macalg *mac, bool etm_mode, const void *mac_key, - const ssh_compression_alg *compression, bool delayed_compression); + const ssh_compression_alg *compression, bool delayed_compression, + bool reset_sequence_number); /* * A query method specific to the interface between ssh2transport and diff --git a/ssh/bpp2.c b/ssh/bpp2.c index e019dd2e..88003e82 100644 --- a/ssh/bpp2.c +++ b/ssh/bpp2.c @@ -106,7 +106,8 @@ void ssh2_bpp_new_outgoing_crypto( BinaryPacketProtocol *bpp, const ssh_cipheralg *cipher, const void *ckey, const void *iv, const ssh2_macalg *mac, bool etm_mode, const void *mac_key, - const ssh_compression_alg *compression, bool delayed_compression) + const ssh_compression_alg *compression, bool delayed_compression, + bool reset_sequence_number) { struct ssh2_bpp_state *s; assert(bpp->vt == &ssh2_bpp_vtable); @@ -150,6 +151,9 @@ void ssh2_bpp_new_outgoing_crypto( s->out.mac = NULL; } + if (reset_sequence_number) + s->out.sequence = 0; + if (delayed_compression && !s->seen_userauth_success) { s->out.pending_compression = compression; s->out_comp = NULL; @@ -174,7 +178,8 @@ void ssh2_bpp_new_incoming_crypto( BinaryPacketProtocol *bpp, const ssh_cipheralg *cipher, const void *ckey, const void *iv, const ssh2_macalg *mac, bool etm_mode, const void *mac_key, - const ssh_compression_alg *compression, bool delayed_compression) + const ssh_compression_alg *compression, bool delayed_compression, + bool reset_sequence_number) { struct ssh2_bpp_state *s; assert(bpp->vt == &ssh2_bpp_vtable); @@ -231,6 +236,9 @@ void ssh2_bpp_new_incoming_crypto( * start consuming the input data again. */ s->pending_newkeys = false; + if (reset_sequence_number) + s->in.sequence = 0; + /* And schedule a run of handle_input, in case there's already * input data in the queue. */ queue_idempotent_callback(&s->bpp.ic_in_raw); diff --git a/ssh/transport2.c b/ssh/transport2.c index 053ca570..e60c1b30 100644 --- a/ssh/transport2.c +++ b/ssh/transport2.c @@ -29,6 +29,10 @@ const static ssh2_macalg *const buggymacs[] = { const static ptrlen ext_info_c = PTRLEN_DECL_LITERAL("ext-info-c"); const static ptrlen ext_info_s = PTRLEN_DECL_LITERAL("ext-info-s"); +const static ptrlen kex_strict_c = + PTRLEN_DECL_LITERAL("kex-strict-c-v00@openssh.com"); +const static ptrlen kex_strict_s = + PTRLEN_DECL_LITERAL("kex-strict-s-v00@openssh.com"); static ssh_compressor *ssh_comp_none_init(void) { @@ -465,6 +469,31 @@ static bool ssh2_transport_filter_queue(struct ssh2_transport_state *s) { PktIn *pktin; + if (!s->enabled_incoming_crypto) { + /* + * Record the fact that we've seen any non-KEXINIT packet at + * the head of our queue. + * + * This enables us to check later that the initial incoming + * KEXINIT was the very first packet, if scanning the KEXINITs + * turns out to enable strict-kex mode. + */ + PktIn *pktin = pq_peek(s->ppl.in_pq); + if (pktin && pktin->type != SSH2_MSG_KEXINIT) + s->seen_non_kexinit = true; + + if (s->strict_kex) { + /* + * Also, if we're already in strict-KEX mode and haven't + * turned on crypto yet, don't do any actual filtering. + * This ensures that extraneous packets _after_ the + * KEXINIT will go to the main coroutine, which will + * complain about them. + */ + return false; + } + } + while (1) { if (ssh2_common_filter_queue(&s->ppl)) return true; @@ -940,10 +969,13 @@ static void ssh2_write_kexinit_lists( add_to_commasep_pl(list, kexlists[i].algs[j].name); } if (i == KEXLIST_KEX && first_time) { - if (our_hostkeys) /* we're the server */ + if (our_hostkeys) { /* we're the server */ add_to_commasep_pl(list, ext_info_s); - else /* we're the client */ + add_to_commasep_pl(list, kex_strict_s); + } else { /* we're the client */ add_to_commasep_pl(list, ext_info_c); + add_to_commasep_pl(list, kex_strict_c); + } } put_stringsb(pktout, list); } @@ -974,7 +1006,7 @@ static bool ssh2_scan_kexinits( bool *warn_kex, bool *warn_hk, bool *warn_cscipher, bool *warn_sccipher, Ssh *ssh, bool *ignore_guess_cs_packet, bool *ignore_guess_sc_packet, struct server_hostkeys *server_hostkeys, unsigned *hkflags, - bool *can_send_ext_info) + bool *can_send_ext_info, bool first_time, bool *strict_kex) { BinarySource client[1], server[1]; int i; @@ -1181,6 +1213,14 @@ static bool ssh2_scan_kexinits( we_are_server ? ext_info_c : ext_info_s)) *can_send_ext_info = true; + /* + * Check whether the other side advertised support for kex-strict. + */ + if (first_time && kexinit_keyword_found( + we_are_server ? clists[KEXLIST_KEX] : slists[KEXLIST_KEX], + we_are_server ? kex_strict_c : kex_strict_s)) + *strict_kex = true; + if (server_hostkeys) { /* * Finally, make an auxiliary pass over the server's host key @@ -1244,10 +1284,26 @@ static void filter_outgoing_kexinit(struct ssh2_transport_state *s) strbuf_clear(out); ptrlen olist = get_string(osrc), ilist = get_string(isrc); for (ptrlen oword; get_commasep_word(&olist, &oword) ;) { + ptrlen searchword = oword; ptrlen ilist_copy = ilist; + + /* + * Special case: the kex_strict keywords are + * asymmetrically named, so if we're contemplating + * including one of them in our filtered KEXINIT, we + * should search the other side's KEXINIT for the _other_ + * one, not the same one. + */ + if (i == KEXLIST_KEX) { + if (ptrlen_eq_ptrlen(oword, kex_strict_c)) + searchword = kex_strict_s; + else if (ptrlen_eq_ptrlen(oword, kex_strict_s)) + searchword = kex_strict_c; + } + bool add = false; for (ptrlen iword; get_commasep_word(&ilist_copy, &iword) ;) { - if (ptrlen_eq_ptrlen(oword, iword)) { + if (ptrlen_eq_ptrlen(searchword, iword)) { /* Found this word in the incoming list. */ add = true; break; @@ -1472,11 +1528,25 @@ static void ssh2_transport_process_queue(PacketProtocolLayer *ppl) s->kexlists, &s->kex_alg, &s->hostkey_alg, s->cstrans, s->sctrans, &s->warn_kex, &s->warn_hk, &s->warn_cscipher, &s->warn_sccipher, s->ppl.ssh, NULL, &s->ignorepkt, &hks, - &s->hkflags, &s->can_send_ext_info)) { + &s->hkflags, &s->can_send_ext_info, !s->got_session_id, + &s->strict_kex)) { sfree(hks.indices); return; /* false means a fatal error function was called */ } + /* + * If we've just turned on strict kex mode, say so, and + * retrospectively fault any pre-KEXINIT extraneous packets. + */ + if (!s->got_session_id && s->strict_kex) { + ppl_logevent("Enabling strict key exchange semantics"); + if (s->seen_non_kexinit) { + ssh_proto_error(s->ppl.ssh, "Received a packet before KEXINIT " + "in strict-kex mode"); + return; + } + } + /* * In addition to deciding which host key we're actually going * to use, we should make a list of the host keys offered by @@ -1669,7 +1739,9 @@ static void ssh2_transport_process_queue(PacketProtocolLayer *ppl) s->ppl.bpp, s->out.cipher, cipher_key->u, cipher_iv->u, s->out.mac, s->out.etm_mode, mac_key->u, - s->out.comp, s->out.comp_delayed); + s->out.comp, s->out.comp_delayed, + s->strict_kex); + s->enabled_outgoing_crypto = true; strbuf_free(cipher_key); strbuf_free(cipher_iv); @@ -1761,7 +1833,9 @@ static void ssh2_transport_process_queue(PacketProtocolLayer *ppl) s->ppl.bpp, s->in.cipher, cipher_key->u, cipher_iv->u, s->in.mac, s->in.etm_mode, mac_key->u, - s->in.comp, s->in.comp_delayed); + s->in.comp, s->in.comp_delayed, + s->strict_kex); + s->enabled_incoming_crypto = true; strbuf_free(cipher_key); strbuf_free(cipher_iv); diff --git a/ssh/transport2.h b/ssh/transport2.h index 204573fb..1322cf5b 100644 --- a/ssh/transport2.h +++ b/ssh/transport2.h @@ -202,6 +202,8 @@ struct ssh2_transport_state { bool warned_about_no_gss_transient_hostkey; bool got_session_id; bool can_send_ext_info, post_newkeys_ext_info; + bool strict_kex, enabled_outgoing_crypto, enabled_incoming_crypto; + bool seen_non_kexinit; SeatPromptResult spr; bool guessok; bool ignorepkt;