diff --git a/sshhmac.c b/sshhmac.c index 31e15ea8..9bbce124 100644 --- a/sshhmac.c +++ b/sshhmac.c @@ -6,14 +6,16 @@ #include "ssh.h" struct hmac { + const ssh_hashalg *hashalg; ssh_hash *h_outer, *h_inner, *h_live; + bool keyed; uint8_t *digest; strbuf *text_name; ssh2_mac mac; }; struct hmac_extra { - const ssh_hashalg *hashalg; + const ssh_hashalg *hashalg_base; const char *suffix; }; @@ -22,6 +24,14 @@ static ssh2_mac *hmac_new(const ssh2_macalg *alg, ssh_cipher *cipher) struct hmac *ctx = snew(struct hmac); const struct hmac_extra *extra = (const struct hmac_extra *)alg->extra; + ctx->h_outer = ssh_hash_new(extra->hashalg_base); + /* In case that hashalg was a selector vtable, we'll now switch to + * using whatever real one it selected, for all future purposes. */ + ctx->hashalg = ssh_hash_alg(ctx->h_outer); + ctx->h_inner = ssh_hash_new(ctx->hashalg); + ctx->h_live = ssh_hash_new(ctx->hashalg); + ctx->keyed = false; + /* * HMAC is not well defined as a wrapper on an absolutely general * hash function; it expects that the function it's wrapping will @@ -29,17 +39,16 @@ static ssh2_mac *hmac_new(const ssh2_macalg *alg, ssh_cipher *cipher) * in terms of that block size. So we insist that the hash we're * given must have defined a meaningful block size. */ - assert(extra->hashalg->blocklen); + assert(ctx->hashalg->blocklen); - ctx->h_outer = ctx->h_inner = ctx->h_live = NULL; - ctx->digest = snewn(extra->hashalg->hlen, uint8_t); + ctx->digest = snewn(ctx->hashalg->hlen, uint8_t); ctx->text_name = strbuf_new(); strbuf_catf(ctx->text_name, "HMAC-%s%s", - extra->hashalg->text_name, extra->suffix); + ctx->hashalg->text_name, extra->suffix); ctx->mac.vt = alg; - BinarySink_DELEGATE_CLEAR(&ctx->mac); + BinarySink_DELEGATE_INIT(&ctx->mac, ctx->h_live); return &ctx->mac; } @@ -49,13 +58,10 @@ static void hmac_free(ssh2_mac *mac) struct hmac *ctx = container_of(mac, struct hmac, mac); const struct hmac_extra *extra = (const struct hmac_extra *)mac->vt->extra; - if (ctx->h_outer) - ssh_hash_free(ctx->h_outer); - if (ctx->h_inner) - ssh_hash_free(ctx->h_inner); - if (ctx->h_live) - ssh_hash_free(ctx->h_live); - smemclr(ctx->digest, extra->hashalg->hlen); + ssh_hash_free(ctx->h_outer); + ssh_hash_free(ctx->h_inner); + ssh_hash_free(ctx->h_live); + smemclr(ctx->digest, ctx->hashalg->hlen); sfree(ctx->digest); strbuf_free(ctx->text_name); @@ -75,16 +81,28 @@ static void hmac_key(ssh2_mac *mac, ptrlen key) size_t klen; strbuf *sb = NULL; - if (key.len > extra->hashalg->blocklen) { + if (ctx->keyed) { + /* + * If we've already been keyed, throw away the existing hash + * objects and make a fresh pair to put the new key in. + */ + ssh_hash_free(ctx->h_outer); + ssh_hash_free(ctx->h_inner); + ctx->h_outer = ssh_hash_new(ctx->hashalg); + ctx->h_inner = ssh_hash_new(ctx->hashalg); + } + ctx->keyed = true; + + if (key.len > ctx->hashalg->blocklen) { /* * RFC 2104 section 2: if the key exceeds the block length of * the underlying hash, then we start by hashing the key, and * use that hash as the 'true' key for the HMAC construction. */ sb = strbuf_new(); - strbuf_append(sb, extra->hashalg->hlen); + strbuf_append(sb, ctx->hashalg->hlen); - ssh_hash *htmp = ssh_hash_new(extra->hashalg); + ssh_hash *htmp = ssh_hash_new(ctx->hashalg); put_datapl(htmp, key); ssh_hash_final(htmp, sb->u); @@ -103,16 +121,16 @@ static void hmac_key(ssh2_mac *mac, ptrlen key) if (ctx->h_inner) ssh_hash_free(ctx->h_inner); - ctx->h_outer = ssh_hash_new(extra->hashalg); + ctx->h_outer = ssh_hash_new(ctx->hashalg); for (size_t i = 0; i < klen; i++) put_byte(ctx->h_outer, PAD_OUTER ^ kp[i]); - for (size_t i = klen; i < extra->hashalg->blocklen; i++) + for (size_t i = klen; i < ctx->hashalg->blocklen; i++) put_byte(ctx->h_outer, PAD_OUTER); - ctx->h_inner = ssh_hash_new(extra->hashalg); + ctx->h_inner = ssh_hash_new(ctx->hashalg); for (size_t i = 0; i < klen; i++) put_byte(ctx->h_inner, PAD_INNER ^ kp[i]); - for (size_t i = klen; i < extra->hashalg->blocklen; i++) + for (size_t i = klen; i < ctx->hashalg->blocklen; i++) put_byte(ctx->h_inner, PAD_INNER); if (sb) @@ -123,10 +141,7 @@ static void hmac_start(ssh2_mac *mac) { struct hmac *ctx = container_of(mac, struct hmac, mac); - assert(ctx->h_outer); - if (ctx->h_live) - ssh_hash_free(ctx->h_live); - + ssh_hash_free(ctx->h_live); ctx->h_live = ssh_hash_copy(ctx->h_inner); BinarySink_DELEGATE_INIT(&ctx->mac, ctx->h_live); } @@ -135,15 +150,16 @@ static void hmac_genresult(ssh2_mac *mac, unsigned char *output) { struct hmac *ctx = container_of(mac, struct hmac, mac); const struct hmac_extra *extra = (const struct hmac_extra *)mac->vt->extra; + ssh_hash *htmp; - assert(ctx->h_live); - ssh_hash_final(ctx->h_live, ctx->digest); + /* Leave h_live in place, so that the SSH-2 BPP can continue + * regenerating test results from different-length prefixes of the + * packet */ + htmp = ssh_hash_copy(ctx->h_live); + ssh_hash_final(htmp, ctx->digest); - ctx->h_live = NULL; - BinarySink_DELEGATE_CLEAR(&ctx->mac); - - ssh_hash *htmp = ssh_hash_copy(ctx->h_outer); - put_data(htmp, ctx->digest, extra->hashalg->hlen); + htmp = ssh_hash_copy(ctx->h_outer); + put_data(htmp, ctx->digest, ctx->hashalg->hlen); ssh_hash_final(htmp, ctx->digest); /* @@ -152,7 +168,7 @@ static void hmac_genresult(ssh2_mac *mac, unsigned char *output) * full-length buffer, and now we copy the required amount. */ memcpy(output, ctx->digest, mac->vt->len); - smemclr(ctx->digest, extra->hashalg->hlen); + smemclr(ctx->digest, ctx->hashalg->hlen); } static const char *hmac_text_name(ssh2_mac *mac)