import sys import os import numbers import subprocess import re import string import struct from binascii import hexlify assert sys.version_info[:2] >= (3,0), "This is Python 3 code" # Expect to be run from the 'test' subdirectory, one level down from # the main source putty_srcdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) def coerce_to_bytes(arg): return arg.encode("UTF-8") if isinstance(arg, str) else arg class ChildProcessFailure(Exception): pass class ChildProcess(object): def __init__(self): self.sp = None self.debug = None self.exitstatus = None self.exception = None dbg = os.environ.get("PUTTY_TESTCRYPT_DEBUG") if dbg is not None: if dbg == "stderr": self.debug = sys.stderr else: sys.stderr.write("Unknown value '{}' for PUTTY_TESTCRYPT_DEBUG" " (try 'stderr'\n") def start(self): assert self.sp is None override_command = os.environ.get("PUTTY_TESTCRYPT") if override_command is None: cmd = [os.path.join(putty_srcdir, "testcrypt")] shell = False else: cmd = override_command shell = True self.sp = subprocess.Popen( cmd, shell=shell, stdin=subprocess.PIPE, stdout=subprocess.PIPE) def write_line(self, line): if self.exception is not None: # Re-raise our fatal-error exception, if it previously # occurred in a context where it couldn't be propagated (a # __del__ method). raise self.exception if self.debug is not None: self.debug.write("send: {}\n".format(line)) self.sp.stdin.write(line + b"\n") self.sp.stdin.flush() def read_line(self): line = self.sp.stdout.readline() if len(line) == 0: self.exception = ChildProcessFailure("received EOF from testcrypt") raise self.exception line = line.rstrip(b"\r\n") if self.debug is not None: self.debug.write("recv: {}\n".format(line)) return line def already_terminated(self): return self.sp is None and self.exitstatus is not None def funcall(self, cmd, args): if self.sp is None: assert self.exitstatus is None self.start() self.write_line(coerce_to_bytes(cmd) + b" " + b" ".join( coerce_to_bytes(arg) for arg in args)) argcount = int(self.read_line()) return [self.read_line() for arg in range(argcount)] def wait_for_exit(self): if self.sp is not None: self.sp.stdin.close() self.exitstatus = self.sp.wait() self.sp = None def check_return_status(self): self.wait_for_exit() if self.exitstatus is not None and self.exitstatus != 0: raise ChildProcessFailure("testcrypt returned exit status {}" .format(self.exitstatus)) childprocess = ChildProcess() method_prefixes = { 'val_wpoint': ['ecc_weierstrass_'], 'val_mpoint': ['ecc_montgomery_'], 'val_epoint': ['ecc_edwards_'], 'val_hash': ['ssh_hash_'], 'val_mac': ['ssh2_mac_'], 'val_key': ['ssh_key_'], 'val_cipher': ['ssh_cipher_'], 'val_dh': ['dh_'], 'val_ecdh': ['ssh_ecdhkex_'], 'val_rsakex': ['ssh_rsakex_'], 'val_prng': ['prng_'], 'val_pcs': ['pcs_'], 'val_pockle': ['pockle_'], 'val_ntruencodeschedule': ['ntru_encode_schedule_', 'ntru_'], } method_lists = {t: [] for t in method_prefixes} checked_enum_values = {} class Value(object): def __init__(self, typename, ident): self._typename = typename self._ident = ident for methodname, function in method_lists.get(self._typename, []): setattr(self, methodname, (lambda f: lambda *args: f(self, *args))(function)) def _consumed(self): self._ident = None def __repr__(self): return "Value({!r}, {!r})".format(self._typename, self._ident) def __del__(self): if self._ident is not None and not childprocess.already_terminated(): try: childprocess.funcall("free", [self._ident]) except ChildProcessFailure: # If we see this exception now, we can't do anything # about it, because exceptions don't propagate out of # __del__ methods. Squelch it to prevent the annoying # runtime warning from Python, and the # 'self.exception' mechanism in the ChildProcess class # will raise it again at the next opportunity. # # (This covers both the case where testcrypt crashes # _during_ one of these free operations, and the # silencing of cascade failures when we try to send a # "free" command to testcrypt after it had already # crashed for some other reason.) pass def __long__(self): if self._typename != "val_mpint": raise TypeError("testcrypt values of types other than mpint" " cannot be converted to integer") hexval = childprocess.funcall("mp_dump", [self._ident])[0] return 0 if len(hexval) == 0 else int(hexval, 16) def __int__(self): return int(self.__long__()) def marshal_string(val): val = coerce_to_bytes(val) assert isinstance(val, bytes), "Bad type for val_string input" return "".join( chr(b) if (0x20 <= b < 0x7F and b != 0x25) else "%{:02x}".format(b) for b in val) def make_argword(arg, argtype, fnname, argindex, argname, to_preserve): typename, consumed = argtype if typename.startswith("opt_"): if arg is None: return "NULL" typename = typename[4:] if typename == "val_string": retwords = childprocess.funcall("newstring", [marshal_string(arg)]) arg = make_retvals([typename], retwords, unpack_strings=False)[0] to_preserve.append(arg) if typename == "val_mpint" and isinstance(arg, numbers.Integral): retwords = childprocess.funcall("mp_literal", ["0x{:x}".format(arg)]) arg = make_retvals([typename], retwords)[0] to_preserve.append(arg) if isinstance(arg, Value): if arg._typename != typename: raise TypeError( "{}() argument #{:d} ({}) should be {} ({} given)".format( fnname, argindex, argname, typename, arg._typename)) ident = arg._ident if consumed: arg._consumed() return ident if typename == "uint" and isinstance(arg, numbers.Integral): return "0x{:x}".format(arg) if typename == "boolean": return "true" if arg else "false" if typename in { "hashalg", "macalg", "keyalg", "cipheralg", "dh_group", "ecdh_alg", "rsaorder", "primegenpolicy", "argon2flavour", "fptype", "httpdigesthash"}: arg = coerce_to_bytes(arg) if isinstance(arg, bytes) and b" " not in arg: dictkey = (typename, arg) if dictkey not in checked_enum_values: retwords = childprocess.funcall("checkenum", [typename, arg]) assert len(retwords) == 1 checked_enum_values[dictkey] = (retwords[0] == b"ok") if checked_enum_values[dictkey]: return arg if typename == "mpint_list": sublist = [make_argword(len(arg), ("uint", False), fnname, argindex, argname, to_preserve)] for val in arg: sublist.append(make_argword(val, ("val_mpint", False), fnname, argindex, argname, to_preserve)) return b" ".join(coerce_to_bytes(sub) for sub in sublist) raise TypeError( "Can't convert {}() argument #{:d} ({}) to {} (value was {!r})".format( fnname, argindex, argname, typename, arg)) def unpack_string(identifier): retwords = childprocess.funcall("getstring", [identifier]) childprocess.funcall("free", [identifier]) return re.sub(b"%[0-9A-F][0-9A-F]", lambda m: bytes([int(m.group(0)[1:], 16)]), retwords[0]) def unpack_mp(identifier): retwords = childprocess.funcall("mp_dump", [identifier]) childprocess.funcall("free", [identifier]) return int(retwords[0], 16) def make_retval(rettype, word, unpack_strings): if rettype.startswith("opt_"): if word == b"NULL": return None rettype = rettype[4:] if rettype == "val_string" and unpack_strings: return unpack_string(word) if rettype == "val_keycomponents": kc = {} retwords = childprocess.funcall("key_components_count", [word]) for i in range(int(retwords[0], 0)): args = [word, "{:d}".format(i)] retwords = childprocess.funcall("key_components_nth_name", args) kc_key = unpack_string(retwords[0]) retwords = childprocess.funcall("key_components_nth_str", args) if retwords[0] != b"NULL": kc_value = unpack_string(retwords[0]).decode("ASCII") else: retwords = childprocess.funcall("key_components_nth_mp", args) kc_value = unpack_mp(retwords[0]) kc[kc_key.decode("ASCII")] = kc_value childprocess.funcall("free", [word]) return kc if rettype.startswith("val_"): return Value(rettype, word) elif rettype == "int" or rettype == "uint": return int(word, 0) elif rettype == "boolean": assert word == b"true" or word == b"false" return word == b"true" elif rettype in {"pocklestatus", "mr_result"}: return word.decode("ASCII") raise TypeError("Can't deal with return value {!r} of type {!r}" .format(word, rettype)) def make_retvals(rettypes, retwords, unpack_strings=True): assert len(rettypes) == len(retwords) # FIXME: better exception return [make_retval(rettype, word, unpack_strings) for rettype, word in zip(rettypes, retwords)] class Function(object): def __init__(self, fnname, rettypes, retnames, argtypes, argnames): self.fnname = fnname self.rettypes = rettypes self.retnames = retnames self.argtypes = argtypes self.argnames = argnames def __repr__(self): return " ({})>".format( self.fnname, ", ".join(("consumed " if c else "")+t+" "+n for (t,c),n in zip(self.argtypes, self.argnames)), ", ".join((t+" "+n if n is not None else t) for t,n in zip(self.rettypes, self.retnames)), ) def __call__(self, *args): if len(args) != len(self.argtypes): raise TypeError( "{}() takes exactly {} arguments ({} given)".format( self.fnname, len(self.argtypes), len(args))) to_preserve = [] retwords = childprocess.funcall( self.fnname, [make_argword(args[i], self.argtypes[i], self.fnname, i, self.argnames[i], to_preserve) for i in range(len(args))]) retvals = make_retvals(self.rettypes, retwords) if len(retvals) == 0: return None if len(retvals) == 1: return retvals[0] return tuple(retvals) def _lex_testcrypt_header(header): pat = re.compile( # Skip any combination of whitespace and comments '(?:{})*'.format('|'.join(( '[ \t\n]', # whitespace '/\\*(?:.|\n)*?\\*/', # C90-style /* ... */ comment, ended eagerly '//[^\n]*\n', # C99-style comment to end-of-line ))) + # And then match a token '({})'.format('|'.join(( # Punctuation '\(', '\)', ',', # Identifier '[A-Za-z_][A-Za-z0-9_]*', # End of string '$', ))) ) pos = 0 end = len(header) while pos < end: m = pat.match(header, pos) assert m is not None, ( "Failed to lex testcrypt-func.h at byte position {:d}".format(pos)) pos = m.end() tok = m.group(1) if len(tok) == 0: assert pos == end, ( "Empty token should only be returned at end of string") yield tok, m.start(1) def _parse_testcrypt_header(tokens): def is_id(tok): return tok[0] in string.ascii_letters+"_" def expect(what, why, eof_ok=False): tok, pos = next(tokens) if tok == '' and eof_ok: return None if hasattr(what, '__call__'): description = lambda: "" ok = what(tok) elif isinstance(what, set): description = lambda: " or ".join("'"+x+"' " for x in sorted(what)) ok = tok in what else: description = lambda: "'"+what+"' " ok = tok == what if not ok: sys.exit("testcrypt-func.h:{:d}: expected {}{}".format( pos, description(), why)) return tok while True: tok = expect({"FUNC", "FUNC_WRAPPED"}, "at start of function specification", eof_ok=True) if tok is None: break expect("(", "after FUNC") rettype = expect(is_id, "return type") expect(",", "after return type") funcname = expect(is_id, "function name") expect(",", "after function name") args = [] firstargkind = expect({"ARG", "VOID"}, "at start of argument list") if firstargkind == "VOID": expect(")", "after VOID") else: while True: # Every time we come back to the top of this loop, we've # just seen 'ARG' expect("(", "after ARG") argtype = expect(is_id, "argument type") expect(",", "after argument type") argname = expect(is_id, "argument name") args.append((argtype, argname)) expect(")", "at end of ARG") punct = expect({",", ")"}, "after argument") if punct == ")": break expect("ARG", "to begin next argument") yield funcname, rettype, args def _setup(scope): valprefix = "val_" outprefix = "out_" optprefix = "opt_" consprefix = "consumed_" def trim_argtype(arg): if arg.startswith(optprefix): return optprefix + trim_argtype(arg[len(optprefix):]) if (arg.startswith(valprefix) and "_" in arg[len(valprefix):]): # Strip suffixes like val_string_asciz arg = arg[:arg.index("_", len(valprefix))] return arg with open(os.path.join(putty_srcdir, "test", "testcrypt-func.h")) as f: header = f.read() tokens = _lex_testcrypt_header(header) for function, rettype, arglist in _parse_testcrypt_header(tokens): rettypes = [] retnames = [] if rettype != "void": rettypes.append(trim_argtype(rettype)) retnames.append(None) argtypes = [] argnames = [] argsconsumed = [] for arg, argname in arglist: if arg.startswith(outprefix): rettypes.append(trim_argtype(arg[len(outprefix):])) retnames.append(argname) else: consumed = False if arg.startswith(consprefix): arg = arg[len(consprefix):] consumed = True arg = trim_argtype(arg) argtypes.append((arg, consumed)) argnames.append(argname) func = Function(function, rettypes, retnames, argtypes, argnames) scope[function] = func if len(argtypes) > 0: t = argtypes[0][0] if t in method_prefixes: for prefix in method_prefixes[t]: if function.startswith(prefix): methodname = function[len(prefix):] method_lists[t].append((methodname, func)) break _setup(globals()) del _setup