summaryrefslogtreecommitdiff
path: root/crates/common/src/auth/oauth/oidc.rs
diff options
context:
space:
mode:
authormdecimus <mauro@stalw.art>2024-09-30 16:57:34 +0200
committermdecimus <mauro@stalw.art>2024-09-30 16:57:34 +0200
commit6a5f963b43851e1720f17df24d58b413ea6d2926 (patch)
treee9f8bb5887b119722ee93c744fc3c169c0764375 /crates/common/src/auth/oauth/oidc.rs
parent1fed40a926b7070ad59fc69cb57ee1b67d53502f (diff)
OpenID Connect implementation (closes #298)
Diffstat (limited to 'crates/common/src/auth/oauth/oidc.rs')
-rw-r--r--crates/common/src/auth/oauth/oidc.rs153
1 files changed, 153 insertions, 0 deletions
diff --git a/crates/common/src/auth/oauth/oidc.rs b/crates/common/src/auth/oauth/oidc.rs
new file mode 100644
index 00000000..16a9192b
--- /dev/null
+++ b/crates/common/src/auth/oauth/oidc.rs
@@ -0,0 +1,153 @@
+/*
+ * SPDX-FileCopyrightText: 2020 Stalwart Labs Ltd <hello@stalw.art>
+ *
+ * SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-SEL
+ */
+
+use std::fmt;
+
+use biscuit::{jws::RegisteredHeader, ClaimsSet, RegisteredClaims, SingleOrMultiple, JWT};
+use serde::{
+ de::{self, Visitor},
+ Deserialize, Deserializer, Serialize,
+};
+use store::write::now;
+
+use crate::Server;
+
+#[derive(Debug, Default, Clone, Eq, PartialEq, Deserialize, Serialize)]
+pub struct Userinfo {
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub sub: Option<String>,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub name: Option<String>,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub given_name: Option<String>,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub family_name: Option<String>,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub middle_name: Option<String>,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub nickname: Option<String>,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub preferred_username: Option<String>,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub profile: Option<String>,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub picture: Option<String>,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub website: Option<String>,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub email: Option<String>,
+
+ #[serde(default, deserialize_with = "any_bool")]
+ #[serde(skip_serializing_if = "std::ops::Not::not")]
+ pub email_verified: bool,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub zoneinfo: Option<String>,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub locale: Option<String>,
+
+ #[serde(default)]
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub updated_at: Option<i64>,
+}
+
+impl Server {
+ pub fn issue_id_token(
+ &self,
+ subject: impl Into<String>,
+ issuer: impl Into<String>,
+ audience: impl Into<String>,
+ ) -> trc::Result<String> {
+ let now = now() as i64;
+
+ JWT::new_decoded(
+ From::from(RegisteredHeader {
+ algorithm: self.core.oauth.oidc_signature_algorithm,
+ key_id: Some("default".into()),
+ ..Default::default()
+ }),
+ ClaimsSet::<()> {
+ registered: RegisteredClaims {
+ issuer: Some(issuer.into()),
+ subject: Some(subject.into()),
+ audience: Some(SingleOrMultiple::Single(audience.into())),
+ not_before: Some(now.into()),
+ issued_at: Some(now.into()),
+ expiry: Some((now + self.core.oauth.oidc_expiry_id_token as i64).into()),
+ ..Default::default()
+ },
+ private: (),
+ },
+ )
+ .into_encoded(&self.core.oauth.oidc_signing_secret)
+ .map(|token| token.unwrap_encoded().to_string())
+ .map_err(|err| {
+ trc::AuthEvent::Error
+ .into_err()
+ .reason(err)
+ .details("Failed to encode ID token")
+ })
+ }
+}
+
+fn any_bool<'de, D>(deserializer: D) -> Result<bool, D::Error>
+where
+ D: Deserializer<'de>,
+{
+ struct AnyBoolVisitor;
+
+ impl<'de> Visitor<'de> for AnyBoolVisitor {
+ type Value = bool;
+
+ fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ formatter.write_str("a boolean value")
+ }
+
+ fn visit_str<E>(self, value: &str) -> Result<bool, E>
+ where
+ E: de::Error,
+ {
+ match value {
+ "true" => Ok(true),
+ "false" => Ok(false),
+ _ => Err(E::custom(format!("Unknown boolean: {value}"))),
+ }
+ }
+
+ fn visit_bool<E>(self, value: bool) -> Result<bool, E>
+ where
+ E: de::Error,
+ {
+ Ok(value)
+ }
+ }
+
+ deserializer.deserialize_any(AnyBoolVisitor)
+}