diff --git a/sshbn.c b/sshbn.c index c9859093..455aa57a 100644 --- a/sshbn.c +++ b/sshbn.c @@ -532,107 +532,419 @@ static void internal_add_shifted(BignumInt *number, } } +static int bn_clz(BignumInt x) +{ + /* + * Count the leading zero bits in x. Equivalently, how far left + * would we need to shift x to make its top bit set? + * + * Precondition: x != 0. + */ + + /* FIXME: would be nice to put in some compiler intrinsics under + * ifdef here */ + int i, ret = 0; + for (i = BIGNUM_INT_BITS / 2; i != 0; i >>= 1) { + if ((x >> (BIGNUM_INT_BITS-i)) == 0) { + x <<= i; + ret += i; + } + } + return ret; +} + +static BignumInt reciprocal_word(BignumInt d) +{ + BignumInt dshort, recip; + BignumDblInt product; + int corrections; + + /* + * Input: a BignumInt value d, with its top bit set. + */ + assert(d >> (BIGNUM_INT_BITS-1) == 1); + + /* + * Output: a value, shifted to fill a BignumInt, which is strictly + * less than 1/(d+1), i.e. is an *under*-estimate (but by as + * little as possible within the constraints) of the reciprocal of + * any number whose first BIGNUM_INT_BITS bits match d. + * + * Ideally we'd like to _totally_ fill BignumInt, i.e. always + * return a value with the top bit set. Unfortunately we can't + * quite guarantee that for all inputs and also return a fixed + * exponent. So instead we take our reciprocal to be + * 2^(BIGNUM_INT_BITS*2-1) / d, so that it has the top bit clear + * only in the exceptional case where d takes exactly the maximum + * value BIGNUM_INT_MASK; in that case, the top bit is clear and + * the next bit down is set. + */ + + /* + * Start by computing a half-length version of the answer, by + * straightforward division within a BignumInt. + */ + dshort = (d >> (BIGNUM_INT_BITS/2)) + 1; + recip = (BIGNUM_TOP_BIT + dshort - 1) / dshort; + recip <<= BIGNUM_INT_BITS - BIGNUM_INT_BITS/2; + + /* + * Newton-Raphson iteration to improve that starting reciprocal + * estimate: take f(x) = d - 1/x, and then the N-R formula gives + * x_new = x - f(x)/f'(x) = x - (d-1/x)/(1/x^2) = x(2-d*x). Or, + * taking our fixed-point representation into account, take f(x) + * to be d - K/x (where K = 2^(BIGNUM_INT_BITS*2-1) as discussed + * above) and then we get (2K - d*x) * x/K. + * + * Newton-Raphson doubles the number of correct bits at every + * iteration, and the initial division above already gave us half + * the output word, so it's only worth doing one iteration. + */ + product = MUL_WORD(recip, d); + product += recip; + product = -product; /* the 2K shifts just off the top */ + product &= (((BignumDblInt)BIGNUM_INT_MASK << BIGNUM_INT_BITS) + + BIGNUM_INT_MASK); + product >>= BIGNUM_INT_BITS; + product = MUL_WORD(product, recip); + product >>= (BIGNUM_INT_BITS-1); + recip = (BignumInt)product; + + /* + * Now make sure we have the best possible reciprocal estimate, + * before we return it. We might have been off by a handful either + * way - not enough to bother with any better-thought-out kind of + * correction loop. + */ + product = MUL_WORD(recip, d); + product += recip; + corrections = 0; + if (product >= ((BignumDblInt)1 << (2*BIGNUM_INT_BITS-1))) { + do { + product -= d; + recip--; + corrections++; + } while (product >= ((BignumDblInt)1 << (2*BIGNUM_INT_BITS-1))); + } else { + while (product < ((BignumDblInt)1 << (2*BIGNUM_INT_BITS-1)) - d) { + product += d; + recip++; + corrections++; + } + } + + return recip; +} + /* * Compute a = a % m. * Input in first alen words of a and first mlen words of m. * Output in first alen words of a * (of which first alen-mlen words will be zero). - * The MSW of m MUST have its high bit set. * Quotient is accumulated in the `quotient' array, which is a Bignum - * rather than the internal bigendian format. Quotient parts are shifted - * left by `qshift' before adding into quot. + * rather than the internal bigendian format. + * + * 'recip' must be the result of calling reciprocal_word() on the top + * BIGNUM_INT_BITS of the modulus (denoted m0 in comments below), with + * the topmost set bit normalised to the MSB of the input to + * reciprocal_word. 'rshift' is how far left the top nonzero word of + * the modulus had to be shifted to set that top bit. */ static void internal_mod(BignumInt *a, int alen, BignumInt *m, int mlen, - BignumInt *quot, int qshift) + BignumInt *quot, BignumInt recip, int rshift) { - BignumInt m0, m1, h; int i, k; - m0 = m[0]; - assert(m0 >> (BIGNUM_INT_BITS-1) == 1); - if (mlen > 1) - m1 = m[1]; - else - m1 = 0; - - for (i = 0; i <= alen - mlen; i++) { - BignumDblInt t; - BignumInt q, r, c, ai1; - - if (i == 0) { - h = 0; - } else { - h = a[i - 1]; - a[i - 1] = 0; - } - - if (i == alen - 1) - ai1 = 0; - else - ai1 = a[i + 1]; - - /* Find q = h:a[i] / m0 */ - if (h >= m0) { - /* - * Special case. - * - * To illustrate it, suppose a BignumInt is 8 bits, and - * we are dividing (say) A1:23:45:67 by A1:B2:C3. Then - * our initial division will be 0xA123 / 0xA1, which - * will give a quotient of 0x100 and a divide overflow. - * However, the invariants in this division algorithm - * are not violated, since the full number A1:23:... is - * _less_ than the quotient prefix A1:B2:... and so the - * following correction loop would have sorted it out. - * - * In this situation we set q to be the largest - * quotient we _can_ stomach (0xFF, of course). - */ - q = BIGNUM_INT_MASK; - } else { - /* Macro doesn't want an array subscript expression passed - * into it (see definition), so use a temporary. */ - BignumInt tmplo = a[i]; - DIVMOD_WORD(q, r, h, tmplo, m0); - - /* Refine our estimate of q by looking at - h:a[i]:a[i+1] / m0:m1 */ - t = MUL_WORD(m1, q); - if (t > ((BignumDblInt) r << BIGNUM_INT_BITS) + ai1) { - q--; - t -= m1; - r = (r + m0) & BIGNUM_INT_MASK; /* overflow? */ - if (r >= (BignumDblInt) m0 && - t > ((BignumDblInt) r << BIGNUM_INT_BITS) + ai1) q--; - } - } - - /* Subtract q * m from a[i...] */ - c = 0; - for (k = mlen - 1; k >= 0; k--) { - t = MUL_WORD(q, m[k]); - t += c; - c = (BignumInt)(t >> BIGNUM_INT_BITS); - if ((BignumInt) t > a[i + k]) - c++; - a[i + k] -= (BignumInt) t; - } - - /* Add back m in case of borrow */ - if (c != h) { - t = 0; - for (k = mlen - 1; k >= 0; k--) { - t += m[k]; - t += a[i + k]; - a[i + k] = (BignumInt) t; - t = t >> BIGNUM_INT_BITS; - } - q--; - } - if (quot) - internal_add_shifted(quot, q, qshift + BIGNUM_INT_BITS * (alen - mlen - i)); +#ifdef DIVISION_DEBUG + { + int d; + printf("start division, m=0x"); + for (d = 0; d < mlen; d++) + printf("%0*llx", BIGNUM_INT_BITS/4, (unsigned long long)m[d]); + printf(", recip=%#0*llx, rshift=%d\n", + BIGNUM_INT_BITS/4, (unsigned long long)recip, rshift); } +#endif + + /* + * Repeatedly use that reciprocal estimate to get a decent number + * of quotient bits, and subtract off the resulting multiple of m. + * + * Normally we expect to terminate this loop by means of finding + * out q=0 part way through, but one way in which we might not get + * that far in the first place is if the input a is actually zero, + * in which case we'll discard zero words from the front of a + * until we reach the termination condition in the for statement + * here. + */ + for (i = 0; i <= alen - mlen ;) { + BignumDblInt product, subtmp, t; + BignumInt aword, q; + int shift, full_bitoffset, bitoffset, wordoffset; + +#ifdef DIVISION_DEBUG + { + int d; + printf("main loop, a=0x"); + for (d = 0; d < alen; d++) + printf("%0*llx", BIGNUM_INT_BITS/4, (unsigned long long)a[d]); + printf("\n"); + } +#endif + + if (a[i] == 0) { +#ifdef DIVISION_DEBUG + printf("zero word at i=%d\n", i); +#endif + i++; + continue; + } + + aword = a[i]; + shift = bn_clz(aword); + aword <<= shift; + if (shift > 0 && i+1 < alen) + aword |= a[i+1] >> (BIGNUM_INT_BITS - shift); + + t = MUL_WORD(recip, aword); + q = (BignumInt)(t >> BIGNUM_INT_BITS); + +#ifdef DIVISION_DEBUG + printf("i=%d, aword=%#0*llx, shift=%d, q=%#0*llx\n", + i, BIGNUM_INT_BITS/4, (unsigned long long)aword, + shift, BIGNUM_INT_BITS/4, (unsigned long long)q); +#endif + + /* + * Work out the right bit and word offsets to use when + * subtracting q*m from a. + * + * aword was taken from a[i], which means its LSB was at bit + * position (alen-1-i) * BIGNUM_INT_BITS. But then we shifted + * it left by 'shift', so now the low bit of aword corresponds + * to bit position (alen-1-i) * BIGNUM_INT_BITS - shift, i.e. + * aword is approximately equal to a / 2^(that). + * + * m0 comes from the top word of mod, so its LSB is at bit + * position (mlen-1) * BIGNUM_INT_BITS - rshift, i.e. it can + * be considered to be m / 2^(that power). 'recip' is the + * reciprocal of m0, times 2^(BIGNUM_INT_BITS*2-1), i.e. it's + * about 2^((mlen+1) * BIGNUM_INT_BITS - rshift - 1) / m. + * + * Hence, recip * aword is approximately equal to the product + * of those, which simplifies to + * + * a/m * 2^((mlen+2+i-alen)*BIGNUM_INT_BITS + shift - rshift - 1) + * + * But we've also shifted recip*aword down by BIGNUM_INT_BITS + * to form q, so we have + * + * q ~= a/m * 2^((mlen+1+i-alen)*BIGNUM_INT_BITS + shift - rshift - 1) + * + * and hence, when we now compute q*m, it will be about + * a*2^(all that lot), i.e. the negation of that expression is + * how far left we have to shift the product q*m to make it + * approximately equal to a. + */ + full_bitoffset = -((mlen+1+i-alen)*BIGNUM_INT_BITS + shift-rshift-1); +#ifdef DIVISION_DEBUG + printf("full_bitoffset=%d\n", full_bitoffset); +#endif + + if (full_bitoffset < 0) { + /* + * If we find ourselves needing to shift q*m _right_, that + * means we've reached the bottom of the quotient. Clip q + * so that its right shift becomes zero, and if that means + * q becomes _actually_ zero, this loop is done. + */ + if (full_bitoffset <= -BIGNUM_INT_BITS) + break; + q >>= -full_bitoffset; + full_bitoffset = 0; + if (!q) + break; +#ifdef DIVISION_DEBUG + printf("now full_bitoffset=%d, q=%#0*llx\n", + full_bitoffset, BIGNUM_INT_BITS/4, (unsigned long long)q); +#endif + } + + wordoffset = full_bitoffset / BIGNUM_INT_BITS; + bitoffset = full_bitoffset % BIGNUM_INT_BITS; +#ifdef DIVISION_DEBUG + printf("wordoffset=%d, bitoffset=%d\n", wordoffset, bitoffset); +#endif + + /* wordoffset as computed above is the offset between the LSWs + * of m and a. But in fact m and a are stored MSW-first, so we + * need to adjust it to be the offset between the actual array + * indices, and flip the sign too. */ + wordoffset = alen - mlen - wordoffset; + + if (bitoffset == 0) { + BignumInt c = 1; + BignumInt prev_hi_word = 0; + for (k = mlen - 1; wordoffset+k >= i; k--) { + BignumInt mword = k<0 ? 0 : m[k]; + product = MUL_WORD(q, mword); + product += prev_hi_word; + prev_hi_word = product >> BIGNUM_INT_BITS; +#ifdef DIVISION_DEBUG + printf(" aligned sub: product word for m[%d] = %#0*llx\n", + k, BIGNUM_INT_BITS/4, + (unsigned long long)(BignumInt)product); +#endif +#ifdef DIVISION_DEBUG + printf(" aligned sub: subtrahend for a[%d] = %#0*llx\n", + wordoffset+k, BIGNUM_INT_BITS/4, + (unsigned long long)(BignumInt)product); +#endif + subtmp = (BignumDblInt)a[wordoffset+k] + + ((BignumInt)product ^ BIGNUM_INT_MASK) + c; + a[wordoffset+k] = (BignumInt)subtmp; + c = subtmp >> BIGNUM_INT_BITS; + } + } else { + BignumInt add_word = 0; + BignumInt c = 1; + BignumInt prev_hi_word = 0; + for (k = mlen - 1; wordoffset+k >= i; k--) { + BignumInt mword = k<0 ? 0 : m[k]; + product = MUL_WORD(q, mword); + product += prev_hi_word; + prev_hi_word = product >> BIGNUM_INT_BITS; +#ifdef DIVISION_DEBUG + printf(" unaligned sub: product word for m[%d] = %#0*llx\n", + k, BIGNUM_INT_BITS/4, + (unsigned long long)(BignumInt)product); +#endif + + add_word |= (BignumInt)product << bitoffset; + +#ifdef DIVISION_DEBUG + printf(" unaligned sub: subtrahend for a[%d] = %#0*llx\n", + wordoffset+k, + BIGNUM_INT_BITS/4, (unsigned long long)add_word); +#endif + subtmp = (BignumDblInt)a[wordoffset+k] + + (add_word ^ BIGNUM_INT_MASK) + c; + a[wordoffset+k] = (BignumInt)subtmp; + c = subtmp >> BIGNUM_INT_BITS; + + add_word = (BignumInt)product >> (BIGNUM_INT_BITS - bitoffset); + } + } + + if (quot) { +#ifdef DIVISION_DEBUG + printf("adding quotient word %#0*llx << %d\n", + BIGNUM_INT_BITS/4, (unsigned long long)q, full_bitoffset); +#endif + internal_add_shifted(quot, q, full_bitoffset); +#ifdef DIVISION_DEBUG + { + int d; + printf("now quot=0x"); + for (d = quot[0]; d > 0; d--) + printf("%0*llx", BIGNUM_INT_BITS/4, + (unsigned long long)quot[d]); + printf("\n"); + } +#endif + } + } + +#ifdef DIVISION_DEBUG + { + int d; + printf("end main loop, a=0x"); + for (d = 0; d < alen; d++) + printf("%0*llx", BIGNUM_INT_BITS/4, (unsigned long long)a[d]); + if (quot) { + printf(", quot=0x"); + for (d = quot[0]; d > 0; d--) + printf("%0*llx", BIGNUM_INT_BITS/4, + (unsigned long long)quot[d]); + } + printf("\n"); + } +#endif + + /* + * The above loop should terminate with the remaining value in a + * being strictly less than 2*m (if a >= 2*m then we should always + * have managed to get a nonzero q word), but we can't guarantee + * that it will be strictly less than m: consider a case where the + * remainder is 1, and another where the remainder is m-1. By the + * time a contains a value that's _about m_, you clearly can't + * distinguish those cases by looking at only the top word of a - + * you have to go all the way down to the bottom before you find + * out whether it's just less or just more than m. + * + * Hence, we now do a final fixup in which we subtract one last + * copy of m, or don't, accordingly. We should never have to + * subtract more than one copy of m here. + */ + for (i = 0; i < alen; i++) { + /* Compare a with m, word by word, from the MSW down. As soon + * as we encounter a difference, we know whether we need the + * fixup. */ + int mindex = mlen-alen+i; + BignumInt mword = mindex < 0 ? 0 : m[mindex]; + if (a[i] < mword) { +#ifdef DIVISION_DEBUG + printf("final fixup not needed, a < m\n"); +#endif + return; + } else if (a[i] > mword) { +#ifdef DIVISION_DEBUG + printf("final fixup is needed, a > m\n"); +#endif + break; + } + /* If neither of those cases happened, the words are the same, + * so keep going and look at the next one. */ + } +#ifdef DIVISION_DEBUG + if (i == mlen) /* if we printed neither of the above diagnostics */ + printf("final fixup is needed, a == m\n"); +#endif + + /* + * If we got here without returning, then a >= m, so we must + * subtract m, and increment the quotient. + */ + { + BignumInt c = 1; + for (i = alen - 1; i >= 0; i--) { + int mindex = mlen-alen+i; + BignumInt mword = mindex < 0 ? 0 : m[mindex]; + BignumDblInt subtmp = (BignumDblInt)a[i] + + ((BignumInt)mword ^ BIGNUM_INT_MASK) + c; + a[i] = (BignumInt)subtmp; + c = subtmp >> BIGNUM_INT_BITS; + } + } + if (quot) + internal_add_shifted(quot, 1, 0); + +#ifdef DIVISION_DEBUG + { + int d; + printf("after final fixup, a=0x"); + for (d = 0; d < alen; d++) + printf("%0*llx", BIGNUM_INT_BITS/4, (unsigned long long)a[d]); + if (quot) { + printf(", quot=0x"); + for (d = quot[0]; d > 0; d--) + printf("%0*llx", BIGNUM_INT_BITS/4, + (unsigned long long)quot[d]); + } + printf("\n"); + } +#endif } /* @@ -641,7 +953,8 @@ static void internal_mod(BignumInt *a, int alen, Bignum modpow_simple(Bignum base_in, Bignum exp, Bignum mod) { BignumInt *a, *b, *n, *m, *scratch; - int mshift; + BignumInt recip; + int rshift; int mlen, scratchlen, i, j; Bignum base, result; @@ -664,16 +977,6 @@ Bignum modpow_simple(Bignum base_in, Bignum exp, Bignum mod) for (j = 0; j < mlen; j++) m[j] = mod[mod[0] - j]; - /* Shift m left to make msb bit set */ - for (mshift = 0; mshift < BIGNUM_INT_BITS-1; mshift++) - if ((m[0] << mshift) & BIGNUM_TOP_BIT) - break; - if (mshift) { - for (i = 0; i < mlen - 1; i++) - m[i] = (m[i] << mshift) | (m[i + 1] >> (BIGNUM_INT_BITS - mshift)); - m[mlen - 1] = m[mlen - 1] << mshift; - } - /* Allocate n of size mlen, copy base to n */ n = snewn(mlen, BignumInt); i = mlen - base[0]; @@ -704,14 +1007,26 @@ Bignum modpow_simple(Bignum base_in, Bignum exp, Bignum mod) } } + /* Compute reciprocal of the top full word of the modulus */ + { + BignumInt m0 = m[0]; + rshift = bn_clz(m0); + if (rshift) { + m0 <<= rshift; + if (mlen > 1) + m0 |= m[1] >> (BIGNUM_INT_BITS - rshift); + } + recip = reciprocal_word(m0); + } + /* Main computation */ while (i < (int)exp[0]) { while (j >= 0) { internal_mul(a + mlen, a + mlen, b, mlen, scratch); - internal_mod(b, mlen * 2, m, mlen, NULL, 0); + internal_mod(b, mlen * 2, m, mlen, NULL, recip, rshift); if ((exp[exp[0] - i] & ((BignumInt)1 << j)) != 0) { internal_mul(b + mlen, n, a, mlen, scratch); - internal_mod(a, mlen * 2, m, mlen, NULL, 0); + internal_mod(a, mlen * 2, m, mlen, NULL, recip, rshift); } else { BignumInt *t; t = a; @@ -724,16 +1039,6 @@ Bignum modpow_simple(Bignum base_in, Bignum exp, Bignum mod) j = BIGNUM_INT_BITS-1; } - /* Fixup result in case the modulus was shifted */ - if (mshift) { - for (i = mlen - 1; i < 2 * mlen - 1; i++) - a[i] = (a[i] << mshift) | (a[i + 1] >> (BIGNUM_INT_BITS - mshift)); - a[2 * mlen - 1] = a[2 * mlen - 1] << mshift; - internal_mod(a, mlen * 2, m, mlen, NULL, 0); - for (i = 2 * mlen - 1; i >= mlen; i--) - a[i] = (a[i] >> mshift) | (a[i - 1] << (BIGNUM_INT_BITS - mshift)); - } - /* Copy result to buffer */ result = newbn(mod[0]); for (i = 0; i < mlen; i++) @@ -912,7 +1217,8 @@ Bignum modpow(Bignum base_in, Bignum exp, Bignum mod) Bignum modmul(Bignum p, Bignum q, Bignum mod) { BignumInt *a, *n, *m, *o, *scratch; - int mshift, scratchlen; + BignumInt recip; + int rshift, scratchlen; int pqlen, mlen, rlen, i, j; Bignum result; @@ -929,16 +1235,6 @@ Bignum modmul(Bignum p, Bignum q, Bignum mod) for (j = 0; j < mlen; j++) m[j] = mod[mod[0] - j]; - /* Shift m left to make msb bit set */ - for (mshift = 0; mshift < BIGNUM_INT_BITS-1; mshift++) - if ((m[0] << mshift) & BIGNUM_TOP_BIT) - break; - if (mshift) { - for (i = 0; i < mlen - 1; i++) - m[i] = (m[i] << mshift) | (m[i + 1] >> (BIGNUM_INT_BITS - mshift)); - m[mlen - 1] = m[mlen - 1] << mshift; - } - pqlen = (p[0] > q[0] ? p[0] : q[0]); /* @@ -971,19 +1267,21 @@ Bignum modmul(Bignum p, Bignum q, Bignum mod) scratchlen = mul_compute_scratch(pqlen); scratch = snewn(scratchlen, BignumInt); + /* Compute reciprocal of the top full word of the modulus */ + { + BignumInt m0 = m[0]; + rshift = bn_clz(m0); + if (rshift) { + m0 <<= rshift; + if (mlen > 1) + m0 |= m[1] >> (BIGNUM_INT_BITS - rshift); + } + recip = reciprocal_word(m0); + } + /* Main computation */ internal_mul(n, o, a, pqlen, scratch); - internal_mod(a, pqlen * 2, m, mlen, NULL, 0); - - /* Fixup result in case the modulus was shifted */ - if (mshift) { - for (i = 2 * pqlen - mlen - 1; i < 2 * pqlen - 1; i++) - a[i] = (a[i] << mshift) | (a[i + 1] >> (BIGNUM_INT_BITS - mshift)); - a[2 * pqlen - 1] = a[2 * pqlen - 1] << mshift; - internal_mod(a, pqlen * 2, m, mlen, NULL, 0); - for (i = 2 * pqlen - 1; i >= 2 * pqlen - mlen; i--) - a[i] = (a[i] >> mshift) | (a[i - 1] << (BIGNUM_INT_BITS - mshift)); - } + internal_mod(a, pqlen * 2, m, mlen, NULL, recip, rshift); /* Copy result to buffer */ rlen = (mlen < pqlen * 2 ? mlen : pqlen * 2); @@ -1047,7 +1345,8 @@ Bignum modsub(const Bignum a, const Bignum b, const Bignum n) static void bigdivmod(Bignum p, Bignum mod, Bignum result, Bignum quotient) { BignumInt *n, *m; - int mshift; + BignumInt recip; + int rshift; int plen, mlen, i, j; /* @@ -1063,16 +1362,6 @@ static void bigdivmod(Bignum p, Bignum mod, Bignum result, Bignum quotient) for (j = 0; j < mlen; j++) m[j] = mod[mod[0] - j]; - /* Shift m left to make msb bit set */ - for (mshift = 0; mshift < BIGNUM_INT_BITS-1; mshift++) - if ((m[0] << mshift) & BIGNUM_TOP_BIT) - break; - if (mshift) { - for (i = 0; i < mlen - 1; i++) - m[i] = (m[i] << mshift) | (m[i + 1] >> (BIGNUM_INT_BITS - mshift)); - m[mlen - 1] = m[mlen - 1] << mshift; - } - plen = p[0]; /* Ensure plen > mlen */ if (plen <= mlen) @@ -1085,19 +1374,21 @@ static void bigdivmod(Bignum p, Bignum mod, Bignum result, Bignum quotient) for (j = 1; j <= (int)p[0]; j++) n[plen - j] = p[j]; - /* Main computation */ - internal_mod(n, plen, m, mlen, quotient, mshift); - - /* Fixup result in case the modulus was shifted */ - if (mshift) { - for (i = plen - mlen - 1; i < plen - 1; i++) - n[i] = (n[i] << mshift) | (n[i + 1] >> (BIGNUM_INT_BITS - mshift)); - n[plen - 1] = n[plen - 1] << mshift; - internal_mod(n, plen, m, mlen, quotient, 0); - for (i = plen - 1; i >= plen - mlen; i--) - n[i] = (n[i] >> mshift) | (n[i - 1] << (BIGNUM_INT_BITS - mshift)); + /* Compute reciprocal of the top full word of the modulus */ + { + BignumInt m0 = m[0]; + rshift = bn_clz(m0); + if (rshift) { + m0 <<= rshift; + if (mlen > 1) + m0 |= m[1] >> (BIGNUM_INT_BITS - rshift); + } + recip = reciprocal_word(m0); } + /* Main computation */ + internal_mod(n, plen, m, mlen, quotient, recip, rshift); + /* Copy result to buffer */ if (result) { for (i = 1; i <= (int)result[0]; i++) { diff --git a/sshbn.h b/sshbn.h index 9366f614..fc0e6b5a 100644 --- a/sshbn.h +++ b/sshbn.h @@ -1,23 +1,8 @@ /* * sshbn.h: the assorted conditional definitions of BignumInt and - * multiply/divide macros used throughout the bignum code to treat - * numbers as arrays of the most conveniently sized word for the - * target machine. Exported so that other code (e.g. poly1305) can use - * it too. - */ - -/* - * Usage notes: - * * Do not call the DIVMOD_WORD macro with expressions such as array - * subscripts, as some implementations object to this (see below). - * * Note that none of the division methods below will cope if the - * quotient won't fit into BIGNUM_INT_BITS. Callers should be careful - * to avoid this case. - * If this condition occurs, in the case of the x86 DIV instruction, - * an overflow exception will occur, which (according to a correspondent) - * will manifest on Windows as something like - * 0xC0000095: Integer overflow - * The C variant won't give the right answer, either. + * multiply macros used throughout the bignum code to treat numbers as + * arrays of the most conveniently sized word for the target machine. + * Exported so that other code (e.g. poly1305) can use it too. */ #if defined __SIZEOF_INT128__ @@ -32,11 +17,6 @@ typedef __uint128_t BignumDblInt; #define BIGNUM_TOP_BIT 0x8000000000000000ULL #define BIGNUM_INT_BITS 64 #define MUL_WORD(w1, w2) ((BignumDblInt)w1 * w2) -#define DIVMOD_WORD(q, r, hi, lo, w) do { \ - BignumDblInt n = (((BignumDblInt)hi) << BIGNUM_INT_BITS) | lo; \ - q = n / w; \ - r = n % w; \ -} while (0) #elif defined __GNUC__ && defined __i386__ typedef unsigned long BignumInt; typedef unsigned long long BignumDblInt; @@ -44,10 +24,6 @@ typedef unsigned long long BignumDblInt; #define BIGNUM_TOP_BIT 0x80000000UL #define BIGNUM_INT_BITS 32 #define MUL_WORD(w1, w2) ((BignumDblInt)w1 * w2) -#define DIVMOD_WORD(q, r, hi, lo, w) \ - __asm__("div %2" : \ - "=d" (r), "=a" (q) : \ - "r" (w), "d" (hi), "a" (lo)) #elif defined _MSC_VER && defined _M_IX86 typedef unsigned __int32 BignumInt; typedef unsigned __int64 BignumDblInt; @@ -55,16 +31,6 @@ typedef unsigned __int64 BignumDblInt; #define BIGNUM_TOP_BIT 0x80000000UL #define BIGNUM_INT_BITS 32 #define MUL_WORD(w1, w2) ((BignumDblInt)w1 * w2) -/* Note: MASM interprets array subscripts in the macro arguments as - * assembler syntax, which gives the wrong answer. Don't supply them. - * */ -#define DIVMOD_WORD(q, r, hi, lo, w) do { \ - __asm mov edx, hi \ - __asm mov eax, lo \ - __asm div w \ - __asm mov r, edx \ - __asm mov q, eax \ -} while(0) #elif defined _LP64 /* 64-bit architectures can do 32x32->64 chunks at a time */ typedef unsigned int BignumInt; @@ -73,11 +39,6 @@ typedef unsigned long BignumDblInt; #define BIGNUM_TOP_BIT 0x80000000U #define BIGNUM_INT_BITS 32 #define MUL_WORD(w1, w2) ((BignumDblInt)w1 * w2) -#define DIVMOD_WORD(q, r, hi, lo, w) do { \ - BignumDblInt n = (((BignumDblInt)hi) << BIGNUM_INT_BITS) | lo; \ - q = n / w; \ - r = n % w; \ -} while (0) #elif defined _LLP64 /* 64-bit architectures in which unsigned long is 32 bits, not 64 */ typedef unsigned long BignumInt; @@ -86,11 +47,6 @@ typedef unsigned long long BignumDblInt; #define BIGNUM_TOP_BIT 0x80000000UL #define BIGNUM_INT_BITS 32 #define MUL_WORD(w1, w2) ((BignumDblInt)w1 * w2) -#define DIVMOD_WORD(q, r, hi, lo, w) do { \ - BignumDblInt n = (((BignumDblInt)hi) << BIGNUM_INT_BITS) | lo; \ - q = n / w; \ - r = n % w; \ -} while (0) #else /* Fallback for all other cases */ typedef unsigned short BignumInt; @@ -99,11 +55,6 @@ typedef unsigned long BignumDblInt; #define BIGNUM_TOP_BIT 0x8000U #define BIGNUM_INT_BITS 16 #define MUL_WORD(w1, w2) ((BignumDblInt)w1 * w2) -#define DIVMOD_WORD(q, r, hi, lo, w) do { \ - BignumDblInt n = (((BignumDblInt)hi) << BIGNUM_INT_BITS) | lo; \ - q = n / w; \ - r = n % w; \ -} while (0) #endif #define BIGNUM_INT_BYTES (BIGNUM_INT_BITS / 8)