mirror of
https://git.tartarus.org/simon/putty.git
synced 2025-01-09 17:38:00 +00:00
kh2reg.py: support ECDSA point compression.
We support it in the ECC code proper these days, as of the bignum
rewrite in commit 25b034ee3
. So we should support it in this auxiliary
script too, and fortunately, there's no real difficulty in doing so
because I already had some Python code kicking around in
test/eccref.py for taking modular square roots.
This commit is contained in:
parent
81be535f67
commit
5a508a84a2
@ -18,6 +18,8 @@ import string
|
||||
import re
|
||||
import sys
|
||||
import getopt
|
||||
import itertools
|
||||
import collections
|
||||
|
||||
def winmungestr(s):
|
||||
"Duplicate of PuTTY's mungestr() in winstore.c:1.10 for Registry keys"
|
||||
@ -59,6 +61,115 @@ def warn(s):
|
||||
|
||||
output_type = 'windows'
|
||||
|
||||
def invert(n, p):
|
||||
"""Compute inverse mod p."""
|
||||
if n % p == 0:
|
||||
raise ZeroDivisionError()
|
||||
a = n, 1, 0
|
||||
b = p, 0, 1
|
||||
while b[0]:
|
||||
q = a[0] // b[0]
|
||||
a = a[0] - q*b[0], a[1] - q*b[1], a[2] - q*b[2]
|
||||
b, a = a, b
|
||||
assert abs(a[0]) == 1
|
||||
return a[1]*a[0]
|
||||
|
||||
def jacobi(n,m):
|
||||
"""Compute the Jacobi symbol.
|
||||
|
||||
The special case of this when m is prime is the Legendre symbol,
|
||||
which is 0 if n is congruent to 0 mod m; 1 if n is congruent to a
|
||||
non-zero square number mod m; -1 if n is not congruent to any
|
||||
square mod m.
|
||||
|
||||
"""
|
||||
assert m & 1
|
||||
acc = 1
|
||||
while True:
|
||||
n %= m
|
||||
if n == 0:
|
||||
return 0
|
||||
while not (n & 1):
|
||||
n >>= 1
|
||||
if (m & 7) not in {1,7}:
|
||||
acc *= -1
|
||||
if n == 1:
|
||||
return acc
|
||||
if (n & 3) == 3 and (m & 3) == 3:
|
||||
acc *= -1
|
||||
n, m = m, n
|
||||
|
||||
class SqrtModP(object):
|
||||
"""Class for finding square roots of numbers mod p.
|
||||
|
||||
p must be an odd prime (but its primality is not checked)."""
|
||||
|
||||
def __init__(self, p):
|
||||
p = abs(p)
|
||||
assert p & 1
|
||||
self.p = p
|
||||
|
||||
# Decompose p as 2^e k + 1 for odd k.
|
||||
self.k = p-1
|
||||
self.e = 0
|
||||
while not (self.k & 1):
|
||||
self.k >>= 1
|
||||
self.e += 1
|
||||
|
||||
# Find a non-square mod p.
|
||||
for self.z in itertools.count(1):
|
||||
if jacobi(self.z, self.p) == -1:
|
||||
break
|
||||
self.zinv = invert(self.z, self.p)
|
||||
|
||||
def sqrt_recurse(self, a):
|
||||
ak = pow(a, self.k, self.p)
|
||||
for i in range(self.e, -1, -1):
|
||||
if ak == 1:
|
||||
break
|
||||
ak = ak*ak % self.p
|
||||
assert i > 0
|
||||
if i == self.e:
|
||||
return pow(a, (self.k+1) // 2, self.p)
|
||||
r_prime = self.sqrt_recurse(a * pow(self.z, 2**i, self.p))
|
||||
return r_prime * pow(self.zinv, 2**(i-1), self.p) % self.p
|
||||
|
||||
def sqrt(self, a):
|
||||
j = jacobi(a, self.p)
|
||||
if j == 0:
|
||||
return 0
|
||||
if j < 0:
|
||||
raise ValueError("{} has no square root mod {}".format(a, self.p))
|
||||
a %= self.p
|
||||
r = self.sqrt_recurse(a)
|
||||
assert r*r % self.p == a
|
||||
# Normalise to the smaller (or 'positive') one of the two roots.
|
||||
return min(r, self.p - r)
|
||||
|
||||
def __str__(self):
|
||||
return "{}({})".format(type(self).__name__, self.p)
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
instances = {}
|
||||
|
||||
@classmethod
|
||||
def make(cls, p):
|
||||
if p not in cls.instances:
|
||||
cls.instances[p] = cls(p)
|
||||
return cls.instances[p]
|
||||
|
||||
@classmethod
|
||||
def root(cls, n, p):
|
||||
return cls.make(p).sqrt(n)
|
||||
|
||||
NistCurve = collections.namedtuple("NistCurve", "p a b")
|
||||
nist_curves = {
|
||||
"ecdsa-sha2-nistp256": NistCurve(0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff, 0xffffffff00000001000000000000000000000000fffffffffffffffffffffffc, 0x5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b),
|
||||
"ecdsa-sha2-nistp384": NistCurve(0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000ffffffff, 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000fffffffc, 0xb3312fa7e23ee7e4988e056be3f82d19181d9c6efe8141120314088f5013875ac656398d8a2ed19d2a85c8edd3ec2aef),
|
||||
"ecdsa-sha2-nistp521": NistCurve(0x01ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff, 0x01fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffc, 0x0051953eb9618e1c9a1f929a21a0b68540eea2da725b99b315f3b8b489918ef109e156193951ec7e937b1652c0bd3bb1bf073573df883d2c34f1ef451fd46b503f00),
|
||||
}
|
||||
|
||||
try:
|
||||
optlist, args = getopt.getopt(sys.argv[1:], '', [ 'win', 'unix' ])
|
||||
if filter(lambda x: x[0] == '--unix', optlist):
|
||||
@ -151,9 +262,7 @@ for line in fileinput.input(args):
|
||||
# Same again.
|
||||
keyparams = map (strtolong, subfields[1:])
|
||||
|
||||
elif sshkeytype == "ecdsa-sha2-nistp256" \
|
||||
or sshkeytype == "ecdsa-sha2-nistp384" \
|
||||
or sshkeytype == "ecdsa-sha2-nistp521":
|
||||
elif sshkeytype in nist_curves:
|
||||
keytype = sshkeytype
|
||||
# Have to parse this a bit.
|
||||
if len(subfields) > 3:
|
||||
@ -166,16 +275,28 @@ for line in fileinput.input(args):
|
||||
% (sshkeytype, curvename))
|
||||
# Second contains key material X and Y (hopefully).
|
||||
# First a magic octet indicating point compression.
|
||||
if struct.unpack("B", Q[0])[0] != 4:
|
||||
# No-one seems to use this.
|
||||
raise KeyFormatError("can't convert point-compressed ECDSA")
|
||||
# Then two equal-length bignums (X and Y).
|
||||
bnlen = len(Q)-1
|
||||
if (bnlen % 1) != 0:
|
||||
raise KeyFormatError("odd-length X+Y")
|
||||
bnlen = bnlen / 2
|
||||
(x,y) = Q[1:bnlen+1], Q[bnlen+1:2*bnlen+1]
|
||||
keyparams = [curvename] + map (strtolong, [x,y])
|
||||
point_type = struct.unpack("B", Q[0])[0]
|
||||
Qrest = Q[1:]
|
||||
if point_type == 4:
|
||||
# Then two equal-length bignums (X and Y).
|
||||
bnlen = len(Qrest)
|
||||
if (bnlen % 1) != 0:
|
||||
raise KeyFormatError("odd-length X+Y")
|
||||
bnlen = bnlen // 2
|
||||
x = strtolong(Qrest[:bnlen])
|
||||
y = strtolong(Qrest[bnlen:])
|
||||
elif 2 <= point_type <= 3:
|
||||
# A compressed point just specifies X, and leaves
|
||||
# Y implicit except for parity, so we have to
|
||||
# recover it from the curve equation.
|
||||
curve = nist_curves[sshkeytype]
|
||||
x = strtolong(Qrest)
|
||||
yy = (x*x*x + curve.a*x + curve.b) % curve.p
|
||||
y = SqrtModP.root(yy, curve.p)
|
||||
if y % 2 != point_type % 2:
|
||||
y = curve.p - y
|
||||
|
||||
keyparams = [curvename, x, y]
|
||||
|
||||
elif sshkeytype == "ssh-ed25519":
|
||||
keytype = sshkeytype
|
||||
@ -197,19 +318,10 @@ for line in fileinput.input(args):
|
||||
d = 0x52036cee2b6ffe738cc740797779e89800700a4d4141d8ab75eb4dca135978a3
|
||||
|
||||
# Recover x^2 = (y^2 - 1) / (d y^2 + 1).
|
||||
#
|
||||
# With no real time constraints here, it's easier to
|
||||
# take the inverse of the denominator by raising it to
|
||||
# the power p-2 (by Fermat's Little Theorem) than
|
||||
# faffing about with the properly efficient Euclid
|
||||
# method.
|
||||
xx = (y*y - 1) * pow(d*y*y + 1, p-2, p) % p
|
||||
xx = (y*y - 1) * invert(d*y*y + 1, p) % p
|
||||
|
||||
# Take the square root, which may require trying twice.
|
||||
x = pow(xx, (p+3)/8, p)
|
||||
if pow(x, 2, p) != xx:
|
||||
x = x * pow(2, (p-1)/4, p) % p
|
||||
assert pow(x, 2, p) == xx
|
||||
# Take the square root.
|
||||
x = SqrtModP.root(xx, p)
|
||||
|
||||
# Pick the square root of the correct parity.
|
||||
if (x % 2) != x_parity:
|
||||
|
Loading…
Reference in New Issue
Block a user