mirror of
https://git.tartarus.org/simon/putty.git
synced 2025-01-09 17:38:00 +00:00
numbertheory.py: generalise SqrtModP to do other roots.
I'm about to want to solve quartics mod a prime, which means I'll need to be able to take cube roots mod p as well as square roots. This commit introduces a more general class which can take rth roots for any prime r, and moreover, it can do it in a general cyclic group. (You have to tell it the group's order and give it some primitives for doing arithmetic, plus a way of iterating over the group elements that it can use to look for a non-rth-power and roots of unity.) That system makes it nicely easy to test, because you can give it a cyclic group represented as the integers under _addition_, and then you obviously know what all the right answers are. So I've also added a unit test system checking that.
This commit is contained in:
parent
7be2e16023
commit
072d3c665a
@ -111,9 +111,9 @@ class WeierstrassCurve(CurveBase):
|
||||
|
||||
def cpoint(self, x, yparity=0):
|
||||
if not hasattr(self, 'sqrtmodp'):
|
||||
self.sqrtmodp = SqrtModP(self.p)
|
||||
self.sqrtmodp = RootModP(2, self.p)
|
||||
rhs = x**3 + self.a.n * x + self.b.n
|
||||
y = self.sqrtmodp.sqrt(rhs)
|
||||
y = self.sqrtmodp.root(rhs)
|
||||
if (y - yparity) % 2:
|
||||
y = -y
|
||||
return self.point(x, y)
|
||||
@ -157,9 +157,9 @@ class MontgomeryCurve(CurveBase):
|
||||
|
||||
def cpoint(self, x, yparity=0):
|
||||
if not hasattr(self, 'sqrtmodp'):
|
||||
self.sqrtmodp = SqrtModP(self.p)
|
||||
self.sqrtmodp = RootModP(2, self.p)
|
||||
rhs = (x**3 + self.a.n * x**2 + x) / self.b
|
||||
y = self.sqrtmodp.sqrt(int(rhs))
|
||||
y = self.sqrtmodp.root(int(rhs))
|
||||
if (y - yparity) % 2:
|
||||
y = -y
|
||||
return self.point(x, y)
|
||||
@ -198,11 +198,11 @@ class TwistedEdwardsCurve(CurveBase):
|
||||
|
||||
def cpoint(self, y, xparity=0):
|
||||
if not hasattr(self, 'sqrtmodp'):
|
||||
self.sqrtmodp = SqrtModP(self.p)
|
||||
self.sqrtmodp = RootModP(self.p)
|
||||
y = ModP(self.p, y)
|
||||
y2 = y**2
|
||||
radicand = (y2 - 1) / (self.d * y2 - self.a)
|
||||
x = self.sqrtmodp.sqrt(radicand.n)
|
||||
x = self.sqrtmodp.root(radicand.n)
|
||||
if (x - xparity) % 2:
|
||||
x = -x
|
||||
return self.point(x, y)
|
||||
|
@ -1,5 +1,6 @@
|
||||
import numbers
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
def invert(a, b):
|
||||
"Multiplicative inverse of a mod b. a,b must be coprime."
|
||||
@ -36,57 +37,148 @@ def jacobi(n,m):
|
||||
acc *= -1
|
||||
n, m = m, n
|
||||
|
||||
class SqrtModP(object):
|
||||
"""Class for finding square roots of numbers mod p.
|
||||
class CyclicGroupRootFinder(object):
|
||||
"""Class for finding rth roots in a cyclic group. r must be prime."""
|
||||
|
||||
p must be an odd prime (but its primality is not checked)."""
|
||||
# Basic strategy:
|
||||
#
|
||||
# We write |G| = r^k u, with u coprime to r. This gives us a
|
||||
# nested sequence of subgroups G = G_0 > G_1 > ... > G_k, each
|
||||
# with index 3 in its predecessor. G_0 is the whole group, and the
|
||||
# innermost G_k has order u.
|
||||
#
|
||||
# Within G_k, you can take an rth root by raising an element to
|
||||
# the power of (r^{-1} mod u). If k=0 (so G = G_0 = G_k) then
|
||||
# that's all that's needed: every element has a unique rth root.
|
||||
# But if k>0, then things go differently.
|
||||
#
|
||||
# Define the 'rank' of an element g as the highest i such that
|
||||
# g \in G_i. Elements of rank 0 are the non-rth-powers: they don't
|
||||
# even _have_ an rth root. Elements of rank k are the easy ones to
|
||||
# take rth roots of, as above.
|
||||
#
|
||||
# In between, you can follow an inductive process, as long as you
|
||||
# know one element z of index 0. Suppose we're trying to take the
|
||||
# rth root of some g with index i. Repeatedly multiply g by
|
||||
# z^{r^i} until its index increases; then take the root of that
|
||||
# (recursively), and divide off z^{r^{i-1}} once you're done.
|
||||
|
||||
def __init__(self, p):
|
||||
p = abs(p)
|
||||
assert p & 1
|
||||
self.p = p
|
||||
def __init__(self, r, order):
|
||||
self.order = order # order of G
|
||||
self.r = r
|
||||
self.k = next(k for k in itertools.count()
|
||||
if self.order % (r**(k+1)) != 0)
|
||||
self.u = self.order // (r**self.k)
|
||||
self.z = next(z for z in self.iter_elements()
|
||||
if self.index(z) == 0)
|
||||
self.zinv = self.inverse(self.z)
|
||||
self.root_power = invert(self.r, self.u) if self.u > 1 else 0
|
||||
|
||||
# Decompose p as 2^e k + 1 for odd k.
|
||||
self.k = p-1
|
||||
self.e = 0
|
||||
while not (self.k & 1):
|
||||
self.k >>= 1
|
||||
self.e += 1
|
||||
self.roots_of_unity = {self.identity()}
|
||||
if self.k > 0:
|
||||
exponent = self.order // self.r
|
||||
for z in self.iter_elements():
|
||||
root_of_unity = self.pow(z, exponent)
|
||||
if root_of_unity not in self.roots_of_unity:
|
||||
self.roots_of_unity.add(root_of_unity)
|
||||
if len(self.roots_of_unity) == r:
|
||||
break
|
||||
|
||||
# Find a non-square mod p.
|
||||
for self.z in itertools.count(1):
|
||||
if jacobi(self.z, self.p) == -1:
|
||||
break
|
||||
self.zinv = ModP(self.p, self.z).invert()
|
||||
def index(self, g):
|
||||
h = self.pow(g, self.u)
|
||||
for i in range(self.k+1):
|
||||
if h == self.identity():
|
||||
return self.k - i
|
||||
h = self.pow(h, self.r)
|
||||
assert False, ("Not a cyclic group! Raising {} to u r^k should give e."
|
||||
.format(g))
|
||||
|
||||
def sqrt_recurse(self, a):
|
||||
ak = pow(a, self.k, self.p)
|
||||
for i in range(self.e, -1, -1):
|
||||
if ak == 1:
|
||||
break
|
||||
ak = ak*ak % self.p
|
||||
assert i > 0
|
||||
if i == self.e:
|
||||
return pow(a, (self.k+1) // 2, self.p)
|
||||
r_prime = self.sqrt_recurse(a * pow(self.z, 2**i, self.p))
|
||||
return r_prime * pow(self.zinv, 2**(i-1), self.p) % self.p
|
||||
def all_roots(self, g):
|
||||
try:
|
||||
r = self.root(g)
|
||||
except ValueError:
|
||||
return []
|
||||
return {r * rou for rou in self.roots_of_unity}
|
||||
|
||||
def sqrt(self, a):
|
||||
j = jacobi(a, self.p)
|
||||
if j == 0:
|
||||
return 0
|
||||
if j < 0:
|
||||
raise ValueError("{} has no square root mod {}".format(a, self.p))
|
||||
a %= self.p
|
||||
r = self.sqrt_recurse(a)
|
||||
assert r*r % self.p == a
|
||||
# Normalise to the smaller (or 'positive') one of the two roots.
|
||||
return min(r, self.p - r)
|
||||
def root(self, g):
|
||||
i = self.index(g)
|
||||
if i == 0 and self.k > 0:
|
||||
raise ValueError("{} has no {}th root".format(g, self.r))
|
||||
out = self.root_recurse(g, i)
|
||||
assert self.pow(out, self.r) == g
|
||||
return out
|
||||
|
||||
def __str__(self):
|
||||
return "{}({})".format(type(self).__name__, self.p)
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
def root_recurse(self, g, i):
|
||||
if i == self.k:
|
||||
return self.pow(g, self.root_power)
|
||||
z_in = self.pow(self.z, self.r**i)
|
||||
z_out = self.pow(self.zinv, self.r**(i-1))
|
||||
adjust = self.identity()
|
||||
while True:
|
||||
g = self.mul(g, z_in)
|
||||
adjust = self.mul(adjust, z_out)
|
||||
i2 = self.index(g)
|
||||
if i2 > i:
|
||||
return self.mul(self.root_recurse(g, i2), adjust)
|
||||
|
||||
class AdditiveGroupRootFinder(CyclicGroupRootFinder):
|
||||
"""Trivial test subclass for CyclicGroupRootFinder.
|
||||
|
||||
Represents a cyclic group of any order additively, as the integers
|
||||
mod n under addition. This makes root-finding trivial without
|
||||
having to use the complicated algorithm above, and therefore it's
|
||||
a good way to test the complicated algorithm under conditions
|
||||
where the right answers are obvious."""
|
||||
|
||||
def __init__(self, r, order):
|
||||
super().__init__(r, order)
|
||||
|
||||
def mul(self, x, y):
|
||||
return (x + y) % self.order
|
||||
def pow(self, x, n):
|
||||
return (x * n) % self.order
|
||||
def inverse(self, x):
|
||||
return (-x) % self.order
|
||||
def identity(self):
|
||||
return 0
|
||||
def iter_elements(self):
|
||||
return range(self.order)
|
||||
|
||||
class TestCyclicGroupRootFinder(unittest.TestCase):
|
||||
def testRootFinding(self):
|
||||
for order in 10, 11, 12, 18:
|
||||
grf = AdditiveGroupRootFinder(3, order)
|
||||
for i in range(order):
|
||||
try:
|
||||
r = grf.root(i)
|
||||
except ValueError:
|
||||
r = None
|
||||
|
||||
if order % 3 == 0 and i % 3 != 0:
|
||||
self.assertEqual(r, None)
|
||||
else:
|
||||
self.assertEqual(r*3 % order, i)
|
||||
|
||||
class RootModP(CyclicGroupRootFinder):
|
||||
"""The live class that can take rth roots mod a prime."""
|
||||
|
||||
def __init__(self, r, p):
|
||||
self.modulus = p
|
||||
super().__init__(r, p-1)
|
||||
|
||||
def mul(self, x, y):
|
||||
return (x * y) % self.modulus
|
||||
def pow(self, x, n):
|
||||
return pow(x, n, self.modulus)
|
||||
def inverse(self, x):
|
||||
return invert(x, self.modulus)
|
||||
def identity(self):
|
||||
return 1
|
||||
def iter_elements(self):
|
||||
return range(1, self.modulus)
|
||||
|
||||
def root(self, g):
|
||||
return 0 if g == 0 else super().root(g)
|
||||
|
||||
class ModP(object):
|
||||
"""Class that represents integers mod p as a field.
|
||||
@ -179,3 +271,9 @@ class ModP(object):
|
||||
return "{}(0x{:x},0x{:x})".format(type(self).__name__, self.p, self.n)
|
||||
def __hash__(self):
|
||||
return hash((type(self).__name__, self.p, self.n))
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
if sys.argv[1:] == ["--test"]:
|
||||
sys.argv[1:2] = []
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user