diff --git a/mpint.c b/mpint.c index 3d7a3987..8ac7a50e 100644 --- a/mpint.c +++ b/mpint.c @@ -33,6 +33,35 @@ static inline BignumInt mp_word(mp_int *x, size_t i) return i < x->nw ? x->w[i] : 0; } +/* + * Shift an ordinary C integer by BIGNUM_INT_BITS, in a way that + * avoids writing a shift operator whose RHS is greater or equal to + * the size of the type, because that's undefined behaviour in C. + * + * In fact we must avoid even writing it in a definitely-untaken + * branch of an if, because compilers will sometimes warn about + * that. So you can't just write 'shift too big ? 0 : n >> shift', + * because even if 'shift too big' is a constant-expression + * evaluating to false, you can still get complaints about the + * else clause of the ?:. + * + * So we have to re-check _inside_ that clause, so that the shift + * count is reset to something nonsensical but safe in the case + * where the clause wasn't going to be taken anyway. + */ +static uintmax_t shift_right_by_one_word(uintmax_t n) +{ + bool shift_too_big = BIGNUM_INT_BYTES >= sizeof(n); + return shift_too_big ? 0 : + n >> (shift_too_big ? 0 : BIGNUM_INT_BITS); +} +static uintmax_t shift_left_by_one_word(uintmax_t n) +{ + bool shift_too_big = BIGNUM_INT_BYTES >= sizeof(n); + return shift_too_big ? 0 : + n << (shift_too_big ? 0 : BIGNUM_INT_BITS); +} + static mp_int *mp_make_sized(size_t nw) { mp_int *x = snew_plus(mp_int, nw * sizeof(BignumInt)); @@ -260,12 +289,8 @@ unsigned mp_get_bit(mp_int *x, size_t bit) uintmax_t mp_get_integer(mp_int *x) { uintmax_t toret = 0; - for (size_t i = x->nw; i-- > 0 ;) { - /* Shift in two stages to avoid undefined behaviour if the - * shift count equals the integer width */ - toret = (toret << (BIGNUM_INT_BITS/2)) << (BIGNUM_INT_BITS/2); - toret |= x->w[i]; - } + for (size_t i = x->nw; i-- > 0 ;) + toret = shift_left_by_one_word(toret) | x->w[i]; return toret; } @@ -757,8 +782,8 @@ static BignumCarry mp_add_masked_integer_into( { for (size_t i = 0; i < rw; i++) { BignumInt aword = mp_word(a, i); - size_t shift = i * BIGNUM_INT_BITS; - BignumInt bword = shift < CHAR_BIT*sizeof(b) ? b >> shift : 0; + BignumInt bword = b; + b = shift_right_by_one_word(b); BignumInt out; bword = (bword ^ b_xor) & b_and; BignumADC(out, carry, aword, bword, carry); @@ -799,8 +824,7 @@ static void mp_add_integer_into_shifted_by_words( * shift n down. If it's 0, we add zero bits into r, and * leave n alone. */ BignumInt bword = n & -(BignumInt)indicator; - uintmax_t new_n = (BIGNUM_INT_BYTES < sizeof(n) ? - n >> BIGNUM_INT_BITS : 0); + uintmax_t new_n = shift_right_by_one_word(n); n ^= (n ^ new_n) & -(uintmax_t)indicator; BignumInt aword = mp_word(a, i); @@ -846,8 +870,8 @@ unsigned mp_hs_integer(mp_int *x, uintmax_t n) { BignumInt carry = 1; for (size_t i = 0; i < x->nw; i++) { - size_t shift = i * BIGNUM_INT_BITS; - BignumInt nword = shift < CHAR_BIT*sizeof(n) ? n >> shift : 0; + BignumInt nword = n; + n = shift_right_by_one_word(n); BignumInt dummy_out; BignumADC(dummy_out, carry, x->w[i], ~nword, carry); (void)dummy_out; @@ -872,8 +896,8 @@ unsigned mp_eq_integer(mp_int *x, uintmax_t n) { BignumInt diff = 0; for (size_t i = 0; i < x->nw; i++) { - size_t shift = i * BIGNUM_INT_BITS; - BignumInt nword = shift < CHAR_BIT*sizeof(n) ? n >> shift : 0; + BignumInt nword = n; + n = shift_right_by_one_word(n); diff |= x->w[i] ^ nword; } return 1 ^ normalise_to_1(diff); /* return 1 if diff _is_ zero */