1
0
mirror of https://git.tartarus.org/simon/putty.git synced 2025-01-25 09:12:24 +00:00
putty-source/test/agenttest.py

253 lines
7.6 KiB
Python
Raw Normal View History

#!/usr/bin/python3
import sys
import os
import socket
import base64
import itertools
import collections
from ssh import *
import agenttestdata
test_session_id = b'Test16ByteSessId'
assert len(test_session_id) == 16
test_message_to_sign = b'test message to sign'
TestSig2 = collections.namedtuple("TestSig2", "flags sig")
class Key2(collections.namedtuple("Key2", "comment public sigs openssh")):
def public_only(self):
return Key2(self.comment, self.public, None, None)
def Add(self):
alg = ssh_decode_string(self.public)
msg = (ssh_byte(SSH2_AGENTC_ADD_IDENTITY) +
ssh_string(alg) +
self.openssh +
ssh_string(self.comment))
return agent_query(msg)
verb = "sign"
def Use(self, flags):
msg = (ssh_byte(SSH2_AGENTC_SIGN_REQUEST) +
ssh_string(self.public) +
ssh_string(test_message_to_sign))
if flags is not None:
msg += ssh_uint32(flags)
rsp = agent_query(msg)
t, rsp = ssh_decode_byte(rsp, True)
assert t == SSH2_AGENT_SIGN_RESPONSE
sig, rsp = ssh_decode_string(rsp, True)
assert len(rsp) == 0
return sig
def Del(self):
msg = (ssh_byte(SSH2_AGENTC_REMOVE_IDENTITY) +
ssh_string(self.public))
return agent_query(msg)
@staticmethod
def DelAll():
msg = (ssh_byte(SSH2_AGENTC_REMOVE_ALL_IDENTITIES))
return agent_query(msg)
@staticmethod
def List():
msg = (ssh_byte(SSH2_AGENTC_REQUEST_IDENTITIES))
rsp = agent_query(msg)
t, rsp = ssh_decode_byte(rsp, True)
assert t == SSH2_AGENT_IDENTITIES_ANSWER
nk, rsp = ssh_decode_uint32(rsp, True)
keylist = []
for _ in range(nk):
p, rsp = ssh_decode_string(rsp, True)
c, rsp = ssh_decode_string(rsp, True)
keylist.append(Key2(c, p, None, None))
assert len(rsp) == 0
return keylist
@classmethod
def make_examples(cls):
cls.examples = agenttestdata.key2examples(cls, TestSig2)
def iter_testsigs(self):
for testsig in self.sigs:
if testsig.flags == 0:
yield testsig._replace(flags=None)
yield testsig
def iter_tests(self):
for testsig in self.iter_testsigs():
yield ([testsig.flags],
" (flags={})".format(testsig.flags),
testsig.sig)
class Key1(collections.namedtuple(
"Key1", "comment public challenge response private")):
def public_only(self):
return Key1(self.comment, self.public, None, None, None)
def Add(self):
msg = (ssh_byte(SSH1_AGENTC_ADD_RSA_IDENTITY) +
self.private +
ssh_string(self.comment))
return agent_query(msg)
verb = "decrypt"
def Use(self, challenge):
msg = (ssh_byte(SSH1_AGENTC_RSA_CHALLENGE) +
self.public +
ssh1_mpint(challenge) +
test_session_id +
ssh_uint32(1))
rsp = agent_query(msg)
t, rsp = ssh_decode_byte(rsp, True)
assert t == SSH1_AGENT_RSA_RESPONSE
assert len(rsp) == 16
return rsp
def Del(self):
msg = (ssh_byte(SSH1_AGENTC_REMOVE_RSA_IDENTITY) +
self.public)
return agent_query(msg)
@staticmethod
def DelAll():
msg = (ssh_byte(SSH1_AGENTC_REMOVE_ALL_RSA_IDENTITIES))
return agent_query(msg)
@staticmethod
def List():
msg = (ssh_byte(SSH1_AGENTC_REQUEST_RSA_IDENTITIES))
rsp = agent_query(msg)
t, rsp = ssh_decode_byte(rsp, True)
assert t == SSH1_AGENT_RSA_IDENTITIES_ANSWER
nk, rsp = ssh_decode_uint32(rsp, True)
keylist = []
for _ in range(nk):
b, rsp = ssh_decode_uint32(rsp, True)
e, rsp = ssh1_get_mpint(rsp, True)
m, rsp = ssh1_get_mpint(rsp, True)
c, rsp = ssh_decode_string(rsp, True)
keylist.append(Key1(c, ssh_uint32(b)+e+m, None, None, None))
assert len(rsp) == 0
return keylist
@classmethod
def make_examples(cls):
cls.examples = agenttestdata.key1examples(cls)
def iter_tests(self):
yield [self.challenge], "", self.response
def agent_query(msg):
msg = ssh_string(msg)
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
s.connect(os.environ["SSH_AUTH_SOCK"])
s.send(msg)
length = ssh_decode_uint32(s.recv(4))
assert length < AGENT_MAX_MSGLEN
return s.recv(length)
def enumerate_bits(iterable):
return ((1<<j, item) for j,item in enumerate(iterable))
def gray_code(nbits):
old = 0
for i in itertools.chain(range(1, 1 << nbits), [0]):
new = i ^ (i>>1)
diff = new ^ old
assert diff != 0 and (diff & (diff-1)) == 0
yield old, new, diff
old = new
assert old == 0
class TestRunner:
def __init__(self):
self.ok = True
@staticmethod
def fmt_response(response):
return "'{}'".format(
base64.encodebytes(response).decode("ASCII").replace("\n",""))
@staticmethod
def fmt_keylist(keys):
return "{{{}}}".format(
",".join(key.comment.decode("ASCII") for key in sorted(keys)))
def expect_success(self, text, response):
if response == ssh_byte(SSH_AGENT_SUCCESS):
print(text, "=> success")
elif response == ssh_byte(SSH_AGENT_FAILURE):
print("FAIL!", text, "=> failure")
self.ok = False
else:
print("FAIL!", text, "=>", self.fmt_response(response))
self.ok = False
def check_keylist(self, K, expected_keys):
keys = K.List()
print("list keys =>", self.fmt_keylist(keys))
if set(keys) != set(expected_keys):
print("FAIL! Should have been", self.fmt_keylist(expected_keys))
self.ok = False
def gray_code_test(self, K):
bks = list(enumerate_bits(K.examples))
self.check_keylist(K, {})
for old, new, diff in gray_code(len(K.examples)):
bit, key = next((bit, key) for bit, key in bks if diff & bit)
if new & bit:
self.expect_success("insert " + key.comment.decode("ASCII"),
key.Add())
else:
self.expect_success("delete " + key.comment.decode("ASCII"),
key.Del())
self.check_keylist(K, [key.public_only() for bit, key in bks
if new & bit])
def sign_test(self, K):
for key in K.examples:
for params, message, expected_answer in key.iter_tests():
key.Add()
actual_answer = key.Use(*params)
key.Del()
record = "{} with {}{}".format(
K.verb, key.comment.decode("ASCII"), message)
if actual_answer == expected_answer:
print(record, "=> success")
else:
print("FAIL!", record, "=> {} but expected {}".format(
self.fmt_response(actual_answer),
self.fmt_response(expected_answer)))
self.ok = False
def run(self):
self.expect_success("init: delete all ssh2 keys", Key2.DelAll())
for K in [Key2, Key1]:
self.gray_code_test(K)
self.sign_test(K)
# TODO: negative tests of all kinds.
def main():
Key2.make_examples()
Key1.make_examples()
tr = TestRunner()
tr.run()
if tr.ok:
print("Test run passed")
else:
sys.exit("Test run failed!")
if __name__ == "__main__":
main()