mirror of
https://git.tartarus.org/simon/putty.git
synced 2025-01-09 17:38:00 +00:00
4948b79114
The class for general rth-root finding started off as a cube-root finder before I generalised it, and in one part of the top-level explanatory comment, I still referred to a subgroup having index 3 rather than index r. Also, in a later paragraph, I seem to have said 'index' several times where I meant the concept of 'rank' I defined in the previous paragraph.
639 lines
24 KiB
Python
639 lines
24 KiB
Python
import sys
|
|
import numbers
|
|
import itertools
|
|
import unittest
|
|
|
|
assert sys.version_info[:2] >= (3,0), "This is Python 3 code"
|
|
|
|
def invert(a, b):
|
|
"Multiplicative inverse of a mod b. a,b must be coprime."
|
|
A = (a, 1, 0)
|
|
B = (b, 0, 1)
|
|
while B[0]:
|
|
q = A[0] // B[0]
|
|
A, B = B, tuple(Ai - q*Bi for Ai, Bi in zip(A, B))
|
|
assert abs(A[0]) == 1
|
|
return A[1]*A[0] % b
|
|
|
|
def jacobi(n,m):
|
|
"""Compute the Jacobi symbol.
|
|
|
|
The special case of this when m is prime is the Legendre symbol,
|
|
which is 0 if n is congruent to 0 mod m; 1 if n is congruent to a
|
|
non-zero square number mod m; -1 if n is not congruent to any
|
|
square mod m.
|
|
|
|
"""
|
|
assert m & 1
|
|
acc = 1
|
|
while True:
|
|
n %= m
|
|
if n == 0:
|
|
return 0
|
|
while not (n & 1):
|
|
n >>= 1
|
|
if (m & 7) not in {1,7}:
|
|
acc *= -1
|
|
if n == 1:
|
|
return acc
|
|
if (n & 3) == 3 and (m & 3) == 3:
|
|
acc *= -1
|
|
n, m = m, n
|
|
|
|
class CyclicGroupRootFinder(object):
|
|
"""Class for finding rth roots in a cyclic group. r must be prime."""
|
|
|
|
# Basic strategy:
|
|
#
|
|
# We write |G| = r^k u, with u coprime to r. This gives us a
|
|
# nested sequence of subgroups G = G_0 > G_1 > ... > G_k, each
|
|
# with index r in its predecessor. G_0 is the whole group, and the
|
|
# innermost G_k has order u.
|
|
#
|
|
# Within G_k, you can take an rth root by raising an element to
|
|
# the power of (r^{-1} mod u). If k=0 (so G = G_0 = G_k) then
|
|
# that's all that's needed: every element has a unique rth root.
|
|
# But if k>0, then things go differently.
|
|
#
|
|
# Define the 'rank' of an element g as the highest i such that
|
|
# g \in G_i. Elements of rank 0 are the non-rth-powers: they don't
|
|
# even _have_ an rth root. Elements of rank k are the easy ones to
|
|
# take rth roots of, as above.
|
|
#
|
|
# In between, you can follow an inductive process, as long as you
|
|
# know one element z of rank 0. Suppose we're trying to take the
|
|
# rth root of some g with rank i. Repeatedly multiply g by z^{r^i}
|
|
# until its rank increases; then take the root of that
|
|
# (recursively), and divide off z^{r^{i-1}} once you're done.
|
|
|
|
def __init__(self, r, order):
|
|
self.order = order # order of G
|
|
self.r = r
|
|
self.k = next(k for k in itertools.count()
|
|
if self.order % (r**(k+1)) != 0)
|
|
self.u = self.order // (r**self.k)
|
|
self.z = next(z for z in self.iter_elements()
|
|
if self.index(z) == 0)
|
|
self.zinv = self.inverse(self.z)
|
|
self.root_power = invert(self.r, self.u) if self.u > 1 else 0
|
|
|
|
self.roots_of_unity = {self.identity()}
|
|
if self.k > 0:
|
|
exponent = self.order // self.r
|
|
for z in self.iter_elements():
|
|
root_of_unity = self.pow(z, exponent)
|
|
if root_of_unity not in self.roots_of_unity:
|
|
self.roots_of_unity.add(root_of_unity)
|
|
if len(self.roots_of_unity) == r:
|
|
break
|
|
|
|
def index(self, g):
|
|
h = self.pow(g, self.u)
|
|
for i in range(self.k+1):
|
|
if h == self.identity():
|
|
return self.k - i
|
|
h = self.pow(h, self.r)
|
|
assert False, ("Not a cyclic group! Raising {} to u r^k should give e."
|
|
.format(g))
|
|
|
|
def all_roots(self, g):
|
|
try:
|
|
r = self.root(g)
|
|
except ValueError:
|
|
return []
|
|
return {r * rou for rou in self.roots_of_unity}
|
|
|
|
def root(self, g):
|
|
i = self.index(g)
|
|
if i == 0 and self.k > 0:
|
|
raise ValueError("{} has no {}th root".format(g, self.r))
|
|
out = self.root_recurse(g, i)
|
|
assert self.pow(out, self.r) == g
|
|
return out
|
|
|
|
def root_recurse(self, g, i):
|
|
if i == self.k:
|
|
return self.pow(g, self.root_power)
|
|
z_in = self.pow(self.z, self.r**i)
|
|
z_out = self.pow(self.zinv, self.r**(i-1))
|
|
adjust = self.identity()
|
|
while True:
|
|
g = self.mul(g, z_in)
|
|
adjust = self.mul(adjust, z_out)
|
|
i2 = self.index(g)
|
|
if i2 > i:
|
|
return self.mul(self.root_recurse(g, i2), adjust)
|
|
|
|
class AdditiveGroupRootFinder(CyclicGroupRootFinder):
|
|
"""Trivial test subclass for CyclicGroupRootFinder.
|
|
|
|
Represents a cyclic group of any order additively, as the integers
|
|
mod n under addition. This makes root-finding trivial without
|
|
having to use the complicated algorithm above, and therefore it's
|
|
a good way to test the complicated algorithm under conditions
|
|
where the right answers are obvious."""
|
|
|
|
def __init__(self, r, order):
|
|
super().__init__(r, order)
|
|
|
|
def mul(self, x, y):
|
|
return (x + y) % self.order
|
|
def pow(self, x, n):
|
|
return (x * n) % self.order
|
|
def inverse(self, x):
|
|
return (-x) % self.order
|
|
def identity(self):
|
|
return 0
|
|
def iter_elements(self):
|
|
return range(self.order)
|
|
|
|
class TestCyclicGroupRootFinder(unittest.TestCase):
|
|
def testRootFinding(self):
|
|
for order in 10, 11, 12, 18:
|
|
grf = AdditiveGroupRootFinder(3, order)
|
|
for i in range(order):
|
|
try:
|
|
r = grf.root(i)
|
|
except ValueError:
|
|
r = None
|
|
|
|
if order % 3 == 0 and i % 3 != 0:
|
|
self.assertEqual(r, None)
|
|
else:
|
|
self.assertEqual(r*3 % order, i)
|
|
|
|
class RootModP(CyclicGroupRootFinder):
|
|
"""The live class that can take rth roots mod a prime."""
|
|
|
|
def __init__(self, r, p):
|
|
self.modulus = p
|
|
super().__init__(r, p-1)
|
|
|
|
def mul(self, x, y):
|
|
return (x * y) % self.modulus
|
|
def pow(self, x, n):
|
|
return pow(x, n, self.modulus)
|
|
def inverse(self, x):
|
|
return invert(x, self.modulus)
|
|
def identity(self):
|
|
return 1
|
|
def iter_elements(self):
|
|
return range(1, self.modulus)
|
|
|
|
def root(self, g):
|
|
return 0 if g == 0 else super().root(g)
|
|
|
|
class ModP(object):
|
|
"""Class that represents integers mod p as a field.
|
|
|
|
All the usual arithmetic operations are supported directly,
|
|
including division, so you can write formulas in a natural way
|
|
without having to keep saying '% p' everywhere or call a
|
|
cumbersome modular_inverse() function.
|
|
|
|
"""
|
|
def __init__(self, p, n=0):
|
|
self.p = p
|
|
if isinstance(n, type(self)):
|
|
self.check(n)
|
|
n = n.n
|
|
self.n = n % p
|
|
def check(self, other):
|
|
assert isinstance(other, type(self))
|
|
assert isinstance(self, type(other))
|
|
assert self.p == other.p
|
|
def coerce_to(self, other):
|
|
if not isinstance(other, type(self)):
|
|
other = type(self)(self.p, other)
|
|
else:
|
|
self.check(other)
|
|
return other
|
|
def __int__(self):
|
|
return self.n
|
|
def __add__(self, rhs):
|
|
rhs = self.coerce_to(rhs)
|
|
return type(self)(self.p, (self.n + rhs.n) % self.p)
|
|
def __neg__(self):
|
|
return type(self)(self.p, -self.n % self.p)
|
|
def __radd__(self, rhs):
|
|
rhs = self.coerce_to(rhs)
|
|
return type(self)(self.p, (self.n + rhs.n) % self.p)
|
|
def __sub__(self, rhs):
|
|
rhs = self.coerce_to(rhs)
|
|
return type(self)(self.p, (self.n - rhs.n) % self.p)
|
|
def __rsub__(self, rhs):
|
|
rhs = self.coerce_to(rhs)
|
|
return type(self)(self.p, (rhs.n - self.n) % self.p)
|
|
def __mul__(self, rhs):
|
|
rhs = self.coerce_to(rhs)
|
|
return type(self)(self.p, (self.n * rhs.n) % self.p)
|
|
def __rmul__(self, rhs):
|
|
rhs = self.coerce_to(rhs)
|
|
return type(self)(self.p, (self.n * rhs.n) % self.p)
|
|
def __div__(self, rhs):
|
|
rhs = self.coerce_to(rhs)
|
|
return type(self)(self.p, (self.n * invert(rhs.n, self.p)) % self.p)
|
|
def __rdiv__(self, rhs):
|
|
rhs = self.coerce_to(rhs)
|
|
return type(self)(self.p, (rhs.n * invert(self.n, self.p)) % self.p)
|
|
def __truediv__(self, rhs): return self.__div__(rhs)
|
|
def __rtruediv__(self, rhs): return self.__rdiv__(rhs)
|
|
def __pow__(self, exponent):
|
|
assert exponent >= 0
|
|
n, b_to_n = 1, self
|
|
total = type(self)(self.p, 1)
|
|
while True:
|
|
if exponent & n:
|
|
exponent -= n
|
|
total *= b_to_n
|
|
n *= 2
|
|
if n > exponent:
|
|
break
|
|
b_to_n *= b_to_n
|
|
return total
|
|
def __cmp__(self, rhs):
|
|
rhs = self.coerce_to(rhs)
|
|
return cmp(self.n, rhs.n)
|
|
def __eq__(self, rhs):
|
|
rhs = self.coerce_to(rhs)
|
|
return self.n == rhs.n
|
|
def __ne__(self, rhs):
|
|
rhs = self.coerce_to(rhs)
|
|
return self.n != rhs.n
|
|
def __lt__(self, rhs):
|
|
raise ValueError("Elements of a modular ring have no ordering")
|
|
def __le__(self, rhs):
|
|
raise ValueError("Elements of a modular ring have no ordering")
|
|
def __gt__(self, rhs):
|
|
raise ValueError("Elements of a modular ring have no ordering")
|
|
def __ge__(self, rhs):
|
|
raise ValueError("Elements of a modular ring have no ordering")
|
|
def __str__(self):
|
|
return "0x{:x}".format(self.n)
|
|
def __repr__(self):
|
|
return "{}(0x{:x},0x{:x})".format(type(self).__name__, self.p, self.n)
|
|
def __hash__(self):
|
|
return hash((type(self).__name__, self.p, self.n))
|
|
|
|
class QuadraticFieldExtensionModP(object):
|
|
"""Class representing Z_p[sqrt(d)] for a given non-square d.
|
|
"""
|
|
def __init__(self, p, d, n=0, m=0):
|
|
self.p = p
|
|
self.d = d
|
|
if isinstance(n, ModP):
|
|
assert self.p == n.p
|
|
n = n.n
|
|
if isinstance(m, ModP):
|
|
assert self.p == m.p
|
|
m = m.n
|
|
if isinstance(n, type(self)):
|
|
self.check(n)
|
|
m += n.m
|
|
n = n.n
|
|
self.n = n % p
|
|
self.m = m % p
|
|
|
|
@classmethod
|
|
def constructor(cls, p, d):
|
|
return lambda *args: cls(p, d, *args)
|
|
|
|
def check(self, other):
|
|
assert isinstance(other, type(self))
|
|
assert isinstance(self, type(other))
|
|
assert self.p == other.p
|
|
assert self.d == other.d
|
|
def coerce_to(self, other):
|
|
if not isinstance(other, type(self)):
|
|
other = type(self)(self.p, self.d, other)
|
|
else:
|
|
self.check(other)
|
|
return other
|
|
def __int__(self):
|
|
if self.m != 0:
|
|
raise ValueError("Can't coerce a non-element of Z_{} to integer"
|
|
.format(self.p))
|
|
return int(self.n)
|
|
def __add__(self, rhs):
|
|
rhs = self.coerce_to(rhs)
|
|
return type(self)(self.p, self.d,
|
|
(self.n + rhs.n) % self.p,
|
|
(self.m + rhs.m) % self.p)
|
|
def __neg__(self):
|
|
return type(self)(self.p, self.d,
|
|
-self.n % self.p,
|
|
-self.m % self.p)
|
|
def __radd__(self, rhs):
|
|
rhs = self.coerce_to(rhs)
|
|
return type(self)(self.p, self.d,
|
|
(self.n + rhs.n) % self.p,
|
|
(self.m + rhs.m) % self.p)
|
|
def __sub__(self, rhs):
|
|
rhs = self.coerce_to(rhs)
|
|
return type(self)(self.p, self.d,
|
|
(self.n - rhs.n) % self.p,
|
|
(self.m - rhs.m) % self.p)
|
|
def __rsub__(self, rhs):
|
|
rhs = self.coerce_to(rhs)
|
|
return type(self)(self.p, self.d,
|
|
(rhs.n - self.n) % self.p,
|
|
(rhs.m - self.m) % self.p)
|
|
def __mul__(self, rhs):
|
|
rhs = self.coerce_to(rhs)
|
|
n, m, N, M = self.n, self.m, rhs.n, rhs.m
|
|
return type(self)(self.p, self.d,
|
|
(n*N + self.d*m*M) % self.p,
|
|
(n*M + m*N) % self.p)
|
|
def __rmul__(self, rhs):
|
|
return self.__mul__(rhs)
|
|
def __div__(self, rhs):
|
|
rhs = self.coerce_to(rhs)
|
|
n, m, N, M = self.n, self.m, rhs.n, rhs.m
|
|
# (n+m sqrt d)/(N+M sqrt d) = (n+m sqrt d)(N-M sqrt d)/(N^2-dM^2)
|
|
denom = (N*N - self.d*M*M) % self.p
|
|
if denom == 0:
|
|
raise ValueError("division by zero")
|
|
recipdenom = invert(denom, self.p)
|
|
return type(self)(self.p, self.d,
|
|
(n*N - self.d*m*M) * recipdenom % self.p,
|
|
(m*N - n*M) * recipdenom % self.p)
|
|
def __rdiv__(self, rhs):
|
|
rhs = self.coerce_to(rhs)
|
|
return rhs.__div__(self)
|
|
def __truediv__(self, rhs): return self.__div__(rhs)
|
|
def __rtruediv__(self, rhs): return self.__rdiv__(rhs)
|
|
def __pow__(self, exponent):
|
|
assert exponent >= 0
|
|
n, b_to_n = 1, self
|
|
total = type(self)(self.p, self.d, 1)
|
|
while True:
|
|
if exponent & n:
|
|
exponent -= n
|
|
total *= b_to_n
|
|
n *= 2
|
|
if n > exponent:
|
|
break
|
|
b_to_n *= b_to_n
|
|
return total
|
|
def __cmp__(self, rhs):
|
|
rhs = self.coerce_to(rhs)
|
|
return cmp((self.n, self.m), (rhs.n, rhs.m))
|
|
def __eq__(self, rhs):
|
|
rhs = self.coerce_to(rhs)
|
|
return self.n == rhs.n and self.m == rhs.m
|
|
def __ne__(self, rhs):
|
|
rhs = self.coerce_to(rhs)
|
|
return self.n != rhs.n or self.m != rhs.m
|
|
def __lt__(self, rhs):
|
|
raise ValueError("Elements of a modular ring have no ordering")
|
|
def __le__(self, rhs):
|
|
raise ValueError("Elements of a modular ring have no ordering")
|
|
def __gt__(self, rhs):
|
|
raise ValueError("Elements of a modular ring have no ordering")
|
|
def __ge__(self, rhs):
|
|
raise ValueError("Elements of a modular ring have no ordering")
|
|
def __str__(self):
|
|
if self.m == 0:
|
|
return "0x{:x}".format(self.n)
|
|
else:
|
|
return "0x{:x}+0x{:x}*sqrt({:d})".format(self.n, self.m, self.d)
|
|
def __repr__(self):
|
|
return "{}(0x{:x},0x{:x},0x{:x},0x{:x})".format(
|
|
type(self).__name__, self.p, self.d, self.n, self.m)
|
|
def __hash__(self):
|
|
return hash((type(self).__name__, self.p, self.d, self.n, self.m))
|
|
|
|
class RootInQuadraticExtension(CyclicGroupRootFinder):
|
|
"""Take rth roots in the quadratic extension of Z_p."""
|
|
|
|
def __init__(self, r, p, d):
|
|
self.modulus = p
|
|
self.constructor = QuadraticFieldExtensionModP.constructor(p, d)
|
|
super().__init__(r, p*p-1)
|
|
|
|
def mul(self, x, y):
|
|
return x * y
|
|
def pow(self, x, n):
|
|
return x ** n
|
|
def inverse(self, x):
|
|
return 1/x
|
|
def identity(self):
|
|
return self.constructor(1, 0)
|
|
def iter_elements(self):
|
|
p = self.modulus
|
|
for n_plus_m in range(1, 2*p-1):
|
|
n_min = max(0, n_plus_m-(p-1))
|
|
n_max = min(p-1, n_plus_m)
|
|
for n in range(n_min, n_max + 1):
|
|
m = n_plus_m - n
|
|
assert(0 <= n < p)
|
|
assert(0 <= m < p)
|
|
assert(n != 0 or m != 0)
|
|
yield self.constructor(n, m)
|
|
|
|
def root(self, g):
|
|
return 0 if g == 0 else super().root(g)
|
|
|
|
class EquationSolverModP(object):
|
|
"""Class that can solve quadratics, cubics and quartics over Z_p.
|
|
|
|
p must be a nontrivial prime (bigger than 3).
|
|
"""
|
|
|
|
# This is a port to Z_p of reasonably standard algorithms for
|
|
# solving quadratics, cubics and quartics over the reals.
|
|
#
|
|
# When you solve a cubic in R, you sometimes have to deal with
|
|
# intermediate results that are complex numbers. In particular,
|
|
# you have to solve a quadratic whose coefficients are in R but
|
|
# its roots may be complex, and then having solved that quadratic,
|
|
# you need to iterate over all three cube roots of the solution in
|
|
# order to recover all the roots of your cubic. (Even if the cubic
|
|
# ends up having three real roots, you can't calculate them
|
|
# without going through those complex intermediate values.)
|
|
#
|
|
# So over Z_p, the same thing applies: we're going to need to be
|
|
# able to solve any quadratic with coefficients in Z_p, even if
|
|
# its discriminant turns out not to be a quadratic residue mod p,
|
|
# and then we'll need to find _three_ cube roots of the result,
|
|
# even if p == 2 (mod 3) so that numbers only have one cube root
|
|
# each.
|
|
#
|
|
# Both of these problems can be solved at once if we work in the
|
|
# finite field GF(p^2), i.e. make a quadratic field extension of
|
|
# Z_p by adjoining a square root of some non-square d. The
|
|
# multiplicative group of GF(p^2) is cyclic and has order p^2-1 =
|
|
# (p-1)(p+1), with the mult group of Z_p forming the unique
|
|
# subgroup of order (p-1) within it. So we've multiplied the group
|
|
# order by p+1, which is even (since by assumption p > 3), and
|
|
# therefore a square root is now guaranteed to exist for every
|
|
# number in the Z_p subgroup. Moreover, no matter whether p itself
|
|
# was congruent to 1 or 2 mod 3, p^2 is always congruent to 1,
|
|
# which means that the mult group of GF(p^2) has order divisible
|
|
# by 3. So there are guaranteed to be three distinct cube roots of
|
|
# unity, and hence, three cube roots of any number that's a cube
|
|
# at all.
|
|
#
|
|
# Quartics don't introduce any additional problems. To solve a
|
|
# quartic, you factorise it into two quadratic factors, by solving
|
|
# a cubic to find one of the coefficients. So if you can already
|
|
# solve cubics, then you're more or less done. The only wrinkle is
|
|
# that the two quadratic factors will have coefficients in GF(p^2)
|
|
# but not necessarily in Z_p. But that doesn't stop us at least
|
|
# _trying_ to solve them by taking square roots in GF(p^2) - and
|
|
# if the discriminant of one of those quadratics has is not a
|
|
# square even in GF(p^2), then its solutions will only exist if
|
|
# you escalate further to GF(p^4), in which case the answer is
|
|
# simply that there aren't any solutions in Z_p to that quadratic.
|
|
|
|
def __init__(self, p):
|
|
self.p = p
|
|
self.nonsquare_mod_p = d = RootModP(2, p).z
|
|
self.constructor = QuadraticFieldExtensionModP.constructor(p, d)
|
|
self.sqrt = RootInQuadraticExtension(2, p, d)
|
|
self.cbrt = RootInQuadraticExtension(3, p, d)
|
|
|
|
def solve_quadratic(self, a, b, c):
|
|
"Solve ax^2 + bx + c = 0."
|
|
a, b, c = map(self.constructor, (a, b, c))
|
|
assert a != 0
|
|
return self.solve_monic_quadratic(b/a, c/a)
|
|
|
|
def solve_monic_quadratic(self, b, c):
|
|
"Solve x^2 + bx + c = 0."
|
|
b, c = map(self.constructor, (b, c))
|
|
s = b/2
|
|
return [y - s for y in self.solve_depressed_quadratic(c - s*s)]
|
|
|
|
def solve_depressed_quadratic(self, c):
|
|
"Solve x^2 + c = 0."
|
|
return self.sqrt.all_roots(-c)
|
|
|
|
def solve_cubic(self, a, b, c, d):
|
|
"Solve ax^3 + bx^2 + cx + d = 0."
|
|
a, b, c, d = map(self.constructor, (a, b, c, d))
|
|
assert a != 0
|
|
return self.solve_monic_cubic(b/a, c/a, d/a)
|
|
|
|
def solve_monic_cubic(self, b, c, d):
|
|
"Solve x^3 + bx^2 + cx + d = 0."
|
|
b, c, d = map(self.constructor, (b, c, d))
|
|
s = b/3
|
|
return [y - s for y in self.solve_depressed_cubic(
|
|
c - 3*s*s, 2*s*s*s - c*s + d)]
|
|
|
|
def solve_depressed_cubic(self, c, d):
|
|
"Solve x^3 + cx + d = 0."
|
|
c, d = map(self.constructor, (c, d))
|
|
solutions = set()
|
|
# To solve x^3 + cx + d = 0, set p = -c/3, then
|
|
# substitute x = z + p/z to get z^6 + d z^3 + p^3 = 0.
|
|
# Solve that quadratic for z^3, then take cube roots.
|
|
p = -c/3
|
|
for z3 in self.solve_monic_quadratic(d, p**3):
|
|
# As I understand the theory, we _should_ only need to
|
|
# take cube roots of one root of that quadratic: the other
|
|
# one should give the same set of answers after you map
|
|
# each one through z |-> z+p/z. But speed isn't at a
|
|
# premium here, so I'll do this the way that must work.
|
|
for z in self.cbrt.all_roots(z3):
|
|
solutions.add(z + p/z)
|
|
return solutions
|
|
|
|
def solve_quartic(self, a, b, c, d, e):
|
|
"Solve ax^4 + bx^3 + cx^2 + dx + e = 0."
|
|
a, b, c, d, e = map(self.constructor, (a, b, c, d, e))
|
|
assert a != 0
|
|
return self.solve_monic_quartic(b/a, c/a, d/a, e/a)
|
|
|
|
def solve_monic_quartic(self, b, c, d, e):
|
|
"Solve x^4 + bx^3 + cx^2 + dx + e = 0."
|
|
b, c, d, e = map(self.constructor, (b, c, d, e))
|
|
s = b/4
|
|
return [y - s for y in self.solve_depressed_quartic(
|
|
c - 6*s*s, d - 2*c*s + 8*s*s*s, e - d*s + c*s*s - 3*s*s*s*s)]
|
|
|
|
def solve_depressed_quartic(self, c, d, e):
|
|
"Solve x^4 + cx^2 + dx + e = 0."
|
|
c, d, e = map(self.constructor, (c, d, e))
|
|
solutions = set()
|
|
# To solve an equation of this form, we search for a value y
|
|
# such that subtracting the original polynomial from (x^2+y)^2
|
|
# yields a quadratic of the special form (ux+v)^2.
|
|
#
|
|
# Then our equation is rewritten as (x^2+y)^2 - (ux+v)^2 = 0
|
|
# i.e. ((x^2+y) + (ux+v)) ((x^2+y) - (ux+v)) = 0
|
|
# i.e. the product of two quadratics, each of which we then solve.
|
|
#
|
|
# To find y, we write down the discriminant of the quadratic
|
|
# (x^2+y)^2 - (x^4 + cx^2 + dx + e) and set it to 0, which
|
|
# gives a cubic in y. Maxima gives the coefficients as
|
|
# (-8)y^3 + (4c)y^2 + (8e)y + (d^2-4ce).
|
|
#
|
|
# As above, we _should_ only need one value of y. But I go
|
|
# through them all just in case, because I don't care about
|
|
# speed, and because checking the assertions inside this loop
|
|
# for every value is extra reassurance that I've done all of
|
|
# this right.
|
|
for y in self.solve_cubic(-8, 4*c, 8*e, d*d-4*c*e):
|
|
# Subtract the original equation from (x^2+y)^2 to get the
|
|
# coefficients of our quadratic residual.
|
|
A, B, C = 2*y-c, -d, y*y-e
|
|
# Expect that to have zero discriminant, i.e. a repeated root.
|
|
assert B*B - 4*A*C == 0
|
|
# If (Ax^2+Bx+C) == (ux+v)^2 then we have u^2=A, 2uv=B, v^2=C.
|
|
# So we can either recover u as sqrt(A) or v as sqrt(C), and
|
|
# whichever we did, find the other from B by division. But
|
|
# either of the end coefficients might be zero, so we have
|
|
# to be prepared to try either option.
|
|
try:
|
|
if A != 0:
|
|
u = self.sqrt.root(A)
|
|
v = B/(2*u)
|
|
elif C != 0:
|
|
v = self.sqrt.root(C)
|
|
u = B/(2*v)
|
|
else:
|
|
# One last possibility is that all three coefficients
|
|
# of our residual quadratic are 0, in which case,
|
|
# obviously, u=v=0 as well.
|
|
u = v = 0
|
|
except ValueError:
|
|
# If Ax^2+Bx+C looked like a perfect square going by
|
|
# its discriminant, but actually taking the square
|
|
# root of A or C threw an exception, that means that
|
|
# it's the square of a polynomial whose coefficients
|
|
# live in a yet-higher field extension of Z_p. In that
|
|
# case we're not going to end up with roots of the
|
|
# original quartic in Z_p if we start from here!
|
|
continue
|
|
# So now our quartic is factorised into the form
|
|
# (x^2 - ux - v + y) (x^2 + ux + v + y).
|
|
for x in self.solve_monic_quadratic(-u, y-v):
|
|
solutions.add(x)
|
|
for x in self.solve_monic_quadratic(u, y+v):
|
|
solutions.add(x)
|
|
return solutions
|
|
|
|
class EquationSolverTest(unittest.TestCase):
|
|
def testQuadratic(self):
|
|
E = EquationSolverModP(11)
|
|
solns = E.solve_quadratic(3, 2, 6)
|
|
self.assertEqual(sorted(map(str, solns)), ["0x1", "0x2"])
|
|
|
|
def testCubic(self):
|
|
E = EquationSolverModP(11)
|
|
solns = E.solve_cubic(7, 2, 0, 2)
|
|
self.assertEqual(sorted(map(str, solns)), ["0x1", "0x2", "0x3"])
|
|
|
|
def testQuartic(self):
|
|
E = EquationSolverModP(11)
|
|
solns = E.solve_quartic(9, 9, 7, 1, 7)
|
|
self.assertEqual(sorted(map(str, solns)), ["0x1", "0x2", "0x3", "0x4"])
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
if sys.argv[1:] == ["--test"]:
|
|
sys.argv[1:2] = []
|
|
unittest.main()
|