[crypto] Allow for relaxed Montgomery reduction

Classic Montgomery reduction involves a single conditional subtraction
to ensure that the result is strictly less than the modulus.

When performing chains of Montgomery multiplications (potentially
interspersed with additions and subtractions), it can be useful to
work with values that are stored modulo some small multiple of the
modulus, thereby allowing some reductions to be elided.  Each addition
and subtraction stage will increase this running multiple, and the
following multiplication stages can be used to reduce the running
multiple since the reduction carried out for multiplication products
is generally strong enough to absorb some additional bits in the
inputs.  This approach is already used in the x25519 code, where
multiplication takes two 258-bit inputs and produces a 257-bit output.

Split out the conditional subtraction from bigint_montgomery() and
provide a separate bigint_montgomery_relaxed() for callers who do not
require immediate reduction to within the range of the modulus.

Modular exponentiation could potentially make use of relaxed
Montgomery multiplication, but this would require R>4N, i.e. that the
two most significant bits of the modulus be zero.  For both RSA and
DHE, this would necessitate extending the modulus size by one element,
which would negate any speed increase from omitting the conditional
subtractions.  We therefore retain the use of classic Montgomery
reduction for modular exponentiation, apart from the final conversion
out of Montgomery form.

Signed-off-by: Michael Brown <mcb30@ipxe.org>
diff --git a/src/crypto/bigint.c b/src/crypto/bigint.c
index 3ef96d1..9274798 100644
--- a/src/crypto/bigint.c
+++ b/src/crypto/bigint.c
@@ -351,18 +351,113 @@
 }
 
 /**
- * Perform Montgomery reduction (REDC) of a big integer product
+ * Perform relaxed Montgomery reduction (REDC) of a big integer
  *
- * @v modulus0		Element 0 of big integer modulus
- * @v mont0		Element 0 of big integer Montgomery product
+ * @v modulus0		Element 0 of big integer odd modulus
+ * @v value0		Element 0 of big integer to be reduced
  * @v result0		Element 0 of big integer to hold result
  * @v size		Number of elements in modulus and result
+ * @ret carry		Carry out
  *
- * Note that the Montgomery product will be overwritten.
+ * The value to be reduced will be made divisible by the size of the
+ * modulus while retaining its residue class (i.e. multiples of the
+ * modulus will be added until the low half of the value is zero).
+ *
+ * The result may be expressed as
+ *
+ *    tR = x + mN
+ *
+ * where x is the input value, N is the modulus, R=2^n (where n is the
+ * number of bits in the representation of the modulus, including any
+ * leading zero bits), and m is the number of multiples of the modulus
+ * added to make the result tR divisible by R.
+ *
+ * The maximum addend is mN <= (R-1)*N (and such an m can be proven to
+ * exist since N is limited to being odd and therefore coprime to R).
+ *
+ * Since the result of this addition is one bit larger than the input
+ * value, a carry out bit is also returned.  The caller may be able to
+ * prove that the carry out is always zero, in which case it may be
+ * safely ignored.
+ *
+ * The upper half of the output value (i.e. t) will also be copied to
+ * the result pointer.  It is permissible for the result pointer to
+ * overlap the lower half of the input value.
+ *
+ * External knowledge of constraints on the modulus and the input
+ * value may be used to prove constraints on the result.  The
+ * constraint on the modulus may be generally expressed as
+ *
+ *    R > kN
+ *
+ * for some positive integer k.  The value k=1 is allowed, and simply
+ * expresses that the modulus fits within the number of bits in its
+ * own representation.
+ *
+ * For classic Montgomery reduction, we have k=1, i.e. R > N and a
+ * separate constraint that the input value is in the range x < RN.
+ * This gives the result constraint
+ *
+ *    tR < RN + (R-1)N
+ *       < 2RN - N
+ *       < 2RN
+ *     t < 2N
+ *
+ * A single subtraction of the modulus may therefore be required to
+ * bring it into the range t < N.
+ *
+ * When the input value is known to be a product of two integers A and
+ * B, with A < aN and B < bN, we get the result constraint
+ *
+ *    tR < abN^2 + (R-1)N
+ *       < (ab/k)RN + RN - N
+ *       < (1 + ab/k)RN
+ *     t < (1 + ab/k)N
+ *
+ * If we have k=a=b=1, i.e. R > N with A < N and B < N, then the
+ * result is in the range t < 2N and may require a single subtraction
+ * of the modulus to bring it into the range t < N so that it may be
+ * used as an input on a subsequent iteration.
+ *
+ * If we have k=4 and a=b=2, i.e. R > 4N with A < 2N and B < 2N, then
+ * the result is in the range t < 2N and may immediately be used as an
+ * input on a subsequent iteration, without requiring a subtraction.
+ *
+ * Larger values of k may be used to allow for larger values of a and
+ * b, which can be useful to elide intermediate reductions in a
+ * calculation chain that involves additions and subtractions between
+ * multiplications (as used in elliptic curve point addition, for
+ * example).  As a general rule: each intermediate addition or
+ * subtraction will require k to be doubled.
+ *
+ * When the input value is known to be a single integer A, with A < aN
+ * (as used when converting out of Montgomery form), we get the result
+ * constraint
+ *
+ *    tR < aN + (R-1)N
+ *       < RN + (a-1)N
+ *
+ * If we have a=1, i.e. A < N, then the constraint becomes
+ *
+ *    tR < RN
+ *     t < N
+ *
+ * and so the result is immediately in the range t < N with no
+ * subtraction of the modulus required.
+ *
+ * For any larger value of a, the result value t=N becomes possible.
+ * Additional external knowledge may potentially be used to prove that
+ * t=N cannot occur.  For example: if the caller is performing modular
+ * exponentiation with a prime modulus (or, more generally, a modulus
+ * that is coprime to the base), then there is no way for a non-zero
+ * base value to end up producing an exact multiple of the modulus.
+ * If t=N cannot be disproved, then conversion out of Montgomery form
+ * may require an additional subtraction of the modulus.
  */
