make registration tokens reloadable, and allow configuring multiple

Signed-off-by: morguldir <morguldir@protonmail.com>
This commit is contained in:
morguldir 2025-01-31 02:36:14 +01:00
parent 69837671bb
commit f698254c41
No known key found for this signature in database
GPG key ID: 5A6025D4F6E7A8A3
4 changed files with 41 additions and 19 deletions

View file

@ -1,5 +1,5 @@
use std::{
collections::BTreeMap,
collections::{BTreeMap, HashSet},
sync::{Arc, RwLock},
};
@ -17,7 +17,7 @@ use ruma::{
CanonicalJsonValue, DeviceId, OwnedDeviceId, OwnedUserId, UserId,
};
use crate::{globals, users, Dep};
use crate::{config, globals, users, Dep};
pub struct Service {
userdevicesessionid_uiaarequest: RwLock<RequestMap>,
@ -28,6 +28,7 @@ pub struct Service {
struct Services {
globals: Dep<globals::Service>,
users: Dep<users::Service>,
config: Dep<config::Service>,
}
struct Data {
@ -49,6 +50,7 @@ impl crate::Service for Service {
services: Services {
globals: args.depend::<globals::Service>("globals"),
users: args.depend::<users::Service>("users"),
config: args.depend::<config::Service>("config"),
},
}))
}
@ -56,6 +58,26 @@ impl crate::Service for Service {
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
#[implement(Service)]
pub async fn read_tokens(&self) -> Result<HashSet<String>> {
let mut tokens = HashSet::new();
if let Some(file) = &self.services.config.registration_token_file.as_ref() {
match std::fs::read_to_string(file) {
| Ok(text) => {
text.split_ascii_whitespace().for_each(|token| {
tokens.insert(token.to_owned());
});
},
| Err(e) => error!("Failed to read the registration token file: {e}"),
}
};
if let Some(token) = &self.services.config.registration_token {
tokens.insert(token.to_owned());
}
Ok(tokens)
}
/// Creates a new Uiaa session. Make sure the session token is unique.
#[implement(Service)]
pub fn create(
@ -152,13 +174,8 @@ pub async fn try_auth(
uiaainfo.completed.push(AuthType::Password);
},
| AuthData::RegistrationToken(t) => {
if self
.services
.globals
.registration_token
.as_ref()
.is_some_and(|reg_token| t.token.trim() == reg_token)
{
let tokens = self.read_tokens().await?;
if tokens.contains(t.token.trim()) {
uiaainfo.completed.push(AuthType::RegistrationToken);
} else {
uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody {