From 3d48a05353efba5fd30af6daf5a800cb61cf7e1c Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Sat, 14 Jun 2025 15:36:16 +0100 Subject: [PATCH] WIP: rate-limiting --- conduit-config/Cargo.toml | 2 +- conduit-config/src/lib.rs | 10 +- conduit-config/src/rate_limiting.rs | 236 +++++++++ conduit/src/api/client_server/media.rs | 350 +++++++++----- conduit/src/api/ruma_wrapper/axum.rs | 437 +++++++++-------- conduit/src/api/ruma_wrapper/mod.rs | 3 +- conduit/src/api/server_server.rs | 23 +- conduit/src/database/key_value/media.rs | 34 +- conduit/src/service/admin/mod.rs | 9 +- conduit/src/service/media/data.rs | 5 +- conduit/src/service/media/mod.rs | 45 +- conduit/src/service/mod.rs | 4 + conduit/src/service/rate_limiting/mod.rs | 581 +++++++++++++++++++++++ 13 files changed, 1386 insertions(+), 353 deletions(-) create mode 100644 conduit-config/src/rate_limiting.rs create mode 100644 conduit/src/service/rate_limiting/mod.rs diff --git a/conduit-config/Cargo.toml b/conduit-config/Cargo.toml index 17ea181f..c4628f0b 100644 --- a/conduit-config/Cargo.toml +++ b/conduit-config/Cargo.toml @@ -23,7 +23,7 @@ reqwest.workspace = true # default room version, server name, ignored keys [dependencies.ruma] -features = ["federation-api"] +features = ["client-api", "federation-api"] workspace = true [features] diff --git a/conduit-config/src/lib.rs b/conduit-config/src/lib.rs index 74f83868..e5ddf556 100644 --- a/conduit-config/src/lib.rs +++ b/conduit-config/src/lib.rs @@ -17,7 +17,9 @@ use url::Url; pub mod error; mod proxy; -use self::proxy::ProxyConfig; +pub mod rate_limiting; + +use self::{proxy::ProxyConfig, rate_limiting::Config as RateLimitingConfig}; const SHA256_HEX_LENGTH: u8 = 64; @@ -98,6 +100,8 @@ pub struct IncompleteConfig { #[serde(default)] pub media: IncompleteMediaConfig, + pub rate_limiting: RateLimitingConfig, + pub emergency_password: Option, #[serde(flatten)] @@ -147,6 +151,8 @@ pub struct Config { pub media: MediaConfig, + pub rate_limiting: RateLimitingConfig, + pub emergency_password: Option, pub catchall: BTreeMap, @@ -194,6 +200,7 @@ impl From for Config { turn_ttl, turn, media, + rate_limiting, emergency_password, catchall, ignored_keys, @@ -295,6 +302,7 @@ impl From for Config { ip_address_detection, turn, media, + rate_limiting, emergency_password, catchall, ignored_keys, diff --git a/conduit-config/src/rate_limiting.rs b/conduit-config/src/rate_limiting.rs new file mode 100644 index 00000000..738e4052 --- /dev/null +++ b/conduit-config/src/rate_limiting.rs @@ -0,0 +1,236 @@ +use std::{collections::HashMap, num::NonZeroU64}; + +use bytesize::ByteSize; +use ruma::api::Metadata; +use serde::Deserialize; + +#[derive(Debug, Clone, Deserialize)] +pub struct Config { + #[serde(flatten)] + pub target: ConfigFragment, + pub global: ConfigFragment, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ConfigFragment { + pub client: ConfigClientFragment, + pub federation: ConfigFederationFragment, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ConfigClientFragment { + pub map: HashMap, + pub media: ClientMediaConfig, + // TODO: Only have available on target, not global (same with most authenticated endpoints too maybe)? + pub authentication_failures: RequestLimitation, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ConfigFederationFragment { + pub map: HashMap, + pub media: FederationMediaConfig, +} + +impl ConfigFragment { + pub fn get(&self, restriction: &Restriction) -> &RequestLimitation { + // Maybe look into https://github.com/moriyoshi-kasuga/enum-table + match restriction { + Restriction::Client(client_restriction) => { + self.client.map.get(client_restriction).unwrap() + } + Restriction::Federation(federation_restriction) => { + self.federation.map.get(federation_restriction).unwrap() + } + } + } +} + +#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub enum Restriction { + Client(ClientRestriction), + Federation(FederationRestriction), +} + +#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[serde(rename_all = "snake_case")] +pub enum ClientRestriction { + Registration, + Login, + RegistrationTokenValidity, + + SendEvent, + + Join, + Invite, + Knock, + + SendReport, + CreateAlias, + + MediaDownload, + MediaCreate, +} + +#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[serde(rename_all = "snake_case")] +pub enum FederationRestriction { + Join, + Knock, + Invite, + + // Transactions should be handled by a completely dedicated rate-limiter + Transaction, + + MediaDownload, +} + +impl TryFrom for Restriction { + type Error = (); + + fn try_from(value: Metadata) -> Result { + use Restriction::*; + use ruma::api::{ + IncomingRequest, + client::{ + account::{check_registration_token_validity, register}, + alias::create_alias, + authenticated_media::{ + get_content, get_content_as_filename, get_content_thumbnail, get_media_preview, + }, + knock::knock_room, + media::{self, create_content, create_mxc_uri}, + membership::{invite_user, join_room_by_id, join_room_by_id_or_alias}, + message::send_message_event, + reporting::report_user, + room::{report_content, report_room}, + session::login, + state::send_state_event, + }, + federation::{ + authenticated_media::{ + get_content as federation_get_content, + get_content_thumbnail as federation_get_content_thumbnail, + }, + membership::{create_invite, create_join_event, create_knock_event}, + }, + }; + + Ok(match value { + register::v3::Request::METADATA => Client(ClientRestriction::Registration), + check_registration_token_validity::v1::Request::METADATA => { + Client(ClientRestriction::RegistrationTokenValidity) + } + login::v3::Request::METADATA => Client(ClientRestriction::Login), + send_message_event::v3::Request::METADATA | send_state_event::v3::Request::METADATA => { + Client(ClientRestriction::SendEvent) + } + join_room_by_id::v3::Request::METADATA + | join_room_by_id_or_alias::v3::Request::METADATA => Client(ClientRestriction::Join), + invite_user::v3::Request::METADATA => Client(ClientRestriction::Invite), + knock_room::v3::Request::METADATA => Client(ClientRestriction::Knock), + report_user::v3::Request::METADATA + | report_content::v3::Request::METADATA + | report_room::v3::Request::METADATA => Client(ClientRestriction::SendReport), + create_alias::v3::Request::METADATA => Client(ClientRestriction::CreateAlias), + // NOTE: handle async media upload in a way that doesn't half the number of uploads you can do within a short timeframe, while not allowing pre-generation of MXC uris to allow uploading double the number of media at once + create_content::v3::Request::METADATA | create_mxc_uri::v1::Request::METADATA => { + Client(ClientRestriction::MediaCreate) + } + // Unauthenticate media is deprecated + #[allow(deprecated)] + media::get_content::v3::Request::METADATA + | media::get_content_as_filename::v3::Request::METADATA + | media::get_content_thumbnail::v3::Request::METADATA + | media::get_media_preview::v3::Request::METADATA + | get_content::v1::Request::METADATA + | get_content_as_filename::v1::Request::METADATA + | get_content_thumbnail::v1::Request::METADATA + | get_media_preview::v1::Request::METADATA => Client(ClientRestriction::MediaDownload), + federation_get_content::v1::Request::METADATA + | federation_get_content_thumbnail::v1::Request::METADATA => { + Federation(FederationRestriction::MediaDownload) + } + // v1 is deprecated + #[allow(deprecated)] + create_join_event::v1::Request::METADATA | create_join_event::v2::Request::METADATA => { + Federation(FederationRestriction::Join) + } + create_knock_event::v1::Request::METADATA => Federation(FederationRestriction::Knock), + create_invite::v1::Request::METADATA | create_invite::v2::Request::METADATA => { + Federation(FederationRestriction::Invite) + } + + _ => return Err(()), + }) + } +} + +#[derive(Clone, Copy, Debug, Deserialize)] +pub struct RequestLimitation { + #[serde(flatten)] + pub timeframe: Timeframe, + pub burst_capacity: NonZeroU64, +} + +#[derive(Deserialize, Clone, Copy, Debug)] +#[serde(rename_all = "snake_case")] +// When deserializing, we want this prefix +#[allow(clippy::enum_variant_names)] +pub enum Timeframe { + PerSecond(NonZeroU64), + PerMinute(NonZeroU64), + PerHour(NonZeroU64), + PerDay(NonZeroU64), +} + +impl Timeframe { + pub fn nano_gap(&self) -> u64 { + match self { + Timeframe::PerSecond(t) => 1000 * 1000 * 1000 / t.get(), + Timeframe::PerMinute(t) => 1000 * 1000 * 1000 * 60 / t.get(), + Timeframe::PerHour(t) => 1000 * 1000 * 1000 * 60 * 60 / t.get(), + Timeframe::PerDay(t) => 1000 * 1000 * 1000 * 60 * 60 * 24 / t.get(), + } + } +} + +#[derive(Clone, Copy, Debug, Deserialize)] +pub struct ClientMediaConfig { + pub download: MediaLimitation, + pub upload: MediaLimitation, + pub fetch: MediaLimitation, +} + +#[derive(Clone, Copy, Debug, Deserialize)] +pub struct FederationMediaConfig { + pub download: MediaLimitation, +} + +#[derive(Clone, Copy, Debug, Deserialize)] +pub struct MediaLimitation { + #[serde(flatten)] + pub timeframe: MediaTimeframe, + pub burst_capacity: ByteSize, +} + +#[derive(Deserialize, Clone, Copy, Debug)] +#[serde(rename_all = "snake_case")] +// When deserializing, we want this prefix +#[allow(clippy::enum_variant_names)] +pub enum MediaTimeframe { + PerSecond(ByteSize), + PerMinute(ByteSize), + PerHour(ByteSize), + PerDay(ByteSize), +} + +impl MediaTimeframe { + pub fn bytes_per_sec(&self) -> u64 { + match self { + MediaTimeframe::PerSecond(t) => t.as_u64(), + MediaTimeframe::PerMinute(t) => t.as_u64() / 60, + MediaTimeframe::PerHour(t) => t.as_u64() / (60 * 60), + MediaTimeframe::PerDay(t) => t.as_u64() / (60 * 60 * 24), + } + } +} diff --git a/conduit/src/api/client_server/media.rs b/conduit/src/api/client_server/media.rs index b66fdc49..c6cf5fd3 100644 --- a/conduit/src/api/client_server/media.rs +++ b/conduit/src/api/client_server/media.rs @@ -3,7 +3,14 @@ use std::time::Duration; -use crate::{Error, Result, Ruma, service::media::FileMeta, services, utils}; +use crate::{ + Error, Result, Ruma, + service::{ + media::{FileMeta, size}, + rate_limiting::Target, + }, + services, utils, +}; use http::header::{CONTENT_DISPOSITION, CONTENT_TYPE}; use ruma::{ ServerName, UInt, @@ -54,6 +61,8 @@ pub async fn get_media_config_auth_route( pub async fn create_content_route( body: Ruma, ) -> Result { + let sender_user = body.sender_user.expect("user is authenticated"); + let create_content::v3::Request { filename, content_type, @@ -61,6 +70,13 @@ pub async fn create_content_route( .. } = body.body; + let target = Target::from_client_request(body.appservice_info, &sender_user); + + services() + .rate_limiting + .check_media_upload(target, size(&file)?) + .await?; + let media_id = utils::random_string(MXC_LENGTH); services() @@ -71,7 +87,7 @@ pub async fn create_content_route( filename.as_deref(), content_type.as_deref(), &file, - body.sender_user.as_deref(), + Some(&sender_user), ) .await?; @@ -84,7 +100,13 @@ pub async fn create_content_route( pub async fn get_remote_content( server_name: &ServerName, media_id: String, + target: Target, ) -> Result { + services() + .rate_limiting + .check_media_pre_fetch(&target) + .await?; + let content_response = match services() .sending .send_federation_request( @@ -153,6 +175,11 @@ pub async fn get_remote_content( ) .await?; + services() + .rate_limiting + .update_media_post_fetch(target, size(&content_response.file)?) + .await; + Ok(content_response) } @@ -171,11 +198,21 @@ pub async fn get_content_route( } = get_content( &body.server_name, body.media_id.clone(), - body.allow_remote, - false, + body.sender_ip_address.map(Target::Ip), ) .await?; + if let Some(target) = Target::from_client_request_optional_auth( + body.appservice_info, + &body.sender_user, + body.sender_ip_address, + ) { + services() + .rate_limiting + .update_media_post_fetch(target, size(&file)?) + .await; + } + Ok(media::get_content::v3::Response { file, content_type, @@ -190,14 +227,24 @@ pub async fn get_content_route( pub async fn get_content_auth_route( body: Ruma, ) -> Result { - get_content(&body.server_name, body.media_id.clone(), true, true).await + let Ruma:: { + body, + sender_user, + appservice_info, + .. + } = body; + + let sender_user = sender_user.as_ref().expect("user is authenticated"); + + let target = Target::from_client_request(appservice_info, sender_user); + + get_content(&body.server_name, body.media_id.clone(), Some(target)).await } pub async fn get_content( server_name: &ServerName, media_id: String, - allow_remote: bool, - authenticated: bool, + target: Option, ) -> Result { services().media.check_blocked(server_name, &media_id)?; @@ -207,7 +254,7 @@ pub async fn get_content( file, })) = services() .media - .get(server_name, &media_id, authenticated) + .get(server_name, &media_id, target.clone()) .await { Ok(get_content::v1::Response { @@ -215,16 +262,25 @@ pub async fn get_content( content_type, content_disposition: Some(content_disposition), }) - } else if server_name != services().globals.server_name() && allow_remote && authenticated { - let remote_content_response = get_remote_content(server_name, media_id.clone()).await?; - - Ok(get_content::v1::Response { - content_disposition: remote_content_response.content_disposition, - content_type: remote_content_response.content_type, - file: remote_content_response.file, - }) } else { - Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) + let error = Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")); + + if let Some(target) = target { + if server_name != services().globals.server_name() && target.is_authenticated() { + let remote_content_response = + get_remote_content(server_name, media_id.clone(), target).await?; + + Ok(get_content::v1::Response { + content_disposition: remote_content_response.content_disposition, + content_type: remote_content_response.content_type, + file: remote_content_response.file, + }) + } else { + error + } + } else { + error + } } } @@ -244,8 +300,7 @@ pub async fn get_content_as_filename_route( &body.server_name, body.media_id.clone(), body.filename.clone(), - body.allow_remote, - false, + body.sender_ip_address.map(Target::Ip), ) .await?; @@ -263,12 +318,22 @@ pub async fn get_content_as_filename_route( pub async fn get_content_as_filename_auth_route( body: Ruma, ) -> Result { + let Ruma:: { + body, + sender_user, + appservice_info, + .. + } = body; + + let sender_user = sender_user.as_ref().expect("user is authenticated"); + + let target = Target::from_client_request(appservice_info, sender_user); + get_content_as_filename( &body.server_name, body.media_id.clone(), body.filename.clone(), - true, - true, + Some(target), ) .await } @@ -277,8 +342,7 @@ async fn get_content_as_filename( server_name: &ServerName, media_id: String, filename: String, - allow_remote: bool, - authenticated: bool, + target: Option, ) -> Result { services().media.check_blocked(server_name, &media_id)?; @@ -286,7 +350,7 @@ async fn get_content_as_filename( file, content_type, .. })) = services() .media - .get(server_name, &media_id, authenticated) + .get(server_name, &media_id, target.clone()) .await { Ok(get_content_as_filename::v1::Response { @@ -297,19 +361,28 @@ async fn get_content_as_filename( .with_filename(Some(filename.clone())), ), }) - } else if server_name != services().globals.server_name() && allow_remote && authenticated { - let remote_content_response = get_remote_content(server_name, media_id.clone()).await?; - - Ok(get_content_as_filename::v1::Response { - content_disposition: Some( - ContentDisposition::new(ContentDispositionType::Inline) - .with_filename(Some(filename.clone())), - ), - content_type: remote_content_response.content_type, - file: remote_content_response.file, - }) } else { - Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) + let error = Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")); + + if let Some(target) = target { + if server_name != services().globals.server_name() && target.is_authenticated() { + let remote_content_response = + get_remote_content(server_name, media_id.clone(), target).await?; + + Ok(get_content_as_filename::v1::Response { + content_disposition: Some( + ContentDisposition::new(ContentDispositionType::Inline) + .with_filename(Some(filename.clone())), + ), + content_type: remote_content_response.content_type, + file: remote_content_response.file, + }) + } else { + error + } + } else { + error + } } } @@ -321,6 +394,17 @@ async fn get_content_as_filename( pub async fn get_content_thumbnail_route( body: Ruma, ) -> Result { + let Ruma:: { + body, + sender_user, + sender_ip_address, + appservice_info, + .. + } = body; + + let target = + Target::from_client_request_optional_auth(appservice_info, &sender_user, sender_ip_address); + let get_content_thumbnail::v1::Response { file, content_type, @@ -332,8 +416,7 @@ pub async fn get_content_thumbnail_route( body.width, body.method.clone(), body.animated, - body.allow_remote, - false, + target, ) .await?; @@ -351,6 +434,15 @@ pub async fn get_content_thumbnail_route( pub async fn get_content_thumbnail_auth_route( body: Ruma, ) -> Result { + let Ruma:: { + body, + sender_user, + appservice_info, + .. + } = body; + let sender_user = sender_user.as_ref().expect("user is authenticated"); + let target = Target::from_client_request(appservice_info, sender_user); + get_content_thumbnail( &body.server_name, body.media_id.clone(), @@ -358,8 +450,7 @@ pub async fn get_content_thumbnail_auth_route( body.width, body.method.clone(), body.animated, - true, - true, + Some(target), ) .await } @@ -372,8 +463,7 @@ async fn get_content_thumbnail( width: UInt, method: Option, animated: Option, - allow_remote: bool, - authenticated: bool, + target: Option, ) -> Result { services().media.check_blocked(server_name, &media_id)?; @@ -392,7 +482,7 @@ async fn get_content_thumbnail( height .try_into() .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?, - authenticated, + target.clone(), ) .await? { @@ -401,99 +491,117 @@ async fn get_content_thumbnail( content_type, content_disposition: Some(content_disposition), }) - } else if server_name != services().globals.server_name() && allow_remote && authenticated { - let thumbnail_response = match services() - .sending - .send_federation_request( - server_name, - federation_media::get_content_thumbnail::v1::Request { - height, - width, - method: method.clone(), - media_id: media_id.clone(), - timeout_ms: Duration::from_secs(20), - animated, - }, - ) - .await - { - Ok(federation_media::get_content_thumbnail::v1::Response { - metadata: _, - content: FileOrLocation::File(content), - }) => get_content_thumbnail::v1::Response { - file: content.file, - content_type: content.content_type, - content_disposition: content.content_disposition, - }, + } else { + let error = Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")); - Ok(federation_media::get_content_thumbnail::v1::Response { - metadata: _, - content: FileOrLocation::Location(url), - }) => { - let get_content::v1::Response { - file, - content_type, - content_disposition, - } = get_location_content(url).await?; - - get_content_thumbnail::v1::Response { - file, - content_type, - content_disposition, - } - } - Err(Error::BadRequest(ErrorKind::Unrecognized, _)) => { - let media::get_content_thumbnail::v3::Response { - file, - content_type, - content_disposition, - .. - } = services() + if let Some(target) = target { + if server_name != services().globals.server_name() { + services() + .rate_limiting + .check_media_pre_fetch(&target) + .await?; + + let thumbnail_response = match services() .sending .send_federation_request( server_name, - media::get_content_thumbnail::v3::Request { + federation_media::get_content_thumbnail::v1::Request { height, width, method: method.clone(), - server_name: server_name.to_owned(), media_id: media_id.clone(), timeout_ms: Duration::from_secs(20), - allow_redirect: false, animated, - allow_remote: false, }, ) + .await + { + Ok(federation_media::get_content_thumbnail::v1::Response { + metadata: _, + content: FileOrLocation::File(content), + }) => get_content_thumbnail::v1::Response { + file: content.file, + content_type: content.content_type, + content_disposition: content.content_disposition, + }, + + Ok(federation_media::get_content_thumbnail::v1::Response { + metadata: _, + content: FileOrLocation::Location(url), + }) => { + let get_content::v1::Response { + file, + content_type, + content_disposition, + } = get_location_content(url).await?; + + get_content_thumbnail::v1::Response { + file, + content_type, + content_disposition, + } + } + Err(Error::BadRequest(ErrorKind::Unrecognized, _)) => { + let media::get_content_thumbnail::v3::Response { + file, + content_type, + content_disposition, + .. + } = services() + .sending + .send_federation_request( + server_name, + media::get_content_thumbnail::v3::Request { + height, + width, + method: method.clone(), + server_name: server_name.to_owned(), + media_id: media_id.clone(), + timeout_ms: Duration::from_secs(20), + allow_redirect: false, + animated, + allow_remote: false, + }, + ) + .await?; + + get_content_thumbnail::v1::Response { + file, + content_type, + content_disposition, + } + } + Err(e) => return Err(e), + }; + + services() + .rate_limiting + .update_media_post_fetch(target, size(&thumbnail_response.file)?) + .await; + + services() + .media + .upload_thumbnail( + server_name, + &media_id, + thumbnail_response + .content_disposition + .as_ref() + .and_then(|cd| cd.filename.as_deref()), + thumbnail_response.content_type.as_deref(), + width.try_into().expect("all UInts are valid u32s"), + height.try_into().expect("all UInts are valid u32s"), + &thumbnail_response.file, + ) .await?; - get_content_thumbnail::v1::Response { - file, - content_type, - content_disposition, - } + Ok(thumbnail_response) + } else { + error } - Err(e) => return Err(e), - }; - - services() - .media - .upload_thumbnail( - server_name, - &media_id, - thumbnail_response - .content_disposition - .as_ref() - .and_then(|cd| cd.filename.as_deref()), - thumbnail_response.content_type.as_deref(), - width.try_into().expect("all UInts are valid u32s"), - height.try_into().expect("all UInts are valid u32s"), - &thumbnail_response.file, - ) - .await?; - - Ok(thumbnail_response) - } else { - Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) + } else { + error + } } } diff --git a/conduit/src/api/ruma_wrapper/axum.rs b/conduit/src/api/ruma_wrapper/axum.rs index 3e6e0272..d4bf31d0 100644 --- a/conduit/src/api/ruma_wrapper/axum.rs +++ b/conduit/src/api/ruma_wrapper/axum.rs @@ -3,7 +3,7 @@ use std::{ error::Error as _, iter::FromIterator, net::{IpAddr, SocketAddr}, - str, + str::{self, FromStr}, }; use axum::{ @@ -18,6 +18,7 @@ use axum_extra::{ typed_header::TypedHeaderRejectionReason, }; use bytes::{BufMut, BytesMut}; +use conduit_config::IpAddrDetection; use http::{Request, StatusCode}; use ruma::{ CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedUserId, UserId, @@ -31,12 +32,15 @@ use tracing::{debug, error, warn}; use super::{Ruma, RumaResponse}; use crate::{ - Error, Result, config::IpAddrDetection, service::appservice::RegistrationInfo, services, + Error, Result, + service::{appservice::RegistrationInfo, rate_limiting::Target}, + services, }; enum Token { Appservice(Box), User((OwnedUserId, OwnedDeviceId)), + AuthRateLimited(Error), Invalid, None, } @@ -122,7 +126,17 @@ where }; let token = if let Some(token) = token { - if let Some(reg_info) = services().appservice.find_from_token(token).await { + let mut rate_limited = None; + + if let Some(ip_addr) = sender_ip_address { + if let Err(instant) = services().rate_limiting.pre_auth_check(ip_addr).await { + rate_limited = Some(instant); + } + } + + if let Some(instant) = rate_limited { + Token::AuthRateLimited(instant) + } else if let Some(reg_info) = services().appservice.find_from_token(token).await { Token::Appservice(Box::new(reg_info.clone())) } else if let Some((user_id, device_id)) = services().users.find_from_token(token)? { Token::User((user_id, device_id)) @@ -135,217 +149,243 @@ where let mut json_body = serde_json::from_slice::(&body).ok(); - let (sender_user, sender_device, sender_servername, appservice_info) = - match (metadata.authentication, token) { - (_, Token::Invalid) => { - // OpenID endpoint uses a query param with the same name, drop this once query params for user auth are removed from the spec - if query_params.access_token.is_some() { - (None, None, None, None) + let (sender_user, sender_device, sender_servername, appservice_info) = match ( + metadata.authentication, + token, + ) { + (_, Token::AuthRateLimited(instant)) => { + return Err(instant); + } + (_, Token::Invalid) => { + // OpenID endpoint uses a query param with the same name, drop this once query params for user auth are removed from the spec + if query_params.access_token.is_some() { + (None, None, None, None) + } else { + if let Some(addr) = sender_ip_address { + services() + .rate_limiting + .update_post_auth_failure(addr) + .await; } else { - return Err(Error::BadRequest( - ErrorKind::UnknownToken { soft_logout: false }, - "Unknown access token.", - )); - } - } - (AuthScheme::AccessToken, Token::Appservice(info)) => { - let user_id = query_params - .user_id - .map_or_else( - || { - UserId::parse_with_server_name( - info.registration.sender_localpart.as_str(), - services().globals.server_name(), - ) - }, - UserId::parse, - ) - .map_err(|_| { - Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") - })?; - - if !info.is_user_match(&user_id) { - return Err(Error::BadRequest( - ErrorKind::Exclusive, - "User is not in namespace.", - )); + error!( + "Auth failure occurred, but IP address was not extracted. Please check your Conduit & reverse proxy configuration, as if nothing is done, an attacker can brute-force access tokens and login to user's accounts" + ); } - if !services().users.exists(&user_id)? { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "User does not exist.", - )); - } + return Err(Error::BadRequest( + ErrorKind::UnknownToken { soft_logout: false }, + "Unknown access token.", + )); + } + } + (AuthScheme::AccessToken, Token::Appservice(info)) => { + let user_id = query_params + .user_id + .map_or_else( + || { + UserId::parse_with_server_name( + info.registration.sender_localpart.as_str(), + services().globals.server_name(), + ) + }, + UserId::parse, + ) + .map_err(|_| { + Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") + })?; - (Some(user_id), None, None, Some(*info)) + if !info.is_user_match(&user_id) { + return Err(Error::BadRequest( + ErrorKind::Exclusive, + "User is not in namespace.", + )); } - ( - AuthScheme::None - | AuthScheme::AppserviceToken - | AuthScheme::AppserviceTokenOptional - | AuthScheme::AccessTokenOptional, - Token::Appservice(info), - ) => (None, None, None, Some(*info)), - (AuthScheme::AppserviceToken | AuthScheme::AccessToken, Token::None) => { + + if !services().users.exists(&user_id)? { return Err(Error::BadRequest( - ErrorKind::MissingToken, - "Missing access token.", + ErrorKind::forbidden(), + "User does not exist.", )); } - ( - AuthScheme::AccessToken | AuthScheme::AccessTokenOptional | AuthScheme::None, - Token::User((user_id, device_id)), - ) => (Some(user_id), Some(device_id), None, None), - (AuthScheme::ServerSignatures, Token::None) => { - let TypedHeader(Authorization(x_matrix)) = parts - .extract::>>() - .await - .map_err(|e| { - warn!("Missing or invalid Authorization header: {}", e); - - let msg = match e.reason() { - TypedHeaderRejectionReason::Missing => { - "Missing Authorization header." - } - TypedHeaderRejectionReason::Error(_) => { - "Invalid X-Matrix signatures." - } - _ => "Unknown header-related error", - }; - - Error::BadRequest(ErrorKind::forbidden(), msg) - })?; - - if let Some(dest) = x_matrix.destination { - if dest != services().globals.server_name() { - return Err(Error::BadRequest( - ErrorKind::Unauthorized, - "X-Matrix destination field does not match server name.", - )); - } - }; - - let origin_signatures = BTreeMap::from_iter([( - x_matrix.key.clone(), - CanonicalJsonValue::String(x_matrix.sig.to_string()), - )]); - - let signatures = BTreeMap::from_iter([( - x_matrix.origin.as_str().to_owned(), - CanonicalJsonValue::Object( - origin_signatures - .into_iter() - .map(|(k, v)| (k.to_string(), v)) - .collect(), - ), - )]); - let mut request_map = BTreeMap::from_iter([ - ( - "method".to_owned(), - CanonicalJsonValue::String(parts.method.to_string()), - ), - ( - "uri".to_owned(), - CanonicalJsonValue::String(parts.uri.to_string()), - ), - ( - "origin".to_owned(), - CanonicalJsonValue::String(x_matrix.origin.as_str().to_owned()), - ), - ( - "destination".to_owned(), - CanonicalJsonValue::String( - services().globals.server_name().as_str().to_owned(), - ), - ), - ( - "signatures".to_owned(), - CanonicalJsonValue::Object(signatures), + (Some(user_id), None, None, Some(*info)) + } + ( + AuthScheme::None + | AuthScheme::AppserviceToken + | AuthScheme::AppserviceTokenOptional + | AuthScheme::AccessTokenOptional, + Token::Appservice(info), + ) => (None, None, None, Some(*info)), + (AuthScheme::AppserviceToken | AuthScheme::AccessToken, Token::None) => { + return Err(Error::BadRequest( + ErrorKind::MissingToken, + "Missing access token.", + )); + } + ( + AuthScheme::AccessToken | AuthScheme::AccessTokenOptional | AuthScheme::None, + Token::User((user_id, device_id)), + ) => (Some(user_id), Some(device_id), None, None), + (AuthScheme::ServerSignatures, Token::None) => { + let TypedHeader(Authorization(x_matrix)) = parts + .extract::>>() + .await + .map_err(|e| { + warn!("Missing or invalid Authorization header: {}", e); + + let msg = match e.reason() { + TypedHeaderRejectionReason::Missing => "Missing Authorization header.", + TypedHeaderRejectionReason::Error(_) => "Invalid X-Matrix signatures.", + _ => "Unknown header-related error", + }; + + Error::BadRequest(ErrorKind::forbidden(), msg) + })?; + + if let Some(dest) = x_matrix.destination { + if dest != services().globals.server_name() { + return Err(Error::BadRequest( + ErrorKind::Unauthorized, + "X-Matrix destination field does not match server name.", + )); + } + }; + + let origin_signatures = BTreeMap::from_iter([( + x_matrix.key.clone(), + CanonicalJsonValue::String(x_matrix.sig.to_string()), + )]); + + let signatures = BTreeMap::from_iter([( + x_matrix.origin.as_str().to_owned(), + CanonicalJsonValue::Object( + origin_signatures + .into_iter() + .map(|(k, v)| (k.to_string(), v)) + .collect(), + ), + )]); + + let mut request_map = BTreeMap::from_iter([ + ( + "method".to_owned(), + CanonicalJsonValue::String(parts.method.to_string()), + ), + ( + "uri".to_owned(), + CanonicalJsonValue::String(parts.uri.to_string()), + ), + ( + "origin".to_owned(), + CanonicalJsonValue::String(x_matrix.origin.as_str().to_owned()), + ), + ( + "destination".to_owned(), + CanonicalJsonValue::String( + services().globals.server_name().as_str().to_owned(), ), - ]); - - if let Some(json_body) = &json_body { - request_map.insert("content".to_owned(), json_body.clone()); - }; - - let keys_result = services() - .rooms - .event_handler - .fetch_signing_keys(&x_matrix.origin, vec![x_matrix.key.to_string()], false) - .await; - - let keys = match keys_result { - Ok(b) => b, - Err(e) => { - warn!("Failed to fetch signing keys: {}", e); - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Failed to fetch signing keys.", - )); - } - }; - - // Only verify_keys that are currently valid should be used for validating requests - // as per MSC4029 - let pub_key_map = BTreeMap::from_iter([( - x_matrix.origin.as_str().to_owned(), - if keys.valid_until_ts > MilliSecondsSinceUnixEpoch::now() { - keys.verify_keys - .into_iter() - .map(|(id, key)| (id, key.key)) - .collect() - } else { - BTreeMap::new() - }, - )]); - - match ruma::signatures::verify_json(&pub_key_map, &request_map) { - Ok(()) => (None, None, Some(x_matrix.origin), None), - Err(e) => { + ), + ( + "signatures".to_owned(), + CanonicalJsonValue::Object(signatures), + ), + ]); + + if let Some(json_body) = &json_body { + request_map.insert("content".to_owned(), json_body.clone()); + }; + + let keys_result = services() + .rooms + .event_handler + .fetch_signing_keys(&x_matrix.origin, vec![x_matrix.key.to_string()], false) + .await; + + let keys = match keys_result { + Ok(b) => b, + Err(e) => { + warn!("Failed to fetch signing keys: {}", e); + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Failed to fetch signing keys.", + )); + } + }; + + // Only verify_keys that are currently valid should be used for validating requests + // as per MSC4029 + let pub_key_map = BTreeMap::from_iter([( + x_matrix.origin.as_str().to_owned(), + if keys.valid_until_ts > MilliSecondsSinceUnixEpoch::now() { + keys.verify_keys + .into_iter() + .map(|(id, key)| (id, key.key)) + .collect() + } else { + BTreeMap::new() + }, + )]); + + match ruma::signatures::verify_json(&pub_key_map, &request_map) { + Ok(()) => (None, None, Some(x_matrix.origin), None), + Err(e) => { + warn!( + "Failed to verify json request from {}: {}\n{:?}", + x_matrix.origin, e, request_map + ); + + if parts.uri.to_string().contains('@') { warn!( - "Failed to verify json request from {}: {}\n{:?}", - x_matrix.origin, e, request_map - ); - - if parts.uri.to_string().contains('@') { - warn!( - "Request uri contained '@' character. Make sure your \ + "Request uri contained '@' character. Make sure your \ reverse proxy gives Conduit the raw uri (apache: use \ nocanon)" - ); - } - - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Failed to verify X-Matrix signatures.", - )); + ); } + + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "Failed to verify X-Matrix signatures.", + )); } } - ( - AuthScheme::None - | AuthScheme::AppserviceTokenOptional - | AuthScheme::AccessTokenOptional, - Token::None, - ) => (None, None, None, None), - (AuthScheme::ServerSignatures, Token::Appservice(_) | Token::User(_)) => { - return Err(Error::BadRequest( - ErrorKind::Unauthorized, - "Only server signatures should be used on this endpoint.", - )); - } - ( - AuthScheme::AppserviceToken | AuthScheme::AppserviceTokenOptional, - Token::User(_), - ) => { - return Err(Error::BadRequest( - ErrorKind::Unauthorized, - "Only appservice access tokens should be used on this endpoint.", - )); - } - }; + } + ( + AuthScheme::None + | AuthScheme::AppserviceTokenOptional + | AuthScheme::AccessTokenOptional, + Token::None, + ) => (None, None, None, None), + (AuthScheme::ServerSignatures, Token::Appservice(_) | Token::User(_)) => { + return Err(Error::BadRequest( + ErrorKind::Unauthorized, + "Only server signatures should be used on this endpoint.", + )); + } + (AuthScheme::AppserviceToken | AuthScheme::AppserviceTokenOptional, Token::User(_)) => { + return Err(Error::BadRequest( + ErrorKind::Unauthorized, + "Only appservice access tokens should be used on this endpoint.", + )); + } + }; + + let sender_ip_address = parts + .headers + .get("X-Forwarded-For") + .and_then(|header| header.to_str().ok()) + .map(|header| header.split_once(',').map(|(ip, _)| ip).unwrap_or(header)) + .and_then(|ip| IpAddr::from_str(ip).ok()); + + let target = if let Some(server_name) = sender_servername.clone() { + Some(Target::Server(server_name)) + } else if let Some(user) = &sender_user { + Some(Target::from_client_request(appservice_info.clone(), user)) + } else { + sender_ip_address.map(Target::Ip) + }; + + services().rate_limiting.check(target, metadata).await?; let mut http_request = Request::builder().uri(parts.uri).method(parts.method); *http_request.headers_mut().unwrap() = parts.headers; @@ -397,6 +437,7 @@ where sender_servername, appservice_info, json_body, + sender_ip_address, }) } } diff --git a/conduit/src/api/ruma_wrapper/mod.rs b/conduit/src/api/ruma_wrapper/mod.rs index d8866194..342253b6 100644 --- a/conduit/src/api/ruma_wrapper/mod.rs +++ b/conduit/src/api/ruma_wrapper/mod.rs @@ -3,7 +3,7 @@ use ruma::{ CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, api::client::uiaa::UiaaResponse, }; -use std::ops::Deref; +use std::{net::IpAddr, ops::Deref}; #[cfg(feature = "conduit_bin")] mod axum; @@ -14,6 +14,7 @@ pub struct Ruma { pub sender_user: Option, pub sender_device: Option, pub sender_servername: Option, + pub sender_ip_address: Option, // This is None when body is not a valid string pub json_body: Option, pub appservice_info: Option, diff --git a/conduit/src/api/server_server.rs b/conduit/src/api/server_server.rs index 4036c6ee..8317ca46 100644 --- a/conduit/src/api/server_server.rs +++ b/conduit/src/api/server_server.rs @@ -7,6 +7,7 @@ use crate::{ globals::SigningKeys, media::FileMeta, pdu::{PduBuilder, gen_event_id_canonical_json}, + rate_limiting::Target, }, services, utils, }; @@ -2244,6 +2245,13 @@ pub async fn create_invite_route( pub async fn get_content_route( body: Ruma, ) -> Result { + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); + + let target = Some(Target::Server(sender_servername.to_owned())); + services() .media .check_blocked(services().globals.server_name(), &body.media_id)?; @@ -2254,7 +2262,11 @@ pub async fn get_content_route( file, }) = services() .media - .get(services().globals.server_name(), &body.media_id, true) + .get( + services().globals.server_name(), + &body.media_id, + target.clone(), + ) .await? { Ok(get_content::v1::Response::new( @@ -2276,6 +2288,13 @@ pub async fn get_content_route( pub async fn get_content_thumbnail_route( body: Ruma, ) -> Result { + let Ruma:: { + body, + sender_servername, + .. + } = body; + let sender_servername = sender_servername.expect("server is authenticated"); + services() .media .check_blocked(services().globals.server_name(), &body.media_id)?; @@ -2295,7 +2314,7 @@ pub async fn get_content_thumbnail_route( body.height .try_into() .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, - true, + Some(Target::Server(sender_servername)), ) .await? else { diff --git a/conduit/src/database/key_value/media.rs b/conduit/src/database/key_value/media.rs index 14702c17..793fcba5 100644 --- a/conduit/src/database/key_value/media.rs +++ b/conduit/src/database/key_value/media.rs @@ -204,19 +204,7 @@ impl service::media::Data for KeyValueDatabase { let is_blocked_via_filehash = self.is_blocked_filehash(&sha256_digest)?; - let time_info = if let Some(filehash_meta) = self - .filehash_metadata - .get(&sha256_digest)? - .map(FilehashMetadata::from_vec) - { - Some(FileInfo { - creation: filehash_meta.creation(&sha256_digest)?, - last_access: filehash_meta.last_access(&sha256_digest)?, - size: filehash_meta.size(&sha256_digest)?, - }) - } else { - None - }; + let file_info = self.file_info(&sha256_digest)?; Some(MediaQueryFileInfo { uploader_localpart, @@ -225,7 +213,7 @@ impl service::media::Data for KeyValueDatabase { content_type, unauthenticated_access_permitted, is_blocked_via_filehash, - file_info: time_info, + file_info, }) } else { None @@ -1358,6 +1346,24 @@ impl service::media::Data for KeyValueDatabase { Ok(()) } } + + fn file_info(&self, sha256_digest: &[u8]) -> Result, Error> { + Ok( + if let Some(filehash_meta) = self + .filehash_metadata + .get(sha256_digest)? + .map(FilehashMetadata::from_vec) + { + Some(FileInfo { + creation: filehash_meta.creation(sha256_digest)?, + last_access: filehash_meta.last_access(sha256_digest)?, + size: filehash_meta.size(sha256_digest)?, + }) + } else { + None + }, + ) + } } impl KeyValueDatabase { diff --git a/conduit/src/service/admin/mod.rs b/conduit/src/service/admin/mod.rs index b21c03d5..b34acbaa 100644 --- a/conduit/src/service/admin/mod.rs +++ b/conduit/src/service/admin/mod.rs @@ -42,6 +42,7 @@ use tokio::sync::{Mutex, RwLock, mpsc}; use crate::{ Error, PduEvent, Result, api::client_server::{self, AUTO_GEN_PASSWORD_LENGTH, leave_all_rooms}, + service::rate_limiting::Target, services, utils::{self, HtmlEscape}, }; @@ -1174,8 +1175,12 @@ impl Service { file, content_type, content_disposition, - } = client_server::media::get_content(server_name, media_id.to_owned(), true, true) - .await?; + } = client_server::media::get_content( + server_name, + media_id.to_owned(), + Some(Target::User(services().globals.server_user().to_owned())), + ) + .await?; if let Ok(image) = image::load_from_memory(&file) { let filename = content_disposition.and_then(|cd| cd.filename); diff --git a/conduit/src/service/media/data.rs b/conduit/src/service/media/data.rs index 43a0382a..ce3b3a16 100644 --- a/conduit/src/service/media/data.rs +++ b/conduit/src/service/media/data.rs @@ -2,7 +2,7 @@ use conduit_config::MediaRetentionConfig; use ruma::{OwnedServerName, ServerName, UserId}; use sha2::{Sha256, digest::Output}; -use crate::{Error, Result}; +use crate::{Error, Result, service::media::FileInfo}; use super::{ BlockedMediaInfo, DbFileMeta, MediaListItem, MediaQuery, MediaType, ServerNameOrUserId, @@ -125,4 +125,7 @@ pub trait Data: Send + Sync { fn update_last_accessed(&self, server_name: &ServerName, media_id: &str) -> Result<()>; fn update_last_accessed_filehash(&self, sha256_digest: &[u8]) -> Result<()>; + + /// Returns the known information about a file + fn file_info(&self, sha256_digest: &[u8]) -> Result>; } diff --git a/conduit/src/service/media/mod.rs b/conduit/src/service/media/mod.rs index 5b2d57b1..29d3faf2 100644 --- a/conduit/src/service/media/mod.rs +++ b/conduit/src/service/media/mod.rs @@ -16,7 +16,7 @@ use rusty_s3::{ use sha2::{Digest, Sha256, digest::Output}; use tracing::{error, info, warn}; -use crate::{Error, Result, services, utils}; +use crate::{Error, Result, service::rate_limiting::Target, services, utils}; use image::imageops::FilterType; pub struct DbFileMeta { @@ -242,7 +242,7 @@ impl Service { &self, servername: &ServerName, media_id: &str, - authenticated: bool, + target: Option, ) -> Result> { let DbFileMeta { sha256_digest, @@ -251,12 +251,19 @@ impl Service { unauthenticated_access_permitted, } = self.db.search_file_metadata(servername, media_id)?; - if !(authenticated || unauthenticated_access_permitted) { + if !(target.as_ref().is_some_and(Target::is_authenticated) + || unauthenticated_access_permitted) + { return Ok(None); } let file = self.get_file(&sha256_digest, None).await?; + services() + .rate_limiting + .check_media_download(target, size(&file)?) + .await?; + Ok(Some(FileMeta { content_disposition: content_disposition(filename, &content_type), content_type, @@ -293,7 +300,7 @@ impl Service { media_id: &str, width: u32, height: u32, - authenticated: bool, + target: Option, ) -> Result> { if let Some((width, height, crop)) = self.thumbnail_properties(width, height) { if let Ok(DbFileMeta { @@ -305,10 +312,19 @@ impl Service { .db .search_thumbnail_metadata(servername, media_id, width, height) { - if !(authenticated || unauthenticated_access_permitted) { + if !(target.as_ref().is_some_and(Target::is_authenticated) + || unauthenticated_access_permitted) + { return Ok(None); } + let file_info = self.file_info(&sha256_digest)?; + + services() + .rate_limiting + .check_media_download(target, file_info.size) + .await?; + // Using saved thumbnail let file = self .get_file(&sha256_digest, Some((servername, media_id))) @@ -319,19 +335,15 @@ impl Service { content_type, file, })) - } else if !authenticated { + } else if !target.as_ref().is_some_and(Target::is_authenticated) { return Ok(None); } else if let Ok(DbFileMeta { sha256_digest, filename, content_type, - unauthenticated_access_permitted, + .. }) = self.db.search_file_metadata(servername, media_id) { - if !(authenticated || unauthenticated_access_permitted) { - return Ok(None); - } - let content_disposition = content_disposition(filename.clone(), &content_type); // Generate a thumbnail let file = self.get_file(&sha256_digest, None).await?; @@ -431,7 +443,9 @@ impl Service { return Ok(None); }; - if !(authenticated || unauthenticated_access_permitted) { + if !(target.as_ref().is_some_and(Target::is_authenticated) + || unauthenticated_access_permitted) + { return Ok(None); } @@ -667,6 +681,13 @@ impl Service { .update_last_accessed_filehash(sha256_digest) .map(|_| file) } + + fn file_info(&self, sha256_digest: &[u8]) -> Result { + self.db + .file_info(sha256_digest) + .transpose() + .unwrap_or_else(|| Err(Error::BadRequest(ErrorKind::NotFound, "Fi)le not found"))) + } } /// Creates the media file, using the configured media backend diff --git a/conduit/src/service/mod.rs b/conduit/src/service/mod.rs index a364d2fa..39e631a7 100644 --- a/conduit/src/service/mod.rs +++ b/conduit/src/service/mod.rs @@ -17,6 +17,7 @@ pub mod key_backups; pub mod media; pub mod pdu; pub mod pusher; +pub mod rate_limiting; pub mod rooms; pub mod sending; pub mod transaction_ids; @@ -36,6 +37,7 @@ pub struct Services { pub key_backups: key_backups::Service, pub media: Arc, pub sending: Arc, + pub rate_limiting: Arc, } impl Services { @@ -123,6 +125,8 @@ impl Services { media: Arc::new(media::Service { db }), sending: sending::Service::build(db, &config), + rate_limiting: rate_limiting::Service::build(&config), + globals: globals::Service::load(db, config)?, }) } diff --git a/conduit/src/service/rate_limiting/mod.rs b/conduit/src/service/rate_limiting/mod.rs new file mode 100644 index 00000000..46453402 --- /dev/null +++ b/conduit/src/service/rate_limiting/mod.rs @@ -0,0 +1,581 @@ +use std::{ + collections::{HashMap, hash_map::Entry}, + net::IpAddr, + sync::Arc, + time::Duration, +}; + +use conduit_config::{ + Config, + rate_limiting::{MediaLimitation, RequestLimitation, Restriction}, +}; +use ruma::{ + OwnedServerName, OwnedUserId, UserId, + api::{ + Metadata, + client::error::{ErrorKind, RetryAfter}, + }, +}; +use tokio::{ + sync::{Mutex, MutexGuard, RwLock}, + time::Instant, +}; + +use crate::{Error, Result, service::appservice::RegistrationInfo, services}; + +#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub enum Target { + User(OwnedUserId), + // Server endpoints should be rate-limited on a server and room basis + Server(OwnedServerName), + Appservice { id: String, rate_limited: bool }, + Ip(IpAddr), +} + +impl Target { + pub fn from_client_request( + registration_info: Option, + sender_user: &UserId, + ) -> Self { + if let Some(info) = registration_info { + // `rate_limited` only effects "masqueraded users", "The sender [user?] is excluded" + return Target::Appservice { + id: info.registration.id, + rate_limited: info.registration.rate_limited.unwrap_or(true) + && !(sender_user.server_name() == services().globals.server_name() + && info.registration.sender_localpart == sender_user.localpart()), + }; + } + + Target::User(sender_user.to_owned()) + } + + pub fn from_client_request_optional_auth( + registration_info: Option, + sender_user: &Option, + ip_addr: Option, + ) -> Option { + if let Some(sender_user) = sender_user.as_ref() { + Some(Self::from_client_request(registration_info, sender_user)) + } else { + ip_addr.map(Self::Ip) + } + } + + fn rate_limited(&self) -> bool { + match self { + Target::User(user_id) => user_id != services().globals.server_user(), + Target::Appservice { + id: _, + rate_limited, + } => *rate_limited, + _ => true, + } + } + + pub fn is_authenticated(&self) -> bool { + !matches!(self, Target::Ip(_)) + } +} + +// NOTE: still chokes on the global mutex around the map, which in theory only needed for inserts, +// but due to rust ownership model I don't think it's possible to (easily) only require to lock +// said mutex for inserts (and have a rwlock for removals if memory usage ever becomes a problem). +type MediaBucket = Mutex>>>; +type GlobalMediaBucket = Arc>; + +type RequestPermittedAfter = Arc>; + +pub struct Service { + buckets: Mutex>, + global_bucket: Mutex>, + + media_upload: MediaBucket, + media_fetch: MediaBucket, + media_download: MediaBucket, + + global_media_upload: GlobalMediaBucket, + global_media_fetch: GlobalMediaBucket, + global_media_download_client: GlobalMediaBucket, + global_media_download_federation: GlobalMediaBucket, + + authentication_failures: RwLock>>>, +} + +impl Service { + pub fn build(config: &Config) -> Arc { + let now = Instant::now(); + let global_media_config = &config.rate_limiting.global; + + Arc::new(Self { + buckets: Mutex::new(HashMap::new()), + global_bucket: Mutex::new(HashMap::new()), + + media_upload: Mutex::new(HashMap::new()), + media_fetch: Mutex::new(HashMap::new()), + media_download: Mutex::new(HashMap::new()), + + global_media_upload: default_media_entry(global_media_config.client.media.upload, now), + global_media_fetch: default_media_entry(global_media_config.client.media.fetch, now), + global_media_download_client: default_media_entry( + global_media_config.client.media.download, + now, + ), + global_media_download_federation: default_media_entry( + global_media_config.federation.media.download, + now, + ), + + authentication_failures: RwLock::new(HashMap::new()), + }) + } + + //TODO: use checked and saturating arithmetic + + /// Takes the target and request, and either accepts the request while adding to the + /// bucket, or rejects the request, returning the duration that should be waited until + /// the request should be retried. + pub async fn check(&self, target: Option, request: Metadata) -> Result<()> { + let Ok(restriction) = request.try_into() else { + // Endpoint has no associated restriction + return Ok(()); + }; + let arrival = Instant::now(); + + let config = services() + .globals + .config + .rate_limiting + .global + .get(&restriction); + + let mut map = self.global_bucket.lock().await; + + let entry = map.entry(restriction); + let proposed_entry = match &entry { + Entry::Occupied(occupied_entry) => { + let entry = Arc::clone(occupied_entry.get()); + let entry = entry.lock().await; + + if arrival.checked_duration_since(*entry).is_none() { + return instant_to_err(&entry); + } + + let min_instant = arrival + - Duration::from_nanos( + config.timeframe.nano_gap() * config.burst_capacity.get(), + ); + entry.max(min_instant) + Duration::from_nanos(config.timeframe.nano_gap()) + } + Entry::Vacant(_) => { + arrival + - Duration::from_nanos( + config.timeframe.nano_gap() * (config.burst_capacity.get() - 1), + ) + } + }; + + if let Some(target) = target { + let config = services() + .globals + .config + .rate_limiting + .target + .get(&restriction); + + let mut map = self.buckets.lock().await; + let entry = map.entry((target, restriction)); + match entry { + Entry::Occupied(occupied_entry) => { + let entry = Arc::clone(occupied_entry.get()); + let mut entry = entry.lock().await; + + if arrival.checked_duration_since(*entry).is_none() { + return instant_to_err(&entry); + } + + let min_instant = arrival + - Duration::from_nanos( + config.timeframe.nano_gap() * config.burst_capacity.get(), + ); + *entry = + entry.max(min_instant) + Duration::from_nanos(config.timeframe.nano_gap()); + } + Entry::Vacant(vacant_entry) => { + vacant_entry.insert(Arc::new(Mutex::new( + arrival + - Duration::from_nanos( + config.timeframe.nano_gap() * (config.burst_capacity.get() - 1), + ), + ))); + } + } + } + + entry.insert_entry(Arc::new(Mutex::new(proposed_entry))); + + Ok(()) + } + + pub async fn check_media_download(&self, target: Option, size: u64) -> Result<()> { + // All targets besides servers use the client-server API + let (target_limitation, global_limitation, global_bucket) = + if let Some(Target::Server(_)) = &target { + ( + services() + .globals + .config + .rate_limiting + .target + .federation + .media + .download, + services() + .globals + .config + .rate_limiting + .global + .federation + .media + .download, + &self.global_media_download_federation, + ) + } else { + ( + services() + .globals + .config + .rate_limiting + .target + .client + .media + .download, + services() + .globals + .config + .rate_limiting + .global + .client + .media + .download, + &self.global_media_download_client, + ) + }; + + check_media( + target, + size, + target_limitation, + global_limitation, + &self.media_download, + global_bucket, + ) + .await + } + + pub async fn check_media_upload(&self, target: Target, size: u64) -> Result<()> { + let target_limitation = services() + .globals + .config + .rate_limiting + .target + // Media can only be uploaded on the client-server API + .client + .media + .upload; + + let global_limitation = services() + .globals + .config + .rate_limiting + .global + // Media can only be uploaded on the client-server API + .client + .media + .upload; + + check_media( + Some(target), + size, + target_limitation, + global_limitation, + &self.media_upload, + &self.global_media_upload, + ) + .await + } + + pub async fn check_media_pre_fetch(&self, target: &Target) -> Result<()> { + if !target.rate_limited() { + return Ok(()); + } + + let arrival = Instant::now(); + + let check = async |map: &MediaBucket, global_bucket: &GlobalMediaBucket| { + let map = map.lock().await; + if let Some(mutex) = map.get(target) { + let mutex = mutex.lock().await; + + if arrival.checked_duration_since(*mutex).is_none() { + return instant_to_err(&mutex); + } + } + + let global_bucket = global_bucket.lock().await; + + if arrival.checked_duration_since(*global_bucket).is_none() { + return instant_to_err(&global_bucket); + } + + Ok(()) + }; + + // checking fetch + check(&self.media_fetch, &self.global_media_fetch).await?; + + // checking download as well + check(&self.media_download, &self.global_media_download_client).await + } + + /// Checks whether the ip address is has been rate limited due to too many bad access tokens being sent. + pub async fn pre_auth_check(&self, ip_addr: IpAddr) -> Result<()> { + let arrival = Instant::now(); + + if let Some(instant) = self.authentication_failures.read().await.get(&ip_addr) { + let instant = instant.read().await; + + if arrival.checked_duration_since(*instant).is_none() { + return instant_to_err(&instant); + } + } + + Ok(()) + } + + /// Updates the bad auth rate limiter when a bad access token is sent where access tokens auth is an option. + pub async fn update_post_auth_failure(&self, ip_addr: IpAddr) { + let arrival = Instant::now(); + + let RequestLimitation { + timeframe, + burst_capacity, + } = services() + .globals + .config + .rate_limiting + .target + .client + .authentication_failures; + + let mut map = self.authentication_failures.write().await; + let entry = map.entry(ip_addr); + + match entry { + Entry::Occupied(occupied_entry) => { + let entry = Arc::clone(occupied_entry.get()); + let mut entry = entry.write().await; + + let min_instant = + arrival - Duration::from_nanos(timeframe.nano_gap() * burst_capacity.get()); + *entry = entry.max(min_instant) + Duration::from_nanos(timeframe.nano_gap()); + } + Entry::Vacant(vacant_entry) => { + vacant_entry.insert(Arc::new(RwLock::new( + arrival - Duration::from_nanos(burst_capacity.get() / timeframe.nano_gap()), + ))); + } + } + } + + pub async fn update_media_post_fetch(&self, target: Target, size: u64) { + if !target.rate_limited() { + return; + } + + let arrival = Instant::now(); + + let update = async |map: &MediaBucket, + target_limitation: &MediaLimitation, + global_bucket: &GlobalMediaBucket, + global_limitation: &MediaLimitation| { + let mut map = map.lock().await; + let entry = map.entry(target.clone()); + + match entry { + Entry::Occupied(occupied_entry) => { + let entry = Arc::clone(occupied_entry.get()); + + let _ = + update_media_entry(size, target_limitation, &arrival, entry, false).await; + } + Entry::Vacant(vacant_entry) => { + vacant_entry.insert(Arc::new(Mutex::new( + arrival + - Duration::from_nanos( + target_limitation.burst_capacity.as_u64() + / target_limitation.timeframe.bytes_per_sec(), + ), + ))); + } + } + + let _ = update_media_entry( + size, + global_limitation, + &arrival, + Arc::clone(global_bucket), + false, + ) + .await; + }; + + // updating fetch + update( + &self.media_fetch, + &services() + .globals + .config + .rate_limiting + .target + .client + .media + .fetch, + &self.global_media_fetch, + &services() + .globals + .config + .rate_limiting + .global + .client + .media + .fetch, + ) + .await; + + // updating download as well + update( + &self.media_download, + &services() + .globals + .config + .rate_limiting + .target + .client + .media + .download, + &self.global_media_download_client, + &services() + .globals + .config + .rate_limiting + .global + .client + .media + .download, + ) + .await; + } +} + +async fn update_media_entry( + size: u64, + limitation: &MediaLimitation, + arrival: &Instant, + entry: Arc>, + and_check: bool, +) -> Result<()> { + let mut entry = entry.lock().await; + + //TODO: use more precise conversion than secs + let proposed_entry = get_proposed_entry(size, limitation, arrival, &entry, and_check)?; + + *entry = proposed_entry; + + Ok(()) +} + +fn get_proposed_entry( + size: u64, + limitation: &MediaLimitation, + arrival: &Instant, + entry: &MutexGuard<'_, Instant>, + and_check: bool, +) -> Result { + let min_instant = *arrival + - Duration::from_secs( + limitation.burst_capacity.as_u64() / limitation.timeframe.bytes_per_sec(), + ); + + let proposed_entry = + entry.max(min_instant) + Duration::from_secs(size / limitation.timeframe.bytes_per_sec()); + + if and_check && arrival.checked_duration_since(proposed_entry).is_none() { + return instant_to_err(&proposed_entry).map(|_| proposed_entry); + } + + Ok(proposed_entry) +} + +async fn check_media( + target: Option, + size: u64, + target_limitation: MediaLimitation, + global_limitation: MediaLimitation, + target_map: &MediaBucket, + global_bucket: &GlobalMediaBucket, +) -> Result<()> { + if !target.as_ref().is_some_and(Target::rate_limited) { + return Ok(()); + } + + let arrival = Instant::now(); + + let mut global_bucket = global_bucket.lock().await; + let proposed = get_proposed_entry(size, &global_limitation, &arrival, &global_bucket, true)?; + + if let Some(target) = target { + let mut map = target_map.lock().await; + let entry = map.entry(target); + + match entry { + Entry::Occupied(occupied_entry) => { + let entry = Arc::clone(occupied_entry.get()); + + update_media_entry(size, &target_limitation, &arrival, entry, true).await?; + } + Entry::Vacant(vacant_entry) => { + vacant_entry.insert(default_media_entry(target_limitation, arrival)); + } + } + } + + *global_bucket = proposed; + + Ok(()) +} + +fn default_media_entry( + target_limitation: MediaLimitation, + arrival: Instant, +) -> Arc> { + Arc::new(Mutex::new( + arrival + - Duration::from_nanos( + target_limitation.burst_capacity.as_u64() + / target_limitation.timeframe.bytes_per_sec(), + ), + )) +} + +fn instant_to_err(instant: &Instant) -> Result<()> { + let now = Instant::now(); + + Err(Error::BadRequest( + ErrorKind::LimitExceeded { + // Not using ::DateTime because conversion from Instant to SystemTime is convoluted + retry_after: Some(RetryAfter::Delay(instant.duration_since(now))), + }, + "Rate limit exceeded", + )) +}