diff --git a/marshal.h b/marshal.h index e7603adb..3b7a089f 100644 --- a/marshal.h +++ b/marshal.h @@ -243,6 +243,10 @@ struct BinarySource { BinarySource_get_mp_ssh1(BinarySource_UPCAST(src)) #define get_mp_ssh2(src) \ BinarySource_get_mp_ssh2(BinarySource_UPCAST(src)) +#define get_rsa_ssh1_pub(src, rsa, keystr, order) \ + BinarySource_get_rsa_ssh1_pub(BinarySource_UPCAST(src), rsa, keystr, order) +#define get_rsa_ssh1_priv(src, rsa) \ + BinarySource_get_rsa_ssh1_priv(BinarySource_UPCAST(src), rsa) #define get_err(src) (BinarySource_UPCAST(src)->err) #define get_avail(src) (BinarySource_UPCAST(src)->len - \ diff --git a/ssh.h b/ssh.h index 3ab3a526..21ab1593 100644 --- a/ssh.h +++ b/ssh.h @@ -182,8 +182,13 @@ typedef enum { RSA_SSH1_EXPONENT_FIRST, RSA_SSH1_MODULUS_FIRST } RsaSsh1Order; int rsa_ssh1_readpub(const unsigned char *data, int len, struct RSAKey *result, const unsigned char **keystr, RsaSsh1Order order); +void BinarySource_get_rsa_ssh1_pub( + BinarySource *src, struct RSAKey *result, + ptrlen *keystr, RsaSsh1Order order); int rsa_ssh1_readpriv(const unsigned char *data, int len, struct RSAKey *result); +void BinarySource_get_rsa_ssh1_priv( + BinarySource *src, struct RSAKey *rsa); int rsa_ssh1_encrypt(unsigned char *data, int length, struct RSAKey *key); Bignum rsa_ssh1_decrypt(Bignum input, struct RSAKey *key); void rsasanitise(struct RSAKey *key); diff --git a/sshrsa.c b/sshrsa.c index 2fcf5e6b..c0008dc1 100644 --- a/sshrsa.c +++ b/sshrsa.c @@ -10,53 +10,79 @@ #include "ssh.h" #include "misc.h" +void BinarySource_get_rsa_ssh1_pub( + BinarySource *src, struct RSAKey *rsa, ptrlen *keystr, RsaSsh1Order order) +{ + const unsigned char *start, *end; + unsigned bits; + Bignum e, m; + + bits = get_uint32(src); + if (order == RSA_SSH1_EXPONENT_FIRST) { + e = get_mp_ssh1(src); + start = get_ptr(src); + m = get_mp_ssh1(src); + end = get_ptr(src); + } else { + start = get_ptr(src); + m = get_mp_ssh1(src); + end = get_ptr(src); + e = get_mp_ssh1(src); + } + + if (keystr) { + start += (end-start >= 2 ? 2 : end-start); + keystr->ptr = start; + keystr->len = end - start; + } + + if (rsa) { + rsa->bits = bits; + rsa->exponent = e; + rsa->modulus = m; + rsa->bytes = (bignum_bitcount(m) + 7) / 8; + } else { + freebn(e); + freebn(m); + } +} + int rsa_ssh1_readpub(const unsigned char *data, int len, struct RSAKey *result, const unsigned char **keystr, RsaSsh1Order order) { - const unsigned char *p = data; - int i, n; + BinarySource src; + ptrlen key_pl; - if (len < 4) - return -1; + BinarySource_BARE_INIT(&src, data, len); + get_rsa_ssh1_pub(&src, result, &key_pl, order); - if (result) { - result->bits = 0; - for (i = 0; i < 4; i++) - result->bits = (result->bits << 8) + *p++; - } else - p += 4; - - len -= 4; - - if (order == RSA_SSH1_EXPONENT_FIRST) { - n = ssh1_read_bignum(p, len, result ? &result->exponent : NULL); - if (n < 0) return -1; - p += n; - len -= n; - } - - n = ssh1_read_bignum(p, len, result ? &result->modulus : NULL); - if (n < 0 || (result && bignum_bitcount(result->modulus) == 0)) return -1; - if (result) - result->bytes = n - 2; if (keystr) - *keystr = p + 2; - p += n; - len -= n; + *keystr = key_pl.ptr; - if (order == RSA_SSH1_MODULUS_FIRST) { - n = ssh1_read_bignum(p, len, result ? &result->exponent : NULL); - if (n < 0) return -1; - p += n; - len -= n; - } - return p - data; + if (get_err(&src)) + return -1; + else + return key_pl.len; +} + +void BinarySource_get_rsa_ssh1_priv( + BinarySource *src, struct RSAKey *rsa) +{ + rsa->private_exponent = get_mp_ssh1(src); } int rsa_ssh1_readpriv(const unsigned char *data, int len, struct RSAKey *result) { - return ssh1_read_bignum(data, len, &result->private_exponent); + BinarySource src; + + BinarySource_BARE_INIT(&src, data, len); + get_rsa_ssh1_priv(&src, result); + + if (get_err(&src)) + return -1; + else + return src.pos; } int rsa_ssh1_encrypt(unsigned char *data, int length, struct RSAKey *key) @@ -455,25 +481,21 @@ void rsa_ssh1_public_blob(BinarySink *bs, struct RSAKey *key, /* Given a public blob, determine its length. */ int rsa_public_blob_len(void *data, int maxlen) { - unsigned char *p = (unsigned char *)data; - int n; + BinarySource src[1]; - if (maxlen < 4) + BinarySource_BARE_INIT(src, data, maxlen); + + /* Expect a length word, then exponent and modulus. (It doesn't + * even matter which order.) */ + get_uint32(src); + freebn(get_mp_ssh1(src)); + freebn(get_mp_ssh1(src)); + + if (get_err(src)) return -1; - p += 4; /* length word */ - maxlen -= 4; - n = ssh1_read_bignum(p, maxlen, NULL); /* exponent */ - if (n < 0) - return -1; - p += n; - - n = ssh1_read_bignum(p, maxlen, NULL); /* modulus */ - if (n < 0) - return -1; - p += n; - - return p - (unsigned char *)data; + /* Return the number of bytes consumed. */ + return src->pos; } void freersakey(struct RSAKey *key)