From db7a314c3810666eec8589262431d7055ca57880 Mon Sep 17 00:00:00 2001 From: Simon Tatham Date: Sat, 29 Feb 2020 09:48:00 +0000 Subject: [PATCH] testcrypt.py: fake some OO syntax. When I'm writing Python using the testcrypt API, I keep finding that I instinctively try to call vtable methods as if they were actual methods of the object. For example, calling key.sign(msg, 0) instead of ssh_key_sign(key, msg, 0). So this change to the Python side of the testcrypt mechanism panders to my inappropriate finger-macros by making them work! The idea is that I define a set of pairs (type, prefix), such that any function whose name begins with the prefix and whose first argument is of that type will be automatically translated into a method on the Python object wrapping a testcrypt value of that type. For example, any function of the form ssh_key_foo(val_ssh_key, other args) will automatically be exposed as a method key.foo(other args), simply because (val_ssh_key, "ssh_key_") appears in the translation table. This is particularly nice for the Python 3 REPL, which will let me tab-complete the right set of method names by knowing the type I'm trying to invoke one on. I haven't decided yet whether I want to switch to using it throughout cryptsuite.py. For namespace-cleanness, I've also renamed all the existing attributes of the Python Value class wrapper so that they start with '_', to leave the space of sensible names clear for the new OOish methods. --- test/testcrypt.py | 54 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 14 deletions(-) diff --git a/test/testcrypt.py b/test/testcrypt.py index 85e0b044..9fb35d2d 100644 --- a/test/testcrypt.py +++ b/test/testcrypt.py @@ -89,18 +89,37 @@ class ChildProcess(object): childprocess = ChildProcess() +method_prefixes = { + 'val_wpoint': 'ecc_weierstrass_', + 'val_mpoint': 'ecc_montgomery_', + 'val_epoint': 'ecc_edwards_', + 'val_hash': 'ssh_hash_', + 'val_mac': 'ssh_mac_', + 'val_key': 'ssh_key_', + 'val_cipher': 'ssh_cipher_', + 'val_dh': 'dh_', + 'val_ecdh': 'ssh_ecdhkex_', + 'val_rsakex': 'ssh_rsakex_', + 'val_prng': 'prng_', + 'val_pcs': 'pcs_', +} +method_lists = {t: [] for t in method_prefixes} + class Value(object): def __init__(self, typename, ident): - self.typename = typename - self.ident = ident - def consumed(self): - self.ident = None + self._typename = typename + self._ident = ident + for methodname, function in method_lists.get(self._typename, []): + setattr(self, methodname, + (lambda f: lambda *args: f(self, *args))(function)) + def _consumed(self): + self._ident = None def __repr__(self): - return "Value({!r}, {!r})".format(self.typename, self.ident) + return "Value({!r}, {!r})".format(self._typename, self._ident) def __del__(self): - if self.ident is not None: + if self._ident is not None: try: - childprocess.funcall("free", [self.ident]) + childprocess.funcall("free", [self._ident]) except ChildProcessFailure: # If we see this exception now, we can't do anything # about it, because exceptions don't propagate out of @@ -116,10 +135,10 @@ class Value(object): # crashed for some other reason.) pass def __long__(self): - if self.typename != "val_mpint": + if self._typename != "val_mpint": raise TypeError("testcrypt values of types other than mpint" " cannot be converted to integer") - hexval = childprocess.funcall("mp_dump", [self.ident])[0] + hexval = childprocess.funcall("mp_dump", [self._ident])[0] return 0 if len(hexval) == 0 else int(hexval, 16) def __int__(self): return int(self.__long__()) @@ -143,13 +162,13 @@ def make_argword(arg, argtype, fnname, argindex, to_preserve): arg = make_retvals([typename], retwords)[0] to_preserve.append(arg) if isinstance(arg, Value): - if arg.typename != typename: + if arg._typename != typename: raise TypeError( "{}() argument {:d} should be {} ({} given)".format( - fnname, argindex, typename, arg.typename)) - ident = arg.ident + fnname, argindex, typename, arg._typename)) + ident = arg._ident if consumed: - arg.consumed() + arg._consumed() return ident if typename == "uint" and isinstance(arg, numbers.Integral): return "0x{:x}".format(arg) @@ -278,7 +297,14 @@ def _setup(scope): consumed = True arg = trim_argtype(arg) argtypes.append((arg, consumed)) - scope[function] = Function(function, rettypes, argtypes) + func = Function(function, rettypes, argtypes) + scope[function] = func + if len(argtypes) > 0: + t = argtypes[0][0] + if (t in method_prefixes and + function.startswith(method_prefixes[t])): + methodname = function[len(method_prefixes[t]):] + method_lists[t].append((methodname, func)) _setup(globals()) del _setup