diff --git a/src/include/ipxe/http.h b/src/include/ipxe/http.h index d8f4ca5a..cf8c0c7f 100644 --- a/src/include/ipxe/http.h +++ b/src/include/ipxe/http.h @@ -18,6 +18,7 @@ FILE_LICENCE ( GPL2_OR_LATER ); extern int http_open_filter ( struct interface *xfer, struct uri *uri, unsigned int default_port, int ( * filter ) ( struct interface *, + const char *, struct interface ** ) ); #endif /* _IPXE_HTTP_H */ diff --git a/src/include/ipxe/tls.h b/src/include/ipxe/tls.h index c14e9210..90833781 100644 --- a/src/include/ipxe/tls.h +++ b/src/include/ipxe/tls.h @@ -74,6 +74,10 @@ struct tls_header { #define TLS_RSA_WITH_AES_128_CBC_SHA 0x002f #define TLS_RSA_WITH_AES_256_CBC_SHA 0x0035 +/* TLS extension types */ +#define TLS_SERVER_NAME 0 +#define TLS_SERVER_NAME_HOST_NAME 0 + /** TLS RX state machine state */ enum tls_rx_state { TLS_RX_HEADER = 0, @@ -133,6 +137,8 @@ struct tls_session { /** Reference counter */ struct refcnt refcnt; + /** Server name */ + const char *name; /** Plaintext stream */ struct interface plainstream; /** Ciphertext stream */ @@ -183,7 +189,7 @@ struct tls_session { void *rx_data; }; -extern int add_tls ( struct interface *xfer, +extern int add_tls ( struct interface *xfer, const char *name, struct interface **next ); #endif /* _IPXE_TLS_H */ diff --git a/src/net/tcp/httpcore.c b/src/net/tcp/httpcore.c index 69d27389..617f49b0 100644 --- a/src/net/tcp/httpcore.c +++ b/src/net/tcp/httpcore.c @@ -838,6 +838,7 @@ static struct process_descriptor http_process_desc = int http_open_filter ( struct interface *xfer, struct uri *uri, unsigned int default_port, int ( * filter ) ( struct interface *xfer, + const char *name, struct interface **next ) ) { struct http_request *http; struct sockaddr_tcpip server; @@ -865,7 +866,7 @@ int http_open_filter ( struct interface *xfer, struct uri *uri, server.st_port = htons ( uri_port ( http->uri, default_port ) ); socket = &http->socket; if ( filter ) { - if ( ( rc = filter ( socket, &socket ) ) != 0 ) + if ( ( rc = filter ( socket, uri->host, &socket ) ) != 0 ) goto err; } if ( ( rc = xfer_open_named_socket ( socket, SOCK_STREAM, diff --git a/src/net/tls.c b/src/net/tls.c index cbba0003..919025e7 100644 --- a/src/net/tls.c +++ b/src/net/tls.c @@ -691,6 +691,19 @@ static int tls_send_client_hello ( struct tls_session *tls ) { uint16_t cipher_suites[2]; uint8_t compression_methods_len; uint8_t compression_methods[1]; + uint16_t extensions_len; + struct { + uint16_t server_name_type; + uint16_t server_name_len; + struct { + uint16_t len; + struct { + uint8_t type; + uint16_t len; + uint8_t name[ strlen ( tls->name ) ]; + } __attribute__ (( packed )) list[1]; + } __attribute__ (( packed )) server_name; + } __attribute__ (( packed )) extensions; } __attribute__ (( packed )) hello; memset ( &hello, 0, sizeof ( hello ) ); @@ -703,6 +716,17 @@ static int tls_send_client_hello ( struct tls_session *tls ) { hello.cipher_suites[0] = htons ( TLS_RSA_WITH_AES_128_CBC_SHA ); hello.cipher_suites[1] = htons ( TLS_RSA_WITH_AES_256_CBC_SHA ); hello.compression_methods_len = sizeof ( hello.compression_methods ); + hello.extensions_len = htons ( sizeof ( hello.extensions ) ); + hello.extensions.server_name_type = htons ( TLS_SERVER_NAME ); + hello.extensions.server_name_len + = htons ( sizeof ( hello.extensions.server_name ) ); + hello.extensions.server_name.len + = htons ( sizeof ( hello.extensions.server_name.list ) ); + hello.extensions.server_name.list[0].type = TLS_SERVER_NAME_HOST_NAME; + hello.extensions.server_name.list[0].len + = htons ( sizeof ( hello.extensions.server_name.list[0].name )); + memcpy ( hello.extensions.server_name.list[0].name, tls->name, + sizeof ( hello.extensions.server_name.list[0].name ) ); return tls_send_handshake ( tls, &hello, sizeof ( hello ) ); } @@ -881,8 +905,8 @@ static int tls_new_server_hello ( struct tls_session *tls, int rc; /* Sanity check */ - if ( end != ( data + len ) ) { - DBGC ( tls, "TLS %p received overlength Server Hello\n", tls ); + if ( end > ( data + len ) ) { + DBGC ( tls, "TLS %p received underlength Server Hello\n", tls ); DBGC_HD ( tls, data, len ); return -EINVAL; } @@ -1805,7 +1829,8 @@ static struct process_descriptor tls_process_desc = ****************************************************************************** */ -int add_tls ( struct interface *xfer, struct interface **next ) { +int add_tls ( struct interface *xfer, const char *name, + struct interface **next ) { struct tls_session *tls; int rc; @@ -1817,6 +1842,7 @@ int add_tls ( struct interface *xfer, struct interface **next ) { } memset ( tls, 0, sizeof ( *tls ) ); ref_init ( &tls->refcnt, free_tls ); + tls->name = name; intf_init ( &tls->plainstream, &tls_plainstream_desc, &tls->refcnt ); intf_init ( &tls->cipherstream, &tls_cipherstream_desc, &tls->refcnt ); tls->version = TLS_VERSION_TLS_1_1;