diff --git a/ssh.c b/ssh.c index 8bfd53eb..0300cad2 100644 --- a/ssh.c +++ b/ssh.c @@ -397,6 +397,8 @@ static void ssh_channel_destroy(struct ssh_channel *c); static void ssh_channel_unthrottle(struct ssh_channel *c, int bufsize); static void ssh2_msg_something_unimplemented(Ssh ssh, struct Packet *pktin); static void ssh2_general_packet_processing(Ssh ssh, struct Packet *pktin); +static void ssh1_login_input(Ssh ssh); +static void ssh2_authconn_input(Ssh ssh); /* * Buffer management constants. There are several of these for @@ -969,6 +971,9 @@ struct ssh_tag { struct PacketQueue pq_full; struct IdempotentCallback pq_full_consumer; + bufchain user_input; + struct IdempotentCallback user_input_consumer; + struct rdpkt1_state_tag rdpkt1_state; struct rdpkt2_state_tag rdpkt2_state; struct rdpkt2_bare_state_tag rdpkt2_bare_state; @@ -979,6 +984,7 @@ struct ssh_tag { void (*protocol) (Ssh ssh, const void *vin, int inlen); void (*general_packet_processing)(Ssh ssh, struct Packet *pkt); void (*current_incoming_data_fn) (Ssh ssh); + void (*current_user_input_fn) (Ssh ssh); /* * We maintain our own copy of a Conf structure here. That way, @@ -3349,6 +3355,7 @@ static void do_ssh_init(Ssh ssh) ssh2_protocol_setup(ssh); ssh->general_packet_processing = ssh2_general_packet_processing; ssh->current_incoming_data_fn = ssh2_rdpkt; + ssh->current_user_input_fn = NULL; } else { /* * Initialise SSH-1 protocol. @@ -3356,6 +3363,7 @@ static void do_ssh_init(Ssh ssh) ssh->protocol = ssh1_protocol; ssh1_protocol_setup(ssh); ssh->current_incoming_data_fn = ssh1_rdpkt; + ssh->current_user_input_fn = ssh1_login_input; } queue_idempotent_callback(&ssh->incoming_data_consumer); if (ssh->version == 2) @@ -3579,6 +3587,13 @@ static void ssh_process_pq_full(void *ctx) } } +static void ssh_process_user_input(void *ctx) +{ + Ssh ssh = (Ssh)ctx; + if (ssh->current_user_input_fn) + ssh->current_user_input_fn(ssh); +} + static int ssh_do_close(Ssh ssh, int notify_exit) { int ret = 0; @@ -6268,6 +6283,28 @@ static void ssh_msg_ignore(Ssh ssh, struct Packet *pktin) /* Do nothing, because we're ignoring it! Duhh. */ } +static void ssh1_login_input(Ssh ssh) +{ + while (bufchain_size(&ssh->user_input) > 0) { + void *data; + int len; + bufchain_prefix(&ssh->user_input, &data, &len); + do_ssh1_login(ssh, data, len, NULL); + bufchain_consume(&ssh->user_input, len); + } +} + +static void ssh1_connection_input(Ssh ssh) +{ + while (bufchain_size(&ssh->user_input) > 0) { + void *data; + int len; + bufchain_prefix(&ssh->user_input, &data, &len); + do_ssh1_connection(ssh, data, len, NULL); + bufchain_consume(&ssh->user_input, len); + } +} + static void ssh1_coro_wrapper_session(Ssh ssh, struct Packet *pktin); static void ssh1_coro_wrapper_initial(Ssh ssh, struct Packet *pktin) @@ -6277,6 +6314,7 @@ static void ssh1_coro_wrapper_initial(Ssh ssh, struct Packet *pktin) for (i = 0; i < 256; i++) if (ssh->packet_dispatch[i] == ssh1_coro_wrapper_initial) ssh->packet_dispatch[i] = ssh1_coro_wrapper_session; + ssh->current_user_input_fn = ssh1_connection_input; } } @@ -6310,10 +6348,12 @@ static void ssh1_protocol(Ssh ssh, const void *vin, int inlen) return; if (!ssh->protocol_initial_phase_done) { - if (do_ssh1_login(ssh, in, inlen, NULL)) + if (do_ssh1_login(ssh, in, inlen, NULL)) { ssh->protocol_initial_phase_done = TRUE; - else + ssh->current_user_input_fn = ssh1_connection_input; + } else { return; + } } do_ssh1_connection(ssh, in, inlen, NULL); @@ -8449,6 +8489,7 @@ static void do_ssh2_transport(Ssh ssh, const void *vin, int inlen, * Allow authconn to initialise itself. */ do_ssh2_authconn(ssh, NULL, 0, NULL); + ssh->current_user_input_fn = ssh2_authconn_input; } crReturnV; } @@ -12254,6 +12295,17 @@ static void ssh2_protocol(Ssh ssh, const void *vin, int inlen) do_ssh2_authconn(ssh, in, inlen, NULL); } +static void ssh2_authconn_input(Ssh ssh) +{ + while (bufchain_size(&ssh->user_input) > 0) { + void *data; + int len; + bufchain_prefix(&ssh->user_input, &data, &len); + do_ssh2_authconn(ssh, data, len, NULL); + bufchain_consume(&ssh->user_input, len); + } +} + static void ssh2_bare_connection_protocol(Ssh ssh, const void *vin, int inlen) { const unsigned char *in = (const unsigned char *)vin; @@ -12343,6 +12395,10 @@ static const char *ssh_init(void *frontend_handle, void **backend_handle, ssh->pq_full_consumer.fn = ssh_process_pq_full; ssh->pq_full_consumer.ctx = ssh; ssh->pq_full_consumer.queued = FALSE; + bufchain_init(&ssh->user_input); + ssh->user_input_consumer.fn = ssh_process_user_input; + ssh->user_input_consumer.ctx = ssh; + ssh->user_input_consumer.queued = FALSE; ssh->pending_newkeys = FALSE; ssh->v_c = NULL; ssh->v_s = NULL; @@ -12520,6 +12576,7 @@ static void ssh_free(void *handle) bufchain_clear(&ssh->incoming_data); sfree(ssh->incoming_data_eof_message); pq_clear(&ssh->pq_full); + bufchain_clear(&ssh->user_input); sfree(ssh->v_c); sfree(ssh->v_s); sfree(ssh->fullhostname); @@ -12626,7 +12683,8 @@ static int ssh_send(void *handle, const char *buf, int len) if (ssh == NULL || ssh->s == NULL || ssh->protocol == NULL) return 0; - ssh->protocol(ssh, (const unsigned char *)buf, len, 0); + bufchain_add(&ssh->user_input, buf, len); + queue_idempotent_callback(&ssh->user_input_consumer); return ssh_sendbuffer(ssh); }