[tls] Add support for Ephemeral Diffie-Hellman key exchange

Signed-off-by: Michael Brown <mcb30@ipxe.org>
diff --git a/src/include/ipxe/tls.h b/src/include/ipxe/tls.h
index 80cdd12..6d6c82d 100644
--- a/src/include/ipxe/tls.h
+++ b/src/include/ipxe/tls.h
@@ -403,6 +403,7 @@
 #define TLS_RX_ALIGN 16
 
 extern struct tls_key_exchange_algorithm tls_pubkey_exchange_algorithm;
+extern struct tls_key_exchange_algorithm tls_dhe_exchange_algorithm;
 
 extern int add_tls ( struct interface *xfer, const char *name,
 		     struct x509_root *root, struct private_key *key );
diff --git a/src/net/tls.c b/src/net/tls.c
index b209e0d..4aa4d9e 100644
--- a/src/net/tls.c
+++ b/src/net/tls.c
@@ -49,6 +49,7 @@
 #include <ipxe/rbg.h>
 #include <ipxe/validator.h>
 #include <ipxe/job.h>
+#include <ipxe/dhe.h>
 #include <ipxe/tls.h>
 #include <config/crypto.h>
 
@@ -109,6 +110,10 @@
 #define EINFO_EINVAL_TICKET						\
 	__einfo_uniqify ( EINFO_EINVAL, 0x0e,				\
 			  "Invalid New Session Ticket record")
