mirror of
https://git.tartarus.org/simon/putty.git
synced 2025-01-10 01:48:00 +00:00
kh2reg.py: support ECDSA point compression.
We support it in the ECC code proper these days, as of the bignum
rewrite in commit 25b034ee3
. So we should support it in this auxiliary
script too, and fortunately, there's no real difficulty in doing so
because I already had some Python code kicking around in
test/eccref.py for taking modular square roots.
This commit is contained in:
parent
81be535f67
commit
5a508a84a2
@ -18,6 +18,8 @@ import string
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import getopt
|
import getopt
|
||||||
|
import itertools
|
||||||
|
import collections
|
||||||
|
|
||||||
def winmungestr(s):
|
def winmungestr(s):
|
||||||
"Duplicate of PuTTY's mungestr() in winstore.c:1.10 for Registry keys"
|
"Duplicate of PuTTY's mungestr() in winstore.c:1.10 for Registry keys"
|
||||||
@ -59,6 +61,115 @@ def warn(s):
|
|||||||
|
|
||||||
output_type = 'windows'
|
output_type = 'windows'
|
||||||
|
|
||||||
|
def invert(n, p):
|
||||||
|
"""Compute inverse mod p."""
|
||||||
|
if n % p == 0:
|
||||||
|
raise ZeroDivisionError()
|
||||||
|
a = n, 1, 0
|
||||||
|
b = p, 0, 1
|
||||||
|
while b[0]:
|
||||||
|
q = a[0] // b[0]
|
||||||
|
a = a[0] - q*b[0], a[1] - q*b[1], a[2] - q*b[2]
|
||||||
|
b, a = a, b
|
||||||
|
assert abs(a[0]) == 1
|
||||||
|
return a[1]*a[0]
|
||||||
|
|
||||||
|
def jacobi(n,m):
|
||||||
|
"""Compute the Jacobi symbol.
|
||||||
|
|
||||||
|
The special case of this when m is prime is the Legendre symbol,
|
||||||
|
which is 0 if n is congruent to 0 mod m; 1 if n is congruent to a
|
||||||
|
non-zero square number mod m; -1 if n is not congruent to any
|
||||||
|
square mod m.
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert m & 1
|
||||||
|
acc = 1
|
||||||
|
while True:
|
||||||
|
n %= m
|
||||||
|
if n == 0:
|
||||||
|
return 0
|
||||||
|
while not (n & 1):
|
||||||
|
n >>= 1
|
||||||
|
if (m & 7) not in {1,7}:
|
||||||
|
acc *= -1
|
||||||
|
if n == 1:
|
||||||
|
return acc
|
||||||
|
if (n & 3) == 3 and (m & 3) == 3:
|
||||||
|
acc *= -1
|
||||||
|
n, m = m, n
|
||||||
|
|
||||||
|
class SqrtModP(object):
|
||||||
|
"""Class for finding square roots of numbers mod p.
|
||||||
|
|
||||||
|
p must be an odd prime (but its primality is not checked)."""
|
||||||
|
|
||||||
|
def __init__(self, p):
|
||||||
|
p = abs(p)
|
||||||
|
assert p & 1
|
||||||
|
self.p = p
|
||||||
|
|
||||||
|
# Decompose p as 2^e k + 1 for odd k.
|
||||||
|
self.k = p-1
|
||||||
|
self.e = 0
|
||||||
|
while not (self.k & 1):
|
||||||
|
self.k >>= 1
|
||||||
|
self.e += 1
|
||||||
|
|
||||||
|
# Find a non-square mod p.
|
||||||
|
for self.z in itertools.count(1):
|
||||||
|
if jacobi(self.z, self.p) == -1:
|
||||||
|
break
|
||||||
|
self.zinv = invert(self.z, self.p)
|
||||||
|
|
||||||
|
def sqrt_recurse(self, a):
|
||||||
|
ak = pow(a, self.k, self.p)
|
||||||
|
for i in range(self.e, -1, -1):
|
||||||
|
if ak == 1:
|
||||||
|
break
|
||||||
|
ak = ak*ak % self.p
|
||||||
|
assert i > 0
|
||||||
|
if i == self.e:
|
||||||
|
return pow(a, (self.k+1) // 2, self.p)
|
||||||
|
r_prime = self.sqrt_recurse(a * pow(self.z, 2**i, self.p))
|
||||||
|
return r_prime * pow(self.zinv, 2**(i-1), self.p) % self.p
|
||||||
|
|
||||||
|
def sqrt(self, a):
|
||||||
|
j = jacobi(a, self.p)
|
||||||
|
if j == 0:
|
||||||
|
return 0
|
||||||
|
if j < 0:
|
||||||
|
raise ValueError("{} has no square root mod {}".format(a, self.p))
|
||||||
|
a %= self.p
|
||||||
|
r = self.sqrt_recurse(a)
|
||||||
|
assert r*r % self.p == a
|
||||||
|
# Normalise to the smaller (or 'positive') one of the two roots.
|
||||||
|
return min(r, self.p - r)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "{}({})".format(type(self).__name__, self.p)
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__str__()
|
||||||
|
|
||||||
|
instances = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make(cls, p):
|
||||||
|
if p not in cls.instances:
|
||||||
|
cls.instances[p] = cls(p)
|
||||||
|
return cls.instances[p]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def root(cls, n, p):
|
||||||
|
return cls.make(p).sqrt(n)
|
||||||
|
|
||||||
|
NistCurve = collections.namedtuple("NistCurve", "p a b")
|
||||||
|
nist_curves = {
|
||||||
|
"ecdsa-sha2-nistp256": NistCurve(0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff, 0xffffffff00000001000000000000000000000000fffffffffffffffffffffffc, 0x5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b),
|
||||||
|
"ecdsa-sha2-nistp384": NistCurve(0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000ffffffff, 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000fffffffc, 0xb3312fa7e23ee7e4988e056be3f82d19181d9c6efe8141120314088f5013875ac656398d8a2ed19d2a85c8edd3ec2aef),
|
||||||
|
"ecdsa-sha2-nistp521": NistCurve(0x01ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff, 0x01fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffc, 0x0051953eb9618e1c9a1f929a21a0b68540eea2da725b99b315f3b8b489918ef109e156193951ec7e937b1652c0bd3bb1bf073573df883d2c34f1ef451fd46b503f00),
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
optlist, args = getopt.getopt(sys.argv[1:], '', [ 'win', 'unix' ])
|
optlist, args = getopt.getopt(sys.argv[1:], '', [ 'win', 'unix' ])
|
||||||
if filter(lambda x: x[0] == '--unix', optlist):
|
if filter(lambda x: x[0] == '--unix', optlist):
|
||||||
@ -151,9 +262,7 @@ for line in fileinput.input(args):
|
|||||||
# Same again.
|
# Same again.
|
||||||
keyparams = map (strtolong, subfields[1:])
|
keyparams = map (strtolong, subfields[1:])
|
||||||
|
|
||||||
elif sshkeytype == "ecdsa-sha2-nistp256" \
|
elif sshkeytype in nist_curves:
|
||||||
or sshkeytype == "ecdsa-sha2-nistp384" \
|
|
||||||
or sshkeytype == "ecdsa-sha2-nistp521":
|
|
||||||
keytype = sshkeytype
|
keytype = sshkeytype
|
||||||
# Have to parse this a bit.
|
# Have to parse this a bit.
|
||||||
if len(subfields) > 3:
|
if len(subfields) > 3:
|
||||||
@ -166,16 +275,28 @@ for line in fileinput.input(args):
|
|||||||
% (sshkeytype, curvename))
|
% (sshkeytype, curvename))
|
||||||
# Second contains key material X and Y (hopefully).
|
# Second contains key material X and Y (hopefully).
|
||||||
# First a magic octet indicating point compression.
|
# First a magic octet indicating point compression.
|
||||||
if struct.unpack("B", Q[0])[0] != 4:
|
point_type = struct.unpack("B", Q[0])[0]
|
||||||
# No-one seems to use this.
|
Qrest = Q[1:]
|
||||||
raise KeyFormatError("can't convert point-compressed ECDSA")
|
if point_type == 4:
|
||||||
# Then two equal-length bignums (X and Y).
|
# Then two equal-length bignums (X and Y).
|
||||||
bnlen = len(Q)-1
|
bnlen = len(Qrest)
|
||||||
if (bnlen % 1) != 0:
|
if (bnlen % 1) != 0:
|
||||||
raise KeyFormatError("odd-length X+Y")
|
raise KeyFormatError("odd-length X+Y")
|
||||||
bnlen = bnlen / 2
|
bnlen = bnlen // 2
|
||||||
(x,y) = Q[1:bnlen+1], Q[bnlen+1:2*bnlen+1]
|
x = strtolong(Qrest[:bnlen])
|
||||||
keyparams = [curvename] + map (strtolong, [x,y])
|
y = strtolong(Qrest[bnlen:])
|
||||||
|
elif 2 <= point_type <= 3:
|
||||||
|
# A compressed point just specifies X, and leaves
|
||||||
|
# Y implicit except for parity, so we have to
|
||||||
|
# recover it from the curve equation.
|
||||||
|
curve = nist_curves[sshkeytype]
|
||||||
|
x = strtolong(Qrest)
|
||||||
|
yy = (x*x*x + curve.a*x + curve.b) % curve.p
|
||||||
|
y = SqrtModP.root(yy, curve.p)
|
||||||
|
if y % 2 != point_type % 2:
|
||||||
|
y = curve.p - y
|
||||||
|
|
||||||
|
keyparams = [curvename, x, y]
|
||||||
|
|
||||||
elif sshkeytype == "ssh-ed25519":
|
elif sshkeytype == "ssh-ed25519":
|
||||||
keytype = sshkeytype
|
keytype = sshkeytype
|
||||||
@ -197,19 +318,10 @@ for line in fileinput.input(args):
|
|||||||
d = 0x52036cee2b6ffe738cc740797779e89800700a4d4141d8ab75eb4dca135978a3
|
d = 0x52036cee2b6ffe738cc740797779e89800700a4d4141d8ab75eb4dca135978a3
|
||||||
|
|
||||||
# Recover x^2 = (y^2 - 1) / (d y^2 + 1).
|
# Recover x^2 = (y^2 - 1) / (d y^2 + 1).
|
||||||
#
|
xx = (y*y - 1) * invert(d*y*y + 1, p) % p
|
||||||
# With no real time constraints here, it's easier to
|
|
||||||
# take the inverse of the denominator by raising it to
|
|
||||||
# the power p-2 (by Fermat's Little Theorem) than
|
|
||||||
# faffing about with the properly efficient Euclid
|
|
||||||
# method.
|
|
||||||
xx = (y*y - 1) * pow(d*y*y + 1, p-2, p) % p
|
|
||||||
|
|
||||||
# Take the square root, which may require trying twice.
|
# Take the square root.
|
||||||
x = pow(xx, (p+3)/8, p)
|
x = SqrtModP.root(xx, p)
|
||||||
if pow(x, 2, p) != xx:
|
|
||||||
x = x * pow(2, (p-1)/4, p) % p
|
|
||||||
assert pow(x, 2, p) == xx
|
|
||||||
|
|
||||||
# Pick the square root of the correct parity.
|
# Pick the square root of the correct parity.
|
||||||
if (x % 2) != x_parity:
|
if (x % 2) != x_parity:
|
||||||
|
Loading…
Reference in New Issue
Block a user