diff --git a/ssh.c b/ssh.c index a790f717..8d35f447 100644 --- a/ssh.c +++ b/ssh.c @@ -365,6 +365,7 @@ static int do_ssh1_login(Ssh ssh, const unsigned char *in, int inlen, static void do_ssh2_authconn(Ssh ssh, const unsigned char *in, int inlen, struct Packet *pktin); static void ssh_channel_init(struct ssh_channel *c); +static struct ssh_channel *ssh_channel_msg(Ssh ssh, struct Packet *pktin); static void ssh2_channel_check_close(struct ssh_channel *c); static void ssh_channel_destroy(struct ssh_channel *c); static void ssh_channel_unthrottle(struct ssh_channel *c, int bufsize); @@ -985,6 +986,14 @@ struct ssh_tag { int cross_certifying; }; +static const char *ssh_pkt_type(Ssh ssh, int type) +{ + if (ssh->version == 1) + return ssh1_pkt_type(type); + else + return ssh2_pkt_type(ssh->pkt_kctx, ssh->pkt_actx, type); +} + #define logevent(s) logevent(ssh->frontend, s) /* logevent, only printf-formatted. */ @@ -7971,19 +7980,23 @@ static void ssh2_set_window(struct ssh_channel *c, int newwin) * Find the channel associated with a message. If there's no channel, * or it's not properly open, make a noise about it and return NULL. */ -static struct ssh_channel *ssh2_channel_msg(Ssh ssh, struct Packet *pktin) +static struct ssh_channel *ssh_channel_msg(Ssh ssh, struct Packet *pktin) { unsigned localid = ssh_pkt_getuint32(pktin); struct ssh_channel *c; + int halfopen_ok; + /* Is this message OK on a half-open connection? */ + if (ssh->version == 1) + halfopen_ok = (pktin->type == SSH1_MSG_CHANNEL_OPEN_CONFIRMATION || + pktin->type == SSH1_MSG_CHANNEL_OPEN_FAILURE); + else + halfopen_ok = (pktin->type == SSH2_MSG_CHANNEL_OPEN_CONFIRMATION || + pktin->type == SSH2_MSG_CHANNEL_OPEN_FAILURE); c = find234(ssh->channels, &localid, ssh_channelfind); - if (!c || - (c->type != CHAN_SHARING && c->halfopen && - pktin->type != SSH2_MSG_CHANNEL_OPEN_CONFIRMATION && - pktin->type != SSH2_MSG_CHANNEL_OPEN_FAILURE)) { + if (!c || (c->type != CHAN_SHARING && c->halfopen && !halfopen_ok)) { char *buf = dupprintf("Received %s for %s channel %u", - ssh2_pkt_type(ssh->pkt_kctx, ssh->pkt_actx, - pktin->type), + ssh_pkt_type(ssh, pktin->type), c ? "half-open" : "nonexistent", localid); ssh_disconnect(ssh, NULL, buf, SSH2_DISCONNECT_PROTOCOL_ERROR, FALSE); sfree(buf); @@ -8018,7 +8031,7 @@ static void ssh2_handle_winadj_response(struct ssh_channel *c, static void ssh2_msg_channel_response(Ssh ssh, struct Packet *pktin) { - struct ssh_channel *c = ssh2_channel_msg(ssh, pktin); + struct ssh_channel *c = ssh_channel_msg(ssh, pktin); struct outstanding_channel_request *ocr; if (!c) return; @@ -8046,7 +8059,7 @@ static void ssh2_msg_channel_response(Ssh ssh, struct Packet *pktin) static void ssh2_msg_channel_window_adjust(Ssh ssh, struct Packet *pktin) { struct ssh_channel *c; - c = ssh2_channel_msg(ssh, pktin); + c = ssh_channel_msg(ssh, pktin); if (!c) return; if (c->type == CHAN_SHARING) { @@ -8065,7 +8078,7 @@ static void ssh2_msg_channel_data(Ssh ssh, struct Packet *pktin) char *data; int length; struct ssh_channel *c; - c = ssh2_channel_msg(ssh, pktin); + c = ssh_channel_msg(ssh, pktin); if (!c) return; if (c->type == CHAN_SHARING) { @@ -8283,7 +8296,7 @@ static void ssh2_msg_channel_eof(Ssh ssh, struct Packet *pktin) { struct ssh_channel *c; - c = ssh2_channel_msg(ssh, pktin); + c = ssh_channel_msg(ssh, pktin); if (!c) return; if (c->type == CHAN_SHARING) { @@ -8298,7 +8311,7 @@ static void ssh2_msg_channel_close(Ssh ssh, struct Packet *pktin) { struct ssh_channel *c; - c = ssh2_channel_msg(ssh, pktin); + c = ssh_channel_msg(ssh, pktin); if (!c) return; if (c->type == CHAN_SHARING) { @@ -8380,7 +8393,7 @@ static void ssh2_msg_channel_open_confirmation(Ssh ssh, struct Packet *pktin) { struct ssh_channel *c; - c = ssh2_channel_msg(ssh, pktin); + c = ssh_channel_msg(ssh, pktin); if (!c) return; if (c->type == CHAN_SHARING) { @@ -8388,7 +8401,7 @@ static void ssh2_msg_channel_open_confirmation(Ssh ssh, struct Packet *pktin) pktin->body, pktin->length); return; } - assert(c->halfopen); /* ssh2_channel_msg will have enforced this */ + assert(c->halfopen); /* ssh_channel_msg will have enforced this */ c->remoteid = ssh_pkt_getuint32(pktin); c->halfopen = FALSE; c->v.v2.remwindow = ssh_pkt_getuint32(pktin); @@ -8440,7 +8453,7 @@ static void ssh2_msg_channel_open_failure(Ssh ssh, struct Packet *pktin) int reason_length; struct ssh_channel *c; - c = ssh2_channel_msg(ssh, pktin); + c = ssh_channel_msg(ssh, pktin); if (!c) return; if (c->type == CHAN_SHARING) { @@ -8448,7 +8461,7 @@ static void ssh2_msg_channel_open_failure(Ssh ssh, struct Packet *pktin) pktin->body, pktin->length); return; } - assert(c->halfopen); /* ssh2_channel_msg will have enforced this */ + assert(c->halfopen); /* ssh_channel_msg will have enforced this */ if (c->type == CHAN_SOCKDATA_DORMANT) { reason_code = ssh_pkt_getuint32(pktin); @@ -8493,7 +8506,7 @@ static void ssh2_msg_channel_request(Ssh ssh, struct Packet *pktin) struct ssh_channel *c; struct Packet *pktout; - c = ssh2_channel_msg(ssh, pktin); + c = ssh_channel_msg(ssh, pktin); if (!c) return; if (c->type == CHAN_SHARING) {