diff --git a/ssh.c b/ssh.c index e8a90e69..9a71cd5d 100644 --- a/ssh.c +++ b/ssh.c @@ -674,8 +674,6 @@ struct PktIn { int refcount; int type; unsigned long sequence; /* SSH-2 incoming sequence number */ - unsigned char *data; /* allocated storage */ - long maxlen; /* amount of storage allocated for `data' */ long encrypted_len; /* for SSH-2 total-size counting */ BinarySource_IMPLEMENTATION; }; @@ -807,21 +805,26 @@ static int pq_empty_on_to_front_of(struct PacketQueue *src, } struct rdpkt1_state_tag { - long len, pad, biglen, length; + long len, pad, biglen, length, maxlen; + unsigned char *data; unsigned long realcrc, gotcrc; int chunk; PktIn *pktin; }; struct rdpkt2_state_tag { - long len, pad, payload, packetlen, maclen, length; + long len, pad, payload, packetlen, maclen, length, maxlen; + unsigned char *buf; + size_t bufsize; + unsigned char *data; int cipherblk; unsigned long incoming_sequence; PktIn *pktin; }; struct rdpkt2_bare_state_tag { - long packetlen; + long packetlen, maxlen; + unsigned char *data; unsigned long incoming_sequence; PktIn *pktin; }; @@ -1373,10 +1376,8 @@ static void c_write_str(Ssh ssh, const char *buf) static void ssh_unref_packet(PktIn *pkt) { - if (--pkt->refcount <= 0) { - sfree(pkt->data); + if (--pkt->refcount <= 0) sfree(pkt); - } } static void ssh_free_pktout(PktOut *pkt) @@ -1507,12 +1508,7 @@ static void ssh1_rdpkt(Ssh ssh) crBegin(ssh->ssh1_rdpkt_crstate); while (1) { - st->pktin = snew(PktIn); - st->pktin->data = NULL; - st->pktin->maxlen = 0; - st->pktin->refcount = 1; - - st->pktin->type = 0; + st->maxlen = 0; st->length = 0; { @@ -1533,14 +1529,20 @@ static void ssh1_rdpkt(Ssh ssh) crStopV; } - st->pktin->maxlen = st->biglen; - st->pktin->data = snewn(st->biglen, unsigned char); + /* + * Allocate the packet to return, now we know its length. + */ + st->pktin = snew_plus(PktIn, st->biglen); + st->pktin->refcount = 1; + st->pktin->type = 0; + + st->maxlen = st->biglen; + st->data = snew_plus_get_aux(st->pktin); crMaybeWaitUntilV(bufchain_try_fetch_consume( - &ssh->incoming_data, - st->pktin->data, st->biglen)); + &ssh->incoming_data, st->data, st->biglen)); - if (ssh->cipher && detect_attack(ssh->crcda_ctx, st->pktin->data, + if (ssh->cipher && detect_attack(ssh->crcda_ctx, st->data, st->biglen, NULL)) { bombout(("Network attack (CRC compensation) detected!")); ssh_unref_packet(st->pktin); @@ -1548,10 +1550,10 @@ static void ssh1_rdpkt(Ssh ssh) } if (ssh->cipher) - ssh->cipher->decrypt(ssh->v1_cipher_ctx, st->pktin->data, st->biglen); + ssh->cipher->decrypt(ssh->v1_cipher_ctx, st->data, st->biglen); - st->realcrc = crc32_compute(st->pktin->data, st->biglen - 4); - st->gotcrc = GET_32BIT(st->pktin->data + st->biglen - 4); + st->realcrc = crc32_compute(st->data, st->biglen - 4); + st->gotcrc = GET_32BIT(st->data + st->biglen - 4); if (st->gotcrc != st->realcrc) { bombout(("Incorrect CRC received on packet")); ssh_unref_packet(st->pktin); @@ -1562,33 +1564,36 @@ static void ssh1_rdpkt(Ssh ssh) unsigned char *decompblk; int decomplen; if (!zlib_decompress_block(ssh->sc_comp_ctx, - st->pktin->data + st->pad, - st->length + 1, + st->data + st->pad, st->length + 1, &decompblk, &decomplen)) { bombout(("Zlib decompression encountered invalid data")); ssh_unref_packet(st->pktin); crStopV; } - if (st->pktin->maxlen < st->pad + decomplen) { - st->pktin->maxlen = st->pad + decomplen; - st->pktin->data = sresize(st->pktin->data, st->pktin->maxlen, - unsigned char); + if (st->maxlen < st->pad + decomplen) { + PktIn *old_pktin = st->pktin; + + st->maxlen = st->pad + decomplen; + st->pktin = snew_plus(PktIn, st->maxlen); + *st->pktin = *old_pktin; /* structure copy */ + st->data = snew_plus_get_aux(st->pktin); + + smemclr(old_pktin, st->biglen); + sfree(old_pktin); } - memcpy(st->pktin->data + st->pad, decompblk, decomplen); + memcpy(st->data + st->pad, decompblk, decomplen); sfree(decompblk); st->length = decomplen - 1; } - st->pktin->type = st->pktin->data[st->pad]; - /* - * Now we know the bounds of the semantic content of the - * packet, excluding the initial type byte. + * Now we can find the bounds of the semantic content of the + * packet, and the initial type byte. */ - BinarySource_INIT(st->pktin, st->pktin->data + st->pad + 1, - st->length); + st->pktin->type = st->data[st->pad]; + BinarySource_INIT(st->pktin, st->data + st->pad + 1, st->length); if (ssh->logctx) ssh1_log_incoming_packet(ssh, st->pktin); @@ -1753,13 +1758,11 @@ static void ssh2_rdpkt(Ssh ssh) crBegin(ssh->ssh2_rdpkt_crstate); - while (1) { - st->pktin = snew(PktIn); - st->pktin->data = NULL; - st->pktin->maxlen = 0; - st->pktin->refcount = 1; + st->buf = NULL; + st->bufsize = 0; - st->pktin->type = 0; + while (1) { + st->maxlen = 0; st->length = 0; if (ssh->sccipher) st->cipherblk = ssh->sccipher->blksize; @@ -1789,14 +1792,18 @@ static void ssh2_rdpkt(Ssh ssh) * detecting it before we decrypt anything. */ - /* May as well allocate the whole lot now. */ - st->pktin->data = snewn(OUR_V2_PACKETLIMIT + st->maclen, - unsigned char); + /* + * Make sure we have buffer space for a maximum-size packet. + */ + int buflimit = OUR_V2_PACKETLIMIT + st->maclen; + if (st->bufsize < buflimit) { + st->bufsize = buflimit; + st->buf = sresize(st->buf, st->bufsize, unsigned char); + } /* Read an amount corresponding to the MAC. */ crMaybeWaitUntilV(bufchain_try_fetch_consume( - &ssh->incoming_data, - st->pktin->data, st->maclen)); + &ssh->incoming_data, st->buf, st->maclen)); st->packetlen = 0; ssh->scmac->start(ssh->sc_mac_ctx); @@ -1807,51 +1814,59 @@ static void ssh2_rdpkt(Ssh ssh) /* Read another cipher-block's worth, and tack it onto the end. */ crMaybeWaitUntilV(bufchain_try_fetch_consume( &ssh->incoming_data, - st->pktin->data + (st->packetlen + - st->maclen), + st->buf + (st->packetlen + st->maclen), st->cipherblk)); /* Decrypt one more block (a little further back in the stream). */ ssh->sccipher->decrypt(ssh->sc_cipher_ctx, - st->pktin->data + st->packetlen, + st->buf + st->packetlen, st->cipherblk); /* Feed that block to the MAC. */ put_data(ssh->sc_mac_bs, - st->pktin->data + st->packetlen, st->cipherblk); + st->buf + st->packetlen, st->cipherblk); st->packetlen += st->cipherblk; /* See if that gives us a valid packet. */ if (ssh->scmac->verresult(ssh->sc_mac_ctx, - st->pktin->data + st->packetlen) && - ((st->len = toint(GET_32BIT(st->pktin->data))) == + st->buf + st->packetlen) && + ((st->len = toint(GET_32BIT(st->buf))) == st->packetlen-4)) break; if (st->packetlen >= OUR_V2_PACKETLIMIT) { bombout(("No valid incoming packet found")); - ssh_unref_packet(st->pktin); crStopV; } } - st->pktin->maxlen = st->packetlen + st->maclen; - st->pktin->data = sresize(st->pktin->data, st->pktin->maxlen, - unsigned char); + st->maxlen = st->packetlen + st->maclen; + + /* + * Now transfer the data into an output packet. + */ + st->pktin = snew_plus(PktIn, st->maxlen); + st->pktin->refcount = 1; + st->pktin->type = 0; + st->data = snew_plus_get_aux(st->pktin); + memcpy(st->data, st->buf, st->maxlen); } else if (ssh->scmac && ssh->scmac_etm) { - st->pktin->data = snewn(4, unsigned char); + if (st->bufsize < 4) { + st->bufsize = 4; + st->buf = sresize(st->buf, st->bufsize, unsigned char); + } /* * OpenSSH encrypt-then-MAC mode: the packet length is * unencrypted, unless the cipher supports length encryption. */ crMaybeWaitUntilV(bufchain_try_fetch_consume( - &ssh->incoming_data, st->pktin->data, 4)); + &ssh->incoming_data, st->buf, 4)); /* Cipher supports length decryption, so do it */ if (ssh->sccipher && (ssh->sccipher->flags & SSH_CIPHER_SEPARATE_LENGTH)) { /* Keep the packet the same though, so the MAC passes */ unsigned char len[4]; - memcpy(len, st->pktin->data, 4); + memcpy(len, st->buf, 4); ssh->sccipher->decrypt_length(ssh->sc_cipher_ctx, len, 4, st->incoming_sequence); st->len = toint(GET_32BIT(len)); } else { - st->len = toint(GET_32BIT(st->pktin->data)); + st->len = toint(GET_32BIT(st->buf)); } /* @@ -1861,7 +1876,6 @@ static void ssh2_rdpkt(Ssh ssh) if (st->len < 0 || st->len > OUR_V2_PACKETLIMIT || st->len % st->cipherblk != 0) { bombout(("Incoming packet length field was garbled")); - ssh_unref_packet(st->pktin); crStopV; } @@ -1871,24 +1885,26 @@ static void ssh2_rdpkt(Ssh ssh) st->packetlen = st->len + 4; /* - * Allocate memory for the rest of the packet. + * Allocate the packet to return, now we know its length. */ - st->pktin->maxlen = st->packetlen + st->maclen; - st->pktin->data = sresize(st->pktin->data, st->pktin->maxlen, - unsigned char); + st->pktin = snew_plus(PktIn, OUR_V2_PACKETLIMIT + st->maclen); + st->pktin->refcount = 1; + st->pktin->type = 0; + st->data = snew_plus_get_aux(st->pktin); + memcpy(st->data, st->buf, 4); /* * Read the remainder of the packet. */ crMaybeWaitUntilV(bufchain_try_fetch_consume( - &ssh->incoming_data, st->pktin->data + 4, + &ssh->incoming_data, st->data + 4, st->packetlen + st->maclen - 4)); /* * Check the MAC. */ if (ssh->scmac - && !ssh->scmac->verify(ssh->sc_mac_ctx, st->pktin->data, + && !ssh->scmac->verify(ssh->sc_mac_ctx, st->data, st->len + 4, st->incoming_sequence)) { bombout(("Incorrect MAC received on packet")); ssh_unref_packet(st->pktin); @@ -1898,10 +1914,12 @@ static void ssh2_rdpkt(Ssh ssh) /* Decrypt everything between the length field and the MAC. */ if (ssh->sccipher) ssh->sccipher->decrypt(ssh->sc_cipher_ctx, - st->pktin->data + 4, - st->packetlen - 4); + st->data + 4, st->packetlen - 4); } else { - st->pktin->data = snewn(st->cipherblk, unsigned char); + if (st->bufsize < st->cipherblk) { + st->bufsize = st->cipherblk; + st->buf = sresize(st->buf, st->bufsize, unsigned char); + } /* * Acquire and decrypt the first block of the packet. This will @@ -1909,16 +1927,16 @@ static void ssh2_rdpkt(Ssh ssh) */ crMaybeWaitUntilV(bufchain_try_fetch_consume( &ssh->incoming_data, - st->pktin->data, st->cipherblk)); + st->buf, st->cipherblk)); if (ssh->sccipher) ssh->sccipher->decrypt(ssh->sc_cipher_ctx, - st->pktin->data, st->cipherblk); + st->buf, st->cipherblk); /* * Now get the length figure. */ - st->len = toint(GET_32BIT(st->pktin->data)); + st->len = toint(GET_32BIT(st->buf)); /* * _Completely_ silly lengths should be stomped on before they @@ -1937,31 +1955,34 @@ static void ssh2_rdpkt(Ssh ssh) st->packetlen = st->len + 4; /* - * Allocate memory for the rest of the packet. + * Allocate the packet to return, now we know its length. */ - st->pktin->maxlen = st->packetlen + st->maclen; - st->pktin->data = sresize(st->pktin->data, st->pktin->maxlen, - unsigned char); + st->maxlen = st->packetlen + st->maclen; + st->pktin = snew_plus(PktIn, st->maxlen); + st->pktin->refcount = 1; + st->pktin->type = 0; + st->data = snew_plus_get_aux(st->pktin); + memcpy(st->data, st->buf, st->cipherblk); /* * Read and decrypt the remainder of the packet. */ crMaybeWaitUntilV(bufchain_try_fetch_consume( &ssh->incoming_data, - st->pktin->data + st->cipherblk, + st->data + st->cipherblk, st->packetlen + st->maclen - st->cipherblk)); /* Decrypt everything _except_ the MAC. */ if (ssh->sccipher) ssh->sccipher->decrypt(ssh->sc_cipher_ctx, - st->pktin->data + st->cipherblk, + st->data + st->cipherblk, st->packetlen - st->cipherblk); /* * Check the MAC. */ if (ssh->scmac - && !ssh->scmac->verify(ssh->sc_mac_ctx, st->pktin->data, + && !ssh->scmac->verify(ssh->sc_mac_ctx, st->data, st->len + 4, st->incoming_sequence)) { bombout(("Incorrect MAC received on packet")); ssh_unref_packet(st->pktin); @@ -1969,7 +1990,7 @@ static void ssh2_rdpkt(Ssh ssh) } } /* Get and sanity-check the amount of random padding. */ - st->pad = st->pktin->data[4]; + st->pad = st->data[4]; if (st->pad < 4 || st->len - st->pad < 1) { bombout(("Invalid padding length on received packet")); ssh_unref_packet(st->pktin); @@ -1996,16 +2017,21 @@ static void ssh2_rdpkt(Ssh ssh) int newlen; if (ssh->sccomp && ssh->sccomp->decompress(ssh->sc_comp_ctx, - st->pktin->data + 5, st->length - 5, + st->data + 5, st->length - 5, &newpayload, &newlen)) { - if (st->pktin->maxlen < newlen + 5) { - st->pktin->maxlen = newlen + 5; - st->pktin->data = sresize(st->pktin->data, - st->pktin->maxlen, - unsigned char); + if (st->maxlen < newlen + 5) { + PktIn *old_pktin = st->pktin; + + st->maxlen = newlen + 5; + st->pktin = snew_plus(PktIn, st->maxlen); + *st->pktin = *old_pktin; /* structure copy */ + st->data = snew_plus_get_aux(st->pktin); + + smemclr(old_pktin, st->packetlen + st->maclen); + sfree(old_pktin); } st->length = 5 + newlen; - memcpy(st->pktin->data + 5, newpayload, newlen); + memcpy(st->data + 5, newpayload, newlen); sfree(newpayload); } } @@ -2019,14 +2045,15 @@ static void ssh2_rdpkt(Ssh ssh) ssh2_msg_something_unimplemented(ssh, st->pktin); crStopV; } + /* * Now we can identify the semantic content of the packet, * and also the initial type byte. */ - st->pktin->type = st->pktin->data[5]; + st->pktin->type = st->data[5]; st->length -= 6; assert(st->length >= 0); /* one last double-check */ - BinarySource_INIT(st->pktin, st->pktin->data + 6, st->length); + BinarySource_INIT(st->pktin, st->data + 6, st->length); if (ssh->logctx) ssh2_log_incoming_packet(ssh, st->pktin); @@ -2079,10 +2106,13 @@ static void ssh2_bare_connection_rdpkt(Ssh ssh) crStopV; } - st->pktin = snew(PktIn); - st->pktin->maxlen = 0; + /* + * Allocate the packet to return, now we know its length. + */ + st->pktin = snew_plus(PktIn, st->packetlen); + st->maxlen = 0; st->pktin->refcount = 1; - st->pktin->data = snewn(st->packetlen, unsigned char); + st->data = snew_plus_get_aux(st->pktin); st->pktin->encrypted_len = st->packetlen; @@ -2092,15 +2122,14 @@ static void ssh2_bare_connection_rdpkt(Ssh ssh) * Read the remainder of the packet. */ crMaybeWaitUntilV(bufchain_try_fetch_consume( - &ssh->incoming_data, - st->pktin->data, st->packetlen)); + &ssh->incoming_data, st->data, st->packetlen)); /* - * pktin->body and pktin->length should identify the semantic - * content of the packet, excluding the initial type byte. + * The data we just read is precisely the initial type byte + * followed by the packet payload. */ - st->pktin->type = st->pktin->data[0]; - BinarySource_INIT(st->pktin, st->pktin->data + 1, st->packetlen - 1); + st->pktin->type = st->data[0]; + BinarySource_INIT(st->pktin, st->data + 1, st->packetlen - 1); /* * Log incoming packet, possibly omitting sensitive fields. @@ -12045,6 +12074,7 @@ static const char *ssh_init(void *frontend_handle, void **backend_handle, ssh->ssh1_rdpkt_crstate = 0; ssh->ssh2_rdpkt_crstate = 0; ssh->ssh2_bare_rdpkt_crstate = 0; + ssh->rdpkt2_state.buf = NULL; ssh->do_ssh1_connection_crstate = 0; ssh->do_ssh_init_state = NULL; ssh->do_ssh_connection_init_state = NULL; @@ -12205,6 +12235,11 @@ static void ssh_free(void *handle) dh_cleanup(ssh->kex_ctx); sfree(ssh->savedhost); + if (ssh->rdpkt2_state.buf) { + smemclr(ssh->rdpkt2_state.buf, ssh->rdpkt2_state.bufsize); + sfree(ssh->rdpkt2_state.buf); + } + while (ssh->queuelen-- > 0) ssh_free_pktout(ssh->queue[ssh->queuelen]); sfree(ssh->queue);