-void bigint_montgomery_raw ( const bigint_element_t *modulus0,
-			     bigint_element_t *mont0,
-			     bigint_element_t *result0, unsigned int size ) {
+int bigint_montgomery_relaxed_raw ( const bigint_element_t *modulus0,
+				    bigint_element_t *value0,
+				    bigint_element_t *result0,
+				    unsigned int size ) {
 	const bigint_t ( size ) __attribute__ (( may_alias ))
 		*modulus = ( ( const void * ) modulus0 );
 	union {
@@ -371,7 +466,7 @@
 			bigint_t ( size ) low;
 			bigint_t ( size ) high;
 		} __attribute__ (( packed ));
-	} __attribute__ (( may_alias )) *mont = ( ( void * ) mont0 );
+	} __attribute__ (( may_alias )) *value = ( ( void * ) value0 );
 	bigint_t ( size ) __attribute__ (( may_alias ))
 		*result = ( ( void * ) result0 );
 	static bigint_t ( 1 ) cached;
@@ -381,7 +476,6 @@
 	unsigned int i;
 	unsigned int j;
 	int overflow;
-	int underflow;
 
 	/* Sanity checks */
 	assert ( bigint_bit_is_set ( modulus, 0 ) );
@@ -397,33 +491,73 @@
 	for ( i = 0 ; i < size ; i++ ) {
 
 		/* Determine scalar multiple for this round */
-		multiple = ( mont->low.element[i] * negmodinv.element[0] );
+		multiple = ( value->low.element[i] * negmodinv.element[0] );
 
 		/* Multiply value to make it divisible by 2^(width*i) */
 		carry = 0;
 		for ( j = 0 ; j < size ; j++ ) {
 			bigint_multiply_one ( multiple, modulus->element[j],
-					      &mont->full.element[ i + j ],
+					      &value->full.element[ i + j ],
 					      &carry );
 		}
 
 		/* Since value is now divisible by 2^(width*i), we
 		 * know that the current low element must have been
-		 * zeroed.  We can store the multiplication carry out
-		 * in this element, avoiding the need to immediately
-		 * propagate the carry through the remaining elements.
+		 * zeroed.
 		 */
-		assert ( mont->low.element[i] == 0 );
-		mont->low.element[i] = carry;
+		assert ( value->low.element[i] == 0 );
+
+		/* Store the multiplication carry out in the result,
+		 * avoiding the need to immediately propagate the
+		 * carry through the remaining elements.
+		 */
+		result->element[i] = carry;
 	}
 
 	/* Add the accumulated carries */
