1
0
mirror of https://git.tartarus.org/simon/putty.git synced 2025-01-10 01:48:00 +00:00

Fix an array-size bug in modmul, and add some tests for it.

[originally from svn r9977]
This commit is contained in:
Simon Tatham 2013-08-02 06:27:54 +00:00
parent a777103fd9
commit e01104f899
2 changed files with 55 additions and 0 deletions

45
sshbn.c
View File

@ -1018,6 +1018,13 @@ Bignum modmul(Bignum p, Bignum q, Bignum mod)
pqlen = (p[0] > q[0] ? p[0] : q[0]); pqlen = (p[0] > q[0] ? p[0] : q[0]);
/*
* Make sure that we're allowing enough space. The shifting below
* will underflow the vectors we allocate if pqlen is too small.
*/
if (2*pqlen <= mlen)
pqlen = mlen/2 + 1;
/* Allocate n of size pqlen, copy p to n */ /* Allocate n of size pqlen, copy p to n */
n = snewn(pqlen, BignumInt); n = snewn(pqlen, BignumInt);
i = pqlen - p[0]; i = pqlen - p[0];
@ -1864,6 +1871,44 @@ int main(int argc, char **argv)
freebn(b); freebn(b);
freebn(c); freebn(c);
freebn(p); freebn(p);
} else if (!strcmp(buf, "modmul")) {
Bignum a, b, m, c, p;
if (ptrnum != 4) {
printf("%d: modmul with %d parameters, expected 4\n",
line, ptrnum);
exit(1);
}
a = bignum_from_bytes(ptrs[0], ptrs[1]-ptrs[0]);
b = bignum_from_bytes(ptrs[1], ptrs[2]-ptrs[1]);
m = bignum_from_bytes(ptrs[2], ptrs[3]-ptrs[2]);
c = bignum_from_bytes(ptrs[3], ptrs[4]-ptrs[3]);
p = modmul(a, b, m);
if (bignum_cmp(c, p) == 0) {
passes++;
} else {
char *as = bignum_decimal(a);
char *bs = bignum_decimal(b);
char *ms = bignum_decimal(m);
char *cs = bignum_decimal(c);
char *ps = bignum_decimal(p);
printf("%d: fail: %s * %s mod %s gave %s expected %s\n",
line, as, bs, ms, ps, cs);
fails++;
sfree(as);
sfree(bs);
sfree(ms);
sfree(cs);
sfree(ps);
}
freebn(a);
freebn(b);
freebn(m);
freebn(c);
freebn(p);
} else if (!strcmp(buf, "pow")) { } else if (!strcmp(buf, "pow")) {
Bignum base, expt, modulus, expected, answer; Bignum base, expt, modulus, expected, answer;

10
testdata/bignum.py vendored
View File

@ -103,6 +103,15 @@ for i in range(1,4200):
a, b, p = findprod((1<<i)+1, +1, (i, i+1)) a, b, p = findprod((1<<i)+1, +1, (i, i+1))
print "mul", hexstr(a), hexstr(b), hexstr(p) print "mul", hexstr(a), hexstr(b), hexstr(p)
# Simple tests of modmul.
for ai in range(20, 200, 60):
a = sqrt(3<<(2*ai-1))
for bi in range(20, 200, 60):
b = sqrt(5<<(2*bi-1))
for m in range(20, 600, 32):
m = sqrt(2**(m+1))
print "modmul", hexstr(a), hexstr(b), hexstr(m), hexstr((a*b) % m)
# Simple tests of modpow. # Simple tests of modpow.
for i in range(64, 4097, 63): for i in range(64, 4097, 63):
modulus = sqrt(1<<(2*i-1)) | 1 modulus = sqrt(1<<(2*i-1)) | 1
@ -113,3 +122,4 @@ for i in range(64, 4097, 63):
# Test even moduli, which can't be done by Montgomery. # Test even moduli, which can't be done by Montgomery.
modulus = modulus - 1 modulus = modulus - 1
print "pow", hexstr(base), hexstr(expt), hexstr(modulus), hexstr(pow(base, expt, modulus)) print "pow", hexstr(base), hexstr(expt), hexstr(modulus), hexstr(pow(base, expt, modulus))
print "pow", hexstr(i), hexstr(expt), hexstr(modulus), hexstr(pow(i, expt, modulus))