From 3ea69c290e8764282d3c52f351dfdc166e926590 Mon Sep 17 00:00:00 2001 From: Simon Tatham Date: Mon, 2 Mar 2020 18:34:52 +0000 Subject: [PATCH] mpint: clean up handling of uintmax_t. Functions like mp_copy_integer_into, mp_add_integer_into and mp_hs_integer all take an ordinary C integer in the form of a uintmax_t, and perform an operation between that and an mp_int. In order to do that, they have to break it up into some number of BignumInt, via bit shifts. But in C, shifting by an amount equal to or greater than the width of the type is undefined behaviour, and you risk the compiler generating nonsense or complaining at compile time. I did various dodges in those functions to try to avoid that, but didn't manage to use the same idiom everywhere. Sometimes I'd leave the integer in its original form and shift it right by increasing multiples of BIGNUM_INT_BITS; sometimes I'd shift it down in place every time. And mostly I'd do the conditional shift by checking against sizeof(n), but once I did it by shifting by half the word and then the other half. Now refactored so that there's a pair of functions to shift a uintmax_t left or right by BIGNUM_INT_BITS in what I hope is a UB-safe manner, and changed all the code I could find to use them. --- mpint.c | 56 ++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/mpint.c b/mpint.c index 1c464c44..6d826e0d 100644 --- a/mpint.c +++ b/mpint.c @@ -33,6 +33,35 @@ static inline BignumInt mp_word(mp_int *x, size_t i) return i < x->nw ? x->w[i] : 0; } +/* + * Shift an ordinary C integer by BIGNUM_INT_BITS, in a way that + * avoids writing a shift operator whose RHS is greater or equal to + * the size of the type, because that's undefined behaviour in C. + * + * In fact we must avoid even writing it in a definitely-untaken + * branch of an if, because compilers will sometimes warn about + * that. So you can't just write 'shift too big ? 0 : n >> shift', + * because even if 'shift too big' is a constant-expression + * evaluating to false, you can still get complaints about the + * else clause of the ?:. + * + * So we have to re-check _inside_ that clause, so that the shift + * count is reset to something nonsensical but safe in the case + * where the clause wasn't going to be taken anyway. + */ +static uintmax_t shift_right_by_one_word(uintmax_t n) +{ + bool shift_too_big = BIGNUM_INT_BYTES >= sizeof(n); + return shift_too_big ? 0 : + n >> (shift_too_big ? 0 : BIGNUM_INT_BITS); +} +static uintmax_t shift_left_by_one_word(uintmax_t n) +{ + bool shift_too_big = BIGNUM_INT_BYTES >= sizeof(n); + return shift_too_big ? 0 : + n << (shift_too_big ? 0 : BIGNUM_INT_BITS); +} + mp_int *mp_make_sized(size_t nw) { mp_int *x = snew_plus(mp_int, nw * sizeof(BignumInt)); @@ -93,8 +122,8 @@ void mp_copy_into(mp_int *dest, mp_int *src) void mp_copy_integer_into(mp_int *r, uintmax_t n) { for (size_t i = 0; i < r->nw; i++) { - r->w[i] = (BignumInt)n; - n = (BIGNUM_INT_BYTES < sizeof(n)) ? n >> BIGNUM_INT_BITS : 0; + r->w[i] = n; + n = shift_right_by_one_word(n); } } @@ -268,12 +297,8 @@ unsigned mp_get_bit(mp_int *x, size_t bit) uintmax_t mp_get_integer(mp_int *x) { uintmax_t toret = 0; - for (size_t i = x->nw; i-- > 0 ;) { - /* Shift in two stages to avoid undefined behaviour if the - * shift count equals the integer width */ - toret = (toret << (BIGNUM_INT_BITS/2)) << (BIGNUM_INT_BITS/2); - toret |= x->w[i]; - } + for (size_t i = x->nw; i-- > 0 ;) + toret = shift_left_by_one_word(toret) | x->w[i]; return toret; } @@ -765,8 +790,8 @@ static BignumCarry mp_add_masked_integer_into( { for (size_t i = 0; i < rw; i++) { BignumInt aword = mp_word(a, i); - size_t shift = i * BIGNUM_INT_BITS; - BignumInt bword = shift < CHAR_BIT*sizeof(b) ? b >> shift : 0; + BignumInt bword = b; + b = shift_right_by_one_word(b); BignumInt out; bword = (bword ^ b_xor) & b_and; BignumADC(out, carry, aword, bword, carry); @@ -807,8 +832,7 @@ static void mp_add_integer_into_shifted_by_words( * shift n down. If it's 0, we add zero bits into r, and * leave n alone. */ BignumInt bword = n & -(BignumInt)indicator; - uintmax_t new_n = (BIGNUM_INT_BYTES < sizeof(n) ? - n >> BIGNUM_INT_BITS : 0); + uintmax_t new_n = shift_right_by_one_word(n); n ^= (n ^ new_n) & -(uintmax_t)indicator; BignumInt aword = mp_word(a, i); @@ -854,8 +878,8 @@ unsigned mp_hs_integer(mp_int *x, uintmax_t n) { BignumInt carry = 1; for (size_t i = 0; i < x->nw; i++) { - size_t shift = i * BIGNUM_INT_BITS; - BignumInt nword = shift < CHAR_BIT*sizeof(n) ? n >> shift : 0; + BignumInt nword = n; + n = shift_right_by_one_word(n); BignumInt dummy_out; BignumADC(dummy_out, carry, x->w[i], ~nword, carry); (void)dummy_out; @@ -880,8 +904,8 @@ unsigned mp_eq_integer(mp_int *x, uintmax_t n) { BignumInt diff = 0; for (size_t i = 0; i < x->nw; i++) { - size_t shift = i * BIGNUM_INT_BITS; - BignumInt nword = shift < CHAR_BIT*sizeof(n) ? n >> shift : 0; + BignumInt nword = n; + n = shift_right_by_one_word(n); diff |= x->w[i] ^ nword; } return 1 ^ normalise_to_1(diff); /* return 1 if diff _is_ zero */