diff --git a/config.c b/config.c index 8114a586..1fbbb3e8 100644 --- a/config.c +++ b/config.c @@ -567,6 +567,8 @@ static void kexlist_handler(union control *ctrl, dlgparam *dlg, { "Diffie-Hellman group exchange", KEX_DHGEX }, { "RSA-based key exchange", KEX_RSA }, { "ECDH key exchange", KEX_ECDH }, + { "NTRU Prime / Curve25519 hybrid kex" + " (quantum-resistant)", KEX_NTRU_HYBRID }, { "-- warn below here --", KEX_WARN } }; diff --git a/crypto/CMakeLists.txt b/crypto/CMakeLists.txt index c6420361..71083107 100644 --- a/crypto/CMakeLists.txt +++ b/crypto/CMakeLists.txt @@ -20,6 +20,7 @@ add_sources_from_current_dir(crypto mac_simple.c md5.c mpint.c + ntru.c prng.c pubkey-pem.c pubkey-ppk.c diff --git a/crypto/ntru.c b/crypto/ntru.c new file mode 100644 index 00000000..88d4f084 --- /dev/null +++ b/crypto/ntru.c @@ -0,0 +1,1915 @@ +/* + * Implementation of OpenSSH 9.x's hybrid key exchange protocol + * sntrup761x25519-sha512@openssh.com . + * + * This consists of the 'Streamlined NTRU Prime' quantum-resistant + * cryptosystem, run in parallel with ordinary Curve25519 to generate + * a shared secret combining the output of both systems. + * + * (Hence, even if you don't trust this newfangled NTRU Prime thing at + * all, it's at least no _less_ secure than the kex you were using + * already.) + * + * References for the NTRU Prime cryptosystem, up to and including + * binary encodings of public and private keys and the exact preimages + * of the hashes used in key exchange: + * + * https://ntruprime.cr.yp.to/ + * https://ntruprime.cr.yp.to/nist/ntruprime-20201007.pdf + * + * The SSH protocol layer is not documented anywhere I could find (as + * of 2022-04-15, not even in OpenSSH's PROTOCOL.* files). I had to + * read OpenSSH's source code to find out how it worked, and the + * answer is as follows: + * + * This hybrid kex method is treated for SSH purposes as a form of + * elliptic-curve Diffie-Hellman, and shares the same SSH message + * sequence: client sends SSH2_MSG_KEX_ECDH_INIT containing its public + * half, server responds with SSH2_MSG_KEX_ECDH_REPLY containing _its_ + * public half plus the host key and signature on the shared secret. + * + * (This is a bit of a fudge, because unlike actual ECDH, this kex + * method is asymmetric: one side sends a public key, and the other + * side encrypts something with it and sends the ciphertext back. So + * while the normal ECDH implementations can compute the two sides + * independently in parallel, this system reusing the same messages + * has to be serial. But the order of the messages _is_ firmly + * specified in SSH ECDH, so it works anyway.) + * + * For this kex method, SSH2_MSG_KEX_ECDH_INIT still contains a single + * SSH 'string', which consists of the concatenation of a Streamlined + * NTRU Prime public key with the Curve25519 public value. (Both of + * these have fixed length in bytes, so there's no ambiguity in the + * concatenation.) + * + * SSH2_MSG_KEX_ECDH_REPLY is mostly the same as usual. The only + * string in the packet that varies is the second one, which would + * normally contain the server's public elliptic curve point. Instead, + * it now contains the concatenation of + * + * - a Streamlined NTRU Prime ciphertext + * - the 'confirmation hash' specified in ntruprime-20201007.pdf, + * hashing the plaintext of that ciphertext together with the + * public key + * - the Curve25519 public point as usual. + * + * Again, all three of those elements have fixed lengths. + * + * The client decrypts the ciphertext, checks the confirmation hash, + * and if successful, generates the 'session hash' specified in + * ntruprime-20201007.pdf, which is 32 bytes long and is the ultimate + * output of the Streamlined NTRU Prime key exchange. + * + * The output of the hybrid kex method as a whole is an SSH 'string' + * of length 64 containing the SHA-512 hash of the concatenatio of + * + * - the Streamlined NTRU Prime session hash (32 bytes) + * - the Curve25519 shared secret (32 bytes). + * + * That string is included directly into the SSH exchange hash and key + * derivation hashes, in place of the mpint that comes out of most + * other kex methods. + */ + +#include +#include +#include + +#include "putty.h" +#include "ssh.h" +#include "mpint.h" +#include "ntru.h" + +/* ---------------------------------------------------------------------- + * Preliminaries: we're going to need to do modular arithmetic on + * small values (considerably smaller than 2^16), and we need to do it + * without using integer division which might not be time-safe. + * + * The strategy for this is the same as I used in + * mp_mod_known_integer: see there for the proofs. The basic idea is + * that we precompute the reciprocal of our modulus as a fixed-point + * number, and use that to get an approximate quotient which we + * subtract off. For these integer sizes, precomputing a fixed-point + * reciprocal of the form (2^48 / modulus) leaves us at most off by 1 + * in the quotient, so there's a single (time-safe) trial subtraction + * at the end. + * + * (It's possible that some speed could be gained by not reducing + * fully at every step. But then you'd have to carefully identify all + * the places in the algorithm where things are compared to zero. This + * was the easiest way to get it all working in the first place.) + */ + +/* Precompute the reciprocal */ +static uint64_t reciprocal_for_reduction(uint16_t q) +{ + return ((uint64_t)1 << 48) / q; +} + +/* Reduce x mod q, assuming qrecip == reciprocal_for_reduction(q) */ +static uint16_t reduce(uint32_t x, uint16_t q, uint64_t qrecip) +{ + uint64_t unshifted_quot = x * qrecip; + uint64_t quot = unshifted_quot >> 48; + uint16_t reduced = x - quot * q; + reduced -= q * (1 & ((q-1 - reduced) >> 15)); + return reduced; +} + +/* Reduce x mod q as above, but also return the quotient */ +static uint16_t reduce_with_quot(uint32_t x, uint32_t *quot_out, + uint16_t q, uint64_t qrecip) +{ + uint64_t unshifted_quot = x * qrecip; + uint64_t quot = unshifted_quot >> 48; + uint16_t reduced = x - quot * q; + uint64_t extraquot = (1 & ((q-1 - reduced) >> 15)); + reduced -= extraquot * q; + *quot_out = quot + extraquot; + return reduced; +} + +/* Invert x mod q, assuming it's nonzero. (For time-safety, no check + * is made for zero; it just returns 0.) */ +static uint16_t invert(uint16_t x, uint16_t q, uint64_t qrecip) +{ + /* Fermat inversion: compute x^(q-2), since x^(q-1) == 1. */ + uint32_t sq = x, bit = 1, acc = 1, exp = q-2; + while (1) { + if (exp & bit) { + acc = reduce(acc * sq, q, qrecip); + exp &= ~bit; + if (!exp) + return acc; + } + sq = reduce(sq * sq, q, qrecip); + bit <<= 1; + } +} + +/* Check whether x == 0, time-safely, and return 1 if it is or 0 otherwise. */ +static unsigned iszero(uint16_t x) +{ + return 1 & ~((x + 0xFFFF) >> 16); +} + +/* + * Handy macros to cut down on all those extra function parameters. In + * the common case where a function is working mod the same modulus + * throughout (and has called it q), you can just write 'SETUP;' at + * the top and then call REDUCE(...) and INVERT(...) without having to + * write out q and qrecip every time. + */ +#define SETUP uint64_t qrecip = reciprocal_for_reduction(q) +#define REDUCE(x) reduce(x, q, qrecip) +#define INVERT(x) invert(x, q, qrecip) + +/* ---------------------------------------------------------------------- + * Quotient-ring functions. + * + * NTRU Prime works with two similar but different quotient rings: + * + * Z_q[x] / where p,q are the prime parameters of the system + * Z_3[x] / with the same p, but coefficients mod 3. + * + * The former is a field (every nonzero element is invertible), + * because the system parameters are chosen such that x^p-x-1 is + * invertible over Z_q. The latter is not a field (or not necessarily, + * and in particular, not for the value of p we use here). + * + * In these core functions, you pass in the modulus you want as the + * parameter q, which is either the 'real' q specified in the system + * parameters, or 3 if you're doing one of the mod-3 parts of the + * algorithm. + */ + +/* + * Multiply two elements of a quotient ring. + * + * 'a' and 'b' are arrays of exactly p coefficients, with constant + * term first. 'out' is an array the same size to write the inverse + * into. + */ +void ntru_ring_multiply(uint16_t *out, const uint16_t *a, const uint16_t *b, + unsigned p, unsigned q) +{ + SETUP; + + /* + * Strategy: just compute the full product with 2p coefficients, + * and then reduce it mod x^p-x-1 by working downwards from the + * top coefficient replacing x^{p+k} with (x+1)x^k for k = ...,1,0. + * + * Possibly some speed could be gained here by doing the recursive + * Karatsuba optimisation for the initial multiplication? But I + * haven't tried it. + */ + uint32_t *unreduced = snewn(2*p, uint32_t); + for (unsigned i = 0; i < 2*p; i++) + unreduced[i] = 0; + for (unsigned i = 0; i < p; i++) + for (unsigned j = 0; j < p; j++) + unreduced[i+j] = REDUCE(unreduced[i+j] + a[i] * b[j]); + + for (unsigned i = 2*p - 1; i >= p; i--) { + unreduced[i-p] += unreduced[i]; + unreduced[i-p+1] += unreduced[i]; + unreduced[i] = 0; + } + + for (unsigned i = 0; i < p; i++) + out[i] = REDUCE(unreduced[i]); + + smemclr(unreduced, 2*p * sizeof(*unreduced)); + sfree(unreduced); +} + +/* + * Invert an element of the quotient ring. + * + * 'in' is an array of exactly p coefficients, with constant term + * first. 'out' is an array the same size to write the inverse into. + * + * Method: essentially Stein's gcd algorithm, taking the gcd of the + * input (regarded as an element of Z_q[x] proper) and x^p-x-1. Given + * two polynomials over a field which are not both divisible by x, you + * can find their gcd by iterating the following procedure: + * + * - if one is divisible by x, divide off x + * - otherwise, subtract from the higher-degree one whatever scalar + * multiple of the lower-degree one will make it divisible by x, + * and _then_ divide off x + * + * Neither of these types of step changes the gcd of the two + * polynomials. + * + * Each step reduces the sum of the two polynomials' degree by at + * least one, as long as at least one of the degrees is positive. + * (Maybe more than one if all the stars align in the second case, if + * the subtraction cancels the leading term as well as the constant + * term.) So in at most deg A + deg B steps, we must have reached the + * situation where both polys are constants, and in one more step + * after that, one of them will be zero. Or rather, that's what + * happens in the case where A,B are coprime; if not, then one hits + * zero while the other is still nonzero. + * + * Then unwind all the transformations, to find a linear combination + * of the two original polynomials that yields the nonzero one of the + * two outputs. (In fact we only need the coefficient of 'in' in that + * linear combination, but we have to compute both halves, because + * they keep swapping round during the unwinding.) + */ +unsigned ntru_ring_invert(uint16_t *out, const uint16_t *in, + unsigned p, unsigned q) +{ + SETUP; + + /* Size of the polynomial arrays we'll work with */ + const size_t SIZE = p+1; + + /* Number of steps of the algorithm is the max possible value of + * deg A + deg B + 1, where deg A <= p-1 and deg B = p */ + const size_t STEPS = 2*p; + + /* Our two working polynomials */ + uint16_t *A = snewn(SIZE, uint16_t); + uint16_t *B = snewn(SIZE, uint16_t); + + /* History of what we did */ + uint16_t *multipliers = snewn(STEPS, uint16_t); + uint8_t *swaps = snewn(STEPS, uint8_t); + + /* Initialise A to the input */ + memcpy(A, in, p*sizeof(uint16_t)); + A[p] = 0; + + /* And initialise B to the quotient polynomial of the ring, x^p-x-1 */ + B[0] = B[1] = q-1; + for (size_t i = 2; i < p; i++) + B[i] = 0; + B[p] = 1; + + /* Run the gcd-finding algorithm. */ + for (size_t i = 0; i < STEPS; i++) { + /* + * First swap round so that A is the one we'll be dividing by 2. + * + * In the case where one of the two polys has a zero constant + * term, it's that one. In the other case, it's the one of + * smaller degree. We must compute both, and choose between + * them in a side-channel-safe way. + */ + unsigned x_divides_A = iszero(A[0]); + unsigned x_divides_B = iszero(B[0]); + unsigned B_is_bigger = 0; + { + unsigned not_seen_top_term_of_A = 1, not_seen_top_term_of_B = 1; + for (size_t j = SIZE; j-- > 0 ;) { + not_seen_top_term_of_A &= iszero(A[j]); + not_seen_top_term_of_B &= iszero(B[j]); + B_is_bigger |= (~not_seen_top_term_of_B & + not_seen_top_term_of_A); + } + } + unsigned need_swap = x_divides_B | (~x_divides_A & B_is_bigger); + uint16_t swap_mask = -need_swap; + for (size_t j = 0; j < SIZE; j++) { + uint16_t diff = (A[j] ^ B[j]) & swap_mask; + A[j] ^= diff; + B[j] ^= diff; + } + + /* + * Add a multiple of B to A to make A's constant term zero. In + * one of the two cases, A's constant term is already zero, so + * this will do nothing but take the same length of time as + * doing something, which is just what we want. + * + * Also, shift down by one in the course of doing this. + */ + uint16_t mult = REDUCE((q - A[0]) * INVERT(B[0])); + for (size_t j = 1; j < SIZE; j++) + A[j-1] = REDUCE(A[j] + mult * B[j]); + A[SIZE-1] = 0; + + /* + * Record what we just did. + */ + swaps[i] = need_swap; + multipliers[i] = mult; + } + + /* + * Now we expect that one of the polynomials is zero, and the + * other is zero except for the constant term. If so, then they + * are coprime, and we're going to return success. If not, they + * have a common factor. + */ + unsigned success = iszero(A[0]) ^ iszero(B[0]); + for (size_t j = 1; j < SIZE; j++) + success &= iszero(A[j]) & iszero(B[j]); + + /* + * Now unwind to make a linear combination of the two original + * polynomials that equals 1 (assuming we're going to return + * success). + * + * We make two polynomials Ac,Bc, with the intention that we'll + * preserve the invariant Ac*A + Bc*B = 1 as we rewind through the + * steps. + * + * Initially, we set the coefficient of the zero one of A,B to + * zero, and the coefficient of the constant one to be its + * inverse. + */ + uint16_t *Ac = snewn(SIZE, uint16_t); + uint16_t *Bc = snewn(SIZE, uint16_t); + for (size_t i = 1; i < SIZE; i++) + Ac[i] = Bc[i] = 0; + Ac[0] = INVERT(A[0]); + Bc[0] = INVERT(B[0]); + + for (size_t i = STEPS; i-- > 0 ;) { + /* + * The last thing we did in our step was always to divide A by + * x. That is, we currently have 1 as a linear combination of + * A and B, and now we need it as a linear combination of A*x + * and B. + * + * We have Ac*A + Bc*B = (Ac+k*B)*A + (Bc-k*A)*B for any k. + * So choose k such that Ac+k*B has zero constant term + * (possible since B has nonzero constant term), and then we + * have 1 = (Ac+k*B)/x * (A*x) + (Bc-k*A) * B. + */ + uint16_t minusk = REDUCE(Ac[0] * INVERT(B[0])); + uint16_t k = q - minusk; + for (size_t j = 1; j < SIZE; j++) + Ac[j-1] = REDUCE(Ac[j] + k * B[j]); + Ac[SIZE-1] = 0; + for (size_t j = 0; j < SIZE; j++) + Bc[j] = REDUCE(Bc[j] + minusk * A[j]); + + /* And unwind the shift of A itself. */ + memmove(A+1, A, (SIZE-1) * sizeof(*A)); + A[0] = 0; + + /* + * Before that, we added m*B to A. So our new A will be A-m*B. + * So we have 1 = Ac*A + Bc*B = Ac*(A-m*B) + (Bc+m*Ac)*B. + */ + uint16_t m = multipliers[i]; + uint16_t minusm = q - m; + for (size_t j = 0; j < SIZE; j++) + Bc[j] = REDUCE(Bc[j] + m * Ac[j]); + for (size_t j = 0; j < SIZE; j++) + A[j] = REDUCE(A[j] + minusm * B[j]); + + /* + * And before that, we conditionally swapped A,B. + */ + uint16_t swap_mask = -swaps[i]; + for (size_t j = 0; j < SIZE; j++) { + uint16_t diff; + diff = (A[j] ^ B[j]) & swap_mask; + A[j] ^= diff; + B[j] ^= diff; + diff = (Ac[j] ^ Bc[j]) & swap_mask; + Ac[j] ^= diff; + Bc[j] ^= diff; + } + } + + /* Done! Our coefficient Ac is the inverse, if one exists. */ + memcpy(out, Ac, p * sizeof(*out)); + + smemclr(A, SIZE * sizeof(*A)); + sfree(A); + smemclr(B, SIZE * sizeof(*B)); + sfree(B); + smemclr(Ac, SIZE * sizeof(*A)); + sfree(Ac); + smemclr(Bc, SIZE * sizeof(*B)); + sfree(Bc); + smemclr(multipliers, STEPS * sizeof(*multipliers)); + sfree(multipliers); + smemclr(swaps, STEPS * sizeof(*swaps)); + sfree(swaps); + + return success; +} + +/* + * Given an array of values mod q, convert each one to its + * minimum-absolute-value representative, and then reduce mod 3. + * + * Output values are 0, 1 and 0xFFFF, representing -1. + * + * (Normally our arrays of uint16_t are in 'minimal non-negative + * residue' form, so the output of this function is unusual. But it's + * useful to have it in this form so that it can be reused by + * ntru_round3. You can put it back to the usual representation using + * ntru_normalise, below.) + */ +void ntru_mod3(uint16_t *out, const uint16_t *in, unsigned p, unsigned q) +{ + uint64_t qrecip = reciprocal_for_reduction(q); + uint64_t recip3 = reciprocal_for_reduction(3); + + unsigned bias = q/2; + uint16_t adjust = 3 - reduce(bias-1, 3, recip3); + + for (unsigned i = 0; i < p; i++) { + uint16_t val = reduce(in[i] + bias, q, qrecip); + uint16_t residue = reduce(val + adjust, 3, recip3); + out[i] = residue - 1; + } +} + +/* + * Given an array of values mod q, round each one to the nearest + * multiple of 3 to its minimum-absolute-value representative. + * + * Output values are signed integers coerced to uint16_t, so again, + * use ntru_normalise afterwards to put them back to normal. + */ +void ntru_round3(uint16_t *out, const uint16_t *in, unsigned p, unsigned q) +{ + SETUP; + unsigned bias = q/2; + ntru_mod3(out, in, p, q); + for (unsigned i = 0; i < p; i++) + out[i] = REDUCE(in[i] + bias) - bias - out[i]; +} + +/* + * Given an array of signed integers coerced to uint16_t in the range + * [-q/2,+q/2], normalise them back to mod q values. + */ +static void ntru_normalise(uint16_t *out, const uint16_t *in, + unsigned p, unsigned q) +{ + for (unsigned i = 0; i < p; i++) + out[i] = in[i] + q * (in[i] >> 15); +} + +/* + * Given an array of values mod q, add a constant to each one. + */ +void ntru_bias(uint16_t *out, const uint16_t *in, unsigned bias, + unsigned p, unsigned q) +{ + SETUP; + for (unsigned i = 0; i < p; i++) + out[i] = REDUCE(in[i] + bias); +} + +/* + * Given an array of values mod q, multiply each one by a constant. + */ +void ntru_scale(uint16_t *out, const uint16_t *in, uint16_t scale, + unsigned p, unsigned q) +{ + SETUP; + for (unsigned i = 0; i < p; i++) + out[i] = REDUCE(in[i] * scale); +} + +/* + * Given an array of values mod 3, convert them to values mod q in a + * way that maps -1,0,+1 to -1,0,+1. + */ +void ntru_expand(uint16_t *out, const uint16_t *in, unsigned p, unsigned q) +{ + for (size_t i = 0; i < p; i++) { + uint16_t v = in[i]; + /* Map 2 to q-1, and leave 0 and 1 unchanged */ + v += (v >> 1) * (q-3); + out[i] = v; + } +} + +/* ---------------------------------------------------------------------- + * Implement the binary encoding from ntruprime-20201007.pdf, which is + * used to encode public keys and ciphertexts (though not plaintexts, + * which are done in a much simpler way). + * + * The general idea is that your encoder takes as input a list of + * small non-negative integers (r_i), and a sequence of limits (m_i) + * such that 0 <= r_i < m_i, and emits a sequence of bytes that encode + * all of these as tightly as reasonably possible. + * + * That's more general than is really needed, because in both the + * actual uses of this encoding, the input m_i are all the same! But + * the array of (r_i,m_i) pairs evolves during encoding, so they don't + * _stay_ all the same, so you still have to have all the generality. + * + * The encoding process makes a number of passes along the list of + * inputs. In each step, pairs of adjacent numbers are combined into + * one larger one by turning (r_i,m_i) and (r_{i+1},m_{i+1}) into the + * pair (r_i + m_i r_{i+1}, m_i m_{i+1}), i.e. so that the original + * numbers could be recovered by taking the quotient and remaiinder of + * the new r value by m_i. Then, if the new m_i is at least 2^14, we + * emit the low 8 bits of r_i to the output stream and reduce r_i and + * its limit correspondingly. So at the end of the pass, we've got + * half as many numbers still to encode, they're all still not too + * big, and we've emitted some amount of data into the output. Then do + * another pass, keep going until there's only one number left, and + * emit it little-endian. + * + * That's all very well, but how do you decode it again? DJB exhibits + * a pair of recursive functions that are supposed to be mutually + * inverse, but I didn't have any confidence that I'd be able to debug + * them sensibly if they turned out not to be (or rather, if I + * implemented one of them wrong). So I came up with my own strategy + * instead. + * + * In my strategy, we start by processing just the (m_i) into an + * 'encoding schedule' consisting of a sequence of simple + * instructions. The instructions operate on a FIFO queue of numbers, + * initialised to the original (r_i). The three instruction types are: + * + * - 'COMBINE': consume two numbers a,b from the head of the queue, + * combine them by calculating a + m*b for some specified m, and + * push the result on the tail of the queue. + * + * - 'BYTE': divide the tail element of the queue by 2^8 and emit the + * low bits into the output stream. + * + * - 'COPY': pop a number from the head of the queue and push it + * straight back on the tail. (Used for handling the leftover + * element at the end of a pass if the input to the pass was a list + * of odd length.) + * + * So we effectively implement DJB's encoding process in simulation, + * and instead of actually processing a set of (r_i), we 'compile' the + * process into a sequence of instructions that can be handed just the + * (r_i) later and encode them in the right way. At the end of the + * instructions, the queue is expected to have been reduced to length + * 1 and contain the single integer 0. + * + * The nice thing about this system is that each of those three + * instructions is easy to reverse. So you can also use the same + * instructions for decoding: start with a queue containing 0, and + * process the instructions in reverse order and reverse sense. So + * BYTE means to _consume_ a byte from the encoded data (starting from + * the rightmost end) and use it to make a queue element bigger; and + * COMBINE run in reverse pops a single element from one end of the + * queue, divides it by m, and pushes the quotient and remainder on + * the other end. + * + * (So it's easy to debug, because the queue passes through the exact + * same sequence of states during decoding that it did during + * encoding, just in reverse order.) + * + * Also, the encoding schedule comes with information about the + * expected size of the encoded data, because you can find that out + * easily by just counting the BYTE commands. + */ + +enum { + /* + * Command values appearing in the 'ops' array. ENC_COPY and + * ENC_BYTE are single values; values of the form + * (ENC_COMBINE_BASE + m) represent a COMBINE command with + * parameter m. + */ + ENC_COPY, ENC_BYTE, ENC_COMBINE_BASE +}; +struct NTRUEncodeSchedule { + /* + * Object representing a compiled set of encoding instructions. + * + * 'nvals' is the number of r_i we expect to encode. 'nops' is the + * number of encoding commands in the 'ops' list; 'opsize' is the + * physical size of the array, used during construction. + * + * 'endpos' is used to avoid a last-minute faff during decoding. + * We implement our FIFO of integers as a ring buffer of size + * 'nvals'. Encoding cycles round it some number of times, and the + * final 0 element ends up at some random location in the array. + * If we know _where_ the 0 ends up during encoding, we can put + * the initial 0 there at the start of decoding, and then when we + * finish reversing all the instructions, we'll end up with the + * output numbers already arranged at their correct positions, so + * that there's no need to rotate the array at the last minute. + */ + size_t nvals, endpos, nops, opsize; + uint32_t *ops; +}; +static inline void sched_append(NTRUEncodeSchedule *sched, uint16_t op) +{ + /* Helper function to append an operation to the schedule, and + * update endpos. */ + sgrowarray(sched->ops, sched->opsize, sched->nops); + sched->ops[sched->nops++] = op; + if (op != ENC_BYTE) + sched->endpos = (sched->endpos + 1) % sched->nvals; +} + +/* + * Take in the list of limit values (m_i) and compute the encoding + * schedule. + */ +NTRUEncodeSchedule *ntru_encode_schedule(const uint16_t *ms_in, size_t n) +{ + NTRUEncodeSchedule *sched = snew(NTRUEncodeSchedule); + sched->nvals = n; + sched->endpos = n-1; + sched->nops = sched->opsize = 0; + sched->ops = NULL; + + assert(n != 0); + + /* + * 'ms' is the list of (m_i) on input to the current pass. + * 'ms_new' is the list output from the current pass. After each + * pass we swap the arrays round. + */ + uint32_t *ms = snewn(n, uint32_t); + uint32_t *msnew = snewn(n, uint32_t); + for (size_t i = 0; i < n; i++) + ms[i] = ms_in[i]; + + while (n > 1) { + size_t nnew = 0; + for (size_t i = 0; i < n; i += 2) { + if (i+1 == n) { + /* + * Odd element at the end of the input list: just copy + * it unchanged to the output. + */ + sched_append(sched, ENC_COPY); + msnew[nnew++] = ms[i]; + break; + } + + /* + * Normal case: consume two elements from the input list + * and combine them. + */ + uint32_t m1 = ms[i], m2 = ms[i+1], m = m1*m2; + sched_append(sched, ENC_COMBINE_BASE + m1); + + /* + * And then, as long as the combined limit is big enough, + * emit an output byte from the bottom of it. + */ + while (m >= (1<<14)) { + sched_append(sched, ENC_BYTE); + m = (m + 0xFF) >> 8; + } + + /* + * Whatever is left after that, we emit into the output + * list and append to the fifo. + */ + msnew[nnew++] = m; + } + + /* + * End of pass. The output list of (m_i) now becomes the input + * list. + */ + uint32_t *tmp = ms; + ms = msnew; + n = nnew; + msnew = tmp; + } + + /* + * When that loop terminates, it's because there's exactly one + * number left to encode. (Or, technically, _at most_ one - but we + * don't support encoding a completely empty list in this + * implementation, because what would be the point?) That number + * is just emitted little-endian until its limit is 1 (meaning its + * only possible actual value is 0). + */ + assert(n == 1); + uint32_t m = ms[0]; + while (m > 1) { + sched_append(sched, ENC_BYTE); + m = (m + 0xFF) >> 8; + } + + sfree(ms); + sfree(msnew); + + return sched; +} + +void ntru_encode_schedule_free(NTRUEncodeSchedule *sched) +{ + sfree(sched->ops); + sfree(sched); +} + +/* + * Calculate the output length of the encoded data in bytes. + */ +size_t ntru_encode_schedule_length(NTRUEncodeSchedule *sched) +{ + size_t len = 0; + for (size_t i = 0; i < sched->nops; i++) + if (sched->ops[i] == ENC_BYTE) + len++; + return len; +} + +/* + * Retrieve the number of items encoded. (Used by testcrypt.) + */ +size_t ntru_encode_schedule_nvals(NTRUEncodeSchedule *sched) +{ + return sched->nvals; +} + +/* + * Actually encode a sequence of (r_i), emitting the output bytes to + * an arbitrary BinarySink. + */ +void ntru_encode(NTRUEncodeSchedule *sched, const uint16_t *rs_in, + BinarySink *bs) +{ + size_t n = sched->nvals; + uint32_t *rs = snewn(n, uint32_t); + for (size_t i = 0; i < n; i++) + rs[i] = rs_in[i]; + + /* + * The head and tail pointers of the queue are both 'full'. That + * is, rs[head] is the first element actually in the queue, and + * rs[tail] is the last element. + * + * So you append to the queue by first advancing 'tail' and then + * writing to rs[tail], whereas you consume from the queue by + * first reading rs[head] and _then_ advancing 'head'. + * + * The more normal thing would be to make 'tail' point to the + * first empty slot instead of the last full one. But then you'd + * have to faff about with modular arithmetic to find the last + * full slot for the BYTE command, so in this case, it's easier to + * do it the less usual way. + */ + size_t head = 0, tail = n-1; + + for (size_t i = 0; i < sched->nops; i++) { + uint16_t op = sched->ops[i]; + switch (op) { + case ENC_BYTE: + put_byte(bs, rs[tail] & 0xFF); + rs[tail] >>= 8; + break; + case ENC_COPY: { + uint32_t r = rs[head]; + head = (head + 1) % n; + tail = (tail + 1) % n; + rs[tail] = r; + break; + } + default: { + uint32_t r1 = rs[head]; + head = (head + 1) % n; + uint32_t r2 = rs[head]; + head = (head + 1) % n; + tail = (tail + 1) % n; + rs[tail] = r1 + (op - ENC_COMBINE_BASE) * r2; + break; + } + } + } + + /* + * Expect that we've ended up with a single zero in the queue, at + * exactly the position that the setup-time analysis predicted it. + */ + assert(head == sched->endpos); + assert(tail == sched->endpos); + assert(rs[head] == 0); + + smemclr(rs, n * sizeof(*rs)); + sfree(rs); +} + +/* + * Decode a ptrlen of binary data into a sequence of (r_i). The data + * is expected to be of exactly the right length (on pain of assertion + * failure). + */ +void ntru_decode(NTRUEncodeSchedule *sched, uint16_t *rs_out, ptrlen data) +{ + size_t n = sched->nvals; + const uint8_t *base = (const uint8_t *)data.ptr; + const uint8_t *pos = base + data.len; + + /* + * Initialise the queue to a single zero, at the 'endpos' position + * that will mean the final output is correctly aligned. + * + * 'head' and 'tail' have the same meanings as in encoding. So + * 'tail' is the location that BYTE modifies and COPY and COMBINE + * consume from, and 'head' is the location that COPY and COMBINE + * push on to. As in encoding, they both point at the extremal + * full slots in the array. + */ + uint32_t *rs = snewn(n, uint32_t); + size_t head = sched->endpos, tail = head; + rs[tail] = 0; + + for (size_t i = sched->nops; i-- > 0 ;) { + uint16_t op = sched->ops[i]; + switch (op) { + case ENC_BYTE: { + assert(pos > base); + uint8_t byte = *--pos; + rs[tail] = (rs[tail] << 8) | byte; + break; + } + case ENC_COPY: { + uint32_t r = rs[tail]; + tail = (tail + n - 1) % n; + head = (head + n - 1) % n; + rs[head] = r; + break; + } + default: { + uint32_t r = rs[tail]; + tail = (tail + n - 1) % n; + + uint32_t m = op - ENC_COMBINE_BASE; + uint64_t mrecip = reciprocal_for_reduction(m); + + uint32_t r1, r2; + r1 = reduce_with_quot(r, &r2, m, mrecip); + + head = (head + n - 1) % n; + rs[head] = r2; + head = (head + n - 1) % n; + rs[head] = r1; + break; + } + } + } + + assert(pos == base); + assert(head == 0); + assert(tail == n-1); + + for (size_t i = 0; i < n; i++) + rs_out[i] = rs[i]; + smemclr(rs, n * sizeof(*rs)); + sfree(rs); +} + +/* ---------------------------------------------------------------------- + * The actual public-key cryptosystem. + */ + +struct NTRUKeyPair { + unsigned p, q, w; + uint16_t *h; /* public key */ + uint16_t *f3, *ginv; /* private key */ + uint16_t *rho; /* for implicit rejection */ +}; + +/* Helper function to free an array of uint16_t containing a ring + * element, clearing it on the way since some of them are sensitive. */ +static void ring_free(uint16_t *val, unsigned p) +{ + smemclr(val, p*sizeof(*val)); + sfree(val); +} + +void ntru_keypair_free(NTRUKeyPair *keypair) +{ + ring_free(keypair->h, keypair->p); + ring_free(keypair->f3, keypair->p); + ring_free(keypair->ginv, keypair->p); + ring_free(keypair->rho, keypair->p); + sfree(keypair); +} + +/* Trivial accessors used by test programs. */ +unsigned ntru_keypair_p(NTRUKeyPair *keypair) { return keypair->p; } +const uint16_t *ntru_pubkey(NTRUKeyPair *keypair) { return keypair->h; } + +/* + * Generate a value of the class DJB describes as 'Short': it consists + * of p terms that are all either 0 or +1 or -1, and exactly w of them + * are not zero. + * + * Values of this kind are used for several purposes: part of the + * private key, a plaintext, and the 'rho' fake-plaintext value used + * for deliberately returning a duff but non-revealing session hash if + * things go wrong. + */ +void ntru_gen_short(uint16_t *v, unsigned p, unsigned w) +{ + /* + * Get enough random data to generate a polynomial all of whose p + * terms are in {0,+1,-1}, and exactly w of them are nonzero. + * + * We're going to need w * random bits to choose the nonzero + * values, and then (doing it the simplest way) log2(p!) bits to + * shuffle them, plus say 128 bits to ensure any fluctuations in + * uniformity are negligible. + * + * log2(p!) is a pain to calculate, so we'll bound it above by + * p*log2(p), which we bound in turn by p*16. + */ + size_t randbitpos = 16 * p + w + 128; + mp_int *randdata = mp_resize(mp_random_bits(randbitpos), randbitpos + 32); + + /* + * Initial value before shuffling: w randomly chosen values in + * {1,q-1}, plus zeroes to pad to length p. + */ + for (size_t i = 0; i < w; i++) + v[i] = 1 + mp_get_bit(randdata, --randbitpos); + for (size_t i = w; i < p; i++) + v[i] = 0; + + /* + * Hereafter we're going to extract random bits by multiplication, + * treating randdata as a large fixed-point number. + */ + mp_reduce_mod_2to(randdata, randbitpos); + + /* + * Shuffle. + */ + mp_int *x = mp_new(64); + for (size_t i = p-1; i > 0; i--) { + /* + * Decide which element to swap with v[i], potentially + * including i itself. + */ + mp_mul_integer_into(randdata, randdata, i+1); + mp_rshift_fixed_into(x, randdata, randbitpos); + mp_reduce_mod_2to(randdata, randbitpos); + size_t j = mp_get_integer(x); + + /* + * Swap it, which involves a constant-time selection loop over + * the whole eligible part of the array. This makes the + * shuffling quadratic-time overall. I'd be interested in a + * nicer algorithm, but this will do for now. + */ + for (size_t k = 0; k <= i; k++) { + uint16_t mask = -iszero(k ^ j); + uint16_t diff = mask & (v[k] ^ v[i]); + v[k] ^= diff; + v[i] ^= diff; + } + } + + mp_free(x); + mp_free(randdata); +} + +/* + * Make a single attempt at generating a key pair. This involves + * inventing random elements of both our quotient rings and hoping + * they're both invertible. + * + * They may not be, if you're unlucky. The element of Z_q/ + * will _almost_ certainly be invertible, because that is a field, so + * invertibility can only fail if you were so unlucky as to choose the + * all-0s element. But the element of Z_3/ may fail to be + * invertible because it has a common factor with x^p-x-1 (which, over + * Z_3, is not irreducible). + * + * So we can't guarantee to generate a key pair in constant time, + * because there's no predicting how many retries we'll need. However, + * this isn't a failure of side-channel safety, because we completely + * discard all the random numbers and state from each failed attempt. + * So if there were a side-channel leakage from a failure, the only + * thing it would give away would be a bunch of random numbers that + * turned out not to be used anyway. + * + * But a _successful_ call to this function should execute in a + * secret-independent manner, and this 'make a single attempt' + * function is exposed in the API so that 'testsc' can check that. + */ +NTRUKeyPair *ntru_keygen_attempt(unsigned p, unsigned q, unsigned w) +{ + /* + * First invent g, which is the one more likely to fail to invert. + * This is simply a uniformly random polynomial with p terms over + * Z_3. So we need p*log2(3) random bits for it, plus 128 for + * uniformity. It's easiest to bound log2(3) above by 2. + */ + size_t randbitpos = 2 * p + 128; + mp_int *randdata = mp_resize(mp_random_bits(randbitpos), randbitpos + 32); + + /* + * Select p random values from {0,1,2}. + */ + uint16_t *g = snewn(p, uint16_t); + mp_int *x = mp_new(64); + for (size_t i = 0; i < p; i++) { + mp_mul_integer_into(randdata, randdata, 3); + mp_rshift_fixed_into(x, randdata, randbitpos); + mp_reduce_mod_2to(randdata, randbitpos); + g[i] = mp_get_integer(x); + } + mp_free(x); + mp_free(randdata); + + /* + * Try to invert g over Z_3, and fail if it isn't invertible. + */ + uint16_t *ginv = snewn(p, uint16_t); + if (!ntru_ring_invert(ginv, g, p, 3)) { + ring_free(g, p); + ring_free(ginv, p); + return NULL; + } + + /* + * Fine; we have g. Now make up an f, and convert it to a + * polynomial over q. + */ + uint16_t *f = snewn(p, uint16_t); + ntru_gen_short(f, p, w); + ntru_expand(f, f, p, q); + + /* + * Multiply f by 3. + */ + uint16_t *f3 = snewn(p, uint16_t); + ntru_scale(f3, f, 3, p, q); + + /* + * Try to invert 3*f over Z_q. This should be _almost_ guaranteed + * to succeed, since Z_q/ is a field, so the only + * non-invertible value is 0. Even so, there _is_ one, so check + * the return value! + */ + uint16_t *f3inv = snewn(p, uint16_t); + if (!ntru_ring_invert(f3inv, f3, p, q)) { + ring_free(f, p); + ring_free(f3, p); + ring_free(f3inv, p); + ring_free(g, p); + ring_free(ginv, p); + return NULL; + } + + /* + * Make the public key, by converting g to a polynomial over q and + * then multiplying by f3inv. + */ + uint16_t *g_q = snewn(p, uint16_t); + ntru_expand(g_q, g, p, q); + uint16_t *h = snewn(p, uint16_t); + ntru_ring_multiply(h, g_q, f3inv, p, q); + + /* + * Make up rho, used to substitute for the plaintext in the + * session hash in case of confirmation failure. + */ + uint16_t *rho = snewn(p, uint16_t); + ntru_gen_short(rho, p, w); + + /* + * And we're done! Free everything except the pieces we're + * returning. + */ + NTRUKeyPair *keypair = snew(NTRUKeyPair); + keypair->p = p; + keypair->q = q; + keypair->w = w; + keypair->h = h; + keypair->f3 = f3; + keypair->ginv = ginv; + keypair->rho = rho; + ring_free(f, p); + ring_free(f3inv, p); + ring_free(g, p); + ring_free(g_q, p); + return keypair; +} + +/* + * The top-level key generation function for real use (as opposed to + * testsc): keep trying to make a key until you succeed. + */ +NTRUKeyPair *ntru_keygen(unsigned p, unsigned q, unsigned w) +{ + while (1) { + NTRUKeyPair *keypair = ntru_keygen_attempt(p, q, w); + if (keypair) + return keypair; + } +} + +/* + * Public-key encryption. + */ +void ntru_encrypt(uint16_t *ciphertext, const uint16_t *plaintext, + uint16_t *pubkey, unsigned p, unsigned q) +{ + uint16_t *r_q = snewn(p, uint16_t); + ntru_expand(r_q, plaintext, p, q); + + uint16_t *unrounded = snewn(p, uint16_t); + ntru_ring_multiply(unrounded, r_q, pubkey, p, q); + + ntru_round3(ciphertext, unrounded, p, q); + ntru_normalise(ciphertext, ciphertext, p, q); + + ring_free(r_q, p); + ring_free(unrounded, p); +} + +/* + * Public-key decryption. + */ +void ntru_decrypt(uint16_t *plaintext, const uint16_t *ciphertext, + NTRUKeyPair *keypair) +{ + unsigned p = keypair->p, q = keypair->q, w = keypair->w; + uint16_t *tmp = snewn(p, uint16_t); + + ntru_ring_multiply(tmp, ciphertext, keypair->f3, p, q); + + ntru_mod3(tmp, tmp, p, q); + ntru_normalise(tmp, tmp, p, 3); + + ntru_ring_multiply(plaintext, tmp, keypair->ginv, p, 3); + ring_free(tmp, p); + + /* + * With luck, this should have recovered exactly the original + * plaintext. But, as per the spec, we check whether it has + * exactly w nonzero coefficients, and if not, then something has + * gone wrong - and in that situation we time-safely substitute a + * different output. + * + * (I don't know exactly why we do this, but I assume it's because + * otherwise the mis-decoded output could be made to disgorge a + * secret about the private key in some way.) + */ + + unsigned weight = p; + for (size_t i = 0; i < p; i++) + weight -= iszero(plaintext[i]); + unsigned ok = iszero(weight ^ w); + + /* + * The default failure return value consists of w 1s followed by + * 0s. + */ + unsigned mask = ok - 1; + for (size_t i = 0; i < w; i++) { + uint16_t diff = (1 ^ plaintext[i]) & mask; + plaintext[i] ^= diff; + } + for (size_t i = w; i < p; i++) { + uint16_t diff = (0 ^ plaintext[i]) & mask; + plaintext[i] ^= diff; + } +} + +/* ---------------------------------------------------------------------- + * Encode and decode public keys, ciphertexts and plaintexts. + * + * Public keys and ciphertexts use the complicated binary encoding + * system implemented above. In both cases, the inputs are regarded as + * symmetric about zero, and are first biased to map their most + * negative permitted value to 0, so that they become non-negative and + * hence suitable as inputs to the encoding system. In the case of a + * ciphertext, where the input coefficients have also been coerced to + * be multiples of 3, we divide by 3 as well, saving space by reducing + * the upper bounds (m_i) on all the encoded numbers. + */ + +/* + * Compute the encoding schedule for a public key. + */ +static NTRUEncodeSchedule *ntru_encode_pubkey_schedule(unsigned p, unsigned q) +{ + uint16_t *ms = snewn(p, uint16_t); + for (size_t i = 0; i < p; i++) + ms[i] = q; + NTRUEncodeSchedule *sched = ntru_encode_schedule(ms, p); + sfree(ms); + return sched; +} + +/* + * Encode a public key. + */ +void ntru_encode_pubkey(const uint16_t *pubkey, unsigned p, unsigned q, + BinarySink *bs) +{ + /* Compute the biased version for encoding */ + uint16_t *biased_pubkey = snewn(p, uint16_t); + ntru_bias(biased_pubkey, pubkey, q / 2, p, q); + + /* Encode it */ + NTRUEncodeSchedule *sched = ntru_encode_pubkey_schedule(p, q); + ntru_encode(sched, biased_pubkey, bs); + ntru_encode_schedule_free(sched); + + ring_free(biased_pubkey, p); +} + +/* + * Decode a public key and write it into 'pubkey'. We also return a + * ptrlen pointing at the chunk of data we removed from the + * BinarySource. + */ +ptrlen ntru_decode_pubkey(uint16_t *pubkey, unsigned p, unsigned q, + BinarySource *src) +{ + NTRUEncodeSchedule *sched = ntru_encode_pubkey_schedule(p, q); + + /* Retrieve the right number of bytes from the source */ + size_t len = ntru_encode_schedule_length(sched); + ptrlen encoded = get_data(src, len); + if (get_err(src)) { + /* If there wasn't enough data, give up and return all-zeroes + * purely for determinism. But that value should never be + * used, because the caller will also check get_err(src). */ + memset(pubkey, 0, p*sizeof(*pubkey)); + } else { + /* Do the decoding */ + ntru_decode(sched, pubkey, encoded); + ntru_encode_schedule_free(sched); + + /* Unbias the coefficients */ + ntru_bias(pubkey, pubkey, q-q/2, p, q); + } + + return encoded; +} + +/* + * For ciphertext biasing: work out the largest absolute value a + * ciphertext element can take, which is given by taking q/2 and + * rounding it to the nearest multiple of 3. + */ +static inline unsigned ciphertext_bias(unsigned q) +{ + return (q/2+1) / 3; +} + +/* + * The number of possible values of a ciphertext coefficient (for use + * as the m_i in encoding) ranges from +ciphertext_bias(q) to + * -ciphertext_bias(q) inclusive. + */ +static inline unsigned ciphertext_m(unsigned q) +{ + return 1 + 2 * ciphertext_bias(q); +} + +/* + * Compute the encoding schedule for a ciphertext. + */ +static NTRUEncodeSchedule *ntru_encode_ciphertext_schedule( + unsigned p, unsigned q) +{ + unsigned m = ciphertext_m(q); + uint16_t *ms = snewn(p, uint16_t); + for (size_t i = 0; i < p; i++) + ms[i] = m; + NTRUEncodeSchedule *sched = ntru_encode_schedule(ms, p); + sfree(ms); + return sched; +} + +/* + * Encode a ciphertext. + */ +void ntru_encode_ciphertext(const uint16_t *ciphertext, unsigned p, unsigned q, + BinarySink *bs) +{ + SETUP; + + /* + * Bias the ciphertext, and scale down by 1/3, which we do by + * modular multiplication by the inverse of 3 mod q. (That only + * works if we know the inputs are all _exact_ multiples of 3 + * - but we do!) + */ + uint16_t *biased_ciphertext = snewn(p, uint16_t); + ntru_bias(biased_ciphertext, ciphertext, 3 * ciphertext_bias(q), p, q); + ntru_scale(biased_ciphertext, biased_ciphertext, INVERT(3), p, q); + + /* Encode. */ + NTRUEncodeSchedule *sched = ntru_encode_ciphertext_schedule(p, q); + ntru_encode(sched, biased_ciphertext, bs); + ntru_encode_schedule_free(sched); + + ring_free(biased_ciphertext, p); +} + +ptrlen ntru_decode_ciphertext(uint16_t *ct, NTRUKeyPair *keypair, + BinarySource *src) +{ + unsigned p = keypair->p, q = keypair->q; + + NTRUEncodeSchedule *sched = ntru_encode_ciphertext_schedule(p, q); + + /* Retrieve the right number of bytes from the source */ + size_t len = ntru_encode_schedule_length(sched); + ptrlen encoded = get_data(src, len); + if (get_err(src)) { + /* As above, return deterministic nonsense on failure */ + memset(ct, 0, p*sizeof(*ct)); + } else { + /* Do the decoding */ + ntru_decode(sched, ct, encoded); + ntru_encode_schedule_free(sched); + + /* Undo the scaling and bias */ + ntru_scale(ct, ct, 3, p, q); + ntru_bias(ct, ct, q - 3 * ciphertext_bias(q), p, q); + } + + return encoded; /* also useful to the caller, optionally */ +} + +/* + * Encode a plaintext. + * + * This is a much simpler encoding than the NTRUEncodeSchedule system: + * since elements of a plaintext are mod 3, we just encode each one in + * 2 bits, applying the usual bias so that {-1,0,+1} map to {0,1,2} + * respectively. + * + * There's no corresponding decode function, because plaintexts are + * never transmitted on the wire (the whole point is that they're too + * secret!). Plaintexts are only encoded in order to put them into + * hash preimages. + */ +void ntru_encode_plaintext(const uint16_t *plaintext, unsigned p, + BinarySink *bs) +{ + unsigned byte = 0, bitpos = 0; + for (size_t i = 0; i < p; i++) { + unsigned encoding = (plaintext[i] + 1) * iszero(plaintext[i] >> 1); + byte |= encoding << bitpos; + bitpos += 2; + if (bitpos == 8 || i+1 == p) { + put_byte(bs, byte); + byte = 0; + bitpos = 0; + } + } +} + +/* ---------------------------------------------------------------------- + * Compute the hashes required by the key exchange layer of NTRU Prime. + * + * There are two of these. The 'confirmation hash' is sent by the + * server along with the ciphertext, and the client can recalculate it + * to check whether the ciphertext was decrypted correctly. Then, the + * 'session hash' is the actual output of key exchange, and if the + * confirmation hash doesn't match, it gets deliberately corrupted. + */ + +/* + * Make the confirmation hash, whose inputs are the plaintext and the + * public key. + * + * This is defined as H(2 || H(3 || r) || H(4 || K)), where r is the + * plaintext and K is the public key (as encoded by the above + * functions), and the constants 2,3,4 are single bytes. The choice of + * hash function (H itself) is SHA-512 truncated to 256 bits. + * + * (To be clear: that is _not_ the thing that FIPS 180-4 6.7 defines + * as "SHA-512/256", which varies the initialisation vector of the + * SHA-512 algorithm as well as truncating the output. _This_ + * algorithm uses the standard SHA-512 IV, and _just_ truncates the + * output, in the manner suggested by FIPS 180-4 section 7.) + * + * 'out' should therefore expect to receive 32 bytes of data. + */ +void ntru_confirmation_hash(uint8_t *out, const uint16_t *plaintext, + const uint16_t *pubkey, unsigned p, unsigned q) +{ + /* The outer hash object */ + ssh_hash *hconfirm = ssh_hash_new(&ssh_sha512); + put_byte(hconfirm, 2); /* initial byte 2 */ + + uint8_t hashdata[64]; + + /* Compute H(3 || r) and add it to the main hash */ + ssh_hash *h3r = ssh_hash_new(&ssh_sha512); + put_byte(h3r, 3); + ntru_encode_plaintext(plaintext, p, BinarySink_UPCAST(h3r)); + ssh_hash_final(h3r, hashdata); + put_data(hconfirm, hashdata, 32); + + /* Compute H(4 || K) and add it to the main hash */ + ssh_hash *h4K = ssh_hash_new(&ssh_sha512); + put_byte(h4K, 4); + ntru_encode_pubkey(pubkey, p, q, BinarySink_UPCAST(h4K)); + ssh_hash_final(h4K, hashdata); + put_data(hconfirm, hashdata, 32); + + /* Compute the full output of the main SHA-512 hash */ + ssh_hash_final(hconfirm, hashdata); + + /* And copy the first 32 bytes into the caller's output array */ + memcpy(out, hashdata, 32); + smemclr(hashdata, sizeof(hashdata)); +} + +/* + * Make the session hash, whose inputs are the plaintext, the + * ciphertext, and the confirmation hash (hence, transitively, a + * dependence on the public key as well). + * + * As computed by the server, and by the client if the confirmation + * hash matched, this is defined as + * + * H(1 || H(3 || r) || ciphertext || confirmation hash) + * + * but if the confirmation hash _didn't_ match, then the plaintext r + * is replaced with the dummy plaintext-shaped value 'rho' we invented + * during key generation (presumably to avoid leaking any information + * about our secrets), and the initial byte 1 is replaced with 0 (to + * ensure that the resulting hash preimage can't match any legitimate + * preimage). So in that case, you instead get + * + * H(0 || H(3 || rho) || ciphertext || confirmation hash) + * + * The inputs to this function include 'ok', which is the value to use + * as the initial byte (1 on success, 0 on failure), and 'plaintext' + * which should already have been substituted with rho in case of + * failure. + * + * The ciphertext is provided in already-encoded form. + */ +void ntru_session_hash(uint8_t *out, unsigned ok, const uint16_t *plaintext, + unsigned p, ptrlen ciphertext, ptrlen confirmation_hash) +{ + /* The outer hash object */ + ssh_hash *hsession = ssh_hash_new(&ssh_sha512); + put_byte(hsession, ok); /* initial byte 1 or 0 */ + + uint8_t hashdata[64]; + + /* Compute H(3 || r), or maybe H(3 || rho), and add it to the main hash */ + ssh_hash *h3r = ssh_hash_new(&ssh_sha512); + put_byte(h3r, 3); + ntru_encode_plaintext(plaintext, p, BinarySink_UPCAST(h3r)); + ssh_hash_final(h3r, hashdata); + put_data(hsession, hashdata, 32); + + /* Put the ciphertext and confirmation hash in */ + put_datapl(hsession, ciphertext); + put_datapl(hsession, confirmation_hash); + + /* Compute the full output of the main SHA-512 hash */ + ssh_hash_final(hsession, hashdata); + + /* And copy the first 32 bytes into the caller's output array */ + memcpy(out, hashdata, 32); + smemclr(hashdata, sizeof(hashdata)); +} + +/* ---------------------------------------------------------------------- + * Top-level key exchange and SSH integration. + * + * Although this system borrows the ECDH packet structure, it's unlike + * true ECDH in that it is completely asymmetric between client and + * server. So we have two separate vtables of methods for the two + * sides of the system, and a third vtable containing only the class + * methods, in particular a constructor which chooses which one to + * instantiate. + */ + +/* + * The parameters p,q,w for the system. There are other choices of + * these, but OpenSSH only specifies this set. (If that ever changes, + * we'll need to turn these into elements of the state structures.) + */ +#define p_LIVE 761 +#define q_LIVE 4591 +#define w_LIVE 286 + +static char *ssh_ntru_description(const ssh_kex *kex) +{ + return dupprintf("NTRU Prime / Curve25519 hybrid key exchange"); +} + +/* + * State structure for the client, which takes the role of inventing a + * key pair and decrypting a secret plaintext sent to it by the server. + */ +typedef struct ntru_client_key { + NTRUKeyPair *keypair; + ecdh_key *curve25519; + + ecdh_key ek; +} ntru_client_key; + +static void ssh_ntru_client_free(ecdh_key *dh); +static void ssh_ntru_client_getpublic(ecdh_key *dh, BinarySink *bs); +static bool ssh_ntru_client_getkey(ecdh_key *dh, ptrlen remoteKey, + BinarySink *bs); + +static const ecdh_keyalg ssh_ntru_client_vt = { + /* This vtable has no 'new' method, because it's constructed via + * the selector vt below */ + .free = ssh_ntru_client_free, + .getpublic = ssh_ntru_client_getpublic, + .getkey = ssh_ntru_client_getkey, + .description = ssh_ntru_description, +}; + +static ecdh_key *ssh_ntru_client_new(void) +{ + ntru_client_key *nk = snew(ntru_client_key); + nk->ek.vt = &ssh_ntru_client_vt; + + nk->keypair = ntru_keygen(p_LIVE, q_LIVE, w_LIVE); + nk->curve25519 = ecdh_key_new(&ssh_ec_kex_curve25519, false); + + return &nk->ek; +} + +static void ssh_ntru_client_free(ecdh_key *dh) +{ + ntru_client_key *nk = container_of(dh, ntru_client_key, ek); + ntru_keypair_free(nk->keypair); + ecdh_key_free(nk->curve25519); + sfree(nk); +} + +static void ssh_ntru_client_getpublic(ecdh_key *dh, BinarySink *bs) +{ + ntru_client_key *nk = container_of(dh, ntru_client_key, ek); + + /* + * The client's public information is a single SSH string + * containing the NTRU public key and the Curve25519 public point + * concatenated. So write both of those into the output + * BinarySink. + */ + ntru_encode_pubkey(nk->keypair->h, p_LIVE, q_LIVE, bs); + ecdh_key_getpublic(nk->curve25519, bs); +} + +static bool ssh_ntru_client_getkey(ecdh_key *dh, ptrlen remoteKey, + BinarySink *bs) +{ + ntru_client_key *nk = container_of(dh, ntru_client_key, ek); + + /* + * We expect the server to have sent us a string containing a + * ciphertext, a confirmation hash, and a Curve25519 public point. + * Extract all three. + */ + BinarySource src[1]; + BinarySource_BARE_INIT_PL(src, remoteKey); + + uint16_t *ciphertext = snewn(p_LIVE, uint16_t); + ptrlen ciphertext_encoded = ntru_decode_ciphertext( + ciphertext, nk->keypair, src); + ptrlen confirmation_hash = get_data(src, 32); + ptrlen curve25519_remoteKey = get_data(src, 32); + + if (get_err(src) || get_avail(src)) { + /* Hard-fail if the input wasn't exactly the right length */ + ring_free(ciphertext, p_LIVE); + return false; + } + + /* + * Main hash object which will combine the NTRU and Curve25519 + * outputs. + */ + ssh_hash *h = ssh_hash_new(&ssh_sha512); + + /* Reusable buffer for storing various hash outputs. */ + uint8_t hashdata[64]; + + /* + * NTRU side. + */ + { + /* Decrypt the ciphertext to recover the server's plaintext */ + uint16_t *plaintext = snewn(p_LIVE, uint16_t); + ntru_decrypt(plaintext, ciphertext, nk->keypair); + + /* Make the confirmation hash */ + ntru_confirmation_hash(hashdata, plaintext, nk->keypair->h, + p_LIVE, q_LIVE); + + /* Check it matches the one the server sent */ + unsigned ok = smemeq(hashdata, confirmation_hash.ptr, 32); + + /* If not, substitute in rho for the plaintext in the session hash */ + unsigned mask = ok-1; + for (size_t i = 0; i < p_LIVE; i++) + plaintext[i] ^= mask & (plaintext[i] ^ nk->keypair->rho[i]); + + /* Compute the session hash, whether or not we did that */ + ntru_session_hash(hashdata, ok, plaintext, p_LIVE, ciphertext_encoded, + confirmation_hash); + + /* Free temporary values */ + ring_free(plaintext, p_LIVE); + ring_free(ciphertext, p_LIVE); + + /* And put the NTRU session hash into the main hash object. */ + put_data(h, hashdata, 32); + } + + /* + * Curve25519 side. + */ + { + strbuf *otherkey = strbuf_new_nm(); + + /* Call out to Curve25519 to compute the shared secret from that + * kex method */ + bool ok = ecdh_key_getkey(nk->curve25519, curve25519_remoteKey, + BinarySink_UPCAST(otherkey)); + + /* If that failed (which only happens if the other end does + * something wrong, like sending a low-order curve point + * outside the subgroup it's supposed to), we might as well + * just abort and return failure. That's what we'd have done + * in standalone Curve25519. */ + if (!ok) { + ssh_hash_free(h); + smemclr(hashdata, sizeof(hashdata)); + return false; + } + + /* + * ecdh_key_getkey will have returned us a chunk of data + * containing an encoded mpint, which is how the Curve25519 + * output normally goes into the exchange hash. But in this + * context we want to treat it as a fixed big-endian 32 bytes, + * so extract it from its encoding and put it into the main + * hash object in the new format. + */ + BinarySource src[1]; + BinarySource_BARE_INIT_PL(src, ptrlen_from_strbuf(otherkey)); + mp_int *curvekey = get_mp_ssh2(src); + + for (unsigned i = 32; i-- > 0 ;) + put_byte(h, mp_get_byte(curvekey, i)); + + mp_free(curvekey); + strbuf_free(otherkey); + } + + /* + * Finish up: compute the final output hash (full 64 bytes of + * SHA-512 this time), and return it encoded as a string. + */ + ssh_hash_final(h, hashdata); + put_stringpl(bs, make_ptrlen(hashdata, sizeof(hashdata))); + smemclr(hashdata, sizeof(hashdata)); + + return true; +} + +/* + * State structure for the server, which takes the role of inventing a + * secret plaintext and sending it to the client encrypted with the + * public key the client sent. + */ +typedef struct ntru_server_key { + uint16_t *plaintext; + strbuf *ciphertext_encoded, *confirmation_hash; + ecdh_key *curve25519; + + ecdh_key ek; +} ntru_server_key; + +static void ssh_ntru_server_free(ecdh_key *dh); +static void ssh_ntru_server_getpublic(ecdh_key *dh, BinarySink *bs); +static bool ssh_ntru_server_getkey(ecdh_key *dh, ptrlen remoteKey, + BinarySink *bs); + +static const ecdh_keyalg ssh_ntru_server_vt = { + /* This vtable has no 'new' method, because it's constructed via + * the selector vt below */ + .free = ssh_ntru_server_free, + .getpublic = ssh_ntru_server_getpublic, + .getkey = ssh_ntru_server_getkey, + .description = ssh_ntru_description, +}; + +static ecdh_key *ssh_ntru_server_new(void) +{ + ntru_server_key *nk = snew(ntru_server_key); + nk->ek.vt = &ssh_ntru_server_vt; + + nk->plaintext = snewn(p_LIVE, uint16_t); + nk->ciphertext_encoded = strbuf_new_nm(); + nk->confirmation_hash = strbuf_new_nm(); + ntru_gen_short(nk->plaintext, p_LIVE, w_LIVE); + + nk->curve25519 = ecdh_key_new(&ssh_ec_kex_curve25519, false); + + return &nk->ek; +} + +static void ssh_ntru_server_free(ecdh_key *dh) +{ + ntru_server_key *nk = container_of(dh, ntru_server_key, ek); + ring_free(nk->plaintext, p_LIVE); + strbuf_free(nk->ciphertext_encoded); + strbuf_free(nk->confirmation_hash); + ecdh_key_free(nk->curve25519); + sfree(nk); +} + +static bool ssh_ntru_server_getkey(ecdh_key *dh, ptrlen remoteKey, + BinarySink *bs) +{ + ntru_server_key *nk = container_of(dh, ntru_server_key, ek); + + /* + * In the server, getkey is called first, with the public + * information received from the client. We expect the client to + * have sent us a string containing a public key and a Curve25519 + * public point. + */ + BinarySource src[1]; + BinarySource_BARE_INIT_PL(src, remoteKey); + + uint16_t *pubkey = snewn(p_LIVE, uint16_t); + ntru_decode_pubkey(pubkey, p_LIVE, q_LIVE, src); + ptrlen curve25519_remoteKey = get_data(src, 32); + + if (get_err(src) || get_avail(src)) { + /* Hard-fail if the input wasn't exactly the right length */ + ring_free(pubkey, p_LIVE); + return false; + } + + /* + * Main hash object which will combine the NTRU and Curve25519 + * outputs. + */ + ssh_hash *h = ssh_hash_new(&ssh_sha512); + + /* Reusable buffer for storing various hash outputs. */ + uint8_t hashdata[64]; + + /* + * NTRU side. + */ + { + /* Encrypt the plaintext we generated at construction time, + * and encode the ciphertext into a strbuf so we can reuse it + * for both the session hash and sending to the client. */ + uint16_t *ciphertext = snewn(p_LIVE, uint16_t); + ntru_encrypt(ciphertext, nk->plaintext, pubkey, p_LIVE, q_LIVE); + ntru_encode_ciphertext(ciphertext, p_LIVE, q_LIVE, + BinarySink_UPCAST(nk->ciphertext_encoded)); + + /* Compute the confirmation hash, and write it into another + * strbuf. */ + ntru_confirmation_hash(hashdata, nk->plaintext, pubkey, + p_LIVE, q_LIVE); + put_data(nk->confirmation_hash, hashdata, 32); + + /* Compute the session hash (which is easy on the server side, + * requiring no conditional substitution). */ + ntru_session_hash(hashdata, 1, nk->plaintext, p_LIVE, + ptrlen_from_strbuf(nk->ciphertext_encoded), + ptrlen_from_strbuf(nk->confirmation_hash)); + + /* And put the NTRU session hash into the main hash object. */ + put_data(h, hashdata, 32); + } + + /* + * Curve25519 side. + */ + { + strbuf *otherkey = strbuf_new_nm(); + + /* Call out to Curve25519 to compute the shared secret from that + * kex method */ + bool ok = ecdh_key_getkey(nk->curve25519, curve25519_remoteKey, + BinarySink_UPCAST(otherkey)); + /* As on the client side, abort if Curve25519 reported failure */ + if (!ok) { + ssh_hash_free(h); + smemclr(hashdata, sizeof(hashdata)); + return false; + } + + /* As on the client side, decode Curve25519's mpint so we can + * re-encode it appropriately for our hash preimage */ + BinarySource src[1]; + BinarySource_BARE_INIT_PL(src, ptrlen_from_strbuf(otherkey)); + mp_int *curvekey = get_mp_ssh2(src); + + for (unsigned i = 32; i-- > 0 ;) + put_byte(h, mp_get_byte(curvekey, i)); + + mp_free(curvekey); + strbuf_free(otherkey); + } + + /* + * Finish up: compute the final output hash (full 64 bytes of + * SHA-512 this time), and return it encoded as a string. + */ + ssh_hash_final(h, hashdata); + put_stringpl(bs, make_ptrlen(hashdata, sizeof(hashdata))); + smemclr(hashdata, sizeof(hashdata)); + + return true; +} + +static void ssh_ntru_server_getpublic(ecdh_key *dh, BinarySink *bs) +{ + ntru_server_key *nk = container_of(dh, ntru_server_key, ek); + + /* + * In the server, this function is called after getkey, so we + * already have all our pieces prepared. Just concatenate them all + * into the 'server's public data' string to go in ECDH_REPLY. + */ + put_datapl(bs, ptrlen_from_strbuf(nk->ciphertext_encoded)); + put_datapl(bs, ptrlen_from_strbuf(nk->confirmation_hash)); + ecdh_key_getpublic(nk->curve25519, bs); +} + +/* ---------------------------------------------------------------------- + * Selector vtable that instantiates the appropriate one of the above, + * depending on is_server. + */ +static ecdh_key *ssh_ntru_new(const ssh_kex *kex, bool is_server) +{ + if (is_server) + return ssh_ntru_server_new(); + else + return ssh_ntru_client_new(); +} + +static const ecdh_keyalg ssh_ntru_selector_vt = { + /* This is a never-instantiated vtable which only implements the + * functions that don't require an instance. */ + .new = ssh_ntru_new, + .description = ssh_ntru_description, +}; + +const ssh_kex ssh_ntru_curve25519 = { + .name = "sntrup761x25519-sha512@openssh.com", + .main_type = KEXTYPE_ECDH, + .hash = &ssh_sha512, + .ecdh_vt = &ssh_ntru_selector_vt, +}; + +static const ssh_kex *const hybrid_list[] = { + &ssh_ntru_curve25519, +}; + +const ssh_kexes ssh_ntru_hybrid_kex = { lenof(hybrid_list), hybrid_list }; diff --git a/crypto/ntru.h b/crypto/ntru.h new file mode 100644 index 00000000..4789491b --- /dev/null +++ b/crypto/ntru.h @@ -0,0 +1,53 @@ +/* + * Internal functions for the NTRU cryptosystem, exposed in a header + * that is expected to be included only by ntru.c and test programs. + */ + +#ifndef PUTTY_CRYPTO_NTRU_H +#define PUTTY_CRYPTO_NTRU_H + +unsigned ntru_ring_invert(uint16_t *out, const uint16_t *in, + unsigned p, unsigned q); +void ntru_ring_multiply(uint16_t *out, const uint16_t *a, const uint16_t *b, + unsigned p, unsigned q); +void ntru_mod3(uint16_t *out, const uint16_t *in, unsigned p, unsigned q); +void ntru_round3(uint16_t *out, const uint16_t *in, unsigned p, unsigned q); +void ntru_bias(uint16_t *out, const uint16_t *in, unsigned bias, + unsigned p, unsigned q); +void ntru_scale(uint16_t *out, const uint16_t *in, uint16_t scale, + unsigned p, unsigned q); + +NTRUEncodeSchedule *ntru_encode_schedule(const uint16_t *ms_in, size_t n); +void ntru_encode_schedule_free(NTRUEncodeSchedule *sched); +size_t ntru_encode_schedule_length(NTRUEncodeSchedule *sched); +size_t ntru_encode_schedule_nvals(NTRUEncodeSchedule *sched); +void ntru_encode(NTRUEncodeSchedule *sched, const uint16_t *rs_in, + BinarySink *bs); +void ntru_decode(NTRUEncodeSchedule *sched, uint16_t *rs_out, ptrlen data); + +void ntru_gen_short(uint16_t *v, unsigned p, unsigned w); + +NTRUKeyPair *ntru_keygen_attempt(unsigned p, unsigned q, unsigned w); +NTRUKeyPair *ntru_keygen(unsigned p, unsigned q, unsigned w); +void ntru_keypair_free(NTRUKeyPair *keypair); + +void ntru_encrypt(uint16_t *ciphertext, const uint16_t *plaintext, + uint16_t *pubkey, unsigned p, unsigned q); +void ntru_decrypt(uint16_t *plaintext, const uint16_t *ciphertext, + NTRUKeyPair *keypair); + +void ntru_encode_pubkey(const uint16_t *pubkey, unsigned p, unsigned q, + BinarySink *bs); +ptrlen ntru_decode_pubkey(uint16_t *pubkey, unsigned p, unsigned q, + BinarySource *src); +void ntru_encode_ciphertext(const uint16_t *ciphertext, unsigned p, unsigned q, + BinarySink *bs); +ptrlen ntru_decode_ciphertext(uint16_t *ct, NTRUKeyPair *keypair, + BinarySource *src); +void ntru_encode_plaintext(const uint16_t *plaintext, unsigned p, + BinarySink *bs); + +unsigned ntru_keypair_p(NTRUKeyPair *keypair); +const uint16_t *ntru_pubkey(NTRUKeyPair *keypair); + +#endif /* PUTTY_CRYPTO_NTRU_H */ diff --git a/defs.h b/defs.h index 1edf3442..37cc7979 100644 --- a/defs.h +++ b/defs.h @@ -168,6 +168,8 @@ typedef struct ssh2_ciphers ssh2_ciphers; typedef struct dh_ctx dh_ctx; typedef struct ecdh_key ecdh_key; typedef struct ecdh_keyalg ecdh_keyalg; +typedef struct NTRUKeyPair NTRUKeyPair; +typedef struct NTRUEncodeSchedule NTRUEncodeSchedule; typedef struct dlgparam dlgparam; diff --git a/putty.h b/putty.h index fc5c2941..84a7df31 100644 --- a/putty.h +++ b/putty.h @@ -426,6 +426,7 @@ enum { KEX_DHGEX, KEX_RSA, KEX_ECDH, + KEX_NTRU_HYBRID, KEX_MAX }; diff --git a/settings.c b/settings.c index 98313d17..09701618 100644 --- a/settings.c +++ b/settings.c @@ -28,6 +28,7 @@ static const struct keyvalwhere ciphernames[] = { * compatibility warts in load_open_settings(), and should be kept * in sync with those. */ static const struct keyvalwhere kexnames[] = { + { "ntru-curve25519", KEX_NTRU_HYBRID, -1, +1 }, { "ecdh", KEX_ECDH, -1, +1 }, /* This name is misleading: it covers both SHA-256 and SHA-1 variants */ { "dh-gex-sha1", KEX_DHGEX, -1, -1 }, diff --git a/ssh.h b/ssh.h index 60880482..fb86f4f0 100644 --- a/ssh.h +++ b/ssh.h @@ -1058,6 +1058,7 @@ extern const ssh_kex ssh_ec_kex_nistp256; extern const ssh_kex ssh_ec_kex_nistp384; extern const ssh_kex ssh_ec_kex_nistp521; extern const ssh_kexes ssh_ecdh_kex; +extern const ssh_kexes ssh_ntru_hybrid_kex; extern const ssh_keyalg ssh_dsa; extern const ssh_keyalg ssh_rsa; extern const ssh_keyalg ssh_rsa_sha256; diff --git a/ssh/transport2.c b/ssh/transport2.c index e00507e1..02289747 100644 --- a/ssh/transport2.c +++ b/ssh/transport2.c @@ -533,6 +533,10 @@ static void ssh2_write_kexinit_lists( preferred_kex[n_preferred_kex++] = &ssh_ecdh_kex; break; + case KEX_NTRU_HYBRID: + preferred_kex[n_preferred_kex++] = + &ssh_ntru_hybrid_kex; + break; case KEX_WARN: /* Flag for later. Don't bother if it's the last in * the list. */ diff --git a/test/cryptsuite.py b/test/cryptsuite.py index 958980b9..ef2f79a4 100755 --- a/test/cryptsuite.py +++ b/test/cryptsuite.py @@ -1274,6 +1274,107 @@ class keygen(MyTestBase): mr = miller_rabin_new(n) self.assertEqual(miller_rabin_test(mr, 0x251), "failed") +class ntru(MyTestBase): + def testMultiply(self): + self.assertEqual( + ntru_ring_multiply([1,1,1,1,1,1], [1,1,1,1,1,1], 11, 59), + [1,2,3,4,5,6,5,4,3,2,1]) + self.assertEqual(ntru_ring_multiply( + [1,0,1,2,0,0,1,2,0,1,2], [2,0,0,1,0,1,2,2,2,0,2], 11, 3), + [1,0,0,0,0,0,0,0,0,0,0]) + + def testInvert(self): + # Over GF(3), x^11-x-1 factorises as + # (x^3+x^2+2) * (x^8+2*x^7+x^6+2*x^4+2*x^3+x^2+x+1) + # so we expect that 2,0,1,1 has no inverse, being one of those factors. + self.assertEqual(ntru_ring_invert([0], 11, 3), None) + self.assertEqual(ntru_ring_invert([1], 11, 3), + [1,0,0,0,0,0,0,0,0,0,0]) + self.assertEqual(ntru_ring_invert([2,0,1,1], 11, 3), None) + self.assertEqual(ntru_ring_invert([1,0,1,2,0,0,1,2,0,1,2], 11, 3), + [2,0,0,1,0,1,2,2,2,0,2]) + + self.assertEqual(ntru_ring_invert([1,0,1,2,0,0,1,2,0,1,2], 11, 59), + [1,26,10,1,38,48,34,37,53,3,53]) + + def testMod3Round3(self): + # Try a prime congruent to 1 mod 3 + self.assertEqual(ntru_mod3([4,5,6,0,1,2,3], 7, 7), + [0,1,-1,0,1,-1,0]) + self.assertEqual(ntru_round3([4,5,6,0,1,2,3], 7, 7), + [-3,-3,0,0,0,3,3]) + + # And one congruent to 2 mod 3 + self.assertEqual(ntru_mod3([6,7,8,9,10,0,1,2,3,4,5], 11, 11), + [1,-1,0,1,-1,0,1,-1,0,1,-1]) + self.assertEqual(ntru_round3([6,7,8,9,10,0,1,2,3,4,5], 11, 11), + [-6,-3,-3,-3,0,0,0,3,3,3,6]) + + def testBiasScale(self): + self.assertEqual(ntru_bias([0,1,2,3,4,5,6,7,8,9,10], 4, 11, 11), + [4,5,6,7,8,9,10,0,1,2,3]) + self.assertEqual(ntru_scale([0,1,2,3,4,5,6,7,8,9,10], 4, 11, 11), + [0,4,8,1,5,9,2,6,10,3,7]) + + def testEncode(self): + # Test a small case. Worked through in detail: + # + # Pass 1: + # Input list is (89:123, 90:234, 344:345, 432:456, 222:567) + # (89:123, 90:234) -> (89+123*90 : 123*234) = (11159:28782) + # Emit low byte of 11159 = 0x97, and get (43:113) + # (344:345, 432:456) -> (344+345*432 : 345*456) = (149384:157320) + # Emit low byte of 149384 = 0x88, and get (583:615) + # Odd pair (222:567) is copied to end of new list + # Final list is (43:113, 583:615, 222:567) + # Pass 2: + # Input list is (43:113, 583:615, 222:567) + # (43:113, 583:615) -> (43+113*583, 113*615) = (65922:69495) + # Emit low byte of 65922 = 0x82, and get (257:272) + # Odd pair (222:567) is copied to end of new list + # Final list is (257:272, 222:567) + # Pass 3: + # Input list is (257:272, 222:567) + # (257:272, 222:567) -> (257+272*222, 272*567) = (60641:154224) + # Emit low byte of 60641 = 0xe1, and get (236:603) + # Final list is (236:603) + # Cleanup: + # Emit low byte of 236 = 0xec, and get (0:3) + # Emit low byte of 0 = 0x00, and get (0:1) + + ms = [123,234,345,456,567] + rs = [89,90,344,432,222] + encoding = unhex('978882e1ec00') + sched = ntru_encode_schedule(ms) + self.assertEqual(sched.encode(rs), encoding) + self.assertEqual(sched.decode(encoding), rs) + + # Encode schedules for sntrup761 public keys and ciphertexts + pubsched = ntru_encode_schedule([4591]*761) + self.assertEqual(pubsched.length(), 1158) + ciphersched = ntru_encode_schedule([1531]*761) + self.assertEqual(ciphersched.length(), 1007) + + # Test round-trip encoding using those schedules + testlist = list(range(761)) + pubtext = pubsched.encode(testlist) + self.assertEqual(pubsched.decode(pubtext), testlist) + ciphertext = ciphersched.encode(testlist) + self.assertEqual(ciphersched.decode(ciphertext), testlist) + + def testCore(self): + # My own set of NTRU Prime parameters, satisfying all the + # requirements and tiny enough for convenient testing + p, q, w = 11, 59, 3 + + with random_prng('ntru keygen seed'): + keypair = ntru_keygen(p, q, w) + plaintext = ntru_gen_short(p, w) + + ciphertext = ntru_encrypt(plaintext, ntru_pubkey(keypair), p, q) + recovered = ntru_decrypt(ciphertext, keypair) + self.assertEqual(plaintext, recovered) + class crypt(MyTestBase): def testSSH1Fingerprint(self): # Example key and reference fingerprint value generated by diff --git a/test/testcrypt-func.h b/test/testcrypt-func.h index 1188d2c4..54de4b88 100644 --- a/test/testcrypt-func.h +++ b/test/testcrypt-func.h @@ -352,6 +352,35 @@ FUNC(void, ecdh_key_getpublic, ARG(val_ecdh, key), FUNC_WRAPPED(opt_val_string, ecdh_key_getkey, ARG(val_ecdh, key), ARG(val_string_ptrlen, pub)) +/* + * NTRU and its subroutines. + */ +FUNC_WRAPPED(int16_list, ntru_ring_multiply, ARG(int16_list, a), + ARG(int16_list, b), ARG(uint, p), ARG(uint, q)) +FUNC_WRAPPED(opt_int16_list, ntru_ring_invert, ARG(int16_list, r), + ARG(uint, p), ARG(uint, q)) +FUNC_WRAPPED(int16_list, ntru_mod3, ARG(int16_list, r), + ARG(uint, p), ARG(uint, q)) +FUNC_WRAPPED(int16_list, ntru_round3, ARG(int16_list, r), + ARG(uint, p), ARG(uint, q)) +FUNC_WRAPPED(int16_list, ntru_bias, ARG(int16_list, r), + ARG(uint, bias), ARG(uint, p), ARG(uint, q)) +FUNC_WRAPPED(int16_list, ntru_scale, ARG(int16_list, r), + ARG(uint, scale), ARG(uint, p), ARG(uint, q)) +FUNC_WRAPPED(val_ntruencodeschedule, ntru_encode_schedule, ARG(int16_list, ms)) +FUNC(uint, ntru_encode_schedule_length, ARG(val_ntruencodeschedule, sched)) +FUNC_WRAPPED(void, ntru_encode, ARG(val_ntruencodeschedule, sched), + ARG(int16_list, rs), ARG(out_val_string_binarysink, data)) +FUNC_WRAPPED(opt_int16_list, ntru_decode, ARG(val_ntruencodeschedule, sched), + ARG(val_string_ptrlen, data)) +FUNC_WRAPPED(int16_list, ntru_gen_short, ARG(uint, p), ARG(uint, w)) +FUNC(val_ntrukeypair, ntru_keygen, ARG(uint, p), ARG(uint, q), ARG(uint, w)) +FUNC_WRAPPED(int16_list, ntru_pubkey, ARG(val_ntrukeypair, keypair)) +FUNC_WRAPPED(int16_list, ntru_encrypt, ARG(int16_list, plaintext), + ARG(int16_list, pubkey), ARG(uint, p), ARG(uint, q)) +FUNC_WRAPPED(int16_list, ntru_decrypt, ARG(int16_list, ciphertext), + ARG(val_ntrukeypair, keypair)) + /* * RSA key exchange, and also the BinarySource get function * get_ssh1_rsa_priv_agent, which is a convenient way to make an diff --git a/test/testcrypt.c b/test/testcrypt.c index cfe28ae2..3157ed54 100644 --- a/test/testcrypt.c +++ b/test/testcrypt.c @@ -35,6 +35,7 @@ #include "misc.h" #include "mpint.h" #include "crypto/ecc.h" +#include "crypto/ntru.h" #include "proxy/cproxy.h" static NORETURN PRINTF_LIKE(1, 2) void fatal_error(const char *p, ...) @@ -96,6 +97,8 @@ uint64_t prng_reseed_time_ms(void) X(pgc, PrimeGenerationContext *, primegen_free_context(v)) \ X(pockle, Pockle *, pockle_free(v)) \ X(millerrabin, MillerRabin *, miller_rabin_free(v)) \ + X(ntrukeypair, NTRUKeyPair *, ntru_keypair_free(v)) \ + X(ntruencodeschedule, NTRUEncodeSchedule *, ntru_encode_schedule_free(v)) \ /* end of list */ typedef struct Value Value; @@ -221,6 +224,7 @@ typedef RsaSsh1Order TD_rsaorder; typedef key_components *TD_keycomponents; typedef const PrimeGenerationPolicy *TD_primegenpolicy; typedef struct mpint_list TD_mpint_list; +typedef struct int16_list *TD_int16_list; typedef PockleStatus TD_pocklestatus; typedef struct mr_result TD_mr_result; typedef Argon2Flavour TD_argon2flavour; @@ -385,6 +389,46 @@ static struct mpint_list get_mpint_list(BinarySource *in) return mpl; } +typedef struct int16_list { + size_t n; + uint16_t *integers; +} int16_list; + +static void finaliser_int16_list_free(strbuf *out, void *vlist) +{ + int16_list *list = (int16_list *)vlist; + sfree(list->integers); + sfree(list); +} + +static int16_list *make_int16_list(size_t n) +{ + int16_list *list = snew(int16_list); + list->n = n; + list->integers = snewn(n, uint16_t); + add_finaliser(finaliser_int16_list_free, list); + return list; +} + +static int16_list *get_int16_list(BinarySource *in) +{ + size_t n = get_uint(in); + int16_list *list = make_int16_list(n); + for (size_t i = 0; i < n; i++) + list->integers[i] = get_uint(in); + return list; +} + +static void return_int16_list(strbuf *out, int16_list *list) +{ + for (size_t i = 0; i < list->n; i++) { + if (i > 0) + put_byte(out, ','); + put_fmt(out, "%d", (int)(int16_t)list->integers[i]); + } + put_byte(out, '\n'); +} + static void finaliser_return_uint(strbuf *out, void *ctx) { unsigned *uval = (unsigned *)ctx; @@ -543,6 +587,7 @@ NULLABLE_RETURN_WRAPPER(val_cipher, ssh_cipher *) NULLABLE_RETURN_WRAPPER(val_hash, ssh_hash *) NULLABLE_RETURN_WRAPPER(val_key, ssh_key *) NULLABLE_RETURN_WRAPPER(val_mpint, mp_int *) +NULLABLE_RETURN_WRAPPER(int16_list, int16_list *) static void handle_hello(BinarySource *in, strbuf *out) { @@ -799,6 +844,130 @@ strbuf *ecdh_key_getkey_wrapper(ecdh_key *ek, ptrlen remoteKey) return sb; } +static void int16_list_resize(int16_list *list, unsigned p) +{ + list->integers = sresize(list->integers, p, uint16_t); + for (size_t i = list->n; i < p; i++) + list->integers[i] = 0; +} + +#if 0 +static int16_list ntru_ring_to_list_and_free(uint16_t *out, unsigned p) +{ + struct mpint_list mpl; + mpl.n = p; + mpl->integers = snewn(p, mp_int *); + for (unsigned i = 0; i < p; i++) + mpl->integers[i] = mp_from_integer((int16_t)out[i]); + sfree(out); + add_finaliser(finaliser_sfree, mpl->integers); + return mpl; +} +#endif + +int16_list *ntru_ring_multiply_wrapper( + int16_list *a, int16_list *b, unsigned p, unsigned q) +{ + int16_list_resize(a, p); + int16_list_resize(b, p); + int16_list *out = make_int16_list(p); + ntru_ring_multiply(out->integers, a->integers, b->integers, p, q); + return out; +} + +int16_list *ntru_ring_invert_wrapper(int16_list *in, unsigned p, unsigned q) +{ + int16_list_resize(in, p); + int16_list *out = make_int16_list(p); + unsigned success = ntru_ring_invert(out->integers, in->integers, p, q); + if (!success) + return NULL; + return out; +} + +int16_list *ntru_mod3_wrapper(int16_list *in, unsigned p, unsigned q) +{ + int16_list_resize(in, p); + int16_list *out = make_int16_list(p); + ntru_mod3(out->integers, in->integers, p, q); + return out; +} + +int16_list *ntru_round3_wrapper(int16_list *in, unsigned p, unsigned q) +{ + int16_list_resize(in, p); + int16_list *out = make_int16_list(p); + ntru_round3(out->integers, in->integers, p, q); + return out; +} + +int16_list *ntru_bias_wrapper(int16_list *in, unsigned bias, + unsigned p, unsigned q) +{ + int16_list_resize(in, p); + int16_list *out = make_int16_list(p); + ntru_bias(out->integers, in->integers, bias, p, q); + return out; +} + +int16_list *ntru_scale_wrapper(int16_list *in, unsigned scale, + unsigned p, unsigned q) +{ + int16_list_resize(in, p); + int16_list *out = make_int16_list(p); + ntru_scale(out->integers, in->integers, scale, p, q); + return out; +} + +NTRUEncodeSchedule *ntru_encode_schedule_wrapper(int16_list *in) +{ + return ntru_encode_schedule(in->integers, in->n); +} + +void ntru_encode_wrapper(NTRUEncodeSchedule *sched, int16_list *rs, + BinarySink *bs) +{ + ntru_encode(sched, rs->integers, bs); +} + +int16_list *ntru_decode_wrapper(NTRUEncodeSchedule *sched, ptrlen data) +{ + int16_list *out = make_int16_list(ntru_encode_schedule_nvals(sched)); + ntru_decode(sched, out->integers, data); + return out; +} + +int16_list *ntru_gen_short_wrapper(unsigned p, unsigned w) +{ + int16_list *out = make_int16_list(p); + ntru_gen_short(out->integers, p, w); + return out; +} + +int16_list *ntru_pubkey_wrapper(NTRUKeyPair *keypair) +{ + unsigned p = ntru_keypair_p(keypair); + int16_list *out = make_int16_list(p); + memcpy(out->integers, ntru_pubkey(keypair), p*sizeof(uint16_t)); + return out; +} + +int16_list *ntru_encrypt_wrapper(int16_list *plaintext, int16_list *pubkey, + unsigned p, unsigned q) +{ + int16_list *out = make_int16_list(p); + ntru_encrypt(out->integers, plaintext->integers, pubkey->integers, p, q); + return out; +} + +int16_list *ntru_decrypt_wrapper(int16_list *ciphertext, NTRUKeyPair *keypair) +{ + unsigned p = ntru_keypair_p(keypair); + int16_list *out = make_int16_list(p); + ntru_decrypt(out->integers, ciphertext->integers, keypair); + return out; +} + strbuf *rsa_ssh1_encrypt_wrapper(ptrlen input, RSAKey *key) { /* Fold the boolean return value in C into the string return value diff --git a/test/testcrypt.py b/test/testcrypt.py index 6c0e95ce..66f63d5c 100644 --- a/test/testcrypt.py +++ b/test/testcrypt.py @@ -199,6 +199,13 @@ def make_argword(arg, argtype, fnname, argindex, argname, to_preserve): sublist.append(make_argword(val, ("val_mpint", False), fnname, argindex, argname, to_preserve)) return b" ".join(coerce_to_bytes(sub) for sub in sublist) + if typename == "int16_list": + sublist = [make_argword(len(arg), ("uint", False), + fnname, argindex, argname, to_preserve)] + for val in arg: + sublist.append(make_argword(val & 0xFFFF, ("uint", False), + fnname, argindex, argname, to_preserve)) + return b" ".join(coerce_to_bytes(sub) for sub in sublist) raise TypeError( "Can't convert {}() argument #{:d} ({}) to {} (value was {!r})".format( fnname, argindex, argname, typename, arg)) @@ -247,6 +254,8 @@ def make_retval(rettype, word, unpack_strings): return word == b"true" elif rettype in {"pocklestatus", "mr_result"}: return word.decode("ASCII") + elif rettype == "int16_list": + return list(map(int, word.split(b','))) raise TypeError("Can't deal with return value {!r} of type {!r}" .format(word, rettype)) diff --git a/test/testsc.c b/test/testsc.c index 4d8b55a4..55a64aba 100644 --- a/test/testsc.c +++ b/test/testsc.c @@ -81,6 +81,7 @@ #include "misc.h" #include "mpint.h" #include "crypto/ecc.h" +#include "crypto/ntru.h" static NORETURN PRINTF_LIKE(1, 2) void fatal_error(const char *p, ...) { @@ -395,6 +396,7 @@ VOLATILE_WRAPPED_DEFN(static, size_t, looplimit, (size_t x)) HASHES(HASH_TESTLIST, X) \ X(argon2) \ X(primegen_probabilistic) \ + X(ntru) \ /* end of list */ static void test_mp_get_nbits(void) @@ -1556,6 +1558,74 @@ static void test_primegen_probabilistic(void) test_primegen(&primegen_probabilistic); } +static void test_ntru(void) +{ + unsigned p = 11, q = 59, w = 3; + uint16_t *pubkey_orig = snewn(p, uint16_t); + uint16_t *pubkey_check = snewn(p, uint16_t); + uint16_t *pubkey = snewn(p, uint16_t); + uint16_t *plaintext = snewn(p, uint16_t); + uint16_t *ciphertext = snewn(p, uint16_t); + + strbuf *buffer = strbuf_new(); + strbuf_append(buffer, 16384); + BinarySource src[1]; + + for (size_t i = 0; i < looplimit(32); i++) { + while (true) { + random_advance_counter(); + struct random_state st = random_get_state(); + + NTRUKeyPair *keypair = ntru_keygen_attempt(p, q, w); + + if (keypair) { + memcpy(pubkey_orig, ntru_pubkey(keypair), + p*sizeof(*pubkey_orig)); + ntru_keypair_free(keypair); + + random_set_state(st); + + log_start(); + NTRUKeyPair *keypair = ntru_keygen_attempt(p, q, w); + memcpy(pubkey_check, ntru_pubkey(keypair), + p*sizeof(*pubkey_check)); + + ntru_gen_short(plaintext, p, w); + ntru_encrypt(ciphertext, plaintext, pubkey, p, w); + ntru_decrypt(plaintext, ciphertext, keypair); + + strbuf_clear(buffer); + ntru_encode_pubkey(ntru_pubkey(keypair), p, q, + BinarySink_UPCAST(buffer)); + BinarySource_BARE_INIT_PL(src, ptrlen_from_strbuf(buffer)); + ntru_decode_pubkey(pubkey, p, q, src); + + strbuf_clear(buffer); + ntru_encode_ciphertext(ciphertext, p, q, + BinarySink_UPCAST(buffer)); + BinarySource_BARE_INIT_PL(src, ptrlen_from_strbuf(buffer)); + ntru_decode_ciphertext(ciphertext, keypair, src); + + strbuf_clear(buffer); + ntru_encode_plaintext(plaintext, p, BinarySink_UPCAST(buffer)); + log_end(); + + break; + } + + assert(!memcmp(pubkey_orig, pubkey_check, + p*sizeof(*pubkey_check))); + } + } + + sfree(pubkey_orig); + sfree(pubkey_check); + sfree(pubkey); + sfree(plaintext); + sfree(ciphertext); + strbuf_free(buffer); +} + static const struct test tests[] = { #define STRUCT_TEST(X) { #X, test_##X }, TESTLIST(STRUCT_TEST)