mirror of
https://git.tartarus.org/simon/putty.git
synced 2025-01-10 09:58:01 +00:00
8c7b0a787f
Well, actually, two new test programs. agenttest.py is the actual test; it depends on agenttestgen.py which generates a collection of test private keys, using the newly exposed testcrypt interface to our key generation code. In this commit I've also factored out some Python SSH marshalling code from cryptsuite, and moved it into a module ssh.py which the agent tests can reuse.
253 lines
7.6 KiB
Python
Executable File
253 lines
7.6 KiB
Python
Executable File
#!/usr/bin/python3
|
|
|
|
import sys
|
|
import os
|
|
import socket
|
|
import base64
|
|
import itertools
|
|
import collections
|
|
|
|
from ssh import *
|
|
import agenttestdata
|
|
|
|
test_session_id = b'Test16ByteSessId'
|
|
assert len(test_session_id) == 16
|
|
test_message_to_sign = b'test message to sign'
|
|
|
|
TestSig2 = collections.namedtuple("TestSig2", "flags sig")
|
|
|
|
class Key2(collections.namedtuple("Key2", "comment public sigs openssh")):
|
|
def public_only(self):
|
|
return Key2(self.comment, self.public, None, None)
|
|
|
|
def Add(self):
|
|
alg = ssh_decode_string(self.public)
|
|
msg = (ssh_byte(SSH2_AGENTC_ADD_IDENTITY) +
|
|
ssh_string(alg) +
|
|
self.openssh +
|
|
ssh_string(self.comment))
|
|
return agent_query(msg)
|
|
|
|
verb = "sign"
|
|
def Use(self, flags):
|
|
msg = (ssh_byte(SSH2_AGENTC_SIGN_REQUEST) +
|
|
ssh_string(self.public) +
|
|
ssh_string(test_message_to_sign))
|
|
if flags is not None:
|
|
msg += ssh_uint32(flags)
|
|
rsp = agent_query(msg)
|
|
t, rsp = ssh_decode_byte(rsp, True)
|
|
assert t == SSH2_AGENT_SIGN_RESPONSE
|
|
sig, rsp = ssh_decode_string(rsp, True)
|
|
assert len(rsp) == 0
|
|
return sig
|
|
|
|
def Del(self):
|
|
msg = (ssh_byte(SSH2_AGENTC_REMOVE_IDENTITY) +
|
|
ssh_string(self.public))
|
|
return agent_query(msg)
|
|
|
|
@staticmethod
|
|
def DelAll():
|
|
msg = (ssh_byte(SSH2_AGENTC_REMOVE_ALL_IDENTITIES))
|
|
return agent_query(msg)
|
|
|
|
@staticmethod
|
|
def List():
|
|
msg = (ssh_byte(SSH2_AGENTC_REQUEST_IDENTITIES))
|
|
rsp = agent_query(msg)
|
|
t, rsp = ssh_decode_byte(rsp, True)
|
|
assert t == SSH2_AGENT_IDENTITIES_ANSWER
|
|
nk, rsp = ssh_decode_uint32(rsp, True)
|
|
keylist = []
|
|
for _ in range(nk):
|
|
p, rsp = ssh_decode_string(rsp, True)
|
|
c, rsp = ssh_decode_string(rsp, True)
|
|
keylist.append(Key2(c, p, None, None))
|
|
assert len(rsp) == 0
|
|
return keylist
|
|
|
|
@classmethod
|
|
def make_examples(cls):
|
|
cls.examples = agenttestdata.key2examples(cls, TestSig2)
|
|
|
|
def iter_testsigs(self):
|
|
for testsig in self.sigs:
|
|
if testsig.flags == 0:
|
|
yield testsig._replace(flags=None)
|
|
yield testsig
|
|
|
|
def iter_tests(self):
|
|
for testsig in self.iter_testsigs():
|
|
yield ([testsig.flags],
|
|
" (flags={})".format(testsig.flags),
|
|
testsig.sig)
|
|
|
|
class Key1(collections.namedtuple(
|
|
"Key1", "comment public challenge response private")):
|
|
def public_only(self):
|
|
return Key1(self.comment, self.public, None, None, None)
|
|
|
|
def Add(self):
|
|
msg = (ssh_byte(SSH1_AGENTC_ADD_RSA_IDENTITY) +
|
|
self.private +
|
|
ssh_string(self.comment))
|
|
return agent_query(msg)
|
|
|
|
verb = "decrypt"
|
|
def Use(self, challenge):
|
|
msg = (ssh_byte(SSH1_AGENTC_RSA_CHALLENGE) +
|
|
self.public +
|
|
ssh1_mpint(challenge) +
|
|
test_session_id +
|
|
ssh_uint32(1))
|
|
rsp = agent_query(msg)
|
|
t, rsp = ssh_decode_byte(rsp, True)
|
|
assert t == SSH1_AGENT_RSA_RESPONSE
|
|
assert len(rsp) == 16
|
|
return rsp
|
|
|
|
def Del(self):
|
|
msg = (ssh_byte(SSH1_AGENTC_REMOVE_RSA_IDENTITY) +
|
|
self.public)
|
|
return agent_query(msg)
|
|
|
|
@staticmethod
|
|
def DelAll():
|
|
msg = (ssh_byte(SSH1_AGENTC_REMOVE_ALL_RSA_IDENTITIES))
|
|
return agent_query(msg)
|
|
|
|
@staticmethod
|
|
def List():
|
|
msg = (ssh_byte(SSH1_AGENTC_REQUEST_RSA_IDENTITIES))
|
|
rsp = agent_query(msg)
|
|
t, rsp = ssh_decode_byte(rsp, True)
|
|
assert t == SSH1_AGENT_RSA_IDENTITIES_ANSWER
|
|
nk, rsp = ssh_decode_uint32(rsp, True)
|
|
keylist = []
|
|
for _ in range(nk):
|
|
b, rsp = ssh_decode_uint32(rsp, True)
|
|
e, rsp = ssh1_get_mpint(rsp, True)
|
|
m, rsp = ssh1_get_mpint(rsp, True)
|
|
c, rsp = ssh_decode_string(rsp, True)
|
|
keylist.append(Key1(c, ssh_uint32(b)+e+m, None, None, None))
|
|
assert len(rsp) == 0
|
|
return keylist
|
|
|
|
@classmethod
|
|
def make_examples(cls):
|
|
cls.examples = agenttestdata.key1examples(cls)
|
|
|
|
def iter_tests(self):
|
|
yield [self.challenge], "", self.response
|
|
|
|
def agent_query(msg):
|
|
msg = ssh_string(msg)
|
|
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
|
s.connect(os.environ["SSH_AUTH_SOCK"])
|
|
s.send(msg)
|
|
length = ssh_decode_uint32(s.recv(4))
|
|
assert length < AGENT_MAX_MSGLEN
|
|
return s.recv(length)
|
|
|
|
def enumerate_bits(iterable):
|
|
return ((1<<j, item) for j,item in enumerate(iterable))
|
|
|
|
def gray_code(nbits):
|
|
old = 0
|
|
for i in itertools.chain(range(1, 1 << nbits), [0]):
|
|
new = i ^ (i>>1)
|
|
diff = new ^ old
|
|
assert diff != 0 and (diff & (diff-1)) == 0
|
|
yield old, new, diff
|
|
old = new
|
|
assert old == 0
|
|
|
|
class TestRunner:
|
|
def __init__(self):
|
|
self.ok = True
|
|
|
|
@staticmethod
|
|
def fmt_response(response):
|
|
return "'{}'".format(
|
|
base64.encodebytes(response).decode("ASCII").replace("\n",""))
|
|
|
|
@staticmethod
|
|
def fmt_keylist(keys):
|
|
return "{{{}}}".format(
|
|
",".join(key.comment.decode("ASCII") for key in sorted(keys)))
|
|
|
|
def expect_success(self, text, response):
|
|
if response == ssh_byte(SSH_AGENT_SUCCESS):
|
|
print(text, "=> success")
|
|
elif response == ssh_byte(SSH_AGENT_FAILURE):
|
|
print("FAIL!", text, "=> failure")
|
|
self.ok = False
|
|
else:
|
|
print("FAIL!", text, "=>", self.fmt_response(response))
|
|
self.ok = False
|
|
|
|
def check_keylist(self, K, expected_keys):
|
|
keys = K.List()
|
|
print("list keys =>", self.fmt_keylist(keys))
|
|
if set(keys) != set(expected_keys):
|
|
print("FAIL! Should have been", self.fmt_keylist(expected_keys))
|
|
self.ok = False
|
|
|
|
def gray_code_test(self, K):
|
|
bks = list(enumerate_bits(K.examples))
|
|
|
|
self.check_keylist(K, {})
|
|
|
|
for old, new, diff in gray_code(len(K.examples)):
|
|
bit, key = next((bit, key) for bit, key in bks if diff & bit)
|
|
|
|
if new & bit:
|
|
self.expect_success("insert " + key.comment.decode("ASCII"),
|
|
key.Add())
|
|
else:
|
|
self.expect_success("delete " + key.comment.decode("ASCII"),
|
|
key.Del())
|
|
|
|
self.check_keylist(K, [key.public_only() for bit, key in bks
|
|
if new & bit])
|
|
|
|
def sign_test(self, K):
|
|
for key in K.examples:
|
|
for params, message, expected_answer in key.iter_tests():
|
|
key.Add()
|
|
actual_answer = key.Use(*params)
|
|
key.Del()
|
|
record = "{} with {}{}".format(
|
|
K.verb, key.comment.decode("ASCII"), message)
|
|
if actual_answer == expected_answer:
|
|
print(record, "=> success")
|
|
else:
|
|
print("FAIL!", record, "=> {} but expected {}".format(
|
|
self.fmt_response(actual_answer),
|
|
self.fmt_response(expected_answer)))
|
|
self.ok = False
|
|
|
|
def run(self):
|
|
self.expect_success("init: delete all ssh2 keys", Key2.DelAll())
|
|
|
|
for K in [Key2, Key1]:
|
|
self.gray_code_test(K)
|
|
self.sign_test(K)
|
|
|
|
# TODO: negative tests of all kinds.
|
|
|
|
def main():
|
|
Key2.make_examples()
|
|
Key1.make_examples()
|
|
|
|
tr = TestRunner()
|
|
tr.run()
|
|
if tr.ok:
|
|
print("Test run passed")
|
|
else:
|
|
sys.exit("Test run failed!")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|