diff --git a/contrib/kh2reg.py b/contrib/kh2reg.py index 3c2a9c5f..6f4b79be 100755 --- a/contrib/kh2reg.py +++ b/contrib/kh2reg.py @@ -18,6 +18,8 @@ import string import re import sys import getopt +import itertools +import collections def winmungestr(s): "Duplicate of PuTTY's mungestr() in winstore.c:1.10 for Registry keys" @@ -59,6 +61,115 @@ def warn(s): output_type = 'windows' +def invert(n, p): + """Compute inverse mod p.""" + if n % p == 0: + raise ZeroDivisionError() + a = n, 1, 0 + b = p, 0, 1 + while b[0]: + q = a[0] // b[0] + a = a[0] - q*b[0], a[1] - q*b[1], a[2] - q*b[2] + b, a = a, b + assert abs(a[0]) == 1 + return a[1]*a[0] + +def jacobi(n,m): + """Compute the Jacobi symbol. + + The special case of this when m is prime is the Legendre symbol, + which is 0 if n is congruent to 0 mod m; 1 if n is congruent to a + non-zero square number mod m; -1 if n is not congruent to any + square mod m. + + """ + assert m & 1 + acc = 1 + while True: + n %= m + if n == 0: + return 0 + while not (n & 1): + n >>= 1 + if (m & 7) not in {1,7}: + acc *= -1 + if n == 1: + return acc + if (n & 3) == 3 and (m & 3) == 3: + acc *= -1 + n, m = m, n + +class SqrtModP(object): + """Class for finding square roots of numbers mod p. + + p must be an odd prime (but its primality is not checked).""" + + def __init__(self, p): + p = abs(p) + assert p & 1 + self.p = p + + # Decompose p as 2^e k + 1 for odd k. + self.k = p-1 + self.e = 0 + while not (self.k & 1): + self.k >>= 1 + self.e += 1 + + # Find a non-square mod p. + for self.z in itertools.count(1): + if jacobi(self.z, self.p) == -1: + break + self.zinv = invert(self.z, self.p) + + def sqrt_recurse(self, a): + ak = pow(a, self.k, self.p) + for i in range(self.e, -1, -1): + if ak == 1: + break + ak = ak*ak % self.p + assert i > 0 + if i == self.e: + return pow(a, (self.k+1) // 2, self.p) + r_prime = self.sqrt_recurse(a * pow(self.z, 2**i, self.p)) + return r_prime * pow(self.zinv, 2**(i-1), self.p) % self.p + + def sqrt(self, a): + j = jacobi(a, self.p) + if j == 0: + return 0 + if j < 0: + raise ValueError("{} has no square root mod {}".format(a, self.p)) + a %= self.p + r = self.sqrt_recurse(a) + assert r*r % self.p == a + # Normalise to the smaller (or 'positive') one of the two roots. + return min(r, self.p - r) + + def __str__(self): + return "{}({})".format(type(self).__name__, self.p) + def __repr__(self): + return self.__str__() + + instances = {} + + @classmethod + def make(cls, p): + if p not in cls.instances: + cls.instances[p] = cls(p) + return cls.instances[p] + + @classmethod + def root(cls, n, p): + return cls.make(p).sqrt(n) + +NistCurve = collections.namedtuple("NistCurve", "p a b") +nist_curves = { + "ecdsa-sha2-nistp256": NistCurve(0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff, 0xffffffff00000001000000000000000000000000fffffffffffffffffffffffc, 0x5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b), + "ecdsa-sha2-nistp384": NistCurve(0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000ffffffff, 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000fffffffc, 0xb3312fa7e23ee7e4988e056be3f82d19181d9c6efe8141120314088f5013875ac656398d8a2ed19d2a85c8edd3ec2aef), + "ecdsa-sha2-nistp521": NistCurve(0x01ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff, 0x01fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffc, 0x0051953eb9618e1c9a1f929a21a0b68540eea2da725b99b315f3b8b489918ef109e156193951ec7e937b1652c0bd3bb1bf073573df883d2c34f1ef451fd46b503f00), +} + try: optlist, args = getopt.getopt(sys.argv[1:], '', [ 'win', 'unix' ]) if filter(lambda x: x[0] == '--unix', optlist): @@ -151,9 +262,7 @@ for line in fileinput.input(args): # Same again. keyparams = map (strtolong, subfields[1:]) - elif sshkeytype == "ecdsa-sha2-nistp256" \ - or sshkeytype == "ecdsa-sha2-nistp384" \ - or sshkeytype == "ecdsa-sha2-nistp521": + elif sshkeytype in nist_curves: keytype = sshkeytype # Have to parse this a bit. if len(subfields) > 3: @@ -166,16 +275,28 @@ for line in fileinput.input(args): % (sshkeytype, curvename)) # Second contains key material X and Y (hopefully). # First a magic octet indicating point compression. - if struct.unpack("B", Q[0])[0] != 4: - # No-one seems to use this. - raise KeyFormatError("can't convert point-compressed ECDSA") - # Then two equal-length bignums (X and Y). - bnlen = len(Q)-1 - if (bnlen % 1) != 0: - raise KeyFormatError("odd-length X+Y") - bnlen = bnlen / 2 - (x,y) = Q[1:bnlen+1], Q[bnlen+1:2*bnlen+1] - keyparams = [curvename] + map (strtolong, [x,y]) + point_type = struct.unpack("B", Q[0])[0] + Qrest = Q[1:] + if point_type == 4: + # Then two equal-length bignums (X and Y). + bnlen = len(Qrest) + if (bnlen % 1) != 0: + raise KeyFormatError("odd-length X+Y") + bnlen = bnlen // 2 + x = strtolong(Qrest[:bnlen]) + y = strtolong(Qrest[bnlen:]) + elif 2 <= point_type <= 3: + # A compressed point just specifies X, and leaves + # Y implicit except for parity, so we have to + # recover it from the curve equation. + curve = nist_curves[sshkeytype] + x = strtolong(Qrest) + yy = (x*x*x + curve.a*x + curve.b) % curve.p + y = SqrtModP.root(yy, curve.p) + if y % 2 != point_type % 2: + y = curve.p - y + + keyparams = [curvename, x, y] elif sshkeytype == "ssh-ed25519": keytype = sshkeytype @@ -197,19 +318,10 @@ for line in fileinput.input(args): d = 0x52036cee2b6ffe738cc740797779e89800700a4d4141d8ab75eb4dca135978a3 # Recover x^2 = (y^2 - 1) / (d y^2 + 1). - # - # With no real time constraints here, it's easier to - # take the inverse of the denominator by raising it to - # the power p-2 (by Fermat's Little Theorem) than - # faffing about with the properly efficient Euclid - # method. - xx = (y*y - 1) * pow(d*y*y + 1, p-2, p) % p + xx = (y*y - 1) * invert(d*y*y + 1, p) % p - # Take the square root, which may require trying twice. - x = pow(xx, (p+3)/8, p) - if pow(x, 2, p) != xx: - x = x * pow(2, (p-1)/4, p) % p - assert pow(x, 2, p) == xx + # Take the square root. + x = SqrtModP.root(xx, p) # Pick the square root of the correct parity. if (x % 2) != x_parity: