diff --git a/defs.h b/defs.h index 87b48499..55468241 100644 --- a/defs.h +++ b/defs.h @@ -37,6 +37,7 @@ struct RSAKey; typedef uint32_t uint32; typedef struct BinarySink BinarySink; +typedef struct BinarySource BinarySource; typedef struct SockAddr_tag *SockAddr; diff --git a/int64.h b/int64.h index 6ac7f3fc..17122974 100644 --- a/int64.h +++ b/int64.h @@ -24,5 +24,6 @@ uint64 uint64_shift_left(uint64 x, int shift); uint64 uint64_from_decimal(char *str); void BinarySink_put_uint64(BinarySink *, uint64); +uint64 BinarySource_get_uint64(BinarySource *); #endif diff --git a/marshal.c b/marshal.c index 48bbef78..76c7c05f 100644 --- a/marshal.c +++ b/marshal.c @@ -82,3 +82,144 @@ int BinarySink_put_pstring(BinarySink *bs, const char *str) bs->write(bs, str, len); return TRUE; } + +/* ---------------------------------------------------------------------- */ + +static int BinarySource_data_avail(BinarySource *src, size_t wanted) +{ + if (src->err) + return FALSE; + + if (wanted <= src->len - src->pos) + return TRUE; + + src->err = BSE_OUT_OF_DATA; + return FALSE; +} + +#define avail(wanted) BinarySource_data_avail(src, wanted) +#define advance(dist) (src->pos += dist) +#define here ((const void *)((const unsigned char *)src->data + src->pos)) +#define consume(dist) \ + ((const void *)((const unsigned char *)src->data + \ + ((src->pos += dist) - dist))) + +ptrlen BinarySource_get_data(BinarySource *src, size_t wanted) +{ + if (!avail(wanted)) + return make_ptrlen("", 0); + + return make_ptrlen(consume(wanted), wanted); +} + +unsigned char BinarySource_get_byte(BinarySource *src) +{ + const unsigned char *ucp; + + if (!avail(1)) + return 0; + + ucp = consume(1); + return *ucp; +} + +int BinarySource_get_bool(BinarySource *src) +{ + const unsigned char *ucp; + + if (!avail(1)) + return 0; + + ucp = consume(1); + return *ucp != 0; +} + +unsigned BinarySource_get_uint16(BinarySource *src) +{ + const unsigned char *ucp; + + if (!avail(2)) + return 0; + + ucp = consume(2); + return GET_16BIT_MSB_FIRST(ucp); +} + +unsigned long BinarySource_get_uint32(BinarySource *src) +{ + const unsigned char *ucp; + + if (!avail(4)) + return 0; + + ucp = consume(4); + return GET_32BIT_MSB_FIRST(ucp); +} + +uint64 BinarySource_get_uint64(BinarySource *src) +{ + const unsigned char *ucp; + uint64 toret; + + if (!avail(8)) { + toret.hi = toret.lo = 0; + return toret; + } + + ucp = consume(8); + toret.hi = GET_32BIT_MSB_FIRST(ucp); + toret.lo = GET_32BIT_MSB_FIRST(ucp + 4); + return toret; +} + +ptrlen BinarySource_get_string(BinarySource *src) +{ + const unsigned char *ucp; + size_t len; + + if (!avail(4)) + return make_ptrlen("", 0); + + ucp = consume(4); + len = GET_32BIT_MSB_FIRST(ucp); + + if (!avail(len)) + return make_ptrlen("", 0); + + return make_ptrlen(consume(len), len); +} + +const char *BinarySource_get_asciz(BinarySource *src) +{ + const char *start, *end; + + if (src->err) + return ""; + + start = here; + end = memchr(start, '\0', src->len - src->pos); + if (!end) { + src->err = BSE_OUT_OF_DATA; + return ""; + } + + advance(end + 1 - start); + return start; +} + +ptrlen BinarySource_get_pstring(BinarySource *src) +{ + const unsigned char *ucp; + size_t len; + + if (!avail(1)) + return make_ptrlen("", 0); + + ucp = consume(1); + len = *ucp; + + if (!avail(len)) + return make_ptrlen("", 0); + + return make_ptrlen(consume(len), len); +} diff --git a/marshal.h b/marshal.h index ca5a009f..e7603adb 100644 --- a/marshal.h +++ b/marshal.h @@ -138,4 +138,127 @@ void BinarySink_put_stringsb(BinarySink *, struct strbuf *); void BinarySink_put_asciz(BinarySink *, const char *str); int BinarySink_put_pstring(BinarySink *, const char *str); +/* ---------------------------------------------------------------------- */ + +/* + * A complementary trait structure for _un_-marshalling. + * + * This structure contains client-visible data fields rather than + * methods, because that seemed more useful than leaving it totally + * opaque. But it's still got the self-pointer system that will allow + * the set of get_* macros to target one of these itself or any other + * type that 'derives' from it. So, for example, an SSH packet + * structure can act as a BinarySource while also having additional + * fields like the packet type. + */ +typedef enum BinarySourceError { + BSE_NO_ERROR, + BSE_OUT_OF_DATA, + BSE_INVALID +} BinarySourceError; +struct BinarySource { + /* + * (data, len) is the data block being decoded. pos is the current + * position within the block. + */ + const void *data; + size_t pos, len; + + /* + * 'err' indicates whether a decoding error has happened at any + * point. Once this has been set to something other than + * BSE_NO_ERROR, it shouldn't be changed by any unmarshalling + * function. So you can safely do a long sequence of get_foo() + * operations and then test err just once at the end, rather than + * having to conditionalise every single get. + * + * The unmarshalling functions should always return some value, + * even if a decoding error occurs. Generally on error they'll + * return zero (if numeric) or the empty string (if string-based), + * or some other appropriate default value for more complicated + * types. + * + * If the usual return value is dynamically allocated (e.g. a + * Bignum, or a normal C 'char *' string), then the error value is + * also dynamic in the same way. So you have to free exactly the + * same set of things whether or not there was a decoding error, + * which simplifies exit paths - for example, you could call a big + * pile of get_foo functions, then put the actual handling of the + * results under 'if (!get_err(src))', and then free everything + * outside that if. + */ + BinarySourceError err; + + /* + * Self-pointer for the implicit derivation trick, same as + * BinarySink above. + */ + BinarySource *binarysource_; +}; + +/* + * Implementation macros, similar to BinarySink. + */ +#define BinarySource_IMPLEMENTATION BinarySource binarysource_[1] +#define BinarySource_INIT__(obj, data_, len_) \ + ((obj)->data = (data_), \ + (obj)->len = (len_), \ + (obj)->pos = 0, \ + (obj)->err = BSE_NO_ERROR, \ + (obj)->binarysource_ = (obj)) +#define BinarySource_BARE_INIT(obj, data_, len_) \ + TYPECHECK(&(obj)->binarysource_ == (BinarySource **)0, \ + BinarySource_INIT__(obj, data_, len_)) +#define BinarySource_INIT(obj, data_, len_) \ + TYPECHECK(&(obj)->binarysource_ == (BinarySource (*)[1])0, \ + BinarySource_INIT__(BinarySource_UPCAST(obj), data_, len_)) +#define BinarySource_DOWNCAST(object, type) \ + TYPECHECK((object) == ((type *)0)->binarysource_, \ + ((type *)(((char *)(object)) - offsetof(type, binarysource_)))) +#define BinarySource_UPCAST(object) \ + TYPECHECK((object)->binarysource_ == (BinarySource *)0, \ + (object)->binarysource_) +#define BinarySource_COPIED(obj) \ + ((obj)->binarysource_->binarysource_ = (obj)->binarysource_) + +#define get_data(src, len) \ + BinarySource_get_data(BinarySource_UPCAST(src), len) +#define get_byte(src) \ + BinarySource_get_byte(BinarySource_UPCAST(src)) +#define get_bool(src) \ + BinarySource_get_bool(BinarySource_UPCAST(src)) +#define get_uint16(src) \ + BinarySource_get_uint16(BinarySource_UPCAST(src)) +#define get_uint32(src) \ + BinarySource_get_uint32(BinarySource_UPCAST(src)) +#define get_uint64(src) \ + BinarySource_get_uint64(BinarySource_UPCAST(src)) +#define get_string(src) \ + BinarySource_get_string(BinarySource_UPCAST(src)) +#define get_asciz(src) \ + BinarySource_get_asciz(BinarySource_UPCAST(src)) +#define get_pstring(src) \ + BinarySource_get_pstring(BinarySource_UPCAST(src)) +#define get_mp_ssh1(src) \ + BinarySource_get_mp_ssh1(BinarySource_UPCAST(src)) +#define get_mp_ssh2(src) \ + BinarySource_get_mp_ssh2(BinarySource_UPCAST(src)) + +#define get_err(src) (BinarySource_UPCAST(src)->err) +#define get_avail(src) (BinarySource_UPCAST(src)->len - \ + BinarySource_UPCAST(src)->pos) +#define get_ptr(src) \ + ((const void *)( \ + (const unsigned char *)(BinarySource_UPCAST(src)->data) + \ + BinarySource_UPCAST(src)->pos)) + +ptrlen BinarySource_get_data(BinarySource *, size_t); +unsigned char BinarySource_get_byte(BinarySource *); +int BinarySource_get_bool(BinarySource *); +unsigned BinarySource_get_uint16(BinarySource *); +unsigned long BinarySource_get_uint32(BinarySource *); +ptrlen BinarySource_get_string(BinarySource *); +const char *BinarySource_get_asciz(BinarySource *); +ptrlen BinarySource_get_pstring(BinarySource *); + #endif /* PUTTY_MARSHAL_H */ diff --git a/ssh.h b/ssh.h index 7a1c5766..3ab3a526 100644 --- a/ssh.h +++ b/ssh.h @@ -695,6 +695,8 @@ Bignum bignum_from_decimal(const char *decimal); void BinarySink_put_mp_ssh1(BinarySink *, Bignum); void BinarySink_put_mp_ssh2(BinarySink *, Bignum); +Bignum BinarySource_get_mp_ssh1(BinarySource *); +Bignum BinarySource_get_mp_ssh2(BinarySource *); #ifdef DEBUG void diagbn(char *prefix, Bignum md); diff --git a/sshbn.c b/sshbn.c index 39f8dfd8..1f0213c0 100644 --- a/sshbn.c +++ b/sshbn.c @@ -1624,12 +1624,12 @@ int ssh1_write_bignum(void *data, Bignum bn) void BinarySink_put_mp_ssh1(BinarySink *bs, Bignum bn) { - int len = ssh1_bignum_length(bn); + int bits = bignum_bitcount(bn); + int bytes = (bits + 7) / 8; int i; - int bitc = bignum_bitcount(bn); - put_uint16(bs, bitc); - for (i = len - 2; i--;) + put_uint16(bs, bits); + for (i = bytes; i--;) put_byte(bs, bignum_byte(bn, i)); } @@ -1643,6 +1643,40 @@ void BinarySink_put_mp_ssh2(BinarySink *bs, Bignum bn) put_byte(bs, bignum_byte(bn, i)); } +Bignum BinarySource_get_mp_ssh1(BinarySource *src) +{ + unsigned bitc = get_uint16(src); + ptrlen bytes = get_data(src, (bitc + 7) / 8); + if (get_err(src)) { + return bignum_from_long(0); + } else { + Bignum toret = bignum_from_bytes(bytes.ptr, bytes.len); + if (bignum_bitcount(toret) != bitc) { + src->err = BSE_INVALID; + freebn(toret); + toret = bignum_from_long(0); + } + return toret; + } +} + +Bignum BinarySource_get_mp_ssh2(BinarySource *src) +{ + ptrlen bytes = get_string(src); + if (get_err(src)) { + return bignum_from_long(0); + } else { + const unsigned char *p = bytes.ptr; + if ((bytes.len > 0 && + ((p[0] & 0x80) || + (p[0] == 0 && (bytes.len <= 1 || !(p[1] & 0x80)))))) { + src->err = BSE_INVALID; + return bignum_from_long(0); + } + return bignum_from_bytes(bytes.ptr, bytes.len); + } +} + /* * Compare two bignums. Returns like strcmp. */