diff --git a/ssh.c b/ssh.c index 5cc1373c..e3dba149 100644 --- a/ssh.c +++ b/ssh.c @@ -559,8 +559,8 @@ struct ssh_tag { * Track incoming and outgoing data sizes and time, for * size-based rekeys. */ - unsigned long incoming_data_size, outgoing_data_size; unsigned long max_data_size; + struct DataTransferStats stats; int kex_in_progress; unsigned long next_rekey, last_rekey; const char *deferred_rekey_reason; @@ -781,10 +781,9 @@ static void ssh_send_outgoing_data(void *ctx) backlog = s_write(ssh, data, len); bufchain_consume(&ssh->outgoing_data, len); - ssh->outgoing_data_size += len; if (ssh->version == 2 && !ssh->kex_in_progress && - !ssh->bare_connection && ssh->max_data_size != 0 && - ssh->outgoing_data_size > ssh->max_data_size) { + ssh->state != SSH_STATE_PREPACKET && + !ssh->bare_connection && !ssh->stats.out.running) { ssh->rekey_reason = "too much data sent"; ssh->rekey_class = RK_NORMAL; queue_idempotent_callback(&ssh->ssh2_transport_icb); @@ -5131,7 +5130,9 @@ static void do_ssh2_transport(void *vctx) */ s->pktout = ssh_bpp_new_pktout(ssh->bpp, SSH2_MSG_NEWKEYS); ssh_pkt_write(ssh, s->pktout); - ssh->outgoing_data_size = 0; /* start counting from here */ + /* Start counting down the outgoing-data limit for these cipher keys. */ + ssh->stats.out.running = TRUE; + ssh->stats.out.remaining = ssh->max_data_size; /* * We've sent client NEWKEYS, so create and initialise @@ -5204,7 +5205,9 @@ static void do_ssh2_transport(void *vctx) bombout(("expected new-keys packet from server")); crStopV; } - ssh->incoming_data_size = 0; /* start counting from here */ + /* Start counting down the incoming-data limit for these cipher keys. */ + ssh->stats.in.running = TRUE; + ssh->stats.in.remaining = ssh->max_data_size; /* * We've seen server NEWKEYS, so create and initialise @@ -5363,8 +5366,9 @@ static void do_ssh2_transport(void *vctx) ssh->rekey_reason); /* Reset the counters, so that at least this message doesn't * hit the event log _too_ often. */ - ssh->outgoing_data_size = 0; - ssh->incoming_data_size = 0; + ssh->stats.in.running = ssh->stats.out.running = TRUE; + ssh->stats.in.remaining = ssh->stats.out.remaining = + ssh->max_data_size; (void) ssh2_timer_update(ssh, 0); goto wait_for_rekey; /* this is still utterly horrid */ } else { @@ -8649,7 +8653,7 @@ static void ssh2_protocol_setup(Ssh ssh) { int i; - ssh->bpp = ssh2_bpp_new(); + ssh->bpp = ssh2_bpp_new(&ssh->stats); #ifndef NO_GSSAPI /* Load and pick the highest GSS library on the preference list. */ @@ -9055,10 +9059,8 @@ static void ssh2_timer(void *ctx, unsigned long now) static void ssh2_general_packet_processing(Ssh ssh, PktIn *pktin) { - ssh->incoming_data_size += pktin->encrypted_len; - if (!ssh->kex_in_progress && - ssh->max_data_size != 0 && - ssh->incoming_data_size > ssh->max_data_size) { + if (!ssh->kex_in_progress && ssh->max_data_size != 0 && + ssh->state != SSH_STATE_PREPACKET && !ssh->stats.in.running) { ssh->rekey_reason = "too much data received"; ssh->rekey_class = RK_NORMAL; queue_idempotent_callback(&ssh->ssh2_transport_icb); @@ -9213,7 +9215,7 @@ static const char *ssh_init(Frontend *frontend, Backend **backend_handle, ssh->pinger = NULL; - ssh->incoming_data_size = ssh->outgoing_data_size = 0L; + memset(&ssh->stats, 0, sizeof(ssh->stats)); ssh->max_data_size = parse_blocksize(conf_get_str(ssh->conf, CONF_ssh_rekey_data)); ssh->kex_in_progress = FALSE; @@ -9371,9 +9373,22 @@ static void ssh_reconfig(Backend *be, Conf *conf) CONF_ssh_rekey_data)); if (old_max_data_size != ssh->max_data_size && ssh->max_data_size != 0) { - if (ssh->outgoing_data_size > ssh->max_data_size || - ssh->incoming_data_size > ssh->max_data_size) - rekeying = "data limit lowered"; + if (ssh->max_data_size < old_max_data_size) { + unsigned long diff = old_max_data_size - ssh->max_data_size; + + /* Intentionally use bitwise OR instead of logical, so + * that we decrement both counters even if the first one + * runs out */ + if ((DTS_CONSUME(&ssh->stats, out, diff) != 0) | + (DTS_CONSUME(&ssh->stats, in, diff) != 0)) + rekeying = "data limit lowered"; + } else { + unsigned long diff = ssh->max_data_size - old_max_data_size; + if (ssh->stats.out.running) + ssh->stats.out.remaining += diff; + if (ssh->stats.in.running) + ssh->stats.in.remaining += diff; + } } if (conf_get_int(ssh->conf, CONF_compression) != diff --git a/ssh.h b/ssh.h index 16224330..fd8d1677 100644 --- a/ssh.h +++ b/ssh.h @@ -59,7 +59,6 @@ typedef struct PktIn { int refcount; int type; unsigned long sequence; /* SSH-2 incoming sequence number */ - long encrypted_len; /* for SSH-2 total-size counting */ PacketQueueNode qnode; /* for linking this packet on to a queue */ BinarySource_IMPLEMENTATION; } PktIn; @@ -71,7 +70,6 @@ typedef struct PktOut { long minlen; /* SSH-2: ensure wire length is at least this */ unsigned char *data; /* allocated storage */ long maxlen; /* amount of storage allocated for `data' */ - long encrypted_len; /* for SSH-2 total-size counting */ /* Extra metadata used in SSH packet logging mode, allowing us to * log in the packet header line that the packet came from a diff --git a/ssh2bpp-bare.c b/ssh2bpp-bare.c index bdf49a21..e16c8509 100644 --- a/ssh2bpp-bare.c +++ b/ssh2bpp-bare.c @@ -79,8 +79,6 @@ static void ssh2_bare_bpp_handle_input(BinaryPacketProtocol *bpp) s->pktin->refcount = 1; s->data = snew_plus_get_aux(s->pktin); - s->pktin->encrypted_len = s->packetlen; - s->pktin->sequence = s->incoming_sequence++; /* diff --git a/ssh2bpp.c b/ssh2bpp.c index 0b95d7fe..2b7f4654 100644 --- a/ssh2bpp.c +++ b/ssh2bpp.c @@ -24,6 +24,7 @@ struct ssh2_bpp_state { unsigned char *data; unsigned cipherblk; PktIn *pktin; + struct DataTransferStats *stats; struct ssh2_bpp_direction in, out; /* comp and decomp logically belong in the per-direction @@ -48,11 +49,12 @@ const struct BinaryPacketProtocolVtable ssh2_bpp_vtable = { ssh2_bpp_format_packet, }; -BinaryPacketProtocol *ssh2_bpp_new(void) +BinaryPacketProtocol *ssh2_bpp_new(struct DataTransferStats *stats) { struct ssh2_bpp_state *s = snew(struct ssh2_bpp_state); memset(s, 0, sizeof(*s)); s->bpp.vt = &ssh2_bpp_vtable; + s->stats = stats; return &s->bpp; } @@ -407,7 +409,8 @@ static void ssh2_bpp_handle_input(BinaryPacketProtocol *bpp) s->payload = s->len - s->pad - 1; s->length = s->payload + 5; - s->pktin->encrypted_len = s->packetlen; + + DTS_CONSUME(s->stats, in, s->packetlen); s->pktin->sequence = s->in.sequence++; @@ -601,7 +604,8 @@ static void ssh2_bpp_format_packet_inner(struct ssh2_bpp_state *s, PktOut *pkt) } s->out.sequence++; /* whether or not we MACed */ - pkt->encrypted_len = origlen + padding; + + DTS_CONSUME(s->stats, out, origlen + padding); } diff --git a/sshbpp.h b/sshbpp.h index 953add3a..a1c9d0ac 100644 --- a/sshbpp.h +++ b/sshbpp.h @@ -36,7 +36,32 @@ void ssh1_bpp_new_cipher(BinaryPacketProtocol *bpp, const void *session_key); void ssh1_bpp_start_compression(BinaryPacketProtocol *bpp); -BinaryPacketProtocol *ssh2_bpp_new(void); +/* + * Structure that tracks how much data is sent and received, for + * purposes of triggering an SSH-2 rekey when either one gets over a + * configured limit. In each direction, the flag 'running' indicates + * that we haven't hit the limit yet, and 'remaining' tracks how much + * longer until we do. The macro DTS_CONSUME subtracts a given amount + * from the counter in a particular direction, and evaluates to a + * boolean indicating whether the limit has been hit. + * + * The limit is sticky: once 'running' has flipped to false, + * 'remaining' is no longer decremented, so it shouldn't dangerously + * wrap round. + */ +struct DataTransferStats { + struct { + int running; + unsigned long remaining; + } in, out; +}; +#define DTS_CONSUME(stats, direction, size) \ + ((stats)->direction.running && \ + (stats)->direction.remaining <= (size) ? \ + ((stats)->direction.running = FALSE, TRUE) : \ + ((stats)->direction.remaining -= (size), FALSE)) + +BinaryPacketProtocol *ssh2_bpp_new(struct DataTransferStats *stats); void ssh2_bpp_new_outgoing_crypto( BinaryPacketProtocol *bpp, const struct ssh2_cipheralg *cipher, const void *ckey, const void *iv,