[crypto] Calculate inverse of modulus on demand in bigint_montgomery()

Reduce the number of parameters passed to bigint_montgomery() by
calculating the inverse of the modulus modulo the element size on
demand.  Cache the result, since Montgomery reduction will be used
repeatedly with the same modulus value.

In all currently supported algorithms, the modulus is a public value
(or a fixed value defined by specification) and so this non-constant
timing does not leak any private information.

Signed-off-by: Michael Brown <mcb30@ipxe.org>
diff --git a/src/crypto/bigint.c b/src/crypto/bigint.c
index b357ea2..3ef96d1 100644
--- a/src/crypto/bigint.c
+++ b/src/crypto/bigint.c
@@ -354,23 +354,17 @@
  * Perform Montgomery reduction (REDC) of a big integer product
  *
  * @v modulus0		Element 0 of big integer modulus
- * @v modinv0		Element 0 of the inverse of the modulus modulo 2^k
  * @v mont0		Element 0 of big integer Montgomery product
  * @v result0		Element 0 of big integer to hold result
  * @v size		Number of elements in modulus and result
  *
- * Note that only the least significant element of the inverse modulo
- * 2^k is required, and that the Montgomery product will be
- * overwritten.
+ * Note that the Montgomery product will be overwritten.
  */
 void bigint_montgomery_raw ( const bigint_element_t *modulus0,
-			     const bigint_element_t *modinv0,
 			     bigint_element_t *mont0,
 			     bigint_element_t *result0, unsigned int size ) {
 	const bigint_t ( size ) __attribute__ (( may_alias ))
 		*modulus = ( ( const void * ) modulus0 );
-	const bigint_t ( 1 ) __attribute__ (( may_alias ))
-		*modinv = ( ( const void * ) modinv0 );
 	union {
 		bigint_t ( size * 2 ) full;
 		struct {
@@ -380,7 +374,8 @@
 	} __attribute__ (( may_alias )) *mont = ( ( void * ) mont0 );
 	bigint_t ( size ) __attribute__ (( may_alias ))
 		*result = ( ( void * ) result0 );
-	bigint_element_t negmodinv = -modinv->element[0];
+	static bigint_t ( 1 ) cached;
+	static bigint_t ( 1 ) negmodinv;
 	bigint_element_t multiple;
 	bigint_element_t carry;
 	unsigned int i;
@@ -391,11 +386,18 @@
 	/* Sanity checks */
 	assert ( bigint_bit_is_set ( modulus, 0 ) );
 
+	/* Calculate inverse (or use cached version) */
+	if ( cached.element[0] != modulus->element[0] ) {
+		bigint_mod_invert ( modulus, &negmodinv );
+		negmodinv.element[0] = -negmodinv.element[0];
+		cached.element[0] = modulus->element[0];
+	}
+
 	/* Perform multiprecision Montgomery reduction */
 	for ( i = 0 ; i < size ; i++ ) {
 
 		/* Determine scalar multiple for this round */
-		multiple = ( mont->low.element[i] * negmodinv );
+		multiple = ( mont->low.element[i] * negmodinv.element[0] );
 
 		/* Multiply value to make it divisible by 2^(width*i) */
 		carry = 0;
@@ -467,7 +469,6 @@
 		} product;
 	} *temp = tmp;
 	const uint8_t one[1] = { 1 };
-	bigint_t ( 1 ) modinv;
 	bigint_element_t submask;
 	unsigned int subsize;
 	unsigned int scale;
@@ -494,9 +495,6 @@
 	if ( ! submask )
 		submask = ~submask;
 
-	/* Calculate inverse of (scaled) modulus N modulo element size */
-	bigint_mod_invert ( &temp->modulus, &modinv );
-
 	/* Calculate (R^2 mod N) via direct reduction of (R^2 - N) */
 	memset ( &temp->product.full, 0, sizeof ( temp->product.full ) );
 	bigint_subtract ( &temp->padded_modulus, &temp->product.full );
@@ -504,12 +502,11 @@
 	bigint_copy ( &temp->product.low, &temp->stash );
 
 	/* Initialise result = Montgomery(1, R^2 mod N) */
-	bigint_montgomery ( &temp->modulus, &modinv,
-			    &temp->product.full, result );
+	bigint_montgomery ( &temp->modulus, &temp->product.full, result );
 
 	/* Convert base into Montgomery form */
 	bigint_multiply ( base, &temp->stash, &temp->product.full );
-	bigint_montgomery ( &temp->modulus, &modinv, &temp->product.full,
+	bigint_montgomery ( &temp->modulus, &temp->product.full,
 			    &temp->stash );
 
 	/* Calculate x1 = base^exponent modulo N */
@@ -518,13 +515,13 @@
 
 		/* Square (and reduce) */
 		bigint_multiply ( result, result, &temp->product.full );
-		bigint_montgomery ( &temp->modulus, &modinv,
-				    &temp->product.full, result );
+		bigint_montgomery ( &temp->modulus, &temp->product.full,
+				    result );
 
 		/* Multiply (and reduce) */
 		bigint_multiply ( &temp->stash, result, &temp->product.full );
-		bigint_montgomery ( &temp->modulus, &modinv,
-				    &temp->product.full, &temp->product.low );
+		bigint_montgomery ( &temp->modulus, &temp->product.full,
+				    &temp->product.low );
 
 		/* Conditionally swap the multiplied result */
 		bigint_swap ( result, &temp->product.low,
@@ -533,8 +530,7 @@
 
 	/* Convert back out of Montgomery form */
 	bigint_grow ( result, &temp->product.full );
-	bigint_montgomery ( &temp->modulus, &modinv, &temp->product.full,
-			    result );
+	bigint_montgomery ( &temp->modulus, &temp->product.full, result );
 
 	/* Handle even moduli via Garner's algorithm */
 	if ( subsize ) {
diff --git a/src/include/ipxe/bigint.h b/src/include/ipxe/bigint.h
index 3058547..90e212b 100644
--- a/src/include/ipxe/bigint.h
+++ b/src/include/ipxe/bigint.h
@@ -257,16 +257,15 @@
  * Perform Montgomery reduction (REDC) of a big integer product
  *
  * @v modulus		Big integer modulus
- * @v modinv		Big integer inverse of the modulus modulo 2^k
  * @v mont		Big integer Montgomery product
  * @v result		Big integer to hold result
  *
  * Note that the Montgomery product will be overwritten.
  */
-#define bigint_montgomery( modulus, modinv, mont, result ) do {		\
+#define bigint_montgomery( modulus, mont, result ) do {			\
 	unsigned int size = bigint_size (modulus);			\
-	bigint_montgomery_raw ( (modulus)->element, (modinv)->element,	\
-				(mont)->element, (result)->element,	\
+	bigint_montgomery_raw ( (modulus)->element, (mont)->element,	\
+				(result)->element,			\
 				size );					\
 	} while ( 0 )
 
@@ -377,7 +376,6 @@
 void bigint_mod_invert_raw ( const bigint_element_t *invertend0,
 			     bigint_element_t *inverse0, unsigned int size );
 void bigint_montgomery_raw ( const bigint_element_t *modulus0,
-			     const bigint_element_t *modinv0,
 			     bigint_element_t *mont0,
 			     bigint_element_t *result0, unsigned int size );
 void bigint_mod_exp_raw ( const bigint_element_t *base0,
diff --git a/src/tests/bigint_test.c b/src/tests/bigint_test.c
index dc74740..07ba13b 100644
--- a/src/tests/bigint_test.c
+++ b/src/tests/bigint_test.c
@@ -207,20 +207,17 @@
 }
 
 void bigint_montgomery_sample ( const bigint_element_t *modulus0,
-				const bigint_element_t *modinv0,
 				bigint_element_t *mont0,
 				bigint_element_t *result0,
 				unsigned int size ) {
 	const bigint_t ( size ) __attribute__ (( may_alias ))
 		*modulus = ( ( const void * ) modulus0 );
-	const bigint_t ( 1 ) __attribute__ (( may_alias ))
-		*modinv = ( ( const void * ) modinv0 );
 	bigint_t ( 2 * size ) __attribute__ (( may_alias ))
 		*mont = ( ( void * ) mont0 );
 	bigint_t ( size ) __attribute__ (( may_alias ))
 		*result = ( ( void * ) result0 );
 
-	bigint_montgomery ( modulus, modinv, mont, result );
+	bigint_montgomery ( modulus, mont, result );
 }
 
 void bigint_mod_exp_sample ( const bigint_element_t *base0,
@@ -631,7 +628,6 @@
 	unsigned int size =						\
 		bigint_required_size ( sizeof ( modulus_raw ) );	\
 	bigint_t ( size ) modulus_temp;					\
-	bigint_t ( 1 ) modinv_temp;					\
 	bigint_t ( 2 * size ) mont_temp;				\
 	bigint_t ( size ) result_temp;					\
 	{} /* Fix emacs alignment */					\
@@ -641,13 +637,10 @@
 	bigint_init ( &modulus_temp, modulus_raw,			\
 		      sizeof ( modulus_raw ) );				\
 	bigint_init ( &mont_temp, mont_raw, sizeof ( mont_raw ) );	\
-	bigint_mod_invert ( &modulus_temp, &modinv_temp );		\
 	DBG ( "Montgomery:\n" );					\
 	DBG_HDA ( 0, &modulus_temp, sizeof ( modulus_temp ) );		\
-	DBG_HDA ( 0, &modinv_temp, sizeof ( modinv_temp ) );		\
 	DBG_HDA ( 0, &mont_temp, sizeof ( mont_temp ) );		\
-	bigint_montgomery ( &modulus_temp, &modinv_temp, &mont_temp,	\
-			    &result_temp );				\
+	bigint_montgomery ( &modulus_temp, &mont_temp, &result_temp );	\
 	DBG_HDA ( 0, &result_temp, sizeof ( result_temp ) );		\
 	bigint_done ( &result_temp, result_raw,				\
 		      sizeof ( result_raw ) );				\