diff --git a/mpint.c b/mpint.c index 203298c2..32f21c3f 100644 --- a/mpint.c +++ b/mpint.c @@ -2076,11 +2076,15 @@ mp_int *mp_modsub(mp_int *x, mp_int *y, mp_int *modulus) mp_sub_into(diff, x, y); unsigned negate = mp_cmp_hs(y, x); mp_cond_negate(diff, diff, negate); - mp_int *reduced = mp_mod(diff, modulus); - mp_cond_negate(reduced, reduced, negate); - mp_cond_add_into(reduced, reduced, modulus, negate); + mp_int *residue = mp_mod(diff, modulus); + mp_cond_negate(residue, residue, negate); + /* If we've just negated the residue, then it will be < 0 and need + * the modulus adding to it to make it positive - *except* if the + * residue was zero when we negated it. */ + unsigned make_positive = negate & ~mp_eq_integer(residue, 0); + mp_cond_add_into(residue, residue, modulus, make_positive); mp_free(diff); - return reduced; + return residue; } static mp_int *mp_modadd_in_range(mp_int *x, mp_int *y, mp_int *modulus)