From ad33923b6596437c1e63954decd6cf261bfa5dc7 Mon Sep 17 00:00:00 2001
From: Todd Short <tshort@akamai.com>
Date: Fri, 30 Aug 2019 11:17:58 -0400
Subject: [PATCH 14/43] QUIC: Move QUIC code out of tls13_change_cipher_state()

Create quic_change_cipher_state() that does the minimal required
to generate the QUIC secrets. (e.g. encryption contexts are not
initialized).
---
 ssl/ssl_quic.c  |   7 --
 ssl/tls13_enc.c | 217 ++++++++++++++++++++++++++++++------------------
 2 files changed, 138 insertions(+), 86 deletions(-)

--- a/ssl/ssl_quic.c
+++ b/ssl/ssl_quic.c
@@ -199,7 +199,6 @@ int quic_set_encryption_secrets(SSL *ssl
     uint8_t *s2c_secret = NULL;
     size_t len;
     const EVP_MD *md;
-    static const unsigned char zeros[EVP_MAX_MD_SIZE];
 
     if (!SSL_IS_QUIC(ssl))
         return 1;
@@ -240,12 +239,6 @@ int quic_set_encryption_secrets(SSL *ssl
         return 0;
     }
 
-    /* In some cases, we want to set the secret only when BOTH are non-zero */
-    if (c2s_secret != NULL && s2c_secret != NULL
-            && !memcmp(c2s_secret, zeros, len)
-            && !memcmp(s2c_secret, zeros, len))
-        return 1;
-
     if (ssl->server) {
         if (!ssl->quic_method->set_encryption_secrets(ssl, level, c2s_secret,
                                                       s2c_secret, len)) {
--- a/ssl/tls13_enc.c
+++ b/ssl/tls13_enc.c
@@ -400,27 +400,144 @@ static int derive_secret_key_and_iv(SSL
     return 1;
 }
 
-int tls13_change_cipher_state(SSL *s, int which)
-{
 #ifdef CHARSET_EBCDIC
-  static const unsigned char client_early_traffic[]       = {0x63, 0x20, 0x65, 0x20,       /*traffic*/0x74, 0x72, 0x61, 0x66, 0x66, 0x69, 0x63, 0x00};
-  static const unsigned char client_handshake_traffic[]   = {0x63, 0x20, 0x68, 0x73, 0x20, /*traffic*/0x74, 0x72, 0x61, 0x66, 0x66, 0x69, 0x63, 0x00};
-  static const unsigned char client_application_traffic[] = {0x63, 0x20, 0x61, 0x70, 0x20, /*traffic*/0x74, 0x72, 0x61, 0x66, 0x66, 0x69, 0x63, 0x00};
-  static const unsigned char server_handshake_traffic[]   = {0x73, 0x20, 0x68, 0x73, 0x20, /*traffic*/0x74, 0x72, 0x61, 0x66, 0x66, 0x69, 0x63, 0x00};
-  static const unsigned char server_application_traffic[] = {0x73, 0x20, 0x61, 0x70, 0x20, /*traffic*/0x74, 0x72, 0x61, 0x66, 0x66, 0x69, 0x63, 0x00};
-  static const unsigned char exporter_master_secret[] = {0x65, 0x78, 0x70, 0x20,                    /* master*/  0x6D, 0x61, 0x73, 0x74, 0x65, 0x72, 0x00};
-  static const unsigned char resumption_master_secret[] = {0x72, 0x65, 0x73, 0x20,                  /* master*/  0x6D, 0x61, 0x73, 0x74, 0x65, 0x72, 0x00};
-  static const unsigned char early_exporter_master_secret[] = {0x65, 0x20, 0x65, 0x78, 0x70, 0x20,  /* master*/  0x6D, 0x61, 0x73, 0x74, 0x65, 0x72, 0x00};
+static const unsigned char client_early_traffic[]       = {0x63, 0x20, 0x65, 0x20,       /*traffic*/0x74, 0x72, 0x61, 0x66, 0x66, 0x69, 0x63, 0x00};
+static const unsigned char client_handshake_traffic[]   = {0x63, 0x20, 0x68, 0x73, 0x20, /*traffic*/0x74, 0x72, 0x61, 0x66, 0x66, 0x69, 0x63, 0x00};
+static const unsigned char client_application_traffic[] = {0x63, 0x20, 0x61, 0x70, 0x20, /*traffic*/0x74, 0x72, 0x61, 0x66, 0x66, 0x69, 0x63, 0x00};
+static const unsigned char server_handshake_traffic[]   = {0x73, 0x20, 0x68, 0x73, 0x20, /*traffic*/0x74, 0x72, 0x61, 0x66, 0x66, 0x69, 0x63, 0x00};
+static const unsigned char server_application_traffic[] = {0x73, 0x20, 0x61, 0x70, 0x20, /*traffic*/0x74, 0x72, 0x61, 0x66, 0x66, 0x69, 0x63, 0x00};
+static const unsigned char exporter_master_secret[] = {0x65, 0x78, 0x70, 0x20,                    /* master*/  0x6D, 0x61, 0x73, 0x74, 0x65, 0x72, 0x00};
+static const unsigned char resumption_master_secret[] = {0x72, 0x65, 0x73, 0x20,                  /* master*/  0x6D, 0x61, 0x73, 0x74, 0x65, 0x72, 0x00};
+static const unsigned char early_exporter_master_secret[] = {0x65, 0x20, 0x65, 0x78, 0x70, 0x20,  /* master*/  0x6D, 0x61, 0x73, 0x74, 0x65, 0x72, 0x00};
 #else
-    static const unsigned char client_early_traffic[] = "c e traffic";
-    static const unsigned char client_handshake_traffic[] = "c hs traffic";
-    static const unsigned char client_application_traffic[] = "c ap traffic";
-    static const unsigned char server_handshake_traffic[] = "s hs traffic";
-    static const unsigned char server_application_traffic[] = "s ap traffic";
-    static const unsigned char exporter_master_secret[] = "exp master";
-    static const unsigned char resumption_master_secret[] = "res master";
-    static const unsigned char early_exporter_master_secret[] = "e exp master";
+static const unsigned char client_early_traffic[] = "c e traffic";
+static const unsigned char client_handshake_traffic[] = "c hs traffic";
+static const unsigned char client_application_traffic[] = "c ap traffic";
+static const unsigned char server_handshake_traffic[] = "s hs traffic";
+static const unsigned char server_application_traffic[] = "s ap traffic";
+static const unsigned char exporter_master_secret[] = "exp master";
+static const unsigned char resumption_master_secret[] = "res master";
+static const unsigned char early_exporter_master_secret[] = "e exp master";
 #endif
+#ifndef OPENSSL_NO_QUIC
+static int quic_change_cipher_state(SSL *s, int which)
+{
+    unsigned char hash[EVP_MAX_MD_SIZE];
+    size_t hashlen = 0;
+    int hashleni;
+    int ret = 0;
+    const EVP_MD *md = NULL;
+    OSSL_ENCRYPTION_LEVEL level = ssl_encryption_initial;
+    int is_handshake = ((which & SSL3_CC_HANDSHAKE) == SSL3_CC_HANDSHAKE);
+    int is_client_read = ((which & SSL3_CHANGE_CIPHER_CLIENT_READ) == SSL3_CHANGE_CIPHER_CLIENT_READ);
+    int is_server_write = ((which & SSL3_CHANGE_CIPHER_SERVER_WRITE) == SSL3_CHANGE_CIPHER_SERVER_WRITE);
+    int is_early = (which & SSL3_CC_EARLY);
+
+    md = ssl_handshake_md(s);
+    if (!ssl3_digest_cached_records(s, 1)
+        || !ssl_handshake_hash(s, hash, sizeof(hash), &hashlen)) {
+        /* SSLfatal() already called */;
+        goto err;
+    }
+
+    /* Ensure cast to size_t is safe */
+    hashleni = EVP_MD_size(md);
+    if (!ossl_assert(hashleni >= 0)) {
+        SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_EVP_LIB);
+        goto err;
+    }
+    hashlen = (size_t)hashleni;
+
+    if (is_handshake)
+        level = ssl_encryption_handshake;
+    else
+        level = ssl_encryption_application;
+
+    if (is_client_read || is_server_write) {
+        if (is_handshake) {
+            level = ssl_encryption_handshake;
+
+            if (!tls13_hkdf_expand(s, md, s->handshake_secret, client_handshake_traffic,
+                                   sizeof(client_handshake_traffic)-1, hash, hashlen,
+                                   s->client_hand_traffic_secret, hashlen, 1)) {
+                /* SSLfatal() already called */
+                goto err;
+            }
+            if (!ssl_log_secret(s, CLIENT_HANDSHAKE_LABEL, s->client_hand_traffic_secret, hashlen)) {
+                /* SSLfatal() already called */
+                goto err;
+            }
+
+            if (!tls13_hkdf_expand(s, md, s->handshake_secret, server_handshake_traffic,
+                                   sizeof(server_handshake_traffic)-1, hash, hashlen,
+                                   s->server_hand_traffic_secret, hashlen, 1)) {
+                /* SSLfatal() already called */
+                goto err;
+            }
+            if (!ssl_log_secret(s, SERVER_HANDSHAKE_LABEL, s->server_hand_traffic_secret, hashlen)) {
+                /* SSLfatal() already called */
+                goto err;
+            }
+        } else {
+            level = ssl_encryption_application;
+
+            if (!tls13_hkdf_expand(s, md, s->master_secret, client_application_traffic,
+                                   sizeof(client_application_traffic)-1, hash, hashlen,
+                                   s->client_app_traffic_secret, hashlen, 1)) {
+                /* SSLfatal() already called */
+                goto err;
+            }
+            if (!ssl_log_secret(s, CLIENT_APPLICATION_LABEL, s->client_app_traffic_secret, hashlen)) {
+                /* SSLfatal() already called */
+                goto err;
+            }
+
+            if (!tls13_hkdf_expand(s, md, s->master_secret, server_application_traffic,
+                                   sizeof(server_application_traffic)-1, hash, hashlen,
+                                   s->server_app_traffic_secret, hashlen, 1)) {
+                /* SSLfatal() already called */
+                goto err;
+            }
+            if (!ssl_log_secret(s, SERVER_APPLICATION_LABEL, s->server_app_traffic_secret, hashlen)) {
+                /* SSLfatal() already called */
+                goto err;
+            }
+        }
+        if (s->server)
+            s->quic_write_level = level;
+        else
+            s->quic_read_level = level;
+    } else {
+        if (is_early) {
+            level = ssl_encryption_early_data;
+
+            if (!tls13_hkdf_expand(s, md, s->early_secret, client_early_traffic,
+                                   sizeof(client_early_traffic)-1, hash, hashlen,
+                                   s->client_early_traffic_secret, hashlen, 1)) {
+                /* SSLfatal() already called */
+                goto err;
+            }
+            if (!ssl_log_secret(s, CLIENT_EARLY_LABEL, s->client_early_traffic_secret, hashlen)) {
+                /* SSLfatal() already called */
+                goto err;
+            }
+        }
+        if (s->server)
+            s->quic_read_level = level;
+        else
+            s->quic_write_level = level;
+    }
+
+    if (level != ssl_encryption_initial && !quic_set_encryption_secrets(s, level))
+        goto err;
+
+    ret = 1;
+ err:
+    return ret;
+}
+#endif /* OPENSSL_NO_QUIC */
+int tls13_change_cipher_state(SSL *s, int which)
+{
     unsigned char *iv;
     unsigned char key[EVP_MAX_KEY_LENGTH];
     unsigned char secret[EVP_MAX_MD_SIZE];
@@ -440,8 +557,10 @@ int tls13_change_cipher_state(SSL *s, in
     ktls_crypto_info_t crypto_info;
     BIO *bio;
 #endif
+
 #ifndef OPENSSL_NO_QUIC
-    OSSL_ENCRYPTION_LEVEL level = ssl_encryption_initial;
+    if (SSL_IS_QUIC(s))
+        return quic_change_cipher_state(s, which);
 #endif
 
     if (which & SSL3_CC_READ) {
@@ -488,9 +607,6 @@ int tls13_change_cipher_state(SSL *s, in
             label = client_early_traffic;
             labellen = sizeof(client_early_traffic) - 1;
             log_label = CLIENT_EARLY_LABEL;
-#ifndef OPENSSL_NO_QUIC
-            level = ssl_encryption_early_data;
-#endif
 
             handlen = BIO_get_mem_data(s->s3.handshake_buffer, &hdata);
             if (handlen <= 0) {
@@ -567,14 +683,6 @@ int tls13_change_cipher_state(SSL *s, in
                 /* SSLfatal() already called */
                 goto err;
             }
-#ifndef OPENSSL_NO_QUIC
-            if (SSL_IS_QUIC(s)) {
-                if (s->server)
-                    s->quic_read_level = ssl_encryption_early_data;
-                else
-                    s->quic_write_level = ssl_encryption_early_data;
-            }
-#endif
         } else if (which & SSL3_CC_HANDSHAKE) {
             insecret = s->handshake_secret;
             finsecret = s->client_finished_secret;
@@ -582,15 +690,6 @@ int tls13_change_cipher_state(SSL *s, in
             label = client_handshake_traffic;
             labellen = sizeof(client_handshake_traffic) - 1;
             log_label = CLIENT_HANDSHAKE_LABEL;
-#ifndef OPENSSL_NO_QUIC
-            if (SSL_IS_QUIC(s)) {
-                level = ssl_encryption_handshake;
-                if (s->server)
-                    s->quic_read_level = ssl_encryption_handshake;
-                else
-                    s->quic_write_level = ssl_encryption_handshake;
-            }
-#endif
             /*
              * The handshake hash used for the server read/client write handshake
              * traffic secret is the same as the hash for the server
@@ -613,15 +712,6 @@ int tls13_change_cipher_state(SSL *s, in
              * previously saved value.
              */
             hash = s->server_finished_hash;
-#ifndef OPENSSL_NO_QUIC
-            if (SSL_IS_QUIC(s)) {
-                level = ssl_encryption_application; /* ??? */
-                if (s->server)
-                    s->quic_read_level = ssl_encryption_application;
-                else
-                    s->quic_write_level = ssl_encryption_application;
-            }
-#endif
         }
     } else {
         /* Early data never applies to client-read/server-write */
@@ -632,29 +722,11 @@ int tls13_change_cipher_state(SSL *s, in
             label = server_handshake_traffic;
             labellen = sizeof(server_handshake_traffic) - 1;
             log_label = SERVER_HANDSHAKE_LABEL;
-#ifndef OPENSSL_NO_QUIC
-            if (SSL_IS_QUIC(s)) {
-                level = ssl_encryption_handshake;
-                if (s->server)
-                    s->quic_write_level = ssl_encryption_handshake;
-                else
-                    s->quic_read_level = ssl_encryption_handshake;
-            }
-#endif
         } else {
             insecret = s->master_secret;
             label = server_application_traffic;
             labellen = sizeof(server_application_traffic) - 1;
             log_label = SERVER_APPLICATION_LABEL;
-#ifndef OPENSSL_NO_QUIC
-            if (SSL_IS_QUIC(s)) {
-                level = ssl_encryption_application;
-                if (s->server)
-                    s->quic_write_level = ssl_encryption_application;
-                else
-                    s->quic_read_level = ssl_encryption_application;
-            }
-#endif
         }
     }
 
@@ -723,14 +795,6 @@ int tls13_change_cipher_state(SSL *s, in
         }
     } else if (label == client_application_traffic)
         memcpy(s->client_app_traffic_secret, secret, hashlen);
-#ifndef OPENSSL_NO_QUIC
-    else if (label == client_handshake_traffic)
-        memcpy(s->client_hand_traffic_secret, secret, hashlen);
-    else if (label == server_handshake_traffic)
-        memcpy(s->server_hand_traffic_secret, secret, hashlen);
-    else if (label == client_early_traffic)
-        memcpy(s->client_early_traffic_secret, secret, hashlen);
-#endif
 
     if (!ssl_log_secret(s, log_label, secret, hashlen)) {
         /* SSLfatal() already called */
@@ -791,11 +855,6 @@ skip_ktls:
 # endif
 #endif
 
-#ifndef OPENSSL_NO_QUIC
-    if (!quic_set_encryption_secrets(s, level))
-        goto err;
-#endif
-
     ret = 1;
  err:
     if ((which & SSL3_CC_EARLY) != 0) {
