summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoe Birr-Pixton <jpixton@gmail.com>2024-09-12 16:33:44 +0100
committerJoe Birr-Pixton <jpixton@gmail.com>2024-09-26 13:50:09 +0000
commit39455e19fc3b5657f6a8c4fd22a0f981554ed5cb (patch)
treea49a9bbd555e3140ea7002d9197e9444a4394763
parent163a241a25a2a6233cbc568b33fad8b93f900f93 (diff)
Avoid excess copying of tickets
These can be large (hundreds of bytes), and even larger (thousands of bytes) if the server decides to include the client's identity. Parse them into an Arc, and then maintain that on the path to the session store.
-rw-r--r--rustls/src/client/handy.rs6
-rw-r--r--rustls/src/client/hs.rs2
-rw-r--r--rustls/src/client/tls12.rs10
-rw-r--r--rustls/src/client/tls13.rs2
-rw-r--r--rustls/src/msgs/handshake.rs16
-rw-r--r--rustls/src/msgs/handshake_test.rs5
-rw-r--r--rustls/src/msgs/persist.rs16
7 files changed, 31 insertions, 26 deletions
diff --git a/rustls/src/client/handy.rs b/rustls/src/client/handy.rs
index 7eec4bd4..78d46751 100644
--- a/rustls/src/client/handy.rs
+++ b/rustls/src/client/handy.rs
@@ -243,6 +243,7 @@ impl client::ResolvesClientCert for AlwaysResolvesClientCert {
test_for_each_provider! {
use std::prelude::v1::*;
+ use alloc::sync::Arc;
use super::NoClientSessionStorage;
use crate::client::ClientSessionStore;
use crate::msgs::enums::NamedGroup;
@@ -251,6 +252,7 @@ test_for_each_provider! {
use crate::msgs::handshake::SessionId;
use crate::msgs::persist::Tls13ClientSessionValue;
use crate::suites::SupportedCipherSuite;
+ use crate::msgs::base::PayloadU16;
use provider::cipher_suite;
use pki_types::{ServerName, UnixTime};
@@ -278,7 +280,7 @@ test_for_each_provider! {
Tls12ClientSessionValue::new(
tls12_suite,
SessionId::empty(),
- Vec::new(),
+ Arc::new(PayloadU16::empty()),
&[],
CertificateChain::default(),
now,
@@ -300,7 +302,7 @@ test_for_each_provider! {
name.clone(),
Tls13ClientSessionValue::new(
tls13_suite,
- Vec::new(),
+ Arc::new(PayloadU16::empty()),
&[],
CertificateChain::default(),
now,
diff --git a/rustls/src/client/hs.rs b/rustls/src/client/hs.rs
index e610931f..c7d6ef3c 100644
--- a/rustls/src/client/hs.rs
+++ b/rustls/src/client/hs.rs
@@ -126,7 +126,7 @@ pub(super) fn start_handshake(
// If we have a ticket, we use the sessionid as a signal that
// we're doing an abbreviated handshake. See section 3.4 in
// RFC5077.
- if !inner.ticket().is_empty() {
+ if !inner.ticket().0.is_empty() {
inner.session_id = SessionId::random(config.provider.secure_random)?;
}
Some(inner.session_id)
diff --git a/rustls/src/client/tls12.rs b/rustls/src/client/tls12.rs
index bef0f0f5..a0fcbdd7 100644
--- a/rustls/src/client/tls12.rs
+++ b/rustls/src/client/tls12.rs
@@ -1163,17 +1163,17 @@ impl ExpectFinished {
// Save a ticket. If we got a new ticket, save that. Otherwise, save the
// original ticket again.
let (mut ticket, lifetime) = match self.ticket.take() {
- Some(nst) => (nst.ticket.0, nst.lifetime_hint),
- None => (Vec::new(), 0),
+ Some(nst) => (nst.ticket, nst.lifetime_hint),
+ None => (Arc::new(PayloadU16::empty()), 0),
};
- if ticket.is_empty() {
+ if ticket.0.is_empty() {
if let Some(resuming_session) = &mut self.resuming_session {
- ticket = resuming_session.take_ticket();
+ ticket = resuming_session.ticket();
}
}
- if self.session_id.is_empty() && ticket.is_empty() {
+ if self.session_id.is_empty() && ticket.0.is_empty() {
debug!("Session not saved: server didn't allocate id or ticket");
return;
}
diff --git a/rustls/src/client/tls13.rs b/rustls/src/client/tls13.rs
index 17b59f1c..066640c4 100644
--- a/rustls/src/client/tls13.rs
+++ b/rustls/src/client/tls13.rs
@@ -1428,7 +1428,7 @@ impl ExpectTraffic {
#[allow(unused_mut)]
let mut value = persist::Tls13ClientSessionValue::new(
self.suite,
- nst.ticket.0.clone(),
+ Arc::clone(&nst.ticket),
secret.as_ref(),
cx.common
.peer_certificates
diff --git a/rustls/src/msgs/handshake.rs b/rustls/src/msgs/handshake.rs
index e17ef314..fee63461 100644
--- a/rustls/src/msgs/handshake.rs
+++ b/rustls/src/msgs/handshake.rs
@@ -1,6 +1,7 @@
use alloc::collections::BTreeSet;
#[cfg(feature = "logging")]
use alloc::string::String;
+use alloc::sync::Arc;
use alloc::vec;
use alloc::vec::Vec;
use core::ops::Deref;
@@ -2263,7 +2264,10 @@ impl CertificateRequestPayloadTls13 {
#[derive(Debug)]
pub struct NewSessionTicketPayload {
pub(crate) lifetime_hint: u32,
- pub(crate) ticket: PayloadU16,
+ // Tickets can be large (KB), so we deserialise this straight
+ // into an Arc, so it can be passed directly into the client's
+ // session object without copying.
+ pub(crate) ticket: Arc<PayloadU16>,
}
impl NewSessionTicketPayload {
@@ -2271,7 +2275,7 @@ impl NewSessionTicketPayload {
pub(crate) fn new(lifetime_hint: u32, ticket: Vec<u8>) -> Self {
Self {
lifetime_hint,
- ticket: PayloadU16::new(ticket),
+ ticket: Arc::new(PayloadU16::new(ticket)),
}
}
}
@@ -2284,7 +2288,7 @@ impl Codec<'_> for NewSessionTicketPayload {
fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
let lifetime = u32::read(r)?;
- let ticket = PayloadU16::read(r)?;
+ let ticket = Arc::new(PayloadU16::read(r)?);
Ok(Self {
lifetime_hint: lifetime,
@@ -2344,7 +2348,7 @@ pub struct NewSessionTicketPayloadTls13 {
pub(crate) lifetime: u32,
pub(crate) age_add: u32,
pub(crate) nonce: PayloadU8,
- pub(crate) ticket: PayloadU16,
+ pub(crate) ticket: Arc<PayloadU16>,
pub(crate) exts: Vec<NewSessionTicketExtension>,
}
@@ -2354,7 +2358,7 @@ impl NewSessionTicketPayloadTls13 {
lifetime,
age_add,
nonce: PayloadU8::new(nonce),
- ticket: PayloadU16::new(ticket),
+ ticket: Arc::new(PayloadU16::new(ticket)),
exts: vec![],
}
}
@@ -2395,7 +2399,7 @@ impl Codec<'_> for NewSessionTicketPayloadTls13 {
let lifetime = u32::read(r)?;
let age_add = u32::read(r)?;
let nonce = PayloadU8::read(r)?;
- let ticket = PayloadU16::read(r)?;
+ let ticket = Arc::new(PayloadU16::read(r)?);
let exts = Vec::read(r)?;
Ok(Self {
diff --git a/rustls/src/msgs/handshake_test.rs b/rustls/src/msgs/handshake_test.rs
index ccd30e62..225b6b48 100644
--- a/rustls/src/msgs/handshake_test.rs
+++ b/rustls/src/msgs/handshake_test.rs
@@ -1,3 +1,4 @@
+use alloc::sync::Arc;
use std::prelude::v1::*;
use std::{format, println, vec};
@@ -1248,7 +1249,7 @@ fn sample_certificate_request_payload_tls13() -> CertificateRequestPayloadTls13
fn sample_new_session_ticket_payload() -> NewSessionTicketPayload {
NewSessionTicketPayload {
lifetime_hint: 1234,
- ticket: PayloadU16(vec![1, 2, 3]),
+ ticket: Arc::new(PayloadU16(vec![1, 2, 3])),
}
}
@@ -1257,7 +1258,7 @@ fn sample_new_session_ticket_payload_tls13() -> NewSessionTicketPayloadTls13 {
lifetime: 123,
age_add: 1234,
nonce: PayloadU8(vec![1, 2, 3]),
- ticket: PayloadU16(vec![4, 5, 6]),
+ ticket: Arc::new(PayloadU16(vec![4, 5, 6])),
exts: vec![NewSessionTicketExtension::Unknown(UnknownExtension {
typ: ExtensionType::Unknown(12345),
payload: Payload::Borrowed(&[1, 2, 3]),
diff --git a/rustls/src/msgs/persist.rs b/rustls/src/msgs/persist.rs
index 785d51c6..33e7b66f 100644
--- a/rustls/src/msgs/persist.rs
+++ b/rustls/src/msgs/persist.rs
@@ -1,8 +1,6 @@
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::cmp;
-#[cfg(feature = "tls12")]
-use core::mem;
use pki_types::{DnsName, UnixTime};
use zeroize::Zeroizing;
@@ -81,7 +79,7 @@ pub struct Tls13ClientSessionValue {
impl Tls13ClientSessionValue {
pub(crate) fn new(
suite: &'static Tls13CipherSuite,
- ticket: Vec<u8>,
+ ticket: Arc<PayloadU16>,
secret: &[u8],
server_cert_chain: CertificateChain<'static>,
time_now: UnixTime,
@@ -159,7 +157,7 @@ impl Tls12ClientSessionValue {
pub(crate) fn new(
suite: &'static Tls12CipherSuite,
session_id: SessionId,
- ticket: Vec<u8>,
+ ticket: Arc<PayloadU16>,
master_secret: &[u8],
server_cert_chain: CertificateChain<'static>,
time_now: UnixTime,
@@ -180,8 +178,8 @@ impl Tls12ClientSessionValue {
}
}
- pub(crate) fn take_ticket(&mut self) -> Vec<u8> {
- mem::take(&mut self.common.ticket.0)
+ pub(crate) fn ticket(&mut self) -> Arc<PayloadU16> {
+ Arc::clone(&self.common.ticket)
}
pub(crate) fn extended_ms(&self) -> bool {
@@ -210,7 +208,7 @@ impl core::ops::Deref for Tls12ClientSessionValue {
#[derive(Debug, Clone)]
pub struct ClientSessionCommon {
- ticket: PayloadU16,
+ ticket: Arc<PayloadU16>,
secret: Zeroizing<PayloadU8>,
epoch: u64,
lifetime_secs: u32,
@@ -219,14 +217,14 @@ pub struct ClientSessionCommon {
impl ClientSessionCommon {
fn new(
- ticket: Vec<u8>,
+ ticket: Arc<PayloadU16>,
secret: &[u8],
time_now: UnixTime,
lifetime_secs: u32,
server_cert_chain: CertificateChain<'static>,
) -> Self {
Self {
- ticket: PayloadU16(ticket),
+ ticket,
secret: Zeroizing::new(PayloadU8(secret.to_vec())),
epoch: time_now.as_secs(),
lifetime_secs: cmp::min(lifetime_secs, MAX_TICKET_LIFETIME),