diff --git a/ssh.c b/ssh.c index b22c2179..d0d7e066 100644 --- a/ssh.c +++ b/ssh.c @@ -969,6 +969,9 @@ struct ssh_tag { struct PacketQueue pq_ssh1_login; struct IdempotentCallback ssh1_login_icb; + struct PacketQueue pq_ssh1_connection; + struct IdempotentCallback ssh1_connection_icb; + struct PacketQueue pq_ssh2_authconn; struct IdempotentCallback ssh2_authconn_icb; @@ -6023,8 +6026,11 @@ int ssh_agent_forwarding_permitted(Ssh ssh) return conf_get_int(ssh->conf, CONF_agentfwd) && agent_exists(); } -static void do_ssh1_connection(Ssh ssh, struct Packet *pktin) +static void do_ssh1_connection(void *vctx) { + Ssh ssh = (Ssh)vctx; + struct Packet *pktin; + crBegin(ssh->do_ssh1_connection_crstate); ssh->packet_dispatch[SSH1_SMSG_STDOUT_DATA] = @@ -6044,9 +6050,7 @@ static void do_ssh1_connection(Ssh ssh, struct Packet *pktin) if (ssh_agent_forwarding_permitted(ssh)) { logevent("Requesting agent forwarding"); send_packet(ssh, SSH1_CMSG_AGENT_REQUEST_FORWARDING, PKT_END); - do { - crReturnV; - } while (!pktin); + crMaybeWaitUntilV((pktin = pq_pop(&ssh->pq_ssh1_connection)) != NULL); if (pktin->type != SSH1_SMSG_SUCCESS && pktin->type != SSH1_SMSG_FAILURE) { bombout(("Protocol confusion")); @@ -6086,9 +6090,8 @@ static void do_ssh1_connection(Ssh ssh, struct Packet *pktin) PKT_STR, ssh->x11auth->datastring, PKT_END); } - do { - crReturnV; - } while (!pktin); + crMaybeWaitUntilV((pktin = pq_pop(&ssh->pq_ssh1_connection)) + != NULL); if (pktin->type != SSH1_SMSG_SUCCESS && pktin->type != SSH1_SMSG_FAILURE) { bombout(("Protocol confusion")); @@ -6127,9 +6130,7 @@ static void do_ssh1_connection(Ssh ssh, struct Packet *pktin) ssh_pkt_addbyte(pkt, SSH_TTY_OP_END); s_wrpkt(ssh, pkt); ssh->state = SSH_STATE_INTERMED; - do { - crReturnV; - } while (!pktin); + crMaybeWaitUntilV((pktin = pq_pop(&ssh->pq_ssh1_connection)) != NULL); if (pktin->type != SSH1_SMSG_SUCCESS && pktin->type != SSH1_SMSG_FAILURE) { bombout(("Protocol confusion")); @@ -6148,9 +6149,7 @@ static void do_ssh1_connection(Ssh ssh, struct Packet *pktin) if (conf_get_int(ssh->conf, CONF_compression)) { send_packet(ssh, SSH1_CMSG_REQUEST_COMPRESSION, PKT_INT, 6, PKT_END); - do { - crReturnV; - } while (!pktin); + crMaybeWaitUntilV((pktin = pq_pop(&ssh->pq_ssh1_connection)) != NULL); if (pktin->type != SSH1_SMSG_SUCCESS && pktin->type != SSH1_SMSG_FAILURE) { bombout(("Protocol confusion")); @@ -6206,8 +6205,7 @@ static void do_ssh1_connection(Ssh ssh, struct Packet *pktin) * attention to the unusual ones. */ - crReturnV; - if (pktin) { + while ((pktin = pq_pop(&ssh->pq_ssh1_connection)) != NULL) { if (pktin->type == SSH1_SMSG_SUCCESS) { /* may be from EXEC_SHELL on some servers */ } else if (pktin->type == SSH1_SMSG_FAILURE) { @@ -6217,19 +6215,19 @@ static void do_ssh1_connection(Ssh ssh, struct Packet *pktin) bombout(("Strange packet received: type %d", pktin->type)); crStopV; } - } else { - while (bufchain_size(&ssh->user_input) > 0) { - void *data; - int len; - bufchain_prefix(&ssh->user_input, &data, &len); - if (len > 512) - len = 512; - send_packet(ssh, SSH1_CMSG_STDIN_DATA, - PKT_INT, len, PKT_DATA, data, len, - PKT_END); - bufchain_consume(&ssh->user_input, len); - } } + while (bufchain_size(&ssh->user_input) > 0) { + void *data; + int len; + bufchain_prefix(&ssh->user_input, &data, &len); + if (len > 512) + len = 512; + send_packet(ssh, SSH1_CMSG_STDIN_DATA, + PKT_INT, len, PKT_DATA, data, len, + PKT_END); + bufchain_consume(&ssh->user_input, len); + } + crReturnV; } crFinishV; @@ -6270,7 +6268,7 @@ static void ssh1_login_input(Ssh ssh) static void ssh1_connection_input(Ssh ssh) { - do_ssh1_connection(ssh, NULL); + do_ssh1_connection(ssh); } static void ssh1_coro_wrapper_initial(Ssh ssh, struct Packet *pktin) @@ -6282,7 +6280,9 @@ static void ssh1_coro_wrapper_initial(Ssh ssh, struct Packet *pktin) static void ssh1_coro_wrapper_session(Ssh ssh, struct Packet *pktin) { - do_ssh1_connection(ssh, pktin); + pktin->refcount++; /* avoid packet being freed when we return */ + pq_push(&ssh->pq_ssh1_connection, pktin); + queue_idempotent_callback(&ssh->ssh1_connection_icb); } static void ssh1_protocol_setup(Ssh ssh) @@ -12302,6 +12302,10 @@ static const char *ssh_init(void *frontend_handle, void **backend_handle, ssh->ssh1_login_icb.fn = do_ssh1_login; ssh->ssh1_login_icb.ctx = ssh; ssh->ssh1_login_icb.queued = FALSE; + pq_init(&ssh->pq_ssh1_connection); + ssh->ssh1_connection_icb.fn = do_ssh1_connection; + ssh->ssh1_connection_icb.ctx = ssh; + ssh->ssh1_connection_icb.queued = FALSE; pq_init(&ssh->pq_ssh2_authconn); ssh->ssh2_authconn_icb.fn = do_ssh2_authconn; ssh->ssh2_authconn_icb.ctx = ssh; @@ -12485,6 +12489,7 @@ static void ssh_free(void *handle) sfree(ssh->incoming_data_eof_message); pq_clear(&ssh->pq_full); pq_clear(&ssh->pq_ssh1_login); + pq_clear(&ssh->pq_ssh1_connection); pq_clear(&ssh->pq_ssh2_authconn); bufchain_clear(&ssh->user_input); sfree(ssh->v_c);