1
0
mirror of https://git.tartarus.org/simon/putty.git synced 2025-02-03 21:52:24 +00:00

Make the rdpkt functions output to a PacketQueue.

Each of the coroutines that parses the incoming wire data into a
stream of 'struct Packet' now delivers those packets to a PacketQueue
called ssh->pq_full (containing the full, unfiltered stream of all
packets received on the SSH connection), replacing the old API in
which each coroutine would directly return a 'struct Packet *' to its
caller, or NULL if it didn't have one ready yet.

This simplifies the function-call API of the rdpkt coroutines (they
now return void). It increases the complexity at the other end,
because we've now got a function ssh_process_pq_full (scheduled as an
idempotent callback whenever rdpkt appends anything to the queue)
which pulls things out of the queue and passes them to ssh->protocol.
But that's only a temporary complexity increase; by the time I finish
the upcoming stream of refactorings, there won't be two chained
functions there any more.

One small workaround I had to add in this commit is a flag called
'pending_newkeys', which ssh2_rdpkt sets when it's just returned an
SSH_MSG_NEWKEYS packet, and then waits for the transport layer to
process the NEWKEYS and set up the new encryption context before
processing any more wire data. This wasn't necessary before, because
the old architecture was naturally synchronous - ssh2_rdpkt would
return a NEWKEYS, which would be immediately passed to
do_ssh2_transport, which would finish processing it immediately, and
by the time ssh2_rdpkt was next called, the keys would already be in
place.

This change adds a big while loop around the whole of each rdpkt
function, so it's easiest to read it as a whitespace-ignored diff.
This commit is contained in:
Simon Tatham 2018-05-18 07:22:57 +01:00
parent 9d495b2176
commit 2b57b84fa5

117
ssh.c
View File

