From d2653e79aba277f5d8249c421a7c342e08548403 Mon Sep 17 00:00:00 2001 From: Simon Tatham Date: Sat, 8 Apr 2017 17:29:41 +0100 Subject: [PATCH] Fix bug in Poly1305 bigval_final_reduce(). Mark Wooding pointed out that my comment in make1305.py was completely wrong, and that the stated strategy for reducing a value mod 2^130-5 would not in fact completely reduce all inputs in the range - for the most obvious reason, namely that the numbers between 2^130-5 and 2^130 would never have anything subtracted at all. Implemented a replacement strategy which my tests suggest will do the right thing for all numbers in the expected range that are anywhere near an integer multiple of the modulus. --- contrib/make1305.py | 20 +++-- sshccp.c | 188 +++++++++++++++++++++++++------------------- 2 files changed, 117 insertions(+), 91 deletions(-) diff --git a/contrib/make1305.py b/contrib/make1305.py index c8597040..58dca67a 100755 --- a/contrib/make1305.py +++ b/contrib/make1305.py @@ -338,16 +338,20 @@ static void bigval_mul_mod_p(bigval *r, const bigval *a, const bigval *b) \n""" % target.text() def gen_final_reduce(target): - # We take our input number n, and compute k = n + 5*(n >> 130). - # Then k >> 130 is precisely the multiple of p that needs to be - # subtracted from n to reduce it to strictly less than p. + # Given our input number n, n >> 130 is usually precisely the + # multiple of p that needs to be subtracted from n to reduce it to + # strictly less than p, but it might be too low by 1 (but not more + # than 1, given the range of our input is nowhere near the square + # of the modulus). So we add another 5, which will push a carry + # into the 130th bit if and only if that has happened, and then + # use that to decide whether to subtract one more copy of p. a = target.bigval_input("n", 133) - a1 = a.extract_bits(130, 130) - k = a + target.const(5) * a1 - q = k.extract_bits(130) - adjusted = a + target.const(5) * q - ret = adjusted.extract_bits(0, 130) + q = a.extract_bits(130) + adjusted = a.extract_bits(0, 130) + target.const(5) * q + final_subtract = (adjusted + target.const(5)).extract_bits(130) + adjusted2 = adjusted + target.const(5) * final_subtract + ret = adjusted2.extract_bits(0, 130) target.write_bigval("n", ret) return """\ static void bigval_final_reduce(bigval *n) diff --git a/sshccp.c b/sshccp.c index d0a61775..835b8fe5 100644 --- a/sshccp.c +++ b/sshccp.c @@ -440,9 +440,10 @@ static void bigval_mul_mod_p(bigval *r, const bigval *a, const bigval *b) static void bigval_final_reduce(bigval *n) { - BignumInt v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v12, v13, v14, v15; - BignumInt v16, v17, v18, v19, v20, v21, v22, v24, v25, v26, v27, v28, v29; - BignumInt v30, v31, v32, v33; + BignumInt v0, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v13, v14, v15; + BignumInt v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28; + BignumInt v29, v30, v31, v32, v34, v35, v36, v37, v38, v39, v40, v41, v42; + BignumInt v43; BignumCarry carry; v0 = n->w[0]; @@ -455,45 +456,55 @@ static void bigval_final_reduce(bigval *n) v7 = n->w[7]; v8 = n->w[8]; v9 = (v8) >> 2; - v10 = 5 * v9; - BignumADC(v12, carry, v0, v10, 0); - (void)v12; - BignumADC(v13, carry, v1, 0, carry); - (void)v13; - BignumADC(v14, carry, v2, 0, carry); - (void)v14; - BignumADC(v15, carry, v3, 0, carry); - (void)v15; - BignumADC(v16, carry, v4, 0, carry); - (void)v16; - BignumADC(v17, carry, v5, 0, carry); - (void)v17; - BignumADC(v18, carry, v6, 0, carry); - (void)v18; - BignumADC(v19, carry, v7, 0, carry); - (void)v19; - v20 = v8 + 0 + carry; - v21 = (v20) >> 2; - v22 = 5 * v21; - BignumADC(v24, carry, v0, v22, 0); - BignumADC(v25, carry, v1, 0, carry); - BignumADC(v26, carry, v2, 0, carry); - BignumADC(v27, carry, v3, 0, carry); - BignumADC(v28, carry, v4, 0, carry); - BignumADC(v29, carry, v5, 0, carry); - BignumADC(v30, carry, v6, 0, carry); - BignumADC(v31, carry, v7, 0, carry); - v32 = v8 + 0 + carry; - v33 = (v32) & ((((BignumInt)1) << 2)-1); - n->w[0] = v24; - n->w[1] = v25; - n->w[2] = v26; - n->w[3] = v27; - n->w[4] = v28; - n->w[5] = v29; - n->w[6] = v30; - n->w[7] = v31; - n->w[8] = v33; + v10 = (v8) & ((((BignumInt)1) << 2)-1); + v11 = 5 * v9; + BignumADC(v13, carry, v0, v11, 0); + BignumADC(v14, carry, v1, 0, carry); + BignumADC(v15, carry, v2, 0, carry); + BignumADC(v16, carry, v3, 0, carry); + BignumADC(v17, carry, v4, 0, carry); + BignumADC(v18, carry, v5, 0, carry); + BignumADC(v19, carry, v6, 0, carry); + BignumADC(v20, carry, v7, 0, carry); + v21 = v10 + 0 + carry; + BignumADC(v22, carry, v13, 5, 0); + (void)v22; + BignumADC(v23, carry, v14, 0, carry); + (void)v23; + BignumADC(v24, carry, v15, 0, carry); + (void)v24; + BignumADC(v25, carry, v16, 0, carry); + (void)v25; + BignumADC(v26, carry, v17, 0, carry); + (void)v26; + BignumADC(v27, carry, v18, 0, carry); + (void)v27; + BignumADC(v28, carry, v19, 0, carry); + (void)v28; + BignumADC(v29, carry, v20, 0, carry); + (void)v29; + v30 = v21 + 0 + carry; + v31 = (v30) >> 2; + v32 = 5 * v31; + BignumADC(v34, carry, v13, v32, 0); + BignumADC(v35, carry, v14, 0, carry); + BignumADC(v36, carry, v15, 0, carry); + BignumADC(v37, carry, v16, 0, carry); + BignumADC(v38, carry, v17, 0, carry); + BignumADC(v39, carry, v18, 0, carry); + BignumADC(v40, carry, v19, 0, carry); + BignumADC(v41, carry, v20, 0, carry); + v42 = v21 + 0 + carry; + v43 = (v42) & ((((BignumInt)1) << 2)-1); + n->w[0] = v34; + n->w[1] = v35; + n->w[2] = v36; + n->w[3] = v37; + n->w[4] = v38; + n->w[5] = v39; + n->w[6] = v40; + n->w[7] = v41; + n->w[8] = v43; } #elif BIGNUM_INT_BITS == 32 @@ -604,8 +615,8 @@ static void bigval_mul_mod_p(bigval *r, const bigval *a, const bigval *b) static void bigval_final_reduce(bigval *n) { - BignumInt v0, v1, v2, v3, v4, v5, v6, v8, v9, v10, v11, v12, v13, v14; - BignumInt v16, v17, v18, v19, v20, v21; + BignumInt v0, v1, v2, v3, v4, v5, v6, v7, v9, v10, v11, v12, v13, v14; + BignumInt v15, v16, v17, v18, v19, v20, v22, v23, v24, v25, v26, v27; BignumCarry carry; v0 = n->w[0]; @@ -614,29 +625,35 @@ static void bigval_final_reduce(bigval *n) v3 = n->w[3]; v4 = n->w[4]; v5 = (v4) >> 2; - v6 = 5 * v5; - BignumADC(v8, carry, v0, v6, 0); - (void)v8; - BignumADC(v9, carry, v1, 0, carry); - (void)v9; - BignumADC(v10, carry, v2, 0, carry); - (void)v10; - BignumADC(v11, carry, v3, 0, carry); - (void)v11; - v12 = v4 + 0 + carry; - v13 = (v12) >> 2; - v14 = 5 * v13; - BignumADC(v16, carry, v0, v14, 0); - BignumADC(v17, carry, v1, 0, carry); - BignumADC(v18, carry, v2, 0, carry); - BignumADC(v19, carry, v3, 0, carry); - v20 = v4 + 0 + carry; - v21 = (v20) & ((((BignumInt)1) << 2)-1); - n->w[0] = v16; - n->w[1] = v17; - n->w[2] = v18; - n->w[3] = v19; - n->w[4] = v21; + v6 = (v4) & ((((BignumInt)1) << 2)-1); + v7 = 5 * v5; + BignumADC(v9, carry, v0, v7, 0); + BignumADC(v10, carry, v1, 0, carry); + BignumADC(v11, carry, v2, 0, carry); + BignumADC(v12, carry, v3, 0, carry); + v13 = v6 + 0 + carry; + BignumADC(v14, carry, v9, 5, 0); + (void)v14; + BignumADC(v15, carry, v10, 0, carry); + (void)v15; + BignumADC(v16, carry, v11, 0, carry); + (void)v16; + BignumADC(v17, carry, v12, 0, carry); + (void)v17; + v18 = v13 + 0 + carry; + v19 = (v18) >> 2; + v20 = 5 * v19; + BignumADC(v22, carry, v9, v20, 0); + BignumADC(v23, carry, v10, 0, carry); + BignumADC(v24, carry, v11, 0, carry); + BignumADC(v25, carry, v12, 0, carry); + v26 = v13 + 0 + carry; + v27 = (v26) & ((((BignumInt)1) << 2)-1); + n->w[0] = v22; + n->w[1] = v23; + n->w[2] = v24; + n->w[3] = v25; + n->w[4] = v27; } #elif BIGNUM_INT_BITS == 64 @@ -705,28 +722,33 @@ static void bigval_mul_mod_p(bigval *r, const bigval *a, const bigval *b) static void bigval_final_reduce(bigval *n) { - BignumInt v0, v1, v2, v3, v4, v6, v7, v8, v9, v10, v12, v13, v14, v15; + BignumInt v0, v1, v2, v3, v4, v5, v7, v8, v9, v10, v11, v12, v13, v14; + BignumInt v16, v17, v18, v19; BignumCarry carry; v0 = n->w[0]; v1 = n->w[1]; v2 = n->w[2]; v3 = (v2) >> 2; - v4 = 5 * v3; - BignumADC(v6, carry, v0, v4, 0); - (void)v6; - BignumADC(v7, carry, v1, 0, carry); - (void)v7; - v8 = v2 + 0 + carry; - v9 = (v8) >> 2; - v10 = 5 * v9; - BignumADC(v12, carry, v0, v10, 0); - BignumADC(v13, carry, v1, 0, carry); - v14 = v2 + 0 + carry; - v15 = (v14) & ((((BignumInt)1) << 2)-1); - n->w[0] = v12; - n->w[1] = v13; - n->w[2] = v15; + v4 = (v2) & ((((BignumInt)1) << 2)-1); + v5 = 5 * v3; + BignumADC(v7, carry, v0, v5, 0); + BignumADC(v8, carry, v1, 0, carry); + v9 = v4 + 0 + carry; + BignumADC(v10, carry, v7, 5, 0); + (void)v10; + BignumADC(v11, carry, v8, 0, carry); + (void)v11; + v12 = v9 + 0 + carry; + v13 = (v12) >> 2; + v14 = 5 * v13; + BignumADC(v16, carry, v7, v14, 0); + BignumADC(v17, carry, v8, 0, carry); + v18 = v9 + 0 + carry; + v19 = (v18) & ((((BignumInt)1) << 2)-1); + n->w[0] = v16; + n->w[1] = v17; + n->w[2] = v19; } #else