1
0
mirror of https://git.tartarus.org/simon/putty.git synced 2025-01-10 01:48:00 +00:00

mpint.c: outlaw mp_ints with nw==0.

Some functions got confused if given one as input (particularly
mp_get_decimal, which assumed it could safely write at least one word
into the inv5 value it makes internally), and I've decided it's easier
to stop them ever being created than to teach everything to handle
them correctly. So now mp_make_sized enforces nw != 0 by assertion,
and I've added a max at any call site that looked as if it might
violate that precondition.

mp_from_hex("") could generate one of these, in particular, so now
I've fixed it, I've added a test to make sure it continues doing
something sensible.
This commit is contained in:
Simon Tatham 2019-01-29 20:03:28 +00:00
parent 9e6669d30a
commit 5017d0a6ca
2 changed files with 12 additions and 6 deletions

11
mpint.c
View File

@ -36,6 +36,7 @@ static inline BignumInt mp_word(mp_int *x, size_t i)
static mp_int *mp_make_sized(size_t nw) static mp_int *mp_make_sized(size_t nw)
{ {
mp_int *x = snew_plus(mp_int, nw * sizeof(BignumInt)); mp_int *x = snew_plus(mp_int, nw * sizeof(BignumInt));
assert(nw); /* we outlaw the zero-word mp_int */
x->nw = nw; x->nw = nw;
x->w = snew_plus_get_aux(x); x->w = snew_plus_get_aux(x);
mp_clear(x); mp_clear(x);
@ -140,8 +141,9 @@ void mp_cond_clear(mp_int *x, unsigned clear)
*/ */
static mp_int *mp_from_bytes_int(ptrlen bytes, size_t m, size_t c) static mp_int *mp_from_bytes_int(ptrlen bytes, size_t m, size_t c)
{ {
mp_int *n = mp_make_sized( size_t nw = (bytes.len + BIGNUM_INT_BYTES - 1) / BIGNUM_INT_BYTES;
(bytes.len + BIGNUM_INT_BYTES - 1) / BIGNUM_INT_BYTES); nw = size_t_max(nw, 1);
mp_int *n = mp_make_sized(nw);
for (size_t i = 0; i < bytes.len; i++) for (size_t i = 0; i < bytes.len; i++)
n->w[i / BIGNUM_INT_BYTES] |= n->w[i / BIGNUM_INT_BYTES] |=
(BignumInt)(((const unsigned char *)bytes.ptr)[m*i+c]) << (BignumInt)(((const unsigned char *)bytes.ptr)[m*i+c]) <<
@ -211,6 +213,7 @@ mp_int *mp_from_hex_pl(ptrlen hex)
assert(hex.len <= (~(size_t)0) / 4); assert(hex.len <= (~(size_t)0) / 4);
size_t bits = hex.len * 4; size_t bits = hex.len * 4;
size_t words = (bits + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS; size_t words = (bits + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS;
words = size_t_max(words, 1);
mp_int *x = mp_make_sized(words); mp_int *x = mp_make_sized(words);
for (size_t nibble = 0; nibble < hex.len; nibble++) { for (size_t nibble = 0; nibble < hex.len; nibble++) {
BignumInt digit = ((char *)hex.ptr)[hex.len-1 - nibble]; BignumInt digit = ((char *)hex.ptr)[hex.len-1 - nibble];
@ -1077,7 +1080,8 @@ void mp_rshift_fixed_into(mp_int *r, mp_int *a, size_t bits)
mp_int *mp_rshift_fixed(mp_int *x, size_t bits) mp_int *mp_rshift_fixed(mp_int *x, size_t bits)
{ {
size_t words = bits / BIGNUM_INT_BITS; size_t words = bits / BIGNUM_INT_BITS;
mp_int *r = mp_make_sized(x->nw - size_t_min(x->nw, words)); size_t nw = x->nw - size_t_min(x->nw, words);
mp_int *r = mp_make_sized(size_t_max(nw, 1));
mp_rshift_fixed_into(r, x, bits); mp_rshift_fixed_into(r, x, bits);
return r; return r;
} }
@ -1148,6 +1152,7 @@ mp_int *mp_invert_mod_2to(mp_int *x, size_t p)
assert(p > 0); assert(p > 0);
size_t rw = (p + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS; size_t rw = (p + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS;
rw = size_t_max(rw, 1);
mp_int *r = mp_make_sized(rw); mp_int *r = mp_make_sized(rw);
size_t mul_scratchsize = mp_mul_scratchspace(2*rw, rw, rw); size_t mul_scratchsize = mp_mul_scratchspace(2*rw, rw, rw);

View File

@ -159,6 +159,7 @@ class mpint(MyTestBase):
hexstr = 'ea7cb89f409ae845215822e37D32D0C63EC43E1381C2FF8094' hexstr = 'ea7cb89f409ae845215822e37D32D0C63EC43E1381C2FF8094'
self.assertEqual(int(mp_from_hex_pl(hexstr)), int(hexstr, 16)) self.assertEqual(int(mp_from_hex_pl(hexstr)), int(hexstr, 16))
self.assertEqual(int(mp_from_hex(hexstr)), int(hexstr, 16)) self.assertEqual(int(mp_from_hex(hexstr)), int(hexstr, 16))
self.assertEqual(int(mp_from_hex("")), 0)
p2 = mp_power_2(123) p2 = mp_power_2(123)
self.assertEqual(int(p2), 1 << 123) self.assertEqual(int(p2), 1 << 123)
p2c = mp_copy(p2) p2c = mp_copy(p2)
@ -319,7 +320,7 @@ class mpint(MyTestBase):
diff = mp_sub(am, bm) diff = mp_sub(am, bm)
self.assertEqual(int(diff), (ai - bi) & mp_mask(diff)) self.assertEqual(int(diff), (ai - bi) & mp_mask(diff))
for bits in range(0, 512, 64): for bits in range(64, 512, 64):
cm = mp_new(bits) cm = mp_new(bits)
mp_add_into(cm, am, bm) mp_add_into(cm, am, bm)
self.assertEqual(int(cm), (ai + bi) & mp_mask(cm)) self.assertEqual(int(cm), (ai + bi) & mp_mask(cm))
@ -357,8 +358,8 @@ class mpint(MyTestBase):
if r >= d: if r >= d:
continue # silly cases with tiny divisors continue # silly cases with tiny divisors
n = q*d + r n = q*d + r
mq = mp_new(nbits(q)) mq = mp_new(max(nbits(q), 1))
mr = mp_new(nbits(r)) mr = mp_new(max(nbits(r), 1))
mp_divmod_into(n, d, mq, mr) mp_divmod_into(n, d, mq, mr)
self.assertEqual(int(mq), q) self.assertEqual(int(mq), q)
self.assertEqual(int(mr), r) self.assertEqual(int(mr), r)