1
0
mirror of https://git.tartarus.org/simon/putty.git synced 2025-03-23 15:09:24 -05:00

kh2reg.py: work with Python 3.

This commit is contained in:
Simon Tatham 2019-04-21 13:55:22 +01:00
parent 5a508a84a2
commit 33d4d223a5

View File

@ -8,8 +8,7 @@
# kh2reg.py --unix known_hosts1 2 3 4 ... > sshhostkeys # kh2reg.py --unix known_hosts1 2 3 4 ... > sshhostkeys
# Creates data suitable for storing in ~/.putty/sshhostkeys (Unix). # Creates data suitable for storing in ~/.putty/sshhostkeys (Unix).
# Line endings are someone else's problem as is traditional. # Line endings are someone else's problem as is traditional.
# Originally developed for Python 1.5.2, but probably won't run on that # Should run under either Python 2 or 3.
# any more.
import fileinput import fileinput
import base64 import base64
@ -20,6 +19,7 @@ import sys
import getopt import getopt
import itertools import itertools
import collections import collections
from functools import reduce
def winmungestr(s): def winmungestr(s):
"Duplicate of PuTTY's mungestr() in winstore.c:1.10 for Registry keys" "Duplicate of PuTTY's mungestr() in winstore.c:1.10 for Registry keys"
@ -33,26 +33,19 @@ def winmungestr(s):
candot = 1 candot = 1
return r return r
def strtolong(s): def strtoint(s):
"Convert arbitrary-length big-endian binary data to a Python long" "Convert arbitrary-length big-endian binary data to a Python int"
bytes = struct.unpack(">%luB" % len(s), s) bytes = struct.unpack(">{:d}B".format(len(s)), s)
return reduce ((lambda a, b: (long(a) << 8) + long(b)), bytes) return reduce ((lambda a, b: (int(a) << 8) + int(b)), bytes)
def strtolong_le(s): def strtoint_le(s):
"Convert arbitrary-length little-endian binary data to a Python long" "Convert arbitrary-length little-endian binary data to a Python int"
bytes = reversed(struct.unpack(">%luB" % len(s), s)) bytes = reversed(struct.unpack(">{:d}B".format(len(s)), s))
return reduce ((lambda a, b: (long(a) << 8) + long(b)), bytes) return reduce ((lambda a, b: (int(a) << 8) + int(b)), bytes)
def longtohex(n): def inttohex(n):
"""Convert long int to lower-case hex. "Convert int to lower-case hex."
return "0x{:x}".format(n)
Ick, Python (at least in 1.5.2) doesn't appear to have a way to
turn a long int into an unadorned hex string -- % gets upset if the
number is too big, and raw hex() uses uppercase (sometimes), and
adds unwanted "0x...L" around it."""
plain=string.lower(re.match(r"0x([0-9A-Fa-f]*)l?$", hex(n), re.I).group(1))
return "0x" + plain
def warn(s): def warn(s):
"Warning with file/line number" "Warning with file/line number"
@ -172,9 +165,9 @@ nist_curves = {
try: try:
optlist, args = getopt.getopt(sys.argv[1:], '', [ 'win', 'unix' ]) optlist, args = getopt.getopt(sys.argv[1:], '', [ 'win', 'unix' ])
if filter(lambda x: x[0] == '--unix', optlist): if any(x[0] == '--unix' for x in optlist):
output_type = 'unix' output_type = 'unix'
except getopt.error, e: except getopt.error as e:
sys.stderr.write(str(e) + "\n") sys.stderr.write(str(e) + "\n")
sys.exit(1) sys.exit(1)
@ -201,14 +194,14 @@ for line in fileinput.input(args):
try: try:
# Remove leading/trailing whitespace (should zap CR and LF) # Remove leading/trailing whitespace (should zap CR and LF)
line = string.strip (line) line = line.strip()
# Skip blanks and comments # Skip blanks and comments
if line == '' or line[0] == '#': if line == '' or line[0] == '#':
raise BlankInputLine raise BlankInputLine
# Split line on spaces. # Split line on spaces.
fields = string.split (line, ' ') fields = line.split(' ')
# Common fields # Common fields
hostpat = fields[0] hostpat = fields[0]
@ -222,14 +215,15 @@ for line in fileinput.input(args):
# Treat as SSH-1-type host key. # Treat as SSH-1-type host key.
# Format: hostpat bits10 exp10 mod10 comment... # Format: hostpat bits10 exp10 mod10 comment...
# (PuTTY doesn't store the number of bits.) # (PuTTY doesn't store the number of bits.)
keyparams = map (long, fields[2:4]) keyparams = map (int, fields[2:4])
keytype = "rsa" keytype = "rsa"
else: else:
# Treat as SSH-2-type host key. # Treat as SSH-2-type host key.
# Format: hostpat keytype keyblob64 comment... # Format: hostpat keytype keyblob64 comment...
sshkeytype, blob = fields[1], base64.decodestring (fields[2]) sshkeytype, blob = fields[1], base64.decodestring(
fields[2].encode("ASCII"))
# 'blob' consists of a number of # 'blob' consists of a number of
# uint32 N (big-endian) # uint32 N (big-endian)
@ -244,7 +238,7 @@ for line in fileinput.input(args):
blob = blob [struct.calcsize(sizefmt) + size : ] blob = blob [struct.calcsize(sizefmt) + size : ]
# The first field is keytype again. # The first field is keytype again.
if subfields[0] != sshkeytype: if subfields[0].decode("ASCII") != sshkeytype:
raise KeyFormatError(""" raise KeyFormatError("""
outer and embedded key types do not match: '%s', '%s' outer and embedded key types do not match: '%s', '%s'
""" % (sshkeytype, subfields[1])) """ % (sshkeytype, subfields[1]))
@ -255,12 +249,12 @@ for line in fileinput.input(args):
keytype = "rsa2" keytype = "rsa2"
# The rest of the subfields we can treat as an opaque list # The rest of the subfields we can treat as an opaque list
# of bignums (same numbers and order as stored by PuTTY). # of bignums (same numbers and order as stored by PuTTY).
keyparams = map (strtolong, subfields[1:]) keyparams = map (strtoint, subfields[1:])
elif sshkeytype == "ssh-dss": elif sshkeytype == "ssh-dss":
keytype = "dss" keytype = "dss"
# Same again. # Same again.
keyparams = map (strtolong, subfields[1:]) keyparams = map (strtoint, subfields[1:])
elif sshkeytype in nist_curves: elif sshkeytype in nist_curves:
keytype = sshkeytype keytype = sshkeytype
@ -269,13 +263,13 @@ for line in fileinput.input(args):
raise KeyFormatError("too many subfields in blob") raise KeyFormatError("too many subfields in blob")
(curvename, Q) = subfields[1:] (curvename, Q) = subfields[1:]
# First is yet another copy of the key name. # First is yet another copy of the key name.
if not re.match("ecdsa-sha2-" + re.escape(curvename), if not re.match("ecdsa-sha2-" + re.escape(
sshkeytype): curvename.decode("ASCII")), sshkeytype):
raise KeyFormatError("key type mismatch ('%s' vs '%s')" raise KeyFormatError("key type mismatch ('%s' vs '%s')"
% (sshkeytype, curvename)) % (sshkeytype, curvename))
# Second contains key material X and Y (hopefully). # Second contains key material X and Y (hopefully).
# First a magic octet indicating point compression. # First a magic octet indicating point compression.
point_type = struct.unpack("B", Q[0])[0] point_type = struct.unpack_from("B", Q, 0)[0]
Qrest = Q[1:] Qrest = Q[1:]
if point_type == 4: if point_type == 4:
# Then two equal-length bignums (X and Y). # Then two equal-length bignums (X and Y).
@ -283,14 +277,14 @@ for line in fileinput.input(args):
if (bnlen % 1) != 0: if (bnlen % 1) != 0:
raise KeyFormatError("odd-length X+Y") raise KeyFormatError("odd-length X+Y")
bnlen = bnlen // 2 bnlen = bnlen // 2
x = strtolong(Qrest[:bnlen]) x = strtoint(Qrest[:bnlen])
y = strtolong(Qrest[bnlen:]) y = strtoint(Qrest[bnlen:])
elif 2 <= point_type <= 3: elif 2 <= point_type <= 3:
# A compressed point just specifies X, and leaves # A compressed point just specifies X, and leaves
# Y implicit except for parity, so we have to # Y implicit except for parity, so we have to
# recover it from the curve equation. # recover it from the curve equation.
curve = nist_curves[sshkeytype] curve = nist_curves[sshkeytype]
x = strtolong(Qrest) x = strtoint(Qrest)
yy = (x*x*x + curve.a*x + curve.b) % curve.p yy = (x*x*x + curve.a*x + curve.b) % curve.p
y = SqrtModP.root(yy, curve.p) y = SqrtModP.root(yy, curve.p)
if y % 2 != point_type % 2: if y % 2 != point_type % 2:
@ -303,13 +297,10 @@ for line in fileinput.input(args):
if len(subfields) != 2: if len(subfields) != 2:
raise KeyFormatError("wrong number of subfields in blob") raise KeyFormatError("wrong number of subfields in blob")
if subfields[0] != sshkeytype:
raise KeyFormatError("key type mismatch ('%s' vs '%s')"
% (sshkeytype, subfields[0]))
# Key material y, with the top bit being repurposed as # Key material y, with the top bit being repurposed as
# the expected parity of the associated x (point # the expected parity of the associated x (point
# compression). # compression).
y = strtolong_le(subfields[1]) y = strtoint_le(subfields[1])
x_parity = y >> 255 x_parity = y >> 255
y &= ~(1 << 255) y &= ~(1 << 255)
@ -332,7 +323,7 @@ for line in fileinput.input(args):
raise UnknownKeyType(sshkeytype) raise UnknownKeyType(sshkeytype)
# Now print out one line per host pattern, discarding wildcards. # Now print out one line per host pattern, discarding wildcards.
for host in string.split (hostpat, ','): for host in hostpat.split(','):
if re.search (r"[*?!]", host): if re.search (r"[*?!]", host):
warn("skipping wildcard host pattern '%s'" % host) warn("skipping wildcard host pattern '%s'" % host)
continue continue
@ -351,9 +342,10 @@ for line in fileinput.input(args):
key = keytype + ("@%d:%s" % (port, host)) key = keytype + ("@%d:%s" % (port, host))
# Most of these are numbers, but there's the occasional # Most of these are numbers, but there's the occasional
# string that needs passing through # string that needs passing through
value = string.join (map ( value = ",".join(map(
lambda x: x if isinstance(x, basestring) else longtohex(x), lambda x: x if isinstance(x, str)
keyparams), ',') else x.decode('ASCII') if isinstance(x, bytes)
else inttohex(x), keyparams))
if output_type == 'unix': if output_type == 'unix':
# Unix format. # Unix format.
sys.stdout.write('%s %s\n' % (key, value)) sys.stdout.write('%s %s\n' % (key, value))
@ -363,9 +355,9 @@ for line in fileinput.input(args):
sys.stdout.write("\"%s\"=\"%s\"\n" sys.stdout.write("\"%s\"=\"%s\"\n"
% (winmungestr(key), value)) % (winmungestr(key), value))
except UnknownKeyType, k: except UnknownKeyType as k:
warn("unknown SSH key type '%s', skipping" % k.keytype) warn("unknown SSH key type '%s', skipping" % k.keytype)
except KeyFormatError, k: except KeyFormatError as k:
warn("trouble parsing key (%s), skipping" % k.msg) warn("trouble parsing key (%s), skipping" % k.msg)
except BlankInputLine: except BlankInputLine:
pass pass