From 8771ca8598c700037896929db7deeef3687bd897 Mon Sep 17 00:00:00 2001
From: Todd Short <tshort@akamai.com>
Date: Thu, 6 May 2021 16:24:03 -0400
Subject: [PATCH 35/43] QUIC: Break up header/body processing

As DTLS has changed, so too must QUIC.
---
 ssl/statem/statem.c       |  9 ++++++---
 ssl/statem/statem_local.h |  3 ++-
 ssl/statem/statem_quic.c  | 25 +++++++++++++++----------
 3 files changed, 23 insertions(+), 14 deletions(-)

--- a/ssl/statem/statem.c
+++ b/ssl/statem/statem.c
@@ -586,7 +586,7 @@ static SUB_STATE_RETURN read_state_machi
 #ifndef OPENSSL_NO_QUIC
             } else if (SSL_IS_QUIC(s)) {
                 /* QUIC behaves like DTLS -- all in one go. */
-                ret = quic_get_message(s, &mt, &len);
+                ret = quic_get_message(s, &mt);
 #endif
             } else {
                 ret = tls_get_message_header(s, &mt);
@@ -636,8 +636,11 @@ static SUB_STATE_RETURN read_state_machi
                  * opportunity to do any further processing.
                  */
                 ret = dtls_get_message_body(s, &len);
-            } else if (!SSL_IS_QUIC(s)) {
-                /* We already got this above for QUIC */
+#ifndef OPENSSL_NO_QUIC
+            } else if (SSL_IS_QUIC(s)) {
+                ret = quic_get_message_body(s, &len);
+#endif
+            } else {
                 ret = tls_get_message_body(s, &len);
             }
             if (ret == 0) {
--- a/ssl/statem/statem_local.h
+++ b/ssl/statem/statem_local.h
@@ -105,7 +105,8 @@ __owur int tls_get_message_body(SSL *s,
 __owur int dtls_get_message(SSL *s, int *mt);
 __owur int dtls_get_message_body(SSL *s, size_t *len);
 #ifndef OPENSSL_NO_QUIC
-__owur int quic_get_message(SSL *s, int *mt, size_t *len);
+__owur int quic_get_message(SSL *s, int *mt);
+__owur int quic_get_message_body(SSL *s, size_t *len);
 #endif
 
 /* Message construction and processing functions */
--- a/ssl/statem/statem_quic.c
+++ b/ssl/statem/statem_quic.c
@@ -11,7 +11,7 @@
 #include "statem_local.h"
 #include "internal/cryptlib.h"
 
-int quic_get_message(SSL *s, int *mt, size_t *len)
+int quic_get_message(SSL *s, int *mt)
 {
     size_t l;
     QUIC_DATA *qd = s->quic_input_data_head;
@@ -19,26 +19,26 @@ int quic_get_message(SSL *s, int *mt, si
 
     if (qd == NULL) {
         s->rwstate = SSL_READING;
-        *mt = *len = 0;
+        *mt = 0;
         return 0;
     }
 
     if (!ossl_assert(qd->length >= SSL3_HM_HEADER_LENGTH)) {
         SSLfatal(s, SSL_AD_INTERNAL_ERROR, SSL_R_BAD_LENGTH);
-        *mt = *len = 0;
+        *mt = 0;
         return 0;
     }
 
     /* This is where we check for the proper level, not when data is given */
     if (qd->level != s->quic_read_level) {
         SSLfatal(s, SSL_AD_INTERNAL_ERROR, SSL_R_WRONG_ENCRYPTION_LEVEL_RECEIVED);
-        *mt = *len = 0;
+        *mt = 0;
         return 0;
     }
 
     if (!BUF_MEM_grow_clean(s->init_buf, (int)qd->length)) {
         SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_BUF_LIB);
-        *mt = *len = 0;
+        *mt = 0;
         return 0;
     }
 
@@ -53,28 +53,32 @@ int quic_get_message(SSL *s, int *mt, si
     s->s3.tmp.message_type = *mt = *(s->init_buf->data);
     p = (uint8_t*)s->init_buf->data + 1;
     n2l3(p, l);
-    s->init_num = s->s3.tmp.message_size = *len = l;
+    s->init_num = s->s3.tmp.message_size = l;
     s->init_msg = s->init_buf->data + SSL3_HM_HEADER_LENGTH;
 
+    return 1;
+}
+
+int quic_get_message_body(SSL *s, size_t *len)
+{
     /* No CCS in QUIC/TLSv1.3? */
-    if (*mt == SSL3_MT_CHANGE_CIPHER_SPEC) {
+    if (s->s3.tmp.message_type == SSL3_MT_CHANGE_CIPHER_SPEC) {
         SSLfatal(s, SSL_AD_UNEXPECTED_MESSAGE, SSL_R_CCS_RECEIVED_EARLY);
         *len = 0;
         return 0;
     }
     /* No KeyUpdate in QUIC */
-    if (*mt == SSL3_MT_KEY_UPDATE) {
+    if (s->s3.tmp.message_type == SSL3_MT_KEY_UPDATE) {
         SSLfatal(s, SSL_AD_UNEXPECTED_MESSAGE, SSL_R_UNEXPECTED_MESSAGE);
         *len = 0;
         return 0;
     }
 
-
     /*
      * If receiving Finished, record MAC of prior handshake messages for
      * Finished verification.
      */
-    if (*mt == SSL3_MT_FINISHED && !ssl3_take_mac(s)) {
+    if (s->s3.tmp.message_type == SSL3_MT_FINISHED && !ssl3_take_mac(s)) {
         /* SSLfatal() already called */
         *len = 0;
         return 0;
@@ -108,5 +112,6 @@ int quic_get_message(SSL *s, int *mt, si
                         (size_t)s->init_num + SSL3_HM_HEADER_LENGTH, s,
                         s->msg_callback_arg);
 
+    *len = s->init_num;
     return 1;
 }
