summaryrefslogtreecommitdiff
path: root/crates/common/src
diff options
context:
space:
mode:
Diffstat (limited to 'crates/common/src')
-rw-r--r--crates/common/src/auth/mod.rs9
-rw-r--r--crates/common/src/auth/oauth/introspect.rs97
-rw-r--r--crates/common/src/auth/oauth/mod.rs89
-rw-r--r--crates/common/src/auth/oauth/token.rs163
-rw-r--r--crates/common/src/auth/sasl.rs131
5 files changed, 373 insertions, 116 deletions
diff --git a/crates/common/src/auth/mod.rs b/crates/common/src/auth/mod.rs
index 5a67a0e0..7b74cff8 100644
--- a/crates/common/src/auth/mod.rs
+++ b/crates/common/src/auth/mod.rs
@@ -11,6 +11,7 @@ use directory::{
};
use jmap_proto::types::collection::Collection;
use mail_send::Credentials;
+use oauth::GrantType;
use utils::map::{bitmap::Bitmap, ttl_dashmap::TtlMap, vec_map::VecMap};
use crate::Server;
@@ -18,6 +19,7 @@ use crate::Server;
pub mod access_token;
pub mod oauth;
pub mod roles;
+pub mod sasl;
#[derive(Debug, Clone, Default)]
pub struct AccessToken {
@@ -58,8 +60,11 @@ impl Server {
// Validate credentials
match &req.credentials {
Credentials::OAuthBearer { token } => {
- match self.validate_access_token("access_token", token).await {
- Ok((account_id, _, _)) => self.get_cached_access_token(account_id).await,
+ match self
+ .validate_access_token(GrantType::AccessToken.into(), token)
+ .await
+ {
+ Ok(token_into) => self.get_cached_access_token(token_into.account_id).await,
Err(err) => Err(err),
}
}
diff --git a/crates/common/src/auth/oauth/introspect.rs b/crates/common/src/auth/oauth/introspect.rs
new file mode 100644
index 00000000..f4e931b5
--- /dev/null
+++ b/crates/common/src/auth/oauth/introspect.rs
@@ -0,0 +1,97 @@
+/*
+ * SPDX-FileCopyrightText: 2020 Stalwart Labs Ltd <hello@stalw.art>
+ *
+ * SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-SEL
+ */
+
+use serde::{Deserialize, Serialize};
+use trc::{AddContext, AuthEvent, EventType};
+
+use crate::{auth::AccessToken, Server};
+
+#[derive(Debug, Default, Clone, Eq, PartialEq, Deserialize, Serialize)]
+pub struct OAuthIntrospect {
+ #[serde(default)]
+ pub active: bool,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub scope: Option<String>,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub client_id: Option<String>,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub username: Option<String>,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub token_type: Option<String>,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub exp: Option<i64>,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub iat: Option<i64>,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub nbf: Option<i64>,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub sub: Option<String>,
+ /*#[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub aud: Option<String>,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub iss: Option<String>,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub jti: Option<String>,*/
+}
+
+impl Server {
+ pub async fn introspect_access_token(
+ &self,
+ token: &str,
+ access_token: &AccessToken,
+ ) -> trc::Result<OAuthIntrospect> {
+ match self.validate_access_token(None, token).await {
+ Ok(token_info) => Ok(OAuthIntrospect {
+ active: true,
+ client_id: Some(token_info.client_id),
+ username: if access_token.primary_id() == token_info.account_id {
+ access_token.name.clone()
+ } else {
+ self.get_cached_access_token(token_info.account_id)
+ .await
+ .caused_by(trc::location!())?
+ .name
+ .clone()
+ }
+ .into(),
+ token_type: "bearer".to_string().into(),
+ exp: Some(token_info.expiry as i64),
+ iat: Some(token_info.issued_at as i64),
+ ..Default::default()
+ }),
+ Err(err)
+ if matches!(
+ err.inner,
+ EventType::Auth(AuthEvent::Error) | EventType::Auth(AuthEvent::TokenExpired)
+ ) =>
+ {
+ Ok(OAuthIntrospect::default())
+ }
+ Err(err) => Err(err),
+ }
+ }
+}
diff --git a/crates/common/src/auth/oauth/mod.rs b/crates/common/src/auth/oauth/mod.rs
index b58474e1..20078868 100644
--- a/crates/common/src/auth/oauth/mod.rs
+++ b/crates/common/src/auth/oauth/mod.rs
@@ -5,6 +5,7 @@
*/
pub mod crypto;
+pub mod introspect;
pub mod token;
pub const DEVICE_CODE_LEN: usize = 40;
@@ -14,68 +15,40 @@ pub const CLIENT_ID_MAX_LEN: usize = 20;
pub const USER_CODE_ALPHABET: &[u8] = b"ABCDEFGHJKLMNPQRSTUVWXYZ23456789"; // No 0, O, I, 1
-pub fn extract_oauth_bearer(bytes: &[u8]) -> Option<&str> {
- let mut start_pos = 0;
- let eof = bytes.len().saturating_sub(1);
-
- for (pos, ch) in bytes.iter().enumerate() {
- let is_separator = *ch == 1;
- if is_separator || pos == eof {
- if bytes
- .get(start_pos..start_pos + 12)
- .map_or(false, |s| s.eq_ignore_ascii_case(b"auth=Bearer "))
- {
- return bytes
- .get(start_pos + 12..if is_separator { pos } else { bytes.len() })
- .and_then(|s| std::str::from_utf8(s).ok());
- }
+#[derive(Debug, Clone, Copy, Eq, PartialEq)]
+pub enum GrantType {
+ AccessToken,
+ RefreshToken,
+ LiveTracing,
+ LiveMetrics,
+}
- start_pos = pos + 1;
+impl GrantType {
+ pub fn as_str(&self) -> &'static str {
+ match self {
+ GrantType::AccessToken => "access_token",
+ GrantType::RefreshToken => "refresh_token",
+ GrantType::LiveTracing => "live_tracing",
+ GrantType::LiveMetrics => "live_metrics",
}
}
- None
-}
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_extract_oauth_bearer() {
- let input = b"auth=Bearer validtoken";
- let result = extract_oauth_bearer(input);
- assert_eq!(result, Some("validtoken"));
-
- let input = b"auth=Invalid validtoken";
- let result = extract_oauth_bearer(input);
- assert_eq!(result, None);
-
- let input = b"auth=Bearer";
- let result = extract_oauth_bearer(input);
- assert_eq!(result, None);
-
- let input = b"";
- let result = extract_oauth_bearer(input);
- assert_eq!(result, None);
-
- let input = b"auth=Bearer token1\x01auth=Bearer token2";
- let result = extract_oauth_bearer(input);
- assert_eq!(result, Some("token1"));
-
- let input = b"auth=Bearer VALIDTOKEN";
- let result = extract_oauth_bearer(input);
- assert_eq!(result, Some("VALIDTOKEN"));
-
- let input = b"auth=Bearer token with spaces";
- let result = extract_oauth_bearer(input);
- assert_eq!(result, Some("token with spaces"));
-
- let input = b"auth=Bearer token_with_special_chars!@#";
- let result = extract_oauth_bearer(input);
- assert_eq!(result, Some("token_with_special_chars!@#"));
+ pub fn id(&self) -> u8 {
+ match self {
+ GrantType::AccessToken => 0,
+ GrantType::RefreshToken => 1,
+ GrantType::LiveTracing => 2,
+ GrantType::LiveMetrics => 3,
+ }
+ }
- let input = "n,a=user@example.com,\x01host=server.example.com\x01port=143\x01auth=Bearer vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg==\x01\x01";
- let result = extract_oauth_bearer(input.as_bytes());
- assert_eq!(result, Some("vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg=="));
+ pub fn from_id(id: u8) -> Option<Self> {
+ match id {
+ 0 => Some(GrantType::AccessToken),
+ 1 => Some(GrantType::RefreshToken),
+ 2 => Some(GrantType::LiveTracing),
+ 3 => Some(GrantType::LiveMetrics),
+ _ => None,
+ }
}
}
diff --git a/crates/common/src/auth/oauth/token.rs b/crates/common/src/auth/oauth/token.rs
index 08904983..ec6cb946 100644
--- a/crates/common/src/auth/oauth/token.rs
+++ b/crates/common/src/auth/oauth/token.rs
@@ -13,63 +13,71 @@ use store::{
blake3,
rand::{thread_rng, Rng},
};
+use trc::AddContext;
use utils::codec::leb128::{Leb128Iterator, Leb128Vec};
use crate::Server;
-use super::{crypto::SymmetricEncrypt, CLIENT_ID_MAX_LEN, RANDOM_CODE_LEN};
+use super::{crypto::SymmetricEncrypt, GrantType, CLIENT_ID_MAX_LEN, RANDOM_CODE_LEN};
+
+pub struct TokenInfo {
+ pub grant_type: GrantType,
+ pub account_id: u32,
+ pub client_id: String,
+ pub expiry: u64,
+ pub issued_at: u64,
+ pub expires_in: u64,
+}
+
+const OAUTH_EPOCH: u64 = 946684800; // Jan 1, 2000
impl Server {
- pub async fn issue_custom_token(
+ pub async fn encode_access_token(
&self,
+ grant_type: GrantType,
account_id: u32,
- grant_type: &str,
client_id: &str,
expiry_in: u64,
) -> trc::Result<String> {
- self.encode_access_token(
- grant_type,
- account_id,
- &self
- .password_hash(account_id)
- .await
- .map_err(|err| trc::StoreEvent::UnexpectedError.into_err().details(err))?,
- client_id,
- expiry_in,
- )
- .map_err(|err| trc::StoreEvent::UnexpectedError.into_err().details(err))
- }
-
- pub fn encode_access_token(
- &self,
- grant_type: &str,
- account_id: u32,
- password_hash: &str,
- client_id: &str,
- expiry_in: u64,
- ) -> Result<String, &'static str> {
// Build context
if client_id.len() > CLIENT_ID_MAX_LEN {
- return Err("ClientId is too long");
+ return Err(trc::AuthEvent::Error
+ .into_err()
+ .details("Client id too long"));
}
- let key = self.core.jmap.oauth_key.clone();
+
+ // Include password hash if expiration is over 1 hour
+ let password_hash = if expiry_in > 3600 {
+ self.password_hash(account_id)
+ .await
+ .caused_by(trc::location!())?
+ } else {
+ String::new()
+ };
+
+ let key = &self.core.jmap.oauth_key;
let context = format!(
"{} {} {} {}",
- grant_type, client_id, account_id, password_hash
+ grant_type.as_str(),
+ client_id,
+ account_id,
+ password_hash
);
- let context_nonce = format!("{} nonce {}", grant_type, password_hash);
// Set expiration time
- let expiry = SystemTime::now()
- .duration_since(SystemTime::UNIX_EPOCH)
- .map(|d| d.as_secs())
- .unwrap_or(0)
- .saturating_sub(946684800) // Jan 1, 2000
- + expiry_in;
+ let issued_at = SystemTime::now()
+ .duration_since(SystemTime::UNIX_EPOCH)
+ .map_or(0, |d| d.as_secs())
+ .saturating_sub(OAUTH_EPOCH); // Jan 1, 2000
+ let expiry = issued_at + expiry_in;
// Calculate nonce
let mut hasher = blake3::Hasher::new();
- hasher.update(context_nonce.as_bytes());
+ if !password_hash.is_empty() {
+ hasher.update(password_hash.as_bytes());
+ }
+ hasher.update(grant_type.as_str().as_bytes());
+ hasher.update(issued_at.to_be_bytes().as_slice());
hasher.update(expiry.to_be_bytes().as_slice());
let nonce = hasher
.finalize()
@@ -82,8 +90,15 @@ impl Server {
// Encrypt random bytes
let mut token = SymmetricEncrypt::new(key.as_bytes(), &context)
.encrypt(&thread_rng().gen::<[u8; RANDOM_CODE_LEN]>(), &nonce)
- .map_err(|_| "Failed to encrypt token.")?;
+ .map_err(|_| {
+ trc::AuthEvent::Error
+ .into_err()
+ .ctx(trc::Key::Reason, "Failed to encrypt token")
+ .caused_by(trc::location!())
+ })?;
token.push_leb128(account_id);
+ token.push(grant_type.id());
+ token.push_leb128(issued_at);
token.push_leb128(expiry);
token.extend_from_slice(client_id.as_bytes());
@@ -92,9 +107,9 @@ impl Server {
pub async fn validate_access_token(
&self,
- grant_type: &str,
+ expected_grant_type: Option<GrantType>,
token_: &str,
- ) -> trc::Result<(u32, String, u64)> {
+ ) -> trc::Result<TokenInfo> {
// Base64 decode token
let token = base64_decode(token_.as_bytes()).ok_or_else(|| {
trc::AuthEvent::Error
@@ -103,12 +118,14 @@ impl Server {
.caused_by(trc::location!())
.details(token_.to_string())
})?;
- let (account_id, expiry, client_id) = token
+ let (account_id, grant_type, issued_at, expiry, client_id) = token
.get((RANDOM_CODE_LEN + SymmetricEncrypt::ENCRYPT_TAG_LEN)..)
.and_then(|bytes| {
let mut bytes = bytes.iter();
(
bytes.next_leb128()?,
+ GrantType::from_id(bytes.next().copied()?)?,
+ bytes.next_leb128::<u64>()?,
bytes.next_leb128::<u64>()?,
bytes.copied().map(char::from).collect::<String>(),
)
@@ -125,30 +142,45 @@ impl Server {
// Validate expiration
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
- .map(|d| d.as_secs())
- .unwrap_or(0)
- .saturating_sub(946684800); // Jan 1, 2000
- if expiry <= now {
+ .map_or(0, |d| d.as_secs())
+ .saturating_sub(OAUTH_EPOCH); // Jan 1, 2000
+ if expiry <= now || issued_at > now {
return Err(trc::AuthEvent::TokenExpired.into_err());
}
+ // Validate grant type
+ if expected_grant_type.map_or(false, |g| g != grant_type) {
+ return Err(trc::AuthEvent::Error
+ .into_err()
+ .details("Invalid grant type"));
+ }
+
// Obtain password hash
- let password_hash = self
- .password_hash(account_id)
- .await
- .map_err(|err| trc::AuthEvent::Error.into_err().ctx(trc::Key::Details, err))?;
+ let password_hash = if expiry - issued_at > 3600 {
+ self.password_hash(account_id)
+ .await
+ .map_err(|err| trc::AuthEvent::Error.into_err().ctx(trc::Key::Details, err))?
+ } else {
+ String::new()
+ };
// Build context
let key = self.core.jmap.oauth_key.clone();
let context = format!(
"{} {} {} {}",
- grant_type, client_id, account_id, password_hash
+ grant_type.as_str(),
+ client_id,
+ account_id,
+ password_hash
);
- let context_nonce = format!("{} nonce {}", grant_type, password_hash);
// Calculate nonce
let mut hasher = blake3::Hasher::new();
- hasher.update(context_nonce.as_bytes());
+ if !password_hash.is_empty() {
+ hasher.update(password_hash.as_bytes());
+ }
+ hasher.update(grant_type.as_str().as_bytes());
+ hasher.update(issued_at.to_be_bytes().as_slice());
hasher.update(expiry.to_be_bytes().as_slice());
let nonce = hasher
.finalize()
@@ -173,27 +205,46 @@ impl Server {
})?;
// Success
- Ok((account_id, client_id, expiry - now))
+ Ok(TokenInfo {
+ grant_type,
+ account_id,
+ client_id,
+ expiry: expiry + OAUTH_EPOCH,
+ issued_at: issued_at + OAUTH_EPOCH,
+ expires_in: expiry - now,
+ })
}
- pub async fn password_hash(&self, account_id: u32) -> Result<String, &'static str> {
+ pub async fn password_hash(&self, account_id: u32) -> trc::Result<String> {
if account_id != u32::MAX {
self.core
.storage
.directory
.query(QueryBy::Id(account_id), false)
.await
- .map_err(|_| "Temporary lookup error")?
- .ok_or("Account no longer exists")?
+ .caused_by(trc::location!())?
+ .ok_or_else(|| {
+ trc::AuthEvent::Error
+ .into_err()
+ .details("Account no longer exists")
+ })?
.take_str_array(PrincipalField::Secrets)
.unwrap_or_default()
.into_iter()
.next()
- .ok_or("Failed to obtain password hash")
+ .ok_or(
+ trc::AuthEvent::Error
+ .into_err()
+ .details("Account does not contain secrets")
+ .caused_by(trc::location!()),
+ )
} else if let Some((_, secret)) = &self.core.jmap.fallback_admin {
Ok(secret.clone())
} else {
- Err("Invalid account id.")
+ Err(trc::AuthEvent::Error
+ .into_err()
+ .details("Invalid account ID")
+ .caused_by(trc::location!()))
}
}
}
diff --git a/crates/common/src/auth/sasl.rs b/crates/common/src/auth/sasl.rs
new file mode 100644
index 00000000..92ba84b6
--- /dev/null
+++ b/crates/common/src/auth/sasl.rs
@@ -0,0 +1,131 @@
+/*
+ * SPDX-FileCopyrightText: 2020 Stalwart Labs Ltd <hello@stalw.art>
+ *
+ * SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-SEL
+ */
+
+use mail_send::Credentials;
+
+pub fn sasl_decode_challenge_plain(challenge: &[u8]) -> Option<Credentials<String>> {
+ let mut username = Vec::new();
+ let mut secret = Vec::new();
+ let mut arg_num = 0;
+ for &ch in challenge {
+ if ch != 0 {
+ if arg_num == 1 {
+ username.push(ch);
+ } else if arg_num == 2 {
+ secret.push(ch);
+ }
+ } else {
+ arg_num += 1;
+ }
+ }
+
+ match (String::from_utf8(username), String::from_utf8(secret)) {
+ (Ok(username), Ok(secret)) if !username.is_empty() && !secret.is_empty() => {
+ Some((username, secret).into())
+ }
+ _ => None,
+ }
+}
+
+pub fn sasl_decode_challenge_xoauth(challenge: &[u8]) -> Option<Credentials<String>> {
+ let mut b_username = Vec::new();
+ let mut b_secret = Vec::new();
+ let mut arg_num = 0;
+ let mut in_arg = false;
+
+ for &ch in challenge {
+ if in_arg {
+ if ch != 1 {
+ if arg_num == 1 {
+ b_username.push(ch);
+ } else if arg_num == 2 {
+ b_secret.push(ch);
+ }
+ } else {
+ in_arg = false;
+ }
+ } else if ch == b'=' {
+ arg_num += 1;
+ in_arg = true;
+ }
+ }
+ match (String::from_utf8(b_username), String::from_utf8(b_secret)) {
+ (Ok(s_username), Ok(s_secret)) if !s_username.is_empty() => {
+ Some((s_username, s_secret).into())
+ }
+ _ => None,
+ }
+}
+
+pub fn sasl_decode_challenge_oauth(challenge: &[u8]) -> Option<Credentials<String>> {
+ extract_oauth_bearer(challenge).map(|s| Credentials::OAuthBearer { token: s.into() })
+}
+
+fn extract_oauth_bearer(bytes: &[u8]) -> Option<&str> {
+ let mut start_pos = 0;
+ let eof = bytes.len().saturating_sub(1);
+
+ for (pos, ch) in bytes.iter().enumerate() {
+ let is_separator = *ch == 1;
+ if is_separator || pos == eof {
+ if bytes
+ .get(start_pos..start_pos + 12)
+ .map_or(false, |s| s.eq_ignore_ascii_case(b"auth=Bearer "))
+ {
+ return bytes
+ .get(start_pos + 12..if is_separator { pos } else { bytes.len() })
+ .and_then(|s| std::str::from_utf8(s).ok());
+ }
+
+ start_pos = pos + 1;
+ }
+ }
+
+ None
+}
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_extract_oauth_bearer() {
+ let input = b"auth=Bearer validtoken";
+ let result = extract_oauth_bearer(input);
+ assert_eq!(result, Some("validtoken"));
+
+ let input = b"auth=Invalid validtoken";
+ let result = extract_oauth_bearer(input);
+ assert_eq!(result, None);
+
+ let input = b"auth=Bearer";
+ let result = extract_oauth_bearer(input);
+ assert_eq!(result, None);
+
+ let input = b"";
+ let result = extract_oauth_bearer(input);
+ assert_eq!(result, None);
+
+ let input = b"auth=Bearer token1\x01auth=Bearer token2";
+ let result = extract_oauth_bearer(input);
+ assert_eq!(result, Some("token1"));
+
+ let input = b"auth=Bearer VALIDTOKEN";
+ let result = extract_oauth_bearer(input);
+ assert_eq!(result, Some("VALIDTOKEN"));
+
+ let input = b"auth=Bearer token with spaces";
+ let result = extract_oauth_bearer(input);
+ assert_eq!(result, Some("token with spaces"));
+
+ let input = b"auth=Bearer token_with_special_chars!@#";
+ let result = extract_oauth_bearer(input);
+ assert_eq!(result, Some("token_with_special_chars!@#"));
+
+ let input = "n,a=user@example.com,\x01host=server.example.com\x01port=143\x01auth=Bearer vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg==\x01\x01";
+ let result = extract_oauth_bearer(input.as_bytes());
+ assert_eq!(result, Some("vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg=="));
+ }
+}