#!/usr/bin/env python3

# Convert OpenSSH known_hosts and known_hosts2 files to "new format" PuTTY
# host keys.
#   usage:
#     kh2reg.py [ --win ] known_hosts1 2 3 4 ... > hosts.reg
#       Creates a Windows .REG file (double-click to install).
#     kh2reg.py --unix    known_hosts1 2 3 4 ... > sshhostkeys
#       Creates data suitable for storing in ~/.putty/sshhostkeys (Unix).
# Line endings are someone else's problem as is traditional.
# Should run under either Python 2 or 3.

import fileinput
import base64
import struct
import string
import re
import sys
import argparse
import itertools
import collections
import hashlib
from functools import reduce

def winmungestr(s):
    "Duplicate of PuTTY's mungestr() in winstore.c:1.10 for Registry keys"
    candot = 0
    r = ""
    for c in s:
        if c in ' \*?%~' or ord(c)<ord(' ') or (c == '.' and not candot):
            r = r + ("%%%02X" % ord(c))
        else:
            r = r + c
        candot = 1
    return r

def strtoint(s):
    "Convert arbitrary-length big-endian binary data to a Python int"
    bytes = struct.unpack(">{:d}B".format(len(s)), s)
    return reduce ((lambda a, b: (int(a) << 8) + int(b)), bytes)

def strtoint_le(s):
    "Convert arbitrary-length little-endian binary data to a Python int"
    bytes = reversed(struct.unpack(">{:d}B".format(len(s)), s))
    return reduce ((lambda a, b: (int(a) << 8) + int(b)), bytes)

def inttohex(n):
    "Convert int to lower-case hex."
    return "0x{:x}".format(n)

def warn(s):
    "Warning with file/line number"
    sys.stderr.write("%s:%d: %s\n"
                     % (fileinput.filename(), fileinput.filelineno(), s))

class HMAC(object):
    def __init__(self, hashclass, blocksize):
        self.hashclass = hashclass
        self.blocksize = blocksize
        self.struct = struct.Struct(">{:d}B".format(self.blocksize))
    def pad_key(self, key):
        return key + b'\0' * (self.blocksize - len(key))
    def xor_key(self, key, xor):
        return self.struct.pack(*[b ^ xor for b in self.struct.unpack(key)])
    def keyed_hash(self, key, padbyte, string):
        return self.hashclass(self.xor_key(key, padbyte) + string).digest()
    def compute(self, key, string):
        if len(key) > self.blocksize:
            key = self.hashclass(key).digest()
        key = self.pad_key(key)
        return self.keyed_hash(key, 0x5C, self.keyed_hash(key, 0x36, string))

def openssh_hashed_host_match(hashed_host, try_host):
    if hashed_host.startswith(b'|1|'):
        salt, expected = hashed_host[3:].split(b'|')
        salt = base64.decodebytes(salt)
        expected = base64.decodebytes(expected)
        mac = HMAC(hashlib.sha1, 64)
    else:
        return False # unrecognised magic number prefix

    return mac.compute(salt, try_host) == expected

def invert(n, p):
    """Compute inverse mod p."""
    if n % p == 0:
        raise ZeroDivisionError()
    a = n, 1, 0
    b = p, 0, 1
    while b[0]:
        q = a[0] // b[0]
        a = a[0] - q*b[0], a[1] - q*b[1], a[2] - q*b[2]
        b, a = a, b
    assert abs(a[0]) == 1
    return a[1]*a[0]

def jacobi(n,m):
    """Compute the Jacobi symbol.

    The special case of this when m is prime is the Legendre symbol,
    which is 0 if n is congruent to 0 mod m; 1 if n is congruent to a
    non-zero square number mod m; -1 if n is not congruent to any
    square mod m.

    """
    assert m & 1
    acc = 1
    while True:
        n %= m
        if n == 0:
            return 0
        while not (n & 1):
            n >>= 1
            if (m & 7) not in {1,7}:
                acc *= -1
        if n == 1:
            return acc
        if (n & 3) == 3 and (m & 3) == 3:
            acc *= -1
        n, m = m, n

