diff options
author | mdecimus <mauro@stalw.art> | 2024-05-24 10:32:41 +0200 |
---|---|---|
committer | mdecimus <mauro@stalw.art> | 2024-05-24 10:32:41 +0200 |
commit | 4e7087d33528d1dc145b02e7f195808a4e7ffc01 (patch) | |
tree | c9e62f81ffed6543ddd5e3bddda9cb408245f031 | |
parent | ffdb7d766ac56f86e849682d41520e23b3d4b6a5 (diff) |
Run Sieve scripts in async context
-rw-r--r-- | crates/common/src/scripts/plugins/bayes.rs | 116 | ||||
-rw-r--r-- | crates/common/src/scripts/plugins/dns.rs | 90 | ||||
-rw-r--r-- | crates/common/src/scripts/plugins/exec.rs | 58 | ||||
-rw-r--r-- | crates/common/src/scripts/plugins/http.rs | 9 | ||||
-rw-r--r-- | crates/common/src/scripts/plugins/lookup.rs | 274 | ||||
-rw-r--r-- | crates/common/src/scripts/plugins/mod.rs | 56 | ||||
-rw-r--r-- | crates/common/src/scripts/plugins/pyzor.rs | 7 | ||||
-rw-r--r-- | crates/common/src/scripts/plugins/query.rs | 9 | ||||
-rw-r--r-- | crates/smtp/src/core/mod.rs | 6 | ||||
-rw-r--r-- | crates/smtp/src/core/throttle.rs | 15 | ||||
-rw-r--r-- | crates/smtp/src/core/worker.rs | 69 | ||||
-rw-r--r-- | crates/smtp/src/lib.rs | 10 | ||||
-rw-r--r-- | crates/smtp/src/reporting/analysis.rs | 23 | ||||
-rw-r--r-- | crates/smtp/src/reporting/scheduler.rs | 2 | ||||
-rw-r--r-- | crates/smtp/src/scripts/event_loop.rs | 54 | ||||
-rw-r--r-- | crates/smtp/src/scripts/exec.rs | 9 | ||||
-rw-r--r-- | crates/store/src/write/key.rs | 4 | ||||
-rw-r--r-- | tests/src/smtp/inbound/antispam.rs | 8 | ||||
-rw-r--r-- | tests/src/smtp/inbound/scripts.rs | 8 |
19 files changed, 366 insertions, 461 deletions
diff --git a/crates/common/src/scripts/plugins/bayes.rs b/crates/common/src/scripts/plugins/bayes.rs index c4159763..d96285d4 100644 --- a/crates/common/src/scripts/plugins/bayes.rs +++ b/crates/common/src/scripts/plugins/bayes.rs @@ -30,7 +30,6 @@ use nlp::{ }; use sieve::{runtime::Variable, FunctionMap}; use store::{write::key::KeySerializer, LookupStore, U64_LEN}; -use tokio::runtime::Handle; use super::PluginContext; @@ -50,15 +49,15 @@ pub fn register_is_balanced(plugin_id: u32, fnc_map: &mut FunctionMap) { fnc_map.set_external_function("bayes_is_balanced", plugin_id, 3); } -pub fn exec_train(ctx: PluginContext<'_>) -> Variable { - train(ctx, true) +pub async fn exec_train(ctx: PluginContext<'_>) -> Variable { + train(ctx, true).await } -pub fn exec_untrain(ctx: PluginContext<'_>) -> Variable { - train(ctx, false) +pub async fn exec_untrain(ctx: PluginContext<'_>) -> Variable { + train(ctx, false).await } -fn train(ctx: PluginContext<'_>, is_train: bool) -> Variable { +async fn train(ctx: PluginContext<'_>, is_train: bool) -> Variable { let span: &tracing::Span = ctx.span; let store = match &ctx.arguments[0] { Variable::String(v) if !v.is_empty() => ctx.core.storage.lookups.get(v.as_ref()), @@ -82,7 +81,6 @@ fn train(ctx: PluginContext<'_>, is_train: bool) -> Variable { if text.is_empty() { return false.into(); } - let handle = ctx.handle; // Train the model let mut model = BayesModel::default(); @@ -109,18 +107,17 @@ fn train(ctx: PluginContext<'_>, is_train: bool) -> Variable { let bayes_cache = &ctx.core.sieve.bayes_cache; if is_train { for (hash, weights) in model.weights { - if handle - .block_on( - store.counter_incr( - KeySerializer::new(U64_LEN) - .write(hash.h1) - .write(hash.h2) - .finalize(), - weights.into(), - None, - false, - ), + if store + .counter_incr( + KeySerializer::new(U64_LEN) + .write(hash.h1) + .write(hash.h2) + .finalize(), + weights.into(), + None, + false, ) + .await .is_err() { return false.into(); @@ -134,18 +131,17 @@ fn train(ctx: PluginContext<'_>, is_train: bool) -> Variable { } else { Weights { spam: 0, ham: 1 } }; - if handle - .block_on( - store.counter_incr( - KeySerializer::new(U64_LEN) - .write(0u64) - .write(0u64) - .finalize(), - weights.into(), - None, - false, - ), + if store + .counter_incr( + KeySerializer::new(U64_LEN) + .write(0u64) + .write(0u64) + .finalize(), + weights.into(), + None, + false, ) + .await .is_err() { return false.into(); @@ -160,7 +156,7 @@ fn train(ctx: PluginContext<'_>, is_train: bool) -> Variable { true.into() } -pub fn exec_classify(ctx: PluginContext<'_>) -> Variable { +pub async fn exec_classify(ctx: PluginContext<'_>) -> Variable { let span = ctx.span; let store = match &ctx.arguments[0] { Variable::String(v) if !v.is_empty() => ctx.core.storage.lookups.get(v.as_ref()), @@ -200,12 +196,10 @@ pub fn exec_classify(ctx: PluginContext<'_>) -> Variable { } } - let handle = ctx.handle; - // Obtain training counts let bayes_cache = &ctx.core.sieve.bayes_cache; let (spam_learns, ham_learns) = - if let Some(weights) = bayes_cache.get_or_update(TokenHash::default(), handle, store) { + if let Some(weights) = bayes_cache.get_or_update(TokenHash::default(), store).await { (weights.spam, weights.ham) } else { tracing::warn!( @@ -231,27 +225,25 @@ pub fn exec_classify(ctx: PluginContext<'_>) -> Variable { } // Classify the text + let mut tokens = Vec::new(); + for token in OsbTokenizer::<_, TokenHash>::new( + BayesTokenizer::new(text.as_ref(), &ctx.core.smtp.resolvers.psl), + 5, + ) { + if let Some(weights) = bayes_cache.get_or_update(token.inner, store).await { + tokens.push(OsbToken { + inner: weights, + idx: token.idx, + }); + } + } classifier - .classify( - OsbTokenizer::<_, TokenHash>::new( - BayesTokenizer::new(text.as_ref(), &ctx.core.smtp.resolvers.psl), - 5, - ) - .filter_map(|t| { - OsbToken { - inner: bayes_cache.get_or_update(t.inner, handle, store)?, - idx: t.idx, - } - .into() - }), - ham_learns, - spam_learns, - ) + .classify(tokens.into_iter(), ham_learns, spam_learns) .map(Variable::from) .unwrap_or_default() } -pub fn exec_is_balanced(ctx: PluginContext<'_>) -> Variable { +pub async fn exec_is_balanced(ctx: PluginContext<'_>) -> Variable { let min_balance = match &ctx.arguments[2] { Variable::Float(n) => *n, Variable::Integer(n) => *n as f64, @@ -282,10 +274,9 @@ pub fn exec_is_balanced(ctx: PluginContext<'_>) -> Variable { let learn_spam = ctx.arguments[1].to_bool(); // Obtain training counts - let handle = ctx.handle; let bayes_cache = &ctx.core.sieve.bayes_cache; let (spam_learns, ham_learns) = - if let Some(weights) = bayes_cache.get_or_update(TokenHash::default(), handle, store) { + if let Some(weights) = bayes_cache.get_or_update(TokenHash::default(), store).await { (weights.spam as f64, weights.ham as f64) } else { tracing::warn!( @@ -321,31 +312,22 @@ pub fn exec_is_balanced(ctx: PluginContext<'_>) -> Variable { } trait LookupOrInsert { - fn get_or_update( - &self, - hash: TokenHash, - handle: &Handle, - get_token: &LookupStore, - ) -> Option<Weights>; + async fn get_or_update(&self, hash: TokenHash, get_token: &LookupStore) -> Option<Weights>; } impl LookupOrInsert for BayesTokenCache { - fn get_or_update( - &self, - hash: TokenHash, - handle: &Handle, - get_token: &LookupStore, - ) -> Option<Weights> { + async fn get_or_update(&self, hash: TokenHash, get_token: &LookupStore) -> Option<Weights> { if let Some(weights) = self.get(&hash) { weights.unwrap_or_default().into() - } else if let Ok(num) = handle.block_on( - get_token.counter_get( + } else if let Ok(num) = get_token + .counter_get( KeySerializer::new(U64_LEN) .write(hash.h1) .write(hash.h2) .finalize(), - ), - ) { + ) + .await + { if num != 0 { let weights = Weights::from(num); self.insert_positive(hash, weights); diff --git a/crates/common/src/scripts/plugins/dns.rs b/crates/common/src/scripts/plugins/dns.rs index 0e588fda..e3c06e80 100644 --- a/crates/common/src/scripts/plugins/dns.rs +++ b/crates/common/src/scripts/plugins/dns.rs @@ -36,16 +36,19 @@ pub fn register_exists(plugin_id: u32, fnc_map: &mut FunctionMap) { fnc_map.set_external_function("dns_exists", plugin_id, 2); } -pub fn exec(ctx: PluginContext<'_>) -> Variable { +pub async fn exec(ctx: PluginContext<'_>) -> Variable { let entry = ctx.arguments[0].to_string(); let record_type = ctx.arguments[1].to_string(); if record_type.eq_ignore_ascii_case("ip") { - match ctx.handle.block_on(ctx.core.smtp.resolvers.dns.ip_lookup( - entry.as_ref(), - IpLookupStrategy::Ipv4thenIpv6, - 10, - )) { + match ctx + .core + .smtp + .resolvers + .dns + .ip_lookup(entry.as_ref(), IpLookupStrategy::Ipv4thenIpv6, 10) + .await + { Ok(result) => result .iter() .map(|ip| Variable::from(ip.to_string())) @@ -54,10 +57,7 @@ pub fn exec(ctx: PluginContext<'_>) -> Variable { Err(err) => err.short_error().into(), } } else if record_type.eq_ignore_ascii_case("mx") { - match ctx - .handle - .block_on(ctx.core.smtp.resolvers.dns.mx_lookup(entry.as_ref())) - { + match ctx.core.smtp.resolvers.dns.mx_lookup(entry.as_ref()).await { Ok(result) => result .iter() .flat_map(|mx| { @@ -78,18 +78,19 @@ pub fn exec(ctx: PluginContext<'_>) -> Variable { } match ctx - .handle - .block_on(ctx.core.smtp.resolvers.dns.txt_raw_lookup(entry.as_ref())) + .core + .smtp + .resolvers + .dns + .txt_raw_lookup(entry.as_ref()) + .await { Ok(result) => Variable::from(String::from_utf8(result).unwrap_or_default()), Err(err) => err.short_error().into(), } } else if record_type.eq_ignore_ascii_case("ptr") { if let Ok(addr) = entry.parse::<IpAddr>() { - match ctx - .handle - .block_on(ctx.core.smtp.resolvers.dns.ptr_lookup(addr)) - { + match ctx.core.smtp.resolvers.dns.ptr_lookup(addr).await { Ok(result) => result .iter() .map(|host| Variable::from(host.to_string())) @@ -110,8 +111,12 @@ pub fn exec(ctx: PluginContext<'_>) -> Variable { } match ctx - .handle - .block_on(ctx.core.smtp.resolvers.dns.ipv4_lookup(entry.as_ref())) + .core + .smtp + .resolvers + .dns + .ipv4_lookup(entry.as_ref()) + .await { Ok(result) => result .iter() @@ -122,8 +127,12 @@ pub fn exec(ctx: PluginContext<'_>) -> Variable { } } else if record_type.eq_ignore_ascii_case("ipv6") { match ctx - .handle - .block_on(ctx.core.smtp.resolvers.dns.ipv6_lookup(entry.as_ref())) + .core + .smtp + .resolvers + .dns + .ipv6_lookup(entry.as_ref()) + .await { Ok(result) => result .iter() @@ -137,35 +146,32 @@ pub fn exec(ctx: PluginContext<'_>) -> Variable { } } -pub fn exec_exists(ctx: PluginContext<'_>) -> Variable { +pub async fn exec_exists(ctx: PluginContext<'_>) -> Variable { let entry = ctx.arguments[0].to_string(); let record_type = ctx.arguments[1].to_string(); if record_type.eq_ignore_ascii_case("ip") { - match ctx.handle.block_on(ctx.core.smtp.resolvers.dns.ip_lookup( - entry.as_ref(), - IpLookupStrategy::Ipv4thenIpv6, - 10, - )) { + match ctx + .core + .smtp + .resolvers + .dns + .ip_lookup(entry.as_ref(), IpLookupStrategy::Ipv4thenIpv6, 10) + .await + { Ok(result) => i64::from(!result.is_empty()), Err(Error::DnsRecordNotFound(_)) => 0, Err(_) => -1, } } else if record_type.eq_ignore_ascii_case("mx") { - match ctx - .handle - .block_on(ctx.core.smtp.resolvers.dns.mx_lookup(entry.as_ref())) - { + match ctx.core.smtp.resolvers.dns.mx_lookup(entry.as_ref()).await { Ok(result) => i64::from(result.iter().any(|mx| !mx.exchanges.is_empty())), Err(Error::DnsRecordNotFound(_)) => 0, Err(_) => -1, } } else if record_type.eq_ignore_ascii_case("ptr") { if let Ok(addr) = entry.parse::<IpAddr>() { - match ctx - .handle - .block_on(ctx.core.smtp.resolvers.dns.ptr_lookup(addr)) - { + match ctx.core.smtp.resolvers.dns.ptr_lookup(addr).await { Ok(result) => i64::from(!result.is_empty()), Err(Error::DnsRecordNotFound(_)) => 0, Err(_) => -1, @@ -182,8 +188,12 @@ pub fn exec_exists(ctx: PluginContext<'_>) -> Variable { } match ctx - .handle - .block_on(ctx.core.smtp.resolvers.dns.ipv4_lookup(entry.as_ref())) + .core + .smtp + .resolvers + .dns + .ipv4_lookup(entry.as_ref()) + .await { Ok(result) => i64::from(!result.is_empty()), Err(Error::DnsRecordNotFound(_)) => 0, @@ -191,8 +201,12 @@ pub fn exec_exists(ctx: PluginContext<'_>) -> Variable { } } else if record_type.eq_ignore_ascii_case("ipv6") { match ctx - .handle - .block_on(ctx.core.smtp.resolvers.dns.ipv6_lookup(entry.as_ref())) + .core + .smtp + .resolvers + .dns + .ipv6_lookup(entry.as_ref()) + .await { Ok(result) => i64::from(!result.is_empty()), Err(Error::DnsRecordNotFound(_)) => 0, diff --git a/crates/common/src/scripts/plugins/exec.rs b/crates/common/src/scripts/plugins/exec.rs index d15ba260..2e35484a 100644 --- a/crates/common/src/scripts/plugins/exec.rs +++ b/crates/common/src/scripts/plugins/exec.rs @@ -31,32 +31,38 @@ pub fn register(plugin_id: u32, fnc_map: &mut FunctionMap) { fnc_map.set_external_function("exec", plugin_id, 2); } -pub fn exec(ctx: PluginContext<'_>) -> Variable { - let span = ctx.span; +pub async fn exec(ctx: PluginContext<'_>) -> Variable { + let span = ctx.span.clone(); let mut arguments = ctx.arguments.into_iter(); - match Command::new( - arguments - .next() - .map(|a| a.to_string().into_owned()) - .unwrap_or_default(), - ) - .args( - arguments - .next() - .map(|a| a.into_string_array()) - .unwrap_or_default(), - ) - .output() - { - Ok(result) => result.status.success().into(), - Err(err) => { - tracing::warn!( - parent: span, - context = "sieve", - event = "execute-failed", - reason = %err, - ); - false.into() + + tokio::task::spawn_blocking(move || { + match Command::new( + arguments + .next() + .map(|a| a.to_string().into_owned()) + .unwrap_or_default(), + ) + .args( + arguments + .next() + .map(|a| a.into_string_array()) + .unwrap_or_default(), + ) + .output() + { + Ok(result) => result.status.success(), + Err(err) => { + tracing::warn!( + parent: span, + context = "sieve", + event = "execute-failed", + reason = %err, + ); + false + } } - } + }) + .await + .unwrap_or_default() + .into() } diff --git a/crates/common/src/scripts/plugins/http.rs b/crates/common/src/scripts/plugins/http.rs index c4bdf1ea..4161500c 100644 --- a/crates/common/src/scripts/plugins/http.rs +++ b/crates/common/src/scripts/plugins/http.rs @@ -32,7 +32,7 @@ pub fn register_header(plugin_id: u32, fnc_map: &mut FunctionMap) { fnc_map.set_external_function("http_header", plugin_id, 4); } -pub fn exec_header(ctx: PluginContext<'_>) -> Variable { +pub async fn exec_header(ctx: PluginContext<'_>) -> Variable { let url = ctx.arguments[0].to_string(); let header = ctx.arguments[1].to_string(); let agent = ctx.arguments[2].to_string(); @@ -50,9 +50,10 @@ pub fn exec_header(ctx: PluginContext<'_>) -> Variable { .danger_accept_invalid_certs(true) .build() { - let _enter = ctx.handle.enter(); - ctx.handle - .block_on(client.get(url.as_ref()).send()) + client + .get(url.as_ref()) + .send() + .await .ok() .and_then(|response| { response diff --git a/crates/common/src/scripts/plugins/lookup.rs b/crates/common/src/scripts/plugins/lookup.rs index d12c71e4..7f098eb0 100644 --- a/crates/common/src/scripts/plugins/lookup.rs +++ b/crates/common/src/scripts/plugins/lookup.rs @@ -55,7 +55,7 @@ pub fn register_local_domain(plugin_id: u32, fnc_map: &mut FunctionMap) { fnc_map.set_external_function("is_local_domain", plugin_id, 2); } -pub fn exec(ctx: PluginContext<'_>) -> Variable { +pub async fn exec(ctx: PluginContext<'_>) -> Variable { let store = match &ctx.arguments[0] { Variable::String(v) if !v.is_empty() => ctx.core.storage.lookups.get(v.as_ref()), _ => Some(&ctx.core.storage.lookup), @@ -66,9 +66,9 @@ pub fn exec(ctx: PluginContext<'_>) -> Variable { Variable::Array(items) => { for item in items.iter() { if !item.is_empty() - && ctx - .handle - .block_on(store.key_exists(item.to_string().into_owned().into_bytes())) + && store + .key_exists(item.to_string().into_owned().into_bytes()) + .await .unwrap_or(false) { return true.into(); @@ -76,9 +76,9 @@ pub fn exec(ctx: PluginContext<'_>) -> Variable { } false } - v if !v.is_empty() => ctx - .handle - .block_on(store.key_exists(v.to_string().into_owned().into_bytes())) + v if !v.is_empty() => store + .key_exists(v.to_string().into_owned().into_bytes()) + .await .unwrap_or(false), _ => false, } @@ -95,19 +95,16 @@ pub fn exec(ctx: PluginContext<'_>) -> Variable { .into() } -pub fn exec_get(ctx: PluginContext<'_>) -> Variable { +pub async fn exec_get(ctx: PluginContext<'_>) -> Variable { let store = match &ctx.arguments[0] { Variable::String(v) if !v.is_empty() => ctx.core.storage.lookups.get(v.as_ref()), _ => Some(&ctx.core.storage.lookup), }; if let Some(store) = store { - ctx.handle - .block_on( - store.key_get::<VariableWrapper>( - ctx.arguments[1].to_string().into_owned().into_bytes(), - ), - ) + store + .key_get::<VariableWrapper>(ctx.arguments[1].to_string().into_owned().into_bytes()) + .await .unwrap_or_default() .map(|v| v.into_inner()) .unwrap_or_default() @@ -123,7 +120,7 @@ pub fn exec_get(ctx: PluginContext<'_>) -> Variable { } } -pub fn exec_set(ctx: PluginContext<'_>) -> Variable { +pub async fn exec_set(ctx: PluginContext<'_>) -> Variable { let store = match &ctx.arguments[0] { Variable::String(v) if !v.is_empty() => ctx.core.storage.lookups.get(v.as_ref()), _ => Some(&ctx.core.storage.lookup), @@ -136,8 +133,8 @@ pub fn exec_set(ctx: PluginContext<'_>) -> Variable { _ => None, }; - ctx.handle - .block_on(store.key_set( + store + .key_set( ctx.arguments[1].to_string().into_owned().into_bytes(), if !ctx.arguments[2].is_empty() { bincode::serialize(&ctx.arguments[2]).unwrap_or_default() @@ -145,7 +142,8 @@ pub fn exec_set(ctx: PluginContext<'_>) -> Variable { vec![] }, expires, - )) + ) + .await .is_ok() .into() } else { @@ -160,7 +158,7 @@ pub fn exec_set(ctx: PluginContext<'_>) -> Variable { } } -pub fn exec_remote(ctx: PluginContext<'_>) -> Variable { +pub async fn exec_remote(ctx: PluginContext<'_>) -> Variable { let resource = ctx.arguments[0].to_string(); let item = ctx.arguments[1].to_string(); @@ -228,126 +226,129 @@ pub fn exec_remote(ctx: PluginContext<'_>) -> Variable { } } - // Lock remote list for writing - let mut _lock = ctx.core.sieve.remote_lists.write(); - let list = _lock - .entry(resource.to_string()) - .or_insert_with(|| RemoteList { - entries: HashSet::new(), - expires: Instant::now(), - }); - - // Make sure that the list is still expired - if list.expires > Instant::now() { - return list.entries.contains(item.as_ref()).into(); - } - - let _enter = ctx.handle.enter(); - match ctx - .handle - .block_on( - reqwest::Client::builder() - .timeout(TIMEOUT) - .user_agent(USER_AGENT) - .build() - .unwrap_or_default() - .get(resource.as_ref()) - .send(), - ) - .and_then(|r| { - if r.status().is_success() { - ctx.handle.block_on(r.bytes()).map(Ok) - } else { - Ok(Err(r)) - } - }) { - Ok(Ok(bytes)) => { - let reader: Box<dyn std::io::Read> = if resource.ends_with(".gz") { - Box::new(flate2::read::GzDecoder::new(&bytes[..])) - } else { - Box::new(&bytes[..]) - }; - - for (pos, line) in BufReader::new(reader).lines().enumerate() { - match line { - Ok(line_) => { - // Clear list once the first entry has been successfully fetched, decompressed and UTF8-decoded - if pos == 0 { - list.entries.clear(); - } + match reqwest::Client::builder() + .timeout(TIMEOUT) + .user_agent(USER_AGENT) + .build() + .unwrap_or_default() + .get(resource.as_ref()) + .send() + .await + { + Ok(response) if response.status().is_success() => { + match response.bytes().await { + Ok(bytes) => { + let reader: Box<dyn std::io::Read> = if resource.ends_with(".gz") { + Box::new(flate2::read::GzDecoder::new(&bytes[..])) + } else { + Box::new(&bytes[..]) + }; + + // Lock remote list for writing + let mut _lock = ctx.core.sieve.remote_lists.write(); + let list = _lock + .entry(resource.to_string()) + .or_insert_with(|| RemoteList { + entries: HashSet::new(), + expires: Instant::now(), + }); + + // Make sure that the list is still expired + if list.expires > Instant::now() { + return list.entries.contains(item.as_ref()).into(); + } - match &format { - Format::List => { - let line = line_.trim(); - if !line.is_empty() { - list.entries.insert(line.to_string()); + for (pos, line) in BufReader::new(reader).lines().enumerate() { + match line { + Ok(line_) => { + // Clear list once the first entry has been successfully fetched, decompressed and UTF8-decoded + if pos == 0 { + list.entries.clear(); } - } - Format::Csv { - column, - separator, - skip_first, - } if pos > 0 || !*skip_first => { - let mut in_quote = false; - let mut col_num = 0; - let mut entry = String::new(); - - for ch in line_.chars() { - if ch != '"' { - if ch == *separator && !in_quote { - if col_num == *column { - break; + + match &format { + Format::List => { + let line = line_.trim(); + if !line.is_empty() { + list.entries.insert(line.to_string()); + } + } + Format::Csv { + column, + separator, + skip_first, + } if pos > 0 || !*skip_first => { + let mut in_quote = false; + let mut col_num = 0; + let mut entry = String::new(); + + for ch in line_.chars() { + if ch != '"' { + if ch == *separator && !in_quote { + if col_num == *column { + break; + } else { + col_num += 1; + } + } else if col_num == *column { + entry.push(ch); + if entry.len() > MAX_ENTRY_SIZE { + break; + } + } } else { - col_num += 1; - } - } else if col_num == *column { - entry.push(ch); - if entry.len() > MAX_ENTRY_SIZE { - break; + in_quote = !in_quote; } } - } else { - in_quote = !in_quote; - } - } - if !entry.is_empty() { - list.entries.insert(entry); + if !entry.is_empty() { + list.entries.insert(entry); + } + } + _ => (), } } - _ => (), + Err(err) => { + tracing::warn!( + parent: ctx.span, + context = "sieve:key_exists_http", + event = "failed", + resource = resource.as_ref(), + reason = %err, + ); + break; + } + } + + if list.entries.len() == MAX_ENTRIES { + break; } } - Err(err) => { - tracing::warn!( - parent: ctx.span, - context = "sieve:key_exists_http", - event = "failed", - resource = resource.as_ref(), - reason = %err, - ); - break; - } - } - if list.entries.len() == MAX_ENTRIES { - break; + tracing::debug!( + parent: ctx.span, + context = "sieve:key_exists_http", + event = "fetch", + resource = resource.as_ref(), + num_entries = list.entries.len(), + ); + + // Update expiration + list.expires = Instant::now() + expires; + return list.entries.contains(item.as_ref()).into(); + } + Err(err) => { + tracing::warn!( + parent: ctx.span, + context = "sieve:key_exists_http", + event = "failed", + resource = resource.as_ref(), + reason = %err, + ); } } - - tracing::debug!( - parent: ctx.span, - context = "sieve:key_exists_http", - event = "fetch", - resource = resource.as_ref(), - num_entries = list.entries.len(), - ); - - // Update expiration - list.expires = Instant::now() + expires; - return list.entries.contains(item.as_ref()).into(); } - Ok(Err(response)) => { + Ok(response) => { tracing::warn!( parent: ctx.span, context = "sieve:key_exists_http", @@ -368,11 +369,22 @@ pub fn exec_remote(ctx: PluginContext<'_>) -> Variable { } // Something went wrong, try again in one hour - list.expires = Instant::now() + RETRY; - false.into() + let mut _lock = ctx.core.sieve.remote_lists.write(); + let list = _lock + .entry(resource.to_string()) + .or_insert_with(|| RemoteList { + entries: HashSet::new(), + expires: Instant::now(), + }); + if list.expires > Instant::now() { + list.entries.contains(item.as_ref()).into() + } else { + list.expires = Instant::now() + RETRY; + false.into() + } } -pub fn exec_local_domain(ctx: PluginContext<'_>) -> Variable { +pub async fn exec_local_domain(ctx: PluginContext<'_>) -> Variable { let domain = ctx.arguments[0].to_string(); if !domain.is_empty() { @@ -382,9 +394,9 @@ pub fn exec_local_domain(ctx: PluginContext<'_>) -> Variable { }; if let Some(directory) = directory { - return ctx - .handle - .block_on(directory.is_local_domain(domain.as_ref())) + return directory + .is_local_domain(domain.as_ref()) + .await .unwrap_or_default() .into(); } else { diff --git a/crates/common/src/scripts/plugins/mod.rs b/crates/common/src/scripts/plugins/mod.rs index a3ad8225..aab661a9 100644 --- a/crates/common/src/scripts/plugins/mod.rs +++ b/crates/common/src/scripts/plugins/mod.rs @@ -33,44 +33,21 @@ pub mod text; use mail_parser::Message; use sieve::{runtime::Variable, FunctionMap, Input}; -use tokio::runtime::Handle; use crate::Core; use super::ScriptModification; type RegisterPluginFnc = fn(u32, &mut FunctionMap) -> (); -type ExecPluginFnc = fn(PluginContext<'_>) -> Variable; pub struct PluginContext<'x> { pub span: &'x tracing::Span, - pub handle: &'x Handle, pub core: &'x Core, pub message: &'x Message<'x>, pub modifications: &'x mut Vec<ScriptModification>, pub arguments: Vec<Variable>, } -const PLUGINS_EXEC: [ExecPluginFnc; 18] = [ - query::exec, - exec::exec, - lookup::exec, - lookup::exec_get, - lookup::exec_set, - lookup::exec_remote, - lookup::exec_local_domain, - dns::exec, - dns::exec_exists, - http::exec_header, - bayes::exec_train, - bayes::exec_untrain, - bayes::exec_classify, - bayes::exec_is_balanced, - pyzor::exec, - headers::exec, - text::exec_tokenize, - text::exec_domain_part, -]; const PLUGINS_REGISTER: [RegisterPluginFnc; 18] = [ query::register, exec::register, @@ -100,7 +77,7 @@ impl RegisterSievePlugins for FunctionMap { fn register_plugins(mut self) -> Self { #[cfg(feature = "test_mode")] { - self.set_external_function("print", PLUGINS_EXEC.len() as u32, 1) + self.set_external_function("print", PLUGINS_REGISTER.len() as u32, 1) } for (i, fnc) in PLUGINS_REGISTER.iter().enumerate() { @@ -111,17 +88,34 @@ impl RegisterSievePlugins for FunctionMap { } impl Core { - pub fn run_plugin_blocking(&self, id: u32, ctx: PluginContext<'_>) -> Input { + pub async fn run_plugin(&self, id: u32, ctx: PluginContext<'_>) -> Input { #[cfg(feature = "test_mode")] - if id == PLUGINS_EXEC.len() as u32 { + if id == PLUGINS_REGISTER.len() as u32 { return test_print(ctx); } - PLUGINS_EXEC - .get(id as usize) - .map(|fnc| fnc(ctx)) - .unwrap_or_default() - .into() + match id { + 0 => query::exec(ctx).await, + 1 => exec::exec(ctx).await, + 2 => lookup::exec(ctx).await, + 3 => lookup::exec_get(ctx).await, + 4 => lookup::exec_set(ctx).await, + 5 => lookup::exec_remote(ctx).await, + 6 => lookup::exec_local_domain(ctx).await, + 7 => dns::exec(ctx).await, + 8 => dns::exec_exists(ctx).await, + 9 => http::exec_header(ctx).await, + 10 => bayes::exec_train(ctx).await, + 11 => bayes::exec_untrain(ctx).await, + 12 => bayes::exec_classify(ctx).await, + 13 => bayes::exec_is_balanced(ctx).await, + 14 => pyzor::exec(ctx).await, + 15 => headers::exec(ctx), + 16 => text::exec_tokenize(ctx), + 17 => text::exec_domain_part(ctx), + _ => unreachable!(), + } + .into() } } diff --git a/crates/common/src/scripts/plugins/pyzor.rs b/crates/common/src/scripts/plugins/pyzor.rs index c455317d..4a756060 100644 --- a/crates/common/src/scripts/plugins/pyzor.rs +++ b/crates/common/src/scripts/plugins/pyzor.rs @@ -52,7 +52,7 @@ pub fn register(plugin_id: u32, fnc_map: &mut FunctionMap) { fnc_map.set_external_function("pyzor_check", plugin_id, 2); } -pub fn exec(ctx: PluginContext<'_>) -> Variable { +pub async fn exec(ctx: PluginContext<'_>) -> Variable { // Make sure there is at least one text part if !ctx .message @@ -101,10 +101,7 @@ pub fn exec(ctx: PluginContext<'_>) -> Variable { 5, )); // Send message to address - match ctx - .handle - .block_on(pyzor_send_message(address.as_ref(), timeout, &request)) - { + match pyzor_send_message(address.as_ref(), timeout, &request).await { Ok(response) => response.into(), Err(err) => { tracing::debug!( diff --git a/crates/common/src/scripts/plugins/query.rs b/crates/common/src/scripts/plugins/query.rs index 6102c9df..b3522917 100644 --- a/crates/common/src/scripts/plugins/query.rs +++ b/crates/common/src/scripts/plugins/query.rs @@ -33,7 +33,7 @@ pub fn register(plugin_id: u32, fnc_map: &mut FunctionMap) { fnc_map.set_external_function("query", plugin_id, 3); } -pub fn exec(ctx: PluginContext<'_>) -> Variable { +pub async fn exec(ctx: PluginContext<'_>) -> Variable { let span = ctx.span; // Obtain store name @@ -79,7 +79,7 @@ pub fn exec(ctx: PluginContext<'_>) -> Variable { .get(..6) .map_or(false, |q| q.eq_ignore_ascii_case(b"SELECT")) { - if let Ok(mut rows) = ctx.handle.block_on(store.query::<Rows>(&query, arguments)) { + if let Ok(mut rows) = store.query::<Rows>(&query, arguments).await { match rows.rows.len().cmp(&1) { Ordering::Equal => { let mut row = rows.rows.pop().unwrap().values; @@ -116,9 +116,6 @@ pub fn exec(ctx: PluginContext<'_>) -> Variable { false.into() } } else { - ctx.handle - .block_on(store.query::<usize>(&query, arguments)) - .is_ok() - .into() + store.query::<usize>(&query, arguments).await.is_ok().into() } } diff --git a/crates/smtp/src/core/mod.rs b/crates/smtp/src/core/mod.rs index 76d0e6ec..99afd5e0 100644 --- a/crates/smtp/src/core/mod.rs +++ b/crates/smtp/src/core/mod.rs @@ -60,7 +60,6 @@ use self::throttle::{ThrottleKey, ThrottleKeyHasherBuilder}; pub mod params; pub mod throttle; -pub mod worker; #[derive(Clone)] pub struct SmtpInstance { @@ -95,7 +94,6 @@ pub struct SMTP { } pub struct Inner { - pub worker_pool: rayon::ThreadPool, pub session_throttle: DashMap<ThrottleKey, ConcurrencyLimiter, ThrottleKeyHasherBuilder>, pub queue_throttle: DashMap<ThrottleKey, ConcurrencyLimiter, ThrottleKeyHasherBuilder>, pub queue_tx: mpsc::Sender<queue::Event>, @@ -431,10 +429,6 @@ impl SessionAddress { impl Default for Inner { fn default() -> Self { Self { - worker_pool: rayon::ThreadPoolBuilder::new() - .num_threads(num_cpus::get()) - .build() - .unwrap(), session_throttle: Default::default(), queue_throttle: Default::default(), queue_tx: mpsc::channel(1).0, diff --git a/crates/smtp/src/core/throttle.rs b/crates/smtp/src/core/throttle.rs index 6da347c2..cf9059ab 100644 --- a/crates/smtp/src/core/throttle.rs +++ b/crates/smtp/src/core/throttle.rs @@ -29,9 +29,12 @@ use common::{ use dashmap::mapref::entry::Entry; use utils::config::Rate; -use std::hash::{BuildHasher, Hash, Hasher}; +use std::{ + hash::{BuildHasher, Hash, Hasher}, + sync::atomic::Ordering, +}; -use super::Session; +use super::{Session, SMTP}; #[derive(Debug, Clone, Eq)] pub struct ThrottleKey { @@ -318,3 +321,11 @@ impl<T: SessionStream> Session<T> { .is_none() } } + +impl SMTP { + pub fn cleanup(&self) { + for throttle in [&self.inner.session_throttle, &self.inner.queue_throttle] { + throttle.retain(|_, v| v.concurrent.load(Ordering::Relaxed) > 0); + } + } +} diff --git a/crates/smtp/src/core/worker.rs b/crates/smtp/src/core/worker.rs deleted file mode 100644 index a862e221..00000000 --- a/crates/smtp/src/core/worker.rs +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright (c) 2023 Stalwart Labs Ltd. - * - * This file is part of Stalwart Mail Server. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as - * published by the Free Software Foundation, either version 3 of - * the License, or (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * in the LICENSE file at the top-level directory of this distribution. - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see <http://www.gnu.org/licenses/>. - * - * You can be released from the requirements of the AGPLv3 license by - * purchasing a commercial license. Please contact licensing@stalw.art - * for more details. -*/ - -use std::sync::atomic::Ordering; - -use tokio::sync::oneshot; - -use super::SMTP; - -impl SMTP { - pub async fn spawn_worker<U, V>(&self, f: U) -> Option<V> - where - U: FnOnce() -> V + Send, - V: Sync + Send + 'static, - { - let (tx, rx) = oneshot::channel(); - - self.inner.worker_pool.scope(|s| { - s.spawn(|_| { - tx.send(f()).ok(); - }); - }); - - match rx.await { - Ok(result) => Some(result), - Err(err) => { - tracing::warn!( - context = "worker-pool", - event = "error", - reason = %err, - ); - None - } - } - } - - fn cleanup(&self) { - for throttle in [&self.inner.session_throttle, &self.inner.queue_throttle] { - throttle.retain(|_, v| v.concurrent.load(Ordering::Relaxed) > 0); - } - } - - pub fn spawn_cleanup(&self) { - let core = self.clone(); - self.inner.worker_pool.spawn(move || { - core.cleanup(); - }); - } -} diff --git a/crates/smtp/src/lib.rs b/crates/smtp/src/lib.rs index 2fa99e33..33476016 100644 --- a/crates/smtp/src/lib.rs +++ b/crates/smtp/src/lib.rs @@ -57,16 +57,6 @@ impl SMTP { let (queue_tx, queue_rx) = mpsc::channel(1024); let (report_tx, report_rx) = mpsc::channel(1024); let inner = Inner { - worker_pool: rayon::ThreadPoolBuilder::new() - .num_threads(std::cmp::max( - config - .property::<usize>("global.thread-pool") - .filter(|v| *v > 0) - .unwrap_or_else(num_cpus::get), - 4, - )) - .build() - .unwrap(), session_throttle: DashMap::with_capacity_and_hasher_and_shard_amount( capacity, ThrottleKeyHasherBuilder::default(), diff --git a/crates/smtp/src/reporting/analysis.rs b/crates/smtp/src/reporting/analysis.rs index e11b09cd..c96e2faa 100644 --- a/crates/smtp/src/reporting/analysis.rs +++ b/crates/smtp/src/reporting/analysis.rs @@ -41,7 +41,6 @@ use store::{ write::{now, BatchBuilder, Bincode, ReportClass, ValueClass}, Serialize, }; -use tokio::runtime::Handle; use crate::core::SMTP; @@ -74,8 +73,7 @@ pub struct IncomingReport<T> { impl SMTP { pub fn analyze_report(&self, message: Arc<Vec<u8>>) { let core = self.clone(); - let handle = Handle::current(); - self.inner.worker_pool.spawn(move || { + tokio::spawn(async move { let message = if let Some(message) = MessageParser::default().parse(message.as_ref()) { message } else { @@ -324,17 +322,14 @@ impl SMTP { } } let batch = batch.build(); - let _enter = handle.enter(); - handle.spawn(async move { - if let Err(err) = core.core.storage.data.write(batch).await { - tracing::warn!( - context = "report", - event = "error", - "Failed to write incoming report: {}", - err - ); - } - }); + if let Err(err) = core.core.storage.data.write(batch).await { + tracing::warn!( + context = "report", + event = "error", + "Failed to write incoming report: {}", + err + ); + } } return; } diff --git a/crates/smtp/src/reporting/scheduler.rs b/crates/smtp/src/reporting/scheduler.rs index eba65074..e6e526df 100644 --- a/crates/smtp/src/reporting/scheduler.rs +++ b/crates/smtp/src/reporting/scheduler.rs @@ -107,7 +107,7 @@ impl SpawnReport for mpsc::Receiver<Event> { // Cleanup expired throttles if last_cleanup.elapsed().as_secs() >= 86400 { last_cleanup = Instant::now(); - core.spawn_cleanup(); + core.cleanup(); } } } diff --git a/crates/smtp/src/scripts/event_loop.rs b/crates/smtp/src/scripts/event_loop.rs index e3e9ce8a..218b612b 100644 --- a/crates/smtp/src/scripts/event_loop.rs +++ b/crates/smtp/src/scripts/event_loop.rs @@ -33,18 +33,16 @@ use smtp_proto::{ MAIL_BY_TRACE, MAIL_RET_FULL, MAIL_RET_HDRS, RCPT_NOTIFY_DELAY, RCPT_NOTIFY_FAILURE, RCPT_NOTIFY_NEVER, RCPT_NOTIFY_SUCCESS, }; -use tokio::runtime::Handle; use crate::{core::SMTP, inbound::DkimSign, queue::DomainPart}; use super::{ScriptModification, ScriptParameters, ScriptResult}; impl SMTP { - pub fn run_script_blocking( + pub async fn run_script( &self, script: Arc<Sieve>, - params: ScriptParameters, - handle: Handle, + params: ScriptParameters<'_>, span: tracing::Span, ) -> ScriptResult { // Create filter instance @@ -92,16 +90,17 @@ impl SMTP { 'outer: for list in lists { if let Some(store) = self.core.storage.lookups.get(&list) { for value in &values { - if let Ok(true) = handle.block_on( - store.key_exists( + if let Ok(true) = store + .key_exists( if !matches!(match_as, MatchAs::Lowercase) { value.clone() } else { value.to_lowercase() } .into_bytes(), - ), - ) { + ) + .await + { input = true.into(); break 'outer; } @@ -117,17 +116,19 @@ impl SMTP { } } Event::Function { id, arguments } => { - input = self.core.run_plugin_blocking( - id, - PluginContext { - span: &span, - handle: &handle, - core: &self.core, - message: instance.message(), - modifications: &mut modifications, - arguments, - }, - ); + input = self + .core + .run_plugin( + id, + PluginContext { + span: &span, + core: &self.core, + message: instance.message(), + modifications: &mut modifications, + arguments, + }, + ) + .await; } Event::Keep { message_id, .. } => { keep_id = message_id; @@ -158,11 +159,11 @@ impl SMTP { ); match recipient { Recipient::Address(rcpt) => { - handle.block_on(message.add_recipient(rcpt, self)); + message.add_recipient(rcpt, self).await; } Recipient::Group(rcpt_list) => { for rcpt in rcpt_list { - handle.block_on(message.add_recipient(rcpt, self)); + message.add_recipient(rcpt, self).await; } } Recipient::List(list) => { @@ -296,13 +297,10 @@ impl SMTP { None }; - if handle.block_on(self.has_quota(&mut message)) { - handle.block_on(message.queue( - headers.as_deref(), - raw_message, - self, - &span, - )); + if self.has_quota(&mut message).await { + message + .queue(headers.as_deref(), raw_message, self, &span) + .await; } else { tracing::warn!( parent: &span, diff --git a/crates/smtp/src/scripts/exec.rs b/crates/smtp/src/scripts/exec.rs index 7e6d8702..ff9fa602 100644 --- a/crates/smtp/src/scripts/exec.rs +++ b/crates/smtp/src/scripts/exec.rs @@ -27,7 +27,6 @@ use common::listener::SessionStream; use mail_auth::common::resolver::ToReverseName; use sieve::{runtime::Variable, Envelope, Sieve}; use smtp_proto::*; -use tokio::runtime::Handle; use crate::{core::Session, inbound::AuthResult}; @@ -145,12 +144,6 @@ impl<T: SessionStream> Session<T> { let span = self.span.clone(); let params = params.with_envelope(&self.core.core, self).await; - let handle = Handle::current(); - self.core - .spawn_worker(move || core.run_script_blocking(script, params, handle, span)) - .await - .unwrap_or(ScriptResult::Accept { - modifications: vec![], - }) + core.run_script(script, params, span).await } } diff --git a/crates/store/src/write/key.rs b/crates/store/src/write/key.rs index f5de104c..fd78dda4 100644 --- a/crates/store/src/write/key.rs +++ b/crates/store/src/write/key.rs @@ -661,7 +661,9 @@ impl Deserialize for ReportEvent { .and_then(|domain| std::str::from_utf8(domain).ok()) .map(|s| s.to_string()) .ok_or_else(|| { - crate::Error::InternalError("Failed to deserialize report domain".into()) + crate::Error::InternalError(format!( + "Failed to deserialize report domain: {key:?}" + )) })?, }) } diff --git a/tests/src/smtp/inbound/antispam.rs b/tests/src/smtp/inbound/antispam.rs index b9fb9c7c..b49148f0 100644 --- a/tests/src/smtp/inbound/antispam.rs +++ b/tests/src/smtp/inbound/antispam.rs @@ -22,7 +22,6 @@ use smtp::{ scripts::ScriptResult, }; use store::Stores; -use tokio::runtime::Handle; use utils::config::Config; use crate::smtp::{build_smtp, session::TestSession, TempDir}; @@ -419,15 +418,10 @@ async fn antispam() { } // Run script - let handle = Handle::current(); let span = span.clone(); let core_ = core.clone(); let script = script.clone(); - match core - .spawn_worker(move || core_.run_script_blocking(script, params, handle, span)) - .await - .unwrap() - { + match core_.run_script(script, params, span).await { ScriptResult::Accept { modifications } => { if modifications.len() != expected_headers.len() { panic!( diff --git a/tests/src/smtp/inbound/scripts.rs b/tests/src/smtp/inbound/scripts.rs index 9409778c..788ed341 100644 --- a/tests/src/smtp/inbound/scripts.rs +++ b/tests/src/smtp/inbound/scripts.rs @@ -37,7 +37,6 @@ use smtp::{ scripts::ScriptResult, }; use store::Stores; -use tokio::runtime::Handle; use utils::config::Config; const CONFIG: &str = r#" @@ -182,14 +181,9 @@ async fn sieve_scripts() { .set_variable("from", "john.doe@example.org") .with_envelope(&core.core, &session) .await; - let handle = Handle::current(); let span = span.clone(); let core_ = core.clone(); - match core - .spawn_worker(move || core_.run_script_blocking(script, params, handle, span)) - .await - .unwrap() - { + match core_.run_script(script, params, span).await { ScriptResult::Accept { .. } => (), ScriptResult::Reject(message) => panic!("{}", message), err => { |