mirror of
https://git.tartarus.org/simon/putty.git
synced 2025-01-10 09:58:01 +00:00
c2ec13c7e9
As I mentioned in the previous commit, I'm going to want PuTTY to be able to run sensibly when compiled with 64-bit Visual Studio, including handling bignums in 64-bit chunks for speed. Unfortunately, 64-bit VS does not provide any type we can use as BignumDblInt in that situation (unlike 64-bit gcc and clang, which give us __uint128_t). The only facilities it provides are compiler intrinsics to access an add-with-carry operation and a 64x64->128 multiplication (the latter delivering its product in two separate 64-bit output chunks). Hence, here's a substantial rework of the bignum code to make it implement everything in terms of _those_ primitives, rather than depending throughout on having BignumDblInt available to use ad-hoc. BignumDblInt does still exist, for the moment, but now it's an internal implementation detail of sshbn.h, only declared inside a new set of macros implementing arithmetic primitives, and not accessible to any code outside sshbn.h (which confirms that I really did catch all uses of it and remove them). The resulting code is surprisingly nice-looking, actually. You'd expect more hassle and roundabout circumlocutions when you drop down to using a more basic set of primitive operations, but actually, in many cases it's turned out shorter to write things in terms of the new BignumADC and BignumMUL macros - because almost all my uses of BignumDblInt were implementing those operations anyway, taking several lines at a time, and now they can do each thing in just one line. The biggest headache was Poly1305: I wasn't able to find any sensible way to adapt the existing Python script that generates the various per-int-size implementations of arithmetic mod 2^130-5, and so I had to rewrite it from scratch instead, with nothing in common with the old version beyond a handful of comments. But even that seems to have worked out nicely: the new version has much more legible descriptions of the high-level algorithms, by virtue of having a 'Multiprecision' type which wraps up the division into words, and yet Multiprecision's range analysis allows it to automatically drop out special cases such as multiplication by 5 being much easier than multiplication by another multi-word integer.
369 lines
14 KiB
Python
Executable File
369 lines
14 KiB
Python
Executable File
#!/usr/bin/env python
|
|
|
|
import sys
|
|
import string
|
|
from collections import namedtuple
|
|
|
|
class Multiprecision(object):
|
|
def __init__(self, target, minval, maxval, words):
|
|
self.target = target
|
|
self.minval = minval
|
|
self.maxval = maxval
|
|
self.words = words
|
|
assert 0 <= self.minval
|
|
assert self.minval <= self.maxval
|
|
assert self.target.nwords(self.maxval) == len(words)
|
|
|
|
def getword(self, n):
|
|
return self.words[n] if n < len(self.words) else "0"
|
|
|
|
def __add__(self, rhs):
|
|
newmin = self.minval + rhs.minval
|
|
newmax = self.maxval + rhs.maxval
|
|
nwords = self.target.nwords(newmax)
|
|
words = []
|
|
|
|
addfn = self.target.add
|
|
for i in range(nwords):
|
|
words.append(addfn(self.getword(i), rhs.getword(i)))
|
|
addfn = self.target.adc
|
|
|
|
return Multiprecision(self.target, newmin, newmax, words)
|
|
|
|
def __mul__(self, rhs):
|
|
newmin = self.minval * rhs.minval
|
|
newmax = self.maxval * rhs.maxval
|
|
nwords = self.target.nwords(newmax)
|
|
words = []
|
|
|
|
# There are basically two strategies we could take for
|
|
# multiplying two multiprecision integers. One is to enumerate
|
|
# the space of pairs of word indices in lexicographic order,
|
|
# essentially computing a*b[i] for each i and adding them
|
|
# together; the other is to enumerate in diagonal order,
|
|
# computing everything together that belongs at a particular
|
|
# output word index.
|
|
#
|
|
# For the moment, I've gone for the former.
|
|
|
|
sprev = []
|
|
for i, sword in enumerate(self.words):
|
|
rprev = None
|
|
sthis = sprev[:i]
|
|
for j, rword in enumerate(rhs.words):
|
|
prevwords = []
|
|
if i+j < len(sprev):
|
|
prevwords.append(sprev[i+j])
|
|
if rprev is not None:
|
|
prevwords.append(rprev)
|
|
vhi, vlo = self.target.muladd(sword, rword, *prevwords)
|
|
sthis.append(vlo)
|
|
rprev = vhi
|
|
sthis.append(rprev)
|
|
sprev = sthis
|
|
|
|
# Remove unneeded words from the top of the output, if we can
|
|
# prove by range analysis that they'll always be zero.
|
|
sprev = sprev[:self.target.nwords(newmax)]
|
|
|
|
return Multiprecision(self.target, newmin, newmax, sprev)
|
|
|
|
def extract_bits(self, start, bits=None):
|
|
if bits is None:
|
|
bits = (self.maxval >> start).bit_length()
|
|
|
|
# Overly thorough range analysis: if min and max have the same
|
|
# *quotient* by 2^bits, then the result of reducing anything
|
|
# in the range [min,max] mod 2^bits has to fall within the
|
|
# obvious range. But if they have different quotients, then
|
|
# you can wrap round the modulus and so any value mod 2^bits
|
|
# is possible.
|
|
newmin = self.minval >> start
|
|
newmax = self.maxval >> start
|
|
if (newmin >> bits) != (newmax >> bits):
|
|
newmin = 0
|
|
newmax = (1 << bits) - 1
|
|
|
|
nwords = self.target.nwords(newmax)
|
|
words = []
|
|
for i in range(nwords):
|
|
srcpos = i * self.target.bits + start
|
|
maxbits = min(self.target.bits, start + bits - srcpos)
|
|
wordindex = srcpos / self.target.bits
|
|
if srcpos % self.target.bits == 0:
|
|
word = self.getword(srcpos / self.target.bits)
|
|
elif (wordindex+1 >= len(self.words) or
|
|
srcpos % self.target.bits + maxbits < self.target.bits):
|
|
word = self.target.new_value(
|
|
"(%%s) >> %d" % (srcpos % self.target.bits),
|
|
self.getword(srcpos / self.target.bits))
|
|
else:
|
|
word = self.target.new_value(
|
|
"((%%s) >> %d) | ((%%s) << %d)" % (
|
|
srcpos % self.target.bits,
|
|
self.target.bits - (srcpos % self.target.bits)),
|
|
self.getword(srcpos / self.target.bits),
|
|
self.getword(srcpos / self.target.bits + 1))
|
|
if maxbits < self.target.bits and maxbits < bits:
|
|
word = self.target.new_value(
|
|
"(%%s) & ((((BignumInt)1) << %d)-1)" % maxbits,
|
|
word)
|
|
words.append(word)
|
|
|
|
return Multiprecision(self.target, newmin, newmax, words)
|
|
|
|
# Each Statement has a list of variables it reads, and a list of ones
|
|
# it writes. 'forms' is a list of multiple actual C statements it
|
|
# could be generated as, depending on which of its output variables is
|
|
# actually used (e.g. no point calling BignumADC if the generated
|
|
# carry in a particular case is unused, or BignumMUL if nobody needs
|
|
# the top half). It is indexed by a bitmap whose bits correspond to
|
|
# the entries in wvars, with wvars[0] the MSB and wvars[-1] the LSB.
|
|
Statement = namedtuple("Statement", "rvars wvars forms")
|
|
|
|
class CodegenTarget(object):
|
|
def __init__(self, bits):
|
|
self.bits = bits
|
|
self.valindex = 0
|
|
self.stmts = []
|
|
self.generators = {}
|
|
self.bv_words = (130 + self.bits - 1) / self.bits
|
|
self.carry_index = 0
|
|
|
|
def nwords(self, maxval):
|
|
return (maxval.bit_length() + self.bits - 1) / self.bits
|
|
|
|
def stmt(self, stmt, needed=False):
|
|
index = len(self.stmts)
|
|
self.stmts.append([needed, stmt])
|
|
for val in stmt.wvars:
|
|
self.generators[val] = index
|
|
|
|
def new_value(self, formatstr=None, *deps):
|
|
name = "v%d" % self.valindex
|
|
self.valindex += 1
|
|
if formatstr is not None:
|
|
self.stmt(Statement(
|
|
rvars=deps, wvars=[name],
|
|
forms=[None, name + " = " + formatstr % deps]))
|
|
return name
|
|
|
|
def bigval_input(self, name, bits):
|
|
words = (bits + self.bits - 1) / self.bits
|
|
# Expect not to require an entire extra word
|
|
assert words == self.bv_words
|
|
|
|
return Multiprecision(self, 0, (1<<bits)-1, [
|
|
self.new_value("%s->w[%d]" % (name, i)) for i in range(words)])
|
|
|
|
def const(self, value):
|
|
# We only support constants small enough to both fit in a
|
|
# BignumInt (of any size supported) _and_ be expressible in C
|
|
# with no weird integer literal syntax like a trailing LL.
|
|
#
|
|
# Supporting larger constants would be possible - you could
|
|
# break 'value' up into word-sized pieces on the Python side,
|
|
# and generate a legal C expression for each piece by
|
|
# splitting it further into pieces within the
|
|
# standards-guaranteed 'unsigned long' limit of 32 bits and
|
|
# then casting those to BignumInt before combining them with
|
|
# shifts. But it would be a lot of effort, and since the
|
|
# application for this code doesn't even need it, there's no
|
|
# point in bothering.
|
|
assert value < 2**16
|
|
return Multiprecision(self, value, value, ["%d" % value])
|
|
|
|
def current_carry(self):
|
|
return "carry%d" % self.carry_index
|
|
|
|
def add(self, a1, a2):
|
|
ret = self.new_value()
|
|
adcform = "BignumADC(%s, carry, %s, %s, 0)" % (ret, a1, a2)
|
|
plainform = "%s = %s + %s" % (ret, a1, a2)
|
|
self.carry_index += 1
|
|
carryout = self.current_carry()
|
|
self.stmt(Statement(
|
|
rvars=[a1,a2], wvars=[ret,carryout],
|
|
forms=[None, adcform, plainform, adcform]))
|
|
return ret
|
|
|
|
def adc(self, a1, a2):
|
|
ret = self.new_value()
|
|
adcform = "BignumADC(%s, carry, %s, %s, carry)" % (ret, a1, a2)
|
|
plainform = "%s = %s + %s + carry" % (ret, a1, a2)
|
|
carryin = self.current_carry()
|
|
self.carry_index += 1
|
|
carryout = self.current_carry()
|
|
self.stmt(Statement(
|
|
rvars=[a1,a2,carryin], wvars=[ret,carryout],
|
|
forms=[None, adcform, plainform, adcform]))
|
|
return ret
|
|
|
|
def muladd(self, m1, m2, *addends):
|
|
rlo = self.new_value()
|
|
rhi = self.new_value()
|
|
wideform = "BignumMUL%s(%s)" % (
|
|
{ 0:"", 1:"ADD", 2:"ADD2" }[len(addends)],
|
|
", ".join([rhi, rlo, m1, m2] + list(addends)))
|
|
narrowform = " + ".join(["%s = %s * %s" % (rlo, m1, m2)] +
|
|
list(addends))
|
|
self.stmt(Statement(
|
|
rvars=[m1,m2]+list(addends), wvars=[rhi,rlo],
|
|
forms=[None, narrowform, wideform, wideform]))
|
|
return rhi, rlo
|
|
|
|
def write_bigval(self, name, val):
|
|
for i in range(self.bv_words):
|
|
word = val.getword(i)
|
|
self.stmt(Statement(
|
|
rvars=[word], wvars=[],
|
|
forms=["%s->w[%d] = %s" % (name, i, word)]),
|
|
needed=True)
|
|
|
|
def compute_needed(self):
|
|
used_vars = set()
|
|
|
|
self.queue = [stmt for (needed,stmt) in self.stmts if needed]
|
|
while len(self.queue) > 0:
|
|
stmt = self.queue.pop(0)
|
|
deps = []
|
|
for var in stmt.rvars:
|
|
if var[0] in string.digits:
|
|
continue # constant
|
|
deps.append(self.generators[var])
|
|
used_vars.add(var)
|
|
for index in deps:
|
|
if not self.stmts[index][0]:
|
|
self.stmts[index][0] = True
|
|
self.queue.append(self.stmts[index][1])
|
|
|
|
forms = []
|
|
for i, (needed, stmt) in enumerate(self.stmts):
|
|
if needed:
|
|
formindex = 0
|
|
for (j, var) in enumerate(stmt.wvars):
|
|
formindex *= 2
|
|
if var in used_vars:
|
|
formindex += 1
|
|
forms.append(stmt.forms[formindex])
|
|
|
|
# Now we must check whether this form of the statement
|
|
# also writes some variables we _don't_ actually need
|
|
# (e.g. if you only wanted the top half from a mul, or
|
|
# only the carry from an adc, you'd be forced to
|
|
# generate the other output too). Easiest way to do
|
|
# this is to look for an identical statement form
|
|
# later in the array.
|
|
maxindex = max(i for i in range(len(stmt.forms))
|
|
if stmt.forms[i] == stmt.forms[formindex])
|
|
extra_vars = maxindex & ~formindex
|
|
bitpos = 0
|
|
while extra_vars != 0:
|
|
if extra_vars & (1 << bitpos):
|
|
extra_vars &= ~(1 << bitpos)
|
|
var = stmt.wvars[-1-bitpos]
|
|
used_vars.add(var)
|
|
# Also, write out a cast-to-void for each
|
|
# subsequently unused value, to prevent gcc
|
|
# warnings when the output code is compiled.
|
|
forms.append("(void)" + var)
|
|
bitpos += 1
|
|
|
|
used_carry = any(v.startswith("carry") for v in used_vars)
|
|
used_vars = [v for v in used_vars if v.startswith("v")]
|
|
used_vars.sort(key=lambda v: int(v[1:]))
|
|
|
|
return used_carry, used_vars, forms
|
|
|
|
def text(self):
|
|
used_carry, values, forms = self.compute_needed()
|
|
|
|
ret = ""
|
|
while len(values) > 0:
|
|
prefix, sep, suffix = " BignumInt ", ", ", ";"
|
|
currline = values.pop(0)
|
|
while (len(values) > 0 and
|
|
len(prefix+currline+sep+values[0]+suffix) < 79):
|
|
currline += sep + values.pop(0)
|
|
ret += prefix + currline + suffix + "\n"
|
|
if used_carry:
|
|
ret += " BignumCarry carry;\n"
|
|
if ret != "":
|
|
ret += "\n"
|
|
for stmtform in forms:
|
|
ret += " %s;\n" % stmtform
|
|
return ret
|
|
|
|
def gen_add(target):
|
|
# This is an addition _without_ reduction mod p, so that it can be
|
|
# used both during accumulation of the polynomial and for adding
|
|
# on the encrypted nonce at the end (which is mod 2^128, not mod
|
|
# p).
|
|
#
|
|
# Because one of the inputs will have come from our
|
|
# not-completely-reducing multiplication function, we expect up to
|
|
# 3 extra bits of input.
|
|
|
|
a = target.bigval_input("a", 133)
|
|
b = target.bigval_input("b", 133)
|
|
ret = a + b
|
|
target.write_bigval("r", ret)
|
|
return """\
|
|
static void bigval_add(bigval *r, const bigval *a, const bigval *b)
|
|
{
|
|
%s}
|
|
\n""" % target.text()
|
|
|
|
def gen_mul(target):
|
|
# The inputs are not 100% reduced mod p. Specifically, we can get
|
|
# a full 130-bit number from the pow5==0 pass, and then a 130-bit
|
|
# number times 5 from the pow5==1 pass, plus a possible carry. The
|
|
# total of that can be easily bounded above by 2^130 * 8, so we
|
|
# need to assume we're multiplying two 133-bit numbers.
|
|
|
|
a = target.bigval_input("a", 133)
|
|
b = target.bigval_input("b", 133)
|
|
ab = a * b
|
|
ab0 = ab.extract_bits(0, 130)
|
|
ab1 = ab.extract_bits(130, 130)
|
|
ab2 = ab.extract_bits(260)
|
|
ab1_5 = target.const(5) * ab1
|
|
ab2_25 = target.const(25) * ab2
|
|
ret = ab0 + ab1_5 + ab2_25
|
|
target.write_bigval("r", ret)
|
|
return """\
|
|
static void bigval_mul_mod_p(bigval *r, const bigval *a, const bigval *b)
|
|
{
|
|
%s}
|
|
\n""" % target.text()
|
|
|
|
def gen_final_reduce(target):
|
|
# We take our input number n, and compute k = n + 5*(n >> 130).
|
|
# Then k >> 130 is precisely the multiple of p that needs to be
|
|
# subtracted from n to reduce it to strictly less than p.
|
|
|
|
a = target.bigval_input("n", 133)
|
|
a1 = a.extract_bits(130, 130)
|
|
k = a + target.const(5) * a1
|
|
q = k.extract_bits(130)
|
|
adjusted = a + target.const(5) * q
|
|
ret = adjusted.extract_bits(0, 130)
|
|
target.write_bigval("n", ret)
|
|
return """\
|
|
static void bigval_final_reduce(bigval *n)
|
|
{
|
|
%s}
|
|
\n""" % target.text()
|
|
|
|
pp_keyword = "#if"
|
|
for bits in [16, 32, 64]:
|
|
sys.stdout.write("%s BIGNUM_INT_BITS == %d\n\n" % (pp_keyword, bits))
|
|
pp_keyword = "#elif"
|
|
sys.stdout.write(gen_add(CodegenTarget(bits)))
|
|
sys.stdout.write(gen_mul(CodegenTarget(bits)))
|
|
sys.stdout.write(gen_final_reduce(CodegenTarget(bits)))
|
|
sys.stdout.write("""#else
|
|
#error Add another bit count to contrib/make1305.py and rerun it
|
|
#endif
|
|
""")
|