From 3c149087e4d33cce70e3e856b9f43a58d242e5c4 Mon Sep 17 00:00:00 2001 From: Ben Harris Date: Wed, 3 Oct 2007 20:29:27 +0000 Subject: [PATCH] Take the code that does flow control in SSH-1, and make it work in SSH-2 as well. This won't be triggered in the usual case, but it's useful if the remote end ignores our window, or if we're in "simple" mode and setting the window far larger than is necessary. [originally from svn r7756] --- ssh.c | 86 ++++++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 59 insertions(+), 27 deletions(-) diff --git a/ssh.c b/ssh.c index fb24f435..87fabfbd 100644 --- a/ssh.c +++ b/ssh.c @@ -560,10 +560,12 @@ struct ssh_channel { * A channel is completely finished with when all four bits are set. */ int closes; + /* + * True if this channel is causing the underlying connection to be + * throttled. + */ + int throttling_conn; union { - struct ssh1_data_channel { - int throttling; - } v1; struct ssh2_data_channel { bufchain outbuffer; unsigned remwindow, remmaxpkt; @@ -810,7 +812,7 @@ struct ssh_tag { void *x11auth; int version; - int v1_throttle_count; + int conn_throttle_count; int overall_bufsize; int throttled_all; int v1_stdout_throttling; @@ -2873,14 +2875,14 @@ static const char *connect_to_host(Ssh ssh, char *host, int port, /* * Throttle or unthrottle the SSH connection. */ -static void ssh1_throttle(Ssh ssh, int adjust) +static void ssh_throttle_conn(Ssh ssh, int adjust) { - int old_count = ssh->v1_throttle_count; - ssh->v1_throttle_count += adjust; - assert(ssh->v1_throttle_count >= 0); - if (ssh->v1_throttle_count && !old_count) { + int old_count = ssh->conn_throttle_count; + ssh->conn_throttle_count += adjust; + assert(ssh->conn_throttle_count >= 0); + if (ssh->conn_throttle_count && !old_count) { ssh_set_frozen(ssh, 1); - } else if (!ssh->v1_throttle_count && old_count) { + } else if (!ssh->conn_throttle_count && old_count) { ssh_set_frozen(ssh, 0); } } @@ -4054,17 +4056,20 @@ int sshfwd_write(struct ssh_channel *c, char *buf, int len) void sshfwd_unthrottle(struct ssh_channel *c, int bufsize) { Ssh ssh = c->ssh; + int buflimit; if (ssh->state == SSH_STATE_CLOSED) return; if (ssh->version == 1) { - if (c->v.v1.throttling && bufsize < SSH1_BUFFER_LIMIT) { - c->v.v1.throttling = 0; - ssh1_throttle(ssh, -1); - } + buflimit = SSH1_BUFFER_LIMIT; } else { - ssh2_set_window(c, c->v.v2.locmaxwin - bufsize); + buflimit = c->v.v2.locmaxwin; + ssh2_set_window(c, bufsize < buflimit ? buflimit - bufsize : 0); + } + if (c->throttling_conn && bufsize <= buflimit) { + c->throttling_conn = 0; + ssh_throttle_conn(ssh, -1); } } @@ -4493,7 +4498,7 @@ static void ssh1_smsg_stdout_stderr_data(Ssh ssh, struct Packet *pktin) string, stringlen); if (!ssh->v1_stdout_throttling && bufsize > SSH1_BUFFER_LIMIT) { ssh->v1_stdout_throttling = 1; - ssh1_throttle(ssh, +1); + ssh_throttle_conn(ssh, +1); } } @@ -4527,7 +4532,7 @@ static void ssh1_smsg_x11_open(Ssh ssh, struct Packet *pktin) c->halfopen = FALSE; c->localid = alloc_channel_id(ssh); c->closes = 0; - c->v.v1.throttling = 0; + c->throttling_conn = 0; c->type = CHAN_X11; /* identify channel type */ add234(ssh->channels, c); send_packet(ssh, SSH1_MSG_CHANNEL_OPEN_CONFIRMATION, @@ -4556,7 +4561,7 @@ static void ssh1_smsg_agent_open(Ssh ssh, struct Packet *pktin) c->halfopen = FALSE; c->localid = alloc_channel_id(ssh); c->closes = 0; - c->v.v1.throttling = 0; + c->throttling_conn = 0; c->type = CHAN_AGENT; /* identify channel type */ c->u.a.lensofar = 0; add234(ssh->channels, c); @@ -4610,7 +4615,7 @@ static void ssh1_msg_port_open(Ssh ssh, struct Packet *pktin) c->halfopen = FALSE; c->localid = alloc_channel_id(ssh); c->closes = 0; - c->v.v1.throttling = 0; + c->throttling_conn = 0; c->type = CHAN_SOCKDATA; /* identify channel type */ add234(ssh->channels, c); send_packet(ssh, SSH1_MSG_CHANNEL_OPEN_CONFIRMATION, @@ -4632,7 +4637,7 @@ static void ssh1_msg_channel_open_confirmation(Ssh ssh, struct Packet *pktin) c->remoteid = localid; c->halfopen = FALSE; c->type = CHAN_SOCKDATA; - c->v.v1.throttling = 0; + c->throttling_conn = 0; pfd_confirm(c->u.pfd.s); } @@ -4768,9 +4773,9 @@ static void ssh1_msg_channel_data(Ssh ssh, struct Packet *pktin) bufsize = 0; /* agent channels never back up */ break; } - if (!c->v.v1.throttling && bufsize > SSH1_BUFFER_LIMIT) { - c->v.v1.throttling = 1; - ssh1_throttle(ssh, +1); + if (!c->throttling_conn && bufsize > SSH1_BUFFER_LIMIT) { + c->throttling_conn = 1; + ssh_throttle_conn(ssh, +1); } } } @@ -6458,6 +6463,17 @@ static void ssh2_msg_channel_data(Ssh ssh, struct Packet *pktin) */ ssh2_set_window(c, bufsize < c->v.v2.locmaxwin ? c->v.v2.locmaxwin - bufsize : 0); + /* + * If we're either buffering way too much data, or if we're + * buffering anything at all and we're in "simple" mode, + * throttle the whole channel. + */ + if ((bufsize > c->v.v2.locmaxwin || + (ssh->cfg.ssh_simple && bufsize > 0)) && + !c->throttling_conn) { + c->throttling_conn = 1; + ssh_throttle_conn(ssh, +1); + } } } @@ -6905,6 +6921,7 @@ static void ssh2_msg_channel_open(Ssh ssh, struct Packet *pktin) } else { c->localid = alloc_channel_id(ssh); c->closes = 0; + c->throttling_conn = FALSE; c->v.v2.locwindow = c->v.v2.locmaxwin = OUR_V2_WINSIZE; c->v.v2.remwindow = winsize; c->v.v2.remmaxpkt = pktsize; @@ -8119,6 +8136,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen, s->pktout = ssh2_pkt_init(SSH2_MSG_CHANNEL_OPEN); ssh2_pkt_addstring(s->pktout, "direct-tcpip"); ssh2_pkt_adduint32(s->pktout, ssh->mainchan->localid); + ssh->mainchan->throttling_conn = FALSE; ssh->mainchan->v.v2.locwindow = ssh->mainchan->v.v2.locmaxwin = ssh->mainchan->v.v2.remlocwin = ssh->cfg.ssh_simple ? OUR_V2_BIGWIN : OUR_V2_WINSIZE; @@ -8166,6 +8184,7 @@ static void do_ssh2_authconn(Ssh ssh, unsigned char *in, int inlen, s->pktout = ssh2_pkt_init(SSH2_MSG_CHANNEL_OPEN); ssh2_pkt_addstring(s->pktout, "session"); ssh2_pkt_adduint32(s->pktout, ssh->mainchan->localid); + ssh->mainchan->throttling_conn = FALSE; ssh->mainchan->v.v2.locwindow = ssh->mainchan->v.v2.locmaxwin = ssh->mainchan->v.v2.remlocwin = ssh->cfg.ssh_simple ? OUR_V2_BIGWIN : OUR_V2_WINSIZE; @@ -8792,7 +8811,7 @@ static const char *ssh_init(void *frontend_handle, void **backend_handle, ssh->send_ok = 0; ssh->editing = 0; ssh->echoing = 0; - ssh->v1_throttle_count = 0; + ssh->conn_throttle_count = 0; ssh->overall_bufsize = 0; ssh->fallback_cmd = 0; @@ -9239,15 +9258,27 @@ void *new_sock_channel(void *handle, Socket s) static void ssh_unthrottle(void *handle, int bufsize) { Ssh ssh = (Ssh) handle; + int buflimit; + if (ssh->version == 1) { if (ssh->v1_stdout_throttling && bufsize < SSH1_BUFFER_LIMIT) { ssh->v1_stdout_throttling = 0; - ssh1_throttle(ssh, -1); + ssh_throttle_conn(ssh, -1); } } else { - if (ssh->mainchan) + if (ssh->mainchan) { ssh2_set_window(ssh->mainchan, - ssh->mainchan->v.v2.locmaxwin - bufsize); + bufsize < ssh->mainchan->v.v2.locmaxwin ? + ssh->mainchan->v.v2.locmaxwin - bufsize : 0); + if (ssh->cfg.ssh_simple) + buflimit = 0; + else + buflimit = ssh->mainchan->v.v2.locmaxwin; + if (ssh->mainchan->throttling_conn && bufsize <= buflimit) { + ssh->mainchan->throttling_conn = 0; + ssh_throttle_conn(ssh, -1); + } + } } } @@ -9270,6 +9301,7 @@ void ssh_send_port_open(void *channel, char *hostname, int port, char *org) pktout = ssh2_pkt_init(SSH2_MSG_CHANNEL_OPEN); ssh2_pkt_addstring(pktout, "direct-tcpip"); ssh2_pkt_adduint32(pktout, c->localid); + c->throttling_conn = FALSE; c->v.v2.locwindow = c->v.v2.locmaxwin = OUR_V2_WINSIZE; c->v.v2.remlocwin = OUR_V2_WINSIZE; c->v.v2.winadj_head = c->v.v2.winadj_head = NULL;