diff --git a/ssh2bpp.c b/ssh2bpp.c index 2e73b56e..8cb82356 100644 --- a/ssh2bpp.c +++ b/ssh2bpp.c @@ -67,22 +67,42 @@ BinaryPacketProtocol *ssh2_bpp_new( return &s->bpp; } +static void ssh2_bpp_free_outgoing_crypto(struct ssh2_bpp_state *s) +{ + /* + * We must free the MAC before the cipher, because sometimes the + * MAC is not actually separately allocated but just a different + * facet of the same object as the cipher, in which case + * ssh2_mac_free does nothing and ssh2_cipher_free does the actual + * freeing. So if we freed the cipher first and then tried to + * dereference the MAC's vtable pointer to find out how to free + * that too, we'd be accessing freed memory. + */ + if (s->out.mac) + ssh2_mac_free(s->out.mac); + if (s->out.cipher) + ssh2_cipher_free(s->out.cipher); + if (s->out_comp) + ssh_compressor_free(s->out_comp); +} + +static void ssh2_bpp_free_incoming_crypto(struct ssh2_bpp_state *s) +{ + /* As above, take care to free in.mac before in.cipher */ + if (s->in.mac) + ssh2_mac_free(s->in.mac); + if (s->in.cipher) + ssh2_cipher_free(s->in.cipher); + if (s->in_decomp) + ssh_decompressor_free(s->in_decomp); +} + static void ssh2_bpp_free(BinaryPacketProtocol *bpp) { struct ssh2_bpp_state *s = container_of(bpp, struct ssh2_bpp_state, bpp); sfree(s->buf); - if (s->out.cipher) - ssh2_cipher_free(s->out.cipher); - if (s->out.mac) - ssh2_mac_free(s->out.mac); - if (s->out_comp) - ssh_compressor_free(s->out_comp); - if (s->in.cipher) - ssh2_cipher_free(s->in.cipher); - if (s->in.mac) - ssh2_mac_free(s->in.mac); - if (s->in_decomp) - ssh_decompressor_free(s->in_decomp); + ssh2_bpp_free_outgoing_crypto(s); + ssh2_bpp_free_incoming_crypto(s); sfree(s->pktin); sfree(s); } @@ -97,12 +117,7 @@ void ssh2_bpp_new_outgoing_crypto( assert(bpp->vt == &ssh2_bpp_vtable); s = container_of(bpp, struct ssh2_bpp_state, bpp); - if (s->out.cipher) - ssh2_cipher_free(s->out.cipher); - if (s->out.mac) - ssh2_mac_free(s->out.mac); - if (s->out_comp) - ssh_compressor_free(s->out_comp); + ssh2_bpp_free_outgoing_crypto(s); if (cipher) { s->out.cipher = ssh2_cipher_new(cipher); @@ -164,12 +179,7 @@ void ssh2_bpp_new_incoming_crypto( assert(bpp->vt == &ssh2_bpp_vtable); s = container_of(bpp, struct ssh2_bpp_state, bpp); - if (s->in.cipher) - ssh2_cipher_free(s->in.cipher); - if (s->in.mac) - ssh2_mac_free(s->in.mac); - if (s->in_decomp) - ssh_decompressor_free(s->in_decomp); + ssh2_bpp_free_incoming_crypto(s); if (cipher) { s->in.cipher = ssh2_cipher_new(cipher);