mirror of https://gitlab.com/famedly/conduit.git
parent
commit
3e22bbeecd
7 changed files with 1002 additions and 0 deletions
@ -0,0 +1,160 @@
|
||||
/// Builds a StateMap by iterating over all keys that start
|
||||
/// with state_hash, this gives the full state for the given state_hash.
|
||||
#[tracing::instrument(skip(self))] |
||||
pub async fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, Arc<EventId>>> { |
||||
let full_state = self |
||||
.load_shortstatehash_info(shortstatehash)? |
||||
.pop() |
||||
.expect("there is always one layer") |
||||
.1; |
||||
let mut result = BTreeMap::new(); |
||||
let mut i = 0; |
||||
for compressed in full_state.into_iter() { |
||||
let parsed = self.parse_compressed_state_event(compressed)?; |
||||
result.insert(parsed.0, parsed.1); |
||||
|
||||
i += 1; |
||||
if i % 100 == 0 { |
||||
tokio::task::yield_now().await; |
||||
} |
||||
} |
||||
Ok(result) |
||||
} |
||||
|
||||
#[tracing::instrument(skip(self))] |
||||
pub async fn state_full( |
||||
&self, |
||||
shortstatehash: u64, |
||||
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { |
||||
let full_state = self |
||||
.load_shortstatehash_info(shortstatehash)? |
||||
.pop() |
||||
.expect("there is always one layer") |
||||
.1; |
||||
|
||||
let mut result = HashMap::new(); |
||||
let mut i = 0; |
||||
for compressed in full_state { |
||||
let (_, eventid) = self.parse_compressed_state_event(compressed)?; |
||||
if let Some(pdu) = self.get_pdu(&eventid)? { |
||||
result.insert( |
||||
( |
||||
pdu.kind.to_string().into(), |
||||
pdu.state_key |
||||
.as_ref() |
||||
.ok_or_else(|| Error::bad_database("State event has no state key."))? |
||||
.clone(), |
||||
), |
||||
pdu, |
||||
); |
||||
} |
||||
|
||||
i += 1; |
||||
if i % 100 == 0 { |
||||
tokio::task::yield_now().await; |
||||
} |
||||
} |
||||
|
||||
Ok(result) |
||||
} |
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
#[tracing::instrument(skip(self))] |
||||
pub fn state_get_id( |
||||
&self, |
||||
shortstatehash: u64, |
||||
event_type: &StateEventType, |
||||
state_key: &str, |
||||
) -> Result<Option<Arc<EventId>>> { |
||||
let shortstatekey = match self.get_shortstatekey(event_type, state_key)? { |
||||
Some(s) => s, |
||||
None => return Ok(None), |
||||
}; |
||||
let full_state = self |
||||
.load_shortstatehash_info(shortstatehash)? |
||||
.pop() |
||||
.expect("there is always one layer") |
||||
.1; |
||||
Ok(full_state |
||||
.into_iter() |
||||
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) |
||||
.and_then(|compressed| { |
||||
self.parse_compressed_state_event(compressed) |
||||
.ok() |
||||
.map(|(_, id)| id) |
||||
})) |
||||
} |
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
#[tracing::instrument(skip(self))] |
||||
pub fn state_get( |
||||
&self, |
||||
shortstatehash: u64, |
||||
event_type: &StateEventType, |
||||
state_key: &str, |
||||
) -> Result<Option<Arc<PduEvent>>> { |
||||
self.state_get_id(shortstatehash, event_type, state_key)? |
||||
.map_or(Ok(None), |event_id| self.get_pdu(&event_id)) |
||||
} |
||||
|
||||
/// Returns the state hash for this pdu.
|
||||
pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> { |
||||
self.eventid_shorteventid |
||||
.get(event_id.as_bytes())? |
||||
.map_or(Ok(None), |shorteventid| { |
||||
self.shorteventid_shortstatehash |
||||
.get(&shorteventid)? |
||||
.map(|bytes| { |
||||
utils::u64_from_bytes(&bytes).map_err(|_| { |
||||
Error::bad_database( |
||||
"Invalid shortstatehash bytes in shorteventid_shortstatehash", |
||||
) |
||||
}) |
||||
}) |
||||
.transpose() |
||||
}) |
||||
} |
||||
|
||||
/// Returns the full room state.
|
||||
#[tracing::instrument(skip(self))] |
||||
pub async fn room_state_full( |
||||
&self, |
||||
room_id: &RoomId, |
||||
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { |
||||
if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { |
||||
self.state_full(current_shortstatehash).await |
||||
} else { |
||||
Ok(HashMap::new()) |
||||
} |
||||
} |
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
#[tracing::instrument(skip(self))] |
||||
pub fn room_state_get_id( |
||||
&self, |
||||
room_id: &RoomId, |
||||
event_type: &StateEventType, |
||||
state_key: &str, |
||||
) -> Result<Option<Arc<EventId>>> { |
||||
if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { |
||||
self.state_get_id(current_shortstatehash, event_type, state_key) |
||||
} else { |
||||
Ok(None) |
||||
} |
||||
} |
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
#[tracing::instrument(skip(self))] |
||||
pub fn room_state_get( |
||||
&self, |
||||
room_id: &RoomId, |
||||
event_type: &StateEventType, |
||||
state_key: &str, |
||||
) -> Result<Option<Arc<PduEvent>>> { |
||||
if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { |
||||
self.state_get(current_shortstatehash, event_type, state_key) |
||||
} else { |
||||
Ok(None) |
||||
} |
||||
} |
||||
|
||||
@ -0,0 +1,114 @@
|
||||
|
||||
#[tracing::instrument(skip(self))] |
||||
pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { |
||||
let mut userroom_id = user_id.as_bytes().to_vec(); |
||||
userroom_id.push(0xff); |
||||
userroom_id.extend_from_slice(room_id.as_bytes()); |
||||
|
||||
self.userroomid_notificationcount |
||||
.insert(&userroom_id, &0_u64.to_be_bytes())?; |
||||
self.userroomid_highlightcount |
||||
.insert(&userroom_id, &0_u64.to_be_bytes())?; |
||||
|
||||
Ok(()) |
||||
} |
||||
|
||||
#[tracing::instrument(skip(self))] |
||||
pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { |
||||
let mut userroom_id = user_id.as_bytes().to_vec(); |
||||
userroom_id.push(0xff); |
||||
userroom_id.extend_from_slice(room_id.as_bytes()); |
||||
|
||||
self.userroomid_notificationcount |
||||
.get(&userroom_id)? |
||||
.map(|bytes| { |
||||
utils::u64_from_bytes(&bytes) |
||||
.map_err(|_| Error::bad_database("Invalid notification count in db.")) |
||||
}) |
||||
.unwrap_or(Ok(0)) |
||||
} |
||||
|
||||
#[tracing::instrument(skip(self))] |
||||
pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { |
||||
let mut userroom_id = user_id.as_bytes().to_vec(); |
||||
userroom_id.push(0xff); |
||||
userroom_id.extend_from_slice(room_id.as_bytes()); |
||||
|
||||
self.userroomid_highlightcount |
||||
.get(&userroom_id)? |
||||
.map(|bytes| { |
||||
utils::u64_from_bytes(&bytes) |
||||
.map_err(|_| Error::bad_database("Invalid highlight count in db.")) |
||||
}) |
||||
.unwrap_or(Ok(0)) |
||||
} |
||||
|
||||
pub fn associate_token_shortstatehash( |
||||
&self, |
||||
room_id: &RoomId, |
||||
token: u64, |
||||
shortstatehash: u64, |
||||
) -> Result<()> { |
||||
let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); |
||||
|
||||
let mut key = shortroomid.to_be_bytes().to_vec(); |
||||
key.extend_from_slice(&token.to_be_bytes()); |
||||
|
||||
self.roomsynctoken_shortstatehash |
||||
.insert(&key, &shortstatehash.to_be_bytes()) |
||||
} |
||||
|
||||
pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> { |
||||
let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); |
||||
|
||||
let mut key = shortroomid.to_be_bytes().to_vec(); |
||||
key.extend_from_slice(&token.to_be_bytes()); |
||||
|
||||
self.roomsynctoken_shortstatehash |
||||
.get(&key)? |
||||
.map(|bytes| { |
||||
utils::u64_from_bytes(&bytes).map_err(|_| { |
||||
Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash") |
||||
}) |
||||
}) |
||||
.transpose() |
||||
} |
||||
|
||||
#[tracing::instrument(skip(self))] |
||||
pub fn get_shared_rooms<'a>( |
||||
&'a self, |
||||
users: Vec<Box<UserId>>, |
||||
) -> Result<impl Iterator<Item = Result<Box<RoomId>>> + 'a> { |
||||
let iterators = users.into_iter().map(move |user_id| { |
||||
let mut prefix = user_id.as_bytes().to_vec(); |
||||
prefix.push(0xff); |
||||
|
||||
self.userroomid_joined |
||||
.scan_prefix(prefix) |
||||
.map(|(key, _)| { |
||||
let roomid_index = key |
||||
.iter() |
||||
.enumerate() |
||||
.find(|(_, &b)| b == 0xff) |
||||
.ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? |
||||
.0 |
||||
+ 1; // +1 because the room id starts AFTER the separator
|
||||
|
||||
let room_id = key[roomid_index..].to_vec(); |
||||
|
||||
Ok::<_, Error>(room_id) |
||||
}) |
||||
.filter_map(|r| r.ok()) |
||||
}); |
||||
|
||||
// We use the default compare function because keys are sorted correctly (not reversed)
|
||||
Ok(utils::common_elements(iterators, Ord::cmp) |
||||
.expect("users is not empty") |
||||
.map(|bytes| { |
||||
RoomId::parse(utils::string_from_bytes(&*bytes).map_err(|_| { |
||||
Error::bad_database("Invalid RoomId bytes in userroomid_joined") |
||||
})?) |
||||
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) |
||||
})) |
||||
} |
||||
|
||||
@ -0,0 +1,160 @@
|
||||
/// Builds a StateMap by iterating over all keys that start
|
||||
/// with state_hash, this gives the full state for the given state_hash.
|
||||
#[tracing::instrument(skip(self))] |
||||
pub async fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, Arc<EventId>>> { |
||||
let full_state = self |
||||
.load_shortstatehash_info(shortstatehash)? |
||||
.pop() |
||||
.expect("there is always one layer") |
||||
.1; |
||||
let mut result = BTreeMap::new(); |
||||
let mut i = 0; |
||||
for compressed in full_state.into_iter() { |
||||
let parsed = self.parse_compressed_state_event(compressed)?; |
||||
result.insert(parsed.0, parsed.1); |
||||
|
||||
i += 1; |
||||
if i % 100 == 0 { |
||||
tokio::task::yield_now().await; |
||||
} |
||||
} |
||||
Ok(result) |
||||
} |
||||
|
||||
#[tracing::instrument(skip(self))] |
||||
pub async fn state_full( |
||||
&self, |
||||
shortstatehash: u64, |
||||
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { |
||||
let full_state = self |
||||
.load_shortstatehash_info(shortstatehash)? |
||||
.pop() |
||||
.expect("there is always one layer") |
||||
.1; |
||||
|
||||
let mut result = HashMap::new(); |
||||
let mut i = 0; |
||||
for compressed in full_state { |
||||
let (_, eventid) = self.parse_compressed_state_event(compressed)?; |
||||
if let Some(pdu) = self.get_pdu(&eventid)? { |
||||
result.insert( |
||||
( |
||||
pdu.kind.to_string().into(), |
||||
pdu.state_key |
||||
.as_ref() |
||||
.ok_or_else(|| Error::bad_database("State event has no state key."))? |
||||
.clone(), |
||||
), |
||||
pdu, |
||||
); |
||||
} |
||||
|
||||
i += 1; |
||||
if i % 100 == 0 { |
||||
tokio::task::yield_now().await; |
||||
} |
||||
} |
||||
|
||||
Ok(result) |
||||
} |
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
#[tracing::instrument(skip(self))] |
||||
pub fn state_get_id( |
||||
&self, |
||||
shortstatehash: u64, |
||||
event_type: &StateEventType, |
||||
state_key: &str, |
||||
) -> Result<Option<Arc<EventId>>> { |
||||
let shortstatekey = match self.get_shortstatekey(event_type, state_key)? { |
||||
Some(s) => s, |
||||
None => return Ok(None), |
||||
}; |
||||
let full_state = self |
||||
.load_shortstatehash_info(shortstatehash)? |
||||
.pop() |
||||
.expect("there is always one layer") |
||||
.1; |
||||
Ok(full_state |
||||
.into_iter() |
||||
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) |
||||
.and_then(|compressed| { |
||||
self.parse_compressed_state_event(compressed) |
||||
.ok() |
||||
.map(|(_, id)| id) |
||||
})) |
||||
} |
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
#[tracing::instrument(skip(self))] |
||||
pub fn state_get( |
||||
&self, |
||||
shortstatehash: u64, |
||||
event_type: &StateEventType, |
||||
state_key: &str, |
||||
) -> Result<Option<Arc<PduEvent>>> { |
||||
self.state_get_id(shortstatehash, event_type, state_key)? |
||||
.map_or(Ok(None), |event_id| self.get_pdu(&event_id)) |
||||
} |
||||
|
||||
/// Returns the state hash for this pdu.
|
||||
pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> { |
||||
self.eventid_shorteventid |
||||
.get(event_id.as_bytes())? |
||||
.map_or(Ok(None), |shorteventid| { |
||||
self.shorteventid_shortstatehash |
||||
.get(&shorteventid)? |
||||
.map(|bytes| { |
||||
utils::u64_from_bytes(&bytes).map_err(|_| { |
||||
Error::bad_database( |
||||
"Invalid shortstatehash bytes in shorteventid_shortstatehash", |
||||
) |
||||
}) |
||||
}) |
||||
.transpose() |
||||
}) |
||||
} |
||||
|
||||
/// Returns the full room state.
|
||||
#[tracing::instrument(skip(self))] |
||||
pub async fn room_state_full( |
||||
&self, |
||||
room_id: &RoomId, |
||||
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { |
||||
if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { |
||||
self.state_full(current_shortstatehash).await |
||||
} else { |
||||
Ok(HashMap::new()) |
||||
} |
||||
} |
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
#[tracing::instrument(skip(self))] |
||||
pub fn room_state_get_id( |
||||
&self, |
||||
room_id: &RoomId, |
||||
event_type: &StateEventType, |
||||
state_key: &str, |
||||
) -> Result<Option<Arc<EventId>>> { |
||||
if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { |
||||
self.state_get_id(current_shortstatehash, event_type, state_key) |
||||
} else { |
||||
Ok(None) |
||||
} |
||||
} |
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
#[tracing::instrument(skip(self))] |
||||
pub fn room_state_get( |
||||
&self, |
||||
room_id: &RoomId, |
||||
event_type: &StateEventType, |
||||
state_key: &str, |
||||
) -> Result<Option<Arc<PduEvent>>> { |
||||
if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { |
||||
self.state_get(current_shortstatehash, event_type, state_key) |
||||
} else { |
||||
Ok(None) |
||||
} |
||||
} |
||||
|
||||
@ -0,0 +1,114 @@
|
||||
|
||||
#[tracing::instrument(skip(self))] |
||||
pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { |
||||
let mut userroom_id = user_id.as_bytes().to_vec(); |
||||
userroom_id.push(0xff); |
||||
userroom_id.extend_from_slice(room_id.as_bytes()); |
||||
|
||||
self.userroomid_notificationcount |
||||
.insert(&userroom_id, &0_u64.to_be_bytes())?; |
||||
self.userroomid_highlightcount |
||||
.insert(&userroom_id, &0_u64.to_be_bytes())?; |
||||
|
||||
Ok(()) |
||||
} |
||||
|
||||
#[tracing::instrument(skip(self))] |
||||
pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { |
||||
let mut userroom_id = user_id.as_bytes().to_vec(); |
||||
userroom_id.push(0xff); |
||||
userroom_id.extend_from_slice(room_id.as_bytes()); |
||||
|
||||
self.userroomid_notificationcount |
||||
.get(&userroom_id)? |
||||
.map(|bytes| { |
||||
utils::u64_from_bytes(&bytes) |
||||
.map_err(|_| Error::bad_database("Invalid notification count in db.")) |
||||
}) |
||||
.unwrap_or(Ok(0)) |
||||
} |
||||
|
||||
#[tracing::instrument(skip(self))] |
||||
pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> { |
||||
let mut userroom_id = user_id.as_bytes().to_vec(); |
||||
userroom_id.push(0xff); |
||||
userroom_id.extend_from_slice(room_id.as_bytes()); |
||||
|
||||
self.userroomid_highlightcount |
||||
.get(&userroom_id)? |
||||
.map(|bytes| { |
||||
utils::u64_from_bytes(&bytes) |
||||
.map_err(|_| Error::bad_database("Invalid highlight count in db.")) |
||||
}) |
||||
.unwrap_or(Ok(0)) |
||||
} |
||||
|
||||
pub fn associate_token_shortstatehash( |
||||
&self, |
||||
room_id: &RoomId, |
||||
token: u64, |
||||
shortstatehash: u64, |
||||
) -> Result<()> { |
||||
let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); |
||||
|
||||
let mut key = shortroomid.to_be_bytes().to_vec(); |
||||
key.extend_from_slice(&token.to_be_bytes()); |
||||
|
||||
self.roomsynctoken_shortstatehash |
||||
.insert(&key, &shortstatehash.to_be_bytes()) |
||||
} |
||||
|
||||
pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> { |
||||
let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); |
||||
|
||||
let mut key = shortroomid.to_be_bytes().to_vec(); |
||||
key.extend_from_slice(&token.to_be_bytes()); |
||||
|
||||
self.roomsynctoken_shortstatehash |
||||
.get(&key)? |
||||
.map(|bytes| { |
||||
utils::u64_from_bytes(&bytes).map_err(|_| { |
||||
Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash") |
||||
}) |
||||
}) |
||||
.transpose() |
||||
} |
||||
|
||||
#[tracing::instrument(skip(self))] |
||||
pub fn get_shared_rooms<'a>( |
||||
&'a self, |
||||
users: Vec<Box<UserId>>, |
||||
) -> Result<impl Iterator<Item = Result<Box<RoomId>>> + 'a> { |
||||
let iterators = users.into_iter().map(move |user_id| { |
||||
let mut prefix = user_id.as_bytes().to_vec(); |
||||
prefix.push(0xff); |
||||
|
||||
self.userroomid_joined |
||||
.scan_prefix(prefix) |
||||
.map(|(key, _)| { |
||||
let roomid_index = key |
||||
.iter() |
||||
.enumerate() |
||||
.find(|(_, &b)| b == 0xff) |
||||
.ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? |
||||
.0 |
||||
+ 1; // +1 because the room id starts AFTER the separator
|
||||
|
||||
let room_id = key[roomid_index..].to_vec(); |
||||
|
||||
Ok::<_, Error>(room_id) |
||||
}) |
||||
.filter_map(|r| r.ok()) |
||||
}); |
||||
|
||||
// We use the default compare function because keys are sorted correctly (not reversed)
|
||||
Ok(utils::common_elements(iterators, Ord::cmp) |
||||
.expect("users is not empty") |
||||
.map(|bytes| { |
||||
RoomId::parse(utils::string_from_bytes(&*bytes).map_err(|_| { |
||||
Error::bad_database("Invalid RoomId bytes in userroomid_joined") |
||||
})?) |
||||
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) |
||||
})) |
||||
} |
||||
|
||||
@ -0,0 +1,227 @@
|
||||
use std::{ |
||||
collections::BTreeMap, |
||||
sync::{Arc, RwLock}, |
||||
}; |
||||
|
||||
use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result}; |
||||
use ruma::{ |
||||
api::client::{ |
||||
error::ErrorKind, |
||||
uiaa::{ |
||||
AuthType, IncomingAuthData, IncomingPassword, |
||||
IncomingUserIdentifier::UserIdOrLocalpart, UiaaInfo, |
||||
}, |
||||
}, |
||||
signatures::CanonicalJsonValue, |
||||
DeviceId, UserId, |
||||
}; |
||||
use tracing::error; |
||||
|
||||
use super::abstraction::Tree; |
||||
|
||||
pub struct Uiaa { |
||||
pub(super) userdevicesessionid_uiaainfo: Arc<dyn Tree>, // User-interactive authentication
|
||||
pub(super) userdevicesessionid_uiaarequest: |
||||
RwLock<BTreeMap<(Box<UserId>, Box<DeviceId>, String), CanonicalJsonValue>>, |
||||
} |
||||
|
||||
impl Uiaa { |
||||
/// Creates a new Uiaa session. Make sure the session token is unique.
|
||||
pub fn create( |
||||
&self, |
||||
user_id: &UserId, |
||||
device_id: &DeviceId, |
||||
uiaainfo: &UiaaInfo, |
||||
json_body: &CanonicalJsonValue, |
||||
) -> Result<()> { |
||||
self.set_uiaa_request( |
||||
user_id, |
||||
device_id, |
||||
uiaainfo.session.as_ref().expect("session should be set"), // TODO: better session error handling (why is it optional in ruma?)
|
||||
json_body, |
||||
)?; |
||||
self.update_uiaa_session( |
||||
user_id, |
||||
device_id, |
||||
uiaainfo.session.as_ref().expect("session should be set"), |
||||
Some(uiaainfo), |
||||
) |
||||
} |
||||
|
||||
pub fn try_auth( |
||||
&self, |
||||
user_id: &UserId, |
||||
device_id: &DeviceId, |
||||
auth: &IncomingAuthData, |
||||
uiaainfo: &UiaaInfo, |
||||
users: &super::users::Users, |
||||
globals: &super::globals::Globals, |
||||
) -> Result<(bool, UiaaInfo)> { |
||||
let mut uiaainfo = auth |
||||
.session() |
||||
.map(|session| self.get_uiaa_session(user_id, device_id, session)) |
||||
.unwrap_or_else(|| Ok(uiaainfo.clone()))?; |
||||
|
||||
if uiaainfo.session.is_none() { |
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); |
||||
} |
||||
|
||||
match auth { |
||||
// Find out what the user completed
|
||||
IncomingAuthData::Password(IncomingPassword { |
||||
identifier, |
||||
password, |
||||
.. |
||||
}) => { |
||||
let username = match identifier { |
||||
UserIdOrLocalpart(username) => username, |
||||
_ => { |
||||
return Err(Error::BadRequest( |
||||
ErrorKind::Unrecognized, |
||||
"Identifier type not recognized.", |
||||
)) |
||||
} |
||||
}; |
||||
|
||||
let user_id = |
||||
UserId::parse_with_server_name(username.clone(), globals.server_name()) |
||||
.map_err(|_| { |
||||
Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.") |
||||
})?; |
||||
|
||||
// Check if password is correct
|
||||
if let Some(hash) = users.password_hash(&user_id)? { |
||||
let hash_matches = |
||||
argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false); |
||||
|
||||
if !hash_matches { |
||||
uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody { |
||||
kind: ErrorKind::Forbidden, |
||||
message: "Invalid username or password.".to_owned(), |
||||
}); |
||||
return Ok((false, uiaainfo)); |
||||
} |
||||
} |
||||
|
||||
// Password was correct! Let's add it to `completed`
|
||||
uiaainfo.completed.push(AuthType::Password); |
||||
} |
||||
IncomingAuthData::Dummy(_) => { |
||||
uiaainfo.completed.push(AuthType::Dummy); |
||||
} |
||||
k => error!("type not supported: {:?}", k), |
||||
} |
||||
|
||||
// Check if a flow now succeeds
|
||||
let mut completed = false; |
||||
'flows: for flow in &mut uiaainfo.flows { |
||||
for stage in &flow.stages { |
||||
if !uiaainfo.completed.contains(stage) { |
||||
continue 'flows; |
||||
} |
||||
} |
||||
// We didn't break, so this flow succeeded!
|
||||
completed = true; |
||||
} |
||||
|
||||
if !completed { |
||||
self.update_uiaa_session( |
||||
user_id, |
||||
device_id, |
||||
uiaainfo.session.as_ref().expect("session is always set"), |
||||
Some(&uiaainfo), |
||||
)?; |
||||
return Ok((false, uiaainfo)); |
||||
} |
||||
|
||||
// UIAA was successful! Remove this session and return true
|
||||
self.update_uiaa_session( |
||||
user_id, |
||||
device_id, |
||||
uiaainfo.session.as_ref().expect("session is always set"), |
||||
None, |
||||
)?; |
||||
Ok((true, uiaainfo)) |
||||
} |
||||
|
||||
fn set_uiaa_request( |
||||
&self, |
||||
user_id: &UserId, |
||||
device_id: &DeviceId, |
||||
session: &str, |
||||
request: &CanonicalJsonValue, |
||||
) -> Result<()> { |
||||
self.userdevicesessionid_uiaarequest |
||||
.write() |
||||
.unwrap() |
||||
.insert( |
||||
(user_id.to_owned(), device_id.to_owned(), session.to_owned()), |
||||
request.to_owned(), |
||||
); |
||||
|
||||
Ok(()) |
||||
} |
||||
|
||||
pub fn get_uiaa_request( |
||||
&self, |
||||
user_id: &UserId, |
||||
device_id: &DeviceId, |
||||
session: &str, |
||||
) -> Option<CanonicalJsonValue> { |
||||
self.userdevicesessionid_uiaarequest |
||||
.read() |
||||
.unwrap() |
||||
.get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned())) |
||||
.map(|j| j.to_owned()) |
||||
} |
||||
|
||||
fn update_uiaa_session( |
||||
&self, |
||||
user_id: &UserId, |
||||
device_id: &DeviceId, |
||||
session: &str, |
||||
uiaainfo: Option<&UiaaInfo>, |
||||
) -> Result<()> { |
||||
let mut userdevicesessionid = user_id.as_bytes().to_vec(); |
||||
userdevicesessionid.push(0xff); |
||||
userdevicesessionid.extend_from_slice(device_id.as_bytes()); |
||||
userdevicesessionid.push(0xff); |
||||
userdevicesessionid.extend_from_slice(session.as_bytes()); |
||||
|
||||
if let Some(uiaainfo) = uiaainfo { |
||||
self.userdevicesessionid_uiaainfo.insert( |
||||
&userdevicesessionid, |
||||
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), |
||||
)?; |
||||
} else { |
||||
self.userdevicesessionid_uiaainfo |
||||
.remove(&userdevicesessionid)?; |
||||
} |
||||
|
||||
Ok(()) |
||||
} |
||||
|
||||
fn get_uiaa_session( |
||||
&self, |
||||
user_id: &UserId, |
||||
device_id: &DeviceId, |
||||
session: &str, |
||||
) -> Result<UiaaInfo> { |
||||
let mut userdevicesessionid = user_id.as_bytes().to_vec(); |
||||
userdevicesessionid.push(0xff); |
||||
userdevicesessionid.extend_from_slice(device_id.as_bytes()); |
||||
userdevicesessionid.push(0xff); |
||||
userdevicesessionid.extend_from_slice(session.as_bytes()); |
||||
|
||||
serde_json::from_slice( |
||||
&self |
||||
.userdevicesessionid_uiaainfo |
||||
.get(&userdevicesessionid)? |
||||
.ok_or(Error::BadRequest( |
||||
ErrorKind::Forbidden, |
||||
"UIAA session does not exist.", |
||||
))?, |
||||
) |
||||
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.")) |
||||
} |
||||
} |
||||
@ -0,0 +1,227 @@
|
||||
use std::{ |
||||
collections::BTreeMap, |
||||
sync::{Arc, RwLock}, |
||||
}; |
||||
|
||||
use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result}; |
||||
use ruma::{ |
||||
api::client::{ |
||||
error::ErrorKind, |
||||
uiaa::{ |
||||
AuthType, IncomingAuthData, IncomingPassword, |
||||
IncomingUserIdentifier::UserIdOrLocalpart, UiaaInfo, |
||||
}, |
||||
}, |
||||
signatures::CanonicalJsonValue, |
||||
DeviceId, UserId, |
||||
}; |
||||
use tracing::error; |
||||
|
||||
use super::abstraction::Tree; |
||||
|
||||
pub struct Uiaa { |
||||
pub(super) userdevicesessionid_uiaainfo: Arc<dyn Tree>, // User-interactive authentication
|
||||
pub(super) userdevicesessionid_uiaarequest: |
||||
RwLock<BTreeMap<(Box<UserId>, Box<DeviceId>, String), CanonicalJsonValue>>, |
||||
} |
||||
|
||||
impl Uiaa { |
||||
/// Creates a new Uiaa session. Make sure the session token is unique.
|
||||
pub fn create( |
||||
&self, |
||||
user_id: &UserId, |
||||
device_id: &DeviceId, |
||||
uiaainfo: &UiaaInfo, |
||||
json_body: &CanonicalJsonValue, |
||||
) -> Result<()> { |
||||
self.set_uiaa_request( |
||||
user_id, |
||||
device_id, |
||||
uiaainfo.session.as_ref().expect("session should be set"), // TODO: better session error handling (why is it optional in ruma?)
|
||||
json_body, |
||||
)?; |
||||
self.update_uiaa_session( |
||||
user_id, |
||||
device_id, |
||||
uiaainfo.session.as_ref().expect("session should be set"), |
||||
Some(uiaainfo), |
||||
) |
||||
} |
||||
|
||||
pub fn try_auth( |
||||
&self, |
||||
user_id: &UserId, |
||||
device_id: &DeviceId, |
||||
auth: &IncomingAuthData, |
||||
uiaainfo: &UiaaInfo, |
||||
users: &super::users::Users, |
||||
globals: &super::globals::Globals, |
||||
) -> Result<(bool, UiaaInfo)> { |
||||
let mut uiaainfo = auth |
||||
.session() |
||||
.map(|session| self.get_uiaa_session(user_id, device_id, session)) |
||||
.unwrap_or_else(|| Ok(uiaainfo.clone()))?; |
||||
|
||||
if uiaainfo.session.is_none() { |
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); |
||||
} |
||||
|
||||
match auth { |
||||
// Find out what the user completed
|
||||
IncomingAuthData::Password(IncomingPassword { |
||||
identifier, |
||||
password, |
||||
.. |
||||
}) => { |
||||
let username = match identifier { |
||||
UserIdOrLocalpart(username) => username, |
||||
_ => { |
||||
return Err(Error::BadRequest( |
||||
ErrorKind::Unrecognized, |
||||
"Identifier type not recognized.", |
||||
)) |
||||
} |
||||
}; |
||||
|
||||
let user_id = |
||||
UserId::parse_with_server_name(username.clone(), globals.server_name()) |
||||
.map_err(|_| { |
||||
Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.") |
||||
})?; |
||||
|
||||
// Check if password is correct
|
||||
if let Some(hash) = users.password_hash(&user_id)? { |
||||
let hash_matches = |
||||
argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false); |
||||
|
||||
if !hash_matches { |
||||
uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody { |
||||
kind: ErrorKind::Forbidden, |
||||
message: "Invalid username or password.".to_owned(), |
||||
}); |
||||
return Ok((false, uiaainfo)); |
||||
} |
||||
} |
||||
|
||||
// Password was correct! Let's add it to `completed`
|
||||
uiaainfo.completed.push(AuthType::Password); |
||||
} |
||||
IncomingAuthData::Dummy(_) => { |
||||
uiaainfo.completed.push(AuthType::Dummy); |
||||
} |
||||
k => error!("type not supported: {:?}", k), |
||||
} |
||||
|
||||
// Check if a flow now succeeds
|
||||
let mut completed = false; |
||||
'flows: for flow in &mut uiaainfo.flows { |
||||
for stage in &flow.stages { |
||||
if !uiaainfo.completed.contains(stage) { |
||||
continue 'flows; |
||||
} |
||||
} |
||||
// We didn't break, so this flow succeeded!
|
||||
completed = true; |
||||
} |
||||
|
||||
if !completed { |
||||
self.update_uiaa_session( |
||||
user_id, |
||||
device_id, |
||||
uiaainfo.session.as_ref().expect("session is always set"), |
||||
Some(&uiaainfo), |
||||
)?; |
||||
return Ok((false, uiaainfo)); |
||||
} |
||||
|
||||
// UIAA was successful! Remove this session and return true
|
||||
self.update_uiaa_session( |
||||
user_id, |
||||
device_id, |
||||
uiaainfo.session.as_ref().expect("session is always set"), |
||||
None, |
||||
)?; |
||||
Ok((true, uiaainfo)) |
||||
} |
||||
|
||||
fn set_uiaa_request( |
||||
&self, |
||||
user_id: &UserId, |
||||
device_id: &DeviceId, |
||||
session: &str, |
||||
request: &CanonicalJsonValue, |
||||
) -> Result<()> { |
||||
self.userdevicesessionid_uiaarequest |
||||
.write() |
||||
.unwrap() |
||||
.insert( |
||||
(user_id.to_owned(), device_id.to_owned(), session.to_owned()), |
||||
request.to_owned(), |
||||
); |
||||
|
||||
Ok(()) |
||||
} |
||||
|
||||
pub fn get_uiaa_request( |
||||
&self, |
||||
user_id: &UserId, |
||||
device_id: &DeviceId, |
||||
session: &str, |
||||
) -> Option<CanonicalJsonValue> { |
||||
self.userdevicesessionid_uiaarequest |
||||
.read() |
||||
.unwrap() |
||||
.get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned())) |
||||
.map(|j| j.to_owned()) |
||||
} |
||||
|
||||
fn update_uiaa_session( |
||||
&self, |
||||
user_id: &UserId, |
||||
device_id: &DeviceId, |
||||
session: &str, |
||||
uiaainfo: Option<&UiaaInfo>, |
||||
) -> Result<()> { |
||||
let mut userdevicesessionid = user_id.as_bytes().to_vec(); |
||||
userdevicesessionid.push(0xff); |
||||
userdevicesessionid.extend_from_slice(device_id.as_bytes()); |
||||
userdevicesessionid.push(0xff); |
||||
userdevicesessionid.extend_from_slice(session.as_bytes()); |
||||
|
||||
if let Some(uiaainfo) = uiaainfo { |
||||
self.userdevicesessionid_uiaainfo.insert( |
||||
&userdevicesessionid, |
||||
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), |
||||
)?; |
||||
} else { |
||||
self.userdevicesessionid_uiaainfo |
||||
.remove(&userdevicesessionid)?; |
||||
} |
||||
|
||||
Ok(()) |
||||
} |
||||
|
||||
fn get_uiaa_session( |
||||
&self, |
||||
user_id: &UserId, |
||||
device_id: &DeviceId, |
||||
session: &str, |
||||
) -> Result<UiaaInfo> { |
||||
let mut userdevicesessionid = user_id.as_bytes().to_vec(); |
||||
userdevicesessionid.push(0xff); |
||||
userdevicesessionid.extend_from_slice(device_id.as_bytes()); |
||||
userdevicesessionid.push(0xff); |
||||
userdevicesessionid.extend_from_slice(session.as_bytes()); |
||||
|
||||
serde_json::from_slice( |
||||
&self |
||||
.userdevicesessionid_uiaainfo |
||||
.get(&userdevicesessionid)? |
||||
.ok_or(Error::BadRequest( |
||||
ErrorKind::Forbidden, |
||||
"UIAA session does not exist.", |
||||
))?, |
||||
) |
||||
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.")) |
||||
} |
||||
} |
||||
Loading…
Reference in new issue