diff --git a/test/numbertheory.py b/test/numbertheory.py index 554140c8..5c22e936 100644 --- a/test/numbertheory.py +++ b/test/numbertheory.py @@ -272,6 +272,362 @@ class ModP(object): def __hash__(self): return hash((type(self).__name__, self.p, self.n)) +class QuadraticFieldExtensionModP(object): + """Class representing Z_p[sqrt(d)] for a given non-square d. + """ + def __init__(self, p, d, n=0, m=0): + self.p = p + self.d = d + if isinstance(n, ModP): + assert self.p == n.p + n = n.n + if isinstance(m, ModP): + assert self.p == m.p + m = m.n + if isinstance(n, type(self)): + self.check(n) + m += n.m + n = n.n + self.n = n % p + self.m = m % p + + @classmethod + def constructor(cls, p, d): + return lambda *args: cls(p, d, *args) + + def check(self, other): + assert isinstance(other, type(self)) + assert isinstance(self, type(other)) + assert self.p == other.p + assert self.d == other.d + def coerce_to(self, other): + if not isinstance(other, type(self)): + other = type(self)(self.p, self.d, other) + else: + self.check(other) + return other + def __int__(self): + if self.m != 0: + raise ValueError("Can't coerce a non-element of Z_{} to integer" + .format(self.p)) + return int(self.n) + def __add__(self, rhs): + rhs = self.coerce_to(rhs) + return type(self)(self.p, self.d, + (self.n + rhs.n) % self.p, + (self.m + rhs.m) % self.p) + def __neg__(self): + return type(self)(self.p, self.d, + -self.n % self.p, + -self.m % self.p) + def __radd__(self, rhs): + rhs = self.coerce_to(rhs) + return type(self)(self.p, self.d, + (self.n + rhs.n) % self.p, + (self.m + rhs.m) % self.p) + def __sub__(self, rhs): + rhs = self.coerce_to(rhs) + return type(self)(self.p, self.d, + (self.n - rhs.n) % self.p, + (self.m - rhs.m) % self.p) + def __rsub__(self, rhs): + rhs = self.coerce_to(rhs) + return type(self)(self.p, self.d, + (rhs.n - self.n) % self.p, + (rhs.m - self.m) % self.p) + def __mul__(self, rhs): + rhs = self.coerce_to(rhs) + n, m, N, M = self.n, self.m, rhs.n, rhs.m + return type(self)(self.p, self.d, + (n*N + self.d*m*M) % self.p, + (n*M + m*N) % self.p) + def __rmul__(self, rhs): + return self.__mul__(rhs) + def __div__(self, rhs): + rhs = self.coerce_to(rhs) + n, m, N, M = self.n, self.m, rhs.n, rhs.m + # (n+m sqrt d)/(N+M sqrt d) = (n+m sqrt d)(N-M sqrt d)/(N^2-dM^2) + denom = (N*N - self.d*M*M) % self.p + if denom == 0: + raise ValueError("division by zero") + recipdenom = invert(denom, self.p) + return type(self)(self.p, self.d, + (n*N - self.d*m*M) * recipdenom % self.p, + (m*N - n*M) * recipdenom % self.p) + def __rdiv__(self, rhs): + rhs = self.coerce_to(rhs) + return rhs.__div__(self) + def __truediv__(self, rhs): return self.__div__(rhs) + def __rtruediv__(self, rhs): return self.__rdiv__(rhs) + def __pow__(self, exponent): + assert exponent >= 0 + n, b_to_n = 1, self + total = type(self)(self.p, self.d, 1) + while True: + if exponent & n: + exponent -= n + total *= b_to_n + n *= 2 + if n > exponent: + break + b_to_n *= b_to_n + return total + def __cmp__(self, rhs): + rhs = self.coerce_to(rhs) + return cmp((self.n, self.m), (rhs.n, rhs.m)) + def __eq__(self, rhs): + rhs = self.coerce_to(rhs) + return self.n == rhs.n and self.m == rhs.m + def __ne__(self, rhs): + rhs = self.coerce_to(rhs) + return self.n != rhs.n or self.m != rhs.m + def __lt__(self, rhs): + raise ValueError("Elements of a modular ring have no ordering") + def __le__(self, rhs): + raise ValueError("Elements of a modular ring have no ordering") + def __gt__(self, rhs): + raise ValueError("Elements of a modular ring have no ordering") + def __ge__(self, rhs): + raise ValueError("Elements of a modular ring have no ordering") + def __str__(self): + if self.m == 0: + return "0x{:x}".format(self.n) + else: + return "0x{:x}+0x{:x}*sqrt({:d})".format(self.n, self.m, self.d) + def __repr__(self): + return "{}(0x{:x},0x{:x},0x{:x},0x{:x})".format( + type(self).__name__, self.p, self.d, self.n, self.m) + def __hash__(self): + return hash((type(self).__name__, self.p, self.d, self.n, self.m)) + +class RootInQuadraticExtension(CyclicGroupRootFinder): + """Take rth roots in the quadratic extension of Z_p.""" + + def __init__(self, r, p, d): + self.modulus = p + self.constructor = QuadraticFieldExtensionModP.constructor(p, d) + super().__init__(r, p*p-1) + + def mul(self, x, y): + return x * y + def pow(self, x, n): + return x ** n + def inverse(self, x): + return 1/x + def identity(self): + return self.constructor(1, 0) + def iter_elements(self): + p = self.modulus + for n_plus_m in range(1, 2*p-1): + n_min = max(0, n_plus_m-(p-1)) + n_max = min(p-1, n_plus_m) + for n in range(n_min, n_max + 1): + m = n_plus_m - n + assert(0 <= n < p) + assert(0 <= m < p) + assert(n != 0 or m != 0) + yield self.constructor(n, m) + + def root(self, g): + return 0 if g == 0 else super().root(g) + +class EquationSolverModP(object): + """Class that can solve quadratics, cubics and quartics over Z_p. + + p must be a nontrivial prime (bigger than 3). + """ + + # This is a port to Z_p of reasonably standard algorithms for + # solving quadratics, cubics and quartics over the reals. + # + # When you solve a cubic in R, you sometimes have to deal with + # intermediate results that are complex numbers. In particular, + # you have to solve a quadratic whose coefficients are in R but + # its roots may be complex, and then having solved that quadratic, + # you need to iterate over all three cube roots of the solution in + # order to recover all the roots of your cubic. (Even if the cubic + # ends up having three real roots, you can't calculate them + # without going through those complex intermediate values.) + # + # So over Z_p, the same thing applies: we're going to need to be + # able to solve any quadratic with coefficients in Z_p, even if + # its discriminant turns out not to be a quadratic residue mod p, + # and then we'll need to find _three_ cube roots of the result, + # even if p == 2 (mod 3) so that numbers only have one cube root + # each. + # + # Both of these problems can be solved at once if we work in the + # finite field GF(p^2), i.e. make a quadratic field extension of + # Z_p by adjoining a square root of some non-square d. The + # multiplicative group of GF(p^2) is cyclic and has order p^2-1 = + # (p-1)(p+1), with the mult group of Z_p forming the unique + # subgroup of order (p-1) within it. So we've multiplied the group + # order by p+1, which is even (since by assumption p > 3), and + # therefore a square root is now guaranteed to exist for every + # number in the Z_p subgroup. Moreover, no matter whether p itself + # was congruent to 1 or 2 mod 3, p^2 is always congruent to 1, + # which means that the mult group of GF(p^2) has order divisible + # by 3. So there are guaranteed to be three distinct cube roots of + # unity, and hence, three cube roots of any number that's a cube + # at all. + # + # Quartics don't introduce any additional problems. To solve a + # quartic, you factorise it into two quadratic factors, by solving + # a cubic to find one of the coefficients. So if you can already + # solve cubics, then you're more or less done. The only wrinkle is + # that the two quadratic factors will have coefficients in GF(p^2) + # but not necessarily in Z_p. But that doesn't stop us at least + # _trying_ to solve them by taking square roots in GF(p^2) - and + # if the discriminant of one of those quadratics has is not a + # square even in GF(p^2), then its solutions will only exist if + # you escalate further to GF(p^4), in which case the answer is + # simply that there aren't any solutions in Z_p to that quadratic. + + def __init__(self, p): + self.p = p + self.nonsquare_mod_p = d = RootModP(2, p).z + self.constructor = QuadraticFieldExtensionModP.constructor(p, d) + self.sqrt = RootInQuadraticExtension(2, p, d) + self.cbrt = RootInQuadraticExtension(3, p, d) + + def solve_quadratic(self, a, b, c): + "Solve ax^2 + bx + c = 0." + a, b, c = map(self.constructor, (a, b, c)) + assert a != 0 + return self.solve_monic_quadratic(b/a, c/a) + + def solve_monic_quadratic(self, b, c): + "Solve x^2 + bx + c = 0." + b, c = map(self.constructor, (b, c)) + s = b/2 + return [y - s for y in self.solve_depressed_quadratic(c - s*s)] + + def solve_depressed_quadratic(self, c): + "Solve x^2 + c = 0." + return self.sqrt.all_roots(-c) + + def solve_cubic(self, a, b, c, d): + "Solve ax^3 + bx^2 + cx + d = 0." + a, b, c, d = map(self.constructor, (a, b, c, d)) + assert a != 0 + return self.solve_monic_cubic(b/a, c/a, d/a) + + def solve_monic_cubic(self, b, c, d): + "Solve x^3 + bx^2 + cx + d = 0." + b, c, d = map(self.constructor, (b, c, d)) + s = b/3 + return [y - s for y in self.solve_depressed_cubic( + c - 3*s*s, 2*s*s*s - c*s + d)] + + def solve_depressed_cubic(self, c, d): + "Solve x^3 + cx + d = 0." + c, d = map(self.constructor, (c, d)) + solutions = set() + # To solve x^3 + cx + d = 0, set p = -c/3, then + # substitute x = z + p/z to get z^6 + d z^3 + p^3 = 0. + # Solve that quadratic for z^3, then take cube roots. + p = -c/3 + for z3 in self.solve_monic_quadratic(d, p**3): + # As I understand the theory, we _should_ only need to + # take cube roots of one root of that quadratic: the other + # one should give the same set of answers after you map + # each one through z |-> z+p/z. But speed isn't at a + # premium here, so I'll do this the way that must work. + for z in self.cbrt.all_roots(z3): + solutions.add(z + p/z) + return solutions + + def solve_quartic(self, a, b, c, d, e): + "Solve ax^4 + bx^3 + cx^2 + dx + e = 0." + a, b, c, d, e = map(self.constructor, (a, b, c, d, e)) + assert a != 0 + return self.solve_monic_quartic(b/a, c/a, d/a, e/a) + + def solve_monic_quartic(self, b, c, d, e): + "Solve x^4 + bx^3 + cx^2 + dx + e = 0." + b, c, d, e = map(self.constructor, (b, c, d, e)) + s = b/4 + return [y - s for y in self.solve_depressed_quartic( + c - 6*s*s, d - 2*c*s + 8*s*s*s, e - d*s + c*s*s - 3*s*s*s*s)] + + def solve_depressed_quartic(self, c, d, e): + "Solve x^4 + cx^2 + dx + e = 0." + c, d, e = map(self.constructor, (c, d, e)) + solutions = set() + # To solve an equation of this form, we search for a value y + # such that subtracting the original polynomial from (x^2+y)^2 + # yields a quadratic of the special form (ux+v)^2. + # + # Then our equation is rewritten as (x^2+y)^2 - (ux+v)^2 = 0 + # i.e. ((x^2+y) + (ux+v)) ((x^2+y) - (ux+v)) = 0 + # i.e. the product of two quadratics, each of which we then solve. + # + # To find y, we write down the discriminant of the quadratic + # (x^2+y)^2 - (x^4 + cx^2 + dx + e) and set it to 0, which + # gives a cubic in y. Maxima gives the coefficients as + # (-8)y^3 + (4c)y^2 + (8e)y + (d^2-4ce). + # + # As above, we _should_ only need one value of y. But I go + # through them all just in case, because I don't care about + # speed, and because checking the assertions inside this loop + # for every value is extra reassurance that I've done all of + # this right. + for y in self.solve_cubic(-8, 4*c, 8*e, d*d-4*c*e): + # Subtract the original equation from (x^2+y)^2 to get the + # coefficients of our quadratic residual. + A, B, C = 2*y-c, -d, y*y-e + # Expect that to have zero discriminant, i.e. a repeated root. + assert B*B - 4*A*C == 0 + # If (Ax^2+Bx+C) == (ux+v)^2 then we have u^2=A, 2uv=B, v^2=C. + # So we can either recover u as sqrt(A) or v as sqrt(C), and + # whichever we did, find the other from B by division. But + # either of the end coefficients might be zero, so we have + # to be prepared to try either option. + try: + if A != 0: + u = self.sqrt.root(A) + v = B/(2*u) + elif C != 0: + v = self.sqrt.root(C) + u = B/(2*v) + else: + # One last possibility is that all three coefficients + # of our residual quadratic are 0, in which case, + # obviously, u=v=0 as well. + u = v = 0 + except ValueError: + # If Ax^2+Bx+C looked like a perfect square going by + # its discriminant, but actually taking the square + # root of A or C threw an exception, that means that + # it's the square of a polynomial whose coefficients + # live in a yet-higher field extension of Z_p. In that + # case we're not going to end up with roots of the + # original quartic in Z_p if we start from here! + continue + # So now our quartic is factorised into the form + # (x^2 - ux - v + y) (x^2 + ux + v + y). + for x in self.solve_monic_quadratic(-u, y-v): + solutions.add(x) + for x in self.solve_monic_quadratic(u, y+v): + solutions.add(x) + return solutions + +class EquationSolverTest(unittest.TestCase): + def testQuadratic(self): + E = EquationSolverModP(11) + solns = E.solve_quadratic(3, 2, 6) + self.assertEqual(sorted(map(str, solns)), ["0x1", "0x2"]) + + def testCubic(self): + E = EquationSolverModP(11) + solns = E.solve_cubic(7, 2, 0, 2) + self.assertEqual(sorted(map(str, solns)), ["0x1", "0x2", "0x3"]) + + def testQuartic(self): + E = EquationSolverModP(11) + solns = E.solve_quartic(9, 9, 7, 1, 7) + self.assertEqual(sorted(map(str, solns)), ["0x1", "0x2", "0x3", "0x4"]) + if __name__ == "__main__": import sys if sys.argv[1:] == ["--test"]: