1
0
mirror of https://git.tartarus.org/simon/putty.git synced 2025-01-10 01:48:00 +00:00
putty-source/contrib/kh2reg.py
Simon Tatham fade8e81bf kh2reg: stop using deprecated base64.decodestring.
Python 3 gave me a warning that I should have been using decodebytes
instead.

(cherry picked from commit 1efded20a1)
2020-06-14 15:49:36 +01:00

441 lines
16 KiB
Python
Executable File

#! /usr/bin/env python
# Convert OpenSSH known_hosts and known_hosts2 files to "new format" PuTTY
# host keys.
# usage:
# kh2reg.py [ --win ] known_hosts1 2 3 4 ... > hosts.reg
# Creates a Windows .REG file (double-click to install).
# kh2reg.py --unix known_hosts1 2 3 4 ... > sshhostkeys
# Creates data suitable for storing in ~/.putty/sshhostkeys (Unix).
# Line endings are someone else's problem as is traditional.
# Should run under either Python 2 or 3.
import fileinput
import base64
import struct
import string
import re
import sys
import argparse
import itertools
import collections
import hashlib
from functools import reduce
def winmungestr(s):
"Duplicate of PuTTY's mungestr() in winstore.c:1.10 for Registry keys"
candot = 0
r = ""
for c in s:
if c in ' \*?%~' or ord(c)<ord(' ') or (c == '.' and not candot):
r = r + ("%%%02X" % ord(c))
else:
r = r + c
candot = 1
return r
def strtoint(s):
"Convert arbitrary-length big-endian binary data to a Python int"
bytes = struct.unpack(">{:d}B".format(len(s)), s)
return reduce ((lambda a, b: (int(a) << 8) + int(b)), bytes)
def strtoint_le(s):
"Convert arbitrary-length little-endian binary data to a Python int"
bytes = reversed(struct.unpack(">{:d}B".format(len(s)), s))
return reduce ((lambda a, b: (int(a) << 8) + int(b)), bytes)
def inttohex(n):
"Convert int to lower-case hex."
return "0x{:x}".format(n)
def warn(s):
"Warning with file/line number"
sys.stderr.write("%s:%d: %s\n"
% (fileinput.filename(), fileinput.filelineno(), s))
class HMAC(object):
def __init__(self, hashclass, blocksize):
self.hashclass = hashclass
self.blocksize = blocksize
self.struct = struct.Struct(">{:d}B".format(self.blocksize))
def pad_key(self, key):
return key + b'\0' * (self.blocksize - len(key))
def xor_key(self, key, xor):
return self.struct.pack(*[b ^ xor for b in self.struct.unpack(key)])
def keyed_hash(self, key, padbyte, string):
return self.hashclass(self.xor_key(key, padbyte) + string).digest()
def compute(self, key, string):
if len(key) > self.blocksize:
key = self.hashclass(key).digest()
key = self.pad_key(key)
return self.keyed_hash(key, 0x5C, self.keyed_hash(key, 0x36, string))
def openssh_hashed_host_match(hashed_host, try_host):
if hashed_host.startswith(b'|1|'):
salt, expected = hashed_host[3:].split(b'|')
salt = base64.decodebytes(salt)
expected = base64.decodebytes(expected)
mac = HMAC(hashlib.sha1, 64)
else:
return False # unrecognised magic number prefix
return mac.compute(salt, try_host) == expected
def invert(n, p):
"""Compute inverse mod p."""
if n % p == 0:
raise ZeroDivisionError()
a = n, 1, 0
b = p, 0, 1
while b[0]:
q = a[0] // b[0]
a = a[0] - q*b[0], a[1] - q*b[1], a[2] - q*b[2]
b, a = a, b
assert abs(a[0]) == 1
return a[1]*a[0]
def jacobi(n,m):
"""Compute the Jacobi symbol.
The special case of this when m is prime is the Legendre symbol,
which is 0 if n is congruent to 0 mod m; 1 if n is congruent to a
non-zero square number mod m; -1 if n is not congruent to any
square mod m.
"""
assert m & 1
acc = 1
while True:
n %= m
if n == 0:
return 0
while not (n & 1):
n >>= 1
if (m & 7) not in {1,7}:
acc *= -1
if n == 1:
return acc
if (n & 3) == 3 and (m & 3) == 3:
acc *= -1
n, m = m, n
class SqrtModP(object):
"""Class for finding square roots of numbers mod p.
p must be an odd prime (but its primality is not checked)."""
def __init__(self, p):
p = abs(p)
assert p & 1
self.p = p
# Decompose p as 2^e k + 1 for odd k.
self.k = p-1
self.e = 0
while not (self.k & 1):
self.k >>= 1
self.e += 1
# Find a non-square mod p.
for self.z in itertools.count(1):
if jacobi(self.z, self.p) == -1:
break
self.zinv = invert(self.z, self.p)
def sqrt_recurse(self, a):
ak = pow(a, self.k, self.p)
for i in range(self.e, -1, -1):
if ak == 1:
break
ak = ak*ak % self.p
assert i > 0
if i == self.e:
return pow(a, (self.k+1) // 2, self.p)
r_prime = self.sqrt_recurse(a * pow(self.z, 2**i, self.p))
return r_prime * pow(self.zinv, 2**(i-1), self.p) % self.p
def sqrt(self, a):
j = jacobi(a, self.p)
if j == 0:
return 0
if j < 0:
raise ValueError("{} has no square root mod {}".format(a, self.p))
a %= self.p
r = self.sqrt_recurse(a)
assert r*r % self.p == a
# Normalise to the smaller (or 'positive') one of the two roots.
return min(r, self.p - r)
def __str__(self):
return "{}({})".format(type(self).__name__, self.p)
def __repr__(self):
return self.__str__()
instances = {}
@classmethod
def make(cls, p):
if p not in cls.instances:
cls.instances[p] = cls(p)
return cls.instances[p]
@classmethod
def root(cls, n, p):
return cls.make(p).sqrt(n)
NistCurve = collections.namedtuple("NistCurve", "p a b")
nist_curves = {
"ecdsa-sha2-nistp256": NistCurve(0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff, 0xffffffff00000001000000000000000000000000fffffffffffffffffffffffc, 0x5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b),
"ecdsa-sha2-nistp384": NistCurve(0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000ffffffff, 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000fffffffc, 0xb3312fa7e23ee7e4988e056be3f82d19181d9c6efe8141120314088f5013875ac656398d8a2ed19d2a85c8edd3ec2aef),
"ecdsa-sha2-nistp521": NistCurve(0x01ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff, 0x01fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffc, 0x0051953eb9618e1c9a1f929a21a0b68540eea2da725b99b315f3b8b489918ef109e156193951ec7e937b1652c0bd3bb1bf073573df883d2c34f1ef451fd46b503f00),
}
class BlankInputLine(Exception):
pass
class UnknownKeyType(Exception):
def __init__(self, keytype):
self.keytype = keytype
class KeyFormatError(Exception):
def __init__(self, msg):
self.msg = msg
def handle_line(line, output_formatter, try_hosts):
try:
# Remove leading/trailing whitespace (should zap CR and LF)
line = line.strip()
# Skip blanks and comments
if line == '' or line[0] == '#':
raise BlankInputLine
# Split line on spaces.
fields = line.split(' ')
# Common fields
hostpat = fields[0]
keyparams = [] # placeholder
keytype = "" # placeholder
# Grotty heuristic to distinguish known_hosts from known_hosts2:
# is second field entirely decimal digits?
if re.match (r"\d*$", fields[1]):
# Treat as SSH-1-type host key.
# Format: hostpat bits10 exp10 mod10 comment...
# (PuTTY doesn't store the number of bits.)
keyparams = list(map(int, fields[2:4]))
keytype = "rsa"
else:
# Treat as SSH-2-type host key.
# Format: hostpat keytype keyblob64 comment...
sshkeytype, blob = fields[1], base64.decodebytes(
fields[2].encode("ASCII"))
# 'blob' consists of a number of
# uint32 N (big-endian)
# uint8[N] field_data
subfields = []
while blob:
sizefmt = ">L"
(size,) = struct.unpack (sizefmt, blob[0:4])
size = int(size) # req'd for slicage
(data,) = struct.unpack (">%lus" % size, blob[4:size+4])
subfields.append(data)
blob = blob [struct.calcsize(sizefmt) + size : ]
# The first field is keytype again.
if subfields[0].decode("ASCII") != sshkeytype:
raise KeyFormatError("""
outer and embedded key types do not match: '%s', '%s'
""" % (sshkeytype, subfields[1]))
# Translate key type string into something PuTTY can use, and
# munge the rest of the data.
if sshkeytype == "ssh-rsa":
keytype = "rsa2"
# The rest of the subfields we can treat as an opaque list
# of bignums (same numbers and order as stored by PuTTY).
keyparams = list(map(strtoint, subfields[1:]))
elif sshkeytype == "ssh-dss":
keytype = "dss"
# Same again.
keyparams = list(map(strtoint, subfields[1:]))
elif sshkeytype in nist_curves:
keytype = sshkeytype
# Have to parse this a bit.
if len(subfields) > 3:
raise KeyFormatError("too many subfields in blob")
(curvename, Q) = subfields[1:]
# First is yet another copy of the key name.
if not re.match("ecdsa-sha2-" + re.escape(
curvename.decode("ASCII")), sshkeytype):
raise KeyFormatError("key type mismatch ('%s' vs '%s')"
% (sshkeytype, curvename))
# Second contains key material X and Y (hopefully).
# First a magic octet indicating point compression.
point_type = struct.unpack_from("B", Q, 0)[0]
Qrest = Q[1:]
if point_type == 4:
# Then two equal-length bignums (X and Y).
bnlen = len(Qrest)
if (bnlen % 1) != 0:
raise KeyFormatError("odd-length X+Y")
bnlen = bnlen // 2
x = strtoint(Qrest[:bnlen])
y = strtoint(Qrest[bnlen:])
elif 2 <= point_type <= 3:
# A compressed point just specifies X, and leaves
# Y implicit except for parity, so we have to
# recover it from the curve equation.
curve = nist_curves[sshkeytype]
x = strtoint(Qrest)
yy = (x*x*x + curve.a*x + curve.b) % curve.p
y = SqrtModP.root(yy, curve.p)
if y % 2 != point_type % 2:
y = curve.p - y
keyparams = [curvename, x, y]
elif sshkeytype == "ssh-ed25519":
keytype = sshkeytype
if len(subfields) != 2:
raise KeyFormatError("wrong number of subfields in blob")
# Key material y, with the top bit being repurposed as
# the expected parity of the associated x (point
# compression).
y = strtoint_le(subfields[1])
x_parity = y >> 255
y &= ~(1 << 255)
# Standard Ed25519 parameters.
p = 2**255 - 19
d = 0x52036cee2b6ffe738cc740797779e89800700a4d4141d8ab75eb4dca135978a3
# Recover x^2 = (y^2 - 1) / (d y^2 + 1).
xx = (y*y - 1) * invert(d*y*y + 1, p) % p
# Take the square root.
x = SqrtModP.root(xx, p)
# Pick the square root of the correct parity.
if (x % 2) != x_parity:
x = p - x
keyparams = [x, y]
else:
raise UnknownKeyType(sshkeytype)
# Now print out one line per host pattern, discarding wildcards.
for host in hostpat.split(','):
if re.search (r"[*?!]", host):
warn("skipping wildcard host pattern '%s'" % host)
continue
if re.match (r"\|", host):
for try_host in try_hosts:
if openssh_hashed_host_match(host.encode('ASCII'),
try_host.encode('UTF-8')):
host = try_host
break
else:
warn("unable to match hashed hostname '%s'" % host)
continue
m = re.match (r"\[([^]]*)\]:(\d*)$", host)
if m:
(host, port) = m.group(1,2)
port = int(port)
else:
port = 22
# Slightly bizarre output key format: 'type@port:hostname'
# XXX: does PuTTY do anything useful with literal IP[v4]s?
key = keytype + ("@%d:%s" % (port, host))
# Most of these are numbers, but there's the occasional
# string that needs passing through
value = ",".join(map(
lambda x: x if isinstance(x, str)
else x.decode('ASCII') if isinstance(x, bytes)
else inttohex(x), keyparams))
output_formatter.key(key, value)
except UnknownKeyType as k:
warn("unknown SSH key type '%s', skipping" % k.keytype)
except KeyFormatError as k:
warn("trouble parsing key (%s), skipping" % k.msg)
except BlankInputLine:
pass
class OutputFormatter(object):
def __init__(self, fh):
self.fh = fh
def header(self):
pass
def trailer(self):
pass
class WindowsOutputFormatter(OutputFormatter):
def header(self):
# Output REG file header.
self.fh.write("""REGEDIT4
[HKEY_CURRENT_USER\Software\SimonTatham\PuTTY\SshHostKeys]
""")
def key(self, key, value):
# XXX: worry about double quotes?
self.fh.write("\"%s\"=\"%s\"\n" % (winmungestr(key), value))
def trailer(self):
# The spec at http://support.microsoft.com/kb/310516 says we need
# a blank line at the end of the reg file:
#
# Note the registry file should contain a blank line at the
# bottom of the file.
#
self.fh.write("\n")
class UnixOutputFormatter(OutputFormatter):
def key(self, key, value):
self.fh.write('%s %s\n' % (key, value))
def main():
parser = argparse.ArgumentParser(
description="Convert OpenSSH known hosts files to PuTTY's format.")
group = parser.add_mutually_exclusive_group()
group.add_argument(
"--windows", "--win", action='store_const',
dest="output_formatter_class", const=WindowsOutputFormatter,
help="Produce Windows .reg file output that regedit.exe can import"
" (default).")
group.add_argument(
"--unix", action='store_const',
dest="output_formatter_class", const=UnixOutputFormatter,
help="Produce a file suitable for use as ~/.putty/sshhostkeys.")
parser.add_argument("-o", "--output", type=argparse.FileType("w"),
default=argparse.FileType("w")("-"),
help="Output file to write to (default stdout).")
parser.add_argument("--hostname", action="append",
help="Host name(s) to try matching against hashed "
"host entries in input.")
parser.add_argument("infile", nargs="*",
help="Input file(s) to read from (default stdin).")
parser.set_defaults(output_formatter_class=WindowsOutputFormatter,
hostname=[])
args = parser.parse_args()
output_formatter = args.output_formatter_class(args.output)
output_formatter.header()
for line in fileinput.input(args.infile):
handle_line(line, output_formatter, args.hostname)
output_formatter.trailer()
if __name__ == "__main__":
main()