-	overflow = bigint_add ( &mont->low, &mont->high );
+	overflow = bigint_add ( result, &value->high );
+
+	/* Copy to result buffer */
+	bigint_copy ( &value->high, result );
+
+	return overflow;
+}
+
+/**
+ * Perform classic Montgomery reduction (REDC) of a big integer
+ *
+ * @v modulus0		Element 0 of big integer odd modulus
+ * @v value0		Element 0 of big integer to be reduced
+ * @v result0		Element 0 of big integer to hold result
+ * @v size		Number of elements in modulus and result
+ */
+void bigint_montgomery_raw ( const bigint_element_t *modulus0,
+			     bigint_element_t *value0,
+			     bigint_element_t *result0,
+			     unsigned int size ) {
+	const bigint_t ( size ) __attribute__ (( may_alias ))
+		*modulus = ( ( const void * ) modulus0 );
+	union {
+		bigint_t ( size * 2 ) full;
+		struct {
+			bigint_t ( size ) low;
+			bigint_t ( size ) high;
+		} __attribute__ (( packed ));
+	} __attribute__ (( may_alias )) *value = ( ( void * ) value0 );
+	bigint_t ( size ) __attribute__ (( may_alias ))
+		*result = ( ( void * ) result0 );
+	int overflow;
+	int underflow;
+
+	/* Sanity check */
+	assert ( ! bigint_is_geq ( &value->high, modulus ) );
+
+	/* Perform relaxed Montgomery reduction */
+	overflow = bigint_montgomery_relaxed ( modulus, &value->full, result );
 
 	/* Conditionally subtract the modulus once */
-	memcpy ( result, &mont->high, sizeof ( *result ) );
 	underflow = bigint_subtract ( modulus, result );
-	bigint_swap ( result, &mont->high, ( underflow & ~overflow ) );
+	bigint_swap ( result, &value->high, ( underflow & ~overflow ) );
 
 	/* Sanity check */
 	assert ( ! bigint_is_geq ( result, modulus ) );
@@ -530,7 +664,8 @@
 
 	/* Convert back out of Montgomery form */
 	bigint_grow ( result, &temp->product.full );
-	bigint_montgomery ( &temp->modulus, &temp->product.full, result );
+	bigint_montgomery_relaxed ( &temp->modulus, &temp->product.full,
+				    result );
 
 	/* Handle even moduli via Garner's algorithm */
 	if ( subsize ) {
diff --git a/src/include/ipxe/bigint.h b/src/include/ipxe/bigint.h
index 90e212b..db907f1 100644
--- a/src/include/ipxe/bigint.h
+++ b/src/include/ipxe/bigint.h
@@ -235,7 +235,7 @@
  * @v modulus		Big integer modulus
  * @v value		Big integer to be reduced
  */
-#define bigint_reduce( modulus, value ) do {	\
+#define bigint_reduce( modulus, value ) do {				\
 		unsigned int size = bigint_size (modulus);		\
 		bigint_reduce_raw ( (modulus)->element,			\
 				    (value)->element, size );		\
@@ -254,19 +254,31 @@
 	} while ( 0 )
 
 /**
- * Perform Montgomery reduction (REDC) of a big integer product
+ * Perform relaxed Montgomery reduction (REDC) of a big integer
  *
- * @v modulus		Big integer modulus
- * @v mont		Big integer Montgomery product
+ * @v modulus		Big integer odd modulus
+ * @v value		Big integer to be reduced
  * @v result		Big integer to hold result
- *
- * Note that the Montgomery product will be overwritten.
+ * @ret carry		Carry out
  */
-#define bigint_montgomery( modulus, mont, result ) do {			\
+#define bigint_montgomery_relaxed( modulus, value, result ) ( {		\
 	unsigned int size = bigint_size (modulus);			\
-	bigint_montgomery_raw ( (modulus)->element, (mont)->element,	\
-				(result)->element,			\
-				size );					\
+	bigint_montgomery_relaxed_raw ( (modulus)->element,		\
+					(value)->element,		\
+					(result)->element, size );	\
+	} )
+
+/**
+ * Perform classic Montgomery reduction (REDC) of a big integer
+ *
+ * @v modulus		Big integer odd modulus
+ * @v value		Big integer to be reduced
+ * @v result		Big integer to hold result
+ */
+#define bigint_montgomery( modulus, value, result ) do {		\
+	unsigned int size = bigint_size (modulus);			\
+	bigint_montgomery_raw ( (modulus)->element, (value)->element,	\
+				(result)->element, size );		\
 	} while ( 0 )
 
 /**
@@ -375,8 +387,12 @@
 			 unsigned int size );
 void bigint_mod_invert_raw ( const bigint_element_t *invertend0,
 			     bigint_element_t *inverse0, unsigned int size );
+int bigint_montgomery_relaxed_raw ( const bigint_element_t *modulus0,
+				    bigint_element_t *value0,
+				    bigint_element_t *result0,
+				    unsigned int size );
 void bigint_montgomery_raw ( const bigint_element_t *modulus0,
-			     bigint_element_t *mont0,
+			     bigint_element_t *value0,
 			     bigint_element_t *result0, unsigned int size );
 void bigint_mod_exp_raw ( const bigint_element_t *base0,
 			  const bigint_element_t *modulus0,
diff --git a/src/tests/bigint_test.c b/src/tests/bigint_test.c
index 07ba13b..fce5f5c 100644
--- a/src/tests/bigint_test.c
+++ b/src/tests/bigint_test.c
@@ -207,17 +207,17 @@
 }
 
 void bigint_montgomery_sample ( const bigint_element_t *modulus0,
-				bigint_element_t *mont0,
+				bigint_element_t *value0,
 				bigint_element_t *result0,
 				unsigned int size ) {
 	const bigint_t ( size ) __attribute__ (( may_alias ))
 		*modulus = ( ( const void * ) modulus0 );
 	bigint_t ( 2 * size ) __attribute__ (( may_alias ))
-		*mont = ( ( void * ) mont0 );
+		*value = ( ( void * ) value0 );
 	bigint_t ( size ) __attribute__ (( may_alias ))
 		*result = ( ( void * ) result0 );
 
-	bigint_montgomery ( modulus, mont, result );
+	bigint_montgomery ( modulus, value, result );
 }
 
 void bigint_mod_exp_sample ( const bigint_element_t *base0,