From 18678ba9bc42c497355d5f53831e46246532a0a9 Mon Sep 17 00:00:00 2001 From: Simon Tatham Date: Tue, 18 Feb 2020 18:55:56 +0000 Subject: [PATCH] mpint: add mp_[lr]shift_safe_into functions. There was previously no safe left shift at all, which is an omission. And rshift_safe_into was an odd thing to be missing, so while I'm here, I've added it on the basis that it will probably be useful sooner or later. --- mpint.c | 54 +++++++++++++++++++++++++++++++++++++++++++--- mpint.h | 9 +++++++- test/cryptsuite.py | 10 +++++++++ testcrypt.h | 2 ++ 4 files changed, 71 insertions(+), 4 deletions(-) diff --git a/mpint.c b/mpint.c index 94060ae7..cb10bd5f 100644 --- a/mpint.c +++ b/mpint.c @@ -1132,13 +1132,11 @@ mp_int *mp_rshift_fixed(mp_int *x, size_t bits) * by a power of 2 words, using the usual bit twiddling to make the * whole shift conditional on the appropriate bit of n. */ -mp_int *mp_rshift_safe(mp_int *x, size_t bits) +static void mp_rshift_safe_in_place(mp_int *r, size_t bits) { size_t wordshift = bits / BIGNUM_INT_BITS; size_t bitshift = bits % BIGNUM_INT_BITS; - mp_int *r = mp_copy(x); - unsigned clear = (r->nw - wordshift) >> (CHAR_BIT * sizeof(size_t) - 1); mp_cond_clear(r, clear); @@ -1163,10 +1161,60 @@ mp_int *mp_rshift_safe(mp_int *x, size_t bits) r->w[i] ^= (r->w[i] ^ w) & mask; } } +} +mp_int *mp_rshift_safe(mp_int *x, size_t bits) +{ + mp_int *r = mp_copy(x); + mp_rshift_safe_in_place(r, bits); return r; } +void mp_rshift_safe_into(mp_int *r, mp_int *x, size_t bits) +{ + mp_copy_into(r, x); + mp_rshift_safe_in_place(r, bits); +} + +static void mp_lshift_safe_in_place(mp_int *r, size_t bits) +{ + size_t wordshift = bits / BIGNUM_INT_BITS; + size_t bitshift = bits % BIGNUM_INT_BITS; + + /* + * Same strategy as mp_rshift_safe_in_place, but of course the + * other way up. + */ + + unsigned clear = (r->nw - wordshift) >> (CHAR_BIT * sizeof(size_t) - 1); + mp_cond_clear(r, clear); + + for (unsigned bit = 0; r->nw >> bit; bit++) { + size_t word_offset = 1 << bit; + BignumInt mask = -(BignumInt)((wordshift >> bit) & 1); + for (size_t i = r->nw; i-- > 0 ;) { + BignumInt w = mp_word(r, i - word_offset); + r->w[i] ^= (r->w[i] ^ w) & mask; + } + } + + size_t downshift = BIGNUM_INT_BITS - bitshift; + size_t no_shift = (downshift >> BIGNUM_INT_BITS_BITS); + downshift &= ~-(size_t)no_shift; + BignumInt downshifted_mask = ~-(BignumInt)no_shift; + + for (size_t i = r->nw; i-- > 0 ;) { + r->w[i] = (r->w[i] << bitshift) | + ((mp_word(r, i-1) >> downshift) & downshifted_mask); + } +} + +void mp_lshift_safe_into(mp_int *r, mp_int *x, size_t bits) +{ + mp_copy_into(r, x); + mp_lshift_safe_in_place(r, bits); +} + void mp_reduce_mod_2to(mp_int *x, size_t p) { size_t word = p / BIGNUM_INT_BITS; diff --git a/mpint.h b/mpint.h index e15312b3..13ac9a51 100644 --- a/mpint.h +++ b/mpint.h @@ -360,10 +360,17 @@ mp_int *mp_modadd(mp_int *x, mp_int *y, mp_int *modulus); mp_int *mp_modsub(mp_int *x, mp_int *y, mp_int *modulus); /* - * Shift an mp_int right by a given number of bits. The shift count is + * Shift an mp_int by a given number of bits. The shift count is * considered to be secret data, and as a result, the algorithm takes * O(n log n) time instead of the obvious O(n). + * + * There's no mp_lshift_safe, because the size of mp_int to allocate + * would not be able to avoid depending on the shift count. So if you + * need to behave independently of the size of a left shift, you have + * to know a bound on the space you'll need by some other means. */ +void mp_lshift_safe_into(mp_int *r, mp_int *x, size_t shift); +void mp_rshift_safe_into(mp_int *r, mp_int *x, size_t shift); mp_int *mp_rshift_safe(mp_int *x, size_t shift); /* diff --git a/test/cryptsuite.py b/test/cryptsuite.py index 55269189..b25123c8 100755 --- a/test/cryptsuite.py +++ b/test/cryptsuite.py @@ -574,10 +574,20 @@ class mpint(MyTestBase): mp_lshift_fixed_into(mp, mp, i) self.assertEqual(int(mp), (x << i) & mp_mask(mp)) + mp_copy_into(mp, x) + mp_lshift_safe_into(mp, mp, i) + self.assertEqual(int(mp), (x << i) & mp_mask(mp)) + mp_copy_into(mp, x) mp_rshift_fixed_into(mp, mp, i) self.assertEqual(int(mp), x >> i) + + mp_copy_into(mp, x) + mp_rshift_safe_into(mp, mp, i) + self.assertEqual(int(mp), x >> i) + self.assertEqual(int(mp_rshift_fixed(x, i)), x >> i) + self.assertEqual(int(mp_rshift_safe(x, i)), x >> i) def testRandom(self): diff --git a/testcrypt.h b/testcrypt.h index 13f7011b..6832e017 100644 --- a/testcrypt.h +++ b/testcrypt.h @@ -75,6 +75,8 @@ FUNC3(val_mpint, mp_modpow, val_mpint, val_mpint, val_mpint) FUNC3(val_mpint, mp_modmul, val_mpint, val_mpint, val_mpint) FUNC3(val_mpint, mp_modadd, val_mpint, val_mpint, val_mpint) FUNC3(val_mpint, mp_modsub, val_mpint, val_mpint, val_mpint) +FUNC3(void, mp_lshift_safe_into, val_mpint, val_mpint, uint) +FUNC3(void, mp_rshift_safe_into, val_mpint, val_mpint, uint) FUNC2(val_mpint, mp_rshift_safe, val_mpint, uint) FUNC3(void, mp_lshift_fixed_into, val_mpint, val_mpint, uint) FUNC3(void, mp_rshift_fixed_into, val_mpint, val_mpint, uint)