From ccd7adf0401d1c7e0cb3c7f874777d7d7246679b Mon Sep 17 00:00:00 2001
From: Tatsuhiro Tsujikawa <404610+tatsuhiro-t@users.noreply.github.com>
Date: Thu, 11 Mar 2021 23:07:55 +0900
Subject: [PATCH 31/43] QUIC: Add early data support

* QUIC: Add early data support

This commit adds SSL_set_quic_early_data_enabled to add early data
support to QUIC.
---
 doc/man3/SSL_CTX_set_quic_method.pod |  10 +-
 include/openssl/ssl.h.in             |   2 +
 ssl/ssl_lib.c                        |  15 +++
 ssl/ssl_quic.c                       |  76 ++++++++++---
 ssl/statem/statem_srvr.c             |  10 ++
 ssl/tls13_enc.c                      |  90 +++++++++++++---
 test/sslapitest.c                    | 156 +++++++++++++++++++++++++++
 util/libssl.num                      |   1 +
 8 files changed, 327 insertions(+), 33 deletions(-)

--- a/doc/man3/SSL_CTX_set_quic_method.pod
+++ b/doc/man3/SSL_CTX_set_quic_method.pod
@@ -17,7 +17,8 @@ SSL_is_quic,
 SSL_get_peer_quic_transport_version,
 SSL_get_quic_transport_version,
 SSL_set_quic_transport_version,
-SSL_set_quic_use_legacy_codepoint
+SSL_set_quic_use_legacy_codepoint,
+SSL_set_quic_early_data_enabled
 - QUIC support
 
 =head1 SYNOPSIS
@@ -47,6 +48,7 @@ SSL_set_quic_use_legacy_codepoint
  void SSL_set_quic_transport_version(SSL *ssl, int version);
  int SSL_get_quic_transport_version(const SSL *ssl);
  int SSL_get_peer_quic_transport_version(const SSL *ssl);
+ void SSL_set_quic_early_data_enabled(SSL *ssl, int enabled);
 
 =head1 DESCRIPTION
 
@@ -106,6 +108,12 @@ SSL_set_quic_transport_version().
 SSL_get_peer_quic_transport_version() returns the version the that was 
 negotiated.
 
+SSL_set_quic_early_data_enabled() enables QUIC early data if a nonzero
+value is passed.  Client must set a resumed session before calling
+this function.  Server must set 0xffffffffu to
+SSL_CTX_set_max_early_data() or SSL_set_max_early_data() so that a
+session ticket indicates that server is able to accept early data.
+
 =head1 NOTES
 
 These APIs are implementations of BoringSSL's QUIC APIs.
--- a/include/openssl/ssl.h.in
+++ b/include/openssl/ssl.h.in
@@ -2577,6 +2577,8 @@ __owur int SSL_get_peer_quic_transport_v
 
 int SSL_CIPHER_get_prf_nid(const SSL_CIPHER *c);
 
+void SSL_set_quic_early_data_enabled(SSL *ssl, int enabled);
+
 #  endif
 
 # ifdef  __cplusplus
--- a/ssl/ssl_lib.c
+++ b/ssl/ssl_lib.c
@@ -4015,6 +4015,21 @@ int SSL_do_handshake(SSL *s)
             ret = s->handshake_func(s);
         }
     }
+#ifndef OPENSSL_NO_QUIC
+    if (SSL_IS_QUIC(s) && ret == 1) {
+        if (s->server) {
+            if (s->early_data_state == SSL_EARLY_DATA_ACCEPTING) {
+                s->early_data_state = SSL_EARLY_DATA_FINISHED_READING;
+                s->rwstate = SSL_READING;
+                ret = 0;
+            }
+        } else if (s->early_data_state == SSL_EARLY_DATA_CONNECTING) {
+            s->early_data_state = SSL_EARLY_DATA_WRITE_RETRY;
+            s->rwstate = SSL_READING;
+            ret = 0;
+        }
+    }
+#endif
     return ret;
 }
 
