1
0
mirror of https://git.tartarus.org/simon/putty.git synced 2025-01-25 01:02:24 +00:00

cryptsuite.py: a couple more helper functions.

I've moved the static method nbits up into a top-level function, so I
can use it to implement Python marshalling functions for SSH mpints.
I'm about to need one of these, and the other will surely come in
useful as well sooner or later.
This commit is contained in:
Simon Tatham 2019-01-05 08:21:30 +00:00
parent be779f988d
commit e5e520d48e

View File

@ -14,8 +14,31 @@ except ImportError:
from eccref import *
from testcrypt import *
def nbits(n):
# Mimic mp_get_nbits for ordinary Python integers.
assert 0 <= n
smax = next(s for s in itertools.count() if (n >> (1 << s)) == 0)
toret = 0
for shift in reversed([1 << s for s in range(smax)]):
if n >> shift != 0:
n >>= shift
toret += shift
assert n <= 1
if n == 1:
toret += 1
return toret
def ssh_uint32(n):
return struct.pack(">L", n)
def ssh_string(s):
return struct.pack(">L", len(s)) + s
return ssh_uint32(len(s)) + s
def ssh1_mpint(x):
bits = nbits(x)
bytevals = [0xFF & (x >> (8*n)) for n in range((bits-1)//8, -1, -1)]
return struct.pack(">H" + "B" * len(bytevals), bits, *bytevals)
def ssh2_mpint(x):
bytevals = [0xFF & (x >> (8*n)) for n in range(nbits(x)//8, -1, -1)]
return struct.pack(">L" + "B" * len(bytevals), len(bytevals), *bytevals)
def find_non_square_mod(p):
# Find a non-square mod p, using the Jacobi symbol
@ -293,21 +316,6 @@ class mpint(unittest.TestCase):
bm = mp_copy(bi)
self.assertEqual(int(mp_mul(am, bm)), ai * bi)
@staticmethod
def nbits(n):
# Mimic mp_get_nbits for ordinary Python integers.
assert 0 <= n
smax = next(s for s in itertools.count() if (n >> (1 << s)) == 0)
toret = 0
for shift in reversed([1 << s for s in range(smax)]):
if n >> shift != 0:
n >>= shift
toret += shift
assert n <= 1
if n == 1:
toret += 1
return toret
def testDivision(self):
divisors = [1, 2, 3, 2**16+1, 2**32-1, 2**32+1, 2**128-159,
141421356237309504880168872420969807856967187537694807]
@ -319,8 +327,8 @@ class mpint(unittest.TestCase):
if r >= d:
continue # silly cases with tiny divisors
n = q*d + r
mq = mp_new(self.nbits(q))
mr = mp_new(self.nbits(r))
mq = mp_new(nbits(q))
mr = mp_new(nbits(r))
mp_divmod_into(n, d, mq, mr)
self.assertEqual(int(mq), q)
self.assertEqual(int(mr), r)
@ -441,7 +449,7 @@ class mpint(unittest.TestCase):
for p in moduli:
# Count the factors of 2 in the group. (That is, we want
# p-1 to be an odd multiple of 2^{factors_of_2}.)
factors_of_2 = self.nbits((p-1) & (1-p)) - 1
factors_of_2 = nbits((p-1) & (1-p)) - 1
assert (p & ((2 << factors_of_2)-1)) == ((1 << factors_of_2)+1)
z = find_non_square_mod(p)
@ -459,7 +467,7 @@ class mpint(unittest.TestCase):
self.assertFalse(success)
# Make up some more or less random values mod p to square
v1 = pow(3, self.nbits(p), p)
v1 = pow(3, nbits(p), p)
v2 = pow(5, v1, p)
test_roots = [0, 1, 2, 3, 4, 3*p//4, v1, v2, v1+1, 12873*v1, v1*v2]
known_squares = {r*r % p for r in test_roots}