1
0
mirror of https://git.tartarus.org/simon/putty.git synced 2025-01-25 01:02:24 +00:00
putty-source/test/numbertheory.py
Simon Tatham 7be2e16023 numbertheory.py: make the ModP class hashable.
That will let me keep them in sets.
2020-02-28 20:40:08 +00:00

182 lines
5.7 KiB
Python

import numbers
import itertools
def invert(a, b):
"Multiplicative inverse of a mod b. a,b must be coprime."
A = (a, 1, 0)
B = (b, 0, 1)
while B[0]:
q = A[0] // B[0]
A, B = B, tuple(Ai - q*Bi for Ai, Bi in zip(A, B))
assert abs(A[0]) == 1
return A[1]*A[0] % b
def jacobi(n,m):
"""Compute the Jacobi symbol.
The special case of this when m is prime is the Legendre symbol,
which is 0 if n is congruent to 0 mod m; 1 if n is congruent to a
non-zero square number mod m; -1 if n is not congruent to any
square mod m.
"""
assert m & 1
acc = 1
while True:
n %= m
if n == 0:
return 0
while not (n & 1):
n >>= 1
if (m & 7) not in {1,7}:
acc *= -1
if n == 1:
return acc
if (n & 3) == 3 and (m & 3) == 3:
acc *= -1
n, m = m, n
class SqrtModP(object):
"""Class for finding square roots of numbers mod p.
p must be an odd prime (but its primality is not checked)."""
def __init__(self, p):
p = abs(p)
assert p & 1
self.p = p
# Decompose p as 2^e k + 1 for odd k.
self.k = p-1
self.e = 0
while not (self.k & 1):
self.k >>= 1
self.e += 1
# Find a non-square mod p.
for self.z in itertools.count(1):
if jacobi(self.z, self.p) == -1:
break
self.zinv = ModP(self.p, self.z).invert()
def sqrt_recurse(self, a):
ak = pow(a, self.k, self.p)
for i in range(self.e, -1, -1):
if ak == 1:
break
ak = ak*ak % self.p
assert i > 0
if i == self.e:
return pow(a, (self.k+1) // 2, self.p)
r_prime = self.sqrt_recurse(a * pow(self.z, 2**i, self.p))
return r_prime * pow(self.zinv, 2**(i-1), self.p) % self.p
def sqrt(self, a):
j = jacobi(a, self.p)
if j == 0:
return 0
if j < 0:
raise ValueError("{} has no square root mod {}".format(a, self.p))
a %= self.p
r = self.sqrt_recurse(a)
assert r*r % self.p == a
# Normalise to the smaller (or 'positive') one of the two roots.
return min(r, self.p - r)
def __str__(self):
return "{}({})".format(type(self).__name__, self.p)
def __repr__(self):
return self.__str__()
class ModP(object):
"""Class that represents integers mod p as a field.
All the usual arithmetic operations are supported directly,
including division, so you can write formulas in a natural way
without having to keep saying '% p' everywhere or call a
cumbersome modular_inverse() function.
"""
def __init__(self, p, n=0):
self.p = p
if isinstance(n, type(self)):
self.check(n)
n = n.n
self.n = n % p
def check(self, other):
assert isinstance(other, type(self))
assert isinstance(self, type(other))
assert self.p == other.p
def coerce_to(self, other):
if not isinstance(other, type(self)):
other = type(self)(self.p, other)
else:
self.check(other)
return other
def __int__(self):
return self.n
def __add__(self, rhs):
rhs = self.coerce_to(rhs)
return type(self)(self.p, (self.n + rhs.n) % self.p)
def __neg__(self):
return type(self)(self.p, -self.n % self.p)
def __radd__(self, rhs):
rhs = self.coerce_to(rhs)
return type(self)(self.p, (self.n + rhs.n) % self.p)
def __sub__(self, rhs):
rhs = self.coerce_to(rhs)
return type(self)(self.p, (self.n - rhs.n) % self.p)
def __rsub__(self, rhs):
rhs = self.coerce_to(rhs)
return type(self)(self.p, (rhs.n - self.n) % self.p)
def __mul__(self, rhs):
rhs = self.coerce_to(rhs)
return type(self)(self.p, (self.n * rhs.n) % self.p)
def __rmul__(self, rhs):
rhs = self.coerce_to(rhs)
return type(self)(self.p, (self.n * rhs.n) % self.p)
def __div__(self, rhs):
rhs = self.coerce_to(rhs)
return type(self)(self.p, (self.n * invert(rhs.n, self.p)) % self.p)
def __rdiv__(self, rhs):
rhs = self.coerce_to(rhs)
return type(self)(self.p, (rhs.n * invert(self.n, self.p)) % self.p)
def __truediv__(self, rhs): return self.__div__(rhs)
def __rtruediv__(self, rhs): return self.__rdiv__(rhs)
def __pow__(self, exponent):
assert exponent >= 0
n, b_to_n = 1, self
total = type(self)(self.p, 1)
while True:
if exponent & n:
exponent -= n
total *= b_to_n
n *= 2
if n > exponent:
break
b_to_n *= b_to_n
return total
def __cmp__(self, rhs):
rhs = self.coerce_to(rhs)
return cmp(self.n, rhs.n)
def __eq__(self, rhs):
rhs = self.coerce_to(rhs)
return self.n == rhs.n
def __ne__(self, rhs):
rhs = self.coerce_to(rhs)
return self.n != rhs.n
def __lt__(self, rhs):
raise ValueError("Elements of a modular ring have no ordering")
def __le__(self, rhs):
raise ValueError("Elements of a modular ring have no ordering")
def __gt__(self, rhs):
raise ValueError("Elements of a modular ring have no ordering")
def __ge__(self, rhs):
raise ValueError("Elements of a modular ring have no ordering")
def __str__(self):
return "0x{:x}".format(self.n)
def __repr__(self):
return "{}(0x{:x},0x{:x})".format(type(self).__name__, self.p, self.n)
def __hash__(self):
return hash((type(self).__name__, self.p, self.n))