@ -969,6 +969,9 @@ struct ssh_tag {
void *do_ssh2_authconn_state;
void *do_ssh_connection_init_state;
struct PacketQueue pq_full;
struct IdempotentCallback pq_full_consumer;
struct rdpkt1_state_tag rdpkt1_state;
struct rdpkt2_state_tag rdpkt2_state;
struct rdpkt2_bare_state_tag rdpkt2_bare_state;
@ -978,8 +981,7 @@ struct ssh_tag {
void (*protocol) (Ssh ssh, const void *vin, int inlen,
struct Packet *pkt);
struct Packet *(*s_rdpkt) (Ssh ssh, const unsigned char **data,
int *datalen);
void (*s_rdpkt) (Ssh ssh, const unsigned char **data, int *datalen);
int (*do_ssh_init)(Ssh ssh, unsigned char c);
/*
@ -1051,6 +1053,13 @@ struct ssh_tag {
unsigned long next_rekey, last_rekey;
const char *deferred_rekey_reason;
/*
* Inhibit processing of incoming raw data into packets while
* we're still waiting for a NEWKEYS message to complete and fill
* in the new details of how that should be done.
*/
int pending_newkeys;
/*
* Fully qualified host name, which we need if doing GSSAPI.
*/
@ -1467,13 +1476,13 @@ static void ssh1_log_outgoing_packet(Ssh ssh, struct Packet *pkt)
* Update the *data and *datalen variables.
* Return a Packet structure when a packet is completed.
*/
static struct Packet *ssh1_rdpkt(Ssh ssh, const unsigned char **data,
int *datalen)
static void ssh1_rdpkt(Ssh ssh, const unsigned char **data, int *datalen)
{
struct rdpkt1_state_tag *st = &ssh->rdpkt1_state;
crBegin(ssh->ssh1_rdpkt_crstate);
while (1) {
st->pktin = ssh_new_packet();
st->pktin->type = 0;
@ -1481,7 +1490,7 @@ static struct Packet *ssh1_rdpkt(Ssh ssh, const unsigned char **data,
for (st->i = st->len = 0; st->i < 4; st->i++) {
while ((*datalen) == 0)
crReturn(NULL);
crReturnV;
st->len = (st->len << 8) + **data;
(*data)++, (*datalen)--;
}
@ -1494,7 +1503,7 @@ static struct Packet *ssh1_rdpkt(Ssh ssh, const unsigned char **data,
bombout(("Extremely large packet length from server suggests"
" data stream corruption"));
ssh_unref_packet(st->pktin);
crStop(NULL);
crStopV;
}
st->pktin->maxlen = st->biglen;
@ -1505,7 +1514,7 @@ static struct Packet *ssh1_rdpkt(Ssh ssh, const unsigned char **data,
while (st->to_read > 0) {
st->chunk = st->to_read;
while ((*datalen) == 0)
crReturn(NULL);
crReturnV;
if (st->chunk > (*datalen))
st->chunk = (*datalen);
memcpy(st->p, *data, st->chunk);
@ -1519,7 +1528,7 @@ static struct Packet *ssh1_rdpkt(Ssh ssh, const unsigned char **data,
st->biglen, NULL)) {
bombout(("Network attack (CRC compensation) detected!"));
ssh_unref_packet(st->pktin);
crStop(NULL);
crStopV;
}
if (ssh->cipher)
@ -1530,7 +1539,7 @@ static struct Packet *ssh1_rdpkt(Ssh ssh, const unsigned char **data,
if (st->gotcrc != st->realcrc) {
bombout(("Incorrect CRC received on packet"));
ssh_unref_packet(st->pktin);
crStop(NULL);
crStopV;
}
st->pktin->body = st->pktin->data + st->pad + 1;
@ -1543,7 +1552,7 @@ static struct Packet *ssh1_rdpkt(Ssh ssh, const unsigned char **data,
&decompblk, &decomplen)) {
bombout(("Zlib decompression encountered invalid data"));
ssh_unref_packet(st->pktin);
crStop(NULL);
crStopV;
}
if (st->pktin->maxlen < st->pad + decomplen) {
@ -1571,7 +1580,11 @@ static struct Packet *ssh1_rdpkt(Ssh ssh, const unsigned char **data,
st->pktin->savedpos = 0;
crFinish(st->pktin);
pq_push(&ssh->pq_full, st->pktin);
queue_idempotent_callback(&ssh->pq_full_consumer);
crReturnV;
}
crFinishV;
}
static void ssh2_log_incoming_packet(Ssh ssh, struct Packet *pkt)
@ -1723,13 +1736,13 @@ static void ssh2_log_outgoing_packet(Ssh ssh, struct Packet *pkt)
pkt->length += (pkt->body - pkt->data);
}
static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data,
int *datalen)
static void ssh2_rdpkt(Ssh ssh, const unsigned char **data, int *datalen)
{
struct rdpkt2_state_tag *st = &ssh->rdpkt2_state;
crBegin(ssh->ssh2_rdpkt_crstate);
while (1) {
st->pktin = ssh_new_packet();
st->pktin->type = 0;
@ -1769,7 +1782,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data,
/* Read an amount corresponding to the MAC. */
for (st->i = 0; st->i < st->maclen; st->i++) {
while ((*datalen) == 0)
crReturn(NULL);
crReturnV;
st->pktin->data[st->i] = *(*data)++;
(*datalen)--;
}
@ -1786,7 +1799,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data,
/* Read another cipher-block's worth, and tack it onto the end. */
for (st->i = 0; st->i < st->cipherblk; st->i++) {
while ((*datalen) == 0)
crReturn(NULL);
crReturnV;
st->pktin->data[st->packetlen+st->maclen+st->i] = *(*data)++;
(*datalen)--;
}
@ -1807,7 +1820,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data,
if (st->packetlen >= OUR_V2_PACKETLIMIT) {
bombout(("No valid incoming packet found"));
ssh_unref_packet(st->pktin);
crStop(NULL);
crStopV;
}
}
st->pktin->maxlen = st->packetlen + st->maclen;
@ -1823,7 +1836,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data,
*/
for (st->i = st->len = 0; st->i < 4; st->i++) {
while ((*datalen) == 0)
crReturn(NULL);
crReturnV;
st->pktin->data[st->i] = *(*data)++;
(*datalen)--;
}
@ -1846,7 +1859,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data,
st->len % st->cipherblk != 0) {
bombout(("Incoming packet length field was garbled"));
ssh_unref_packet(st->pktin);
crStop(NULL);
crStopV;
}
/*
@ -1867,7 +1880,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data,
*/
for (st->i = 4; st->i < st->packetlen + st->maclen; st->i++) {
while ((*datalen) == 0)
crReturn(NULL);
crReturnV;
st->pktin->data[st->i] = *(*data)++;
(*datalen)--;
}
@ -1880,7 +1893,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data,
st->len + 4, st->incoming_sequence)) {
bombout(("Incorrect MAC received on packet"));
ssh_unref_packet(st->pktin);
crStop(NULL);
crStopV;
}
/* Decrypt everything between the length field and the MAC. */
@ -1897,7 +1910,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data,
*/
for (st->i = st->len = 0; st->i < st->cipherblk; st->i++) {
while ((*datalen) == 0)
crReturn(NULL);
crReturnV;
st->pktin->data[st->i] = *(*data)++;
(*datalen)--;
}
@ -1919,7 +1932,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data,
(st->len + 4) % st->cipherblk != 0) {
bombout(("Incoming packet was garbled on decryption"));
ssh_unref_packet(st->pktin);
crStop(NULL);
crStopV;
}
/*
@ -1941,7 +1954,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data,
for (st->i = st->cipherblk; st->i < st->packetlen + st->maclen;
st->i++) {
while ((*datalen) == 0)
crReturn(NULL);
crReturnV;
st->pktin->data[st->i] = *(*data)++;
(*datalen)--;
}
@ -1959,7 +1972,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data,
st->len + 4, st->incoming_sequence)) {
bombout(("Incorrect MAC received on packet"));
ssh_unref_packet(st->pktin);
crStop(NULL);
crStopV;
}
}
/* Get and sanity-check the amount of random padding. */
@ -1967,7 +1980,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data,
if (st->pad < 4 || st->len - st->pad < 1) {
bombout(("Invalid padding length on received packet"));
ssh_unref_packet(st->pktin);
crStop(NULL);
crStopV;
}
/*
* This enables us to deduce the payload length.
@ -2011,7 +2024,7 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data,
*/
if (st->pktin->length <= 5) { /* == 5 we hope, but robustness */
ssh2_msg_something_unimplemented(ssh, st->pktin);
crStop(NULL);
crStopV;
}
/*
* pktin->body and pktin->length should identify the semantic
@ -2027,23 +2040,34 @@ static struct Packet *ssh2_rdpkt(Ssh ssh, const unsigned char **data,
st->pktin->savedpos = 0;
crFinish(st->pktin);
pq_push(&ssh->pq_full, st->pktin);
queue_idempotent_callback(&ssh->pq_full_consumer);
if (st->pktin->type == SSH2_MSG_NEWKEYS) {
/* Mild layer violation: in this situation we must suspend
* processing of the input byte stream in order to ensure
* that the transport code has processed NEWKEYS and
* installed the new cipher. */
ssh->pending_newkeys = TRUE;
crReturnV;
}
}
crFinishV;
}
static struct Packet *ssh2_bare_connection_rdpkt(Ssh ssh,
const unsigned char **data,
int *datalen)
static void ssh2_bare_connection_rdpkt(
Ssh ssh, const unsigned char **data, int *datalen)
{
struct rdpkt2_bare_state_tag *st = &ssh->rdpkt2_bare_state;
crBegin(ssh->ssh2_bare_rdpkt_crstate);
while (1) {
/*
* Read the packet length field.
*/
for (st->i = 0; st->i < 4; st->i++) {
while ((*datalen) == 0)
crReturn(NULL);
crReturnV;
st->length[st->i] = *(*data)++;
(*datalen)--;
}
@ -2051,7 +2075,7 @@ static struct Packet *ssh2_bare_connection_rdpkt(Ssh ssh,
st->packetlen = toint(GET_32BIT_MSB_FIRST(st->length));
if (st->packetlen <= 0 || st->packetlen >= OUR_V2_PACKETLIMIT) {
bombout(("Invalid packet length received"));
crStop(NULL);
crStopV;
}
st->pktin = ssh_new_packet();
@ -2066,7 +2090,7 @@ static struct Packet *ssh2_bare_connection_rdpkt(Ssh ssh,
*/
for (st->i = 0; st->i < st->packetlen; st->i++) {
while ((*datalen) == 0)
crReturn(NULL);
crReturnV;
st->pktin->data[st->i] = *(*data)++;
(*datalen)--;
}
@ -2087,7 +2111,10 @@ static struct Packet *ssh2_bare_connection_rdpkt(Ssh ssh,
st->pktin->savedpos = 0;
crFinish(st->pktin);
pq_push(&ssh->pq_full, st->pktin);
queue_idempotent_callback(&ssh->pq_full_consumer);
}
crFinishV;
}
static int s_wrpkt_prepare(Ssh ssh, struct Packet *pkt, int *offset_p)
@ -3462,8 +3489,8 @@ static void ssh_process_incoming_data(Ssh ssh,
{
struct Packet *pktin;
pktin = ssh->s_rdpkt(ssh, data, datalen);
if (pktin) {
ssh->s_rdpkt(ssh, data, datalen);
while ((pktin = pq_pop(&ssh->pq_full)) != NULL) {
ssh->protocol(ssh, NULL, 0, pktin);
ssh_unref_packet(pktin);
}
@ -3503,6 +3530,17 @@ static void ssh_set_frozen(Ssh ssh, int frozen)
ssh->frozen = frozen;
}
static void ssh_process_pq_full(void *ctx)
{
Ssh ssh = (Ssh)ctx;
struct Packet *pktin;
while ((pktin = pq_pop(&ssh->pq_full)) != NULL) {
ssh->protocol(ssh, NULL, 0, pktin);
ssh_unref_packet(pktin);
}
}
static void ssh_gotdata(Ssh ssh, const unsigned char *data, int datalen)
{
/* Log raw data, if we're in that mode. */
@ -8311,6 +8349,7 @@ static void do_ssh2_transport(Ssh ssh, const void *vin, int inlen,
bombout(("expected new-keys packet from server"));
crStopV;
}
ssh->pending_newkeys = FALSE; /* resume processing incoming data */
ssh->incoming_data_size = 0; /* start counting from here */
/*
@ -12321,6 +12360,11 @@ static const char *ssh_init(void *frontend_handle, void **backend_handle,
ssh->do_ssh1_login_state = NULL;
ssh->do_ssh2_transport_state = NULL;
ssh->do_ssh2_authconn_state = NULL;
pq_init(&ssh->pq_full);
ssh->pq_full_consumer.fn = ssh_process_pq_full;
ssh->pq_full_consumer.ctx = ssh;
ssh->pq_full_consumer.queued = FALSE;
ssh->pending_newkeys = FALSE;
ssh->v_c = NULL;
ssh->v_s = NULL;
ssh->mainchan = NULL;
@ -12494,6 +12538,7 @@ static void ssh_free(void *handle)
sfree(ssh->do_ssh1_login_state);
sfree(ssh->do_ssh2_transport_state);
sfree(ssh->do_ssh2_authconn_state);
pq_clear(&ssh->pq_full);
sfree(ssh->v_c);
sfree(ssh->v_s);
sfree(ssh->fullhostname);