+#define EINVAL_KEY_EXCHANGE __einfo_error ( EINFO_EINVAL_KEY_EXCHANGE )
+#define EINFO_EINVAL_KEY_EXCHANGE					\
+	__einfo_uniqify ( EINFO_EINVAL, 0x0f,				\
+			  "Invalid Server Key Exchange record" )
 #define EIO_ALERT __einfo_error ( EINFO_EIO_ALERT )
 #define EINFO_EIO_ALERT							\
 	__einfo_uniqify ( EINFO_EIO, 0x01,				\
@@ -177,6 +182,10 @@
 #define EINFO_EPERM_RENEG_VERIFY					\
 	__einfo_uniqify ( EINFO_EPERM, 0x05,				\
 			  "Secure renegotiation verification failed" )
+#define EPERM_KEY_EXCHANGE __einfo_error ( EINFO_EPERM_KEY_EXCHANGE )
+#define EINFO_EPERM_KEY_EXCHANGE					\
+	__einfo_uniqify ( EINFO_EPERM, 0x06,				\
+			  "ServerKeyExchange verification failed" )
 #define EPROTO_VERSION __einfo_error ( EINFO_EPROTO_VERSION )
 #define EINFO_EPROTO_VERSION						\
 	__einfo_uniqify ( EINFO_EPROTO, 0x01,				\
@@ -915,6 +924,44 @@
 	return NULL;
 }
 
+/**
+ * Find TLS signature algorithm
+ *
+ * @v code		Signature and hash algorithm identifier
+ * @ret pubkey		Public key algorithm, or NULL
+ */
+static struct pubkey_algorithm *
+tls_signature_hash_pubkey ( struct tls_signature_hash_id code ) {
+	struct tls_signature_hash_algorithm *sig_hash;
+
+	/* Identify signature and hash algorithm */
+	for_each_table_entry ( sig_hash, TLS_SIG_HASH_ALGORITHMS ) {
+		if ( sig_hash->code.signature == code.signature )
+			return sig_hash->pubkey;
+	}
+
+	return NULL;
+}
+
+/**
+ * Find TLS hash algorithm
+ *
+ * @v code		Signature and hash algorithm identifier
+ * @ret digest		Digest algorithm, or NULL
+ */
+static struct digest_algorithm *
+tls_signature_hash_digest ( struct tls_signature_hash_id code ) {
+	struct tls_signature_hash_algorithm *sig_hash;
+
+	/* Identify signature and hash algorithm */
+	for_each_table_entry ( sig_hash, TLS_SIG_HASH_ALGORITHMS ) {
+		if ( sig_hash->code.hash == code.hash )
+			return sig_hash->digest;
+	}
+
+	return NULL;
+}
+
 /******************************************************************************
  *
  * Handshake verification
@@ -1278,6 +1325,205 @@
 };
 
 /**
+ * Transmit Client Key Exchange record using DHE key exchange
+ *
+ * @v tls		TLS connection
+ * @ret rc		Return status code
+ */
+static int tls_send_client_key_exchange_dhe ( struct tls_connection *tls ) {
+	struct tls_cipherspec *cipherspec = &tls->tx_cipherspec_pending;
+	struct pubkey_algorithm *pubkey;
+	struct digest_algorithm *digest;
+	int use_sig_hash = tls_version ( tls, TLS_VERSION_TLS_1_2 );
+	uint8_t private[ sizeof ( tls->client_random.random ) ];
+	const struct {
+		uint16_t len;
+		uint8_t data[0];
+	} __attribute__ (( packed )) *dh_val[3];
+	const struct {
+		struct tls_signature_hash_id sig_hash[use_sig_hash];
+		uint16_t signature_len;
+		uint8_t signature[0];
+	} __attribute__ (( packed )) *sig;
+	const void *data;
+	size_t remaining;
+	size_t frag_len;
+	unsigned int i;
+	int rc;
+
+	/* Parse ServerKeyExchange */
+	data = tls->server_key;
+	remaining = tls->server_key_len;
+	for ( i = 0 ; i < ( sizeof ( dh_val ) / sizeof ( dh_val[0] ) ) ; i++ ){
+		dh_val[i] = data;
+		if ( ( sizeof ( *dh_val[i] ) > remaining ) ||
+		     ( ntohs ( dh_val[i]->len ) > ( remaining -
+						    sizeof ( *dh_val[i] ) ) )){
+			DBGC ( tls, "TLS %p received underlength "
+			       "ServerKeyExchange\n", tls );
+			DBGC_HDA ( tls, 0, tls->server_key,
+				   tls->server_key_len );
+			rc = -EINVAL_KEY_EXCHANGE;
+			goto err_header;
+		}
+		frag_len = ( sizeof ( *dh_val[i] ) + ntohs ( dh_val[i]->len ));
+		data += frag_len;
+		remaining -= frag_len;
+	}
+	sig = data;
+	if ( ( sizeof ( *sig ) > remaining ) ||
+	     ( ntohs ( sig->signature_len ) > ( remaining -
+						sizeof ( *sig ) ) ) ) {
+		DBGC ( tls, "TLS %p received underlength ServerKeyExchange\n",
+		       tls );
+		DBGC_HDA ( tls, 0, tls->server_key, tls->server_key_len );
+		rc = -EINVAL_KEY_EXCHANGE;
+		goto err_header;
+	}
+
+	/* Identify signature and hash algorithm */
+	if ( use_sig_hash ) {
+		pubkey = tls_signature_hash_pubkey ( sig->sig_hash[0] );
+		digest = tls_signature_hash_digest ( sig->sig_hash[0] );
+		if ( ( ! pubkey ) || ( ! digest ) ) {
+			DBGC ( tls, "TLS %p ServerKeyExchange unsupported "
+			       "signature and hash algorithm\n", tls );
+			rc = -ENOTSUP_SIG_HASH;
+			goto err_sig_hash;
+		}
+		if ( pubkey != cipherspec->suite->pubkey ) {
+			DBGC ( tls, "TLS %p ServerKeyExchange incorrect "
+			       "signature algorithm %s (expected %s)\n", tls,
+			       pubkey->name, cipherspec->suite->pubkey->name );
+			rc = -EPERM_KEY_EXCHANGE;
+			goto err_sig_hash;
+		}
+	} else {
+		pubkey = cipherspec->suite->pubkey;
+		digest = &md5_sha1_algorithm;
+	}
+
+	/* Verify signature */
+	{
+		const void *signature = sig->signature;
+		size_t signature_len = ntohs ( sig->signature_len );
+		uint8_t ctx[digest->ctxsize];
+		uint8_t hash[digest->digestsize];
+
+		/* Calculate digest */
+		digest_init ( digest, ctx );
+		digest_update ( digest, ctx, &tls->client_random,
+				sizeof ( tls->client_random ) );
+		digest_update ( digest, ctx, tls->server_random,
+				sizeof ( tls->server_random ) );
+		digest_update ( digest, ctx, tls->server_key,
+				( tls->server_key_len - remaining ) );
+		digest_final ( digest, ctx, hash );
+
+		/* Verify signature */
+		if ( ( rc = pubkey_verify ( pubkey, cipherspec->pubkey_ctx,
+					    digest, hash, signature,
+					    signature_len ) ) != 0 ) {
+			DBGC ( tls, "TLS %p ServerKeyExchange failed "
+			       "verification\n", tls );
+			DBGC_HDA ( tls, 0, tls->server_key,
+				   tls->server_key_len );
+			rc = -EPERM_KEY_EXCHANGE;
+			goto err_verify;
+		}
+	}
+
+	/* Generate Diffie-Hellman private key */
+	if ( ( rc = tls_generate_random ( tls, private,
+					  sizeof ( private ) ) ) != 0 ) {
+		goto err_random;
+	}
+
+	/* Construct pre-master secret and ClientKeyExchange record */
+	{
+		typeof ( dh_val[0] ) dh_p = dh_val[0];
+		typeof ( dh_val[1] ) dh_g = dh_val[1];
+		typeof ( dh_val[2] ) dh_ys = dh_val[2];
+		size_t len = ntohs ( dh_p->len );
+		struct {
+			uint32_t type_length;
+			uint16_t dh_xs_len;
+			uint8_t dh_xs[len];
+		} __attribute__ (( packed )) *key_xchg;
+		struct {
+			uint8_t pre_master_secret[len];
+			typeof ( *key_xchg ) key_xchg;
+		} *dynamic;
+		uint8_t *pre_master_secret;
+
+		/* Allocate space */
+		dynamic = malloc ( sizeof ( *dynamic ) );
+		if ( ! dynamic ) {
+			rc = -ENOMEM;
+			goto err_alloc;
+		}
+		pre_master_secret = dynamic->pre_master_secret;
+		key_xchg = &dynamic->key_xchg;
+		key_xchg->type_length =
+			( cpu_to_le32 ( TLS_CLIENT_KEY_EXCHANGE ) |
+			  htonl ( sizeof ( *key_xchg ) -
+				  sizeof ( key_xchg->type_length ) ) );
+		key_xchg->dh_xs_len = htons ( len );
+
+		/* Calculate pre-master secret and client public value */
+		if ( ( rc = dhe_key ( dh_p->data, len,
+				      dh_g->data, ntohs ( dh_g->len ),
+				      dh_ys->data, ntohs ( dh_ys->len ),
+				      private, sizeof ( private ),
+				      key_xchg->dh_xs,
+				      pre_master_secret ) ) != 0 ) {
+			DBGC ( tls, "TLS %p could not calculate DHE key: %s\n",
+			       tls, strerror ( rc ) );
+			goto err_dhe_key;
+		}
+
+		/* Strip leading zeroes from pre-master secret */
+		while ( len && ( ! *pre_master_secret ) ) {
+			pre_master_secret++;
+			len--;
+		}
+
+		/* Generate master secret */
+		tls_generate_master_secret ( tls, pre_master_secret, len );
+
+		/* Generate keys */
+		if ( ( rc = tls_generate_keys ( tls ) ) != 0 ) {
+			DBGC ( tls, "TLS %p could not generate keys: %s\n",
+			       tls, strerror ( rc ) );
+			goto err_generate_keys;
+		}
+
+		/* Transmit Client Key Exchange record */
+		if ( ( rc = tls_send_handshake ( tls, key_xchg,
+						 sizeof ( *key_xchg ) ) ) !=0){
+			goto err_send_handshake;
+		}
+
+	err_send_handshake:
+	err_generate_keys:
+	err_dhe_key:
+		free ( dynamic );
+	}
+ err_alloc:
+ err_random:
+ err_verify:
+ err_sig_hash:
+ err_header:
+	return rc;
+}
+
+/** Ephemeral Diffie-Hellman key exchange algorithm */
+struct tls_key_exchange_algorithm tls_dhe_exchange_algorithm = {
+	.name = "dhe",
+	.exchange = tls_send_client_key_exchange_dhe,
+};
+
+/**
  * Transmit Client Key Exchange record
  *
  * @v tls		TLS connection