mirror of
https://git.tartarus.org/simon/putty.git
synced 2025-01-25 01:02:24 +00:00
cryptsuite.py: a couple more helper functions.
I've moved the static method nbits up into a top-level function, so I can use it to implement Python marshalling functions for SSH mpints. I'm about to need one of these, and the other will surely come in useful as well sooner or later.
This commit is contained in:
parent
be779f988d
commit
e5e520d48e
@ -14,8 +14,31 @@ except ImportError:
|
||||
from eccref import *
|
||||
from testcrypt import *
|
||||
|
||||
def nbits(n):
|
||||
# Mimic mp_get_nbits for ordinary Python integers.
|
||||
assert 0 <= n
|
||||
smax = next(s for s in itertools.count() if (n >> (1 << s)) == 0)
|
||||
toret = 0
|
||||
for shift in reversed([1 << s for s in range(smax)]):
|
||||
if n >> shift != 0:
|
||||
n >>= shift
|
||||
toret += shift
|
||||
assert n <= 1
|
||||
if n == 1:
|
||||
toret += 1
|
||||
return toret
|
||||
|
||||
def ssh_uint32(n):
|
||||
return struct.pack(">L", n)
|
||||
def ssh_string(s):
|
||||
return struct.pack(">L", len(s)) + s
|
||||
return ssh_uint32(len(s)) + s
|
||||
def ssh1_mpint(x):
|
||||
bits = nbits(x)
|
||||
bytevals = [0xFF & (x >> (8*n)) for n in range((bits-1)//8, -1, -1)]
|
||||
return struct.pack(">H" + "B" * len(bytevals), bits, *bytevals)
|
||||
def ssh2_mpint(x):
|
||||
bytevals = [0xFF & (x >> (8*n)) for n in range(nbits(x)//8, -1, -1)]
|
||||
return struct.pack(">L" + "B" * len(bytevals), len(bytevals), *bytevals)
|
||||
|
||||
def find_non_square_mod(p):
|
||||
# Find a non-square mod p, using the Jacobi symbol
|
||||
@ -293,21 +316,6 @@ class mpint(unittest.TestCase):
|
||||
bm = mp_copy(bi)
|
||||
self.assertEqual(int(mp_mul(am, bm)), ai * bi)
|
||||
|
||||
@staticmethod
|
||||
def nbits(n):
|
||||
# Mimic mp_get_nbits for ordinary Python integers.
|
||||
assert 0 <= n
|
||||
smax = next(s for s in itertools.count() if (n >> (1 << s)) == 0)
|
||||
toret = 0
|
||||
for shift in reversed([1 << s for s in range(smax)]):
|
||||
if n >> shift != 0:
|
||||
n >>= shift
|
||||
toret += shift
|
||||
assert n <= 1
|
||||
if n == 1:
|
||||
toret += 1
|
||||
return toret
|
||||
|
||||
def testDivision(self):
|
||||
divisors = [1, 2, 3, 2**16+1, 2**32-1, 2**32+1, 2**128-159,
|
||||
141421356237309504880168872420969807856967187537694807]
|
||||
@ -319,8 +327,8 @@ class mpint(unittest.TestCase):
|
||||
if r >= d:
|
||||
continue # silly cases with tiny divisors
|
||||
n = q*d + r
|
||||
mq = mp_new(self.nbits(q))
|
||||
mr = mp_new(self.nbits(r))
|
||||
mq = mp_new(nbits(q))
|
||||
mr = mp_new(nbits(r))
|
||||
mp_divmod_into(n, d, mq, mr)
|
||||
self.assertEqual(int(mq), q)
|
||||
self.assertEqual(int(mr), r)
|
||||
@ -441,7 +449,7 @@ class mpint(unittest.TestCase):
|
||||
for p in moduli:
|
||||
# Count the factors of 2 in the group. (That is, we want
|
||||
# p-1 to be an odd multiple of 2^{factors_of_2}.)
|
||||
factors_of_2 = self.nbits((p-1) & (1-p)) - 1
|
||||
factors_of_2 = nbits((p-1) & (1-p)) - 1
|
||||
assert (p & ((2 << factors_of_2)-1)) == ((1 << factors_of_2)+1)
|
||||
|
||||
z = find_non_square_mod(p)
|
||||
@ -459,7 +467,7 @@ class mpint(unittest.TestCase):
|
||||
self.assertFalse(success)
|
||||
|
||||
# Make up some more or less random values mod p to square
|
||||
v1 = pow(3, self.nbits(p), p)
|
||||
v1 = pow(3, nbits(p), p)
|
||||
v2 = pow(5, v1, p)
|
||||
test_roots = [0, 1, 2, 3, 4, 3*p//4, v1, v2, v1+1, 12873*v1, v1*v2]
|
||||
known_squares = {r*r % p for r in test_roots}
|
||||
|
Loading…
Reference in New Issue
Block a user