class SqrtModP(object):
    """Class for finding square roots of numbers mod p.

    p must be an odd prime (but its primality is not checked)."""

    def __init__(self, p):
        p = abs(p)
        assert p & 1
        self.p = p

        # Decompose p as 2^e k + 1 for odd k.
        self.k = p-1
        self.e = 0
        while not (self.k & 1):
            self.k >>= 1
            self.e += 1

        # Find a non-square mod p.
        for self.z in itertools.count(1):
            if jacobi(self.z, self.p) == -1:
                break
        self.zinv = invert(self.z, self.p)

    def sqrt_recurse(self, a):
        ak = pow(a, self.k, self.p)
        for i in range(self.e, -1, -1):
            if ak == 1:
                break
            ak = ak*ak % self.p
        assert i > 0
        if i == self.e:
            return pow(a, (self.k+1) // 2, self.p)
        r_prime = self.sqrt_recurse(a * pow(self.z, 2**i, self.p))
        return r_prime * pow(self.zinv, 2**(i-1), self.p) % self.p

    def sqrt(self, a):
        j = jacobi(a, self.p)
        if j == 0:
            return 0
        if j < 0:
            raise ValueError("{} has no square root mod {}".format(a, self.p))
        a %= self.p
        r = self.sqrt_recurse(a)
        assert r*r % self.p == a
        # Normalise to the smaller (or 'positive') one of the two roots.
        return min(r, self.p - r)

    def __str__(self):
        return "{}({})".format(type(self).__name__, self.p)
    def __repr__(self):
        return self.__str__()

    instances = {}

    @classmethod
    def make(cls, p):
        if p not in cls.instances:
            cls.instances[p] = cls(p)
        return cls.instances[p]

    @classmethod
    def root(cls, n, p):
        return cls.make(p).sqrt(n)

NistCurve = collections.namedtuple("NistCurve", "p a b")
nist_curves = {
    "ecdsa-sha2-nistp256": NistCurve(0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff, 0xffffffff00000001000000000000000000000000fffffffffffffffffffffffc, 0x5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b),
    "ecdsa-sha2-nistp384": NistCurve(0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000ffffffff, 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000fffffffc, 0xb3312fa7e23ee7e4988e056be3f82d19181d9c6efe8141120314088f5013875ac656398d8a2ed19d2a85c8edd3ec2aef),
    "ecdsa-sha2-nistp521": NistCurve(0x01ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff, 0x01fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffc, 0x0051953eb9618e1c9a1f929a21a0b68540eea2da725b99b315f3b8b489918ef109e156193951ec7e937b1652c0bd3bb1bf073573df883d2c34f1ef451fd46b503f00),
}

class BlankInputLine(Exception):
    pass

class UnknownKeyType(Exception):
    def __init__(self, keytype):
        self.keytype = keytype

class KeyFormatError(Exception):
    def __init__(self, msg):
        self.msg = msg

def handle_line(line, output_formatter, try_hosts):
    try:
        # Remove leading/trailing whitespace (should zap CR and LF)
        line = line.strip()

        # Skip blanks and comments
        if line == '' or line[0] == '#':
            raise BlankInputLine

        # Split line on spaces.
        fields = line.split(' ')

        # Common fields
        hostpat = fields[0]
        keyparams = []      # placeholder
        keytype = ""        # placeholder

        # Grotty heuristic to distinguish known_hosts from known_hosts2:
        # is second field entirely decimal digits?
        if re.match (r"\d*$", fields[1]):

            # Treat as SSH-1-type host key.
            # Format: hostpat bits10 exp10 mod10 comment...
            # (PuTTY doesn't store the number of bits.)
            keyparams = list(map(int, fields[2:4]))
            keytype = "rsa"

        else:

            # Treat as SSH-2-type host key.
            # Format: hostpat keytype keyblob64 comment...
            sshkeytype, blob = fields[1], base64.decodebytes(
                fields[2].encode("ASCII"))

            # 'blob' consists of a number of
            #   uint32    N (big-endian)
            #   uint8[N]  field_data
            subfields = []
            while blob:
                sizefmt = ">L"
                (size,) = struct.unpack (sizefmt, blob[0:4])
                size = int(size)   # req'd for slicage
                (data,) = struct.unpack (">%lus" % size, blob[4:size+4])
                subfields.append(data)
                blob = blob [struct.calcsize(sizefmt) + size : ]

            # The first field is keytype again.
            if subfields[0].decode("ASCII") != sshkeytype:
                raise KeyFormatError("""
                    outer and embedded key types do not match: '%s', '%s'
                    """ % (sshkeytype, subfields[1]))

            # Translate key type string into something PuTTY can use, and
            # munge the rest of the data.
            if sshkeytype == "ssh-rsa":
                keytype = "rsa2"
                # The rest of the subfields we can treat as an opaque list
                # of bignums (same numbers and order as stored by PuTTY).
                keyparams = list(map(strtoint, subfields[1:]))

            elif sshkeytype == "ssh-dss":
                keytype = "dss"
                # Same again.
                keyparams = list(map(strtoint, subfields[1:]))

            elif sshkeytype in nist_curves:
                keytype = sshkeytype
                # Have to parse this a bit.
                if len(subfields) > 3:
                    raise KeyFormatError("too many subfields in blob")
                (curvename, Q) = subfields[1:]
                # First is yet another copy of the key name.
                if not re.match("ecdsa-sha2-" + re.escape(
                        curvename.decode("ASCII")), sshkeytype):
                    raise KeyFormatError("key type mismatch ('%s' vs '%s')"
                            % (sshkeytype, curvename))
                # Second contains key material X and Y (hopefully).
                # First a magic octet indicating point compression.
                point_type = struct.unpack_from("B", Q, 0)[0]
                Qrest = Q[1:]
                if point_type == 4:
                    # Then two equal-length bignums (X and Y).
                    bnlen = len(Qrest)
                    if (bnlen % 1) != 0:
                        raise KeyFormatError("odd-length X+Y")
                    bnlen = bnlen // 2
                    x = strtoint(Qrest[:bnlen])
                    y = strtoint(Qrest[bnlen:])
                elif 2 <= point_type <= 3:
                    # A compressed point just specifies X, and leaves
                    # Y implicit except for parity, so we have to
                    # recover it from the curve equation.
                    curve = nist_curves[sshkeytype]
                    x = strtoint(Qrest)
                    yy = (x*x*x + curve.a*x + curve.b) % curve.p
                    y = SqrtModP.root(yy, curve.p)
                    if y % 2 != point_type % 2:
                        y = curve.p - y

                keyparams = [curvename, x, y]

            elif sshkeytype in { "ssh-ed25519",  "ssh-ed448" }:
                keytype = sshkeytype

                if len(subfields) != 2:
                    raise KeyFormatError("wrong number of subfields in blob")
                # Key material y, with the top bit being repurposed as
                # the expected parity of the associated x (point
                # compression).
                y = strtoint_le(subfields[1])
                x_parity = y >> 255
                y &= ~(1 << 255)

                # Curve parameters.
                p, d, a = {
                    "ssh-ed25519": (2**255 - 19, 0x52036cee2b6ffe738cc740797779e89800700a4d4141d8ab75eb4dca135978a3, -1),
                    "ssh-ed448": (2**448-2**224-1, -39081, +1),
                }[sshkeytype]

                # Recover x^2 = (y^2 - 1) / (d y^2 - a).
                xx = (y*y - 1) * invert(d*y*y - a, p) % p

                # Take the square root.
                x = SqrtModP.root(xx, p)

                # Pick the square root of the correct parity.
                if (x % 2) != x_parity:
                    x = p - x

                keyparams = [x, y]
            else:
                raise UnknownKeyType(sshkeytype)

        # Now print out one line per host pattern, discarding wildcards.
        for host in hostpat.split(','):
            if re.search (r"[*?!]", host):
                warn("skipping wildcard host pattern '%s'" % host)
                continue

            if re.match (r"\|", host):
                for try_host in try_hosts:
                    if openssh_hashed_host_match(host.encode('ASCII'),
                                                 try_host.encode('UTF-8')):
                        host = try_host
                        break
                else:
                    warn("unable to match hashed hostname '%s'" % host)
                    continue

            m = re.match (r"\[([^]]*)\]:(\d*)$", host)
            if m:
                (host, port) = m.group(1,2)
                port = int(port)
            else:
                port = 22
            # Slightly bizarre output key format: 'type@port:hostname'
            # XXX: does PuTTY do anything useful with literal IP[v4]s?
            key = keytype + ("@%d:%s" % (port, host))
            # Most of these are numbers, but there's the occasional
            # string that needs passing through
            value = ",".join(map(
                lambda x: x if isinstance(x, str)
                else x.decode('ASCII') if isinstance(x, bytes)
                else inttohex(x), keyparams))
            output_formatter.key(key, value)

    except UnknownKeyType as k:
        warn("unknown SSH key type '%s', skipping" % k.keytype)
    except KeyFormatError as k:
        warn("trouble parsing key (%s), skipping" % k.msg)
    except BlankInputLine:
        pass

class OutputFormatter(object):
    def __init__(self, fh):
        self.fh = fh
    def header(self):
        pass
    def trailer(self):
        pass

class WindowsOutputFormatter(OutputFormatter):
    def header(self):
        # Output REG file header.
        self.fh.write("""REGEDIT4

[HKEY_CURRENT_USER\Software\SimonTatham\PuTTY\SshHostKeys]
""")

    def key(self, key, value):
        # XXX: worry about double quotes?
        self.fh.write("\"%s\"=\"%s\"\n" % (winmungestr(key), value))

    def trailer(self):
        # The spec at http://support.microsoft.com/kb/310516 says we need
        # a blank line at the end of the reg file:
        #
        #   Note the registry file should contain a blank line at the
        #   bottom of the file.
        #
        self.fh.write("\n")

class UnixOutputFormatter(OutputFormatter):
    def key(self, key, value):
        self.fh.write('%s %s\n' % (key, value))

def main():
    parser = argparse.ArgumentParser(
        description="Convert OpenSSH known hosts files to PuTTY's format.")
    group = parser.add_mutually_exclusive_group()
    group.add_argument(
        "--windows", "--win", action='store_const',
        dest="output_formatter_class", const=WindowsOutputFormatter,
        help="Produce Windows .reg file output that regedit.exe can import"
        " (default).")
    group.add_argument(
        "--unix", action='store_const',
        dest="output_formatter_class", const=UnixOutputFormatter,
        help="Produce a file suitable for use as ~/.putty/sshhostkeys.")
    parser.add_argument("-o", "--output", type=argparse.FileType("w"),
                        default=argparse.FileType("w")("-"),
                        help="Output file to write to (default stdout).")
    parser.add_argument("--hostname", action="append",
                        help="Host name(s) to try matching against hashed "
                        "host entries in input.")
    parser.add_argument("infile", nargs="*",
                        help="Input file(s) to read from (default stdin).")
    parser.set_defaults(output_formatter_class=WindowsOutputFormatter,
                        hostname=[])
    args = parser.parse_args()

    output_formatter = args.output_formatter_class(args.output)
    output_formatter.header()
    for line in fileinput.input(args.infile):
        handle_line(line, output_formatter, args.hostname)
    output_formatter.trailer()

if __name__ == "__main__":
    main()