diff options
Diffstat (limited to 'crates/common/src')
-rw-r--r-- | crates/common/src/auth/mod.rs | 9 | ||||
-rw-r--r-- | crates/common/src/auth/oauth/introspect.rs | 97 | ||||
-rw-r--r-- | crates/common/src/auth/oauth/mod.rs | 89 | ||||
-rw-r--r-- | crates/common/src/auth/oauth/token.rs | 163 | ||||
-rw-r--r-- | crates/common/src/auth/sasl.rs | 131 |
5 files changed, 373 insertions, 116 deletions
diff --git a/crates/common/src/auth/mod.rs b/crates/common/src/auth/mod.rs index 5a67a0e0..7b74cff8 100644 --- a/crates/common/src/auth/mod.rs +++ b/crates/common/src/auth/mod.rs @@ -11,6 +11,7 @@ use directory::{ }; use jmap_proto::types::collection::Collection; use mail_send::Credentials; +use oauth::GrantType; use utils::map::{bitmap::Bitmap, ttl_dashmap::TtlMap, vec_map::VecMap}; use crate::Server; @@ -18,6 +19,7 @@ use crate::Server; pub mod access_token; pub mod oauth; pub mod roles; +pub mod sasl; #[derive(Debug, Clone, Default)] pub struct AccessToken { @@ -58,8 +60,11 @@ impl Server { // Validate credentials match &req.credentials { Credentials::OAuthBearer { token } => { - match self.validate_access_token("access_token", token).await { - Ok((account_id, _, _)) => self.get_cached_access_token(account_id).await, + match self + .validate_access_token(GrantType::AccessToken.into(), token) + .await + { + Ok(token_into) => self.get_cached_access_token(token_into.account_id).await, Err(err) => Err(err), } } diff --git a/crates/common/src/auth/oauth/introspect.rs b/crates/common/src/auth/oauth/introspect.rs new file mode 100644 index 00000000..f4e931b5 --- /dev/null +++ b/crates/common/src/auth/oauth/introspect.rs @@ -0,0 +1,97 @@ +/* + * SPDX-FileCopyrightText: 2020 Stalwart Labs Ltd <hello@stalw.art> + * + * SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-SEL + */ + +use serde::{Deserialize, Serialize}; +use trc::{AddContext, AuthEvent, EventType}; + +use crate::{auth::AccessToken, Server}; + +#[derive(Debug, Default, Clone, Eq, PartialEq, Deserialize, Serialize)] +pub struct OAuthIntrospect { + #[serde(default)] + pub active: bool, + + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + pub scope: Option<String>, + + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + pub client_id: Option<String>, + + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + pub username: Option<String>, + + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + pub token_type: Option<String>, + + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + pub exp: Option<i64>, + + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + pub iat: Option<i64>, + + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + pub nbf: Option<i64>, + + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + pub sub: Option<String>, + /*#[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + pub aud: Option<String>, + + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + pub iss: Option<String>, + + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + pub jti: Option<String>,*/ +} + +impl Server { + pub async fn introspect_access_token( + &self, + token: &str, + access_token: &AccessToken, + ) -> trc::Result<OAuthIntrospect> { + match self.validate_access_token(None, token).await { + Ok(token_info) => Ok(OAuthIntrospect { + active: true, + client_id: Some(token_info.client_id), + username: if access_token.primary_id() == token_info.account_id { + access_token.name.clone() + } else { + self.get_cached_access_token(token_info.account_id) + .await + .caused_by(trc::location!())? + .name + .clone() + } + .into(), + token_type: "bearer".to_string().into(), + exp: Some(token_info.expiry as i64), + iat: Some(token_info.issued_at as i64), + ..Default::default() + }), + Err(err) + if matches!( + err.inner, + EventType::Auth(AuthEvent::Error) | EventType::Auth(AuthEvent::TokenExpired) + ) => + { + Ok(OAuthIntrospect::default()) + } + Err(err) => Err(err), + } + } +} diff --git a/crates/common/src/auth/oauth/mod.rs b/crates/common/src/auth/oauth/mod.rs index b58474e1..20078868 100644 --- a/crates/common/src/auth/oauth/mod.rs +++ b/crates/common/src/auth/oauth/mod.rs @@ -5,6 +5,7 @@ */ pub mod crypto; +pub mod introspect; pub mod token; pub const DEVICE_CODE_LEN: usize = 40; @@ -14,68 +15,40 @@ pub const CLIENT_ID_MAX_LEN: usize = 20; pub const USER_CODE_ALPHABET: &[u8] = b"ABCDEFGHJKLMNPQRSTUVWXYZ23456789"; // No 0, O, I, 1 -pub fn extract_oauth_bearer(bytes: &[u8]) -> Option<&str> { - let mut start_pos = 0; - let eof = bytes.len().saturating_sub(1); - - for (pos, ch) in bytes.iter().enumerate() { - let is_separator = *ch == 1; - if is_separator || pos == eof { - if bytes - .get(start_pos..start_pos + 12) - .map_or(false, |s| s.eq_ignore_ascii_case(b"auth=Bearer ")) - { - return bytes - .get(start_pos + 12..if is_separator { pos } else { bytes.len() }) - .and_then(|s| std::str::from_utf8(s).ok()); - } +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum GrantType { + AccessToken, + RefreshToken, + LiveTracing, + LiveMetrics, +} - start_pos = pos + 1; +impl GrantType { + pub fn as_str(&self) -> &'static str { + match self { + GrantType::AccessToken => "access_token", + GrantType::RefreshToken => "refresh_token", + GrantType::LiveTracing => "live_tracing", + GrantType::LiveMetrics => "live_metrics", } } - None -} -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_extract_oauth_bearer() { - let input = b"auth=Bearer validtoken"; - let result = extract_oauth_bearer(input); - assert_eq!(result, Some("validtoken")); - - let input = b"auth=Invalid validtoken"; - let result = extract_oauth_bearer(input); - assert_eq!(result, None); - - let input = b"auth=Bearer"; - let result = extract_oauth_bearer(input); - assert_eq!(result, None); - - let input = b""; - let result = extract_oauth_bearer(input); - assert_eq!(result, None); - - let input = b"auth=Bearer token1\x01auth=Bearer token2"; - let result = extract_oauth_bearer(input); - assert_eq!(result, Some("token1")); - - let input = b"auth=Bearer VALIDTOKEN"; - let result = extract_oauth_bearer(input); - assert_eq!(result, Some("VALIDTOKEN")); - - let input = b"auth=Bearer token with spaces"; - let result = extract_oauth_bearer(input); - assert_eq!(result, Some("token with spaces")); - - let input = b"auth=Bearer token_with_special_chars!@#"; - let result = extract_oauth_bearer(input); - assert_eq!(result, Some("token_with_special_chars!@#")); + pub fn id(&self) -> u8 { + match self { + GrantType::AccessToken => 0, + GrantType::RefreshToken => 1, + GrantType::LiveTracing => 2, + GrantType::LiveMetrics => 3, + } + } - let input = "n,a=user@example.com,\x01host=server.example.com\x01port=143\x01auth=Bearer vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg==\x01\x01"; - let result = extract_oauth_bearer(input.as_bytes()); - assert_eq!(result, Some("vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg==")); + pub fn from_id(id: u8) -> Option<Self> { + match id { + 0 => Some(GrantType::AccessToken), + 1 => Some(GrantType::RefreshToken), + 2 => Some(GrantType::LiveTracing), + 3 => Some(GrantType::LiveMetrics), + _ => None, + } } } diff --git a/crates/common/src/auth/oauth/token.rs b/crates/common/src/auth/oauth/token.rs index 08904983..ec6cb946 100644 --- a/crates/common/src/auth/oauth/token.rs +++ b/crates/common/src/auth/oauth/token.rs @@ -13,63 +13,71 @@ use store::{ blake3, rand::{thread_rng, Rng}, }; +use trc::AddContext; use utils::codec::leb128::{Leb128Iterator, Leb128Vec}; use crate::Server; -use super::{crypto::SymmetricEncrypt, CLIENT_ID_MAX_LEN, RANDOM_CODE_LEN}; +use super::{crypto::SymmetricEncrypt, GrantType, CLIENT_ID_MAX_LEN, RANDOM_CODE_LEN}; + +pub struct TokenInfo { + pub grant_type: GrantType, + pub account_id: u32, + pub client_id: String, + pub expiry: u64, + pub issued_at: u64, + pub expires_in: u64, +} + +const OAUTH_EPOCH: u64 = 946684800; // Jan 1, 2000 impl Server { - pub async fn issue_custom_token( + pub async fn encode_access_token( &self, + grant_type: GrantType, account_id: u32, - grant_type: &str, client_id: &str, expiry_in: u64, ) -> trc::Result<String> { - self.encode_access_token( - grant_type, - account_id, - &self - .password_hash(account_id) - .await - .map_err(|err| trc::StoreEvent::UnexpectedError.into_err().details(err))?, - client_id, - expiry_in, - ) - .map_err(|err| trc::StoreEvent::UnexpectedError.into_err().details(err)) - } - - pub fn encode_access_token( - &self, - grant_type: &str, - account_id: u32, - password_hash: &str, - client_id: &str, - expiry_in: u64, - ) -> Result<String, &'static str> { // Build context if client_id.len() > CLIENT_ID_MAX_LEN { - return Err("ClientId is too long"); + return Err(trc::AuthEvent::Error + .into_err() + .details("Client id too long")); } - let key = self.core.jmap.oauth_key.clone(); + + // Include password hash if expiration is over 1 hour + let password_hash = if expiry_in > 3600 { + self.password_hash(account_id) + .await + .caused_by(trc::location!())? + } else { + String::new() + }; + + let key = &self.core.jmap.oauth_key; let context = format!( "{} {} {} {}", - grant_type, client_id, account_id, password_hash + grant_type.as_str(), + client_id, + account_id, + password_hash ); - let context_nonce = format!("{} nonce {}", grant_type, password_hash); // Set expiration time - let expiry = SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .map(|d| d.as_secs()) - .unwrap_or(0) - .saturating_sub(946684800) // Jan 1, 2000 - + expiry_in; + let issued_at = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .map_or(0, |d| d.as_secs()) + .saturating_sub(OAUTH_EPOCH); // Jan 1, 2000 + let expiry = issued_at + expiry_in; // Calculate nonce let mut hasher = blake3::Hasher::new(); - hasher.update(context_nonce.as_bytes()); + if !password_hash.is_empty() { + hasher.update(password_hash.as_bytes()); + } + hasher.update(grant_type.as_str().as_bytes()); + hasher.update(issued_at.to_be_bytes().as_slice()); hasher.update(expiry.to_be_bytes().as_slice()); let nonce = hasher .finalize() @@ -82,8 +90,15 @@ impl Server { // Encrypt random bytes let mut token = SymmetricEncrypt::new(key.as_bytes(), &context) .encrypt(&thread_rng().gen::<[u8; RANDOM_CODE_LEN]>(), &nonce) - .map_err(|_| "Failed to encrypt token.")?; + .map_err(|_| { + trc::AuthEvent::Error + .into_err() + .ctx(trc::Key::Reason, "Failed to encrypt token") + .caused_by(trc::location!()) + })?; token.push_leb128(account_id); + token.push(grant_type.id()); + token.push_leb128(issued_at); token.push_leb128(expiry); token.extend_from_slice(client_id.as_bytes()); @@ -92,9 +107,9 @@ impl Server { pub async fn validate_access_token( &self, - grant_type: &str, + expected_grant_type: Option<GrantType>, token_: &str, - ) -> trc::Result<(u32, String, u64)> { + ) -> trc::Result<TokenInfo> { // Base64 decode token let token = base64_decode(token_.as_bytes()).ok_or_else(|| { trc::AuthEvent::Error @@ -103,12 +118,14 @@ impl Server { .caused_by(trc::location!()) .details(token_.to_string()) })?; - let (account_id, expiry, client_id) = token + let (account_id, grant_type, issued_at, expiry, client_id) = token .get((RANDOM_CODE_LEN + SymmetricEncrypt::ENCRYPT_TAG_LEN)..) .and_then(|bytes| { let mut bytes = bytes.iter(); ( bytes.next_leb128()?, + GrantType::from_id(bytes.next().copied()?)?, + bytes.next_leb128::<u64>()?, bytes.next_leb128::<u64>()?, bytes.copied().map(char::from).collect::<String>(), ) @@ -125,30 +142,45 @@ impl Server { // Validate expiration let now = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) - .map(|d| d.as_secs()) - .unwrap_or(0) - .saturating_sub(946684800); // Jan 1, 2000 - if expiry <= now { + .map_or(0, |d| d.as_secs()) + .saturating_sub(OAUTH_EPOCH); // Jan 1, 2000 + if expiry <= now || issued_at > now { return Err(trc::AuthEvent::TokenExpired.into_err()); } + // Validate grant type + if expected_grant_type.map_or(false, |g| g != grant_type) { + return Err(trc::AuthEvent::Error + .into_err() + .details("Invalid grant type")); + } + // Obtain password hash - let password_hash = self - .password_hash(account_id) - .await - .map_err(|err| trc::AuthEvent::Error.into_err().ctx(trc::Key::Details, err))?; + let password_hash = if expiry - issued_at > 3600 { + self.password_hash(account_id) + .await + .map_err(|err| trc::AuthEvent::Error.into_err().ctx(trc::Key::Details, err))? + } else { + String::new() + }; // Build context let key = self.core.jmap.oauth_key.clone(); let context = format!( "{} {} {} {}", - grant_type, client_id, account_id, password_hash + grant_type.as_str(), + client_id, + account_id, + password_hash ); - let context_nonce = format!("{} nonce {}", grant_type, password_hash); // Calculate nonce let mut hasher = blake3::Hasher::new(); - hasher.update(context_nonce.as_bytes()); + if !password_hash.is_empty() { + hasher.update(password_hash.as_bytes()); + } + hasher.update(grant_type.as_str().as_bytes()); + hasher.update(issued_at.to_be_bytes().as_slice()); hasher.update(expiry.to_be_bytes().as_slice()); let nonce = hasher .finalize() @@ -173,27 +205,46 @@ impl Server { })?; // Success - Ok((account_id, client_id, expiry - now)) + Ok(TokenInfo { + grant_type, + account_id, + client_id, + expiry: expiry + OAUTH_EPOCH, + issued_at: issued_at + OAUTH_EPOCH, + expires_in: expiry - now, + }) } - pub async fn password_hash(&self, account_id: u32) -> Result<String, &'static str> { + pub async fn password_hash(&self, account_id: u32) -> trc::Result<String> { if account_id != u32::MAX { self.core .storage .directory .query(QueryBy::Id(account_id), false) .await - .map_err(|_| "Temporary lookup error")? - .ok_or("Account no longer exists")? + .caused_by(trc::location!())? + .ok_or_else(|| { + trc::AuthEvent::Error + .into_err() + .details("Account no longer exists") + })? .take_str_array(PrincipalField::Secrets) .unwrap_or_default() .into_iter() .next() - .ok_or("Failed to obtain password hash") + .ok_or( + trc::AuthEvent::Error + .into_err() + .details("Account does not contain secrets") + .caused_by(trc::location!()), + ) } else if let Some((_, secret)) = &self.core.jmap.fallback_admin { Ok(secret.clone()) } else { - Err("Invalid account id.") + Err(trc::AuthEvent::Error + .into_err() + .details("Invalid account ID") + .caused_by(trc::location!())) } } } diff --git a/crates/common/src/auth/sasl.rs b/crates/common/src/auth/sasl.rs new file mode 100644 index 00000000..92ba84b6 --- /dev/null +++ b/crates/common/src/auth/sasl.rs @@ -0,0 +1,131 @@ +/* + * SPDX-FileCopyrightText: 2020 Stalwart Labs Ltd <hello@stalw.art> + * + * SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-SEL + */ + +use mail_send::Credentials; + +pub fn sasl_decode_challenge_plain(challenge: &[u8]) -> Option<Credentials<String>> { + let mut username = Vec::new(); + let mut secret = Vec::new(); + let mut arg_num = 0; + for &ch in challenge { + if ch != 0 { + if arg_num == 1 { + username.push(ch); + } else if arg_num == 2 { + secret.push(ch); + } + } else { + arg_num += 1; + } + } + + match (String::from_utf8(username), String::from_utf8(secret)) { + (Ok(username), Ok(secret)) if !username.is_empty() && !secret.is_empty() => { + Some((username, secret).into()) + } + _ => None, + } +} + +pub fn sasl_decode_challenge_xoauth(challenge: &[u8]) -> Option<Credentials<String>> { + let mut b_username = Vec::new(); + let mut b_secret = Vec::new(); + let mut arg_num = 0; + let mut in_arg = false; + + for &ch in challenge { + if in_arg { + if ch != 1 { + if arg_num == 1 { + b_username.push(ch); + } else if arg_num == 2 { + b_secret.push(ch); + } + } else { + in_arg = false; + } + } else if ch == b'=' { + arg_num += 1; + in_arg = true; + } + } + match (String::from_utf8(b_username), String::from_utf8(b_secret)) { + (Ok(s_username), Ok(s_secret)) if !s_username.is_empty() => { + Some((s_username, s_secret).into()) + } + _ => None, + } +} + +pub fn sasl_decode_challenge_oauth(challenge: &[u8]) -> Option<Credentials<String>> { + extract_oauth_bearer(challenge).map(|s| Credentials::OAuthBearer { token: s.into() }) +} + +fn extract_oauth_bearer(bytes: &[u8]) -> Option<&str> { + let mut start_pos = 0; + let eof = bytes.len().saturating_sub(1); + + for (pos, ch) in bytes.iter().enumerate() { + let is_separator = *ch == 1; + if is_separator || pos == eof { + if bytes + .get(start_pos..start_pos + 12) + .map_or(false, |s| s.eq_ignore_ascii_case(b"auth=Bearer ")) + { + return bytes + .get(start_pos + 12..if is_separator { pos } else { bytes.len() }) + .and_then(|s| std::str::from_utf8(s).ok()); + } + + start_pos = pos + 1; + } + } + + None +} +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_oauth_bearer() { + let input = b"auth=Bearer validtoken"; + let result = extract_oauth_bearer(input); + assert_eq!(result, Some("validtoken")); + + let input = b"auth=Invalid validtoken"; + let result = extract_oauth_bearer(input); + assert_eq!(result, None); + + let input = b"auth=Bearer"; + let result = extract_oauth_bearer(input); + assert_eq!(result, None); + + let input = b""; + let result = extract_oauth_bearer(input); + assert_eq!(result, None); + + let input = b"auth=Bearer token1\x01auth=Bearer token2"; + let result = extract_oauth_bearer(input); + assert_eq!(result, Some("token1")); + + let input = b"auth=Bearer VALIDTOKEN"; + let result = extract_oauth_bearer(input); + assert_eq!(result, Some("VALIDTOKEN")); + + let input = b"auth=Bearer token with spaces"; + let result = extract_oauth_bearer(input); + assert_eq!(result, Some("token with spaces")); + + let input = b"auth=Bearer token_with_special_chars!@#"; + let result = extract_oauth_bearer(input); + assert_eq!(result, Some("token_with_special_chars!@#")); + + let input = "n,a=user@example.com,\x01host=server.example.com\x01port=143\x01auth=Bearer vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg==\x01\x01"; + let result = extract_oauth_bearer(input.as_bytes()); + assert_eq!(result, Some("vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg==")); + } +} |