diff options
author | Joe Birr-Pixton <jpixton@gmail.com> | 2024-09-12 16:33:44 +0100 |
---|---|---|
committer | Joe Birr-Pixton <jpixton@gmail.com> | 2024-09-26 13:50:09 +0000 |
commit | 39455e19fc3b5657f6a8c4fd22a0f981554ed5cb (patch) | |
tree | a49a9bbd555e3140ea7002d9197e9444a4394763 | |
parent | 163a241a25a2a6233cbc568b33fad8b93f900f93 (diff) |
Avoid excess copying of tickets
These can be large (hundreds of bytes), and even larger
(thousands of bytes) if the server decides to include
the client's identity.
Parse them into an Arc, and then maintain that on
the path to the session store.
-rw-r--r-- | rustls/src/client/handy.rs | 6 | ||||
-rw-r--r-- | rustls/src/client/hs.rs | 2 | ||||
-rw-r--r-- | rustls/src/client/tls12.rs | 10 | ||||
-rw-r--r-- | rustls/src/client/tls13.rs | 2 | ||||
-rw-r--r-- | rustls/src/msgs/handshake.rs | 16 | ||||
-rw-r--r-- | rustls/src/msgs/handshake_test.rs | 5 | ||||
-rw-r--r-- | rustls/src/msgs/persist.rs | 16 |
7 files changed, 31 insertions, 26 deletions
diff --git a/rustls/src/client/handy.rs b/rustls/src/client/handy.rs index 7eec4bd4..78d46751 100644 --- a/rustls/src/client/handy.rs +++ b/rustls/src/client/handy.rs @@ -243,6 +243,7 @@ impl client::ResolvesClientCert for AlwaysResolvesClientCert { test_for_each_provider! { use std::prelude::v1::*; + use alloc::sync::Arc; use super::NoClientSessionStorage; use crate::client::ClientSessionStore; use crate::msgs::enums::NamedGroup; @@ -251,6 +252,7 @@ test_for_each_provider! { use crate::msgs::handshake::SessionId; use crate::msgs::persist::Tls13ClientSessionValue; use crate::suites::SupportedCipherSuite; + use crate::msgs::base::PayloadU16; use provider::cipher_suite; use pki_types::{ServerName, UnixTime}; @@ -278,7 +280,7 @@ test_for_each_provider! { Tls12ClientSessionValue::new( tls12_suite, SessionId::empty(), - Vec::new(), + Arc::new(PayloadU16::empty()), &[], CertificateChain::default(), now, @@ -300,7 +302,7 @@ test_for_each_provider! { name.clone(), Tls13ClientSessionValue::new( tls13_suite, - Vec::new(), + Arc::new(PayloadU16::empty()), &[], CertificateChain::default(), now, diff --git a/rustls/src/client/hs.rs b/rustls/src/client/hs.rs index e610931f..c7d6ef3c 100644 --- a/rustls/src/client/hs.rs +++ b/rustls/src/client/hs.rs @@ -126,7 +126,7 @@ pub(super) fn start_handshake( // If we have a ticket, we use the sessionid as a signal that // we're doing an abbreviated handshake. See section 3.4 in // RFC5077. - if !inner.ticket().is_empty() { + if !inner.ticket().0.is_empty() { inner.session_id = SessionId::random(config.provider.secure_random)?; } Some(inner.session_id) diff --git a/rustls/src/client/tls12.rs b/rustls/src/client/tls12.rs index bef0f0f5..a0fcbdd7 100644 --- a/rustls/src/client/tls12.rs +++ b/rustls/src/client/tls12.rs @@ -1163,17 +1163,17 @@ impl ExpectFinished { // Save a ticket. If we got a new ticket, save that. Otherwise, save the // original ticket again. let (mut ticket, lifetime) = match self.ticket.take() { - Some(nst) => (nst.ticket.0, nst.lifetime_hint), - None => (Vec::new(), 0), + Some(nst) => (nst.ticket, nst.lifetime_hint), + None => (Arc::new(PayloadU16::empty()), 0), }; - if ticket.is_empty() { + if ticket.0.is_empty() { if let Some(resuming_session) = &mut self.resuming_session { - ticket = resuming_session.take_ticket(); + ticket = resuming_session.ticket(); } } - if self.session_id.is_empty() && ticket.is_empty() { + if self.session_id.is_empty() && ticket.0.is_empty() { debug!("Session not saved: server didn't allocate id or ticket"); return; } diff --git a/rustls/src/client/tls13.rs b/rustls/src/client/tls13.rs index 17b59f1c..066640c4 100644 --- a/rustls/src/client/tls13.rs +++ b/rustls/src/client/tls13.rs @@ -1428,7 +1428,7 @@ impl ExpectTraffic { #[allow(unused_mut)] let mut value = persist::Tls13ClientSessionValue::new( self.suite, - nst.ticket.0.clone(), + Arc::clone(&nst.ticket), secret.as_ref(), cx.common .peer_certificates diff --git a/rustls/src/msgs/handshake.rs b/rustls/src/msgs/handshake.rs index e17ef314..fee63461 100644 --- a/rustls/src/msgs/handshake.rs +++ b/rustls/src/msgs/handshake.rs @@ -1,6 +1,7 @@ use alloc::collections::BTreeSet; #[cfg(feature = "logging")] use alloc::string::String; +use alloc::sync::Arc; use alloc::vec; use alloc::vec::Vec; use core::ops::Deref; @@ -2263,7 +2264,10 @@ impl CertificateRequestPayloadTls13 { #[derive(Debug)] pub struct NewSessionTicketPayload { pub(crate) lifetime_hint: u32, - pub(crate) ticket: PayloadU16, + // Tickets can be large (KB), so we deserialise this straight + // into an Arc, so it can be passed directly into the client's + // session object without copying. + pub(crate) ticket: Arc<PayloadU16>, } impl NewSessionTicketPayload { @@ -2271,7 +2275,7 @@ impl NewSessionTicketPayload { pub(crate) fn new(lifetime_hint: u32, ticket: Vec<u8>) -> Self { Self { lifetime_hint, - ticket: PayloadU16::new(ticket), + ticket: Arc::new(PayloadU16::new(ticket)), } } } @@ -2284,7 +2288,7 @@ impl Codec<'_> for NewSessionTicketPayload { fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> { let lifetime = u32::read(r)?; - let ticket = PayloadU16::read(r)?; + let ticket = Arc::new(PayloadU16::read(r)?); Ok(Self { lifetime_hint: lifetime, @@ -2344,7 +2348,7 @@ pub struct NewSessionTicketPayloadTls13 { pub(crate) lifetime: u32, pub(crate) age_add: u32, pub(crate) nonce: PayloadU8, - pub(crate) ticket: PayloadU16, + pub(crate) ticket: Arc<PayloadU16>, pub(crate) exts: Vec<NewSessionTicketExtension>, } @@ -2354,7 +2358,7 @@ impl NewSessionTicketPayloadTls13 { lifetime, age_add, nonce: PayloadU8::new(nonce), - ticket: PayloadU16::new(ticket), + ticket: Arc::new(PayloadU16::new(ticket)), exts: vec![], } } @@ -2395,7 +2399,7 @@ impl Codec<'_> for NewSessionTicketPayloadTls13 { let lifetime = u32::read(r)?; let age_add = u32::read(r)?; let nonce = PayloadU8::read(r)?; - let ticket = PayloadU16::read(r)?; + let ticket = Arc::new(PayloadU16::read(r)?); let exts = Vec::read(r)?; Ok(Self { diff --git a/rustls/src/msgs/handshake_test.rs b/rustls/src/msgs/handshake_test.rs index ccd30e62..225b6b48 100644 --- a/rustls/src/msgs/handshake_test.rs +++ b/rustls/src/msgs/handshake_test.rs @@ -1,3 +1,4 @@ +use alloc::sync::Arc; use std::prelude::v1::*; use std::{format, println, vec}; @@ -1248,7 +1249,7 @@ fn sample_certificate_request_payload_tls13() -> CertificateRequestPayloadTls13 fn sample_new_session_ticket_payload() -> NewSessionTicketPayload { NewSessionTicketPayload { lifetime_hint: 1234, - ticket: PayloadU16(vec![1, 2, 3]), + ticket: Arc::new(PayloadU16(vec![1, 2, 3])), } } @@ -1257,7 +1258,7 @@ fn sample_new_session_ticket_payload_tls13() -> NewSessionTicketPayloadTls13 { lifetime: 123, age_add: 1234, nonce: PayloadU8(vec![1, 2, 3]), - ticket: PayloadU16(vec![4, 5, 6]), + ticket: Arc::new(PayloadU16(vec![4, 5, 6])), exts: vec![NewSessionTicketExtension::Unknown(UnknownExtension { typ: ExtensionType::Unknown(12345), payload: Payload::Borrowed(&[1, 2, 3]), diff --git a/rustls/src/msgs/persist.rs b/rustls/src/msgs/persist.rs index 785d51c6..33e7b66f 100644 --- a/rustls/src/msgs/persist.rs +++ b/rustls/src/msgs/persist.rs @@ -1,8 +1,6 @@ use alloc::sync::Arc; use alloc::vec::Vec; use core::cmp; -#[cfg(feature = "tls12")] -use core::mem; use pki_types::{DnsName, UnixTime}; use zeroize::Zeroizing; @@ -81,7 +79,7 @@ pub struct Tls13ClientSessionValue { impl Tls13ClientSessionValue { pub(crate) fn new( suite: &'static Tls13CipherSuite, - ticket: Vec<u8>, + ticket: Arc<PayloadU16>, secret: &[u8], server_cert_chain: CertificateChain<'static>, time_now: UnixTime, @@ -159,7 +157,7 @@ impl Tls12ClientSessionValue { pub(crate) fn new( suite: &'static Tls12CipherSuite, session_id: SessionId, - ticket: Vec<u8>, + ticket: Arc<PayloadU16>, master_secret: &[u8], server_cert_chain: CertificateChain<'static>, time_now: UnixTime, @@ -180,8 +178,8 @@ impl Tls12ClientSessionValue { } } - pub(crate) fn take_ticket(&mut self) -> Vec<u8> { - mem::take(&mut self.common.ticket.0) + pub(crate) fn ticket(&mut self) -> Arc<PayloadU16> { + Arc::clone(&self.common.ticket) } pub(crate) fn extended_ms(&self) -> bool { @@ -210,7 +208,7 @@ impl core::ops::Deref for Tls12ClientSessionValue { #[derive(Debug, Clone)] pub struct ClientSessionCommon { - ticket: PayloadU16, + ticket: Arc<PayloadU16>, secret: Zeroizing<PayloadU8>, epoch: u64, lifetime_secs: u32, @@ -219,14 +217,14 @@ pub struct ClientSessionCommon { impl ClientSessionCommon { fn new( - ticket: Vec<u8>, + ticket: Arc<PayloadU16>, secret: &[u8], time_now: UnixTime, lifetime_secs: u32, server_cert_chain: CertificateChain<'static>, ) -> Self { Self { - ticket: PayloadU16(ticket), + ticket, secret: Zeroizing::new(PayloadU8(secret.to_vec())), epoch: time_now.as_secs(), lifetime_secs: cmp::min(lifetime_secs, MAX_TICKET_LIFETIME), |