diff --git a/ssh.c b/ssh.c index b4ad1a76..605c8e8a 100644 --- a/ssh.c +++ b/ssh.c @@ -1429,19 +1429,8 @@ static void ssh_disconnect(Ssh ssh, const char *client_reason, error = dupprintf("Disconnected: %s", client_reason); else error = dupstr("Disconnected"); - if (wire_reason) { - if (ssh->version == 1) { - PktOut *pktout = ssh_bpp_new_pktout(ssh->bpp, SSH1_MSG_DISCONNECT); - put_stringz(pktout, wire_reason); - ssh_pkt_write(ssh, pktout); - } else if (ssh->version == 2) { - PktOut *pktout = ssh_bpp_new_pktout(ssh->bpp, SSH2_MSG_DISCONNECT); - put_uint32(pktout, code); - put_stringz(pktout, wire_reason); - put_stringz(pktout, "en"); /* language tag */ - ssh_pkt_write(ssh, pktout); - } - } + if (wire_reason) + ssh_bpp_queue_disconnect(ssh->bpp, wire_reason, code); ssh->close_expected = TRUE; ssh->clean_exit = clean_exit; ssh_closing(&ssh->plugvt, error, 0, 0); diff --git a/ssh1bpp.c b/ssh1bpp.c index 24eaeab4..36e03b5d 100644 --- a/ssh1bpp.c +++ b/ssh1bpp.c @@ -31,6 +31,8 @@ struct ssh1_bpp_state { static void ssh1_bpp_free(BinaryPacketProtocol *bpp); static void ssh1_bpp_handle_input(BinaryPacketProtocol *bpp); static void ssh1_bpp_handle_output(BinaryPacketProtocol *bpp); +static void ssh1_bpp_queue_disconnect(BinaryPacketProtocol *bpp, + const char *msg, int category); static PktOut *ssh1_bpp_new_pktout(int type); static const struct BinaryPacketProtocolVtable ssh1_bpp_vtable = { @@ -38,6 +40,7 @@ static const struct BinaryPacketProtocolVtable ssh1_bpp_vtable = { ssh1_bpp_handle_input, ssh1_bpp_handle_output, ssh1_bpp_new_pktout, + ssh1_bpp_queue_disconnect, }; BinaryPacketProtocol *ssh1_bpp_new(void) @@ -302,3 +305,11 @@ static void ssh1_bpp_handle_output(BinaryPacketProtocol *bpp) ssh_free_pktout(pkt); } } + +static void ssh1_bpp_queue_disconnect(BinaryPacketProtocol *bpp, + const char *msg, int category) +{ + PktOut *pkt = ssh_bpp_new_pktout(bpp, SSH1_MSG_DISCONNECT); + put_stringz(pkt, msg); + pq_push(&bpp->out_pq, pkt); +} diff --git a/ssh2bpp-bare.c b/ssh2bpp-bare.c index 1c838303..5b58a3c7 100644 --- a/ssh2bpp-bare.c +++ b/ssh2bpp-bare.c @@ -30,6 +30,7 @@ static const struct BinaryPacketProtocolVtable ssh2_bare_bpp_vtable = { ssh2_bare_bpp_handle_input, ssh2_bare_bpp_handle_output, ssh2_bare_bpp_new_pktout, + ssh2_bpp_queue_disconnect, /* in sshcommon.c */ }; BinaryPacketProtocol *ssh2_bare_bpp_new(void) diff --git a/ssh2bpp.c b/ssh2bpp.c index 1bc27340..b476f33b 100644 --- a/ssh2bpp.c +++ b/ssh2bpp.c @@ -48,6 +48,7 @@ static const struct BinaryPacketProtocolVtable ssh2_bpp_vtable = { ssh2_bpp_handle_input, ssh2_bpp_handle_output, ssh2_bpp_new_pktout, + ssh2_bpp_queue_disconnect, /* in sshcommon.c */ }; BinaryPacketProtocol *ssh2_bpp_new(struct DataTransferStats *stats) diff --git a/sshbpp.h b/sshbpp.h index c3f26728..b3eb9985 100644 --- a/sshbpp.h +++ b/sshbpp.h @@ -10,6 +10,8 @@ struct BinaryPacketProtocolVtable { void (*handle_input)(BinaryPacketProtocol *); void (*handle_output)(BinaryPacketProtocol *); PktOut *(*new_pktout)(int type); + void (*queue_disconnect)(BinaryPacketProtocol *, + const char *msg, int category); }; struct BinaryPacketProtocol { @@ -39,6 +41,8 @@ struct BinaryPacketProtocol { #define ssh_bpp_handle_input(bpp) ((bpp)->vt->handle_input(bpp)) #define ssh_bpp_handle_output(bpp) ((bpp)->vt->handle_output(bpp)) #define ssh_bpp_new_pktout(bpp, type) ((bpp)->vt->new_pktout(type)) +#define ssh_bpp_queue_disconnect(bpp, msg, cat) \ + ((bpp)->vt->queue_disconnect(bpp, msg, cat)) /* ssh_bpp_free is more than just a macro wrapper on the vtable; it * does centralised parts of the freeing too. */ @@ -58,7 +62,9 @@ void ssh1_bpp_requested_compression(BinaryPacketProtocol *bpp); * up in_pq and out_pq, and initialising input_consumer. */ void ssh_bpp_common_setup(BinaryPacketProtocol *); -/* Common helper function between the SSH-2 full and bare BPPs */ +/* Common helper functions between the SSH-2 full and bare BPPs */ +void ssh2_bpp_queue_disconnect(BinaryPacketProtocol *bpp, + const char *msg, int category); int ssh2_bpp_check_unimplemented(BinaryPacketProtocol *bpp, PktIn *pktin); /* diff --git a/sshcommon.c b/sshcommon.c index 088e1c88..5a2270b4 100644 --- a/sshcommon.c +++ b/sshcommon.c @@ -664,6 +664,16 @@ void ssh_bpp_free(BinaryPacketProtocol *bpp) bpp->vt->free(bpp); } +void ssh2_bpp_queue_disconnect(BinaryPacketProtocol *bpp, + const char *msg, int category) +{ + PktOut *pkt = ssh_bpp_new_pktout(bpp, SSH2_MSG_DISCONNECT); + put_uint32(pkt, category); + put_stringz(pkt, msg); + put_stringz(pkt, "en"); /* language tag */ + pq_push(&bpp->out_pq, pkt); +} + #define BITMAP_UNIVERSAL(y, name, value) \ | (value >= y && value < y+32 ? 1UL << (value-y) : 0) #define BITMAP_CONDITIONAL(y, name, value, ctx) \ diff --git a/sshverstring.c b/sshverstring.c index f02b0e9b..ffead60c 100644 --- a/sshverstring.c +++ b/sshverstring.c @@ -43,12 +43,15 @@ static void ssh_verstring_free(BinaryPacketProtocol *bpp); static void ssh_verstring_handle_input(BinaryPacketProtocol *bpp); static void ssh_verstring_handle_output(BinaryPacketProtocol *bpp); static PktOut *ssh_verstring_new_pktout(int type); +static void ssh_verstring_queue_disconnect(BinaryPacketProtocol *bpp, + const char *msg, int category); static const struct BinaryPacketProtocolVtable ssh_verstring_vtable = { ssh_verstring_free, ssh_verstring_handle_input, ssh_verstring_handle_output, ssh_verstring_new_pktout, + ssh_verstring_queue_disconnect, }; static void ssh_detect_bugs(struct ssh_verstring_state *s); @@ -608,3 +611,9 @@ int ssh_verstring_get_bugs(BinaryPacketProtocol *bpp) FROMFIELD(bpp, struct ssh_verstring_state, bpp); return s->remote_bugs; } + +static void ssh_verstring_queue_disconnect(BinaryPacketProtocol *bpp, + const char *msg, int category) +{ + /* No way to send disconnect messages at this stage of the protocol! */ +}