diff --git a/ssh/transport2.c b/ssh/transport2.c index 748187d2..0263673d 100644 --- a/ssh/transport2.c +++ b/ssh/transport2.c @@ -1005,13 +1005,27 @@ static bool kexinit_keyword_found(ptrlen list, ptrlen keyword) return false; } -static bool ssh2_scan_kexinits( +typedef struct ScanKexinitsResult { + bool success; + + /* only if success is false */ + enum { + SKR_INCOMPLETE, + SKR_UNKNOWN_ID, + SKR_NO_AGREEMENT, + } error; + + const char *kind; /* what kind of thing did we fail to sort out? */ + ptrlen desc; /* and what was it? or what was the available list? */ +} ScanKexinitsResult; + +static ScanKexinitsResult ssh2_scan_kexinits( ptrlen client_kexinit, ptrlen server_kexinit, bool we_are_server, struct kexinit_algorithm_list kexlists[NKEXLIST], const ssh_kex **kex_alg, const ssh_keyalg **hostkey_alg, transport_direction *cs, transport_direction *sc, bool *warn_kex, bool *warn_hk, bool *warn_cscipher, bool *warn_sccipher, - Ssh *ssh, bool *ignore_guess_cs_packet, bool *ignore_guess_sc_packet, + bool *ignore_guess_cs_packet, bool *ignore_guess_sc_packet, struct server_hostkeys *server_hostkeys, unsigned *hkflags, bool *can_send_ext_info, bool first_time, bool *strict_kex) { @@ -1040,11 +1054,10 @@ static bool ssh2_scan_kexinits( clists[i] = get_string(client); slists[i] = get_string(server); if (get_err(client) || get_err(server)) { - /* Report a better error than the spurious "Couldn't - * agree" that we'd generate if we pressed on regardless - * and treated the empty get_string() result as genuine */ - ssh_proto_error(ssh, "KEXINIT packet was incomplete"); - return false; + ScanKexinitsResult skr = { + .success = false, .error = SKR_INCOMPLETE, + }; + return skr; } for (cfirst = true, clist = clists[i]; @@ -1092,10 +1105,11 @@ static bool ssh2_scan_kexinits( * produce a reasonably useful message instead of an * assertion failure. */ - ssh_sw_abort(ssh, "Selected %s \"%.*s\" does not correspond to " - "any supported algorithm", - kexlist_descr[i], PTRLEN_PRINTF(found)); - return false; + ScanKexinitsResult skr = { + .success = false, .error = SKR_UNKNOWN_ID, + .kind = kexlist_descr[i], .desc = found, + }; + return skr; } /* @@ -1150,9 +1164,11 @@ static bool ssh2_scan_kexinits( /* * Otherwise, any match failure _is_ a fatal error. */ - ssh_sw_abort(ssh, "Couldn't agree a %s (available: %.*s)", - kexlist_descr[i], PTRLEN_PRINTF(slists[i])); - return false; + ScanKexinitsResult skr = { + .success = false, .error = SKR_UNKNOWN_ID, + .kind = kexlist_descr[i], .desc = slists[i], + }; + return skr; } switch (i) { @@ -1248,7 +1264,33 @@ static bool ssh2_scan_kexinits( } } - return true; + ScanKexinitsResult skr = { .success = true }; + return skr; +} + +static void ssh2_report_scan_kexinits_error(Ssh *ssh, ScanKexinitsResult skr) +{ + assert(!skr.success); + + switch (skr.error) { + case SKR_INCOMPLETE: + /* Report a better error than the spurious "Couldn't + * agree" that we'd generate if we pressed on regardless + * and treated the empty get_string() result as genuine */ + ssh_proto_error(ssh, "KEXINIT packet was incomplete"); + break; + case SKR_UNKNOWN_ID: + ssh_sw_abort(ssh, "Selected %s \"%.*s\" does not correspond to " + "any supported algorithm", + skr.kind, PTRLEN_PRINTF(skr.desc)); + break; + case SKR_NO_AGREEMENT: + ssh_sw_abort(ssh, "Couldn't agree a %s (available: %.*s)", + skr.kind, PTRLEN_PRINTF(skr.desc)); + break; + default: + unreachable("bad ScanKexinitsResult"); + } } static inline bool delay_outgoing_kexinit(struct ssh2_transport_state *s) @@ -1529,16 +1571,19 @@ static void ssh2_transport_process_queue(PacketProtocolLayer *ppl) { struct server_hostkeys hks = { NULL, 0, 0 }; - if (!ssh2_scan_kexinits( + ScanKexinitsResult skr = ssh2_scan_kexinits( ptrlen_from_strbuf(s->client_kexinit), ptrlen_from_strbuf(s->server_kexinit), s->ssc != NULL, s->kexlists, &s->kex_alg, &s->hostkey_alg, s->cstrans, s->sctrans, &s->warn_kex, &s->warn_hk, &s->warn_cscipher, - &s->warn_sccipher, s->ppl.ssh, NULL, &s->ignorepkt, &hks, + &s->warn_sccipher, NULL, &s->ignorepkt, &hks, &s->hkflags, &s->can_send_ext_info, !s->got_session_id, - &s->strict_kex)) { + &s->strict_kex); + + if (!skr.success) { sfree(hks.indices); - return; /* false means a fatal error function was called */ + ssh2_report_scan_kexinits_error(s->ppl.ssh, skr); + return; /* we just called a fatal error function */ } /*