diff --git a/ssh.h b/ssh.h index 0be698d7..ec972e05 100644 --- a/ssh.h +++ b/ssh.h @@ -52,6 +52,7 @@ struct ssh_channel; typedef struct PacketQueueNode PacketQueueNode; struct PacketQueueNode { PacketQueueNode *next, *prev; + size_t formal_size; /* contribution to PacketQueueBase's total_size */ bool on_free_queue; /* is this packet scheduled for freeing? */ }; @@ -84,6 +85,7 @@ typedef struct PktOut { typedef struct PacketQueueBase { PacketQueueNode end; + size_t total_size; /* sum of all formal_size fields on the queue */ struct IdempotentCallback *ic; } PacketQueueBase; diff --git a/ssh1bpp.c b/ssh1bpp.c index 0bc729e7..2fd684c5 100644 --- a/ssh1bpp.c +++ b/ssh1bpp.c @@ -236,6 +236,7 @@ static void ssh1_bpp_handle_input(BinaryPacketProtocol *bpp) NULL, 0, NULL); } + s->pktin->qnode.formal_size = get_avail(s->pktin); pq_push(&s->bpp.in_pq, s->pktin); { diff --git a/ssh2bpp-bare.c b/ssh2bpp-bare.c index 8dc3b117..dd2f918a 100644 --- a/ssh2bpp-bare.c +++ b/ssh2bpp-bare.c @@ -129,6 +129,7 @@ static void ssh2_bare_bpp_handle_input(BinaryPacketProtocol *bpp) continue; } + s->pktin->qnode.formal_size = get_avail(s->pktin); pq_push(&s->bpp.in_pq, s->pktin); s->pktin = NULL; } diff --git a/ssh2bpp.c b/ssh2bpp.c index 79b97b31..97179ca9 100644 --- a/ssh2bpp.c +++ b/ssh2bpp.c @@ -589,6 +589,7 @@ static void ssh2_bpp_handle_input(BinaryPacketProtocol *bpp) continue; } + s->pktin->qnode.formal_size = get_avail(s->pktin); pq_push(&s->bpp.in_pq, s->pktin); { diff --git a/sshcommon.c b/sshcommon.c index 26073843..0d56460f 100644 --- a/sshcommon.c +++ b/sshcommon.c @@ -35,6 +35,7 @@ void pq_base_push(PacketQueueBase *pqb, PacketQueueNode *node) node->prev = pqb->end.prev; node->next->prev = node; node->prev->next = node; + pqb->total_size += node->formal_size; if (pqb->ic) queue_idempotent_callback(pqb->ic); @@ -47,6 +48,7 @@ void pq_base_push_front(PacketQueueBase *pqb, PacketQueueNode *node) node->next = pqb->end.next; node->next->prev = node; node->prev->next = node; + pqb->total_size += node->formal_size; if (pqb->ic) queue_idempotent_callback(pqb->ic); @@ -72,6 +74,23 @@ static IdempotentCallback ic_pktin_free = { pktin_free_queue_callback, NULL, false }; +static inline void pq_unlink_common(PacketQueueBase *pqb, + PacketQueueNode *node) +{ + node->next->prev = node->prev; + node->prev->next = node->next; + + /* Check total_size doesn't drift out of sync downwards, by + * ensuring it doesn't underflow when we do this subtraction */ + assert(pqb->total_size >= node->formal_size); + pqb->total_size -= node->formal_size; + + /* Check total_size doesn't drift out of sync upwards, by checking + * that it's returned to exactly zero whenever a queue is + * emptied */ + assert(pqb->end.next != &pqb->end || pqb->total_size == 0); +} + static PktIn *pq_in_after(PacketQueueBase *pqb, PacketQueueNode *prev, bool pop) { @@ -80,14 +99,14 @@ static PktIn *pq_in_after(PacketQueueBase *pqb, return NULL; if (pop) { - node->next->prev = node->prev; - node->prev->next = node->next; + pq_unlink_common(pqb, node); node->prev = pktin_freeq_head.prev; node->next = &pktin_freeq_head; node->next->prev = node; node->prev->next = node; node->on_free_queue = true; + queue_idempotent_callback(&ic_pktin_free); } @@ -102,8 +121,8 @@ static PktOut *pq_out_after(PacketQueueBase *pqb, return NULL; if (pop) { - node->next->prev = node->prev; - node->prev->next = node->next; + pq_unlink_common(pqb, node); + node->prev = node->next = NULL; } @@ -115,6 +134,7 @@ void pq_in_init(PktInQueue *pq) pq->pqb.ic = NULL; pq->pqb.end.next = pq->pqb.end.prev = &pq->pqb.end; pq->after = pq_in_after; + pq->pqb.total_size = 0; } void pq_out_init(PktOutQueue *pq) @@ -122,6 +142,7 @@ void pq_out_init(PktOutQueue *pq) pq->pqb.ic = NULL; pq->pqb.end.next = pq->pqb.end.prev = &pq->pqb.end; pq->after = pq_out_after; + pq->pqb.total_size = 0; } void pq_in_clear(PktInQueue *pq) @@ -153,6 +174,8 @@ void pq_base_concatenate(PacketQueueBase *qdest, { struct PacketQueueNode *head1, *tail1, *head2, *tail2; + size_t total_size = q1->total_size + q2->total_size; + /* * Extract the contents from both input queues, and empty them. */ @@ -164,6 +187,7 @@ void pq_base_concatenate(PacketQueueBase *qdest, q1->end.next = q1->end.prev = &q1->end; q2->end.next = q2->end.prev = &q2->end; + q1->total_size = q2->total_size = 0; /* * Link the two lists together, handling the case where one or @@ -206,6 +230,8 @@ void pq_base_concatenate(PacketQueueBase *qdest, if (qdest->ic) queue_idempotent_callback(qdest->ic); } + + qdest->total_size = total_size; } /* ---------------------------------------------------------------------- @@ -235,6 +261,7 @@ static void ssh_pkt_adddata(PktOut *pkt, const void *data, int len) sgrowarrayn_nm(pkt->data, pkt->maxlen, pkt->length, len); memcpy(pkt->data + pkt->length, data, len); pkt->length += len; + pkt->qnode.formal_size = pkt->length; } static void ssh_pkt_BinarySink_write(BinarySink *bs,