diff --git a/contrib/kh2reg.py b/contrib/kh2reg.py index 32b871d0..272124d5 100755 --- a/contrib/kh2reg.py +++ b/contrib/kh2reg.py @@ -19,6 +19,7 @@ import sys import argparse import itertools import collections +import hashlib from functools import reduce def winmungestr(s): @@ -52,6 +53,34 @@ def warn(s): sys.stderr.write("%s:%d: %s\n" % (fileinput.filename(), fileinput.filelineno(), s)) +class HMAC(object): + def __init__(self, hashclass, blocksize): + self.hashclass = hashclass + self.blocksize = blocksize + self.struct = struct.Struct(">{:d}B".format(self.blocksize)) + def pad_key(self, key): + return key + b'\0' * (self.blocksize - len(key)) + def xor_key(self, key, xor): + return self.struct.pack(*[b ^ xor for b in self.struct.unpack(key)]) + def keyed_hash(self, key, padbyte, string): + return self.hashclass(self.xor_key(key, padbyte) + string).digest() + def compute(self, key, string): + if len(key) > self.blocksize: + key = self.hashclass(key).digest() + key = self.pad_key(key) + return self.keyed_hash(key, 0x5C, self.keyed_hash(key, 0x36, string)) + +def openssh_hashed_host_match(hashed_host, try_host): + if hashed_host.startswith(b'|1|'): + salt, expected = hashed_host[3:].split(b'|') + salt = base64.decodestring(salt) + expected = base64.decodestring(expected) + mac = HMAC(hashlib.sha1, 64) + else: + return False # unrecognised magic number prefix + + return mac.compute(salt, try_host) == expected + def invert(n, p): """Compute inverse mod p.""" if n % p == 0: @@ -172,7 +201,7 @@ class KeyFormatError(Exception): def __init__(self, msg): self.msg = msg -def handle_line(line, output_formatter): +def handle_line(line, output_formatter, try_hosts): try: # Remove leading/trailing whitespace (should zap CR and LF) line = line.strip() @@ -308,26 +337,33 @@ def handle_line(line, output_formatter): if re.search (r"[*?!]", host): warn("skipping wildcard host pattern '%s'" % host) continue - elif re.match (r"\|", host): - warn("skipping hashed hostname '%s'" % host) - continue - else: - m = re.match (r"\[([^]]*)\]:(\d*)$", host) - if m: - (host, port) = m.group(1,2) - port = int(port) + + if re.match (r"\|", host): + for try_host in try_hosts: + if openssh_hashed_host_match(host.encode('ASCII'), + try_host.encode('UTF-8')): + host = try_host + break else: - port = 22 - # Slightly bizarre output key format: 'type@port:hostname' - # XXX: does PuTTY do anything useful with literal IP[v4]s? - key = keytype + ("@%d:%s" % (port, host)) - # Most of these are numbers, but there's the occasional - # string that needs passing through - value = ",".join(map( - lambda x: x if isinstance(x, str) - else x.decode('ASCII') if isinstance(x, bytes) - else inttohex(x), keyparams)) - output_formatter.key(key, value) + warn("unable to match hashed hostname '%s'" % host) + continue + + m = re.match (r"\[([^]]*)\]:(\d*)$", host) + if m: + (host, port) = m.group(1,2) + port = int(port) + else: + port = 22 + # Slightly bizarre output key format: 'type@port:hostname' + # XXX: does PuTTY do anything useful with literal IP[v4]s? + key = keytype + ("@%d:%s" % (port, host)) + # Most of these are numbers, but there's the occasional + # string that needs passing through + value = ",".join(map( + lambda x: x if isinstance(x, str) + else x.decode('ASCII') if isinstance(x, bytes) + else inttohex(x), keyparams)) + output_formatter.key(key, value) except UnknownKeyType as k: warn("unknown SSH key type '%s', skipping" % k.keytype) @@ -385,15 +421,19 @@ def main(): parser.add_argument("-o", "--output", type=argparse.FileType("w"), default=argparse.FileType("w")("-"), help="Output file to write to (default stdout).") + parser.add_argument("--hostname", action="append", + help="Host name(s) to try matching against hashed " + "host entries in input.") parser.add_argument("infile", nargs="*", help="Input file(s) to read from (default stdin).") - parser.set_defaults(output_formatter_class=WindowsOutputFormatter) + parser.set_defaults(output_formatter_class=WindowsOutputFormatter, + hostname=[]) args = parser.parse_args() output_formatter = args.output_formatter_class(args.output) output_formatter.header() for line in fileinput.input(args.infile): - handle_line(line, output_formatter) + handle_line(line, output_formatter, args.hostname) output_formatter.trailer() if __name__ == "__main__":