diff --git a/defs.h b/defs.h index f45abfec..61ada0b3 100644 --- a/defs.h +++ b/defs.h @@ -38,6 +38,8 @@ typedef uint32_t uint32; typedef struct BinarySink BinarySink; typedef struct BinarySource BinarySource; +typedef struct IdempotentCallback IdempotentCallback; + typedef struct SockAddr_tag *SockAddr; typedef struct Socket_vtable Socket_vtable; diff --git a/ssh.c b/ssh.c index 1d7580cb..b4c873d8 100644 --- a/ssh.c +++ b/ssh.c @@ -1043,7 +1043,6 @@ static void ssh_process_pq_full(void *ctx) if (ssh->general_packet_processing) ssh->general_packet_processing(ssh, pktin); ssh->packet_dispatch[pktin->type](ssh, pktin); - ssh_unref_packet(pktin); } } @@ -3316,14 +3315,12 @@ static void ssh1_connection_input(Ssh ssh) static void ssh1_coro_wrapper_initial(Ssh ssh, PktIn *pktin) { - pktin->refcount++; /* avoid packet being freed when we return */ pq_push(&ssh->pq_ssh1_login, pktin); queue_idempotent_callback(&ssh->ssh1_login_icb); } static void ssh1_coro_wrapper_session(Ssh ssh, PktIn *pktin) { - pktin->refcount++; /* avoid packet being freed when we return */ pq_push(&ssh->pq_ssh1_connection, pktin); queue_idempotent_callback(&ssh->ssh1_connection_icb); } @@ -6661,7 +6658,6 @@ static void ssh2_setup_env(struct ssh_channel *c, PktIn *pktin, */ static void ssh2_msg_userauth(Ssh ssh, PktIn *pktin) { - pktin->refcount++; /* avoid packet being freed when we return */ pq_push(&ssh->pq_ssh2_userauth, pktin); if (pktin->type == SSH2_MSG_USERAUTH_SUCCESS) { /* @@ -8136,7 +8132,6 @@ static void ssh2_userauth_input(Ssh ssh) */ static void ssh2_msg_connection(Ssh ssh, PktIn *pktin) { - pktin->refcount++; /* avoid packet being freed when we return */ pq_push(&ssh->pq_ssh2_connection, pktin); queue_idempotent_callback(&ssh->ssh2_connection_icb); } @@ -8610,7 +8605,6 @@ static void ssh2_msg_debug(Ssh ssh, PktIn *pktin) static void ssh2_msg_transport(Ssh ssh, PktIn *pktin) { - pktin->refcount++; /* avoid packet being freed when we return */ pq_push(&ssh->pq_ssh2_transport, pktin); queue_idempotent_callback(&ssh->ssh2_transport_icb); } diff --git a/ssh.h b/ssh.h index 200a5407..36bd90de 100644 --- a/ssh.h +++ b/ssh.h @@ -53,10 +53,10 @@ struct ssh_channel; typedef struct PacketQueueNode PacketQueueNode; struct PacketQueueNode { PacketQueueNode *next, *prev; + int on_free_queue; /* is this packet scheduled for freeing? */ }; typedef struct PktIn { - int refcount; int type; unsigned long sequence; /* SSH-2 incoming sequence number */ PacketQueueNode qnode; /* for linking this packet on to a queue */ @@ -157,7 +157,6 @@ int ssh2_censor_packet( ptrlen pkt, logblank_t *blanks); PktOut *ssh_new_packet(void); -void ssh_unref_packet(PktIn *pkt); void ssh_free_pktout(PktOut *pkt); extern Socket ssh_connection_sharing_init( diff --git a/ssh1bpp.c b/ssh1bpp.c index 90942587..92c89435 100644 --- a/ssh1bpp.c +++ b/ssh1bpp.c @@ -59,8 +59,7 @@ static void ssh1_bpp_free(BinaryPacketProtocol *bpp) ssh_decompressor_free(s->decompctx); if (s->crcda_ctx) crcda_free_context(s->crcda_ctx); - if (s->pktin) - ssh_unref_packet(s->pktin); + sfree(s->pktin); sfree(s); } @@ -125,7 +124,7 @@ static void ssh1_bpp_handle_input(BinaryPacketProtocol *bpp) */ s->pktin = snew_plus(PktIn, s->biglen); s->pktin->qnode.prev = s->pktin->qnode.next = NULL; - s->pktin->refcount = 1; + s->pktin->qnode.on_free_queue = FALSE; s->pktin->type = 0; s->maxlen = s->biglen; diff --git a/ssh2bpp-bare.c b/ssh2bpp-bare.c index c818e218..db645e2f 100644 --- a/ssh2bpp-bare.c +++ b/ssh2bpp-bare.c @@ -44,8 +44,7 @@ static void ssh2_bare_bpp_free(BinaryPacketProtocol *bpp) { struct ssh2_bare_bpp_state *s = FROMFIELD(bpp, struct ssh2_bare_bpp_state, bpp); - if (s->pktin) - ssh_unref_packet(s->pktin); + sfree(s->pktin); sfree(s); } @@ -75,8 +74,8 @@ static void ssh2_bare_bpp_handle_input(BinaryPacketProtocol *bpp) */ s->pktin = snew_plus(PktIn, s->packetlen); s->pktin->qnode.prev = s->pktin->qnode.next = NULL; + s->pktin->qnode.on_free_queue = FALSE; s->maxlen = 0; - s->pktin->refcount = 1; s->data = snew_plus_get_aux(s->pktin); s->pktin->sequence = s->incoming_sequence++; diff --git a/ssh2bpp.c b/ssh2bpp.c index b0dd2ea7..272e1711 100644 --- a/ssh2bpp.c +++ b/ssh2bpp.c @@ -74,8 +74,7 @@ static void ssh2_bpp_free(BinaryPacketProtocol *bpp) ssh2_mac_free(s->in.mac); if (s->in_decomp) ssh_decompressor_free(s->in_decomp); - if (s->pktin) - ssh_unref_packet(s->pktin); + sfree(s->pktin); sfree(s); } @@ -249,8 +248,8 @@ static void ssh2_bpp_handle_input(BinaryPacketProtocol *bpp) */ s->pktin = snew_plus(PktIn, s->maxlen); s->pktin->qnode.prev = s->pktin->qnode.next = NULL; - s->pktin->refcount = 1; s->pktin->type = 0; + s->pktin->qnode.on_free_queue = FALSE; s->data = snew_plus_get_aux(s->pktin); memcpy(s->data, s->buf, s->maxlen); } else if (s->in.mac && s->in.etm_mode) { @@ -300,8 +299,8 @@ static void ssh2_bpp_handle_input(BinaryPacketProtocol *bpp) */ s->pktin = snew_plus(PktIn, OUR_V2_PACKETLIMIT + s->maclen); s->pktin->qnode.prev = s->pktin->qnode.next = NULL; - s->pktin->refcount = 1; s->pktin->type = 0; + s->pktin->qnode.on_free_queue = FALSE; s->data = snew_plus_get_aux(s->pktin); memcpy(s->data, s->buf, 4); @@ -369,8 +368,8 @@ static void ssh2_bpp_handle_input(BinaryPacketProtocol *bpp) s->maxlen = s->packetlen + s->maclen; s->pktin = snew_plus(PktIn, s->maxlen); s->pktin->qnode.prev = s->pktin->qnode.next = NULL; - s->pktin->refcount = 1; s->pktin->type = 0; + s->pktin->qnode.on_free_queue = FALSE; s->data = snew_plus_get_aux(s->pktin); memcpy(s->data, s->buf, s->cipherblk); diff --git a/sshcommon.c b/sshcommon.c index ac23e791..4090f9c6 100644 --- a/sshcommon.c +++ b/sshcommon.c @@ -15,10 +15,20 @@ * Implementation of PacketQueue. */ +static void pq_ensure_unlinked(PacketQueueNode *node) +{ + if (node->on_free_queue) { + node->next->prev = node->prev; + node->prev->next = node->next; + } else { + assert(!node->next); + assert(!node->prev); + } +} + void pq_base_push(PacketQueueBase *pqb, PacketQueueNode *node) { - assert(!node->next); - assert(!node->prev); + pq_ensure_unlinked(node); node->next = &pqb->end; node->prev = pqb->end.prev; node->next->prev = node; @@ -27,14 +37,33 @@ void pq_base_push(PacketQueueBase *pqb, PacketQueueNode *node) void pq_base_push_front(PacketQueueBase *pqb, PacketQueueNode *node) { - assert(!node->next); - assert(!node->prev); + pq_ensure_unlinked(node); node->prev = &pqb->end; node->next = pqb->end.next; node->next->prev = node; node->prev->next = node; } +static PacketQueueNode pktin_freeq_head = { + &pktin_freeq_head, &pktin_freeq_head, TRUE +}; + +static void pktin_free_queue_callback(void *vctx) +{ + while (pktin_freeq_head.next != &pktin_freeq_head) { + PacketQueueNode *node = pktin_freeq_head.next; + PktIn *pktin = FROMFIELD(node, PktIn, qnode); + pktin_freeq_head.next = node->next; + sfree(pktin); + } + + pktin_freeq_head.prev = &pktin_freeq_head; +} + +static IdempotentCallback ic_pktin_free = { + pktin_free_queue_callback, NULL, FALSE +}; + static PktIn *pq_in_get(PacketQueueBase *pqb, int pop) { PacketQueueNode *node = pqb->end.next; @@ -44,7 +73,13 @@ static PktIn *pq_in_get(PacketQueueBase *pqb, int pop) if (pop) { node->next->prev = node->prev; node->prev->next = node->next; - node->prev = node->next = NULL; + + node->prev = pktin_freeq_head.prev; + node->next = &pktin_freeq_head; + node->next->prev = node; + node->prev->next = node; + node->on_free_queue = TRUE; + queue_idempotent_callback(&ic_pktin_free); } return FROMFIELD(node, PktIn, qnode); @@ -80,8 +115,11 @@ void pq_out_init(PktOutQueue *pq) void pq_in_clear(PktInQueue *pq) { PktIn *pkt; - while ((pkt = pq_pop(pq)) != NULL) - ssh_unref_packet(pkt); + while ((pkt = pq_pop(pq)) != NULL) { + /* No need to actually free these packets: pq_pop on a + * PktInQueue will automatically move them to the free + * queue. */ + } } void pq_out_clear(PktOutQueue *pq) @@ -170,6 +208,7 @@ PktOut *ssh_new_packet(void) pkt->downstream_id = 0; pkt->additional_log_text = NULL; pkt->qnode.next = pkt->qnode.prev = NULL; + pkt->qnode.on_free_queue = FALSE; return pkt; } @@ -195,12 +234,6 @@ static void ssh_pkt_BinarySink_write(BinarySink *bs, ssh_pkt_adddata(pkt, data, len); } -void ssh_unref_packet(PktIn *pkt) -{ - if (--pkt->refcount <= 0) - sfree(pkt); -} - void ssh_free_pktout(PktOut *pkt) { sfree(pkt->data);