mirror of
https://git.tartarus.org/simon/putty.git
synced 2025-01-10 01:48:00 +00:00
mpint.c: outlaw mp_ints with nw==0.
Some functions got confused if given one as input (particularly mp_get_decimal, which assumed it could safely write at least one word into the inv5 value it makes internally), and I've decided it's easier to stop them ever being created than to teach everything to handle them correctly. So now mp_make_sized enforces nw != 0 by assertion, and I've added a max at any call site that looked as if it might violate that precondition. mp_from_hex("") could generate one of these, in particular, so now I've fixed it, I've added a test to make sure it continues doing something sensible.
This commit is contained in:
parent
9e6669d30a
commit
5017d0a6ca
11
mpint.c
11
mpint.c
@ -36,6 +36,7 @@ static inline BignumInt mp_word(mp_int *x, size_t i)
|
|||||||
static mp_int *mp_make_sized(size_t nw)
|
static mp_int *mp_make_sized(size_t nw)
|
||||||
{
|
{
|
||||||
mp_int *x = snew_plus(mp_int, nw * sizeof(BignumInt));
|
mp_int *x = snew_plus(mp_int, nw * sizeof(BignumInt));
|
||||||
|
assert(nw); /* we outlaw the zero-word mp_int */
|
||||||
x->nw = nw;
|
x->nw = nw;
|
||||||
x->w = snew_plus_get_aux(x);
|
x->w = snew_plus_get_aux(x);
|
||||||
mp_clear(x);
|
mp_clear(x);
|
||||||
@ -140,8 +141,9 @@ void mp_cond_clear(mp_int *x, unsigned clear)
|
|||||||
*/
|
*/
|
||||||
static mp_int *mp_from_bytes_int(ptrlen bytes, size_t m, size_t c)
|
static mp_int *mp_from_bytes_int(ptrlen bytes, size_t m, size_t c)
|
||||||
{
|
{
|
||||||
mp_int *n = mp_make_sized(
|
size_t nw = (bytes.len + BIGNUM_INT_BYTES - 1) / BIGNUM_INT_BYTES;
|
||||||
(bytes.len + BIGNUM_INT_BYTES - 1) / BIGNUM_INT_BYTES);
|
nw = size_t_max(nw, 1);
|
||||||
|
mp_int *n = mp_make_sized(nw);
|
||||||
for (size_t i = 0; i < bytes.len; i++)
|
for (size_t i = 0; i < bytes.len; i++)
|
||||||
n->w[i / BIGNUM_INT_BYTES] |=
|
n->w[i / BIGNUM_INT_BYTES] |=
|
||||||
(BignumInt)(((const unsigned char *)bytes.ptr)[m*i+c]) <<
|
(BignumInt)(((const unsigned char *)bytes.ptr)[m*i+c]) <<
|
||||||
@ -211,6 +213,7 @@ mp_int *mp_from_hex_pl(ptrlen hex)
|
|||||||
assert(hex.len <= (~(size_t)0) / 4);
|
assert(hex.len <= (~(size_t)0) / 4);
|
||||||
size_t bits = hex.len * 4;
|
size_t bits = hex.len * 4;
|
||||||
size_t words = (bits + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS;
|
size_t words = (bits + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS;
|
||||||
|
words = size_t_max(words, 1);
|
||||||
mp_int *x = mp_make_sized(words);
|
mp_int *x = mp_make_sized(words);
|
||||||
for (size_t nibble = 0; nibble < hex.len; nibble++) {
|
for (size_t nibble = 0; nibble < hex.len; nibble++) {
|
||||||
BignumInt digit = ((char *)hex.ptr)[hex.len-1 - nibble];
|
BignumInt digit = ((char *)hex.ptr)[hex.len-1 - nibble];
|
||||||
@ -1077,7 +1080,8 @@ void mp_rshift_fixed_into(mp_int *r, mp_int *a, size_t bits)
|
|||||||
mp_int *mp_rshift_fixed(mp_int *x, size_t bits)
|
mp_int *mp_rshift_fixed(mp_int *x, size_t bits)
|
||||||
{
|
{
|
||||||
size_t words = bits / BIGNUM_INT_BITS;
|
size_t words = bits / BIGNUM_INT_BITS;
|
||||||
mp_int *r = mp_make_sized(x->nw - size_t_min(x->nw, words));
|
size_t nw = x->nw - size_t_min(x->nw, words);
|
||||||
|
mp_int *r = mp_make_sized(size_t_max(nw, 1));
|
||||||
mp_rshift_fixed_into(r, x, bits);
|
mp_rshift_fixed_into(r, x, bits);
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
@ -1148,6 +1152,7 @@ mp_int *mp_invert_mod_2to(mp_int *x, size_t p)
|
|||||||
assert(p > 0);
|
assert(p > 0);
|
||||||
|
|
||||||
size_t rw = (p + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS;
|
size_t rw = (p + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS;
|
||||||
|
rw = size_t_max(rw, 1);
|
||||||
mp_int *r = mp_make_sized(rw);
|
mp_int *r = mp_make_sized(rw);
|
||||||
|
|
||||||
size_t mul_scratchsize = mp_mul_scratchspace(2*rw, rw, rw);
|
size_t mul_scratchsize = mp_mul_scratchspace(2*rw, rw, rw);
|
||||||
|
@ -159,6 +159,7 @@ class mpint(MyTestBase):
|
|||||||
hexstr = 'ea7cb89f409ae845215822e37D32D0C63EC43E1381C2FF8094'
|
hexstr = 'ea7cb89f409ae845215822e37D32D0C63EC43E1381C2FF8094'
|
||||||
self.assertEqual(int(mp_from_hex_pl(hexstr)), int(hexstr, 16))
|
self.assertEqual(int(mp_from_hex_pl(hexstr)), int(hexstr, 16))
|
||||||
self.assertEqual(int(mp_from_hex(hexstr)), int(hexstr, 16))
|
self.assertEqual(int(mp_from_hex(hexstr)), int(hexstr, 16))
|
||||||
|
self.assertEqual(int(mp_from_hex("")), 0)
|
||||||
p2 = mp_power_2(123)
|
p2 = mp_power_2(123)
|
||||||
self.assertEqual(int(p2), 1 << 123)
|
self.assertEqual(int(p2), 1 << 123)
|
||||||
p2c = mp_copy(p2)
|
p2c = mp_copy(p2)
|
||||||
@ -319,7 +320,7 @@ class mpint(MyTestBase):
|
|||||||
diff = mp_sub(am, bm)
|
diff = mp_sub(am, bm)
|
||||||
self.assertEqual(int(diff), (ai - bi) & mp_mask(diff))
|
self.assertEqual(int(diff), (ai - bi) & mp_mask(diff))
|
||||||
|
|
||||||
for bits in range(0, 512, 64):
|
for bits in range(64, 512, 64):
|
||||||
cm = mp_new(bits)
|
cm = mp_new(bits)
|
||||||
mp_add_into(cm, am, bm)
|
mp_add_into(cm, am, bm)
|
||||||
self.assertEqual(int(cm), (ai + bi) & mp_mask(cm))
|
self.assertEqual(int(cm), (ai + bi) & mp_mask(cm))
|
||||||
@ -357,8 +358,8 @@ class mpint(MyTestBase):
|
|||||||
if r >= d:
|
if r >= d:
|
||||||
continue # silly cases with tiny divisors
|
continue # silly cases with tiny divisors
|
||||||
n = q*d + r
|
n = q*d + r
|
||||||
mq = mp_new(nbits(q))
|
mq = mp_new(max(nbits(q), 1))
|
||||||
mr = mp_new(nbits(r))
|
mr = mp_new(max(nbits(r), 1))
|
||||||
mp_divmod_into(n, d, mq, mr)
|
mp_divmod_into(n, d, mq, mr)
|
||||||
self.assertEqual(int(mq), q)
|
self.assertEqual(int(mq), q)
|
||||||
self.assertEqual(int(mr), r)
|
self.assertEqual(int(mr), r)
|
||||||
|
Loading…
Reference in New Issue
Block a user