diff --git a/src/include/ipxe/tls.h b/src/include/ipxe/tls.h index a491b795..0d1f2d85 100644 --- a/src/include/ipxe/tls.h +++ b/src/include/ipxe/tls.h @@ -201,6 +201,10 @@ struct tls_session { uint8_t handshake_md5_sha1_ctx[MD5_SHA1_CTX_SIZE]; /** SHA256 context for handshake verification */ uint8_t handshake_sha256_ctx[SHA256_CTX_SIZE]; + /** Digest algorithm used for handshake verification */ + struct digest_algorithm *handshake_digest; + /** Digest algorithm context used for handshake verification */ + uint8_t *handshake_ctx; /** TX sequence number */ uint64_t tx_seq; diff --git a/src/net/tls.c b/src/net/tls.c index 2580008d..56b01f23 100644 --- a/src/net/tls.c +++ b/src/net/tls.c @@ -697,22 +697,6 @@ static void tls_add_handshake ( struct tls_session *tls, data, len ); } -/** - * Size of handshake output buffer - * - * @v tls TLS session - */ -static size_t tls_verify_handshake_len ( struct tls_session *tls ) { - - if ( tls->version >= TLS_VERSION_TLS_1_2 ) { - /* Use SHA-256 for TLSv1.2 and later */ - return SHA256_DIGEST_SIZE; - } else { - /* Use MD5+SHA1 for TLSv1.1 and earlier */ - return MD5_SHA1_DIGEST_SIZE; - } -} - /** * Calculate handshake verification hash * @@ -723,22 +707,11 @@ static size_t tls_verify_handshake_len ( struct tls_session *tls ) { * messages seen so far. */ static void tls_verify_handshake ( struct tls_session *tls, void *out ) { - union { - uint8_t md5_sha1[MD5_SHA1_CTX_SIZE]; - uint8_t sha256[SHA256_CTX_SIZE]; - } ctx; + struct digest_algorithm *digest = tls->handshake_digest; + uint8_t ctx[ digest->ctxsize ]; - if ( tls->version >= TLS_VERSION_TLS_1_2 ) { - /* Use SHA-256 for TLSv1.2 and later */ - memcpy ( ctx.sha256, tls->handshake_sha256_ctx, - sizeof ( ctx.sha256 ) ); - digest_final ( &sha256_algorithm, ctx.sha256, out ); - } else { - /* Use MD5+SHA1 for TLSv1.1 and earlier */ - memcpy ( ctx.md5_sha1, tls->handshake_md5_sha1_ctx, - sizeof ( ctx.md5_sha1 ) ); - digest_final ( &md5_sha1_algorithm, ctx.md5_sha1, out ); - } + memcpy ( ctx, tls->handshake_ctx, sizeof ( ctx ) ); + digest_final ( digest, ctx, out ); } /****************************************************************************** @@ -915,20 +888,21 @@ static int tls_send_change_cipher ( struct tls_session *tls ) { * @ret rc Return status code */ static int tls_send_finished ( struct tls_session *tls ) { + struct digest_algorithm *digest = tls->handshake_digest; struct { uint32_t type_length; uint8_t verify_data[12]; } __attribute__ (( packed )) finished; - uint8_t digest[ tls_verify_handshake_len ( tls ) ]; + uint8_t digest_out[ digest->digestsize ]; memset ( &finished, 0, sizeof ( finished ) ); finished.type_length = ( cpu_to_le32 ( TLS_FINISHED ) | htonl ( sizeof ( finished ) - sizeof ( finished.type_length ) ) ); - tls_verify_handshake ( tls, digest ); + tls_verify_handshake ( tls, digest_out ); tls_prf_label ( tls, &tls->master_secret, sizeof ( tls->master_secret ), finished.verify_data, sizeof ( finished.verify_data ), - "client finished", digest, sizeof ( digest ) ); + "client finished", digest_out, sizeof ( digest_out ) ); return tls_send_handshake ( tls, &finished, sizeof ( finished ) ); } @@ -1052,6 +1026,14 @@ static int tls_new_server_hello ( struct tls_session *tls, DBGC ( tls, "TLS %p using protocol version %d.%d\n", tls, ( version >> 8 ), ( version & 0xff ) ); + /* Use MD5+SHA1 digest algorithm for handshake verification + * for versions earlier than TLSv1.2. + */ + if ( tls->version < TLS_VERSION_TLS_1_2 ) { + tls->handshake_digest = &md5_sha1_algorithm; + tls->handshake_ctx = tls->handshake_md5_sha1_ctx; + } + /* Copy out server random bytes */ memcpy ( &tls->server_random, &hello_a->random, sizeof ( tls->server_random ) ); @@ -1254,12 +1236,13 @@ static int tls_new_server_hello_done ( struct tls_session *tls, */ static int tls_new_finished ( struct tls_session *tls, const void *data, size_t len ) { + struct digest_algorithm *digest = tls->handshake_digest; const struct { uint8_t verify_data[12]; char next[0]; } __attribute__ (( packed )) *finished = data; const void *end = finished->next; - uint8_t digest[ tls_verify_handshake_len ( tls ) ]; + uint8_t digest_out[ digest->digestsize ]; uint8_t verify_data[ sizeof ( finished->verify_data ) ]; /* Sanity check */ @@ -1270,10 +1253,10 @@ static int tls_new_finished ( struct tls_session *tls, } /* Verify data */ - tls_verify_handshake ( tls, digest ); + tls_verify_handshake ( tls, digest_out ); tls_prf_label ( tls, &tls->master_secret, sizeof ( tls->master_secret ), verify_data, sizeof ( verify_data ), "server finished", - digest, sizeof ( digest ) ); + digest_out, sizeof ( digest_out ) ); if ( memcmp ( verify_data, finished->verify_data, sizeof ( verify_data ) ) != 0 ) { DBGC ( tls, "TLS %p verification failed\n", tls ); @@ -2014,8 +1997,8 @@ static void tls_tx_step ( struct tls_session *tls ) { } else if ( tls->tx_pending & TLS_TX_CLIENT_KEY_EXCHANGE ) { /* Send Client Key Exchange */ if ( ( rc = tls_send_client_key_exchange ( tls ) ) != 0 ) { - DBGC ( tls, "TLS %p could send Client Key Exchange: " - "%s\n", tls, strerror ( rc ) ); + DBGC ( tls, "TLS %p could not send Client Key " + "Exchange: %s\n", tls, strerror ( rc ) ); goto err; } tls->tx_pending &= ~TLS_TX_CLIENT_KEY_EXCHANGE; @@ -2099,6 +2082,8 @@ int add_tls ( struct interface *xfer, const char *name, } digest_init ( &md5_sha1_algorithm, tls->handshake_md5_sha1_ctx ); digest_init ( &sha256_algorithm, tls->handshake_sha256_ctx ); + tls->handshake_digest = &sha256_algorithm; + tls->handshake_ctx = tls->handshake_sha256_ctx; tls->tx_pending = TLS_TX_CLIENT_HELLO; process_init ( &tls->process, &tls_process_desc, &tls->refcnt );