diff options
author | Mauro D <mauro@stalw.art> | 2024-02-08 20:03:57 -0300 |
---|---|---|
committer | Mauro D <mauro@stalw.art> | 2024-02-08 20:03:57 -0300 |
commit | d16119f54ba73a6cf33620fb78bfbd4beedf5dc4 (patch) | |
tree | 8ff76ef576e44eb9237f3dfb06ed1476cc167f5e /crates/smtp | |
parent | d15f59846078540577d33622accecc4f0669b661 (diff) |
Distributed SMTP queues (untested)
Diffstat (limited to 'crates/smtp')
29 files changed, 1915 insertions, 3006 deletions
diff --git a/crates/smtp/Cargo.toml b/crates/smtp/Cargo.toml index 2e689b8c..53580d1f 100644 --- a/crates/smtp/Cargo.toml +++ b/crates/smtp/Cargo.toml @@ -20,7 +20,7 @@ mail-auth = { version = "0.3" } mail-send = { version = "0.4", default-features = false, features = ["cram-md5"] } mail-parser = { version = "0.9", features = ["full_encoding", "ludicrous_mode"] } mail-builder = { version = "0.3", features = ["ludicrous_mode"] } -smtp-proto = { version = "0.1" } +smtp-proto = { version = "0.1", features = ["serde_support"] } sieve-rs = { version = "0.4" } ahash = { version = "0.8" } rustls = "0.22" diff --git a/crates/smtp/src/config/shared.rs b/crates/smtp/src/config/shared.rs index 80d42825..c0134e23 100644 --- a/crates/smtp/src/config/shared.rs +++ b/crates/smtp/src/config/shared.rs @@ -72,6 +72,17 @@ impl ConfigShared for Config { ) })? .clone(), + default_blob_store: self + .value_or_default("storage.blob", "storage.data") + .and_then(|id| ctx.stores.blob_stores.get(id)) + .ok_or_else(|| { + format!( + "Lookup store {:?} not found for key \"storage.blob\".", + self.value_or_default("storage.blob", "storage.data") + .unwrap() + ) + })? + .clone(), }) } diff --git a/crates/smtp/src/core/eval.rs b/crates/smtp/src/core/eval.rs index a7b8d2c8..1b90750a 100644 --- a/crates/smtp/src/core/eval.rs +++ b/crates/smtp/src/core/eval.rs @@ -3,16 +3,13 @@ use std::{borrow::Cow, net::IpAddr, sync::Arc, vec::IntoIter}; use directory::Directory; use mail_auth::IpLookupStrategy; use sieve::Sieve; -use store::{LookupKey, LookupStore, LookupValue}; +use store::{Deserialize, LookupStore}; use utils::{ config::if_block::IfBlock, expr::{Expression, Variable}, }; -use crate::{ - config::{ArcSealer, DkimSigner, RelayHost}, - scripts::plugins::lookup::VariableExists, -}; +use crate::config::{ArcSealer, DkimSigner, RelayHost}; use super::{ResolveVariable, SMTP}; @@ -165,15 +162,9 @@ impl SMTP { let key = params.next_as_string(); self.get_lookup_store(store.as_ref()) - .key_get::<String>(LookupKey::Key(key.into_owned().into_bytes())) + .key_get::<VariableWrapper>(key.into_owned().into_bytes()) .await - .map(|value| { - if let LookupValue::Value { value, .. } = value { - Variable::from(value) - } else { - Variable::default() - } - }) + .map(|value| value.map(|v| v.into_inner()).unwrap_or_default()) .unwrap_or_else(|err| { tracing::warn!( context = "eval_if", @@ -191,9 +182,8 @@ impl SMTP { let key = params.next_as_string(); self.get_lookup_store(store.as_ref()) - .key_get::<VariableExists>(LookupKey::Key(key.into_owned().into_bytes())) + .key_exists(key.into_owned().into_bytes()) .await - .map(|value| matches!(value, LookupValue::Value { .. })) .unwrap_or_else(|err| { tracing::warn!( context = "eval_if", @@ -395,3 +385,30 @@ impl<'x> FncParams<'x> { self.params.next().unwrap().into_string() } } + +#[derive(Debug)] +struct VariableWrapper(Variable<'static>); + +impl From<i64> for VariableWrapper { + fn from(value: i64) -> Self { + VariableWrapper(Variable::Integer(value)) + } +} + +impl Deserialize for VariableWrapper { + fn deserialize(bytes: &[u8]) -> store::Result<Self> { + String::deserialize(bytes).map(|v| VariableWrapper(Variable::String(v.into()))) + } +} + +impl From<store::Value<'static>> for VariableWrapper { + fn from(value: store::Value<'static>) -> Self { + VariableWrapper(value.into()) + } +} + +impl VariableWrapper { + pub fn into_inner(self) -> Variable<'static> { + self.0 + } +} diff --git a/crates/smtp/src/core/management.rs b/crates/smtp/src/core/management.rs index 246156e7..125e6102 100644 --- a/crates/smtp/src/core/management.rs +++ b/crates/smtp/src/core/management.rs @@ -21,7 +21,7 @@ * for more details. */ -use std::{borrow::Cow, fmt::Display, net::IpAddr, sync::Arc, time::Instant}; +use std::{borrow::Cow, net::IpAddr, sync::Arc}; use directory::{AuthResult, Type}; use http_body_util::{combinators::BoxBody, BodyExt, Empty, Full}; @@ -35,70 +35,24 @@ use hyper::{ use hyper_util::rt::TokioIo; use mail_parser::{decoders::base64::base64_decode, DateTime}; use mail_send::Credentials; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use tokio::sync::oneshot; +use serde::{Deserializer, Serializer}; +use store::{ + write::{key::DeserializeBigEndian, now, Bincode, QueueClass, ReportEvent, ValueClass}, + Deserialize, IterateParams, ValueKey, +}; use utils::listener::{limiter::InFlight, SessionData, SessionManager, SessionStream}; -use crate::{ - queue::{self, instant_to_timestamp, InstantFromTimestamp, QueueId, Status}, - reporting::{ - self, - scheduler::{ReportKey, ReportPolicy, ReportType, ReportValue}, - }, -}; +use crate::queue::{self, HostResponse, QueueId, Status}; use super::{SmtpAdminSessionManager, SMTP}; -#[derive(Debug)] -pub enum QueueRequest { - List { - from: Option<String>, - to: Option<String>, - before: Option<Instant>, - after: Option<Instant>, - result_tx: oneshot::Sender<Vec<u64>>, - }, - Status { - queue_ids: Vec<QueueId>, - result_tx: oneshot::Sender<Vec<Option<Message>>>, - }, - Cancel { - queue_ids: Vec<QueueId>, - item: Option<String>, - result_tx: oneshot::Sender<Vec<bool>>, - }, - Retry { - queue_ids: Vec<QueueId>, - item: Option<String>, - time: Instant, - result_tx: oneshot::Sender<Vec<bool>>, - }, -} - -#[derive(Debug)] -pub enum ReportRequest { - List { - type_: Option<ReportType<(), ()>>, - domain: Option<String>, - result_tx: oneshot::Sender<Vec<String>>, - }, - Status { - report_ids: Vec<ReportKey>, - result_tx: oneshot::Sender<Vec<Option<Report>>>, - }, - Cancel { - report_ids: Vec<ReportKey>, - result_tx: oneshot::Sender<Vec<bool>>, - }, -} - -#[derive(Debug, Serialize)] +#[derive(Debug, serde::Serialize)] pub struct Response<T> { data: T, } -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq, Eq)] pub struct Message { pub return_path: String, pub domains: Vec<Domain>, @@ -113,7 +67,7 @@ pub struct Message { pub env_id: Option<String>, } -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq, Eq)] pub struct Domain { pub name: String, pub status: Status<String, String>, @@ -131,7 +85,7 @@ pub struct Domain { pub expires: DateTime, } -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq, Eq)] pub struct Recipient { pub address: String, pub status: Status<String, String>, @@ -139,7 +93,7 @@ pub struct Recipient { pub orcpt: Option<String>, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, serde::Serialize, serde::Deserialize)] pub struct Report { pub domain: String, #[serde(rename = "type")] @@ -362,18 +316,48 @@ impl SMTP { match error { None => { - let (result_tx, result_rx) = oneshot::channel(); - self.send_queue_event( - QueueRequest::List { - from, - to, - before, - after, - result_tx, - }, - result_rx, + let mut result = Vec::new(); + let from_key = ValueKey::from(ValueClass::Queue(QueueClass::Message(0))); + let to_key = + ValueKey::from(ValueClass::Queue(QueueClass::Message(u64::MAX))); + let has_filters = + from.is_some() || to.is_some() || before.is_some() || after.is_some(); + let _ = + self.shared + .default_data_store + .iterate( + IterateParams::new(from_key, to_key).ascending(), + |key, value| { + if has_filters { + let message = + Bincode::<queue::Message>::deserialize(value)? + .inner; + if from.as_ref().map_or(true, |from| { + message.return_path.contains(from) + }) && to.as_ref().map_or(true, |to| { + message + .recipients + .iter() + .any(|r| r.address_lcase.contains(to)) + }) && before.as_ref().map_or(true, |before| { + message.next_delivery_event() < *before + }) && after.as_ref().map_or(true, |after| { + message.next_delivery_event() > *after + }) { + result.push(key.deserialize_be_u64(1)?); + } + } else { + result.push(key.deserialize_be_u64(1)?); + } + Ok(true) + }, + ) + .await; + + ( + StatusCode::OK, + serde_json::to_string(&Response { data: result }).unwrap_or_default(), ) - .await } Some(error) => error.into_bad_request(), } @@ -404,22 +388,24 @@ impl SMTP { match error { None => { - let (result_tx, result_rx) = oneshot::channel(); - self.send_queue_event( - QueueRequest::Status { - queue_ids, - result_tx, - }, - result_rx, + let mut result = Vec::with_capacity(queue_ids.len()); + for queue_id in queue_ids { + if let Some(message) = self.read_message(queue_id).await { + result.push(Message::from(&message)); + } + } + + ( + StatusCode::OK, + serde_json::to_string(&Response { data: result }).unwrap_or_default(), ) - .await } Some(error) => error.into_bad_request(), } } (&Method::GET, "queue", "retry") => { let mut queue_ids = Vec::new(); - let mut time = Instant::now(); + let mut time = now(); let mut item = None; let mut error = None; @@ -457,17 +443,49 @@ impl SMTP { match error { None => { - let (result_tx, result_rx) = oneshot::channel(); - self.send_queue_event( - QueueRequest::Retry { - queue_ids, - item, - time, - result_tx, - }, - result_rx, + let mut result = Vec::with_capacity(queue_ids.len()); + + for queue_id in queue_ids { + let mut found = false; + + if let Some(mut message) = self.read_message(queue_id).await { + let prev_event = message.next_event().unwrap_or_default(); + + for domain in &mut message.domains { + if matches!( + domain.status, + Status::Scheduled | Status::TemporaryFailure(_) + ) && item + .as_ref() + .map_or(true, |item| domain.domain.contains(item)) + { + domain.retry.due = time; + if domain.expires > time { + domain.expires = time + 10; + } + found = true; + } + } + + if found { + let next_event = message.next_event().unwrap_or_default(); + message + .save_changes(self, prev_event.into(), next_event.into()) + .await; + } + } + + result.push(found); + } + + if result.iter().any(|r| *r) { + let _ = self.queue.tx.send(queue::Event::Reload).await; + } + + ( + StatusCode::OK, + serde_json::to_string(&Response { data: result }).unwrap_or_default(), ) - .await } Some(error) => error.into_bad_request(), } @@ -502,16 +520,93 @@ impl SMTP { match error { None => { - let (result_tx, result_rx) = oneshot::channel(); - self.send_queue_event( - QueueRequest::Cancel { - queue_ids, - item, - result_tx, - }, - result_rx, + let mut result = Vec::with_capacity(queue_ids.len()); + + for queue_id in queue_ids { + let mut found = false; + + if let Some(mut message) = self.read_message(queue_id).await { + let prev_event = message.next_event().unwrap_or_default(); + + if let Some(item) = &item { + // Cancel delivery for all recipients that match + for rcpt in &mut message.recipients { + if rcpt.address_lcase.contains(item) { + rcpt.status = Status::Completed(HostResponse { + hostname: String::new(), + response: smtp_proto::Response { + code: 0, + esc: [0, 0, 0], + message: "Delivery canceled.".to_string(), + }, + }); + found = true; + } + } + if found { + // Mark as completed domains without any pending deliveries + for (domain_idx, domain) in + message.domains.iter_mut().enumerate() + { + if matches!( + domain.status, + Status::TemporaryFailure(_) | Status::Scheduled + ) { + let mut total_rcpt = 0; + let mut total_completed = 0; + + for rcpt in &message.recipients { + if rcpt.domain_idx == domain_idx { + total_rcpt += 1; + if matches!( + rcpt.status, + Status::PermanentFailure(_) + | Status::Completed(_) + ) { + total_completed += 1; + } + } + } + + if total_rcpt == total_completed { + domain.status = Status::Completed(()); + } + } + } + + // Delete message if there are no pending deliveries + if message.domains.iter().any(|domain| { + matches!( + domain.status, + Status::TemporaryFailure(_) | Status::Scheduled + ) + }) { + let next_event = + message.next_event().unwrap_or_default(); + message + .save_changes( + self, + next_event.into(), + prev_event.into(), + ) + .await; + } else { + message.remove(self, prev_event).await; + } + } + } else { + message.remove(self, prev_event).await; + found = true; + } + } + + result.push(found); + } + + ( + StatusCode::OK, + serde_json::to_string(&Response { data: result }).unwrap_or_default(), ) - .await } Some(error) => error.into_bad_request(), } @@ -526,10 +621,10 @@ impl SMTP { match key.as_ref() { "type" => match value.as_ref() { "dmarc" => { - type_ = ReportType::Dmarc(()).into(); + type_ = 0u8.into(); } "tls" => { - type_ = ReportType::Tls(()).into(); + type_ = 1u8.into(); } _ => { error = format!("Invalid report type {value:?}.").into(); @@ -549,16 +644,54 @@ impl SMTP { match error { None => { - let (result_tx, result_rx) = oneshot::channel(); - self.send_report_event( - ReportRequest::List { - type_, - domain, - result_tx, - }, - result_rx, + let mut result = Vec::new(); + let from_key = ValueKey::from(ValueClass::Queue( + QueueClass::DmarcReportHeader(ReportEvent { + due: 0, + policy_hash: 0, + seq_id: 0, + domain: String::new(), + }), + )); + let to_key = ValueKey::from(ValueClass::Queue( + QueueClass::TlsReportHeader(ReportEvent { + due: u64::MAX, + policy_hash: 0, + seq_id: 0, + domain: String::new(), + }), + )); + let _ = + self.shared + .default_data_store + .iterate( + IterateParams::new(from_key, to_key).ascending().no_values(), + |key, _| { + if type_.map_or(true, |t| t == *key.last().unwrap()) { + let event = ReportEvent::deserialize(key)?; + if domain.as_ref().map_or(true, |d| { + d.eq_ignore_ascii_case(&event.domain) + }) { + result.push( + if *key.last().unwrap() == 0 { + QueueClass::DmarcReportHeader(event) + } else { + QueueClass::TlsReportHeader(event) + } + .queue_id(), + ); + } + } + + Ok(true) + }, + ) + .await; + + ( + StatusCode::OK, + serde_json::to_string(&Response { data: result }).unwrap_or_default(), ) - .await } Some(error) => error.into_bad_request(), } @@ -588,17 +721,13 @@ impl SMTP { } match error { - None => { - let (result_tx, result_rx) = oneshot::channel(); - self.send_report_event( - ReportRequest::Status { - report_ids, - result_tx, - }, - result_rx, - ) - .await - } + None => ( + StatusCode::OK, + serde_json::to_string(&Response { + data: report_ids.into_iter().map(Report::from).collect::<Vec<_>>(), + }) + .unwrap_or_default(), + ), Some(error) => error.into_bad_request(), } } @@ -628,15 +757,26 @@ impl SMTP { match error { None => { - let (result_tx, result_rx) = oneshot::channel(); - self.send_report_event( - ReportRequest::Cancel { - report_ids, - result_tx, - }, - result_rx, + let mut result = Vec::with_capacity(report_ids.len()); + + for report_id in report_ids { + match report_id { + QueueClass::DmarcReportHeader(event) => { + self.delete_dmarc_report(event).await; + } + QueueClass::TlsReportHeader(event) => { + self.delete_tls_report(vec![event]).await; + } + _ => (), + } + + result.push(true); + } + + ( + StatusCode::OK, + serde_json::to_string(&Response { data: result }).unwrap_or_default(), ) - .await } Some(error) => error.into_bad_request(), } @@ -660,85 +800,11 @@ impl SMTP { ) .unwrap() } - - async fn send_queue_event<T: Serialize>( - &self, - request: QueueRequest, - rx: oneshot::Receiver<T>, - ) -> (StatusCode, String) { - match self.queue.tx.send(queue::Event::Manage(request)).await { - Ok(_) => match rx.await { - Ok(result) => { - return ( - StatusCode::OK, - serde_json::to_string(&Response { data: result }).unwrap_or_default(), - ) - } - Err(_) => { - tracing::debug!( - context = "queue", - event = "recv-error", - reason = "Failed to receive manage request response." - ); - } - }, - Err(_) => { - tracing::debug!( - context = "queue", - event = "send-error", - reason = "Failed to send manage request event." - ); - } - } - - ( - StatusCode::INTERNAL_SERVER_ERROR, - "{\"error\": \"internal-error\", \"details\": \"Resource unavailable, try again later.\"}" - .to_string(), - ) - } - - async fn send_report_event<T: Serialize>( - &self, - request: ReportRequest, - rx: oneshot::Receiver<T>, - ) -> (StatusCode, String) { - match self.report.tx.send(reporting::Event::Manage(request)).await { - Ok(_) => match rx.await { - Ok(result) => { - return ( - StatusCode::OK, - serde_json::to_string(&Response { data: result }).unwrap_or_default(), - ) - } - Err(_) => { - tracing::debug!( - context = "queue", - event = "recv-error", - reason = "Failed to receive manage request response." - ); - } - }, - Err(_) => { - tracing::debug!( - context = "queue", - event = "send-error", - reason = "Failed to send manage request event." - ); - } - } - - ( - StatusCode::INTERNAL_SERVER_ERROR, - "{\"error\": \"internal-error\", \"details\": \"Resource unavailable, try again later.\"}" - .to_string(), - ) - } } impl From<&queue::Message> for Message { fn from(message: &queue::Message) -> Self { - let now = Instant::now(); + let now = now(); Message { return_path: message.return_path.clone(), @@ -764,20 +830,12 @@ impl From<&queue::Message> for Message { }, retry_num: domain.retry.inner, next_retry: if domain.retry.due > now { - DateTime::from_timestamp(instant_to_timestamp(now, domain.retry.due) as i64) - .into() + DateTime::from_timestamp(domain.retry.due as i64).into() } else { None }, next_notify: if domain.notify.due > now { - DateTime::from_timestamp( - instant_to_timestamp( - now, - domain.notify.due, - ) - as i64, - ) - .into() + DateTime::from_timestamp(domain.notify.due as i64).into() } else { None }, @@ -802,61 +860,64 @@ impl From<&queue::Message> for Message { orcpt: rcpt.orcpt.clone(), }) .collect(), - expires: DateTime::from_timestamp( - instant_to_timestamp(now, domain.expires) as i64 - ), + expires: DateTime::from_timestamp(domain.expires as i64), }) .collect(), } } } -impl From<(&ReportKey, &ReportValue)> for Report { - fn from((key, value): (&ReportKey, &ReportValue)) -> Self { - match (key, value) { - (ReportType::Dmarc(domain), ReportType::Dmarc(value)) => Report { - domain: domain.inner.clone(), - range_from: DateTime::from_timestamp(value.created as i64), - range_to: DateTime::from_timestamp( - (value.created + value.deliver_at.as_secs()) as i64, - ), - size: value.size, +impl From<QueueClass> for Report { + fn from(value: QueueClass) -> Self { + match value { + QueueClass::DmarcReportHeader(event) => Report { + domain: event.domain, type_: "dmarc".to_string(), + range_from: DateTime::from_timestamp(event.due as i64), + range_to: DateTime::from_timestamp(event.due as i64), + size: 0, }, - (ReportType::Tls(domain), ReportType::Tls(value)) => Report { - domain: domain.clone(), - range_from: DateTime::from_timestamp(value.created as i64), - range_to: DateTime::from_timestamp( - (value.created + value.deliver_at.as_secs()) as i64, - ), - size: value.size, + QueueClass::TlsReportHeader(event) => Report { + domain: event.domain, type_: "tls".to_string(), + range_from: DateTime::from_timestamp(event.due as i64), + range_to: DateTime::from_timestamp(event.due as i64), + size: 0, }, _ => unreachable!(), } } } -impl Display for ReportKey { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +trait GenerateQueueId { + fn queue_id(&self) -> String; +} + +impl GenerateQueueId for QueueClass { + fn queue_id(&self) -> String { match self { - ReportType::Dmarc(policy) => write!(f, "d!{}!{}", policy.inner, policy.policy), - ReportType::Tls(domain) => write!(f, "t!{domain}"), + QueueClass::DmarcReportHeader(h) => { + format!("d!{}!{}!{}!{}", h.domain, h.policy_hash, h.seq_id, h.due) + } + QueueClass::TlsReportHeader(h) => { + format!("t!{}!{}!{}!{}", h.domain, h.policy_hash, h.seq_id, h.due) + } + _ => unreachable!(), } } } trait ParseValues { - fn parse_timestamp(&self) -> Result<Instant, String>; + fn parse_timestamp(&self) -> Result<u64, String>; fn parse_queue_ids(&self) -> Result<Vec<QueueId>, String>; - fn parse_report_ids(&self) -> Result<Vec<ReportKey>, String>; + fn parse_report_ids(&self) -> Result<Vec<QueueClass>, String>; } impl ParseValues for Cow<'_, str> { - fn parse_timestamp(&self) -> Result<Instant, String> { + fn parse_timestamp(&self) -> Result<u64, String> { if let Some(dt) = DateTime::parse_rfc3339(self.as_ref()) { - let instant = (dt.to_timestamp() as u64).to_instant(); - if instant >= Instant::now() { + let instant = dt.to_timestamp() as u64; + if instant >= now() { return Ok(instant); } } @@ -881,29 +942,42 @@ impl ParseValues for Cow<'_, str> { Ok(ids) } - fn parse_report_ids(&self) -> Result<Vec<ReportKey>, String> { + fn parse_report_ids(&self) -> Result<Vec<QueueClass>, String> { let mut ids = Vec::new(); for id in self.split(',') { if !id.is_empty() { let mut parts = id.split('!'); - match (parts.next(), parts.next()) { - (Some("d"), Some(domain)) if !domain.is_empty() => { - if let Some(policy) = parts.next().and_then(|policy| policy.parse().ok()) { - ids.push(ReportType::Dmarc(ReportPolicy { - inner: domain.to_string(), - policy, - })); - continue; - } + match ( + parts.next(), + parts.next(), + parts.next().and_then(|p| p.parse::<u64>().ok()), + parts.next().and_then(|p| p.parse::<u64>().ok()), + parts.next().and_then(|p| p.parse::<u64>().ok()), + ) { + (Some("d"), Some(domain), Some(policy), Some(seq_id), Some(due)) + if !domain.is_empty() => + { + ids.push(QueueClass::DmarcReportHeader(ReportEvent { + due, + policy_hash: policy, + seq_id, + domain: domain.to_string(), + })); + } + (Some("t"), Some(domain), Some(policy), Some(seq_id), Some(due)) + if !domain.is_empty() => + { + ids.push(QueueClass::TlsReportHeader(ReportEvent { + due, + policy_hash: policy, + seq_id, + domain: domain.to_string(), + })); } - (Some("t"), Some(domain)) if !domain.is_empty() => { - ids.push(ReportType::Tls(domain.to_string())); - continue; + _ => { + return Err(format!("Failed to parse id {id:?}.")); } - _ => (), } - - return Err(format!("Failed to parse id {id:?}.")); } } Ok(ids) @@ -944,7 +1018,7 @@ fn deserialize_maybe_datetime<'de, D>(deserializer: D) -> Result<Option<DateTime where D: Deserializer<'de>, { - if let Some(value) = Option::<&str>::deserialize(deserializer)? { + if let Some(value) = <Option<&str> as serde::Deserialize>::deserialize(deserializer)? { if let Some(value) = DateTime::parse_rfc3339(value) { Ok(Some(value)) } else { @@ -968,6 +1042,8 @@ fn deserialize_datetime<'de, D>(deserializer: D) -> Result<DateTime, D::Error> where D: Deserializer<'de>, { + use serde::Deserialize; + if let Some(value) = DateTime::parse_rfc3339(<&str>::deserialize(deserializer)?) { Ok(value) } else { diff --git a/crates/smtp/src/core/mod.rs b/crates/smtp/src/core/mod.rs index 28be3737..7808333c 100644 --- a/crates/smtp/src/core/mod.rs +++ b/crates/smtp/src/core/mod.rs @@ -24,7 +24,7 @@ use std::{ hash::Hash, net::IpAddr, - sync::{atomic::AtomicU32, Arc}, + sync::Arc, time::{Duration, Instant}, }; @@ -40,7 +40,7 @@ use smtp_proto::{ }, IntoString, }; -use store::{LookupStore, Store, Value}; +use store::{BlobStore, LookupStore, Store, Value}; use tokio::{ io::{AsyncRead, AsyncWrite}, sync::mpsc, @@ -50,7 +50,12 @@ use tracing::Span; use utils::{ expr, ipc::DeliveryEvent, - listener::{limiter::InFlight, stream::NullIo, ServerInstance, TcpAcceptor}, + listener::{ + limiter::{ConcurrencyLimiter, InFlight}, + stream::NullIo, + ServerInstance, TcpAcceptor, + }, + snowflake::SnowflakeIdGenerator, }; use crate::{ @@ -63,11 +68,11 @@ use crate::{ dane::{DnssecResolver, Tlsa}, mta_sts, }, - queue::{self, DomainPart, QueueId, QuotaLimiter}, + queue::{self, DomainPart, QueueId}, reporting, }; -use self::throttle::{Limiter, ThrottleKey, ThrottleKeyHasherBuilder}; +use self::throttle::{ThrottleKey, ThrottleKeyHasherBuilder}; pub mod eval; pub mod management; @@ -121,6 +126,7 @@ pub struct Shared { // Default store and directory pub default_directory: Arc<Directory>, pub default_data_store: Store, + pub default_blob_store: BlobStore, pub default_lookup_store: LookupStore, } @@ -145,15 +151,14 @@ pub struct DnsCache { pub struct SessionCore { pub config: SessionConfig, - pub throttle: DashMap<ThrottleKey, Limiter, ThrottleKeyHasherBuilder>, + pub throttle: DashMap<ThrottleKey, ConcurrencyLimiter, ThrottleKeyHasherBuilder>, } pub struct QueueCore { pub config: QueueConfig, - pub throttle: DashMap<ThrottleKey, Limiter, ThrottleKeyHasherBuilder>, - pub quota: DashMap<ThrottleKey, Arc<QuotaLimiter>, ThrottleKeyHasherBuilder>, + pub throttle: DashMap<ThrottleKey, ConcurrencyLimiter, ThrottleKeyHasherBuilder>, pub tx: mpsc::Sender<queue::Event>, - pub id_seq: AtomicU32, + pub snowflake_id: SnowflakeIdGenerator, pub connectors: TlsConnectors, } diff --git a/crates/smtp/src/core/throttle.rs b/crates/smtp/src/core/throttle.rs index 6cb3aa03..f2a336dd 100644 --- a/crates/smtp/src/core/throttle.rs +++ b/crates/smtp/src/core/throttle.rs @@ -21,7 +21,7 @@ * for more details. */ -use ::utils::listener::limiter::{ConcurrencyLimiter, RateLimiter}; +use ::utils::listener::limiter::ConcurrencyLimiter; use dashmap::mapref::entry::Entry; use tokio::io::{AsyncRead, AsyncWrite}; use utils::config::Rate; @@ -32,12 +32,6 @@ use crate::config::*; use super::{eval::*, ResolveVariable, Session}; -#[derive(Debug)] -pub struct Limiter { - pub rate: Option<RateLimiter>, - pub concurrency: Option<ConcurrencyLimiter>, -} - #[derive(Debug, Clone, Eq)] pub struct ThrottleKey { hash: [u8; 32], @@ -55,6 +49,12 @@ impl Hash for ThrottleKey { } } +impl AsRef<[u8]> for ThrottleKey { + fn as_ref(&self) -> &[u8] { + &self.hash + } +} + #[derive(Default)] pub struct ThrottleKeyHasher { hash: u64, @@ -236,10 +236,36 @@ impl<T: AsyncRead + AsyncWrite> Session<T> { } // Build throttle key - match self.core.session.throttle.entry(t.new_key(self)) { - Entry::Occupied(mut e) => { - let limiter = e.get_mut(); - if let Some(limiter) = &limiter.concurrency { + let key = t.new_key(self); + + // Check rate + if let Some(rate) = &t.rate { + if self + .core + .shared + .default_lookup_store + .is_rate_allowed(key.hash.as_slice(), rate, false) + .await + .unwrap_or_default() + .is_some() + { + tracing::debug!( + parent: &self.span, + context = "throttle", + event = "rate-limit-exceeded", + max_requests = rate.requests, + max_interval = rate.period.as_secs(), + "Rate limit exceeded." + ); + return false; + } + } + + // Check concurrency + if let Some(concurrency) = &t.concurrency { + match self.core.session.throttle.entry(key) { + Entry::Occupied(mut e) => { + let limiter = e.get_mut(); if let Some(inflight) = limiter.is_allowed() { self.in_flight.push(inflight); } else { @@ -253,35 +279,13 @@ impl<T: AsyncRead + AsyncWrite> Session<T> { return false; } } - if let (Some(limiter), Some(rate)) = (&mut limiter.rate, &t.rate) { - if !limiter.is_allowed(rate) { - tracing::debug!( - parent: &self.span, - context = "throttle", - event = "rate-limit-exceeded", - max_requests = rate.requests, - max_interval = rate.period.as_secs(), - "Rate limit exceeded." - ); - return false; - } - } - } - Entry::Vacant(e) => { - let concurrency = t.concurrency.map(|concurrency| { - let limiter = ConcurrencyLimiter::new(concurrency); + Entry::Vacant(e) => { + let limiter = ConcurrencyLimiter::new(*concurrency); if let Some(inflight) = limiter.is_allowed() { self.in_flight.push(inflight); } - limiter - }); - let rate = t.rate.as_ref().map(|rate| { - let r = RateLimiter::new(rate); - r.is_allowed(rate); - r - }); - - e.insert(Limiter { rate, concurrency }); + e.insert(limiter); + } } } } @@ -290,33 +294,19 @@ impl<T: AsyncRead + AsyncWrite> Session<T> { true } - pub fn throttle_rcpt(&self, rcpt: &str, rate: &Rate, ctx: &str) -> bool { + pub async fn throttle_rcpt(&self, rcpt: &str, rate: &Rate, ctx: &str) -> bool { let mut hasher = blake3::Hasher::new(); hasher.update(rcpt.as_bytes()); hasher.update(ctx.as_bytes()); hasher.update(&rate.period.as_secs().to_ne_bytes()[..]); hasher.update(&rate.requests.to_ne_bytes()[..]); - let key = ThrottleKey { - hash: hasher.finalize().into(), - }; - match self.core.session.throttle.entry(key) { - Entry::Occupied(mut e) => { - if let Some(limiter) = &mut e.get_mut().rate { - limiter.is_allowed(rate) - } else { - false - } - } - Entry::Vacant(e) => { - let limiter = RateLimiter::new(rate); - limiter.is_allowed(rate); - e.insert(Limiter { - rate: limiter.into(), - concurrency: None, - }); - true - } - } + self.core + .shared + .default_lookup_store + .is_rate_allowed(hasher.finalize().as_bytes(), rate, false) + .await + .unwrap_or_default() + .is_none() } } diff --git a/crates/smtp/src/core/worker.rs b/crates/smtp/src/core/worker.rs index 4dafc1e5..56a030f1 100644 --- a/crates/smtp/src/core/worker.rs +++ b/crates/smtp/src/core/worker.rs @@ -54,16 +54,8 @@ impl SMTP { fn cleanup(&self) { for throttle in [&self.session.throttle, &self.queue.throttle] { - throttle.retain(|_, v| { - v.concurrency - .as_ref() - .map_or(false, |c| c.concurrent.load(Ordering::Relaxed) > 0) - || v.rate.as_ref().map_or(false, |r| r.is_active()) - }); + throttle.retain(|_, v| v.concurrent.load(Ordering::Relaxed) > 0); } - self.queue.quota.retain(|_, v| { - v.messages.load(Ordering::Relaxed) > 0 || v.size.load(Ordering::Relaxed) > 0 - }); } } diff --git a/crates/smtp/src/inbound/data.rs b/crates/smtp/src/inbound/data.rs index 72e8f8ca..7a363bca 100644 --- a/crates/smtp/src/inbound/data.rs +++ b/crates/smtp/src/inbound/data.rs @@ -23,10 +23,9 @@ use std::{ borrow::Cow, - path::PathBuf, process::Stdio, sync::Arc, - time::{Duration, Instant, SystemTime}, + time::{Duration, SystemTime}, }; use mail_auth::{ @@ -38,6 +37,7 @@ use sieve::runtime::Variable; use smtp_proto::{ MAIL_BY_RETURN, RCPT_NOTIFY_DELAY, RCPT_NOTIFY_FAILURE, RCPT_NOTIFY_NEVER, RCPT_NOTIFY_SUCCESS, }; +use store::write::now; use tokio::{io::AsyncWriteExt, process::Command}; use utils::{config::Rate, listener::SessionStream}; @@ -654,10 +654,8 @@ impl<T: SessionStream> Session<T> { // Verify queue quota if self.core.has_quota(&mut message).await { let queue_id = message.id; - if self - .core - .queue - .queue_message(message, Some(&headers), &raw_message, &self.span) + if message + .queue(Some(&headers), &raw_message, &self.core, &self.span) .await { self.state = State::Accepted(queue_id); @@ -682,14 +680,14 @@ impl<T: SessionStream> Session<T> { &self, mail_from: SessionAddress, mut rcpt_to: Vec<SessionAddress>, - ) -> Box<Message> { + ) -> Message { // Build message - let mut message = Box::new(Message { - id: self.core.queue.queue_id(), - path: PathBuf::new(), - created: SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .map_or(0, |d| d.as_secs()), + let created = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .map_or(0, |d| d.as_secs()); + let mut message = Message { + id: self.core.queue.snowflake_id.generate().unwrap_or(created), + created, return_path: mail_from.address, return_path_lcase: mail_from.address_lcase, return_path_domain: mail_from.domain, @@ -699,8 +697,9 @@ impl<T: SessionStream> Session<T> { priority: self.data.priority, size: 0, env_id: mail_from.dsn_info, - queue_refs: Vec::with_capacity(0), - }); + blob_hash: Default::default(), + quota_keys: Vec::new(), + }; // Add recipients let future_release = Duration::from_secs(self.data.future_release); @@ -711,7 +710,7 @@ impl<T: SessionStream> Session<T> { .last() .map_or(true, |d| d.domain != rcpt.domain) { - let envelope = SimpleEnvelope::new(message.as_ref(), &rcpt.domain); + let envelope = SimpleEnvelope::new(&message, &rcpt.domain); // Set next retry time let retry = if self.data.future_release == 0 { @@ -731,18 +730,19 @@ impl<T: SessionStream> Session<T> { let (notify, expires) = if self.data.delivery_by == 0 { ( queue::Schedule::later(future_release + next_notify), - Instant::now() - + future_release + now() + + future_release.as_secs() + self .core .eval_if(&config.expire, &envelope) .await - .unwrap_or_else(|| Duration::from_secs(5 * 86400)), + .unwrap_or_else(|| Duration::from_secs(5 * 86400)) + .as_secs(), ) } else if (message.flags & MAIL_BY_RETURN) != 0 { ( queue::Schedule::later(future_release + next_notify), - Instant::now() + Duration::from_secs(self.data.delivery_by as u64), + now() + self.data.delivery_by as u64, ) } else { let expire = self @@ -769,7 +769,7 @@ impl<T: SessionStream> Session<T> { let mut notify = queue::Schedule::later(future_release + notify); notify.inner = (num_intervals - 1) as u32; // Disable further notification attempts - (notify, Instant::now() + expire) + (notify, now() + expire_secs) }; message.domains.push(queue::Domain { @@ -779,7 +779,6 @@ impl<T: SessionStream> Session<T> { status: queue::Status::Scheduled, domain: rcpt.domain, disable_tls: false, - changed: false, }); } diff --git a/crates/smtp/src/lib.rs b/crates/smtp/src/lib.rs index 922d8347..bdb63e51 100644 --- a/crates/smtp/src/lib.rs +++ b/crates/smtp/src/lib.rs @@ -42,6 +42,7 @@ use store::Stores; use tokio::sync::mpsc; use utils::{ config::{Config, ServerProtocol, Servers}, + snowflake::SnowflakeIdGenerator, UnwrapFailure, }; @@ -129,15 +130,10 @@ impl SMTP { .unwrap_or(32) .next_power_of_two() as usize, ), - id_seq: 0.into(), - quota: DashMap::with_capacity_and_hasher_and_shard_amount( - config.property("global.shared-map.capacity")?.unwrap_or(2), - ThrottleKeyHasherBuilder::default(), - config - .property::<u64>("global.shared-map.shard")? - .unwrap_or(32) - .next_power_of_two() as usize, - ), + snowflake_id: config + .property::<u64>("storage.cluster.node-id")? + .map(SnowflakeIdGenerator::with_node_id) + .unwrap_or_else(SnowflakeIdGenerator::new), tx: queue_tx, connectors: TlsConnectors { pki_verify: build_tls_connector(false), @@ -156,10 +152,10 @@ impl SMTP { }); // Spawn queue manager - queue_rx.spawn(core.clone(), core.queue.read_queue().await); + queue_rx.spawn(core.clone()); // Spawn report manager - report_rx.spawn(core.clone(), core.report.read_reports().await); + report_rx.spawn(core.clone()); Ok(core) } diff --git a/crates/smtp/src/outbound/delivery.rs b/crates/smtp/src/outbound/delivery.rs index 14da36b2..c7496a0f 100644 --- a/crates/smtp/src/outbound/delivery.rs +++ b/crates/smtp/src/outbound/delivery.rs @@ -24,7 +24,7 @@ use std::{ net::{IpAddr, Ipv4Addr, SocketAddr}, sync::Arc, - time::{Duration, Instant}, + time::Duration, }; use mail_auth::{ @@ -33,6 +33,7 @@ use mail_auth::{ }; use mail_send::SmtpClient; use smtp_proto::MAIL_REQUIRETLS; +use store::write::now; use utils::config::ServerProtocol; use crate::{ @@ -49,71 +50,80 @@ use super::{ NextHop, }; use crate::queue::{ - manager::Queue, throttle, DeliveryAttempt, Domain, Error, Event, OnHold, QueueEnvelope, - Schedule, Status, WorkerResult, + throttle, DeliveryAttempt, Domain, Error, Event, OnHold, QueueEnvelope, Status, }; impl DeliveryAttempt { - pub async fn try_deliver(mut self, core: Arc<SMTP>, queue: &mut Queue) { - // Check that the message still has recipients to be delivered - let has_pending_delivery = self.has_pending_delivery(); + pub async fn try_deliver(mut self, core: Arc<SMTP>) { + tokio::spawn(async move { + // Check that the message still has recipients to be delivered + let has_pending_delivery = self.has_pending_delivery(); - // Send any due Delivery Status Notifications - core.send_dsn(&mut self).await; + // Send any due Delivery Status Notifications + core.send_dsn(&mut self).await; - if has_pending_delivery { - // Re-queue the message if its not yet due for delivery - let due = self.message.next_delivery_event(); - if due > Instant::now() { - // Save changes to disk - self.message.save_changes().await; + if has_pending_delivery { + // Re-queue the message if its not yet due for delivery + let due = self.message.next_delivery_event(); + if due > now() { + // Save changes + self.message + .save_changes(&core, self.event.due.into(), due.into()) + .await; + if core.queue.tx.send(Event::Reload).await.is_err() { + tracing::warn!("Channel closed while trying to notify queue manager."); + } + return; + } + } else { + // All message recipients expired, do not re-queue. (DSN has been already sent) + self.message.remove(&core, self.event.due).await; + if core.queue.tx.send(Event::Reload).await.is_err() { + tracing::warn!("Channel closed while trying to notify queue manager."); + } - queue.schedule(Schedule { - due, - inner: self.message, - }); return; } - } else { - // All message recipients expired, do not re-queue. (DSN has been already sent) - self.message.remove().await; - return; - } - // Throttle sender - for throttle in &core.queue.config.throttle.sender { - if let Err(err) = core - .is_allowed( - throttle, - self.message.as_ref(), - &mut self.in_flight, - &self.span, - ) - .await - { - // Save changes to disk - self.message.save_changes().await; - - match err { - throttle::Error::Concurrency { limiter } => { - queue.on_hold(OnHold { - next_due: self.message.next_event_after(Instant::now()), - limiters: vec![limiter], - message: self.message, - }); - } - throttle::Error::Rate { retry_at } => { - queue.schedule(Schedule { - due: retry_at, - inner: self.message, - }); + // Throttle sender + for throttle in &core.queue.config.throttle.sender { + if let Err(err) = core + .is_allowed(throttle, &self.message, &mut self.in_flight, &self.span) + .await + { + let event = match err { + throttle::Error::Concurrency { limiter } => { + // Save changes to disk + let next_due = self.message.next_event_after(now()); + self.message.save_changes(&core, None, None).await; + + Event::OnHold(OnHold { + next_due, + limiters: vec![limiter], + message: self.event, + }) + } + throttle::Error::Rate { retry_at } => { + // Save changes to disk + let next_event = std::cmp::min( + retry_at, + self.message.next_event_after(now()).unwrap_or(u64::MAX), + ); + self.message + .save_changes(&core, self.event.due.into(), next_event.into()) + .await; + + Event::Reload + } + }; + + if core.queue.tx.send(event).await.is_err() { + tracing::warn!("Channel closed while trying to notify queue manager."); } + return; } - return; } - } - tokio::spawn(async move { let queue_config = &core.queue.config; let mut on_hold = Vec::new(); let no_ip = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)); @@ -123,7 +133,7 @@ impl DeliveryAttempt { 'next_domain: for (domain_idx, domain) in domains.iter_mut().enumerate() { // Only process domains due for delivery if !matches!(&domain.status, Status::Scheduled | Status::TemporaryFailure(_) - if domain.retry.due <= Instant::now()) + if domain.retry.due <= now()) { continue; } @@ -138,7 +148,7 @@ impl DeliveryAttempt { // Build envelope let mut envelope = QueueEnvelope { - message: self.message.as_ref(), + message: &self.message, domain: &domain.domain, mx: "", remote_ip: no_ip, @@ -672,6 +682,7 @@ impl DeliveryAttempt { .unwrap_or_else(|| "localhost".to_string()); let params = SessionParams { span: &span, + core: &core, credentials: remote_host.credentials(), is_smtp: remote_host.is_smtp(), hostname: envelope.mx, @@ -1018,11 +1029,9 @@ impl DeliveryAttempt { // Notify queue manager let span = self.span; let result = if !on_hold.is_empty() { - // Release quota for completed deliveries - self.message.release_quota(); - // Save changes to disk - self.message.save_changes().await; + let next_due = self.message.next_event_after(now()); + self.message.save_changes(&core, None, None).await; tracing::info!( parent: &span, @@ -1032,17 +1041,16 @@ impl DeliveryAttempt { "Too many outbound concurrent connections, message moved to on-hold queue." ); - WorkerResult::OnHold(OnHold { - next_due: self.message.next_event_after(Instant::now()), + Event::OnHold(OnHold { + next_due, limiters: on_hold, - message: self.message, + message: self.event, }) } else if let Some(due) = self.message.next_event() { - // Release quota for completed deliveries - self.message.release_quota(); - // Save changes to disk - self.message.save_changes().await; + self.message + .save_changes(&core, self.event.due.into(), due.into()) + .await; tracing::info!( parent: &span, @@ -1052,13 +1060,10 @@ impl DeliveryAttempt { "Delivery was not possible, message re-queued for delivery." ); - WorkerResult::Retry(Schedule { - due, - inner: self.message, - }) + Event::Reload } else { // Delete message from queue - self.message.remove().await; + self.message.remove(&core, self.event.due).await; tracing::info!( parent: &span, @@ -1067,9 +1072,9 @@ impl DeliveryAttempt { "Delivery completed." ); - WorkerResult::Done + Event::Reload }; - if core.queue.tx.send(Event::Done(result)).await.is_err() { + if core.queue.tx.send(result).await.is_err() { tracing::warn!( parent: &span, "Channel closed while trying to notify queue manager." @@ -1080,7 +1085,7 @@ impl DeliveryAttempt { /// Marks as failed all domains that reached their expiration time pub fn has_pending_delivery(&mut self) -> bool { - let now = Instant::now(); + let now = now(); let mut has_pending_delivery = false; let span = self.span.clone(); @@ -1103,7 +1108,6 @@ impl DeliveryAttempt { domain.status = std::mem::replace(&mut domain.status, Status::Scheduled).into_permanent(); - domain.changed = true; } Status::Scheduled if domain.expires <= now => { tracing::info!( @@ -1123,7 +1127,6 @@ impl DeliveryAttempt { domain.status = Status::PermanentFailure(Error::Io( "Queue rate limit exceeded.".to_string(), )); - domain.changed = true; } Status::Completed(_) | Status::PermanentFailure(_) => (), _ => { @@ -1139,7 +1142,6 @@ impl DeliveryAttempt { impl Domain { pub fn set_status(&mut self, status: impl Into<Status<(), Error>>, schedule: &[Duration]) { self.status = status.into(); - self.changed = true; if matches!( &self.status, Status::TemporaryFailure(_) | Status::Scheduled @@ -1149,8 +1151,8 @@ impl Domain { } pub fn retry(&mut self, schedule: &[Duration]) { - self.retry.due = - Instant::now() + schedule[std::cmp::min(self.retry.inner as usize, schedule.len() - 1)]; + self.retry.due = now() + + schedule[std::cmp::min(self.retry.inner as usize, schedule.len() - 1)].as_secs(); self.retry.inner += 1; } } diff --git a/crates/smtp/src/outbound/local.rs b/crates/smtp/src/outbound/local.rs index 17bf09c2..a8962deb 100644 --- a/crates/smtp/src/outbound/local.rs +++ b/crates/smtp/src/outbound/local.rs @@ -63,7 +63,7 @@ impl Message { message: IngestMessage { sender_address: self.return_path_lcase.clone(), recipients: recipient_addresses, - message_path: self.path.clone(), + message_blob: self.blob_hash.clone(), message_size: self.size, }, result_tx, diff --git a/crates/smtp/src/outbound/mod.rs b/crates/smtp/src/outbound/mod.rs index d3164015..b5f3398c 100644 --- a/crates/smtp/src/outbound/mod.rs +++ b/crates/smtp/src/outbound/mod.rs @@ -25,6 +25,7 @@ use std::borrow::Cow; use mail_send::Credentials; use smtp_proto::{Response, Severity}; +use store::write::QueueEvent; use utils::config::ServerProtocol; use crate::{ @@ -211,8 +212,8 @@ impl From<mta_sts::Error> for Status<(), Error> { } } -impl From<Box<Message>> for DeliveryAttempt { - fn from(message: Box<Message>) -> Self { +impl DeliveryAttempt { + pub fn new(message: Message, event: QueueEvent) -> Self { DeliveryAttempt { span: tracing::info_span!( "delivery", @@ -227,6 +228,7 @@ impl From<Box<Message>> for DeliveryAttempt { ), in_flight: Vec::new(), message, + event, } } } diff --git a/crates/smtp/src/outbound/session.rs b/crates/smtp/src/outbound/session.rs index eb8c4420..74c55418 100644 --- a/crates/smtp/src/outbound/session.rs +++ b/crates/smtp/src/outbound/session.rs @@ -30,7 +30,6 @@ use smtp_proto::{ use std::fmt::Write; use std::time::Duration; use tokio::{ - fs, io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, net::TcpStream, }; @@ -38,6 +37,7 @@ use tokio_rustls::{client::TlsStream, TlsConnector}; use crate::{ config::{RequireOptional, TlsStrategy}, + core::SMTP, queue::{ErrorDetails, HostResponse, RCPT_STATUS_CHANGED}, }; @@ -45,6 +45,7 @@ use crate::queue::{Error, Message, Recipient, Status}; pub struct SessionParams<'x> { pub span: &'x tracing::Span, + pub core: &'x SMTP, pub hostname: &'x str, pub credentials: Option<&'x Credentials<String>>, pub is_smtp: bool, @@ -532,43 +533,53 @@ pub async fn send_message<T: AsyncRead + AsyncWrite + Unpin>( bdat_cmd: &Option<String>, params: &SessionParams<'_>, ) -> Result<(), Status<(), Error>> { - let mut raw_message = vec![0u8; message.size]; - let mut file = fs::File::open(&message.path).await.map_err(|err| { - tracing::error!(parent: params.span, - context = "queue", - event = "error", - "Failed to open message file {}: {}", - message.path.display(), - err); - Status::TemporaryFailure(Error::Io("Queue system error.".to_string())) - })?; - file.read_exact(&mut raw_message).await.map_err(|err| { - tracing::error!(parent: params.span, - context = "queue", - event = "error", - "Failed to read {} bytes file {} from disk: {}", - message.size, - message.path.display(), - err); - Status::TemporaryFailure(Error::Io("Queue system error.".to_string())) - })?; - tokio::time::timeout(params.timeout_data, async { - if let Some(bdat_cmd) = bdat_cmd { - write_chunks(smtp_client, &[bdat_cmd.as_bytes(), &raw_message]).await - } else { - write_chunks(smtp_client, &[b"DATA\r\n"]).await?; - smtp_client.read().await?.assert_code(354)?; - smtp_client - .write_message(&raw_message) - .await - .map_err(mail_send::Error::from) + match params + .core + .shared + .default_blob_store + .get_blob(message.blob_hash.as_slice(), 0..u32::MAX) + .await + { + Ok(Some(raw_message)) => tokio::time::timeout(params.timeout_data, async { + if let Some(bdat_cmd) = bdat_cmd { + write_chunks(smtp_client, &[bdat_cmd.as_bytes(), &raw_message]).await + } else { + write_chunks(smtp_client, &[b"DATA\r\n"]).await?; + smtp_client.read().await?.assert_code(354)?; + smtp_client + .write_message(&raw_message) + .await + .map_err(mail_send::Error::from) + } + }) + .await + .map_err(|_| Status::timeout(params.hostname, "sending message"))? + .map_err(|err| { + Status::from_smtp_error(params.hostname, bdat_cmd.as_deref().unwrap_or("DATA"), err) + }), + Ok(None) => { + tracing::error!(parent: params.span, + context = "queue", + event = "error", + "BlobHash {:?} does not exist.", + message.blob_hash, + ); + Err(Status::TemporaryFailure(Error::Io( + "Queue system error.".to_string(), + ))) } - }) - .await - .map_err(|_| Status::timeout(params.hostname, "sending message"))? - .map_err(|err| { - Status::from_smtp_error(params.hostname, bdat_cmd.as_deref().unwrap_or("DATA"), err) - }) + Err(err) => { + tracing::error!(parent: params.span, + context = "queue", + event = "error", + "Failed to fetch blobId {:?}: {}", + message.blob_hash, + err); + Err(Status::TemporaryFailure(Error::Io( + "Queue system error.".to_string(), + ))) + } + } } pub async fn say_helo<T: AsyncRead + AsyncWrite + Unpin>( diff --git a/crates/smtp/src/queue/dsn.rs b/crates/smtp/src/queue/dsn.rs index 4027edb3..aef875ac 100644 --- a/crates/smtp/src/queue/dsn.rs +++ b/crates/smtp/src/queue/dsn.rs @@ -30,22 +30,21 @@ use smtp_proto::{ Response, RCPT_NOTIFY_DELAY, RCPT_NOTIFY_FAILURE, RCPT_NOTIFY_NEVER, RCPT_NOTIFY_SUCCESS, }; use std::fmt::Write; -use std::time::{Duration, Instant}; -use tokio::fs::File; -use tokio::io::AsyncReadExt; +use std::time::Duration; +use store::write::now; use crate::core::SMTP; use super::{ - instant_to_timestamp, DeliveryAttempt, Domain, Error, ErrorDetails, HostResponse, Message, - Recipient, SimpleEnvelope, Status, RCPT_DSN_SENT, RCPT_STATUS_CHANGED, + DeliveryAttempt, Domain, Error, ErrorDetails, HostResponse, Message, Recipient, SimpleEnvelope, + Status, RCPT_DSN_SENT, RCPT_STATUS_CHANGED, }; impl SMTP { pub async fn send_dsn(&self, attempt: &mut DeliveryAttempt) { if !attempt.message.return_path.is_empty() { if let Some(dsn) = attempt.build_dsn(self).await { - let mut dsn_message = Message::new_boxed("", "", ""); + let mut dsn_message = self.queue.new_message("", "", ""); dsn_message .add_recipient_parts( &attempt.message.return_path, @@ -64,8 +63,8 @@ impl SMTP { &attempt.span, ) .await; - self.queue - .queue_message(dsn_message, signature.as_deref(), &dsn, &attempt.span) + dsn_message + .queue(signature.as_deref(), &dsn, self, &attempt.span) .await; } } else { @@ -77,7 +76,7 @@ impl SMTP { impl DeliveryAttempt { pub async fn build_dsn(&mut self, core: &SMTP) -> Option<Vec<u8>> { let config = &core.queue.config; - let now = Instant::now(); + let now = now(); let mut txt_success = String::new(); let mut txt_delay = String::new(); @@ -245,11 +244,10 @@ impl DeliveryAttempt { }) { domain.notify.inner += 1; - domain.notify.due = Instant::now() + next_notify; + domain.notify.due = now + next_notify.as_secs(); } else { - domain.notify.due = domain.expires + Duration::from_secs(10); + domain.notify.due = domain.expires + 10; } - domain.changed = true; } } self.message.domains = domains; @@ -257,15 +255,15 @@ impl DeliveryAttempt { // Obtain hostname and sender addresses let from_name = core - .eval_if(&config.dsn.name, self.message.as_ref()) + .eval_if(&config.dsn.name, &self.message) .await .unwrap_or_else(|| String::from("Mail Delivery Subsystem")); let from_addr = core - .eval_if(&config.dsn.address, self.message.as_ref()) + .eval_if(&config.dsn.address, &self.message) .await .unwrap_or_else(|| String::from("MAILER-DAEMON@localhost")); let reporting_mta = core - .eval_if(&config.hostname, self.message.as_ref()) + .eval_if(&config.hostname, &self.message) .await .unwrap_or_else(|| String::from("localhost")); @@ -276,55 +274,54 @@ impl DeliveryAttempt { let dsn = dsn_header + &dsn; // Fetch up to 1024 bytes of message headers - let headers = match File::open(&self.message.path).await { - Ok(mut file) => { - let mut buf = vec![0u8; std::cmp::min(self.message.size, 1024)]; - match file.read(&mut buf).await { - Ok(br) => { - let mut prev_ch = 0; - let mut last_lf = br; - for (pos, &ch) in buf.iter().enumerate() { - match ch { - b'\n' => { - last_lf = pos + 1; - if prev_ch != b'\n' { - prev_ch = ch; - } else { - break; - } - } - b'\r' => (), - 0 => break, - _ => { - prev_ch = ch; - } + let headers = match core + .shared + .default_blob_store + .get_blob(self.message.blob_hash.as_slice(), 0..1024) + .await + { + Ok(Some(mut buf)) => { + let mut prev_ch = 0; + let mut last_lf = buf.len(); + for (pos, &ch) in buf.iter().enumerate() { + match ch { + b'\n' => { + last_lf = pos + 1; + if prev_ch != b'\n' { + prev_ch = ch; + } else { + break; } } - if last_lf < 1024 { - buf.truncate(last_lf); + b'\r' => (), + 0 => break, + _ => { + prev_ch = ch; } - String::from_utf8(buf).unwrap_or_default() - } - Err(err) => { - tracing::error!( - parent: &self.span, - context = "queue", - event = "error", - "Failed to read from {}: {}", - self.message.path.display(), - err - ); - String::new() } } + if last_lf < 1024 { + buf.truncate(last_lf); + } + String::from_utf8(buf).unwrap_or_default() + } + Ok(None) => { + tracing::error!( + parent: &self.span, + context = "queue", + event = "error", + "Failed to open blob {:?}: not found", + self.message.blob_hash + ); + String::new() } Err(err) => { tracing::error!( parent: &self.span, context = "queue", event = "error", - "Failed to open file {}: {}", - self.message.path.display(), + "Failed to open blob {:?}: {}", + self.message.blob_hash, err ); String::new() @@ -387,10 +384,10 @@ impl DeliveryAttempt { } } - let now = Instant::now(); + let now = now(); for domain in &mut message.domains { if domain.notify.due <= now { - domain.notify.due = domain.expires + Duration::from_secs(10); + domain.notify.due = domain.expires + 10; } } @@ -520,13 +517,10 @@ impl Recipient { impl Domain { fn write_dsn_will_retry_until(&self, dsn: &mut String) { - let now = Instant::now(); + let now = now(); if self.expires > now { dsn.push_str("Will-Retry-Until: "); - dsn.push_str( - &DateTime::from_timestamp(instant_to_timestamp(now, self.expires) as i64) - .to_rfc822(), - ); + dsn.push_str(&DateTime::from_timestamp(self.expires as i64).to_rfc822()); dsn.push_str("\r\n"); } } diff --git a/crates/smtp/src/queue/manager.rs b/crates/smtp/src/queue/manager.rs index d44eaa35..4dfdc9a5 100644 --- a/crates/smtp/src/queue/manager.rs +++ b/crates/smtp/src/queue/manager.rs @@ -22,274 +22,84 @@ */ use std::{ - collections::BinaryHeap, sync::{atomic::Ordering, Arc}, - time::{Duration, Instant}, + time::Duration, }; -use ahash::AHashMap; -use smtp_proto::Response; +use store::write::{now, BatchBuilder, QueueClass, QueueEvent, ValueClass}; use tokio::sync::mpsc; -use crate::core::{ - management::{self}, - QueueCore, SMTP, -}; +use crate::core::SMTP; -use super::{ - DeliveryAttempt, Event, HostResponse, Message, OnHold, QueueId, Schedule, Status, WorkerResult, - RCPT_STATUS_CHANGED, -}; +use super::{DeliveryAttempt, Event, Message, OnHold, Status}; + +pub(crate) const SHORT_WAIT: Duration = Duration::from_millis(1); +pub(crate) const LONG_WAIT: Duration = Duration::from_secs(86400 * 365); #[derive(Debug)] pub struct Queue { - short_wait: Duration, - long_wait: Duration, - pub scheduled: BinaryHeap<Schedule<QueueId>>, - pub on_hold: Vec<OnHold<QueueId>>, - pub messages: AHashMap<QueueId, Box<Message>>, + pub on_hold: Vec<OnHold<QueueEvent>>, } impl SpawnQueue for mpsc::Receiver<Event> { - fn spawn(mut self, core: Arc<SMTP>, mut queue: Queue) { + fn spawn(mut self, core: Arc<SMTP>) { tokio::spawn(async move { + let mut queue = Queue::default(); + let mut next_wake_up = SHORT_WAIT; + loop { - let result = tokio::time::timeout(queue.wake_up_time(), self.recv()).await; + let on_hold = match tokio::time::timeout(next_wake_up, self.recv()).await { + Ok(Some(Event::OnHold(on_hold))) => on_hold.into(), + Ok(Some(Event::Stop)) | Ok(None) => { + break; + } + _ => None, + }; + + // Deliver any concurrency limited messages + let mut delete_events = Vec::new(); + while let Some(queue_event) = queue.next_on_hold() { + if let Some(message) = core.read_message(queue_event.queue_id).await { + DeliveryAttempt::new(message, queue_event) + .try_deliver(core.clone()) + .await; + } else { + delete_events.push(queue_event); + } + } // Deliver scheduled messages - while let Some(message) = queue.next_due() { - DeliveryAttempt::from(message) - .try_deliver(core.clone(), &mut queue) - .await; + let now = now(); + next_wake_up = LONG_WAIT; + for queue_event in core.next_event().await { + if queue_event.due <= now { + if let Some(message) = core.read_message(queue_event.queue_id).await { + DeliveryAttempt::new(message, queue_event) + .try_deliver(core.clone()) + .await; + } else { + delete_events.push(queue_event); + } + } else { + next_wake_up = Duration::from_secs(queue_event.due - now); + } } - match result { - Ok(Some(event)) => match event { - Event::Queue(item) => { - // Deliver any concurrency limited messages - while let Some(message) = queue.next_on_hold() { - DeliveryAttempt::from(message) - .try_deliver(core.clone(), &mut queue) - .await; - } - - if item.due <= Instant::now() { - DeliveryAttempt::from(item.inner) - .try_deliver(core.clone(), &mut queue) - .await; - } else { - queue.schedule(item); - } + // Delete unlinked events + if !delete_events.is_empty() { + let core = core.clone(); + tokio::spawn(async move { + let mut batch = BatchBuilder::new(); + for queue_event in delete_events { + batch.clear(ValueClass::Queue(QueueClass::MessageEvent(queue_event))); } - Event::Done(result) => { - // A worker is done, try delivering concurrency limited messages - while let Some(message) = queue.next_on_hold() { - DeliveryAttempt::from(message) - .try_deliver(core.clone(), &mut queue) - .await; - } - match result { - WorkerResult::Done => (), - WorkerResult::Retry(schedule) => { - queue.schedule(schedule); - } - WorkerResult::OnHold(on_hold) => { - queue.on_hold(on_hold); - } - } - } - Event::Manage(request) => match request { - management::QueueRequest::List { - from, - to, - before, - after, - result_tx, - } => { - let mut result = Vec::with_capacity(queue.messages.len()); - for message in queue.messages.values() { - if from.as_ref().map_or(false, |from| { - !message.return_path_lcase.contains(from) - }) { - continue; - } - if to.as_ref().map_or(false, |to| { - !message - .recipients - .iter() - .any(|rcpt| rcpt.address_lcase.contains(to)) - }) { - continue; - } - - if (before.is_some() || after.is_some()) - && !message.domains.iter().any(|domain| { - matches!( - &domain.status, - Status::Scheduled | Status::TemporaryFailure(_) - ) && match (&before, &after) { - (Some(before), Some(after)) => { - domain.retry.due.lt(before) - && domain.retry.due.gt(after) - } - (Some(before), None) => domain.retry.due.lt(before), - (None, Some(after)) => domain.retry.due.gt(after), - (None, None) => false, - } - }) - { - continue; - } - - result.push(message.id); - } - result.sort_unstable_by_key(|id| *id & 0xFFFFFFFF); - let _ = result_tx.send(result); - } - management::QueueRequest::Status { - queue_ids, - result_tx, - } => { - let mut result = Vec::with_capacity(queue_ids.len()); - for queue_id in queue_ids { - result.push( - queue - .messages - .get(&queue_id) - .map(|message| message.as_ref().into()), - ); - } - let _ = result_tx.send(result); - } - management::QueueRequest::Cancel { - queue_ids, - item, - result_tx, - } => { - let mut result = Vec::with_capacity(queue_ids.len()); - for queue_id in &queue_ids { - let mut found = false; - if let Some(item) = &item { - if let Some(message) = queue.messages.get_mut(queue_id) { - // Cancel delivery for all recipients that match - for rcpt in &mut message.recipients { - if rcpt.address_lcase.contains(item) { - rcpt.flags |= RCPT_STATUS_CHANGED; - rcpt.status = Status::Completed(HostResponse { - hostname: String::new(), - response: Response { - code: 0, - esc: [0, 0, 0], - message: "Delivery canceled." - .to_string(), - }, - }); - found = true; - } - } - if found { - // Mark as completed domains without any pending deliveries - for (domain_idx, domain) in - message.domains.iter_mut().enumerate() - { - if matches!( - domain.status, - Status::TemporaryFailure(_) - | Status::Scheduled - ) { - let mut total_rcpt = 0; - let mut total_completed = 0; - - for rcpt in &message.recipients { - if rcpt.domain_idx == domain_idx { - total_rcpt += 1; - if matches!( - rcpt.status, - Status::PermanentFailure(_) - | Status::Completed(_) - ) { - total_completed += 1; - } - } - } - - if total_rcpt == total_completed { - domain.status = Status::Completed(()); - domain.changed = true; - } - } - } - - // Delete message if there are no pending deliveries - if message.domains.iter().any(|domain| { - matches!( - domain.status, - Status::TemporaryFailure(_) - | Status::Scheduled - ) - }) { - message.save_changes().await; - } else { - message.remove().await; - queue.messages.remove(queue_id); - } - } - } - } else if let Some(message) = queue.messages.remove(queue_id) { - message.remove().await; - found = true; - } - result.push(found); - } - let _ = result_tx.send(result); - } - management::QueueRequest::Retry { - queue_ids, - item, - time, - result_tx, - } => { - let mut result = Vec::with_capacity(queue_ids.len()); - for queue_id in &queue_ids { - let mut found = false; - if let Some(message) = queue.messages.get_mut(queue_id) { - for domain in &mut message.domains { - if matches!( - domain.status, - Status::Scheduled | Status::TemporaryFailure(_) - ) && item - .as_ref() - .map_or(true, |item| domain.domain.contains(item)) - { - domain.retry.due = time; - if domain.expires > time { - domain.expires = time + Duration::from_secs(10); - } - domain.changed = true; - found = true; - } - } + let _ = core.shared.default_data_store.write(batch.build()).await; + }); + } - if found { - queue.on_hold.retain(|oh| &oh.message != queue_id); - message.save_changes().await; - if let Some(next_event) = message.next_event() { - queue.scheduled.push(Schedule { - due: next_event, - inner: *queue_id, - }); - } - } - } - result.push(found); - } - let _ = result_tx.send(result); - } - }, - Event::Stop => break, - }, - Ok(None) => break, - Err(_) => (), + // Add message on hold + if let Some(on_hold) = on_hold { + queue.on_hold(on_hold); } } }); @@ -297,36 +107,16 @@ impl SpawnQueue for mpsc::Receiver<Event> { } impl Queue { - pub fn schedule(&mut self, message: Schedule<Box<Message>>) { - self.scheduled.push(Schedule { - due: message.due, - inner: message.inner.id, - }); - self.messages.insert(message.inner.id, message.inner); - } - - pub fn on_hold(&mut self, message: OnHold<Box<Message>>) { + pub fn on_hold(&mut self, message: OnHold<QueueEvent>) { self.on_hold.push(OnHold { next_due: message.next_due, limiters: message.limiters, - message: message.message.id, + message: message.message, }); - self.messages.insert(message.message.id, message.message); } - pub fn next_due(&mut self) -> Option<Box<Message>> { - let item = self.scheduled.peek()?; - if item.due <= Instant::now() { - self.scheduled - .pop() - .and_then(|i| self.messages.remove(&i.inner)) - } else { - None - } - } - - pub fn next_on_hold(&mut self) -> Option<Box<Message>> { - let now = Instant::now(); + pub fn next_on_hold(&mut self) -> Option<QueueEvent> { + let now = now(); self.on_hold .iter() .position(|o| { @@ -335,24 +125,13 @@ impl Queue { .any(|l| l.concurrent.load(Ordering::Relaxed) < l.max_concurrent) || o.next_due.map_or(false, |due| due <= now) }) - .and_then(|pos| self.messages.remove(&self.on_hold.remove(pos).message)) - } - - pub fn wake_up_time(&self) -> Duration { - self.scheduled - .peek() - .map(|item| { - item.due - .checked_duration_since(Instant::now()) - .unwrap_or(self.short_wait) - }) - .unwrap_or(self.long_wait) + .map(|pos| self.on_hold.remove(pos).message) } } impl Message { - pub fn next_event(&self) -> Option<Instant> { - let mut next_event = Instant::now(); + pub fn next_event(&self) -> Option<u64> { + let mut next_event = now(); let mut has_events = false; for domain in &self.domains { @@ -380,8 +159,8 @@ impl Message { } } - pub fn next_delivery_event(&self) -> Instant { - let mut next_delivery = Instant::now(); + pub fn next_delivery_event(&self) -> u64 { + let mut next_delivery = now(); for (pos, domain) in self .domains @@ -397,7 +176,7 @@ impl Message { next_delivery } - pub fn next_event_after(&self, instant: Instant) -> Option<Instant> { + pub fn next_event_after(&self, instant: u64) -> Option<u64> { let mut next_event = None; for domain in &self.domains { @@ -431,129 +210,14 @@ impl Message { } } -impl QueueCore { - pub async fn read_queue(&self) -> Queue { - let mut queue = Queue::default(); - let mut messages = Vec::new(); - - let mut dir = match tokio::fs::read_dir(&self.config.path).await { - Ok(dir) => dir, - Err(err) => { - tracing::warn!( - "Failed to read queue directory {}: {}", - self.config.path.display(), - err - ); - return queue; - } - }; - loop { - match dir.next_entry().await { - Ok(Some(file)) => { - let file = file.path(); - if file.is_dir() { - match tokio::fs::read_dir(&file).await { - Ok(mut dir) => { - let file_ = file; - loop { - match dir.next_entry().await { - Ok(Some(file)) => { - let file = file.path(); - if file.extension().map_or(false, |e| e == "msg") { - messages - .push(tokio::spawn(Message::from_path(file))); - } - } - Ok(None) => break, - Err(err) => { - tracing::warn!( - "Failed to read queue directory {}: {}", - file_.display(), - err - ); - break; - } - } - } - } - Err(err) => { - tracing::warn!( - "Failed to read queue directory {}: {}", - file.display(), - err - ) - } - }; - } else if file.extension().map_or(false, |e| e == "msg") { - messages.push(tokio::spawn(Message::from_path(file))); - } - } - Ok(None) => { - break; - } - Err(err) => { - tracing::warn!( - "Failed to read queue directory {}: {}", - self.config.path.display(), - err - ); - break; - } - } - } - - // Join all futures - for message in messages { - match message.await { - Ok(Ok(mut message)) => { - // Reserve quota - let todo = true; - //self.has_quota(&mut message).await; - - // Schedule message - queue.schedule(Schedule { - due: message.next_event().unwrap_or_else(|| { - tracing::warn!( - context = "queue", - event = "warn", - "No due events found for message {}", - message.path.display() - ); - Instant::now() - }), - inner: Box::new(message), - }); - } - Ok(Err(err)) => { - tracing::warn!( - context = "queue", - event = "error", - "Queue startup error: {}", - err - ); - } - Err(err) => { - tracing::error!("Join error while starting queue: {}", err); - } - } - } - - queue - } -} - impl Default for Queue { fn default() -> Self { Queue { - short_wait: Duration::from_millis(1), - long_wait: Duration::from_secs(86400 * 365), - scheduled: BinaryHeap::with_capacity(128), on_hold: Vec::with_capacity(128), - messages: AHashMap::with_capacity(128), } } } pub trait SpawnQueue { - fn spawn(self, core: Arc<SMTP>, queue: Queue); + fn spawn(self, core: Arc<SMTP>); } diff --git a/crates/smtp/src/queue/mod.rs b/crates/smtp/src/queue/mod.rs index 230e1892..eb621dc9 100644 --- a/crates/smtp/src/queue/mod.rs +++ b/crates/smtp/src/queue/mod.rs @@ -24,21 +24,22 @@ use std::{ fmt::Display, net::IpAddr, - path::PathBuf, - sync::{atomic::AtomicUsize, Arc}, time::{Duration, Instant, SystemTime}, }; use serde::{Deserialize, Serialize}; use smtp_proto::Response; -use utils::listener::limiter::{ConcurrencyLimiter, InFlight}; +use store::write::{now, QueueEvent}; +use utils::{ + listener::limiter::{ConcurrencyLimiter, InFlight}, + BlobHash, +}; -use crate::core::{eval::*, management, ResolveVariable}; +use crate::core::{eval::*, ResolveVariable}; pub mod dsn; pub mod manager; pub mod quota; -pub mod serialize; pub mod spool; pub mod throttle; @@ -46,37 +47,29 @@ pub type QueueId = u64; #[derive(Debug)] pub enum Event { - Queue(Schedule<Box<Message>>), - Manage(management::QueueRequest), - Done(WorkerResult), + Reload, + OnHold(OnHold<QueueEvent>), Stop, } #[derive(Debug)] -pub enum WorkerResult { - Done, - Retry(Schedule<Box<Message>>), - OnHold(OnHold<Box<Message>>), -} - -#[derive(Debug)] pub struct OnHold<T> { - pub next_due: Option<Instant>, + pub next_due: Option<u64>, pub limiters: Vec<ConcurrencyLimiter>, pub message: T, } -#[derive(Debug)] +#[derive(Debug, Serialize, Deserialize)] pub struct Schedule<T> { - pub due: Instant, + pub due: u64, pub inner: T, } -#[derive(Debug)] +#[derive(Debug, serde::Serialize, serde::Deserialize)] pub struct Message { pub id: QueueId, pub created: u64, - pub path: PathBuf, + pub blob_hash: BlobHash, pub return_path: String, pub return_path_lcase: String, @@ -89,21 +82,26 @@ pub struct Message { pub priority: i16, pub size: usize, - pub queue_refs: Vec<UsedQuota>, + pub quota_keys: Vec<QuotaKey>, } -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, serde::Serialize, serde::Deserialize)] +pub enum QuotaKey { + Size { key: Vec<u8>, id: u64 }, + Count { key: Vec<u8>, id: u64 }, +} + +#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct Domain { pub domain: String, pub retry: Schedule<u32>, pub notify: Schedule<u32>, - pub expires: Instant, + pub expires: u64, pub status: Status<(), Error>, pub disable_tls: bool, - pub changed: bool, } -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct Recipient { pub domain_idx: usize, pub address: String, @@ -128,13 +126,13 @@ pub enum Status<T, E> { PermanentFailure(E), } -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct HostResponse<T> { pub hostname: T, pub response: Response<String>, } -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum Error { DnsError(String), UnexpectedResponse(HostResponse<ErrorDetails>), @@ -147,7 +145,7 @@ pub enum Error { Io(String), } -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct ErrorDetails { pub entity: String, pub details: String, @@ -156,32 +154,10 @@ pub struct ErrorDetails { pub struct DeliveryAttempt { pub span: tracing::Span, pub in_flight: Vec<InFlight>, - pub message: Box<Message>, -} - -#[derive(Debug)] -pub struct QuotaLimiter { - pub max_size: usize, - pub max_messages: usize, - pub size: AtomicUsize, - pub messages: AtomicUsize, -} - -#[derive(Debug)] -pub struct UsedQuota { - id: u64, - size: usize, - limiter: Arc<QuotaLimiter>, -} - -impl PartialEq for UsedQuota { - fn eq(&self, other: &Self) -> bool { - self.id == other.id && self.size == other.size - } + pub message: Message, + pub event: QueueEvent, } -impl Eq for UsedQuota {} - impl<T> Ord for Schedule<T> { fn cmp(&self, other: &Self) -> std::cmp::Ordering { other.due.cmp(&self.due) @@ -205,14 +181,14 @@ impl<T> Eq for Schedule<T> {} impl<T: Default> Schedule<T> { pub fn now() -> Self { Schedule { - due: Instant::now(), + due: now(), inner: T::default(), } } pub fn later(duration: Duration) -> Self { Schedule { - due: Instant::now() + duration, + due: now() + duration.as_secs(), inner: T::default(), } } diff --git a/crates/smtp/src/queue/quota.rs b/crates/smtp/src/queue/quota.rs index 8b6b6412..e3fa2268 100644 --- a/crates/smtp/src/queue/quota.rs +++ b/crates/smtp/src/queue/quota.rs @@ -21,25 +21,26 @@ * for more details. */ -use std::sync::{atomic::Ordering, Arc}; - -use dashmap::mapref::entry::Entry; +use store::{ + write::{BatchBuilder, QueueClass, ValueClass}, + ValueKey, +}; use crate::{ config::QueueQuota, core::{ResolveVariable, SMTP}, }; -use super::{Message, QuotaLimiter, SimpleEnvelope, Status, UsedQuota}; +use super::{Message, QuotaKey, SimpleEnvelope, Status}; impl SMTP { pub async fn has_quota(&self, message: &mut Message) -> bool { - let mut queue_refs = Vec::new(); + let mut quota_keys = Vec::new(); if !self.queue.config.quota.sender.is_empty() { for quota in &self.queue.config.quota.sender { if !self - .reserve_quota(quota, message, message.size, 0, &mut queue_refs) + .check_quota(quota, message, message.size, 0, &mut quota_keys) .await { return false; @@ -50,12 +51,12 @@ impl SMTP { for quota in &self.queue.config.quota.rcpt_domain { for (pos, domain) in message.domains.iter().enumerate() { if !self - .reserve_quota( + .check_quota( quota, &SimpleEnvelope::new(message, &domain.domain), message.size, ((pos + 1) << 32) as u64, - &mut queue_refs, + &mut quota_keys, ) .await { @@ -67,7 +68,7 @@ impl SMTP { for quota in &self.queue.config.quota.rcpt { for (pos, rcpt) in message.recipients.iter().enumerate() { if !self - .reserve_quota( + .check_quota( quota, &SimpleEnvelope::new_rcpt( message, @@ -76,7 +77,7 @@ impl SMTP { ), message.size, (pos + 1) as u64, - &mut queue_refs, + &mut quota_keys, ) .await { @@ -85,47 +86,65 @@ impl SMTP { } } - message.queue_refs = queue_refs; + message.quota_keys = quota_keys; true } - async fn reserve_quota( + async fn check_quota( &self, quota: &QueueQuota, envelope: &impl ResolveVariable, size: usize, id: u64, - refs: &mut Vec<UsedQuota>, + refs: &mut Vec<QuotaKey>, ) -> bool { if !quota.expr.is_empty() && self - .eval_expr("a.expr, envelope, "reserve_quota") + .eval_expr("a.expr, envelope, "check_quota") .await .unwrap_or(false) { - match self.queue.quota.entry(quota.new_key(envelope)) { - Entry::Occupied(e) => { - if let Some(qref) = e.get().is_allowed(id, size) { - refs.push(qref); - } else { - return false; - } - } - Entry::Vacant(e) => { - let limiter = Arc::new(QuotaLimiter { - max_size: quota.size.unwrap_or(0), - max_messages: quota.messages.unwrap_or(0), - size: 0.into(), - messages: 0.into(), + let key = quota.new_key(envelope); + if let Some(max_size) = quota.size { + if self + .shared + .default_data_store + .get_counter(ValueKey::from(ValueClass::Queue(QueueClass::QuotaSize( + key.as_ref().to_vec(), + )))) + .await + .unwrap_or(0) as usize + + size + > max_size + { + return false; + } else { + refs.push(QuotaKey::Size { + key: key.as_ref().to_vec(), + id, }); + } + } - if let Some(qref) = limiter.is_allowed(id, size) { - refs.push(qref); - e.insert(limiter); - } else { - return false; - } + if let Some(max_messages) = quota.messages { + if self + .shared + .default_data_store + .get_counter(ValueKey::from(ValueClass::Queue(QueueClass::QuotaCount( + key.as_ref().to_vec(), + )))) + .await + .unwrap_or(0) as usize + + 1 + > max_messages + { + return false; + } else { + refs.push(QuotaKey::Count { + key: key.as_ref().to_vec(), + id, + }); } } } @@ -134,7 +153,10 @@ impl SMTP { } impl Message { - pub fn release_quota(&mut self) { + pub fn release_quota(&mut self, batch: &mut BatchBuilder) { + if self.quota_keys.is_empty() { + return; + } let mut quota_ids = Vec::with_capacity(self.domains.len() + self.recipients.len()); for (pos, domain) in self.domains.iter().enumerate() { if matches!( @@ -153,48 +175,21 @@ impl Message { } } if !quota_ids.is_empty() { - self.queue_refs.retain(|q| !quota_ids.contains(&q.id)); - } - } -} - -trait QuotaLimiterAllowed { - fn is_allowed(&self, id: u64, size: usize) -> Option<UsedQuota>; -} - -impl QuotaLimiterAllowed for Arc<QuotaLimiter> { - fn is_allowed(&self, id: u64, size: usize) -> Option<UsedQuota> { - if self.max_messages > 0 { - if self.messages.load(Ordering::Relaxed) < self.max_messages { - self.messages.fetch_add(1, Ordering::Relaxed); - } else { - return None; - } - } - - if self.max_size > 0 { - if self.size.load(Ordering::Relaxed) + size < self.max_size { - self.size.fetch_add(size, Ordering::Relaxed); - } else { - return None; + let mut quota_keys = Vec::new(); + for quota_key in std::mem::take(&mut self.quota_keys) { + match quota_key { + QuotaKey::Count { id, key } if quota_ids.contains(&id) => { + batch.clear(ValueClass::Queue(QueueClass::QuotaCount(key))); + } + QuotaKey::Size { id, key } if quota_ids.contains(&id) => { + batch.clear(ValueClass::Queue(QueueClass::QuotaSize(key))); + } + _ => { + quota_keys.push(quota_key); + } + } } - } - - Some(UsedQuota { - id, - size, - limiter: self.clone(), - }) - } -} - -impl Drop for UsedQuota { - fn drop(&mut self) { - if self.limiter.max_messages > 0 { - self.limiter.messages.fetch_sub(1, Ordering::Relaxed); - } - if self.limiter.max_size > 0 { - self.limiter.size.fetch_sub(self.size, Ordering::Relaxed); + self.quota_keys = quota_keys; } } } diff --git a/crates/smtp/src/queue/serialize.rs b/crates/smtp/src/queue/serialize.rs deleted file mode 100644 index e056888b..00000000 --- a/crates/smtp/src/queue/serialize.rs +++ /dev/null @@ -1,565 +0,0 @@ -/* - * Copyright (c) 2023 Stalwart Labs Ltd. - * - * This file is part of Stalwart Mail Server. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as - * published by the Free Software Foundation, either version 3 of - * the License, or (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * in the LICENSE file at the top-level directory of this distribution. - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see <http://www.gnu.org/licenses/>. - * - * You can be released from the requirements of the AGPLv3 license by - * purchasing a commercial license. Please contact licensing@stalw.art - * for more details. -*/ - -use mail_auth::common::base32::Base32Reader; -use smtp_proto::Response; -use std::io::SeekFrom; -use std::path::PathBuf; -use std::slice::Iter; -use std::{fmt::Write, time::Instant}; -use tokio::fs; -use tokio::fs::File; -use tokio::io::{AsyncReadExt, AsyncSeekExt}; - -use super::{ - instant_to_timestamp, Domain, DomainPart, Error, ErrorDetails, HostResponse, - InstantFromTimestamp, Message, Recipient, Schedule, Status, RCPT_STATUS_CHANGED, -}; - -pub trait QueueSerializer: Sized { - fn serialize(&self, buf: &mut String); - fn deserialize(bytes: &mut Iter<'_, u8>) -> Option<Self>; -} - -impl Message { - pub fn serialize(&self) -> Vec<u8> { - let mut buf = String::with_capacity( - self.return_path.len() - + self.env_id.as_ref().map_or(0, |e| e.len()) - + (self.domains.len() * 64) - + (self.recipients.len() * 64) - + 50, - ); - - // Serialize message properties - (self.created as usize).serialize(&mut buf); - self.return_path.serialize(&mut buf); - (self.env_id.as_deref().unwrap_or_default()).serialize(&mut buf); - (self.flags as usize).serialize(&mut buf); - self.priority.serialize(&mut buf); - - // Serialize domains - let now = Instant::now(); - self.domains.len().serialize(&mut buf); - for domain in &self.domains { - domain.domain.serialize(&mut buf); - (instant_to_timestamp(now, domain.expires) as usize).serialize(&mut buf); - } - - // Serialize recipients - self.recipients.len().serialize(&mut buf); - for rcpt in &self.recipients { - rcpt.domain_idx.serialize(&mut buf); - rcpt.address.serialize(&mut buf); - (rcpt.orcpt.as_deref().unwrap_or_default()).serialize(&mut buf); - } - - // Serialize domain status - for (idx, domain) in self.domains.iter().enumerate() { - domain.serialize(idx, now, &mut buf); - } - - // Serialize recipient status - for (idx, rcpt) in self.recipients.iter().enumerate() { - rcpt.serialize(idx, &mut buf); - } - - buf.into_bytes() - } - - pub fn serialize_changes(&mut self) -> Vec<u8> { - let now = Instant::now(); - let mut buf = String::with_capacity(128); - - for (idx, domain) in self.domains.iter_mut().enumerate() { - if domain.changed { - domain.changed = false; - domain.serialize(idx, now, &mut buf); - } - } - - for (idx, rcpt) in self.recipients.iter_mut().enumerate() { - if rcpt.has_flag(RCPT_STATUS_CHANGED) { - rcpt.flags &= !RCPT_STATUS_CHANGED; - rcpt.serialize(idx, &mut buf); - } - } - - buf.into_bytes() - } - - pub async fn from_path(path: PathBuf) -> Result<Self, String> { - let filename = path - .file_name() - .and_then(|f| f.to_str()) - .and_then(|f| f.rsplit_once('.')) - .map(|(f, _)| f) - .ok_or_else(|| format!("Invalid queue file name {}", path.display()))?; - - // Decode file name - let mut id = [0u8; std::mem::size_of::<u64>()]; - let mut size = [0u8; std::mem::size_of::<u32>()]; - - for (pos, byte) in Base32Reader::new(filename.as_bytes()).enumerate() { - match pos { - 0..=7 => { - id[pos] = byte; - } - 8..=11 => { - size[pos - 8] = byte; - } - _ => { - return Err(format!("Invalid queue file name {}", path.display())); - } - } - } - - let id = u64::from_le_bytes(id); - let size = u32::from_le_bytes(size) as u64; - - // Obtail file size - let file_size = fs::metadata(&path) - .await - .map_err(|err| { - format!( - "Failed to obtain file metadata for {}: {}", - path.display(), - err - ) - })? - .len(); - if size == 0 || size >= file_size { - return Err(format!( - "Invalid queue file name size {} for {}", - size, - path.display() - )); - } - let mut buf = Vec::with_capacity((file_size - size) as usize); - let mut file = File::open(&path) - .await - .map_err(|err| format!("Failed to open queue file {}: {}", path.display(), err))?; - file.seek(SeekFrom::Start(size)) - .await - .map_err(|err| format!("Failed to seek queue file {}: {}", path.display(), err))?; - file.read_to_end(&mut buf) - .await - .map_err(|err| format!("Failed to read queue file {}: {}", path.display(), err))?; - - let mut message = Self::deserialize(&buf) - .ok_or_else(|| format!("Failed to deserialize metadata for file {}", path.display()))?; - message.path = path; - message.size = size as usize; - message.id = id; - Ok(message) - } - - pub fn deserialize(bytes: &[u8]) -> Option<Self> { - let mut bytes = bytes.iter(); - let created = usize::deserialize(&mut bytes)? as u64; - let return_path = String::deserialize(&mut bytes)?; - let return_path_lcase = return_path.to_lowercase(); - let env_id = String::deserialize(&mut bytes)?; - - let mut message = Message { - id: 0, - path: PathBuf::new(), - created, - return_path_domain: return_path_lcase.domain_part().to_string(), - return_path_lcase, - return_path, - env_id: if !env_id.is_empty() { - env_id.into() - } else { - None - }, - flags: usize::deserialize(&mut bytes)? as u64, - priority: i16::deserialize(&mut bytes)?, - size: 0, - recipients: vec![], - domains: vec![], - queue_refs: vec![], - }; - - // Deserialize domains - let num_domains = usize::deserialize(&mut bytes)?; - message.domains = Vec::with_capacity(num_domains); - for _ in 0..num_domains { - message.domains.push(Domain { - domain: String::deserialize(&mut bytes)?, - expires: Instant::deserialize(&mut bytes)?, - retry: Schedule::now(), - notify: Schedule::now(), - status: Status::Scheduled, - disable_tls: false, - changed: false, - }); - } - - // Deserialize recipients - let num_recipients = usize::deserialize(&mut bytes)?; - message.recipients = Vec::with_capacity(num_recipients); - for _ in 0..num_recipients { - let domain_idx = usize::deserialize(&mut bytes)?; - let address = String::deserialize(&mut bytes)?; - let orcpt = String::deserialize(&mut bytes)?; - message.recipients.push(Recipient { - domain_idx, - address_lcase: address.to_lowercase(), - address, - status: Status::Scheduled, - flags: 0, - orcpt: if !orcpt.is_empty() { - orcpt.into() - } else { - None - }, - }); - } - - // Deserialize status - while let Some((ch, idx)) = bytes - .next() - .and_then(|ch| (ch, usize::deserialize(&mut bytes)?).into()) - { - match ch { - b'D' => { - if let (Some(domain), Some(retry), Some(notify), Some(status)) = ( - message.domains.get_mut(idx), - Schedule::deserialize(&mut bytes), - Schedule::deserialize(&mut bytes), - Status::deserialize(&mut bytes), - ) { - domain.retry = retry; - domain.notify = notify; - domain.status = status; - } else { - break; - } - } - b'R' => { - if let (Some(rcpt), Some(flags), Some(status)) = ( - message.recipients.get_mut(idx), - usize::deserialize(&mut bytes), - Status::deserialize(&mut bytes), - ) { - rcpt.flags = flags as u64; - rcpt.status = status; - } else { - break; - } - } - _ => break, - } - } - - message.into() - } -} - -impl<T: QueueSerializer, E: QueueSerializer> QueueSerializer for Status<T, E> { - fn serialize(&self, buf: &mut String) { - match self { - Status::Scheduled => buf.push('S'), - Status::Completed(s) => { - buf.push('C'); - s.serialize(buf); - } - Status::TemporaryFailure(s) => { - buf.push('T'); - s.serialize(buf); - } - Status::PermanentFailure(s) => { - buf.push('F'); - s.serialize(buf); - } - } - } - - fn deserialize(bytes: &mut Iter<'_, u8>) -> Option<Self> { - match bytes.next()? { - b'S' => Self::Scheduled.into(), - b'C' => Self::Completed(T::deserialize(bytes)?).into(), - b'T' => Self::TemporaryFailure(E::deserialize(bytes)?).into(), - b'F' => Self::PermanentFailure(E::deserialize(bytes)?).into(), - _ => None, - } - } -} - -impl QueueSerializer for Response<String> { - fn serialize(&self, buf: &mut String) { - let _ = write!( - buf, - "{} {} {} {} {} {}", - self.code, - self.esc[0], - self.esc[1], - self.esc[2], - self.message.len(), - self.message - ); - } - - fn deserialize(bytes: &mut Iter<'_, u8>) -> Option<Self> { - Response { - code: usize::deserialize(bytes)? as u16, - esc: [ - usize::deserialize(bytes)? as u8, - usize::deserialize(bytes)? as u8, - usize::deserialize(bytes)? as u8, - ], - message: String::deserialize(bytes)?, - } - .into() - } -} - -impl QueueSerializer for usize { - fn serialize(&self, buf: &mut String) { - let _ = write!(buf, "{self} "); - } - - fn deserialize(bytes: &mut Iter<'_, u8>) -> Option<Self> { - let mut num = 0; - loop { - match bytes.next()? { - ch @ (b'0'..=b'9') => { - num = (num * 10) + (*ch - b'0') as usize; - } - b' ' => { - return num.into(); - } - _ => { - return None; - } - } - } - } -} - -impl QueueSerializer for i16 { - fn serialize(&self, buf: &mut String) { - let _ = write!(buf, "{self} "); - } - - fn deserialize(bytes: &mut Iter<'_, u8>) -> Option<Self> { - let mut num = 0; - let mut mul = 1; - loop { - match bytes.next()? { - ch @ (b'0'..=b'9') => { - num = (num * 10) + (*ch - b'0') as i16; - } - b' ' => { - return (num * mul).into(); - } - b'-' => { - mul = -1; - } - _ => { - return None; - } - } - } - } -} - -impl QueueSerializer for ErrorDetails { - fn serialize(&self, buf: &mut String) { - self.entity.serialize(buf); - self.details.serialize(buf); - } - - fn deserialize(bytes: &mut Iter<'_, u8>) -> Option<Self> { - ErrorDetails { - entity: String::deserialize(bytes)?, - details: String::deserialize(bytes)?, - } - .into() - } -} - -impl<T: QueueSerializer> QueueSerializer for HostResponse<T> { - fn serialize(&self, buf: &mut String) { - self.hostname.serialize(buf); - self.response.serialize(buf); - } - - fn deserialize(bytes: &mut Iter<'_, u8>) -> Option<Self> { - HostResponse { - hostname: T::deserialize(bytes)?, - response: Response::deserialize(bytes)?, - } - .into() - } -} - -impl QueueSerializer for String { - fn serialize(&self, buf: &mut String) { - if !self.is_empty() { - let _ = write!(buf, "{} {}", self.len(), self); - } else { - buf.push_str("0 "); - } - } - - fn deserialize(bytes: &mut Iter<'_, u8>) -> Option<Self> { - match usize::deserialize(bytes)? { - len @ (1..=4096) => { - String::from_utf8(bytes.take(len).copied().collect::<Vec<_>>()).ok() - } - 0 => String::new().into(), - _ => None, - } - } -} - -impl QueueSerializer for &str { - fn serialize(&self, buf: &mut String) { - if !self.is_empty() { - let _ = write!(buf, "{} {}", self.len(), self); - } else { - buf.push_str("0 "); - } - } - - fn deserialize(_bytes: &mut Iter<'_, u8>) -> Option<Self> { - unimplemented!() - } -} - -impl QueueSerializer for Instant { - fn serialize(&self, buf: &mut String) { - let _ = write!(buf, "{} ", instant_to_timestamp(Instant::now(), *self),); - } - - fn deserialize(bytes: &mut Iter<'_, u8>) -> Option<Self> { - (usize::deserialize(bytes)? as u64).to_instant().into() - } -} - -impl QueueSerializer for Schedule<u32> { - fn serialize(&self, buf: &mut String) { - let _ = write!( - buf, - "{} {} ", - self.inner, - instant_to_timestamp(Instant::now(), self.due), - ); - } - - fn deserialize(bytes: &mut Iter<'_, u8>) -> Option<Self> { - Schedule { - inner: usize::deserialize(bytes)? as u32, - due: Instant::deserialize(bytes)?, - } - .into() - } -} - -impl QueueSerializer for Error { - fn serialize(&self, buf: &mut String) { - match self { - Error::DnsError(e) => { - buf.push('0'); - e.serialize(buf); - } - Error::UnexpectedResponse(e) => { - buf.push('1'); - e.serialize(buf); - } - Error::ConnectionError(e) => { - buf.push('2'); - e.serialize(buf); - } - Error::TlsError(e) => { - buf.push('3'); - e.serialize(buf); - } - Error::DaneError(e) => { - buf.push('4'); - e.serialize(buf); - } - Error::MtaStsError(e) => { - buf.push('5'); - e.serialize(buf); - } - Error::RateLimited => { - buf.push('6'); - } - Error::ConcurrencyLimited => { - buf.push('7'); - } - Error::Io(e) => { - buf.push('8'); - e.serialize(buf); - } - } - } - - fn deserialize(bytes: &mut Iter<'_, u8>) -> Option<Self> { - match bytes.next()? { - b'0' => Error::DnsError(String::deserialize(bytes)?).into(), - b'1' => Error::UnexpectedResponse(HostResponse::deserialize(bytes)?).into(), - b'2' => Error::ConnectionError(ErrorDetails::deserialize(bytes)?).into(), - b'3' => Error::TlsError(ErrorDetails::deserialize(bytes)?).into(), - b'4' => Error::DaneError(ErrorDetails::deserialize(bytes)?).into(), - b'5' => Error::MtaStsError(String::deserialize(bytes)?).into(), - b'6' => Error::RateLimited.into(), - b'7' => Error::ConcurrencyLimited.into(), - b'8' => Error::Io(String::deserialize(bytes)?).into(), - _ => None, - } - } -} - -impl QueueSerializer for () { - fn serialize(&self, _buf: &mut String) {} - - fn deserialize(_bytes: &mut Iter<'_, u8>) -> Option<Self> { - Some(()) - } -} - -impl Domain { - fn serialize(&self, idx: usize, now: Instant, buf: &mut String) { - let _ = write!( - buf, - "D{} {} {} {} {} ", - idx, - self.retry.inner, - instant_to_timestamp(now, self.retry.due), - self.notify.inner, - instant_to_timestamp(now, self.notify.due) - ); - self.status.serialize(buf); - } -} - -impl Recipient { - fn serialize(&self, idx: usize, buf: &mut String) { - let _ = write!(buf, "R{} {} ", idx, self.flags); - self.status.serialize(buf); - } -} diff --git a/crates/smtp/src/queue/spool.rs b/crates/smtp/src/queue/spool.rs index fde2171c..b21371c7 100644 --- a/crates/smtp/src/queue/spool.rs +++ b/crates/smtp/src/queue/spool.rs @@ -22,99 +22,167 @@ */ use crate::queue::DomainPart; -use mail_auth::common::base32::Base32Writer; -use mail_auth::common::headers::Writer; -use std::path::PathBuf; -use std::sync::atomic::Ordering; -use std::time::Instant; +use std::borrow::Cow; use std::time::{Duration, SystemTime}; -use tokio::fs::OpenOptions; -use tokio::{fs, io::AsyncWriteExt}; +use store::write::key::DeserializeBigEndian; +use store::write::{now, BatchBuilder, Bincode, BlobOp, QueueClass, QueueEvent, ValueClass}; +use store::{IterateParams, Serialize, ValueKey, U64_LEN}; +use utils::BlobHash; use crate::core::{QueueCore, SMTP}; -use super::{Domain, Event, Message, Recipient, Schedule, SimpleEnvelope, Status}; +use super::{ + Domain, Event, Message, QueueId, QuotaKey, Recipient, Schedule, SimpleEnvelope, Status, +}; impl QueueCore { - pub async fn queue_message( + pub fn new_message( &self, - mut message: Box<Message>, - raw_headers: Option<&[u8]>, - raw_message: &[u8], - span: &tracing::Span, - ) -> bool { - // Generate id - if message.id == 0 { - message.id = self.queue_id(); - } - if message.size == 0 { - message.size = raw_message.len() + raw_headers.as_ref().map_or(0, |h| h.len()); + return_path: impl Into<String>, + return_path_lcase: impl Into<String>, + return_path_domain: impl Into<String>, + ) -> Message { + let created = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .map_or(0, |d| d.as_secs()); + Message { + id: self.snowflake_id.generate().unwrap_or(created), + created, + return_path: return_path.into(), + return_path_lcase: return_path_lcase.into(), + return_path_domain: return_path_domain.into(), + recipients: Vec::with_capacity(1), + domains: Vec::with_capacity(1), + flags: 0, + env_id: None, + priority: 0, + size: 0, + blob_hash: Default::default(), + quota_keys: Vec::new(), } + } +} - // Build path - let todo = 1; - message.path = self.config.path.clone(); - let hash = 1; - if hash > 0 { - message.path.push((message.id % hash).to_string()); - } - let _ = fs::create_dir(&message.path).await; +impl SMTP { + pub async fn next_event(&self) -> Vec<QueueEvent> { + let from_key = ValueKey::from(ValueClass::Queue(QueueClass::MessageEvent(QueueEvent { + due: 0, + queue_id: 0, + }))); + let to_key = ValueKey::from(ValueClass::Queue(QueueClass::MessageEvent(QueueEvent { + due: u64::MAX, + queue_id: u64::MAX, + }))); + + let mut events = Vec::new(); + let now = now(); + let result = self + .shared + .default_data_store + .iterate( + IterateParams::new(from_key, to_key).ascending().no_values(), + |key, _| { + let event = QueueEvent { + due: key.deserialize_be_u64(1)?, + queue_id: key.deserialize_be_u64(U64_LEN + 1)?, + }; + let do_continue = event.due <= now; + events.push(event); + Ok(do_continue) + }, + ) + .await; - // Encode file name - let mut encoder = Base32Writer::with_capacity(20); - encoder.write(&message.id.to_le_bytes()[..]); - encoder.write(&(message.size as u32).to_le_bytes()[..]); - let mut file = encoder.finalize(); - file.push_str(".msg"); - message.path.push(file); + if let Err(err) = result { + tracing::error!( + context = "queue", + event = "error", + "Failed to read from store: {}", + err + ); + } - // Serialize metadata - let metadata = message.serialize(); + events + } - // Save message - let mut file = match fs::File::create(&message.path).await { - Ok(file) => file, + pub async fn read_message(&self, id: QueueId) -> Option<Message> { + match self + .shared + .default_data_store + .get_value::<Bincode<Message>>(ValueKey::from(ValueClass::Queue(QueueClass::Message( + id, + )))) + .await + { + Ok(Some(message)) => Some(message.inner), + Ok(None) => None, Err(err) => { tracing::error!( - parent: span, context = "queue", event = "error", - "Failed to create file {}: {}", - message.path.display(), + "Failed to read message from store: {}", err ); - return false; + None } - }; + } + } +} - let iter = if let Some(raw_headers) = raw_headers { - [raw_headers, raw_message, &metadata].into_iter() +impl Message { + pub async fn queue( + mut self, + raw_headers: Option<&[u8]>, + raw_message: &[u8], + core: &SMTP, + span: &tracing::Span, + ) -> bool { + // Write blob + let message = if let Some(raw_headers) = raw_headers { + let mut message = Vec::with_capacity(raw_headers.len() + raw_message.len()); + message.extend_from_slice(raw_headers); + message.extend_from_slice(raw_message); + Cow::Owned(message) } else { - [raw_message, &metadata, b""].into_iter() + raw_message.into() }; + self.blob_hash = BlobHash::from(message.as_ref()); - for bytes in iter { - if !bytes.is_empty() { - if let Err(err) = file.write_all(bytes).await { - tracing::error!( - parent: span, - context = "queue", - event = "error", - "Failed to write to file {}: {}", - message.path.display(), - err - ); - return false; - } - } + // Generate id + if self.size == 0 { + self.size = message.len(); + } + + // Reserve and write blob + let mut batch = BatchBuilder::new(); + batch.with_account_id(u32::MAX).set( + BlobOp::Reserve { + hash: self.blob_hash.clone(), + until: self.next_delivery_event() + 3600, + }, + 0u32.serialize(), + ); + if let Err(err) = core.shared.default_data_store.write(batch.build()).await { + tracing::error!( + parent: span, + context = "queue", + event = "error", + "Failed to write to data store: {}", + err + ); + return false; } - if let Err(err) = file.flush().await { + if let Err(err) = core + .shared + .default_blob_store + .put_blob(self.blob_hash.as_slice(), message.as_ref()) + .await + { tracing::error!( parent: span, context = "queue", event = "error", - "Failed to flush file {}: {}", - message.path.display(), + "Failed to write to blob store: {}", err ); return false; @@ -124,27 +192,45 @@ impl QueueCore { parent: span, context = "queue", event = "scheduled", - id = message.id, - from = if !message.return_path.is_empty() { - message.return_path.as_str() + id = self.id, + from = if !self.return_path.is_empty() { + self.return_path.as_str() } else { "<>" }, - nrcpts = message.recipients.len(), - size = message.size, + nrcpts = self.recipients.len(), + size = self.size, "Message queued for delivery." ); + // Write message to queue + let mut batch = BatchBuilder::new(); + batch + .set( + ValueClass::Queue(QueueClass::MessageEvent(QueueEvent { + due: self.next_event().unwrap_or_default(), + queue_id: self.id, + })), + vec![], + ) + .set( + ValueClass::Queue(QueueClass::Message(self.id)), + Bincode::new(self).serialize(), + ); + + if let Err(err) = core.shared.default_data_store.write(batch.build()).await { + tracing::error!( + parent: span, + context = "queue", + event = "error", + "Failed to write to store: {}", + err + ); + return false; + } + // Queue the message - if self - .tx - .send(Event::Queue(Schedule { - due: message.next_event().unwrap(), - inner: message, - })) - .await - .is_err() - { + if core.queue.tx.send(Event::Reload).await.is_err() { tracing::warn!( parent: span, context = "queue", @@ -156,42 +242,6 @@ impl QueueCore { true } - pub fn queue_id(&self) -> u64 { - (SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .map_or(0, |d| d.as_secs()) - .saturating_sub(946684800) - & 0xFFFFFFFF) - | (self.id_seq.fetch_add(1, Ordering::Relaxed) as u64) << 32 - } -} - -impl Message { - pub fn new_boxed( - return_path: impl Into<String>, - return_path_lcase: impl Into<String>, - return_path_domain: impl Into<String>, - ) -> Box<Message> { - Box::new(Message { - id: 0, - path: PathBuf::new(), - created: SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .map(|d| d.as_secs()) - .unwrap_or(0), - return_path: return_path.into(), - return_path_lcase: return_path_lcase.into(), - return_path_domain: return_path_domain.into(), - recipients: Vec::with_capacity(1), - domains: Vec::with_capacity(1), - flags: 0, - env_id: None, - priority: 0, - size: 0, - queue_refs: vec![], - }) - } - pub async fn add_recipient_parts( &mut self, rcpt: impl Into<String>, @@ -216,10 +266,9 @@ impl Message { domain: rcpt_domain, retry: Schedule::now(), notify: Schedule::later(expires + Duration::from_secs(10)), - expires: Instant::now() + expires, + expires: now() + expires.as_secs(), status: Status::Scheduled, disable_tls: false, - changed: false, }); idx }; @@ -241,35 +290,95 @@ impl Message { .await; } - pub async fn save_changes(&mut self) { - let buf = self.serialize_changes(); - if !buf.is_empty() { - let err = match OpenOptions::new().append(true).open(&self.path).await { - Ok(mut file) => match file.write_all(&buf).await { - Ok(_) => return, - Err(err) => err, + pub async fn save_changes( + mut self, + core: &SMTP, + prev_event: Option<u64>, + next_event: Option<u64>, + ) -> bool { + debug_assert!(prev_event.is_some() == next_event.is_some()); + + let mut batch = BatchBuilder::new(); + + // Release quota for completed deliveries + self.release_quota(&mut batch); + + // Update message queue + let mut batch = BatchBuilder::new(); + if let Some(prev_event) = prev_event { + batch.clear(ValueClass::Queue(QueueClass::MessageEvent(QueueEvent { + due: prev_event, + queue_id: self.id, + }))); + } + if let Some(next_event) = next_event { + batch.set( + ValueClass::Queue(QueueClass::MessageEvent(QueueEvent { + due: next_event, + queue_id: self.id, + })), + vec![], + ); + } + batch + .with_account_id(u32::MAX) + .set( + BlobOp::Reserve { + hash: self.blob_hash.clone(), + until: self.next_delivery_event() + 3600, }, - Err(err) => err, - }; + 0u32.serialize(), + ) + .set( + ValueClass::Queue(QueueClass::Message(self.id)), + Bincode::new(self).serialize(), + ); + + if let Err(err) = core.shared.default_data_store.write(batch.build()).await { tracing::error!( context = "queue", event = "error", - "Failed to write to {}: {}", - self.path.display(), + "Failed to update queued message: {}", err ); + false + } else { + true } } - pub async fn remove(&self) { - if let Err(err) = fs::remove_file(&self.path).await { + pub async fn remove(self, core: &SMTP, prev_event: u64) -> bool { + let mut batch = BatchBuilder::new(); + + // Release all quotas + for quota_key in self.quota_keys { + match quota_key { + QuotaKey::Count { key, .. } => { + batch.clear(ValueClass::Queue(QueueClass::QuotaCount(key))); + } + QuotaKey::Size { key, .. } => { + batch.clear(ValueClass::Queue(QueueClass::QuotaSize(key))); + } + } + } + + batch + .clear(ValueClass::Queue(QueueClass::MessageEvent(QueueEvent { + due: prev_event, + queue_id: self.id, + }))) + .clear(ValueClass::Queue(QueueClass::Message(self.id))); + + if let Err(err) = core.shared.default_data_store.write(batch.build()).await { tracing::error!( context = "queue", event = "error", - "Failed to delete queued message {}: {}", - self.path.display(), + "Failed to update queued message: {}", err ); + false + } else { + true } } } diff --git a/crates/smtp/src/queue/throttle.rs b/crates/smtp/src/queue/throttle.rs index 5f624719..17130cbf 100644 --- a/crates/smtp/src/queue/throttle.rs +++ b/crates/smtp/src/queue/throttle.rs @@ -21,14 +21,13 @@ * for more details. */ -use std::time::{Duration, Instant}; - use dashmap::mapref::entry::Entry; -use utils::listener::limiter::{ConcurrencyLimiter, InFlight, RateLimiter}; +use store::write::now; +use utils::listener::limiter::{ConcurrencyLimiter, InFlight}; use crate::{ config::Throttle, - core::{throttle::Limiter, ResolveVariable, SMTP}, + core::{ResolveVariable, SMTP}, }; use super::{Domain, Status}; @@ -36,7 +35,7 @@ use super::{Domain, Status}; #[derive(Debug)] pub enum Error { Concurrency { limiter: ConcurrencyLimiter }, - Rate { retry_at: Instant }, + Rate { retry_at: u64 }, } impl SMTP { @@ -53,10 +52,33 @@ impl SMTP { .await .unwrap_or(false) { - match self.queue.throttle.entry(throttle.new_key(envelope)) { - Entry::Occupied(mut e) => { - let limiter = e.get_mut(); - if let Some(limiter) = &limiter.concurrency { + let key = throttle.new_key(envelope); + + if let Some(rate) = &throttle.rate { + if let Ok(Some(next_refill)) = self + .shared + .default_lookup_store + .is_rate_allowed(key.as_ref(), rate, false) + .await + { + tracing::info!( + parent: span, + context = "throttle", + event = "rate-limit-exceeded", + max_requests = rate.requests, + max_interval = rate.period.as_secs(), + "Queue rate limit exceeded." + ); + return Err(Error::Rate { + retry_at: now() + next_refill, + }); + } + } + + if let Some(concurrency) = &throttle.concurrency { + match self.queue.throttle.entry(key) { + Entry::Occupied(mut e) => { + let limiter = e.get_mut(); if let Some(inflight) = limiter.is_allowed() { in_flight.push(inflight); } else { @@ -72,38 +94,13 @@ impl SMTP { }); } } - if let (Some(limiter), Some(rate)) = (&mut limiter.rate, &throttle.rate) { - if !limiter.is_allowed(rate) { - tracing::info!( - parent: span, - context = "throttle", - event = "rate-limit-exceeded", - max_requests = rate.requests, - max_interval = rate.period.as_secs(), - "Queue rate limit exceeded." - ); - return Err(Error::Rate { - retry_at: Instant::now() - + Duration::from_secs(limiter.secs_to_refill()), - }); - } - } - } - Entry::Vacant(e) => { - let concurrency = throttle.concurrency.map(|concurrency| { - let limiter = ConcurrencyLimiter::new(concurrency); + Entry::Vacant(e) => { + let limiter = ConcurrencyLimiter::new(*concurrency); if let Some(inflight) = limiter.is_allowed() { in_flight.push(inflight); } - limiter - }); - let rate = throttle.rate.as_ref().map(|rate| { - let r = RateLimiter::new(rate); - r.is_allowed(rate); - r - }); - - e.insert(Limiter { rate, concurrency }); + e.insert(limiter); + } } } } @@ -124,6 +121,5 @@ impl Domain { self.status = Status::TemporaryFailure(super::Error::RateLimited); } } - self.changed = true; } } diff --git a/crates/smtp/src/reporting/dkim.rs b/crates/smtp/src/reporting/dkim.rs index 8415eff8..23a78170 100644 --- a/crates/smtp/src/reporting/dkim.rs +++ b/crates/smtp/src/reporting/dkim.rs @@ -46,7 +46,7 @@ impl<T: AsyncWrite + AsyncRead + Unpin> Session<T> { }; // Throttle recipient - if !self.throttle_rcpt(rcpt, rate, "dkim") { + if !self.throttle_rcpt(rcpt, rate, "dkim").await { tracing::debug!( parent: &self.span, context = "report", diff --git a/crates/smtp/src/reporting/dmarc.rs b/crates/smtp/src/reporting/dmarc.rs index 2df9b32c..c1b7d626 100644 --- a/crates/smtp/src/reporting/dmarc.rs +++ b/crates/smtp/src/reporting/dmarc.rs @@ -21,7 +21,7 @@ * for more details. */ -use std::{collections::hash_map::Entry, path::PathBuf, sync::Arc}; +use std::collections::hash_map::Entry; use ahash::AHashMap; use mail_auth::{ @@ -31,28 +31,22 @@ use mail_auth::{ ArcOutput, AuthenticatedMessage, AuthenticationResults, DkimOutput, DkimResult, DmarcOutput, SpfResult, }; -use serde::{Deserialize, Serialize}; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - runtime::Handle, +use store::{ + write::{now, BatchBuilder, Bincode, QueueClass, ReportEvent, ValueClass}, + Deserialize, IterateParams, Serialize, ValueKey, }; +use tokio::io::{AsyncRead, AsyncWrite}; use utils::config::Rate; use crate::{ config::AggregateFrequency, core::{Session, SMTP}, - queue::{DomainPart, InstantFromTimestamp, RecipientDomain, Schedule}, + queue::{DomainPart, RecipientDomain}, }; -use super::{ - scheduler::{ - json_append, json_read_blocking, json_write, ReportPath, ReportPolicy, ReportType, - Scheduler, ToHash, - }, - DmarcEvent, -}; +use super::{scheduler::ToHash, DmarcEvent, SerializedSize}; -#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct DmarcFormat { pub rua: Vec<URI>, pub policy: PolicyPublished, @@ -88,16 +82,15 @@ impl<T: AsyncWrite + AsyncRead + Unpin> Session<T> { { Some(rcpts) => { if !rcpts.is_empty() { - rcpts - .into_iter() - .filter_map(|rcpt| { - if self.throttle_rcpt(rcpt.uri(), &failure_rate, "dmarc") { - rcpt.uri().into() - } else { - None - } - }) - .collect() + let mut new_rcpts = Vec::with_capacity(rcpts.len()); + + for rcpt in rcpts { + if self.throttle_rcpt(rcpt.uri(), &failure_rate, "dmarc").await { + new_rcpts.push(rcpt.uri()); + } + } + + new_rcpts } else { if !dmarc_record.ruf().is_empty() { tracing::debug!( @@ -306,224 +299,307 @@ impl<T: AsyncWrite + AsyncRead + Unpin> Session<T> { } } -pub trait GenerateDmarcReport { - fn generate_dmarc_report(&self, domain: ReportPolicy<String>, path: ReportPath<PathBuf>); -} - -impl GenerateDmarcReport for Arc<SMTP> { - fn generate_dmarc_report(&self, domain: ReportPolicy<String>, path: ReportPath<PathBuf>) { - let core = self.clone(); - let handle = Handle::current(); - - self.worker_pool.spawn(move || { - let deliver_at = path.created + path.deliver_at.as_secs(); - let span = tracing::info_span!( - "dmarc-report", - domain = domain.inner, - range_from = path.created, - range_to = deliver_at, - size = path.size, - ); - - // Deserialize report - let dmarc = if let Some(dmarc) = json_read_blocking::<DmarcFormat>(&path.path, &span) { - dmarc - } else { +impl SMTP { + pub async fn generate_dmarc_report(&self, event: ReportEvent) { + let span = tracing::info_span!( + "dmarc-report", + domain = event.domain, + range_from = event.seq_id, + range_to = event.due, + ); + + // Deserialize report + let dmarc = match self + .shared + .default_data_store + .get_value::<Bincode<DmarcFormat>>(ValueKey::from(ValueClass::Queue( + QueueClass::DmarcReportHeader(event.clone()), + ))) + .await + { + Ok(Some(dmarc)) => dmarc.inner, + Ok(None) => { + tracing::warn!( + parent: &span, + event = "missing", + "Failed to read DMARC report: Report not found" + ); return; - }; + } + Err(err) => { + tracing::warn!( + parent: &span, + event = "error", + "Failed to read DMARC report: {}", + err + ); + return; + } + }; - // Verify external reporting addresses - let rua = match handle.block_on( - core.resolvers - .dns - .verify_dmarc_report_address(&domain.inner, &dmarc.rua), - ) { - Some(rcpts) => { - if !rcpts.is_empty() { - rcpts - .into_iter() - .map(|u| u.uri().to_string()) - .collect::<Vec<_>>() - } else { - tracing::info!( - parent: &span, - event = "failed", - reason = "unauthorized-rua", - rua = ?dmarc.rua, - "Unauthorized external reporting addresses" - ); - let _ = std::fs::remove_file(&path.path); - return; - } - } - None => { + // Verify external reporting addresses + let rua = match self + .resolvers + .dns + .verify_dmarc_report_address(&event.domain, &dmarc.rua) + .await + { + Some(rcpts) => { + if !rcpts.is_empty() { + rcpts + .into_iter() + .map(|u| u.uri().to_string()) + .collect::<Vec<_>>() + } else { tracing::info!( parent: &span, event = "failed", - reason = "dns-failure", + reason = "unauthorized-rua", rua = ?dmarc.rua, - "Failed to validate external report addresses", + "Unauthorized external reporting addresses" ); - let _ = std::fs::remove_file(&path.path); + self.delete_dmarc_report(event).await; return; } - }; - - let config = &core.report.config.dmarc_aggregate; + } + None => { + tracing::info!( + parent: &span, + event = "failed", + reason = "dns-failure", + rua = ?dmarc.rua, + "Failed to validate external report addresses", + ); + self.delete_dmarc_report(event).await; + return; + } + }; - // Group duplicates - let mut record_map = AHashMap::with_capacity(dmarc.records.len()); - for record in dmarc.records { - match record_map.entry(record) { + let mut serialized_size = serde_json::Serializer::new(SerializedSize::new( + self.eval_if( + &self.report.config.dmarc_aggregate.max_size, + &RecipientDomain::new(event.domain.as_str()), + ) + .await + .unwrap_or(25 * 1024 * 1024), + )); + let _ = serde::Serialize::serialize(&dmarc, &mut serialized_size); + let config = &self.report.config.dmarc_aggregate; + + // Group duplicates + let from_key = ValueKey::from(ValueClass::Queue(QueueClass::DmarcReportEvent( + ReportEvent { + due: event.due, + policy_hash: event.policy_hash, + seq_id: 0, + domain: event.domain.clone(), + }, + ))); + let to_key = ValueKey::from(ValueClass::Queue(QueueClass::DmarcReportEvent( + ReportEvent { + due: event.due, + policy_hash: event.policy_hash, + seq_id: u64::MAX, + domain: event.domain.clone(), + }, + ))); + let mut record_map = AHashMap::with_capacity(dmarc.records.len()); + if let Err(err) = self + .shared + .default_data_store + .iterate( + IterateParams::new(from_key, to_key).ascending(), + |_, v| match record_map.entry(Bincode::<Record>::deserialize(v)?.inner) { Entry::Occupied(mut e) => { *e.get_mut() += 1; + Ok(true) } Entry::Vacant(e) => { - e.insert(1u32); + if serde::Serialize::serialize(e.key(), &mut serialized_size).is_ok() { + e.insert(1u32); + Ok(true) + } else { + Ok(false) + } } - } - } + }, + ) + .await + { + tracing::warn!( + parent: &span, + event = "error", + "Failed to read DMARC report: {}", + err + ); + } - // Create report - let mut report = Report::new() - .with_policy_published(dmarc.policy) - .with_date_range_begin(path.created) - .with_date_range_end(deliver_at) - .with_report_id(format!("{}_{}", domain.policy, path.created)) - .with_email( - handle - .block_on(core.eval_if( - &config.address, - &RecipientDomain::new(domain.inner.as_str()), - )) - .unwrap_or_else(|| "MAILER-DAEMON@localhost".to_string()), - ); - if let Some(org_name) = handle.block_on(core.eval_if::<String, _>( + // Create report + let mut report = Report::new() + .with_policy_published(dmarc.policy) + .with_date_range_begin(event.seq_id) + .with_date_range_end(event.due) + .with_report_id(format!("{}_{}", event.policy_hash, event.seq_id)) + .with_email( + self.eval_if( + &config.address, + &RecipientDomain::new(event.domain.as_str()), + ) + .await + .unwrap_or_else(|| "MAILER-DAEMON@localhost".to_string()), + ); + if let Some(org_name) = self + .eval_if::<String, _>( &config.org_name, - &RecipientDomain::new(domain.inner.as_str()), - )) { - report = report.with_org_name(org_name); - } - if let Some(contact_info) = handle.block_on(core.eval_if::<String, _>( + &RecipientDomain::new(event.domain.as_str()), + ) + .await + { + report = report.with_org_name(org_name); + } + if let Some(contact_info) = self + .eval_if::<String, _>( &config.contact_info, - &RecipientDomain::new(domain.inner.as_str()), - )) { - report = report.with_extra_contact_info(contact_info); - } - for (record, count) in record_map { - report.add_record(record.with_count(count)); - } - let from_addr = handle - .block_on(core.eval_if( - &config.address, - &RecipientDomain::new(domain.inner.as_str()), - )) - .unwrap_or_else(|| "MAILER-DAEMON@localhost".to_string()); - let mut message = Vec::with_capacity(path.size); - let _ = - report.write_rfc5322( - &handle - .block_on(core.eval_if( - &core.report.config.submitter, - &RecipientDomain::new(domain.inner.as_str()), - )) - .unwrap_or_else(|| "localhost".to_string()), - ( - handle - .block_on(core.eval_if( - &config.name, - &RecipientDomain::new(domain.inner.as_str()), - )) - .unwrap_or_else(|| "Mail Delivery Subsystem".to_string()) - .as_str(), - from_addr.as_str(), - ), - rua.iter().map(|a| a.as_str()), - &mut message, - ); + &RecipientDomain::new(event.domain.as_str()), + ) + .await + { + report = report.with_extra_contact_info(contact_info); + } + for (record, count) in record_map { + report.add_record(record.with_count(count)); + } + let from_addr = self + .eval_if( + &config.address, + &RecipientDomain::new(event.domain.as_str()), + ) + .await + .unwrap_or_else(|| "MAILER-DAEMON@localhost".to_string()); + let mut message = Vec::with_capacity(2048); + let _ = report.write_rfc5322( + &self + .eval_if( + &self.report.config.submitter, + &RecipientDomain::new(event.domain.as_str()), + ) + .await + .unwrap_or_else(|| "localhost".to_string()), + ( + self.eval_if(&config.name, &RecipientDomain::new(event.domain.as_str())) + .await + .unwrap_or_else(|| "Mail Delivery Subsystem".to_string()) + .as_str(), + from_addr.as_str(), + ), + rua.iter().map(|a| a.as_str()), + &mut message, + ); + + // Send report + self.send_report(&from_addr, rua.iter(), message, &config.sign, &span, false) + .await; - // Send report - handle.block_on(core.send_report( - &from_addr, - rua.iter(), - message, - &config.sign, - &span, - false, - )); - - if let Err(err) = std::fs::remove_file(&path.path) { - tracing::warn!( - context = "report", - event = "error", - "Failed to remove report file {}: {}", - path.path.display(), - err - ); - } - }); + self.delete_dmarc_report(event).await; } -} -impl Scheduler { - pub async fn schedule_dmarc(&mut self, event: Box<DmarcEvent>, core: &SMTP) { - let max_size = core - .eval_if( - &core.report.config.dmarc_aggregate.max_size, - &RecipientDomain::new(event.domain.as_str()), + pub async fn delete_dmarc_report(&self, event: ReportEvent) { + let from_key = ReportEvent { + due: event.due, + policy_hash: event.policy_hash, + seq_id: 0, + domain: event.domain.clone(), + }; + let to_key = ReportEvent { + due: event.due, + policy_hash: event.policy_hash, + seq_id: u64::MAX, + domain: event.domain.clone(), + }; + + if let Err(err) = self + .shared + .default_data_store + .delete_range( + ValueKey::from(ValueClass::Queue(QueueClass::DmarcReportEvent(from_key))), + ValueKey::from(ValueClass::Queue(QueueClass::DmarcReportEvent(to_key))), ) .await - .unwrap_or(25 * 1024 * 1024); - - let policy = event.dmarc_record.to_hash(); - let (create, path) = match self.reports.entry(ReportType::Dmarc(ReportPolicy { - inner: event.domain, - policy, - })) { - Entry::Occupied(e) => (None, e.into_mut().dmarc_path()), - Entry::Vacant(e) => { - let domain = e.key().domain_name().to_string(); - let created = event.interval.to_timestamp(); - let deliver_at = created + event.interval.as_secs(); - - self.main.push(Schedule { - due: deliver_at.to_instant(), - inner: e.key().clone(), - }); - let path = core - .build_report_path(ReportType::Dmarc(&domain), policy, created, event.interval) - .await; - let v = e.insert(ReportType::Dmarc(ReportPath { - path, - deliver_at: event.interval, - created, - size: 0, - })); - (domain.into(), v.dmarc_path()) - } + { + tracing::warn!( + context = "report", + event = "error", + "Failed to remove repors: {}", + err + ); + return; + } + + let mut batch = BatchBuilder::new(); + batch.clear(ValueClass::Queue(QueueClass::DmarcReportHeader(event))); + if let Err(err) = self.shared.default_data_store.write(batch.build()).await { + tracing::warn!( + context = "report", + event = "error", + "Failed to remove repors: {}", + err + ); + } + } + + pub async fn schedule_dmarc(&self, event: Box<DmarcEvent>) { + let created = event.interval.to_timestamp(); + let deliver_at = created + event.interval.as_secs(); + let mut report_event = ReportEvent { + due: deliver_at, + policy_hash: event.dmarc_record.to_hash(), + seq_id: created, + domain: event.domain, }; - if let Some(domain) = create { + // Write policy if missing + let mut builder = BatchBuilder::new(); + if self + .shared + .default_data_store + .get_value::<()>(ValueKey::from(ValueClass::Queue( + QueueClass::DmarcReportHeader(report_event.clone()), + ))) + .await + .unwrap_or_default() + .is_none() + { // Serialize report let entry = DmarcFormat { rua: event.dmarc_record.rua().to_vec(), - policy: PolicyPublished::from_record(domain, &event.dmarc_record), - records: vec![event.report_record], + policy: PolicyPublished::from_record( + report_event.domain.to_string(), + &event.dmarc_record, + ), + records: vec![], }; - let bytes_written = json_write(&path.path, &entry).await; - if bytes_written > 0 { - path.size += bytes_written; - } else { - // Something went wrong, remove record - self.reports.remove(&ReportType::Dmarc(ReportPolicy { - inner: entry.policy.domain, - policy, - })); - } - } else if path.size < max_size { - // Append to existing report - path.size += json_append(&path.path, &event.report_record, max_size - path.size).await; + // Write report + builder.set( + ValueClass::Queue(QueueClass::DmarcReportHeader(report_event.clone())), + Bincode::new(entry).serialize(), + ); + } + + // Write entry + report_event.seq_id = self.queue.snowflake_id.generate().unwrap_or_else(now); + builder.set( + ValueClass::Queue(QueueClass::DmarcReportEvent(report_event)), + Bincode::new(event.report_record).serialize(), + ); + + if let Err(err) = self.shared.default_data_store.write(builder.build()).await { + tracing::error!( + context = "report", + event = "error", + "Failed to write DMARC report event: {}", + err + ); } } } diff --git a/crates/smtp/src/reporting/mod.rs b/crates/smtp/src/reporting/mod.rs index d2dd6b73..4dd2c6f8 100644 --- a/crates/smtp/src/reporting/mod.rs +++ b/crates/smtp/src/reporting/mod.rs @@ -21,7 +21,7 @@ * for more details. */ -use std::{sync::Arc, time::SystemTime}; +use std::{io, sync::Arc, time::SystemTime}; use mail_auth::{ common::headers::HeaderWriter, @@ -37,15 +37,13 @@ use tokio::io::{AsyncRead, AsyncWrite}; use utils::config::if_block::IfBlock; use crate::{ - config::{AddressMatch, AggregateFrequency, DkimSigner}, - core::{management, Session, SMTP}, + config::{AddressMatch, AggregateFrequency}, + core::{Session, SMTP}, outbound::{dane::Tlsa, mta_sts::Policy}, queue::{DomainPart, Message}, USER_AGENT, }; -use self::scheduler::{ReportKey, ReportValue}; - pub mod analysis; pub mod dkim; pub mod dmarc; @@ -57,7 +55,6 @@ pub mod tls; pub enum Event { Dmarc(Box<DmarcEvent>), Tls(Box<TlsEvent>), - Manage(management::ReportRequest), Stop, } @@ -137,9 +134,11 @@ impl SMTP { // Build message let from_addr_lcase = from_addr.to_lowercase(); let from_addr_domain = from_addr_lcase.domain_part().to_string(); - let mut message = Message::new_boxed(from_addr, from_addr_lcase, from_addr_domain); + let mut message = self + .queue + .new_message(from_addr, from_addr_lcase, from_addr_domain); for rcpt_ in rcpts { - message.add_recipient(rcpt_.as_ref(), &self).await; + message.add_recipient(rcpt_.as_ref(), self).await; } // Sign message @@ -164,8 +163,8 @@ impl SMTP { } // Queue message - self.queue - .queue_message(message, signature.as_deref(), &report, span) + message + .queue(signature.as_deref(), &report, self, span) .await; } @@ -300,42 +299,28 @@ impl From<(&Option<Arc<Policy>>, &Option<Arc<Tlsa>>)> for PolicyType { } } -impl ReportKey { - pub fn domain(&self) -> &str { - match self { - scheduler::ReportType::Dmarc(p) => &p.inner, - scheduler::ReportType::Tls(d) => d, - } +pub(crate) struct SerializedSize { + bytes_left: usize, +} + +impl SerializedSize { + pub fn new(bytes_left: usize) -> Self { + Self { bytes_left } } } -impl ReportValue { - pub async fn delete(&self) { - match self { - scheduler::ReportType::Dmarc(path) => { - if let Err(err) = tokio::fs::remove_file(&path.path).await { - tracing::warn!( - context = "report", - event = "error", - "Failed to remove report file {}: {}", - path.path.display(), - err - ); - } - } - scheduler::ReportType::Tls(path) => { - for path in &path.path { - if let Err(err) = tokio::fs::remove_file(&path.inner).await { - tracing::warn!( - context = "report", - event = "error", - "Failed to remove report file {}: {}", - path.inner.display(), - err - ); - } - } - } +impl io::Write for SerializedSize { + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + let buf_len = buf.len(); + if buf_len <= self.bytes_left { + self.bytes_left -= buf_len; + Ok(buf_len) + } else { + Err(io::Error::new(io::ErrorKind::Other, "Size exceeded")) } } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } } diff --git a/crates/smtp/src/reporting/scheduler.rs b/crates/smtp/src/reporting/scheduler.rs index 30772c67..5cc29268 100644 --- a/crates/smtp/src/reporting/scheduler.rs +++ b/crates/smtp/src/reporting/scheduler.rs @@ -22,156 +22,82 @@ */ use ahash::{AHashMap, RandomState}; -use mail_auth::{ - common::{ - base32::{Base32Reader, Base32Writer}, - headers::Writer, - }, - dmarc::Dmarc, -}; +use mail_auth::dmarc::Dmarc; -use serde::{de::DeserializeOwned, Serialize}; use std::{ - collections::{hash_map::Entry, BinaryHeap}, - hash::Hash, - path::PathBuf, sync::Arc, time::{Duration, Instant, SystemTime}, }; -use tokio::{ - fs::{self, OpenOptions}, - io::AsyncWriteExt, - sync::mpsc, +use store::{ + write::{now, QueueClass, ReportEvent, ValueClass}, + Deserialize, IterateParams, ValueKey, }; +use tokio::sync::mpsc; use crate::{ - config::AggregateFrequency, - core::{management::ReportRequest, worker::SpawnCleanup, ReportCore, SMTP}, - queue::{InstantFromTimestamp, Schedule}, + core::{worker::SpawnCleanup, SMTP}, + queue::manager::LONG_WAIT, }; -use super::{dmarc::GenerateDmarcReport, tls::GenerateTlsReport, Event}; - -pub type ReportKey = ReportType<ReportPolicy<String>, String>; -pub type ReportValue = ReportType<ReportPath<PathBuf>, ReportPath<Vec<ReportPolicy<PathBuf>>>>; - -pub struct Scheduler { - short_wait: Duration, - long_wait: Duration, - pub main: BinaryHeap<Schedule<ReportKey>>, - pub reports: AHashMap<ReportKey, ReportValue>, -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize)] -pub enum ReportType<T, U> { - Dmarc(T), - Tls(U), -} - -#[derive(Debug, PartialEq, Eq)] -pub struct ReportPath<T> { - pub path: T, - pub size: usize, - pub created: u64, - pub deliver_at: AggregateFrequency, -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct ReportPolicy<T> { - pub inner: T, - pub policy: u64, -} +use super::Event; impl SpawnReport for mpsc::Receiver<Event> { - fn spawn(mut self, core: Arc<SMTP>, mut scheduler: Scheduler) { + fn spawn(mut self, core: Arc<SMTP>) { tokio::spawn(async move { let mut last_cleanup = Instant::now(); + let mut next_wake_up; loop { - match tokio::time::timeout(scheduler.wake_up_time(), self.recv()).await { + // Read events + let now = now(); + let events = core.next_report_event().await; + next_wake_up = events + .last() + .and_then(|e| match e { + QueueClass::DmarcReportHeader(e) | QueueClass::TlsReportHeader(e) + if e.due > now => + { + Duration::from_secs(e.due - now).into() + } + _ => None, + }) + .unwrap_or(LONG_WAIT); + + let core_ = core.clone(); + tokio::spawn(async move { + let mut tls_reports = AHashMap::new(); + for report_event in events { + match report_event { + QueueClass::DmarcReportHeader(event) if event.due <= now => { + core_.generate_dmarc_report(event).await; + } + QueueClass::TlsReportHeader(event) if event.due <= now => { + tls_reports + .entry(event.domain.clone()) + .or_insert_with(Vec::new) + .push(event); + } + _ => (), + } + } + + for (domain_name, tls_report) in tls_reports { + core_.generate_tls_report(domain_name, tls_report).await; + } + }); + + match tokio::time::timeout(next_wake_up, self.recv()).await { Ok(Some(event)) => match event { Event::Dmarc(event) => { - scheduler.schedule_dmarc(event, &core).await; + core.schedule_dmarc(event).await; } Event::Tls(event) => { - scheduler.schedule_tls(event, &core).await; + core.schedule_tls(event).await; } - Event::Manage(request) => match request { - ReportRequest::List { - type_, - domain, - result_tx, - } => { - let mut result = Vec::new(); - for key in scheduler.reports.keys() { - if domain - .as_ref() - .map_or(false, |domain| domain != key.domain()) - { - continue; - } - if let Some(type_) = &type_ { - if !matches!( - (key, type_), - (ReportType::Dmarc(_), ReportType::Dmarc(_)) - | (ReportType::Tls(_), ReportType::Tls(_)) - ) { - continue; - } - } - result.push(key.to_string()); - } - let _ = result_tx.send(result); - } - ReportRequest::Status { - report_ids, - result_tx, - } => { - let mut result = Vec::with_capacity(report_ids.len()); - for report_id in &report_ids { - result.push( - scheduler - .reports - .get(report_id) - .map(|report_value| (report_id, report_value).into()), - ); - } - let _ = result_tx.send(result); - } - ReportRequest::Cancel { - report_ids, - result_tx, - } => { - let mut result = Vec::with_capacity(report_ids.len()); - for report_id in &report_ids { - result.push( - if let Some(report) = scheduler.reports.remove(report_id) { - report.delete().await; - true - } else { - false - }, - ); - } - let _ = result_tx.send(result); - } - }, Event::Stop => break, }, Ok(None) => break, Err(_) => { - while let Some(report) = scheduler.next_due() { - match report { - (ReportType::Dmarc(domain), ReportType::Dmarc(path)) => { - core.generate_dmarc_report(domain, path); - } - (ReportType::Tls(domain), ReportType::Tls(path)) => { - core.generate_tls_report(domain, path); - } - _ => unreachable!(), - } - } - // Cleanup expired throttles if last_cleanup.elapsed().as_secs() >= 86400 { last_cleanup = Instant::now(); @@ -185,429 +111,54 @@ impl SpawnReport for mpsc::Receiver<Event> { } impl SMTP { - pub async fn build_report_path( - &self, - domain: ReportType<&str, &str>, - policy: u64, - created: u64, - interval: AggregateFrequency, - ) -> PathBuf { - let (ext, domain) = match domain { - ReportType::Dmarc(domain) => ("d", domain), - ReportType::Tls(domain) => ("t", domain), - }; - - // Build base path - let mut path = self.report.config.path.clone(); - let todo = "fix"; - let hash = 1; - if hash > 0 { - path.push((policy % hash).to_string()); - } - let _ = fs::create_dir(&path).await; - - // Build filename - let mut w = Base32Writer::with_capacity(domain.len() + 13); - w.write(&policy.to_le_bytes()[..]); - w.write(&(created.saturating_sub(946684800) as u32).to_le_bytes()[..]); - w.push_byte( - match interval { - AggregateFrequency::Hourly => 0, - AggregateFrequency::Daily => 1, - AggregateFrequency::Weekly => 2, - AggregateFrequency::Never => 3, + pub async fn next_report_event(&self) -> Vec<QueueClass> { + let from_key = ValueKey::from(ValueClass::Queue(QueueClass::DmarcReportHeader( + ReportEvent { + due: 0, + policy_hash: 0, + seq_id: 0, + domain: String::new(), }, - false, - ); - w.write(domain.as_bytes()); - let mut file = w.finalize(); - file.push('.'); - file.push_str(ext); - path.push(file); - path - } -} - -impl ReportCore { - pub async fn read_reports(&self) -> Scheduler { - let mut scheduler = Scheduler::default(); - - let mut dir = match tokio::fs::read_dir(&self.config.path).await { - Ok(dir) => dir, - Err(_) => { - return scheduler; - } - }; - loop { - match dir.next_entry().await { - Ok(Some(file)) => { - let file = file.path(); - if file.is_dir() { - match tokio::fs::read_dir(&file).await { - Ok(mut dir) => { - let file_ = file; - loop { - match dir.next_entry().await { - Ok(Some(file)) => { - let file = file.path(); - if file - .extension() - .map_or(false, |e| e == "t" || e == "d") - { - if let Err(err) = scheduler.add_path(file).await { - tracing::warn!("{}", err); - } - } - } - Ok(None) => break, - Err(err) => { - tracing::warn!( - "Failed to read report directory {}: {}", - file_.display(), - err - ); - break; - } - } - } - } - Err(err) => { - tracing::warn!( - "Failed to read report directory {}: {}", - file.display(), - err - ) - } - }; - } else if file.extension().map_or(false, |e| e == "t" || e == "d") { - if let Err(err) = scheduler.add_path(file).await { - tracing::warn!("{}", err); - } - } - } - Ok(None) => { - break; - } - Err(err) => { - tracing::warn!( - "Failed to read report directory {}: {}", - self.config.path.display(), - err - ); - break; - } - } - } - - scheduler - } -} - -impl Scheduler { - pub fn next_due(&mut self) -> Option<(ReportKey, ReportValue)> { - let item = self.main.peek()?; - if item.due <= Instant::now() { - let item = self.main.pop().unwrap(); - self.reports - .remove(&item.inner) - .map(|policy| (item.inner, policy)) - } else { - None - } - } - - pub fn wake_up_time(&self) -> Duration { - self.main - .peek() - .map(|item| { - item.due - .checked_duration_since(Instant::now()) - .unwrap_or(self.short_wait) - }) - .unwrap_or(self.long_wait) - } - - pub async fn add_path(&mut self, path: PathBuf) -> Result<(), String> { - let (file, ext) = path - .file_name() - .and_then(|f| f.to_str()) - .and_then(|f| f.rsplit_once('.')) - .ok_or_else(|| format!("Invalid queue file name {}", path.display()))?; - let file_size = fs::metadata(&path) - .await - .map_err(|err| { - format!( - "Failed to obtain file metadata for {}: {}", - path.display(), - err - ) - })? - .len(); - if file_size == 0 { - let _ = fs::remove_file(&path).await; - return Err(format!( - "Removed zero length report file {}", - path.display() - )); - } - - // Decode domain name - let mut policy = [0u8; std::mem::size_of::<u64>()]; - let mut created = [0u8; std::mem::size_of::<u32>()]; - let mut deliver_at = AggregateFrequency::Never; - let mut domain = Vec::new(); - for (pos, byte) in Base32Reader::new(file.as_bytes()).enumerate() { - match pos { - 0..=7 => { - policy[pos] = byte; - } - 8..=11 => { - created[pos - 8] = byte; - } - 12 => { - deliver_at = match byte { - 0 => AggregateFrequency::Hourly, - 1 => AggregateFrequency::Daily, - 2 => AggregateFrequency::Weekly, - _ => { - return Err(format!( - "Failed to base32 decode report file {}", - path.display() - )); - } - }; - } - _ => { - domain.push(byte); - } - } - } - if domain.is_empty() { - return Err(format!( - "Failed to base32 decode report file {}", - path.display() - )); - } - let domain = String::from_utf8(domain).map_err(|err| { - format!( - "Failed to base32 decode report file {}: {}", - path.display(), - err - ) - })?; - - // Rebuild parts - let policy = u64::from_le_bytes(policy); - let created = u32::from_le_bytes(created) as u64 + 946684800; - - match ext { - "d" => { - let key = ReportType::Dmarc(ReportPolicy { - inner: domain, - policy, - }); - self.reports.insert( - key.clone(), - ReportType::Dmarc(ReportPath { - path, - size: file_size as usize, - created, - deliver_at, - }), - ); - self.main.push(Schedule { - due: (created + deliver_at.as_secs()).to_instant(), - inner: key, - }); - } - "t" => match self.reports.entry(ReportType::Tls(domain)) { - Entry::Occupied(mut e) => { - if let ReportType::Tls(tls) = e.get_mut() { - tls.size += file_size as usize; - tls.path.push(ReportPolicy { - inner: path, - policy, - }); - } - } - Entry::Vacant(e) => { - self.main.push(Schedule { - due: (created + deliver_at.as_secs()).to_instant(), - inner: e.key().clone(), - }); - e.insert(ReportType::Tls(ReportPath { - path: vec![ReportPolicy { - inner: path, - policy, - }], - size: file_size as usize, - created, - deliver_at, - })); - } + ))); + let to_key = ValueKey::from(ValueClass::Queue(QueueClass::TlsReportHeader( + ReportEvent { + due: u64::MAX, + policy_hash: 0, + seq_id: 0, + domain: String::new(), }, - _ => unreachable!(), - } - - Ok(()) - } -} - -pub async fn json_write(path: &PathBuf, entry: &impl Serialize) -> usize { - if let Ok(bytes) = serde_json::to_vec(entry) { - // Save serialized report - let bytes_written = bytes.len() - 2; - match fs::File::create(&path).await { - Ok(mut file) => match file.write_all(&bytes[..bytes_written]).await { - Ok(_) => bytes_written, - Err(err) => { - tracing::error!( - context = "report", - event = "error", - "Failed to write to report file {}: {}", - path.display(), - err - ); - 0 - } - }, - Err(err) => { - tracing::error!( - context = "report", - event = "error", - "Failed to create report file {}: {}", - path.display(), - err - ); - 0 - } - } - } else { - 0 - } -} - -pub async fn json_append(path: &PathBuf, entry: &impl Serialize, bytes_left: usize) -> usize { - let mut bytes = Vec::with_capacity(128); - bytes.push(b','); - if serde_json::to_writer(&mut bytes, entry).is_ok() && bytes.len() <= bytes_left { - let err = match OpenOptions::new().append(true).open(&path).await { - Ok(mut file) => match file.write_all(&bytes).await { - Ok(_) => return bytes.len(), - Err(err) => err, - }, - Err(err) => err, - }; - tracing::error!( - context = "report", - event = "error", - "Failed to append report to {}: {}", - path.display(), - err - ); - } - 0 -} - -pub async fn json_read<T: DeserializeOwned>(path: &PathBuf, span: &tracing::Span) -> Option<T> { - match fs::read_to_string(&path).await { - Ok(mut json) => { - json.push_str("]}"); - match serde_json::from_str(&json) { - Ok(report) => Some(report), - Err(err) => { - tracing::error!( - parent: span, - context = "deserialize", - event = "error", - "Failed to deserialize report file {}: {}", - path.display(), - err - ); - None - } - } - } - Err(err) => { - tracing::error!( - parent: span, - context = "io", - event = "error", - "Failed to read report file {}: {}", - path.display(), - err - ); - None - } - } -} + ))); + + let mut events = Vec::new(); + let now = now(); + let result = self + .shared + .default_data_store + .iterate( + IterateParams::new(from_key, to_key).ascending().no_values(), + |key, _| { + let event = ReportEvent::deserialize(key)?; + let do_continue = event.due <= now; + events.push(if *key.last().unwrap() == 0 { + QueueClass::DmarcReportHeader(event) + } else { + QueueClass::TlsReportHeader(event) + }); + Ok(do_continue) + }, + ) + .await; -pub fn json_read_blocking<T: DeserializeOwned>(path: &PathBuf, span: &tracing::Span) -> Option<T> { - match std::fs::read_to_string(path) { - Ok(mut json) => { - json.push_str("]}"); - match serde_json::from_str(&json) { - Ok(report) => Some(report), - Err(err) => { - tracing::error!( - parent: span, - context = "deserialize", - event = "error", - "Failed to deserialize report file {}: {}", - path.display(), - err - ); - None - } - } - } - Err(err) => { + if let Err(err) = result { tracing::error!( - parent: span, - context = "io", + context = "queue", event = "error", - "Failed to read report file {}: {}", - path.display(), + "Failed to read from store: {}", err ); - None } - } -} - -impl Default for Scheduler { - fn default() -> Self { - Self { - short_wait: Duration::from_millis(1), - long_wait: Duration::from_secs(86400 * 365), - main: BinaryHeap::with_capacity(128), - reports: AHashMap::with_capacity(128), - } - } -} - -impl ReportKey { - pub fn domain_name(&self) -> &str { - match self { - ReportType::Dmarc(domain) => domain.inner.as_str(), - ReportType::Tls(domain) => domain.as_str(), - } - } -} -impl ReportValue { - pub fn dmarc_path(&mut self) -> &mut ReportPath<PathBuf> { - match self { - ReportType::Dmarc(path) => path, - ReportType::Tls(_) => unreachable!(), - } - } - - pub fn tls_path(&mut self) -> &mut ReportPath<Vec<ReportPolicy<PathBuf>>> { - match self { - ReportType::Tls(path) => path, - ReportType::Dmarc(_) => unreachable!(), - } + events } } @@ -641,5 +192,5 @@ impl ToTimestamp for Duration { } pub trait SpawnReport { - fn spawn(self, core: Arc<SMTP>, scheduler: Scheduler); + fn spawn(self, core: Arc<SMTP>); } diff --git a/crates/smtp/src/reporting/spf.rs b/crates/smtp/src/reporting/spf.rs index 1f5c37e5..ac2398d1 100644 --- a/crates/smtp/src/reporting/spf.rs +++ b/crates/smtp/src/reporting/spf.rs @@ -36,7 +36,7 @@ impl<T: AsyncWrite + AsyncRead + Unpin> Session<T> { output: &SpfOutput, ) { // Throttle recipient - if !self.throttle_rcpt(rcpt, rate, "spf") { + if !self.throttle_rcpt(rcpt, rate, "spf").await { tracing::debug!( parent: &self.span, context = "report", diff --git a/crates/smtp/src/reporting/tls.rs b/crates/smtp/src/reporting/tls.rs index 052d2757..a65e7e0e 100644 --- a/crates/smtp/src/reporting/tls.rs +++ b/crates/smtp/src/reporting/tls.rs @@ -21,7 +21,7 @@ * for more details. */ -use std::{collections::hash_map::Entry, path::PathBuf, sync::Arc, time::Duration}; +use std::{collections::hash_map::Entry, sync::Arc, time::Duration}; use ahash::AHashMap; use mail_auth::{ @@ -34,25 +34,21 @@ use mail_auth::{ use mail_parser::DateTime; use reqwest::header::CONTENT_TYPE; -use serde::{Deserialize, Serialize}; use std::fmt::Write; -use tokio::runtime::Handle; +use store::{ + write::{now, BatchBuilder, Bincode, QueueClass, ReportEvent, ValueClass}, + Deserialize, IterateParams, Serialize, ValueKey, +}; use crate::{ config::AggregateFrequency, core::SMTP, outbound::mta_sts::{Mode, MxPattern}, - queue::{InstantFromTimestamp, RecipientDomain, Schedule}, + queue::RecipientDomain, USER_AGENT, }; -use super::{ - scheduler::{ - json_append, json_read_blocking, json_write, ReportPath, ReportPolicy, ReportType, - Scheduler, ToHash, - }, - TlsEvent, -}; +use super::{scheduler::ToHash, SerializedSize, TlsEvent}; #[derive(Debug, Clone)] pub struct TlsRptOptions { @@ -60,310 +56,338 @@ pub struct TlsRptOptions { pub interval: AggregateFrequency, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, serde::Serialize, serde::Deserialize)] struct TlsFormat { rua: Vec<ReportUri>, policy: PolicyDetails, records: Vec<Option<FailureDetails>>, } -pub trait GenerateTlsReport { - fn generate_tls_report(&self, domain: String, paths: ReportPath<Vec<ReportPolicy<PathBuf>>>); -} - #[cfg(feature = "test_mode")] pub static TLS_HTTP_REPORT: parking_lot::Mutex<Vec<u8>> = parking_lot::Mutex::new(Vec::new()); -impl GenerateTlsReport for Arc<SMTP> { - fn generate_tls_report(&self, domain: String, path: ReportPath<Vec<ReportPolicy<PathBuf>>>) { - let core = self.clone(); - let handle = Handle::current(); - - self.worker_pool.spawn(move || { - let deliver_at = path.created + path.deliver_at.as_secs(); - let span = tracing::info_span!( - "tls-report", - domain = domain, - range_from = path.created, - range_to = deliver_at, - size = path.size, - ); +impl SMTP { + pub async fn generate_tls_report(&self, domain_name: String, events: Vec<ReportEvent>) { + let (event_from, event_to, policy) = events + .first() + .map(|e| (e.seq_id, e.due, e.policy_hash)) + .unwrap(); + + let span = tracing::info_span!( + "tls-report", + domain = domain_name, + range_from = event_from, + range_to = event_to, + ); + // Deserialize report + let config = &self.report.config.tls; + let mut report = TlsReport { + organization_name: self + .eval_if( + &config.org_name, + &RecipientDomain::new(domain_name.as_str()), + ) + .await + .clone(), + date_range: DateRange { + start_datetime: DateTime::from_timestamp(event_from as i64), + end_datetime: DateTime::from_timestamp(event_to as i64), + }, + contact_info: self + .eval_if( + &config.contact_info, + &RecipientDomain::new(domain_name.as_str()), + ) + .await + .clone(), + report_id: format!("{}_{}", event_from, policy), + policies: Vec::with_capacity(events.len()), + }; + let mut rua = Vec::new(); + let mut serialized_size = serde_json::Serializer::new(SerializedSize::new( + self.eval_if( + &self.report.config.tls.max_size, + &RecipientDomain::new(domain_name.as_str()), + ) + .await + .unwrap_or(25 * 1024 * 1024), + )); + let _ = serde::Serialize::serialize(&report, &mut serialized_size); + + for event in &events { // Deserialize report - let config = &core.report.config.tls; - let mut report = TlsReport { - organization_name: handle - .block_on( - core.eval_if(&config.org_name, &RecipientDomain::new(domain.as_str())), - ) - .clone(), - date_range: DateRange { - start_datetime: DateTime::from_timestamp(path.created as i64), - end_datetime: DateTime::from_timestamp(deliver_at as i64), - }, - contact_info: handle - .block_on( - core.eval_if(&config.contact_info, &RecipientDomain::new(domain.as_str())), - ) - .clone(), - report_id: format!( - "{}_{}", - path.created, - path.path.first().map_or(0, |p| p.policy) - ), - policies: Vec::with_capacity(path.path.len()), + let tls = match self + .shared + .default_data_store + .get_value::<Bincode<TlsFormat>>(ValueKey::from(ValueClass::Queue( + QueueClass::TlsReportHeader(event.clone()), + ))) + .await + { + Ok(Some(dmarc)) => dmarc.inner, + Ok(None) => { + tracing::warn!( + parent: &span, + event = "missing", + "Failed to read DMARC report: Report not found" + ); + continue; + } + Err(err) => { + tracing::warn!( + parent: &span, + event = "error", + "Failed to read DMARC report: {}", + err + ); + continue; + } }; - let mut rua = Vec::new(); - for path in &path.path { - if let Some(tls) = json_read_blocking::<TlsFormat>(&path.inner, &span) { - // Group duplicates - let mut total_success = 0; - let mut total_failure = 0; - let mut record_map = AHashMap::with_capacity(tls.records.len()); - for record in tls.records { - if let Some(record) = record { - match record_map.entry(record) { - Entry::Occupied(mut e) => { - *e.get_mut() += 1; - } - Entry::Vacant(e) => { + let _ = serde::Serialize::serialize(&tls, &mut serialized_size); + + // Group duplicates + let mut total_success = 0; + let mut total_failure = 0; + + let from_key = + ValueKey::from(ValueClass::Queue(QueueClass::TlsReportEvent(ReportEvent { + due: event.due, + policy_hash: event.policy_hash, + seq_id: 0, + domain: event.domain.clone(), + }))); + let to_key = + ValueKey::from(ValueClass::Queue(QueueClass::TlsReportEvent(ReportEvent { + due: event.due, + policy_hash: event.policy_hash, + seq_id: u64::MAX, + domain: event.domain.clone(), + }))); + let mut record_map = AHashMap::with_capacity(tls.records.len()); + if let Err(err) = self + .shared + .default_data_store + .iterate(IterateParams::new(from_key, to_key).ascending(), |_, v| { + if let Some(failure_details) = + Bincode::<Option<FailureDetails>>::deserialize(v)?.inner + { + total_failure += 1; + + match record_map.entry(failure_details) { + Entry::Occupied(mut e) => { + *e.get_mut() += 1; + Ok(true) + } + Entry::Vacant(e) => { + if serde::Serialize::serialize(e.key(), &mut serialized_size) + .is_ok() + { e.insert(1u32); + Ok(true) + } else { + Ok(false) } } - total_failure += 1; - } else { - total_success += 1; } + } else { + total_success += 1; + Ok(true) } - report.policies.push(Policy { - policy: tls.policy, - summary: Summary { - total_success, - total_failure, - }, - failure_details: record_map - .into_iter() - .map(|(mut r, count)| { - r.failed_session_count = count; - r - }) - .collect(), - }); - - rua = tls.rua; - } + }) + .await + { + tracing::warn!( + parent: &span, + event = "error", + "Failed to read TLS report: {}", + err + ); } - if report.policies.is_empty() { - // This should not happen - tracing::warn!( + report.policies.push(Policy { + policy: tls.policy, + summary: Summary { + total_success, + total_failure, + }, + failure_details: record_map + .into_iter() + .map(|(mut r, count)| { + r.failed_session_count = count; + r + }) + .collect(), + }); + + rua = tls.rua; + } + + if report.policies.is_empty() { + // This should not happen + tracing::warn!( + parent: &span, + event = "empty-report", + "No policies found in report" + ); + self.delete_tls_report(events).await; + return; + } + + // Compress and serialize report + let json = report.to_json(); + let mut e = GzEncoder::new(Vec::with_capacity(json.len()), Compression::default()); + let json = match std::io::Write::write_all(&mut e, json.as_bytes()).and_then(|_| e.finish()) + { + Ok(report) => report, + Err(err) => { + tracing::error!( parent: &span, - event = "empty-report", - "No policies found in report" + event = "error", + "Failed to compress report: {}", + err ); - path.cleanup_blocking(); + self.delete_tls_report(events).await; return; } + }; - // Compress and serialize report - let json = report.to_json(); - let mut e = GzEncoder::new(Vec::with_capacity(json.len()), Compression::default()); - let json = - match std::io::Write::write_all(&mut e, json.as_bytes()).and_then(|_| e.finish()) { - Ok(report) => report, - Err(err) => { - tracing::error!( - parent: &span, - event = "error", - "Failed to compress report: {}", - err - ); - return; - } - }; - - // Try delivering report over HTTP - let mut rcpts = Vec::with_capacity(rua.len()); - for uri in &rua { - match uri { - ReportUri::Http(uri) => { - if let Ok(client) = reqwest::blocking::Client::builder() - .user_agent(USER_AGENT) - .timeout(Duration::from_secs(2 * 60)) - .build() - { - #[cfg(feature = "test_mode")] - if uri == "https://127.0.0.1/tls" { - TLS_HTTP_REPORT.lock().extend_from_slice(&json); - path.cleanup_blocking(); - return; - } + // Try delivering report over HTTP + let mut rcpts = Vec::with_capacity(rua.len()); + for uri in &rua { + match uri { + ReportUri::Http(uri) => { + if let Ok(client) = reqwest::blocking::Client::builder() + .user_agent(USER_AGENT) + .timeout(Duration::from_secs(2 * 60)) + .build() + { + #[cfg(feature = "test_mode")] + if uri == "https://127.0.0.1/tls" { + TLS_HTTP_REPORT.lock().extend_from_slice(&json); + self.delete_tls_report(events).await; + return; + } - match client - .post(uri) - .header(CONTENT_TYPE, "application/tlsrpt+gzip") - .body(json.to_vec()) - .send() - { - Ok(response) => { - if response.status().is_success() { - tracing::info!( - parent: &span, - context = "http", - event = "success", - url = uri, - ); - path.cleanup_blocking(); - return; - } else { - tracing::debug!( - parent: &span, - context = "http", - event = "invalid-response", - url = uri, - status = %response.status() - ); - } - } - Err(err) => { + match client + .post(uri) + .header(CONTENT_TYPE, "application/tlsrpt+gzip") + .body(json.to_vec()) + .send() + { + Ok(response) => { + if response.status().is_success() { + tracing::info!( + parent: &span, + context = "http", + event = "success", + url = uri, + ); + self.delete_tls_report(events).await; + return; + } else { tracing::debug!( parent: &span, context = "http", - event = "error", + event = "invalid-response", url = uri, - reason = %err + status = %response.status() ); } } + Err(err) => { + tracing::debug!( + parent: &span, + context = "http", + event = "error", + url = uri, + reason = %err + ); + } } } - ReportUri::Mail(mailto) => { - rcpts.push(mailto.as_str()); - } + } + ReportUri::Mail(mailto) => { + rcpts.push(mailto.as_str()); } } + } - // Deliver report over SMTP - if !rcpts.is_empty() { - let from_addr = handle - .block_on(core.eval_if(&config.address, &RecipientDomain::new(domain.as_str()))) - .unwrap_or_else(|| "MAILER-DAEMON@localhost".to_string()); - let mut message = Vec::with_capacity(path.size); - let _ = report.write_rfc5322_from_bytes( - &domain, - &handle - .block_on(core.eval_if( - &core.report.config.submitter, - &RecipientDomain::new(domain.as_str()), - )) - .unwrap_or_else(|| "localhost".to_string()), - ( - handle - .block_on( - core.eval_if(&config.name, &RecipientDomain::new(domain.as_str())), - ) - .unwrap_or_else(|| "Mail Delivery Subsystem".to_string()) - .as_str(), - from_addr.as_str(), - ), - rcpts.iter().copied(), - &json, - &mut message, - ); + // Deliver report over SMTP + if !rcpts.is_empty() { + let from_addr = self + .eval_if(&config.address, &RecipientDomain::new(domain_name.as_str())) + .await + .unwrap_or_else(|| "MAILER-DAEMON@localhost".to_string()); + let mut message = Vec::with_capacity(2048); + let _ = report.write_rfc5322_from_bytes( + &domain_name, + &self + .eval_if( + &self.report.config.submitter, + &RecipientDomain::new(domain_name.as_str()), + ) + .await + .unwrap_or_else(|| "localhost".to_string()), + ( + self.eval_if(&config.name, &RecipientDomain::new(domain_name.as_str())) + .await + .unwrap_or_else(|| "Mail Delivery Subsystem".to_string()) + .as_str(), + from_addr.as_str(), + ), + rcpts.iter().copied(), + &json, + &mut message, + ); - // Send report - handle.block_on(core.send_report( - &from_addr, - rcpts.iter(), - message, - &config.sign, - &span, - false, - )); - } else { - tracing::info!( - parent: &span, - event = "delivery-failed", - "No valid recipients found to deliver report to." - ); - } - path.cleanup_blocking(); - }); + // Send report + self.send_report( + &from_addr, + rcpts.iter(), + message, + &config.sign, + &span, + false, + ) + .await; + } else { + tracing::info!( + parent: &span, + event = "delivery-failed", + "No valid recipients found to deliver report to." + ); + } + self.delete_tls_report(events).await; } -} -impl Scheduler { - pub async fn schedule_tls(&mut self, event: Box<TlsEvent>, core: &SMTP) { - let max_size = core - .eval_if( - &core.report.config.tls.max_size, - &RecipientDomain::new(event.domain.as_str()), - ) - .await - .unwrap_or(25 * 1024 * 1024); - let policy_hash = event.policy.to_hash(); - - let (path, pos, create) = match self.reports.entry(ReportType::Tls(event.domain)) { - Entry::Occupied(e) => { - if let ReportType::Tls(path) = e.get() { - if let Some(pos) = path.path.iter().position(|p| p.policy == policy_hash) { - (e.into_mut().tls_path(), pos, None) - } else { - let pos = path.path.len(); - let domain = e.key().domain_name().to_string(); - let path = e.into_mut().tls_path(); - path.path.push(ReportPolicy { - inner: core - .build_report_path( - ReportType::Tls(&domain), - policy_hash, - path.created, - path.deliver_at, - ) - .await, - policy: policy_hash, - }); - (path, pos, domain.into()) - } - } else { - unreachable!() - } - } - Entry::Vacant(e) => { - let created = event.interval.to_timestamp(); - let deliver_at = created + event.interval.as_secs(); - - self.main.push(Schedule { - due: deliver_at.to_instant(), - inner: e.key().clone(), - }); - let domain = e.key().domain_name().to_string(); - let path = core - .build_report_path( - ReportType::Tls(&domain), - policy_hash, - created, - event.interval, - ) - .await; - let v = e.insert(ReportType::Tls(ReportPath { - path: vec![ReportPolicy { - inner: path, - policy: policy_hash, - }], - size: 0, - created, - deliver_at: event.interval, - })); - (v.tls_path(), 0, domain.into()) - } + pub async fn schedule_tls(&self, event: Box<TlsEvent>) { + let created = event.interval.to_timestamp(); + let deliver_at = created + event.interval.as_secs(); + let mut report_event = ReportEvent { + due: deliver_at, + policy_hash: event.policy.to_hash(), + seq_id: created, + domain: event.domain, }; - if let Some(domain) = create { + // Write policy if missing + let mut builder = BatchBuilder::new(); + if self + .shared + .default_data_store + .get_value::<()>(ValueKey::from(ValueClass::Queue( + QueueClass::TlsReportHeader(report_event.clone()), + ))) + .await + .unwrap_or_default() + .is_none() + { + // Serialize report let mut policy = PolicyDetails { policy_type: PolicyType::NoPolicyFound, policy_string: vec![], - policy_domain: domain, + policy_domain: report_event.domain.clone(), mx_host: vec![], }; @@ -420,47 +444,78 @@ impl Scheduler { let entry = TlsFormat { rua: event.tls_record.rua.clone(), policy, - records: vec![event.failure], + records: vec![], }; - let bytes_written = json_write(&path.path[pos].inner, &entry).await; - - if bytes_written > 0 { - path.size += bytes_written; - } else { - // Something went wrong, remove record - if let Entry::Occupied(mut e) = self - .reports - .entry(ReportType::Tls(entry.policy.policy_domain)) - { - if let ReportType::Tls(path) = e.get_mut() { - path.path.retain(|p| p.policy != policy_hash); - if path.path.is_empty() { - e.remove_entry(); - } - } - } - } - } else if path.size < max_size { - // Append to existing report - path.size += - json_append(&path.path[pos].inner, &event.failure, max_size - path.size).await; + + // Write report + builder.set( + ValueClass::Queue(QueueClass::TlsReportHeader(report_event.clone())), + Bincode::new(entry).serialize(), + ); + } + + // Write entry + report_event.seq_id = self.queue.snowflake_id.generate().unwrap_or_else(now); + builder.set( + ValueClass::Queue(QueueClass::TlsReportEvent(report_event)), + Bincode::new(event.failure).serialize(), + ); + + if let Err(err) = self.shared.default_data_store.write(builder.build()).await { + tracing::error!( + context = "report", + event = "error", + "Failed to write DMARC report event: {}", + err + ); } } -} -impl ReportPath<Vec<ReportPolicy<PathBuf>>> { - fn cleanup_blocking(&self) { - for path in &self.path { - if let Err(err) = std::fs::remove_file(&path.inner) { - tracing::error!( + pub async fn delete_tls_report(&self, events: Vec<ReportEvent>) { + let mut batch = BatchBuilder::new(); + + for event in events { + let from_key = ReportEvent { + due: event.due, + policy_hash: event.policy_hash, + seq_id: 0, + domain: event.domain.clone(), + }; + let to_key = ReportEvent { + due: event.due, + policy_hash: event.policy_hash, + seq_id: u64::MAX, + domain: event.domain.clone(), + }; + + if let Err(err) = self + .shared + .default_data_store + .delete_range( + ValueKey::from(ValueClass::Queue(QueueClass::TlsReportEvent(from_key))), + ValueKey::from(ValueClass::Queue(QueueClass::TlsReportEvent(to_key))), + ) + .await + { + tracing::warn!( context = "report", - report = "tls", event = "error", - "Failed to delete file {}: {}", - path.inner.display(), + "Failed to remove repors: {}", err ); + return; } + + batch.clear(ValueClass::Queue(QueueClass::TlsReportHeader(event))); + } + + if let Err(err) = self.shared.default_data_store.write(batch.build()).await { + tracing::warn!( + context = "report", + event = "error", + "Failed to remove repors: {}", + err + ); } } } diff --git a/crates/smtp/src/scripts/event_loop.rs b/crates/smtp/src/scripts/event_loop.rs index 845430f7..4bfaefaf 100644 --- a/crates/smtp/src/scripts/event_loop.rs +++ b/crates/smtp/src/scripts/event_loop.rs @@ -21,7 +21,7 @@ * for more details. */ -use std::{sync::Arc, time::Duration}; +use std::sync::Arc; use mail_auth::common::headers::HeaderWriter; use sieve::{ @@ -32,18 +32,12 @@ use smtp_proto::{ MAIL_BY_TRACE, MAIL_RET_FULL, MAIL_RET_HDRS, RCPT_NOTIFY_DELAY, RCPT_NOTIFY_FAILURE, RCPT_NOTIFY_NEVER, RCPT_NOTIFY_SUCCESS, }; -use store::{backend::memory::MemoryStore, LookupKey, LookupStore, LookupValue}; +use store::{backend::memory::MemoryStore, LookupStore}; use tokio::runtime::Handle; -use crate::{ - core::SMTP, - queue::{DomainPart, InstantFromTimestamp, Message}, -}; +use crate::{core::SMTP, queue::DomainPart}; -use super::{ - plugins::{lookup::VariableExists, PluginContext}, - ScriptModification, ScriptParameters, ScriptResult, -}; +use super::{plugins::PluginContext, ScriptModification, ScriptParameters, ScriptResult}; impl SMTP { pub fn run_script_blocking( @@ -97,15 +91,15 @@ impl SMTP { 'outer: for list in lists { if let Some(store) = self.shared.lookup_stores.get(&list) { for value in &values { - if let Ok(LookupValue::Value { .. }) = handle.block_on( - store.key_get::<VariableExists>(LookupKey::Key( + if let Ok(true) = handle.block_on( + store.key_exists( if !matches!(match_as, MatchAs::Lowercase) { value.clone() } else { value.to_lowercase() } .into_bytes(), - )), + ), ) { input = true.into(); break 'outer; @@ -156,7 +150,7 @@ impl SMTP { // Build message let return_path_lcase = self.sieve.return_path.to_lowercase(); let return_path_domain = return_path_lcase.domain_part().to_string(); - let mut message = Message::new_boxed( + let mut message = self.queue.new_message( self.sieve.return_path.clone(), return_path_lcase, return_path_domain, @@ -223,7 +217,6 @@ impl SMTP { if trace { message.flags |= MAIL_BY_TRACE; } - let rlimit = Duration::from_secs(rlimit); match mode { ByMode::Notify => { for domain in &mut message.domains { @@ -246,16 +239,15 @@ impl SMTP { if trace { message.flags |= MAIL_BY_TRACE; } - let alimit = (alimit as u64).to_instant(); match mode { ByMode::Notify => { for domain in &mut message.domains { - domain.notify.due = alimit; + domain.notify.due = alimit as u64; } } ByMode::Return => { for domain in &mut message.domains { - domain.expires = alimit; + domain.expires = alimit as u64; } } ByMode::Default => (), @@ -302,10 +294,10 @@ impl SMTP { None }; - handle.block_on(self.queue.queue_message( - message, + handle.block_on(message.queue( headers.as_deref(), raw_message, + self, &span, )); } diff --git a/crates/smtp/src/scripts/plugins/bayes.rs b/crates/smtp/src/scripts/plugins/bayes.rs index ab5a147b..8b035ceb 100644 --- a/crates/smtp/src/scripts/plugins/bayes.rs +++ b/crates/smtp/src/scripts/plugins/bayes.rs @@ -29,12 +29,12 @@ use nlp::{ tokenizers::osb::{OsbToken, OsbTokenizer}, }; use sieve::{runtime::Variable, FunctionMap}; -use store::{write::key::KeySerializer, LookupKey, LookupStore, LookupValue, U64_LEN}; +use store::{write::key::KeySerializer, LookupStore, U64_LEN}; use tokio::runtime::Handle; use crate::config::scripts::SieveContext; -use super::{lookup::VariableExists, PluginContext}; +use super::PluginContext; pub fn register_train(plugin_id: u32, fnc_map: &mut FunctionMap<SieveContext>) { fnc_map.set_external_function("bayes_train", plugin_id, 3); @@ -110,14 +110,13 @@ fn train(ctx: PluginContext<'_>, is_train: bool) -> Variable { for (hash, weights) in model.weights { if handle .block_on( - store.key_set( + store.counter_incr( KeySerializer::new(U64_LEN) .write(hash.h1) .write(hash.h2) .finalize(), - LookupValue::Counter { - num: weights.into(), - }, + weights.into(), + None, ), ) .is_err() @@ -135,14 +134,13 @@ fn train(ctx: PluginContext<'_>, is_train: bool) -> Variable { }; if handle .block_on( - store.key_set( + store.counter_incr( KeySerializer::new(U64_LEN) .write(0u64) .write(0u64) .finalize(), - LookupValue::Counter { - num: weights.into(), - }, + weights.into(), + None, ), ) .is_err() @@ -337,15 +335,15 @@ impl LookupOrInsert for BayesTokenCache { ) -> Option<Weights> { if let Some(weights) = self.get(&hash) { weights.unwrap_or_default().into() - } else if let Ok(result) = handle.block_on( - get_token.key_get::<VariableExists>(LookupKey::Counter( + } else if let Ok(num) = handle.block_on( + get_token.counter_get( KeySerializer::new(U64_LEN) .write(hash.h1) .write(hash.h2) .finalize(), - )), + ), ) { - if let LookupValue::Counter { num } = result { + if num != 0 { let weights = Weights::from(num); self.insert_positive(hash, weights); weights diff --git a/crates/smtp/src/scripts/plugins/lookup.rs b/crates/smtp/src/scripts/plugins/lookup.rs index 27978d8a..913a761b 100644 --- a/crates/smtp/src/scripts/plugins/lookup.rs +++ b/crates/smtp/src/scripts/plugins/lookup.rs @@ -29,7 +29,7 @@ use std::{ use mail_auth::flate2; use sieve::{runtime::Variable, FunctionMap}; -use store::{Deserialize, LookupKey, LookupValue, Value}; +use store::{Deserialize, Value}; use crate::{ config::scripts::{RemoteList, SieveContext}, @@ -72,10 +72,7 @@ pub fn exec(ctx: PluginContext<'_>) -> Variable { if !item.is_empty() && ctx .handle - .block_on(store.key_get::<VariableExists>(LookupKey::Key( - item.to_string().into_owned().into_bytes(), - ))) - .map(|v| v != LookupValue::None) + .block_on(store.key_exists(item.to_string().into_owned().into_bytes())) .unwrap_or(false) { return true.into(); @@ -85,10 +82,7 @@ pub fn exec(ctx: PluginContext<'_>) -> Variable { } v if !v.is_empty() => ctx .handle - .block_on(store.key_get::<VariableExists>(LookupKey::Key( - v.to_string().into_owned().into_bytes(), - ))) - .map(|v| v != LookupValue::None) + .block_on(store.key_exists(v.to_string().into_owned().into_bytes())) .unwrap_or(false), _ => false, } @@ -113,14 +107,13 @@ pub fn exec_get(ctx: PluginContext<'_>) -> Variable { if let Some(store) = store { ctx.handle - .block_on(store.key_get::<VariableWrapper>(LookupKey::Key( - ctx.arguments[1].to_string().into_owned().into_bytes(), - ))) - .map(|v| match v { - LookupValue::Value { value, .. } => value.into_inner(), - LookupValue::Counter { num } => num.into(), - LookupValue::None => Variable::default(), - }) + .block_on( + store.key_get::<VariableWrapper>( + ctx.arguments[1].to_string().into_owned().into_bytes(), + ), + ) + .unwrap_or_default() + .map(|v| v.into_inner()) .unwrap_or_default() } else { tracing::warn!( @@ -142,22 +135,20 @@ pub fn exec_set(ctx: PluginContext<'_>) -> Variable { if let Some(store) = store { let expires = match &ctx.arguments[3] { - Variable::Integer(v) => *v as u64, - Variable::Float(v) => *v as u64, - _ => 0, + Variable::Integer(v) => Some(*v as u64), + Variable::Float(v) => Some(*v as u64), + _ => None, }; ctx.handle .block_on(store.key_set( ctx.arguments[1].to_string().into_owned().into_bytes(), - LookupValue::Value { - value: if !ctx.arguments[2].is_empty() { - bincode::serialize(&ctx.arguments[2]).unwrap_or_default() - } else { - vec![] - }, - expires, + if !ctx.arguments[2].is_empty() { + bincode::serialize(&ctx.arguments[2]).unwrap_or_default() + } else { + vec![] }, + expires, )) .is_ok() .into() @@ -426,9 +417,6 @@ pub fn exec_local_domain(ctx: PluginContext<'_>) -> Variable { #[derive(Debug, PartialEq, Eq)] pub struct VariableWrapper(Variable); -#[derive(Debug, PartialEq, Eq)] -pub struct VariableExists; - impl Deserialize for VariableWrapper { fn deserialize(bytes: &[u8]) -> store::Result<Self> { Ok(VariableWrapper( @@ -439,9 +427,9 @@ impl Deserialize for VariableWrapper { } } -impl Deserialize for VariableExists { - fn deserialize(_: &[u8]) -> store::Result<Self> { - Ok(VariableExists) +impl From<i64> for VariableWrapper { + fn from(value: i64) -> Self { + VariableWrapper(value.into()) } } @@ -451,12 +439,6 @@ impl VariableWrapper { } } -impl From<Value<'static>> for VariableExists { - fn from(_: Value<'static>) -> Self { - VariableExists - } -} - impl From<Value<'static>> for VariableWrapper { fn from(value: Value<'static>) -> Self { VariableWrapper(into_sieve_value(value)) |