From d437e5402ead209315c38cd9daff8cce8b4d7966 Mon Sep 17 00:00:00 2001 From: Simon Tatham Date: Fri, 14 Sep 2018 09:16:41 +0100 Subject: [PATCH] Make ssh_compress into a pair of linked classoids. This was mildly fiddly because there's a single vtable structure that implements two distinct interface types, one for compression and one for decompression - and I have actually confused them before now (commit d4304f1b7), so I think it's important to make them actually be separate types! --- ssh.c | 32 +++++++++++++++++---------- ssh.h | 50 ++++++++++++++++++++++-------------------- ssh1bpp.c | 21 +++++++++--------- ssh2bpp.c | 57 ++++++++++++++++++++++++------------------------ sshbpp.h | 4 ++-- sshzlib.c | 65 +++++++++++++++++++++++++++++++++++-------------------- 6 files changed, 130 insertions(+), 99 deletions(-) diff --git a/ssh.c b/ssh.c index f4061894..29f25466 100644 --- a/ssh.c +++ b/ssh.c @@ -335,31 +335,39 @@ const static struct ssh2_macalg *const buggymacs[] = { &ssh_hmac_sha1_buggy, &ssh_hmac_sha1_96_buggy, &ssh_hmac_md5 }; -static void *ssh_comp_none_init(void) +static ssh_compressor *ssh_comp_none_init(void) { return NULL; } -static void ssh_comp_none_cleanup(void *handle) +static void ssh_comp_none_cleanup(ssh_compressor *handle) { } -static void ssh_comp_none_block(void *handle, unsigned char *block, int len, +static ssh_decompressor *ssh_decomp_none_init(void) +{ + return NULL; +} +static void ssh_decomp_none_cleanup(ssh_decompressor *handle) +{ +} +static void ssh_comp_none_block(ssh_compressor *handle, + unsigned char *block, int len, unsigned char **outblock, int *outlen, int minlen) { } -static int ssh_decomp_none_block(void *handle, unsigned char *block, int len, +static int ssh_decomp_none_block(ssh_decompressor *handle, + unsigned char *block, int len, unsigned char **outblock, int *outlen) { return 0; } -const static struct ssh_compress ssh_comp_none = { +const static struct ssh_compression_alg ssh_comp_none = { "none", NULL, ssh_comp_none_init, ssh_comp_none_cleanup, ssh_comp_none_block, - ssh_comp_none_init, ssh_comp_none_cleanup, ssh_decomp_none_block, + ssh_decomp_none_init, ssh_decomp_none_cleanup, ssh_decomp_none_block, NULL }; -extern const struct ssh_compress ssh_zlib; -const static struct ssh_compress *const compressions[] = { +const static struct ssh_compression_alg *const compressions[] = { &ssh_zlib, &ssh_comp_none }; @@ -4853,7 +4861,7 @@ struct kexinit_algorithm { const struct ssh2_macalg *mac; int etm; } mac; - const struct ssh_compress *comp; + const struct ssh_compression_alg *comp; } u; }; @@ -5025,7 +5033,7 @@ static void do_ssh2_transport(void *vctx) const struct ssh2_cipheralg *cipher; const struct ssh2_macalg *mac; int etm_mode; - const struct ssh_compress *comp; + const struct ssh_compression_alg *comp; } in, out; ptrlen hostkeydata, sigdata; char *keystr, *fingerprint; @@ -5042,7 +5050,7 @@ static void do_ssh2_transport(void *vctx) int preferred_hk[HK_MAX]; int n_preferred_ciphers; const struct ssh2_ciphers *preferred_ciphers[CIPHER_MAX]; - const struct ssh_compress *preferred_comp; + const struct ssh_compression_alg *preferred_comp; int userauth_succeeded; /* for delayed compression */ int pending_compression; int got_session_id; @@ -5413,7 +5421,7 @@ static void do_ssh2_transport(void *vctx) alg->u.comp = s->preferred_comp; } for (i = 0; i < lenof(compressions); i++) { - const struct ssh_compress *c = compressions[i]; + const struct ssh_compression_alg *c = compressions[i]; alg = ssh2_kexinit_addalg(s->kexlists[j], c->name); alg->u.comp = c; if (s->userauth_succeeded && c->delayed_name) { diff --git a/ssh.h b/ssh.h index 377c0568..7f038b9c 100644 --- a/ssh.h +++ b/ssh.h @@ -603,23 +603,39 @@ struct ssh_keyalg { #define ssh_key_ssh_id(key) ((*(key))->ssh_id) #define ssh_key_cache_id(key) ((*(key))->cache_id) -struct ssh_compress { +typedef struct ssh_compressor { + const struct ssh_compression_alg *vt; +} ssh_compressor; +typedef struct ssh_decompressor { + const struct ssh_compression_alg *vt; +} ssh_decompressor; + +struct ssh_compression_alg { const char *name; /* For zlib@openssh.com: if non-NULL, this name will be considered once * userauth has completed successfully. */ const char *delayed_name; - void *(*compress_init) (void); - void (*compress_cleanup) (void *); - void (*compress) (void *, unsigned char *block, int len, - unsigned char **outblock, int *outlen, - int minlen); - void *(*decompress_init) (void); - void (*decompress_cleanup) (void *); - int (*decompress) (void *, unsigned char *block, int len, - unsigned char **outblock, int *outlen); + ssh_compressor *(*compress_new)(void); + void (*compress_free)(ssh_compressor *); + void (*compress)(ssh_compressor *, unsigned char *block, int len, + unsigned char **outblock, int *outlen, + int minlen); + ssh_decompressor *(*decompress_new)(void); + void (*decompress_free)(ssh_decompressor *); + int (*decompress)(ssh_decompressor *, unsigned char *block, int len, + unsigned char **outblock, int *outlen); const char *text_name; }; +#define ssh_compressor_new(alg) ((alg)->compress_new()) +#define ssh_compressor_free(comp) ((comp)->vt->compress_free(comp)) +#define ssh_compressor_compress(comp, in, inlen, out, outlen, minlen) \ + ((comp)->vt->compress(comp, in, inlen, out, outlen, minlen)) +#define ssh_decompressor_new(alg) ((alg)->decompress_new()) +#define ssh_decompressor_free(comp) ((comp)->vt->decompress_free(comp)) +#define ssh_decompressor_decompress(comp, in, inlen, out, outlen) \ + ((comp)->vt->decompress(comp, in, inlen, out, outlen)) + struct ssh2_userkey { ssh_key *key; /* the key itself */ char *comment; /* the key comment */ @@ -659,6 +675,7 @@ extern const struct ssh2_macalg ssh_hmac_sha1_buggy; extern const struct ssh2_macalg ssh_hmac_sha1_96; extern const struct ssh2_macalg ssh_hmac_sha1_96_buggy; extern const struct ssh2_macalg ssh_hmac_sha256; +extern const struct ssh_compression_alg ssh_zlib; typedef struct AESContext AESContext; AESContext *aes_make_context(void); @@ -1017,19 +1034,6 @@ Bignum primegen(int bits, int modulus, int residue, Bignum factor, int phase, progfn_t pfn, void *pfnparam, unsigned firstbits); void invent_firstbits(unsigned *one, unsigned *two); - -/* - * zlib compression. - */ -void *zlib_compress_init(void); -void zlib_compress_cleanup(void *); -void *zlib_decompress_init(void); -void zlib_decompress_cleanup(void *); -void zlib_compress_block(void *, unsigned char *block, int len, - unsigned char **outblock, int *outlen, int minlen); -int zlib_decompress_block(void *, unsigned char *block, int len, - unsigned char **outblock, int *outlen); - /* * Connection-sharing API provided by platforms. This function must * either: diff --git a/ssh1bpp.c b/ssh1bpp.c index 627d55a5..a1fd78c9 100644 --- a/ssh1bpp.c +++ b/ssh1bpp.c @@ -21,7 +21,8 @@ struct ssh1_bpp_state { struct crcda_ctx *crcda_ctx; - void *compctx, *decompctx; + ssh_compressor *compctx; + ssh_decompressor *decompctx; BinaryPacketProtocol bpp; }; @@ -52,9 +53,9 @@ static void ssh1_bpp_free(BinaryPacketProtocol *bpp) if (s->cipher) ssh1_cipher_free(s->cipher); if (s->compctx) - zlib_compress_cleanup(s->compctx); + ssh_compressor_free(s->compctx); if (s->decompctx) - zlib_decompress_cleanup(s->decompctx); + ssh_decompressor_free(s->decompctx); if (s->crcda_ctx) crcda_free_context(s->crcda_ctx); if (s->pktin) @@ -90,8 +91,8 @@ void ssh1_bpp_start_compression(BinaryPacketProtocol *bpp) assert(!s->compctx); assert(!s->decompctx); - s->compctx = zlib_compress_init(); - s->decompctx = zlib_decompress_init(); + s->compctx = ssh_compressor_new(&ssh_zlib); + s->decompctx = ssh_decompressor_new(&ssh_zlib); } static void ssh1_bpp_handle_input(BinaryPacketProtocol *bpp) @@ -157,9 +158,9 @@ static void ssh1_bpp_handle_input(BinaryPacketProtocol *bpp) if (s->decompctx) { unsigned char *decompblk; int decomplen; - if (!zlib_decompress_block(s->decompctx, - s->data + s->pad, s->length + 1, - &decompblk, &decomplen)) { + if (!ssh_decompressor_decompress( + s->decompctx, s->data + s->pad, s->length + 1, + &decompblk, &decomplen)) { s->bpp.error = dupprintf( "Zlib decompression encountered invalid data"); crStopV; @@ -248,8 +249,8 @@ static void ssh1_bpp_format_packet(BinaryPacketProtocol *bpp, PktOut *pkt) if (s->compctx) { unsigned char *compblk; int complen; - zlib_compress_block(s->compctx, pkt->data + 12, pkt->length - 12, - &compblk, &complen, 0); + ssh_compressor_compress(s->compctx, pkt->data + 12, pkt->length - 12, + &compblk, &complen, 0); /* Replace the uncompressed packet data with the compressed * version. */ pkt->length = 12; diff --git a/ssh2bpp.c b/ssh2bpp.c index ff957b05..0b95d7fe 100644 --- a/ssh2bpp.c +++ b/ssh2bpp.c @@ -14,8 +14,6 @@ struct ssh2_bpp_direction { ssh2_cipher *cipher; ssh2_mac *mac; int etm_mode; - const struct ssh_compress *comp; - void *comp_ctx; }; struct ssh2_bpp_state { @@ -28,6 +26,11 @@ struct ssh2_bpp_state { PktIn *pktin; struct ssh2_bpp_direction in, out; + /* comp and decomp logically belong in the per-direction + * substructure, except that they have different types */ + ssh_decompressor *in_decomp; + ssh_compressor *out_comp; + int pending_newkeys; BinaryPacketProtocol bpp; @@ -61,14 +64,14 @@ static void ssh2_bpp_free(BinaryPacketProtocol *bpp) ssh2_cipher_free(s->out.cipher); if (s->out.mac) ssh2_mac_free(s->out.mac); - if (s->out.comp_ctx) - s->out.comp->compress_cleanup(s->out.comp_ctx); + if (s->out_comp) + ssh_compressor_free(s->out_comp); if (s->in.cipher) ssh2_cipher_free(s->in.cipher); if (s->in.mac) ssh2_mac_free(s->in.mac); - if (s->in.comp_ctx) - s->in.comp->decompress_cleanup(s->in.comp_ctx); + if (s->in_decomp) + ssh_decompressor_free(s->in_decomp); if (s->pktin) ssh_unref_packet(s->pktin); sfree(s); @@ -78,7 +81,7 @@ void ssh2_bpp_new_outgoing_crypto( BinaryPacketProtocol *bpp, const struct ssh2_cipheralg *cipher, const void *ckey, const void *iv, const struct ssh2_macalg *mac, int etm_mode, const void *mac_key, - const struct ssh_compress *compression) + const struct ssh_compression_alg *compression) { struct ssh2_bpp_state *s; assert(bpp->vt == &ssh2_bpp_vtable); @@ -88,8 +91,8 @@ void ssh2_bpp_new_outgoing_crypto( ssh2_cipher_free(s->out.cipher); if (s->out.mac) ssh2_mac_free(s->out.mac); - if (s->out.comp_ctx) - s->out.comp->compress_cleanup(s->out.comp_ctx); + if (s->out_comp) + ssh_compressor_free(s->out_comp); if (cipher) { s->out.cipher = ssh2_cipher_new(cipher); @@ -106,18 +109,17 @@ void ssh2_bpp_new_outgoing_crypto( s->out.mac = NULL; } - s->out.comp = compression; - /* out_comp is always non-NULL, because no compression is - * indicated by ssh_comp_none. So compress_init always exists, but - * it may return a null out_comp_ctx. */ - s->out.comp_ctx = compression->compress_init(); + /* 'compression' is always non-NULL, because no compression is + * indicated by ssh_comp_none. But this setup call may return a + * null out_comp. */ + s->out_comp = ssh_compressor_new(compression); } void ssh2_bpp_new_incoming_crypto( BinaryPacketProtocol *bpp, const struct ssh2_cipheralg *cipher, const void *ckey, const void *iv, const struct ssh2_macalg *mac, int etm_mode, const void *mac_key, - const struct ssh_compress *compression) + const struct ssh_compression_alg *compression) { struct ssh2_bpp_state *s; assert(bpp->vt == &ssh2_bpp_vtable); @@ -127,8 +129,8 @@ void ssh2_bpp_new_incoming_crypto( ssh2_cipher_free(s->in.cipher); if (s->in.mac) ssh2_mac_free(s->in.mac); - if (s->in.comp_ctx) - s->in.comp->decompress_cleanup(s->in.comp_ctx); + if (s->in_decomp) + ssh_decompressor_free(s->in_decomp); if (cipher) { s->in.cipher = ssh2_cipher_new(cipher); @@ -145,11 +147,10 @@ void ssh2_bpp_new_incoming_crypto( s->in.mac = NULL; } - s->in.comp = compression; - /* in_comp is always non-NULL, because no compression is - * indicated by ssh_comp_none. So compress_init always exists, but - * it may return a null in_comp_ctx. */ - s->in.comp_ctx = compression->decompress_init(); + /* 'compression' is always non-NULL, because no compression is + * indicated by ssh_comp_none. But this setup call may return a + * null in_decomp. */ + s->in_decomp = ssh_decompressor_new(compression); /* Clear the pending_newkeys flag, so that handle_input below will * start consuming the input data again. */ @@ -419,8 +420,8 @@ static void ssh2_bpp_handle_input(BinaryPacketProtocol *bpp) { unsigned char *newpayload; int newlen; - if (s->in.comp && s->in.comp->decompress( - s->in.comp_ctx, s->data + 5, s->length - 5, + if (s->in_decomp && ssh_decompressor_decompress( + s->in_decomp, s->data + 5, s->length - 5, &newpayload, &newlen)) { if (s->maxlen < newlen + 5) { PktIn *old_pktin = s->pktin; @@ -526,7 +527,7 @@ static void ssh2_bpp_format_packet_inner(struct ssh2_bpp_state *s, PktOut *pkt) cipherblk = s->out.cipher ? ssh2_cipher_alg(s->out.cipher)->blksize : 8; cipherblk = cipherblk < 8 ? 8 : cipherblk; /* or 8 if blksize < 8 */ - if (s->out.comp && s->out.comp_ctx) { + if (s->out_comp) { unsigned char *newpayload; int minlen, newlen; @@ -544,8 +545,8 @@ static void ssh2_bpp_format_packet_inner(struct ssh2_bpp_state *s, PktOut *pkt) minlen -= 8; /* length field + min padding */ } - s->out.comp->compress(s->out.comp_ctx, pkt->data + 5, pkt->length - 5, - &newpayload, &newlen, minlen); + ssh_compressor_compress(s->out_comp, pkt->data + 5, pkt->length - 5, + &newpayload, &newlen, minlen); pkt->length = 5; put_data(pkt, newpayload, newlen); sfree(newpayload); @@ -608,7 +609,7 @@ static void ssh2_bpp_format_packet(BinaryPacketProtocol *bpp, PktOut *pkt) { struct ssh2_bpp_state *s = FROMFIELD(bpp, struct ssh2_bpp_state, bpp); - if (pkt->minlen > 0 && !(s->out.comp && s->out.comp_ctx)) { + if (pkt->minlen > 0 && !s->out_comp) { /* * If we've been told to pad the packet out to a given minimum * length, but we're not compressing (and hence can't get the diff --git a/sshbpp.h b/sshbpp.h index f7e2fe2c..6e45262e 100644 --- a/sshbpp.h +++ b/sshbpp.h @@ -41,12 +41,12 @@ void ssh2_bpp_new_outgoing_crypto( BinaryPacketProtocol *bpp, const struct ssh2_cipheralg *cipher, const void *ckey, const void *iv, const struct ssh2_macalg *mac, int etm_mode, const void *mac_key, - const struct ssh_compress *compression); + const struct ssh_compression_alg *compression); void ssh2_bpp_new_incoming_crypto( BinaryPacketProtocol *bpp, const struct ssh2_cipheralg *cipher, const void *ckey, const void *iv, const struct ssh2_macalg *mac, int etm_mode, const void *mac_key, - const struct ssh_compress *compression); + const struct ssh_compression_alg *compression); BinaryPacketProtocol *ssh2_bare_bpp_new(void); diff --git a/sshzlib.c b/sshzlib.c index 17550477..ec8708b5 100644 --- a/sshzlib.c +++ b/sshzlib.c @@ -66,6 +66,10 @@ #define TRUE 1 #endif +typedef struct { const struct dummy *vt; } ssh_compressor; +typedef struct { const struct dummy *vt; } ssh_decompressor; +static const struct dummy { int i; } ssh_zlib; + #else #include "ssh.h" #endif @@ -600,37 +604,45 @@ static void zlib_match(struct LZ77Context *ectx, int distance, int len) } } -void *zlib_compress_init(void) +struct ssh_zlib_compressor { + struct LZ77Context ectx; + ssh_compressor sc; +}; + +ssh_compressor *zlib_compress_init(void) { struct Outbuf *out; - struct LZ77Context *ectx = snew(struct LZ77Context); + struct ssh_zlib_compressor *comp = snew(struct ssh_zlib_compressor); - lz77_init(ectx); - ectx->literal = zlib_literal; - ectx->match = zlib_match; + lz77_init(&comp->ectx); + comp->sc.vt = &ssh_zlib; + comp->ectx.literal = zlib_literal; + comp->ectx.match = zlib_match; out = snew(struct Outbuf); out->outbits = out->noutbits = 0; out->firstblock = 1; - ectx->userdata = out; + comp->ectx.userdata = out; - return ectx; + return &comp->sc; } -void zlib_compress_cleanup(void *handle) +void zlib_compress_cleanup(ssh_compressor *sc) { - struct LZ77Context *ectx = (struct LZ77Context *)handle; - sfree(ectx->userdata); - sfree(ectx->ictx); - sfree(ectx); + struct ssh_zlib_compressor *comp = + FROMFIELD(sc, struct ssh_zlib_compressor, sc); + sfree(comp->ectx.userdata); + sfree(comp->ectx.ictx); + sfree(comp); } -void zlib_compress_block(void *handle, unsigned char *block, int len, +void zlib_compress_block(ssh_compressor *sc, unsigned char *block, int len, unsigned char **outblock, int *outlen, int minlen) { - struct LZ77Context *ectx = (struct LZ77Context *)handle; - struct Outbuf *out = (struct Outbuf *) ectx->userdata; + struct ssh_zlib_compressor *comp = + FROMFIELD(sc, struct ssh_zlib_compressor, sc); + struct Outbuf *out = (struct Outbuf *) comp->ectx.userdata; int in_block; out->outbuf = NULL; @@ -662,7 +674,7 @@ void zlib_compress_block(void *handle, unsigned char *block, int len, /* * Do the compression. */ - lz77_compress(ectx, block, len, TRUE); + lz77_compress(&comp->ectx, block, len, TRUE); /* * End the block (by transmitting code 256, which is @@ -875,9 +887,11 @@ struct zlib_decompress_ctx { int winpos; unsigned char *outblk; int outlen, outsize; + + ssh_decompressor dc; }; -void *zlib_decompress_init(void) +ssh_decompressor *zlib_decompress_init(void) { struct zlib_decompress_ctx *dctx = snew(struct zlib_decompress_ctx); unsigned char lengths[288]; @@ -895,12 +909,14 @@ void *zlib_decompress_init(void) dctx->nbits = 0; dctx->winpos = 0; - return dctx; + dctx->dc.vt = &ssh_zlib; + return &dctx->dc; } -void zlib_decompress_cleanup(void *handle) +void zlib_decompress_cleanup(ssh_decompressor *dc) { - struct zlib_decompress_ctx *dctx = (struct zlib_decompress_ctx *)handle; + struct zlib_decompress_ctx *dctx = + FROMFIELD(dc, struct zlib_decompress_ctx, dc); if (dctx->currlentable && dctx->currlentable != dctx->staticlentable) zlib_freetable(&dctx->currlentable); @@ -958,10 +974,11 @@ static void zlib_emit_char(struct zlib_decompress_ctx *dctx, int c) #define EATBITS(n) ( dctx->nbits -= (n), dctx->bits >>= (n) ) -int zlib_decompress_block(void *handle, unsigned char *block, int len, +int zlib_decompress_block(ssh_decompressor *dc, unsigned char *block, int len, unsigned char **outblock, int *outlen) { - struct zlib_decompress_ctx *dctx = (struct zlib_decompress_ctx *)handle; + struct zlib_decompress_ctx *dctx = + FROMFIELD(dc, struct zlib_decompress_ctx, dc); const coderecord *rec; int code, blktype, rep, dist, nlen, header; static const unsigned char lenlenmap[] = { @@ -1217,7 +1234,7 @@ int main(int argc, char **argv) { unsigned char buf[16], *outbuf; int ret, outlen; - void *handle; + ssh_decompressor *handle; int noheader = FALSE, opts = TRUE; char *filename = NULL; FILE *fp; @@ -1289,7 +1306,7 @@ int main(int argc, char **argv) #else -const struct ssh_compress ssh_zlib = { +const struct ssh_compression_alg ssh_zlib = { "zlib", "zlib@openssh.com", /* delayed version */ zlib_compress_init,