--- a/ssl/ssl_quic.c
+++ b/ssl/ssl_quic.c
@@ -257,24 +257,46 @@ int quic_set_encryption_secrets(SSL *ssl
         return 1;
     }
 
-    md = ssl_handshake_md(ssl);
-    if (md == NULL) {
-        /* May not have selected cipher, yet */
-        const SSL_CIPHER *c = NULL;
-
-        /*
-         * It probably doesn't make sense to use an (external) PSK session,
-         * but in theory some kinds of external session caches could be
-         * implemented using it, so allow psksession to be used as well as
-         * the regular session.
-         */
-        if (ssl->session != NULL)
-            c = SSL_SESSION_get0_cipher(ssl->session);
-        else if (ssl->psksession != NULL)
+    if (level == ssl_encryption_early_data) {
+        const SSL_CIPHER *c = SSL_SESSION_get0_cipher(ssl->session);
+        if (ssl->early_data_state == SSL_EARLY_DATA_CONNECTING
+                && ssl->max_early_data > 0
+                && ssl->session->ext.max_early_data == 0) {
+            if (!ossl_assert(ssl->psksession != NULL
+                    && ssl->max_early_data ==
+                       ssl->psksession->ext.max_early_data)) {
+                SSLfatal(ssl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+                return 0;
+            }
             c = SSL_SESSION_get0_cipher(ssl->psksession);
+        }
+
+        if (c == NULL) {
+            SSLfatal(ssl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+            return 0;
+        }
 
-        if (c != NULL)
-            md = SSL_CIPHER_get_handshake_digest(c);
+        md = ssl_md(ssl->ctx, c->algorithm2);
+    } else {
+        md = ssl_handshake_md(ssl);
+        if (md == NULL) {
+            /* May not have selected cipher, yet */
+            const SSL_CIPHER *c = NULL;
+
+            /*
+             * It probably doesn't make sense to use an (external) PSK session,
+             * but in theory some kinds of external session caches could be
+             * implemented using it, so allow psksession to be used as well as
+             * the regular session.
+             */
+            if (ssl->session != NULL)
+                c = SSL_SESSION_get0_cipher(ssl->session);
+            else if (ssl->psksession != NULL)
+                c = SSL_SESSION_get0_cipher(ssl->psksession);
+
+            if (c != NULL)
+                md = SSL_CIPHER_get_handshake_digest(c);
+        }
     }
 
     if ((len = EVP_MD_size(md)) <= 0) {
@@ -330,3 +352,25 @@ int SSL_is_quic(SSL* ssl)
 {
     return SSL_IS_QUIC(ssl);
 }
+
+void SSL_set_quic_early_data_enabled(SSL *ssl, int enabled)
+{
+    if (!SSL_is_quic(ssl) || !SSL_in_before(ssl))
+        return;
+
+    if (!enabled) {
+      ssl->early_data_state = SSL_EARLY_DATA_NONE;
+      return;
+    }
+
+    if (ssl->server) {
+        ssl->early_data_state = SSL_EARLY_DATA_ACCEPTING;
+        return;
+    }
+
+    if ((ssl->session == NULL || ssl->session->ext.max_early_data == 0)
+            && ssl->psk_use_session_cb == NULL)
+        return;
+
+    ssl->early_data_state = SSL_EARLY_DATA_CONNECTING;
+}
--- a/ssl/statem/statem_srvr.c
+++ b/ssl/statem/statem_srvr.c
@@ -964,6 +964,16 @@ WORK_STATE ossl_statem_server_post_work(
                         SSL3_CC_APPLICATION | SSL3_CHANGE_CIPHER_SERVER_WRITE))
             /* SSLfatal() already called */
             return WORK_ERROR;
+
+#ifndef OPENSSL_NO_QUIC
+            if (SSL_IS_QUIC(s) && s->ext.early_data == SSL_EARLY_DATA_ACCEPTED) {
+                s->early_data_state = SSL_EARLY_DATA_FINISHED_READING;
+                if (!s->method->ssl3_enc->change_cipher_state(
+                        s, SSL3_CC_HANDSHAKE | SSL3_CHANGE_CIPHER_SERVER_READ))
+                    /* SSLfatal() already called */
+                    return WORK_ERROR;
+            }
+#endif
         }
         break;
 
--- a/ssl/tls13_enc.c
+++ b/ssl/tls13_enc.c
@@ -434,20 +434,76 @@ static int quic_change_cipher_state(SSL
     int is_server_write = ((which & SSL3_CHANGE_CIPHER_SERVER_WRITE) == SSL3_CHANGE_CIPHER_SERVER_WRITE);
     int is_early = (which & SSL3_CC_EARLY);
 
-    md = ssl_handshake_md(s);
-    if (!ssl3_digest_cached_records(s, 1)
-        || !ssl_handshake_hash(s, hash, sizeof(hash), &hashlen)) {
-        /* SSLfatal() already called */;
-        goto err;
-    }
+    if (is_early) {
+        EVP_MD_CTX *mdctx = NULL;
+        long handlen;
+        void *hdata;
+        unsigned int hashlenui;
+        const SSL_CIPHER *sslcipher = SSL_SESSION_get0_cipher(s->session);
+
+        handlen = BIO_get_mem_data(s->s3.handshake_buffer, &hdata);
+        if (handlen <= 0) {
+            SSLfatal(s, SSL_AD_INTERNAL_ERROR, SSL_R_BAD_HANDSHAKE_LENGTH);
+            goto err;
+        }
+
+        if (s->early_data_state == SSL_EARLY_DATA_CONNECTING
+                && s->max_early_data > 0
+                && s->session->ext.max_early_data == 0) {
+            /*
+             * If we are attempting to send early data, and we've decided to
+             * actually do it but max_early_data in s->session is 0 then we
+             * must be using an external PSK.
+             */
+            if (!ossl_assert(s->psksession != NULL
+                    && s->max_early_data ==
+                       s->psksession->ext.max_early_data)) {
+                SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+                goto err;
+            }
+            sslcipher = SSL_SESSION_get0_cipher(s->psksession);
+        }
+        if (sslcipher == NULL) {
+            SSLfatal(s, SSL_AD_INTERNAL_ERROR, SSL_R_BAD_PSK);
+            goto err;
+        }
+
+        /*
+         * We need to calculate the handshake digest using the digest from
+         * the session. We haven't yet selected our ciphersuite so we can't
+         * use ssl_handshake_md().
+         */
+        mdctx = EVP_MD_CTX_new();
+        if (mdctx == NULL) {
+            SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_MALLOC_FAILURE);
+            goto err;
+        }
+        md = ssl_md(s->ctx, sslcipher->algorithm2);
+        if (md == NULL || !EVP_DigestInit_ex(mdctx, md, NULL)
+                || !EVP_DigestUpdate(mdctx, hdata, handlen)
+                || !EVP_DigestFinal_ex(mdctx, hash, &hashlenui)) {
+            SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+            EVP_MD_CTX_free(mdctx);
+            goto err;
+        }
+        hashlen = hashlenui;
+        EVP_MD_CTX_free(mdctx);
+    } else {
+        md = ssl_handshake_md(s);
+        if (!ssl3_digest_cached_records(s, 1)
+                || !ssl_handshake_hash(s, hash, sizeof(hash), &hashlen)) {
+            /* SSLfatal() already called */;
+            goto err;
+        }
 
-    /* Ensure cast to size_t is safe */
-    hashleni = EVP_MD_size(md);
-    if (!ossl_assert(hashleni >= 0)) {
-        SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_EVP_LIB);
-        goto err;
+        /* Ensure cast to size_t is safe */
+        hashleni = EVP_MD_size(md);
+        if (!ossl_assert(hashleni >= 0)) {
+            SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_EVP_LIB);
+            goto err;
+        }
+        hashlen = (size_t)hashleni;
     }
-    hashlen = (size_t)hashleni;
 
     if (is_client_read || is_server_write) {
         if (is_handshake) {
@@ -553,10 +609,12 @@ static int quic_change_cipher_state(SSL
             }
         }
 
-        if (s->server)
-            s->quic_read_level = level;
-        else
-            s->quic_write_level = level;
+        if (level != ssl_encryption_early_data) {
+            if (s->server)
+                s->quic_read_level = level;
+            else
+                s->quic_write_level = level;
+        }
     }
 
     ret = 1;
--- a/test/sslapitest.c
+++ b/test/sslapitest.c
@@ -10966,6 +10966,159 @@ end:
     serverssl = NULL;
     return testresult;
 }
