From 072d3c665a5943f2ff1d396af5a1457c462afa92 Mon Sep 17 00:00:00 2001 From: Simon Tatham Date: Fri, 28 Feb 2020 20:14:28 +0000 Subject: [PATCH] numbertheory.py: generalise SqrtModP to do other roots. I'm about to want to solve quartics mod a prime, which means I'll need to be able to take cube roots mod p as well as square roots. This commit introduces a more general class which can take rth roots for any prime r, and moreover, it can do it in a general cyclic group. (You have to tell it the group's order and give it some primitives for doing arithmetic, plus a way of iterating over the group elements that it can use to look for a non-rth-power and roots of unity.) That system makes it nicely easy to test, because you can give it a cyclic group represented as the integers under _addition_, and then you obviously know what all the right answers are. So I've also added a unit test system checking that. --- test/eccref.py | 12 +-- test/numbertheory.py | 186 +++++++++++++++++++++++++++++++++---------- 2 files changed, 148 insertions(+), 50 deletions(-) diff --git a/test/eccref.py b/test/eccref.py index b5158ef5..bc64095f 100644 --- a/test/eccref.py +++ b/test/eccref.py @@ -111,9 +111,9 @@ class WeierstrassCurve(CurveBase): def cpoint(self, x, yparity=0): if not hasattr(self, 'sqrtmodp'): - self.sqrtmodp = SqrtModP(self.p) + self.sqrtmodp = RootModP(2, self.p) rhs = x**3 + self.a.n * x + self.b.n - y = self.sqrtmodp.sqrt(rhs) + y = self.sqrtmodp.root(rhs) if (y - yparity) % 2: y = -y return self.point(x, y) @@ -157,9 +157,9 @@ class MontgomeryCurve(CurveBase): def cpoint(self, x, yparity=0): if not hasattr(self, 'sqrtmodp'): - self.sqrtmodp = SqrtModP(self.p) + self.sqrtmodp = RootModP(2, self.p) rhs = (x**3 + self.a.n * x**2 + x) / self.b - y = self.sqrtmodp.sqrt(int(rhs)) + y = self.sqrtmodp.root(int(rhs)) if (y - yparity) % 2: y = -y return self.point(x, y) @@ -198,11 +198,11 @@ class TwistedEdwardsCurve(CurveBase): def cpoint(self, y, xparity=0): if not hasattr(self, 'sqrtmodp'): - self.sqrtmodp = SqrtModP(self.p) + self.sqrtmodp = RootModP(self.p) y = ModP(self.p, y) y2 = y**2 radicand = (y2 - 1) / (self.d * y2 - self.a) - x = self.sqrtmodp.sqrt(radicand.n) + x = self.sqrtmodp.root(radicand.n) if (x - xparity) % 2: x = -x return self.point(x, y) diff --git a/test/numbertheory.py b/test/numbertheory.py index ae3ba324..554140c8 100644 --- a/test/numbertheory.py +++ b/test/numbertheory.py @@ -1,5 +1,6 @@ import numbers import itertools +import unittest def invert(a, b): "Multiplicative inverse of a mod b. a,b must be coprime." @@ -36,57 +37,148 @@ def jacobi(n,m): acc *= -1 n, m = m, n -class SqrtModP(object): - """Class for finding square roots of numbers mod p. +class CyclicGroupRootFinder(object): + """Class for finding rth roots in a cyclic group. r must be prime.""" - p must be an odd prime (but its primality is not checked).""" + # Basic strategy: + # + # We write |G| = r^k u, with u coprime to r. This gives us a + # nested sequence of subgroups G = G_0 > G_1 > ... > G_k, each + # with index 3 in its predecessor. G_0 is the whole group, and the + # innermost G_k has order u. + # + # Within G_k, you can take an rth root by raising an element to + # the power of (r^{-1} mod u). If k=0 (so G = G_0 = G_k) then + # that's all that's needed: every element has a unique rth root. + # But if k>0, then things go differently. + # + # Define the 'rank' of an element g as the highest i such that + # g \in G_i. Elements of rank 0 are the non-rth-powers: they don't + # even _have_ an rth root. Elements of rank k are the easy ones to + # take rth roots of, as above. + # + # In between, you can follow an inductive process, as long as you + # know one element z of index 0. Suppose we're trying to take the + # rth root of some g with index i. Repeatedly multiply g by + # z^{r^i} until its index increases; then take the root of that + # (recursively), and divide off z^{r^{i-1}} once you're done. - def __init__(self, p): - p = abs(p) - assert p & 1 - self.p = p + def __init__(self, r, order): + self.order = order # order of G + self.r = r + self.k = next(k for k in itertools.count() + if self.order % (r**(k+1)) != 0) + self.u = self.order // (r**self.k) + self.z = next(z for z in self.iter_elements() + if self.index(z) == 0) + self.zinv = self.inverse(self.z) + self.root_power = invert(self.r, self.u) if self.u > 1 else 0 - # Decompose p as 2^e k + 1 for odd k. - self.k = p-1 - self.e = 0 - while not (self.k & 1): - self.k >>= 1 - self.e += 1 + self.roots_of_unity = {self.identity()} + if self.k > 0: + exponent = self.order // self.r + for z in self.iter_elements(): + root_of_unity = self.pow(z, exponent) + if root_of_unity not in self.roots_of_unity: + self.roots_of_unity.add(root_of_unity) + if len(self.roots_of_unity) == r: + break - # Find a non-square mod p. - for self.z in itertools.count(1): - if jacobi(self.z, self.p) == -1: - break - self.zinv = ModP(self.p, self.z).invert() + def index(self, g): + h = self.pow(g, self.u) + for i in range(self.k+1): + if h == self.identity(): + return self.k - i + h = self.pow(h, self.r) + assert False, ("Not a cyclic group! Raising {} to u r^k should give e." + .format(g)) - def sqrt_recurse(self, a): - ak = pow(a, self.k, self.p) - for i in range(self.e, -1, -1): - if ak == 1: - break - ak = ak*ak % self.p - assert i > 0 - if i == self.e: - return pow(a, (self.k+1) // 2, self.p) - r_prime = self.sqrt_recurse(a * pow(self.z, 2**i, self.p)) - return r_prime * pow(self.zinv, 2**(i-1), self.p) % self.p + def all_roots(self, g): + try: + r = self.root(g) + except ValueError: + return [] + return {r * rou for rou in self.roots_of_unity} - def sqrt(self, a): - j = jacobi(a, self.p) - if j == 0: - return 0 - if j < 0: - raise ValueError("{} has no square root mod {}".format(a, self.p)) - a %= self.p - r = self.sqrt_recurse(a) - assert r*r % self.p == a - # Normalise to the smaller (or 'positive') one of the two roots. - return min(r, self.p - r) + def root(self, g): + i = self.index(g) + if i == 0 and self.k > 0: + raise ValueError("{} has no {}th root".format(g, self.r)) + out = self.root_recurse(g, i) + assert self.pow(out, self.r) == g + return out - def __str__(self): - return "{}({})".format(type(self).__name__, self.p) - def __repr__(self): - return self.__str__() + def root_recurse(self, g, i): + if i == self.k: + return self.pow(g, self.root_power) + z_in = self.pow(self.z, self.r**i) + z_out = self.pow(self.zinv, self.r**(i-1)) + adjust = self.identity() + while True: + g = self.mul(g, z_in) + adjust = self.mul(adjust, z_out) + i2 = self.index(g) + if i2 > i: + return self.mul(self.root_recurse(g, i2), adjust) + +class AdditiveGroupRootFinder(CyclicGroupRootFinder): + """Trivial test subclass for CyclicGroupRootFinder. + + Represents a cyclic group of any order additively, as the integers + mod n under addition. This makes root-finding trivial without + having to use the complicated algorithm above, and therefore it's + a good way to test the complicated algorithm under conditions + where the right answers are obvious.""" + + def __init__(self, r, order): + super().__init__(r, order) + + def mul(self, x, y): + return (x + y) % self.order + def pow(self, x, n): + return (x * n) % self.order + def inverse(self, x): + return (-x) % self.order + def identity(self): + return 0 + def iter_elements(self): + return range(self.order) + +class TestCyclicGroupRootFinder(unittest.TestCase): + def testRootFinding(self): + for order in 10, 11, 12, 18: + grf = AdditiveGroupRootFinder(3, order) + for i in range(order): + try: + r = grf.root(i) + except ValueError: + r = None + + if order % 3 == 0 and i % 3 != 0: + self.assertEqual(r, None) + else: + self.assertEqual(r*3 % order, i) + +class RootModP(CyclicGroupRootFinder): + """The live class that can take rth roots mod a prime.""" + + def __init__(self, r, p): + self.modulus = p + super().__init__(r, p-1) + + def mul(self, x, y): + return (x * y) % self.modulus + def pow(self, x, n): + return pow(x, n, self.modulus) + def inverse(self, x): + return invert(x, self.modulus) + def identity(self): + return 1 + def iter_elements(self): + return range(1, self.modulus) + + def root(self, g): + return 0 if g == 0 else super().root(g) class ModP(object): """Class that represents integers mod p as a field. @@ -179,3 +271,9 @@ class ModP(object): return "{}(0x{:x},0x{:x})".format(type(self).__name__, self.p, self.n) def __hash__(self): return hash((type(self).__name__, self.p, self.n)) + +if __name__ == "__main__": + import sys + if sys.argv[1:] == ["--test"]: + sys.argv[1:2] = [] + unittest.main()