1
0
mirror of https://git.tartarus.org/simon/putty.git synced 2025-01-10 01:48:00 +00:00

mpint: clean up handling of uintmax_t.

Functions like mp_copy_integer_into, mp_add_integer_into and
mp_hs_integer all take an ordinary C integer in the form of a
uintmax_t, and perform an operation between that and an mp_int. In
order to do that, they have to break it up into some number of
BignumInt, via bit shifts.

But in C, shifting by an amount equal to or greater than the width of
the type is undefined behaviour, and you risk the compiler generating
nonsense or complaining at compile time. I did various dodges in those
functions to try to avoid that, but didn't manage to use the same
idiom everywhere. Sometimes I'd leave the integer in its original form
and shift it right by increasing multiples of BIGNUM_INT_BITS;
sometimes I'd shift it down in place every time. And mostly I'd do the
conditional shift by checking against sizeof(n), but once I did it by
shifting by half the word and then the other half.

Now refactored so that there's a pair of functions to shift a
uintmax_t left or right by BIGNUM_INT_BITS in what I hope is a UB-safe
manner, and changed all the code I could find to use them.
This commit is contained in:
Simon Tatham 2020-03-02 18:34:52 +00:00
parent a085acbadf
commit 3ea69c290e

56
mpint.c
View File

@ -33,6 +33,35 @@ static inline BignumInt mp_word(mp_int *x, size_t i)
return i < x->nw ? x->w[i] : 0; return i < x->nw ? x->w[i] : 0;
} }
/*
* Shift an ordinary C integer by BIGNUM_INT_BITS, in a way that
* avoids writing a shift operator whose RHS is greater or equal to
* the size of the type, because that's undefined behaviour in C.
*
* In fact we must avoid even writing it in a definitely-untaken
* branch of an if, because compilers will sometimes warn about
* that. So you can't just write 'shift too big ? 0 : n >> shift',
* because even if 'shift too big' is a constant-expression
* evaluating to false, you can still get complaints about the
* else clause of the ?:.
*
* So we have to re-check _inside_ that clause, so that the shift
* count is reset to something nonsensical but safe in the case
* where the clause wasn't going to be taken anyway.
*/
static uintmax_t shift_right_by_one_word(uintmax_t n)
{
bool shift_too_big = BIGNUM_INT_BYTES >= sizeof(n);
return shift_too_big ? 0 :
n >> (shift_too_big ? 0 : BIGNUM_INT_BITS);
}
static uintmax_t shift_left_by_one_word(uintmax_t n)
{
bool shift_too_big = BIGNUM_INT_BYTES >= sizeof(n);
return shift_too_big ? 0 :
n << (shift_too_big ? 0 : BIGNUM_INT_BITS);
}
mp_int *mp_make_sized(size_t nw) mp_int *mp_make_sized(size_t nw)
{ {
mp_int *x = snew_plus(mp_int, nw * sizeof(BignumInt)); mp_int *x = snew_plus(mp_int, nw * sizeof(BignumInt));
@ -93,8 +122,8 @@ void mp_copy_into(mp_int *dest, mp_int *src)
void mp_copy_integer_into(mp_int *r, uintmax_t n) void mp_copy_integer_into(mp_int *r, uintmax_t n)
{ {
for (size_t i = 0; i < r->nw; i++) { for (size_t i = 0; i < r->nw; i++) {
r->w[i] = (BignumInt)n; r->w[i] = n;
n = (BIGNUM_INT_BYTES < sizeof(n)) ? n >> BIGNUM_INT_BITS : 0; n = shift_right_by_one_word(n);
} }
} }
@ -268,12 +297,8 @@ unsigned mp_get_bit(mp_int *x, size_t bit)
uintmax_t mp_get_integer(mp_int *x) uintmax_t mp_get_integer(mp_int *x)
{ {
uintmax_t toret = 0; uintmax_t toret = 0;
for (size_t i = x->nw; i-- > 0 ;) { for (size_t i = x->nw; i-- > 0 ;)
/* Shift in two stages to avoid undefined behaviour if the toret = shift_left_by_one_word(toret) | x->w[i];
* shift count equals the integer width */
toret = (toret << (BIGNUM_INT_BITS/2)) << (BIGNUM_INT_BITS/2);
toret |= x->w[i];
}
return toret; return toret;
} }
@ -765,8 +790,8 @@ static BignumCarry mp_add_masked_integer_into(
{ {
for (size_t i = 0; i < rw; i++) { for (size_t i = 0; i < rw; i++) {
BignumInt aword = mp_word(a, i); BignumInt aword = mp_word(a, i);
size_t shift = i * BIGNUM_INT_BITS; BignumInt bword = b;
BignumInt bword = shift < CHAR_BIT*sizeof(b) ? b >> shift : 0; b = shift_right_by_one_word(b);
BignumInt out; BignumInt out;
bword = (bword ^ b_xor) & b_and; bword = (bword ^ b_xor) & b_and;
BignumADC(out, carry, aword, bword, carry); BignumADC(out, carry, aword, bword, carry);
@ -807,8 +832,7 @@ static void mp_add_integer_into_shifted_by_words(
* shift n down. If it's 0, we add zero bits into r, and * shift n down. If it's 0, we add zero bits into r, and
* leave n alone. */ * leave n alone. */
BignumInt bword = n & -(BignumInt)indicator; BignumInt bword = n & -(BignumInt)indicator;
uintmax_t new_n = (BIGNUM_INT_BYTES < sizeof(n) ? uintmax_t new_n = shift_right_by_one_word(n);
n >> BIGNUM_INT_BITS : 0);
n ^= (n ^ new_n) & -(uintmax_t)indicator; n ^= (n ^ new_n) & -(uintmax_t)indicator;
BignumInt aword = mp_word(a, i); BignumInt aword = mp_word(a, i);
@ -854,8 +878,8 @@ unsigned mp_hs_integer(mp_int *x, uintmax_t n)
{ {
BignumInt carry = 1; BignumInt carry = 1;
for (size_t i = 0; i < x->nw; i++) { for (size_t i = 0; i < x->nw; i++) {
size_t shift = i * BIGNUM_INT_BITS; BignumInt nword = n;
BignumInt nword = shift < CHAR_BIT*sizeof(n) ? n >> shift : 0; n = shift_right_by_one_word(n);
BignumInt dummy_out; BignumInt dummy_out;
BignumADC(dummy_out, carry, x->w[i], ~nword, carry); BignumADC(dummy_out, carry, x->w[i], ~nword, carry);
(void)dummy_out; (void)dummy_out;
@ -880,8 +904,8 @@ unsigned mp_eq_integer(mp_int *x, uintmax_t n)
{ {
BignumInt diff = 0; BignumInt diff = 0;
for (size_t i = 0; i < x->nw; i++) { for (size_t i = 0; i < x->nw; i++) {
size_t shift = i * BIGNUM_INT_BITS; BignumInt nword = n;
BignumInt nword = shift < CHAR_BIT*sizeof(n) ? n >> shift : 0; n = shift_right_by_one_word(n);
diff |= x->w[i] ^ nword; diff |= x->w[i] ^ nword;
} }
return 1 ^ normalise_to_1(diff); /* return 1 if diff _is_ zero */ return 1 ^ normalise_to_1(diff); /* return 1 if diff _is_ zero */