diff --git a/ssh.h b/ssh.h index c771604b..0014a502 100644 --- a/ssh.h +++ b/ssh.h @@ -478,6 +478,7 @@ void BinarySource_get_rsa_ssh1_priv( BinarySource *src, struct RSAKey *rsa); int rsa_ssh1_encrypt(unsigned char *data, int length, struct RSAKey *key); Bignum rsa_ssh1_decrypt(Bignum input, struct RSAKey *key); +int rsa_ssh1_decrypt_pkcs1(Bignum input, struct RSAKey *key, strbuf *outbuf); void rsasanitise(struct RSAKey *key); int rsastr_len(struct RSAKey *key); void rsastr_fmt(char *str, struct RSAKey *key); @@ -504,12 +505,17 @@ int detect_attack(struct crcda_ctx *ctx, unsigned char *buf, uint32 len, * SSH2 RSA key exchange functions */ struct ssh_hashalg; +struct ssh_rsa_kex_extra { + int minklen; +}; struct RSAKey *ssh_rsakex_newkey(const void *data, int len); void ssh_rsakex_freekey(struct RSAKey *key); int ssh_rsakex_klen(struct RSAKey *key); void ssh_rsakex_encrypt(const struct ssh_hashalg *h, unsigned char *in, int inlen, unsigned char *out, int outlen, struct RSAKey *key); +Bignum ssh_rsakex_decrypt(const struct ssh_hashalg *h, ptrlen ciphertext, + struct RSAKey *rsa); /* * SSH2 ECDH key exchange functions diff --git a/sshrsa.c b/sshrsa.c index 4d4ea5b1..86a392eb 100644 --- a/sshrsa.c +++ b/sshrsa.c @@ -285,6 +285,42 @@ Bignum rsa_ssh1_decrypt(Bignum input, struct RSAKey *key) return rsa_privkey_op(input, key); } +int rsa_ssh1_decrypt_pkcs1(Bignum input, struct RSAKey *key, strbuf *outbuf) +{ + strbuf *data = strbuf_new(); + int success = FALSE; + BinarySource src[1]; + + { + Bignum *b = rsa_ssh1_decrypt(input, key); + int i; + for (i = (bignum_bitcount(key->modulus) + 7) / 8; i-- > 0 ;) { + put_byte(data, bignum_byte(b, i)); + } + freebn(b); + } + + BinarySource_BARE_INIT(src, data->u, data->len); + + /* Check PKCS#1 formatting prefix */ + if (get_byte(src) != 0) goto out; + if (get_byte(src) != 2) goto out; + while (1) { + unsigned char byte = get_byte(src); + if (get_err(src)) goto out; + if (byte == 0) + break; + } + + /* Everything else is the payload */ + success = TRUE; + put_data(outbuf, get_ptr(src), get_avail(src)); + + out: + strbuf_free(data); + return success; +} + int rsastr_len(struct RSAKey *key) { Bignum md, ex; @@ -903,12 +939,88 @@ void ssh_rsakex_encrypt(const struct ssh_hashalg *h, */ } +Bignum ssh_rsakex_decrypt(const struct ssh_hashalg *h, ptrlen ciphertext, + struct RSAKey *rsa) +{ + Bignum b1, b2; + int outlen, i; + unsigned char *out; + unsigned char labelhash[64]; + ssh_hash *hash; + BinarySource src[1]; + const int HLEN = h->hlen; + + /* + * Decryption side of the RSA key exchange operation. + */ + + /* The length of the encrypted data should be exactly the length + * in octets of the RSA modulus.. */ + outlen = (7 + bignum_bitcount(rsa->modulus)) / 8; + if (ciphertext.len != outlen) + return NULL; + + /* Do the RSA decryption, and extract the result into a byte array. */ + b1 = bignum_from_bytes(ciphertext.ptr, ciphertext.len); + b2 = rsa_privkey_op(b1, rsa); + out = snewn(outlen, unsigned char); + for (i = 0; i < outlen; i++) + out[i] = bignum_byte(b2, outlen-1-i); + freebn(b1); + freebn(b2); + + /* Do the OAEP masking operations, in the reverse order from encryption */ + oaep_mask(h, out+HLEN+1, outlen-HLEN-1, out+1, HLEN); + oaep_mask(h, out+1, HLEN, out+HLEN+1, outlen-HLEN-1); + + /* Check the leading byte is zero. */ + if (out[0] != 0) { + sfree(out); + return NULL; + } + /* Check the label hash at position 1+HLEN */ + assert(HLEN <= lenof(labelhash)); + hash = ssh_hash_new(h); + ssh_hash_final(hash, labelhash); + if (memcmp(out + HLEN + 1, labelhash, HLEN)) { + sfree(out); + return NULL; + } + /* Expect zero bytes followed by a 1 byte */ + for (i = 1 + 2 * HLEN; i < outlen; i++) { + if (out[i] == 1) { + i++; /* skip over the 1 byte */ + break; + } else if (out[i] != 1) { + sfree(out); + return NULL; + } + } + /* And what's left is the input message data, which should be + * encoded as an ordinary SSH-2 mpint. */ + BinarySource_BARE_INIT(src, out + i, outlen - i); + b1 = get_mp_ssh2(src); + sfree(out); + if (get_err(src) || get_avail(src) != 0) { + freebn(b1); + return NULL; + } + + /* Success! */ + return b1; +} + +static const struct ssh_rsa_kex_extra ssh_rsa_kex_extra_sha1 = { 1024 }; +static const struct ssh_rsa_kex_extra ssh_rsa_kex_extra_sha256 = { 2048 }; + static const struct ssh_kex ssh_rsa_kex_sha1 = { - "rsa1024-sha1", NULL, KEXTYPE_RSA, &ssh_sha1, NULL, + "rsa1024-sha1", NULL, KEXTYPE_RSA, + &ssh_sha1, &ssh_rsa_kex_extra_sha1, }; static const struct ssh_kex ssh_rsa_kex_sha256 = { - "rsa2048-sha256", NULL, KEXTYPE_RSA, &ssh_sha256, NULL, + "rsa2048-sha256", NULL, KEXTYPE_RSA, + &ssh_sha256, &ssh_rsa_kex_extra_sha256, }; static const struct ssh_kex *const rsa_kex_list[] = {