1
0
mirror of https://git.tartarus.org/simon/putty.git synced 2025-01-09 17:38:00 +00:00

numbertheory.py: cubic and quartic solver mod p.

I'm going to want to use this for finding special values in elliptic
curves' ground fields.

In order to solve cubics and quartics in F_p, you have to work in
F_{p^2}, for much the same reasons that you have to be willing to use
complex numbers if you want to solve general cubics over the reals
(even if all the eventual roots turn out to be real after all). So
I've also introduced another arithmetic class to work in that kind of
field, and a shim that glues that on to the cyclic-group root finder
from the previous commit.
This commit is contained in:
Simon Tatham 2020-02-28 20:17:15 +00:00
parent 072d3c665a
commit f82af9ffe2

View File

@ -272,6 +272,362 @@ class ModP(object):
def __hash__(self):
return hash((type(self).__name__, self.p, self.n))
class QuadraticFieldExtensionModP(object):
"""Class representing Z_p[sqrt(d)] for a given non-square d.
"""
def __init__(self, p, d, n=0, m=0):
self.p = p
self.d = d
if isinstance(n, ModP):
assert self.p == n.p
n = n.n
if isinstance(m, ModP):
assert self.p == m.p
m = m.n
if isinstance(n, type(self)):
self.check(n)
m += n.m
n = n.n
self.n = n % p
self.m = m % p
@classmethod
def constructor(cls, p, d):
return lambda *args: cls(p, d, *args)
def check(self, other):
assert isinstance(other, type(self))
assert isinstance(self, type(other))
assert self.p == other.p
assert self.d == other.d
def coerce_to(self, other):
if not isinstance(other, type(self)):
other = type(self)(self.p, self.d, other)
else:
self.check(other)
return other
def __int__(self):
if self.m != 0:
raise ValueError("Can't coerce a non-element of Z_{} to integer"
.format(self.p))
return int(self.n)
def __add__(self, rhs):
rhs = self.coerce_to(rhs)
return type(self)(self.p, self.d,
(self.n + rhs.n) % self.p,
(self.m + rhs.m) % self.p)
def __neg__(self):
return type(self)(self.p, self.d,
-self.n % self.p,
-self.m % self.p)
def __radd__(self, rhs):
rhs = self.coerce_to(rhs)
return type(self)(self.p, self.d,
(self.n + rhs.n) % self.p,
(self.m + rhs.m) % self.p)
def __sub__(self, rhs):
rhs = self.coerce_to(rhs)
return type(self)(self.p, self.d,
(self.n - rhs.n) % self.p,
(self.m - rhs.m) % self.p)
def __rsub__(self, rhs):
rhs = self.coerce_to(rhs)
return type(self)(self.p, self.d,
(rhs.n - self.n) % self.p,
(rhs.m - self.m) % self.p)
def __mul__(self, rhs):
rhs = self.coerce_to(rhs)
n, m, N, M = self.n, self.m, rhs.n, rhs.m
return type(self)(self.p, self.d,
(n*N + self.d*m*M) % self.p,
(n*M + m*N) % self.p)
def __rmul__(self, rhs):
return self.__mul__(rhs)
def __div__(self, rhs):
rhs = self.coerce_to(rhs)
n, m, N, M = self.n, self.m, rhs.n, rhs.m
# (n+m sqrt d)/(N+M sqrt d) = (n+m sqrt d)(N-M sqrt d)/(N^2-dM^2)
denom = (N*N - self.d*M*M) % self.p
if denom == 0:
raise ValueError("division by zero")
recipdenom = invert(denom, self.p)
return type(self)(self.p, self.d,
(n*N - self.d*m*M) * recipdenom % self.p,
(m*N - n*M) * recipdenom % self.p)
def __rdiv__(self, rhs):
rhs = self.coerce_to(rhs)
return rhs.__div__(self)
def __truediv__(self, rhs): return self.__div__(rhs)
def __rtruediv__(self, rhs): return self.__rdiv__(rhs)
def __pow__(self, exponent):
assert exponent >= 0
n, b_to_n = 1, self
total = type(self)(self.p, self.d, 1)
while True:
if exponent & n:
exponent -= n
total *= b_to_n
n *= 2
if n > exponent:
break
b_to_n *= b_to_n
return total
def __cmp__(self, rhs):
rhs = self.coerce_to(rhs)
return cmp((self.n, self.m), (rhs.n, rhs.m))
def __eq__(self, rhs):
rhs = self.coerce_to(rhs)
return self.n == rhs.n and self.m == rhs.m
def __ne__(self, rhs):
rhs = self.coerce_to(rhs)
return self.n != rhs.n or self.m != rhs.m
def __lt__(self, rhs):
raise ValueError("Elements of a modular ring have no ordering")
def __le__(self, rhs):
raise ValueError("Elements of a modular ring have no ordering")
def __gt__(self, rhs):
raise ValueError("Elements of a modular ring have no ordering")
def __ge__(self, rhs):
raise ValueError("Elements of a modular ring have no ordering")
def __str__(self):
if self.m == 0:
return "0x{:x}".format(self.n)
else:
return "0x{:x}+0x{:x}*sqrt({:d})".format(self.n, self.m, self.d)
def __repr__(self):
return "{}(0x{:x},0x{:x},0x{:x},0x{:x})".format(
type(self).__name__, self.p, self.d, self.n, self.m)
def __hash__(self):
return hash((type(self).__name__, self.p, self.d, self.n, self.m))
class RootInQuadraticExtension(CyclicGroupRootFinder):
"""Take rth roots in the quadratic extension of Z_p."""
def __init__(self, r, p, d):
self.modulus = p
self.constructor = QuadraticFieldExtensionModP.constructor(p, d)
super().__init__(r, p*p-1)
def mul(self, x, y):
return x * y
def pow(self, x, n):
return x ** n
def inverse(self, x):
return 1/x
def identity(self):
return self.constructor(1, 0)
def iter_elements(self):
p = self.modulus
for n_plus_m in range(1, 2*p-1):
n_min = max(0, n_plus_m-(p-1))
n_max = min(p-1, n_plus_m)
for n in range(n_min, n_max + 1):
m = n_plus_m - n
assert(0 <= n < p)
assert(0 <= m < p)
assert(n != 0 or m != 0)
yield self.constructor(n, m)
def root(self, g):
return 0 if g == 0 else super().root(g)
class EquationSolverModP(object):
"""Class that can solve quadratics, cubics and quartics over Z_p.
p must be a nontrivial prime (bigger than 3).
"""
# This is a port to Z_p of reasonably standard algorithms for
# solving quadratics, cubics and quartics over the reals.
#
# When you solve a cubic in R, you sometimes have to deal with
# intermediate results that are complex numbers. In particular,
# you have to solve a quadratic whose coefficients are in R but
# its roots may be complex, and then having solved that quadratic,
# you need to iterate over all three cube roots of the solution in
# order to recover all the roots of your cubic. (Even if the cubic
# ends up having three real roots, you can't calculate them
# without going through those complex intermediate values.)
#
# So over Z_p, the same thing applies: we're going to need to be
# able to solve any quadratic with coefficients in Z_p, even if
# its discriminant turns out not to be a quadratic residue mod p,
# and then we'll need to find _three_ cube roots of the result,
# even if p == 2 (mod 3) so that numbers only have one cube root
# each.
#
# Both of these problems can be solved at once if we work in the
# finite field GF(p^2), i.e. make a quadratic field extension of
# Z_p by adjoining a square root of some non-square d. The
# multiplicative group of GF(p^2) is cyclic and has order p^2-1 =
# (p-1)(p+1), with the mult group of Z_p forming the unique
# subgroup of order (p-1) within it. So we've multiplied the group
# order by p+1, which is even (since by assumption p > 3), and
# therefore a square root is now guaranteed to exist for every
# number in the Z_p subgroup. Moreover, no matter whether p itself
# was congruent to 1 or 2 mod 3, p^2 is always congruent to 1,
# which means that the mult group of GF(p^2) has order divisible
# by 3. So there are guaranteed to be three distinct cube roots of
# unity, and hence, three cube roots of any number that's a cube
# at all.
#
# Quartics don't introduce any additional problems. To solve a
# quartic, you factorise it into two quadratic factors, by solving
# a cubic to find one of the coefficients. So if you can already
# solve cubics, then you're more or less done. The only wrinkle is
# that the two quadratic factors will have coefficients in GF(p^2)
# but not necessarily in Z_p. But that doesn't stop us at least
# _trying_ to solve them by taking square roots in GF(p^2) - and
# if the discriminant of one of those quadratics has is not a
# square even in GF(p^2), then its solutions will only exist if
# you escalate further to GF(p^4), in which case the answer is
# simply that there aren't any solutions in Z_p to that quadratic.
def __init__(self, p):
self.p = p
self.nonsquare_mod_p = d = RootModP(2, p).z
self.constructor = QuadraticFieldExtensionModP.constructor(p, d)
self.sqrt = RootInQuadraticExtension(2, p, d)
self.cbrt = RootInQuadraticExtension(3, p, d)
def solve_quadratic(self, a, b, c):
"Solve ax^2 + bx + c = 0."
a, b, c = map(self.constructor, (a, b, c))
assert a != 0
return self.solve_monic_quadratic(b/a, c/a)
def solve_monic_quadratic(self, b, c):
"Solve x^2 + bx + c = 0."
b, c = map(self.constructor, (b, c))
s = b/2
return [y - s for y in self.solve_depressed_quadratic(c - s*s)]
def solve_depressed_quadratic(self, c):
"Solve x^2 + c = 0."
return self.sqrt.all_roots(-c)
def solve_cubic(self, a, b, c, d):
"Solve ax^3 + bx^2 + cx + d = 0."
a, b, c, d = map(self.constructor, (a, b, c, d))
assert a != 0
return self.solve_monic_cubic(b/a, c/a, d/a)
def solve_monic_cubic(self, b, c, d):
"Solve x^3 + bx^2 + cx + d = 0."
b, c, d = map(self.constructor, (b, c, d))
s = b/3
return [y - s for y in self.solve_depressed_cubic(
c - 3*s*s, 2*s*s*s - c*s + d)]
def solve_depressed_cubic(self, c, d):
"Solve x^3 + cx + d = 0."
c, d = map(self.constructor, (c, d))
solutions = set()
# To solve x^3 + cx + d = 0, set p = -c/3, then
# substitute x = z + p/z to get z^6 + d z^3 + p^3 = 0.
# Solve that quadratic for z^3, then take cube roots.
p = -c/3
for z3 in self.solve_monic_quadratic(d, p**3):
# As I understand the theory, we _should_ only need to
# take cube roots of one root of that quadratic: the other
# one should give the same set of answers after you map
# each one through z |-> z+p/z. But speed isn't at a
# premium here, so I'll do this the way that must work.
for z in self.cbrt.all_roots(z3):
solutions.add(z + p/z)
return solutions
def solve_quartic(self, a, b, c, d, e):
"Solve ax^4 + bx^3 + cx^2 + dx + e = 0."
a, b, c, d, e = map(self.constructor, (a, b, c, d, e))
assert a != 0
return self.solve_monic_quartic(b/a, c/a, d/a, e/a)
def solve_monic_quartic(self, b, c, d, e):
"Solve x^4 + bx^3 + cx^2 + dx + e = 0."
b, c, d, e = map(self.constructor, (b, c, d, e))
s = b/4
return [y - s for y in self.solve_depressed_quartic(
c - 6*s*s, d - 2*c*s + 8*s*s*s, e - d*s + c*s*s - 3*s*s*s*s)]
def solve_depressed_quartic(self, c, d, e):
"Solve x^4 + cx^2 + dx + e = 0."
c, d, e = map(self.constructor, (c, d, e))
solutions = set()
# To solve an equation of this form, we search for a value y
# such that subtracting the original polynomial from (x^2+y)^2
# yields a quadratic of the special form (ux+v)^2.
#
# Then our equation is rewritten as (x^2+y)^2 - (ux+v)^2 = 0
# i.e. ((x^2+y) + (ux+v)) ((x^2+y) - (ux+v)) = 0
# i.e. the product of two quadratics, each of which we then solve.
#
# To find y, we write down the discriminant of the quadratic
# (x^2+y)^2 - (x^4 + cx^2 + dx + e) and set it to 0, which
# gives a cubic in y. Maxima gives the coefficients as
# (-8)y^3 + (4c)y^2 + (8e)y + (d^2-4ce).
#
# As above, we _should_ only need one value of y. But I go
# through them all just in case, because I don't care about
# speed, and because checking the assertions inside this loop
# for every value is extra reassurance that I've done all of
# this right.
for y in self.solve_cubic(-8, 4*c, 8*e, d*d-4*c*e):
# Subtract the original equation from (x^2+y)^2 to get the
# coefficients of our quadratic residual.
A, B, C = 2*y-c, -d, y*y-e
# Expect that to have zero discriminant, i.e. a repeated root.
assert B*B - 4*A*C == 0
# If (Ax^2+Bx+C) == (ux+v)^2 then we have u^2=A, 2uv=B, v^2=C.
# So we can either recover u as sqrt(A) or v as sqrt(C), and
# whichever we did, find the other from B by division. But
# either of the end coefficients might be zero, so we have
# to be prepared to try either option.
try:
if A != 0:
u = self.sqrt.root(A)
v = B/(2*u)
elif C != 0:
v = self.sqrt.root(C)
u = B/(2*v)
else:
# One last possibility is that all three coefficients
# of our residual quadratic are 0, in which case,
# obviously, u=v=0 as well.
u = v = 0
except ValueError:
# If Ax^2+Bx+C looked like a perfect square going by
# its discriminant, but actually taking the square
# root of A or C threw an exception, that means that
# it's the square of a polynomial whose coefficients
# live in a yet-higher field extension of Z_p. In that
# case we're not going to end up with roots of the
# original quartic in Z_p if we start from here!
continue
# So now our quartic is factorised into the form
# (x^2 - ux - v + y) (x^2 + ux + v + y).
for x in self.solve_monic_quadratic(-u, y-v):
solutions.add(x)
for x in self.solve_monic_quadratic(u, y+v):
solutions.add(x)
return solutions
class EquationSolverTest(unittest.TestCase):
def testQuadratic(self):
E = EquationSolverModP(11)
solns = E.solve_quadratic(3, 2, 6)
self.assertEqual(sorted(map(str, solns)), ["0x1", "0x2"])
def testCubic(self):
E = EquationSolverModP(11)
solns = E.solve_cubic(7, 2, 0, 2)
self.assertEqual(sorted(map(str, solns)), ["0x1", "0x2", "0x3"])
def testQuartic(self):
E = EquationSolverModP(11)
solns = E.solve_quartic(9, 9, 7, 1, 7)
self.assertEqual(sorted(map(str, solns)), ["0x1", "0x2", "0x3", "0x4"])
if __name__ == "__main__":
import sys
if sys.argv[1:] == ["--test"]: