diff --git a/mpint.c b/mpint.c index f9ba20ea..8b72382f 100644 --- a/mpint.c +++ b/mpint.c @@ -1373,21 +1373,19 @@ mp_int *monty_invert(MontyContext *mc, mp_int *x) /* * Importing a number into Montgomery representation involves - * multiplying it by r and reducing mod m. We could do this using the - * straightforward mp_modmul, but since we have the machinery to avoid - * division, why don't we use it? If we multiply the number not by r - * itself, but by the residue of r^2 mod m, then we can do an actual - * Montgomery reduction to reduce the result and remove the extra - * factor of r. + * multiplying it by r and reducing mod m. We use the general-purpose + * mp_modmul for this, in case the input number is out of range. */ -void monty_import_into(MontyContext *mc, mp_int *r, mp_int *x) -{ - monty_mul_into(mc, r, x, mc->powers_of_r_mod_m[1]); -} - mp_int *monty_import(MontyContext *mc, mp_int *x) { - return monty_mul(mc, x, mc->powers_of_r_mod_m[1]); + return mp_modmul(x, mc->powers_of_r_mod_m[0], mc->m); +} + +void monty_import_into(MontyContext *mc, mp_int *r, mp_int *x) +{ + mp_int *imported = monty_import(mc, x); + mp_copy_into(r, imported); + mp_free(imported); } /* @@ -1450,7 +1448,6 @@ mp_int *monty_pow(MontyContext *mc, mp_int *base, mp_int *exponent) mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus) { - assert(base->nw <= modulus->nw); assert(modulus->nw > 0); assert(modulus->w[0] & 1); diff --git a/test/cryptsuite.py b/test/cryptsuite.py index 958f2ce4..1b9d78d2 100755 --- a/test/cryptsuite.py +++ b/test/cryptsuite.py @@ -473,6 +473,10 @@ class mpint(MyTestBase): b, e, m = 0x2B5B93812F253FF91F56B3B4DAD01CA2884B6A80719B0DA4E2159A230C6009EDA97C5C8FD4636B324F9594706EE3AD444831571BA5E17B1B2DFA92DEA8B7E, 0x25, 0xC8FCFD0FD7371F4FE8D0150EFC124E220581569587CCD8E50423FA8D41E0B2A0127E100E92501E5EE3228D12EA422A568C17E0AD2E5C5FCC2AE9159D2B7FB8CB assert(int(mp_modpow(b, e, m)) == pow(b, e, m)) + # Make sure mp_modpow can handle a base larger than the + # modulus, by pre-reducing it + assert(int(mp_modpow(1<<877, 907, 999979)) == pow(2, 877*907, 999979)) + def testModsqrt(self): moduli = [ 5, 19, 2**16+1, 2**31-1, 2**128-159, 2**255-19,