+
+# ifndef OSSL_NO_USABLE_TLS1_3
+/*
+ * Helper method to setup objects for QUIC early data test. Caller
+ * frees objects on error.
+ */
+static int quic_setupearly_data_test(SSL_CTX **cctx, SSL_CTX **sctx,
+                                     SSL **clientssl, SSL **serverssl,
+                                     SSL_SESSION **sess, int idx)
+{
+    static const char *server_str = "SERVER";
+    static const char *client_str = "CLIENT";
+
+    if (*sctx == NULL
+            && (!TEST_true(create_ssl_ctx_pair(libctx, TLS_server_method(),
+                                               TLS_client_method(),
+                                               TLS1_3_VERSION, 0,
+                                               sctx, cctx, cert, privkey))
+                || !TEST_true(SSL_CTX_set_quic_method(*sctx, &quic_method))
+                || !TEST_true(SSL_CTX_set_quic_method(*cctx, &quic_method))
+                || !TEST_true(SSL_CTX_set_max_early_data(*sctx, 0xffffffffu))))
+        return 0;
+
+    if (idx == 1) {
+        /* When idx == 1 we repeat the tests with read_ahead set */
+        SSL_CTX_set_read_ahead(*cctx, 1);
+        SSL_CTX_set_read_ahead(*sctx, 1);
+    } else if (idx == 2) {
+        /* When idx == 2 we are doing early_data with a PSK. Set up callbacks */
+        SSL_CTX_set_psk_use_session_callback(*cctx, use_session_cb);
+        SSL_CTX_set_psk_find_session_callback(*sctx, find_session_cb);
+        use_session_cb_cnt = 0;
+        find_session_cb_cnt = 0;
+        srvid = pskid;
+    }
+
+    if (!TEST_true(create_ssl_objects(*sctx, *cctx, serverssl, clientssl,
+                                      NULL, NULL))
+            || !TEST_true(SSL_set_quic_transport_params(*serverssl,
+                                                        (unsigned char*)server_str,
+                                                        strlen(server_str)+1))
+            || !TEST_true(SSL_set_quic_transport_params(*clientssl,
+                                                        (unsigned char*)client_str,
+                                                        strlen(client_str)+1))
+            || !TEST_true(SSL_set_app_data(*serverssl, *clientssl))
+            || !TEST_true(SSL_set_app_data(*clientssl, *serverssl)))
+        return 0;
+
+    /*
+     * For one of the run throughs (doesn't matter which one), we'll try sending
+     * some SNI data in the initial ClientHello. This will be ignored (because
+     * there is no SNI cb set up by the server), so it should not impact
+     * early_data.
+     */
+    if (idx == 1
+            && !TEST_true(SSL_set_tlsext_host_name(*clientssl, "localhost")))
+        return 0;
+
+    if (idx == 2) {
+        clientpsk = create_a_psk(*clientssl, SHA256_DIGEST_LENGTH);
+        if (!TEST_ptr(clientpsk)
+                || !TEST_true(SSL_SESSION_set_max_early_data(clientpsk,
+                                                             0xffffffffu))
+                || !TEST_true(SSL_SESSION_up_ref(clientpsk))) {
+            SSL_SESSION_free(clientpsk);
+            clientpsk = NULL;
+            return 0;
+        }
+        serverpsk = clientpsk;
+
+        if (sess != NULL) {
+            if (!TEST_true(SSL_SESSION_up_ref(clientpsk))) {
+                SSL_SESSION_free(clientpsk);
+                SSL_SESSION_free(serverpsk);
+                clientpsk = serverpsk = NULL;
+                return 0;
+            }
+            *sess = clientpsk;
+        }
+
+        SSL_set_quic_early_data_enabled(*serverssl, 1);
+        SSL_set_quic_early_data_enabled(*clientssl, 1);
+
+        return 1;
+    }
+
+    if (sess == NULL)
+        return 1;
+
+    if (!TEST_true(create_ssl_connection(*serverssl, *clientssl,
+                                         SSL_ERROR_NONE)))
+        return 0;
+
+    /* Deal with two NewSessionTickets */
+    if (!TEST_true(SSL_process_quic_post_handshake(*clientssl))
+            || !TEST_true(SSL_process_quic_post_handshake(*clientssl)))
+        return 0;
+
+    *sess = SSL_get1_session(*clientssl);
+    SSL_shutdown(*clientssl);
+    SSL_shutdown(*serverssl);
+    SSL_free(*serverssl);
+    SSL_free(*clientssl);
+    *serverssl = *clientssl = NULL;
+
+    if (!TEST_true(create_ssl_objects(*sctx, *cctx, serverssl,
+                                      clientssl, NULL, NULL))
+            || !TEST_true(SSL_set_session(*clientssl, *sess))
+            || !TEST_true(SSL_set_quic_transport_params(*serverssl,
+                                                        (unsigned char*)server_str,
+                                                        strlen(server_str)+1))
+            || !TEST_true(SSL_set_quic_transport_params(*clientssl,
+                                                        (unsigned char*)client_str,
+                                                        strlen(client_str)+1))
+            || !TEST_true(SSL_set_app_data(*serverssl, *clientssl))
+            || !TEST_true(SSL_set_app_data(*clientssl, *serverssl)))
+        return 0;
+
+    SSL_set_quic_early_data_enabled(*serverssl, 1);
+    SSL_set_quic_early_data_enabled(*clientssl, 1);
+
+    return 1;
+}
+
+static int test_quic_early_data(int tst)
+{
+    SSL_CTX *cctx = NULL, *sctx = NULL;
+    SSL *clientssl = NULL, *serverssl = NULL;
+    int testresult = 0;
+    SSL_SESSION *sess = NULL;
+
+    if (!TEST_true(quic_setupearly_data_test(&cctx, &sctx, &clientssl,
+                                             &serverssl, &sess, tst)))
+        goto end;
+
+    if (!TEST_true(create_ssl_connection(serverssl, clientssl, SSL_ERROR_NONE))
+            || !TEST_true(SSL_get_early_data_status(serverssl)))
+        goto end;
+
+    testresult = 1;
+
+ end:
+    SSL_SESSION_free(sess);
+    SSL_SESSION_free(clientpsk);
+    SSL_SESSION_free(serverpsk);
+    clientpsk = serverpsk = NULL;
+    SSL_free(serverssl);
+    SSL_free(clientssl);
+    SSL_CTX_free(sctx);
+    SSL_CTX_free(cctx);
+    return testresult;
+}
+# endif /* OSSL_NO_USABLE_TLS1_3 */
 #endif /* OPENSSL_NO_QUIC */
 
 static struct next_proto_st {
@@ -11687,6 +11840,9 @@ int setup_tests(void)
     ADD_ALL_TESTS(test_no_renegotiation, 2);
 #ifndef OPENSSL_NO_QUIC
     ADD_ALL_TESTS(test_quic_api, 9);
+# ifndef OSSL_NO_USABLE_TLS1_3
+    ADD_ALL_TESTS(test_quic_early_data, 3);
+# endif
 #endif
     return 1;
 
--- a/util/libssl.num
+++ b/util/libssl.num
@@ -535,3 +535,4 @@ SSL_set_quic_use_legacy_codepoint
 SSL_set_quic_transport_version          20012	3_0_0	EXIST::FUNCTION:QUIC
 SSL_get_peer_quic_transport_version     20013	3_0_0	EXIST::FUNCTION:QUIC
 SSL_get_quic_transport_version          20014	3_0_0	EXIST::FUNCTION:QUIC
+SSL_set_quic_early_data_enabled         20015	3_0_0	EXIST::FUNCTION:QUIC
