Browse Source

WIP: rate-limiting

rate-limiting
Matthias Ahouansou 9 months ago
parent
commit
3d48a05353
No known key found for this signature in database
  1. 2
      conduit-config/Cargo.toml
  2. 10
      conduit-config/src/lib.rs
  3. 236
      conduit-config/src/rate_limiting.rs
  4. 350
      conduit/src/api/client_server/media.rs
  5. 437
      conduit/src/api/ruma_wrapper/axum.rs
  6. 3
      conduit/src/api/ruma_wrapper/mod.rs
  7. 23
      conduit/src/api/server_server.rs
  8. 34
      conduit/src/database/key_value/media.rs
  9. 9
      conduit/src/service/admin/mod.rs
  10. 5
      conduit/src/service/media/data.rs
  11. 45
      conduit/src/service/media/mod.rs
  12. 4
      conduit/src/service/mod.rs
  13. 581
      conduit/src/service/rate_limiting/mod.rs

2
conduit-config/Cargo.toml

@ -23,7 +23,7 @@ reqwest.workspace = true
# default room version, server name, ignored keys # default room version, server name, ignored keys
[dependencies.ruma] [dependencies.ruma]
features = ["federation-api"] features = ["client-api", "federation-api"]
workspace = true workspace = true
[features] [features]

10
conduit-config/src/lib.rs

@ -17,7 +17,9 @@ use url::Url;
pub mod error; pub mod error;
mod proxy; mod proxy;
use self::proxy::ProxyConfig; pub mod rate_limiting;
use self::{proxy::ProxyConfig, rate_limiting::Config as RateLimitingConfig};
const SHA256_HEX_LENGTH: u8 = 64; const SHA256_HEX_LENGTH: u8 = 64;
@ -98,6 +100,8 @@ pub struct IncompleteConfig {
#[serde(default)] #[serde(default)]
pub media: IncompleteMediaConfig, pub media: IncompleteMediaConfig,
pub rate_limiting: RateLimitingConfig,
pub emergency_password: Option<String>, pub emergency_password: Option<String>,
#[serde(flatten)] #[serde(flatten)]
@ -147,6 +151,8 @@ pub struct Config {
pub media: MediaConfig, pub media: MediaConfig,
pub rate_limiting: RateLimitingConfig,
pub emergency_password: Option<String>, pub emergency_password: Option<String>,
pub catchall: BTreeMap<String, IgnoredAny>, pub catchall: BTreeMap<String, IgnoredAny>,
@ -194,6 +200,7 @@ impl From<IncompleteConfig> for Config {
turn_ttl, turn_ttl,
turn, turn,
media, media,
rate_limiting,
emergency_password, emergency_password,
catchall, catchall,
ignored_keys, ignored_keys,
@ -295,6 +302,7 @@ impl From<IncompleteConfig> for Config {
ip_address_detection, ip_address_detection,
turn, turn,
media, media,
rate_limiting,
emergency_password, emergency_password,
catchall, catchall,
ignored_keys, ignored_keys,

236
conduit-config/src/rate_limiting.rs

@ -0,0 +1,236 @@
use std::{collections::HashMap, num::NonZeroU64};
use bytesize::ByteSize;
use ruma::api::Metadata;
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
#[serde(flatten)]
pub target: ConfigFragment,
pub global: ConfigFragment,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ConfigFragment {
pub client: ConfigClientFragment,
pub federation: ConfigFederationFragment,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ConfigClientFragment {
pub map: HashMap<ClientRestriction, RequestLimitation>,
pub media: ClientMediaConfig,
// TODO: Only have available on target, not global (same with most authenticated endpoints too maybe)?
pub authentication_failures: RequestLimitation,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ConfigFederationFragment {
pub map: HashMap<FederationRestriction, RequestLimitation>,
pub media: FederationMediaConfig,
}
impl ConfigFragment {
pub fn get(&self, restriction: &Restriction) -> &RequestLimitation {
// Maybe look into https://github.com/moriyoshi-kasuga/enum-table
match restriction {
Restriction::Client(client_restriction) => {
self.client.map.get(client_restriction).unwrap()
}
Restriction::Federation(federation_restriction) => {
self.federation.map.get(federation_restriction).unwrap()
}
}
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub enum Restriction {
Client(ClientRestriction),
Federation(FederationRestriction),
}
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[serde(rename_all = "snake_case")]
pub enum ClientRestriction {
Registration,
Login,
RegistrationTokenValidity,
SendEvent,
Join,
Invite,
Knock,
SendReport,
CreateAlias,
MediaDownload,
MediaCreate,
}
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[serde(rename_all = "snake_case")]
pub enum FederationRestriction {
Join,
Knock,
Invite,
// Transactions should be handled by a completely dedicated rate-limiter
Transaction,
MediaDownload,
}
impl TryFrom<Metadata> for Restriction {
type Error = ();
fn try_from(value: Metadata) -> Result<Self, Self::Error> {
use Restriction::*;
use ruma::api::{
IncomingRequest,
client::{
account::{check_registration_token_validity, register},
alias::create_alias,
authenticated_media::{
get_content, get_content_as_filename, get_content_thumbnail, get_media_preview,
},
knock::knock_room,
media::{self, create_content, create_mxc_uri},
membership::{invite_user, join_room_by_id, join_room_by_id_or_alias},
message::send_message_event,
reporting::report_user,
room::{report_content, report_room},
session::login,
state::send_state_event,
},
federation::{
authenticated_media::{
get_content as federation_get_content,
get_content_thumbnail as federation_get_content_thumbnail,
},
membership::{create_invite, create_join_event, create_knock_event},
},
};
Ok(match value {
register::v3::Request::METADATA => Client(ClientRestriction::Registration),
check_registration_token_validity::v1::Request::METADATA => {
Client(ClientRestriction::RegistrationTokenValidity)
}
login::v3::Request::METADATA => Client(ClientRestriction::Login),
send_message_event::v3::Request::METADATA | send_state_event::v3::Request::METADATA => {
Client(ClientRestriction::SendEvent)
}
join_room_by_id::v3::Request::METADATA
| join_room_by_id_or_alias::v3::Request::METADATA => Client(ClientRestriction::Join),
invite_user::v3::Request::METADATA => Client(ClientRestriction::Invite),
knock_room::v3::Request::METADATA => Client(ClientRestriction::Knock),
report_user::v3::Request::METADATA
| report_content::v3::Request::METADATA
| report_room::v3::Request::METADATA => Client(ClientRestriction::SendReport),
create_alias::v3::Request::METADATA => Client(ClientRestriction::CreateAlias),
// NOTE: handle async media upload in a way that doesn't half the number of uploads you can do within a short timeframe, while not allowing pre-generation of MXC uris to allow uploading double the number of media at once
create_content::v3::Request::METADATA | create_mxc_uri::v1::Request::METADATA => {
Client(ClientRestriction::MediaCreate)
}
// Unauthenticate media is deprecated
#[allow(deprecated)]
media::get_content::v3::Request::METADATA
| media::get_content_as_filename::v3::Request::METADATA
| media::get_content_thumbnail::v3::Request::METADATA
| media::get_media_preview::v3::Request::METADATA
| get_content::v1::Request::METADATA
| get_content_as_filename::v1::Request::METADATA
| get_content_thumbnail::v1::Request::METADATA
| get_media_preview::v1::Request::METADATA => Client(ClientRestriction::MediaDownload),
federation_get_content::v1::Request::METADATA
| federation_get_content_thumbnail::v1::Request::METADATA => {
Federation(FederationRestriction::MediaDownload)
}
// v1 is deprecated
#[allow(deprecated)]
create_join_event::v1::Request::METADATA | create_join_event::v2::Request::METADATA => {
Federation(FederationRestriction::Join)
}
create_knock_event::v1::Request::METADATA => Federation(FederationRestriction::Knock),
create_invite::v1::Request::METADATA | create_invite::v2::Request::METADATA => {
Federation(FederationRestriction::Invite)
}
_ => return Err(()),
})
}
}
#[derive(Clone, Copy, Debug, Deserialize)]
pub struct RequestLimitation {
#[serde(flatten)]
pub timeframe: Timeframe,
pub burst_capacity: NonZeroU64,
}
#[derive(Deserialize, Clone, Copy, Debug)]
#[serde(rename_all = "snake_case")]
// When deserializing, we want this prefix
#[allow(clippy::enum_variant_names)]
pub enum Timeframe {
PerSecond(NonZeroU64),
PerMinute(NonZeroU64),
PerHour(NonZeroU64),
PerDay(NonZeroU64),
}
impl Timeframe {
pub fn nano_gap(&self) -> u64 {
match self {
Timeframe::PerSecond(t) => 1000 * 1000 * 1000 / t.get(),
Timeframe::PerMinute(t) => 1000 * 1000 * 1000 * 60 / t.get(),
Timeframe::PerHour(t) => 1000 * 1000 * 1000 * 60 * 60 / t.get(),
Timeframe::PerDay(t) => 1000 * 1000 * 1000 * 60 * 60 * 24 / t.get(),
}
}
}
#[derive(Clone, Copy, Debug, Deserialize)]
pub struct ClientMediaConfig {
pub download: MediaLimitation,
pub upload: MediaLimitation,
pub fetch: MediaLimitation,
}
#[derive(Clone, Copy, Debug, Deserialize)]
pub struct FederationMediaConfig {
pub download: MediaLimitation,
}
#[derive(Clone, Copy, Debug, Deserialize)]
pub struct MediaLimitation {
#[serde(flatten)]
pub timeframe: MediaTimeframe,
pub burst_capacity: ByteSize,
}
#[derive(Deserialize, Clone, Copy, Debug)]
#[serde(rename_all = "snake_case")]
// When deserializing, we want this prefix
#[allow(clippy::enum_variant_names)]
pub enum MediaTimeframe {
PerSecond(ByteSize),
PerMinute(ByteSize),
PerHour(ByteSize),
PerDay(ByteSize),
}
impl MediaTimeframe {
pub fn bytes_per_sec(&self) -> u64 {
match self {
MediaTimeframe::PerSecond(t) => t.as_u64(),
MediaTimeframe::PerMinute(t) => t.as_u64() / 60,
MediaTimeframe::PerHour(t) => t.as_u64() / (60 * 60),
MediaTimeframe::PerDay(t) => t.as_u64() / (60 * 60 * 24),
}
}
}

350
conduit/src/api/client_server/media.rs

@ -3,7 +3,14 @@
use std::time::Duration; use std::time::Duration;
use crate::{Error, Result, Ruma, service::media::FileMeta, services, utils}; use crate::{
Error, Result, Ruma,
service::{
media::{FileMeta, size},
rate_limiting::Target,
},
services, utils,
};
use http::header::{CONTENT_DISPOSITION, CONTENT_TYPE}; use http::header::{CONTENT_DISPOSITION, CONTENT_TYPE};
use ruma::{ use ruma::{
ServerName, UInt, ServerName, UInt,
@ -54,6 +61,8 @@ pub async fn get_media_config_auth_route(
pub async fn create_content_route( pub async fn create_content_route(
body: Ruma<create_content::v3::Request>, body: Ruma<create_content::v3::Request>,
) -> Result<create_content::v3::Response> { ) -> Result<create_content::v3::Response> {
let sender_user = body.sender_user.expect("user is authenticated");
let create_content::v3::Request { let create_content::v3::Request {
filename, filename,
content_type, content_type,
@ -61,6 +70,13 @@ pub async fn create_content_route(
.. ..
} = body.body; } = body.body;
let target = Target::from_client_request(body.appservice_info, &sender_user);
services()
.rate_limiting
.check_media_upload(target, size(&file)?)
.await?;
let media_id = utils::random_string(MXC_LENGTH); let media_id = utils::random_string(MXC_LENGTH);
services() services()
@ -71,7 +87,7 @@ pub async fn create_content_route(
filename.as_deref(), filename.as_deref(),
content_type.as_deref(), content_type.as_deref(),
&file, &file,
body.sender_user.as_deref(), Some(&sender_user),
) )
.await?; .await?;
@ -84,7 +100,13 @@ pub async fn create_content_route(
pub async fn get_remote_content( pub async fn get_remote_content(
server_name: &ServerName, server_name: &ServerName,
media_id: String, media_id: String,
target: Target,
) -> Result<get_content::v1::Response, Error> { ) -> Result<get_content::v1::Response, Error> {
services()
.rate_limiting
.check_media_pre_fetch(&target)
.await?;
let content_response = match services() let content_response = match services()
.sending .sending
.send_federation_request( .send_federation_request(
@ -153,6 +175,11 @@ pub async fn get_remote_content(
) )
.await?; .await?;
services()
.rate_limiting
.update_media_post_fetch(target, size(&content_response.file)?)
.await;
Ok(content_response) Ok(content_response)
} }
@ -171,11 +198,21 @@ pub async fn get_content_route(
} = get_content( } = get_content(
&body.server_name, &body.server_name,
body.media_id.clone(), body.media_id.clone(),
body.allow_remote, body.sender_ip_address.map(Target::Ip),
false,
) )
.await?; .await?;
if let Some(target) = Target::from_client_request_optional_auth(
body.appservice_info,
&body.sender_user,
body.sender_ip_address,
) {
services()
.rate_limiting
.update_media_post_fetch(target, size(&file)?)
.await;
}
Ok(media::get_content::v3::Response { Ok(media::get_content::v3::Response {
file, file,
content_type, content_type,
@ -190,14 +227,24 @@ pub async fn get_content_route(
pub async fn get_content_auth_route( pub async fn get_content_auth_route(
body: Ruma<get_content::v1::Request>, body: Ruma<get_content::v1::Request>,
) -> Result<get_content::v1::Response> { ) -> Result<get_content::v1::Response> {
get_content(&body.server_name, body.media_id.clone(), true, true).await let Ruma::<get_content::v1::Request> {
body,
sender_user,
appservice_info,
..
} = body;
let sender_user = sender_user.as_ref().expect("user is authenticated");
let target = Target::from_client_request(appservice_info, sender_user);
get_content(&body.server_name, body.media_id.clone(), Some(target)).await
} }
pub async fn get_content( pub async fn get_content(
server_name: &ServerName, server_name: &ServerName,
media_id: String, media_id: String,
allow_remote: bool, target: Option<Target>,
authenticated: bool,
) -> Result<get_content::v1::Response, Error> { ) -> Result<get_content::v1::Response, Error> {
services().media.check_blocked(server_name, &media_id)?; services().media.check_blocked(server_name, &media_id)?;
@ -207,7 +254,7 @@ pub async fn get_content(
file, file,
})) = services() })) = services()
.media .media
.get(server_name, &media_id, authenticated) .get(server_name, &media_id, target.clone())
.await .await
{ {
Ok(get_content::v1::Response { Ok(get_content::v1::Response {
@ -215,16 +262,25 @@ pub async fn get_content(
content_type, content_type,
content_disposition: Some(content_disposition), content_disposition: Some(content_disposition),
}) })
} else if server_name != services().globals.server_name() && allow_remote && authenticated {
let remote_content_response = get_remote_content(server_name, media_id.clone()).await?;
Ok(get_content::v1::Response {
content_disposition: remote_content_response.content_disposition,
content_type: remote_content_response.content_type,
file: remote_content_response.file,
})
} else { } else {
Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) let error = Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
if let Some(target) = target {
if server_name != services().globals.server_name() && target.is_authenticated() {
let remote_content_response =
get_remote_content(server_name, media_id.clone(), target).await?;
Ok(get_content::v1::Response {
content_disposition: remote_content_response.content_disposition,
content_type: remote_content_response.content_type,
file: remote_content_response.file,
})
} else {
error
}
} else {
error
}
} }
} }
@ -244,8 +300,7 @@ pub async fn get_content_as_filename_route(
&body.server_name, &body.server_name,
body.media_id.clone(), body.media_id.clone(),
body.filename.clone(), body.filename.clone(),
body.allow_remote, body.sender_ip_address.map(Target::Ip),
false,
) )
.await?; .await?;
@ -263,12 +318,22 @@ pub async fn get_content_as_filename_route(
pub async fn get_content_as_filename_auth_route( pub async fn get_content_as_filename_auth_route(
body: Ruma<get_content_as_filename::v1::Request>, body: Ruma<get_content_as_filename::v1::Request>,
) -> Result<get_content_as_filename::v1::Response, Error> { ) -> Result<get_content_as_filename::v1::Response, Error> {
let Ruma::<get_content_as_filename::v1::Request> {
body,
sender_user,
appservice_info,
..
} = body;
let sender_user = sender_user.as_ref().expect("user is authenticated");
let target = Target::from_client_request(appservice_info, sender_user);
get_content_as_filename( get_content_as_filename(
&body.server_name, &body.server_name,
body.media_id.clone(), body.media_id.clone(),
body.filename.clone(), body.filename.clone(),
true, Some(target),
true,
) )
.await .await
} }
@ -277,8 +342,7 @@ async fn get_content_as_filename(
server_name: &ServerName, server_name: &ServerName,
media_id: String, media_id: String,
filename: String, filename: String,
allow_remote: bool, target: Option<Target>,
authenticated: bool,
) -> Result<get_content_as_filename::v1::Response, Error> { ) -> Result<get_content_as_filename::v1::Response, Error> {
services().media.check_blocked(server_name, &media_id)?; services().media.check_blocked(server_name, &media_id)?;
@ -286,7 +350,7 @@ async fn get_content_as_filename(
file, content_type, .. file, content_type, ..
})) = services() })) = services()
.media .media
.get(server_name, &media_id, authenticated) .get(server_name, &media_id, target.clone())
.await .await
{ {
Ok(get_content_as_filename::v1::Response { Ok(get_content_as_filename::v1::Response {
@ -297,19 +361,28 @@ async fn get_content_as_filename(
.with_filename(Some(filename.clone())), .with_filename(Some(filename.clone())),
), ),
}) })
} else if server_name != services().globals.server_name() && allow_remote && authenticated {
let remote_content_response = get_remote_content(server_name, media_id.clone()).await?;
Ok(get_content_as_filename::v1::Response {
content_disposition: Some(
ContentDisposition::new(ContentDispositionType::Inline)
.with_filename(Some(filename.clone())),
),
content_type: remote_content_response.content_type,
file: remote_content_response.file,
})
} else { } else {
Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) let error = Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
if let Some(target) = target {
if server_name != services().globals.server_name() && target.is_authenticated() {
let remote_content_response =
get_remote_content(server_name, media_id.clone(), target).await?;
Ok(get_content_as_filename::v1::Response {
content_disposition: Some(
ContentDisposition::new(ContentDispositionType::Inline)
.with_filename(Some(filename.clone())),
),
content_type: remote_content_response.content_type,
file: remote_content_response.file,
})
} else {
error
}
} else {
error
}
} }
} }
@ -321,6 +394,17 @@ async fn get_content_as_filename(
pub async fn get_content_thumbnail_route( pub async fn get_content_thumbnail_route(
body: Ruma<media::get_content_thumbnail::v3::Request>, body: Ruma<media::get_content_thumbnail::v3::Request>,
) -> Result<media::get_content_thumbnail::v3::Response> { ) -> Result<media::get_content_thumbnail::v3::Response> {
let Ruma::<media::get_content_thumbnail::v3::Request> {
body,
sender_user,
sender_ip_address,
appservice_info,
..
} = body;
let target =
Target::from_client_request_optional_auth(appservice_info, &sender_user, sender_ip_address);
let get_content_thumbnail::v1::Response { let get_content_thumbnail::v1::Response {
file, file,
content_type, content_type,
@ -332,8 +416,7 @@ pub async fn get_content_thumbnail_route(
body.width, body.width,
body.method.clone(), body.method.clone(),
body.animated, body.animated,
body.allow_remote, target,
false,
) )
.await?; .await?;
@ -351,6 +434,15 @@ pub async fn get_content_thumbnail_route(
pub async fn get_content_thumbnail_auth_route( pub async fn get_content_thumbnail_auth_route(
body: Ruma<get_content_thumbnail::v1::Request>, body: Ruma<get_content_thumbnail::v1::Request>,
) -> Result<get_content_thumbnail::v1::Response> { ) -> Result<get_content_thumbnail::v1::Response> {
let Ruma::<get_content_thumbnail::v1::Request> {
body,
sender_user,
appservice_info,
..
} = body;
let sender_user = sender_user.as_ref().expect("user is authenticated");
let target = Target::from_client_request(appservice_info, sender_user);
get_content_thumbnail( get_content_thumbnail(
&body.server_name, &body.server_name,
body.media_id.clone(), body.media_id.clone(),
@ -358,8 +450,7 @@ pub async fn get_content_thumbnail_auth_route(
body.width, body.width,
body.method.clone(), body.method.clone(),
body.animated, body.animated,
true, Some(target),
true,
) )
.await .await
} }
@ -372,8 +463,7 @@ async fn get_content_thumbnail(
width: UInt, width: UInt,
method: Option<Method>, method: Option<Method>,
animated: Option<bool>, animated: Option<bool>,
allow_remote: bool, target: Option<Target>,
authenticated: bool,
) -> Result<get_content_thumbnail::v1::Response, Error> { ) -> Result<get_content_thumbnail::v1::Response, Error> {
services().media.check_blocked(server_name, &media_id)?; services().media.check_blocked(server_name, &media_id)?;
@ -392,7 +482,7 @@ async fn get_content_thumbnail(
height height
.try_into() .try_into()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?, .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?,
authenticated, target.clone(),
) )
.await? .await?
{ {
@ -401,99 +491,117 @@ async fn get_content_thumbnail(
content_type, content_type,
content_disposition: Some(content_disposition), content_disposition: Some(content_disposition),
}) })
} else if server_name != services().globals.server_name() && allow_remote && authenticated { } else {
let thumbnail_response = match services() let error = Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
.sending
.send_federation_request(
server_name,
federation_media::get_content_thumbnail::v1::Request {
height,
width,
method: method.clone(),
media_id: media_id.clone(),
timeout_ms: Duration::from_secs(20),
animated,
},
)
.await
{
Ok(federation_media::get_content_thumbnail::v1::Response {
metadata: _,
content: FileOrLocation::File(content),
}) => get_content_thumbnail::v1::Response {
file: content.file,
content_type: content.content_type,
content_disposition: content.content_disposition,
},
Ok(federation_media::get_content_thumbnail::v1::Response { if let Some(target) = target {
metadata: _, if server_name != services().globals.server_name() {
content: FileOrLocation::Location(url), services()
}) => { .rate_limiting
let get_content::v1::Response { .check_media_pre_fetch(&target)
file, .await?;
content_type,
content_disposition, let thumbnail_response = match services()
} = get_location_content(url).await?;
get_content_thumbnail::v1::Response {
file,
content_type,
content_disposition,
}
}
Err(Error::BadRequest(ErrorKind::Unrecognized, _)) => {
let media::get_content_thumbnail::v3::Response {
file,
content_type,
content_disposition,
..
} = services()
.sending .sending
.send_federation_request( .send_federation_request(
server_name, server_name,
media::get_content_thumbnail::v3::Request { federation_media::get_content_thumbnail::v1::Request {
height, height,
width, width,
method: method.clone(), method: method.clone(),
server_name: server_name.to_owned(),
media_id: media_id.clone(), media_id: media_id.clone(),
timeout_ms: Duration::from_secs(20), timeout_ms: Duration::from_secs(20),
allow_redirect: false,
animated, animated,
allow_remote: false,
}, },
) )
.await
{
Ok(federation_media::get_content_thumbnail::v1::Response {
metadata: _,
content: FileOrLocation::File(content),
}) => get_content_thumbnail::v1::Response {
file: content.file,
content_type: content.content_type,
content_disposition: content.content_disposition,
},
Ok(federation_media::get_content_thumbnail::v1::Response {
metadata: _,
content: FileOrLocation::Location(url),
}) => {
let get_content::v1::Response {
file,
content_type,
content_disposition,
} = get_location_content(url).await?;
get_content_thumbnail::v1::Response {
file,
content_type,
content_disposition,
}
}
Err(Error::BadRequest(ErrorKind::Unrecognized, _)) => {
let media::get_content_thumbnail::v3::Response {
file,
content_type,
content_disposition,
..
} = services()
.sending
.send_federation_request(
server_name,
media::get_content_thumbnail::v3::Request {
height,
width,
method: method.clone(),
server_name: server_name.to_owned(),
media_id: media_id.clone(),
timeout_ms: Duration::from_secs(20),
allow_redirect: false,
animated,
allow_remote: false,
},
)
.await?;
get_content_thumbnail::v1::Response {
file,
content_type,
content_disposition,
}
}
Err(e) => return Err(e),
};
services()
.rate_limiting
.update_media_post_fetch(target, size(&thumbnail_response.file)?)
.await;
services()
.media
.upload_thumbnail(
server_name,
&media_id,
thumbnail_response
.content_disposition
.as_ref()
.and_then(|cd| cd.filename.as_deref()),
thumbnail_response.content_type.as_deref(),
width.try_into().expect("all UInts are valid u32s"),
height.try_into().expect("all UInts are valid u32s"),
&thumbnail_response.file,
)
.await?; .await?;
get_content_thumbnail::v1::Response { Ok(thumbnail_response)
file, } else {
content_type, error
content_disposition,
}
} }
Err(e) => return Err(e), } else {
}; error
}
services()
.media
.upload_thumbnail(
server_name,
&media_id,
thumbnail_response
.content_disposition
.as_ref()
.and_then(|cd| cd.filename.as_deref()),
thumbnail_response.content_type.as_deref(),
width.try_into().expect("all UInts are valid u32s"),
height.try_into().expect("all UInts are valid u32s"),
&thumbnail_response.file,
)
.await?;
Ok(thumbnail_response)
} else {
Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."))
} }
} }

437
conduit/src/api/ruma_wrapper/axum.rs

@ -3,7 +3,7 @@ use std::{
error::Error as _, error::Error as _,
iter::FromIterator, iter::FromIterator,
net::{IpAddr, SocketAddr}, net::{IpAddr, SocketAddr},
str, str::{self, FromStr},
}; };
use axum::{ use axum::{
@ -18,6 +18,7 @@ use axum_extra::{
typed_header::TypedHeaderRejectionReason, typed_header::TypedHeaderRejectionReason,
}; };
use bytes::{BufMut, BytesMut}; use bytes::{BufMut, BytesMut};
use conduit_config::IpAddrDetection;
use http::{Request, StatusCode}; use http::{Request, StatusCode};
use ruma::{ use ruma::{
CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedUserId, UserId, CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedUserId, UserId,
@ -31,12 +32,15 @@ use tracing::{debug, error, warn};
use super::{Ruma, RumaResponse}; use super::{Ruma, RumaResponse};
use crate::{ use crate::{
Error, Result, config::IpAddrDetection, service::appservice::RegistrationInfo, services, Error, Result,
service::{appservice::RegistrationInfo, rate_limiting::Target},
services,
}; };
enum Token { enum Token {
Appservice(Box<RegistrationInfo>), Appservice(Box<RegistrationInfo>),
User((OwnedUserId, OwnedDeviceId)), User((OwnedUserId, OwnedDeviceId)),
AuthRateLimited(Error),
Invalid, Invalid,
None, None,
} }
@ -122,7 +126,17 @@ where
}; };
let token = if let Some(token) = token { let token = if let Some(token) = token {
if let Some(reg_info) = services().appservice.find_from_token(token).await { let mut rate_limited = None;
if let Some(ip_addr) = sender_ip_address {
if let Err(instant) = services().rate_limiting.pre_auth_check(ip_addr).await {
rate_limited = Some(instant);
}
}
if let Some(instant) = rate_limited {
Token::AuthRateLimited(instant)
} else if let Some(reg_info) = services().appservice.find_from_token(token).await {
Token::Appservice(Box::new(reg_info.clone())) Token::Appservice(Box::new(reg_info.clone()))
} else if let Some((user_id, device_id)) = services().users.find_from_token(token)? { } else if let Some((user_id, device_id)) = services().users.find_from_token(token)? {
Token::User((user_id, device_id)) Token::User((user_id, device_id))
@ -135,217 +149,243 @@ where
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok(); let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
let (sender_user, sender_device, sender_servername, appservice_info) = let (sender_user, sender_device, sender_servername, appservice_info) = match (
match (metadata.authentication, token) { metadata.authentication,
(_, Token::Invalid) => { token,
// OpenID endpoint uses a query param with the same name, drop this once query params for user auth are removed from the spec ) {
if query_params.access_token.is_some() { (_, Token::AuthRateLimited(instant)) => {
(None, None, None, None) return Err(instant);
}
(_, Token::Invalid) => {
// OpenID endpoint uses a query param with the same name, drop this once query params for user auth are removed from the spec
if query_params.access_token.is_some() {
(None, None, None, None)
} else {
if let Some(addr) = sender_ip_address {
services()
.rate_limiting
.update_post_auth_failure(addr)
.await;
} else { } else {
return Err(Error::BadRequest( error!(
ErrorKind::UnknownToken { soft_logout: false }, "Auth failure occurred, but IP address was not extracted. Please check your Conduit & reverse proxy configuration, as if nothing is done, an attacker can brute-force access tokens and login to user's accounts"
"Unknown access token.", );
));
}
}
(AuthScheme::AccessToken, Token::Appservice(info)) => {
let user_id = query_params
.user_id
.map_or_else(
|| {
UserId::parse_with_server_name(
info.registration.sender_localpart.as_str(),
services().globals.server_name(),
)
},
UserId::parse,
)
.map_err(|_| {
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
})?;
if !info.is_user_match(&user_id) {
return Err(Error::BadRequest(
ErrorKind::Exclusive,
"User is not in namespace.",
));
} }
if !services().users.exists(&user_id)? { return Err(Error::BadRequest(
return Err(Error::BadRequest( ErrorKind::UnknownToken { soft_logout: false },
ErrorKind::forbidden(), "Unknown access token.",
"User does not exist.", ));
)); }
} }
(AuthScheme::AccessToken, Token::Appservice(info)) => {
let user_id = query_params
.user_id
.map_or_else(
|| {
UserId::parse_with_server_name(
info.registration.sender_localpart.as_str(),
services().globals.server_name(),
)
},
UserId::parse,
)
.map_err(|_| {
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
})?;
(Some(user_id), None, None, Some(*info)) if !info.is_user_match(&user_id) {
return Err(Error::BadRequest(
ErrorKind::Exclusive,
"User is not in namespace.",
));
} }
(
AuthScheme::None if !services().users.exists(&user_id)? {
| AuthScheme::AppserviceToken
| AuthScheme::AppserviceTokenOptional
| AuthScheme::AccessTokenOptional,
Token::Appservice(info),
) => (None, None, None, Some(*info)),
(AuthScheme::AppserviceToken | AuthScheme::AccessToken, Token::None) => {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::MissingToken, ErrorKind::forbidden(),
"Missing access token.", "User does not exist.",
)); ));
} }
(
AuthScheme::AccessToken | AuthScheme::AccessTokenOptional | AuthScheme::None,
Token::User((user_id, device_id)),
) => (Some(user_id), Some(device_id), None, None),
(AuthScheme::ServerSignatures, Token::None) => {
let TypedHeader(Authorization(x_matrix)) = parts
.extract::<TypedHeader<Authorization<XMatrix>>>()
.await
.map_err(|e| {
warn!("Missing or invalid Authorization header: {}", e);
let msg = match e.reason() {
TypedHeaderRejectionReason::Missing => {
"Missing Authorization header."
}
TypedHeaderRejectionReason::Error(_) => {
"Invalid X-Matrix signatures."
}
_ => "Unknown header-related error",
};
Error::BadRequest(ErrorKind::forbidden(), msg)
})?;
if let Some(dest) = x_matrix.destination {
if dest != services().globals.server_name() {
return Err(Error::BadRequest(
ErrorKind::Unauthorized,
"X-Matrix destination field does not match server name.",
));
}
};
let origin_signatures = BTreeMap::from_iter([(
x_matrix.key.clone(),
CanonicalJsonValue::String(x_matrix.sig.to_string()),
)]);
let signatures = BTreeMap::from_iter([(
x_matrix.origin.as_str().to_owned(),
CanonicalJsonValue::Object(
origin_signatures
.into_iter()
.map(|(k, v)| (k.to_string(), v))
.collect(),
),
)]);
let mut request_map = BTreeMap::from_iter([ (Some(user_id), None, None, Some(*info))
( }
"method".to_owned(), (
CanonicalJsonValue::String(parts.method.to_string()), AuthScheme::None
), | AuthScheme::AppserviceToken
( | AuthScheme::AppserviceTokenOptional
"uri".to_owned(), | AuthScheme::AccessTokenOptional,
CanonicalJsonValue::String(parts.uri.to_string()), Token::Appservice(info),
), ) => (None, None, None, Some(*info)),
( (AuthScheme::AppserviceToken | AuthScheme::AccessToken, Token::None) => {
"origin".to_owned(), return Err(Error::BadRequest(
CanonicalJsonValue::String(x_matrix.origin.as_str().to_owned()), ErrorKind::MissingToken,
), "Missing access token.",
( ));
"destination".to_owned(), }
CanonicalJsonValue::String( (
services().globals.server_name().as_str().to_owned(), AuthScheme::AccessToken | AuthScheme::AccessTokenOptional | AuthScheme::None,
), Token::User((user_id, device_id)),
), ) => (Some(user_id), Some(device_id), None, None),
( (AuthScheme::ServerSignatures, Token::None) => {
"signatures".to_owned(), let TypedHeader(Authorization(x_matrix)) = parts
CanonicalJsonValue::Object(signatures), .extract::<TypedHeader<Authorization<XMatrix>>>()
.await
.map_err(|e| {
warn!("Missing or invalid Authorization header: {}", e);
let msg = match e.reason() {
TypedHeaderRejectionReason::Missing => "Missing Authorization header.",
TypedHeaderRejectionReason::Error(_) => "Invalid X-Matrix signatures.",
_ => "Unknown header-related error",
};
Error::BadRequest(ErrorKind::forbidden(), msg)
})?;
if let Some(dest) = x_matrix.destination {
if dest != services().globals.server_name() {
return Err(Error::BadRequest(
ErrorKind::Unauthorized,
"X-Matrix destination field does not match server name.",
));
}
};
let origin_signatures = BTreeMap::from_iter([(
x_matrix.key.clone(),
CanonicalJsonValue::String(x_matrix.sig.to_string()),
)]);
let signatures = BTreeMap::from_iter([(
x_matrix.origin.as_str().to_owned(),
CanonicalJsonValue::Object(
origin_signatures
.into_iter()
.map(|(k, v)| (k.to_string(), v))
.collect(),
),
)]);
let mut request_map = BTreeMap::from_iter([
(
"method".to_owned(),
CanonicalJsonValue::String(parts.method.to_string()),
),
(
"uri".to_owned(),
CanonicalJsonValue::String(parts.uri.to_string()),
),
(
"origin".to_owned(),
CanonicalJsonValue::String(x_matrix.origin.as_str().to_owned()),
),
(
"destination".to_owned(),
CanonicalJsonValue::String(
services().globals.server_name().as_str().to_owned(),
), ),
]); ),
(
if let Some(json_body) = &json_body { "signatures".to_owned(),
request_map.insert("content".to_owned(), json_body.clone()); CanonicalJsonValue::Object(signatures),
}; ),
]);
let keys_result = services()
.rooms if let Some(json_body) = &json_body {
.event_handler request_map.insert("content".to_owned(), json_body.clone());
.fetch_signing_keys(&x_matrix.origin, vec![x_matrix.key.to_string()], false) };
.await;
let keys_result = services()
let keys = match keys_result { .rooms
Ok(b) => b, .event_handler
Err(e) => { .fetch_signing_keys(&x_matrix.origin, vec![x_matrix.key.to_string()], false)
warn!("Failed to fetch signing keys: {}", e); .await;
return Err(Error::BadRequest(
ErrorKind::forbidden(), let keys = match keys_result {
"Failed to fetch signing keys.", Ok(b) => b,
)); Err(e) => {
} warn!("Failed to fetch signing keys: {}", e);
}; return Err(Error::BadRequest(
ErrorKind::forbidden(),
// Only verify_keys that are currently valid should be used for validating requests "Failed to fetch signing keys.",
// as per MSC4029 ));
let pub_key_map = BTreeMap::from_iter([( }
x_matrix.origin.as_str().to_owned(), };
if keys.valid_until_ts > MilliSecondsSinceUnixEpoch::now() {
keys.verify_keys // Only verify_keys that are currently valid should be used for validating requests
.into_iter() // as per MSC4029
.map(|(id, key)| (id, key.key)) let pub_key_map = BTreeMap::from_iter([(
.collect() x_matrix.origin.as_str().to_owned(),
} else { if keys.valid_until_ts > MilliSecondsSinceUnixEpoch::now() {
BTreeMap::new() keys.verify_keys
}, .into_iter()
)]); .map(|(id, key)| (id, key.key))
.collect()
match ruma::signatures::verify_json(&pub_key_map, &request_map) { } else {
Ok(()) => (None, None, Some(x_matrix.origin), None), BTreeMap::new()
Err(e) => { },
)]);
match ruma::signatures::verify_json(&pub_key_map, &request_map) {
Ok(()) => (None, None, Some(x_matrix.origin), None),
Err(e) => {
warn!(
"Failed to verify json request from {}: {}\n{:?}",
x_matrix.origin, e, request_map
);
if parts.uri.to_string().contains('@') {
warn!( warn!(
"Failed to verify json request from {}: {}\n{:?}", "Request uri contained '@' character. Make sure your \
x_matrix.origin, e, request_map
);
if parts.uri.to_string().contains('@') {
warn!(
"Request uri contained '@' character. Make sure your \
reverse proxy gives Conduit the raw uri (apache: use \ reverse proxy gives Conduit the raw uri (apache: use \
nocanon)" nocanon)"
); );
}
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"Failed to verify X-Matrix signatures.",
));
} }
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"Failed to verify X-Matrix signatures.",
));
} }
} }
( }
AuthScheme::None (
| AuthScheme::AppserviceTokenOptional AuthScheme::None
| AuthScheme::AccessTokenOptional, | AuthScheme::AppserviceTokenOptional
Token::None, | AuthScheme::AccessTokenOptional,
) => (None, None, None, None), Token::None,
(AuthScheme::ServerSignatures, Token::Appservice(_) | Token::User(_)) => { ) => (None, None, None, None),
return Err(Error::BadRequest( (AuthScheme::ServerSignatures, Token::Appservice(_) | Token::User(_)) => {
ErrorKind::Unauthorized, return Err(Error::BadRequest(
"Only server signatures should be used on this endpoint.", ErrorKind::Unauthorized,
)); "Only server signatures should be used on this endpoint.",
} ));
( }
AuthScheme::AppserviceToken | AuthScheme::AppserviceTokenOptional, (AuthScheme::AppserviceToken | AuthScheme::AppserviceTokenOptional, Token::User(_)) => {
Token::User(_), return Err(Error::BadRequest(
) => { ErrorKind::Unauthorized,
return Err(Error::BadRequest( "Only appservice access tokens should be used on this endpoint.",
ErrorKind::Unauthorized, ));
"Only appservice access tokens should be used on this endpoint.", }
)); };
}
}; let sender_ip_address = parts
.headers
.get("X-Forwarded-For")
.and_then(|header| header.to_str().ok())
.map(|header| header.split_once(',').map(|(ip, _)| ip).unwrap_or(header))
.and_then(|ip| IpAddr::from_str(ip).ok());
let target = if let Some(server_name) = sender_servername.clone() {
Some(Target::Server(server_name))
} else if let Some(user) = &sender_user {
Some(Target::from_client_request(appservice_info.clone(), user))
} else {
sender_ip_address.map(Target::Ip)
};
services().rate_limiting.check(target, metadata).await?;
let mut http_request = Request::builder().uri(parts.uri).method(parts.method); let mut http_request = Request::builder().uri(parts.uri).method(parts.method);
*http_request.headers_mut().unwrap() = parts.headers; *http_request.headers_mut().unwrap() = parts.headers;
@ -397,6 +437,7 @@ where
sender_servername, sender_servername,
appservice_info, appservice_info,
json_body, json_body,
sender_ip_address,
}) })
} }
} }

3
conduit/src/api/ruma_wrapper/mod.rs

@ -3,7 +3,7 @@ use ruma::{
CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId,
api::client::uiaa::UiaaResponse, api::client::uiaa::UiaaResponse,
}; };
use std::ops::Deref; use std::{net::IpAddr, ops::Deref};
#[cfg(feature = "conduit_bin")] #[cfg(feature = "conduit_bin")]
mod axum; mod axum;
@ -14,6 +14,7 @@ pub struct Ruma<T> {
pub sender_user: Option<OwnedUserId>, pub sender_user: Option<OwnedUserId>,
pub sender_device: Option<OwnedDeviceId>, pub sender_device: Option<OwnedDeviceId>,
pub sender_servername: Option<OwnedServerName>, pub sender_servername: Option<OwnedServerName>,
pub sender_ip_address: Option<IpAddr>,
// This is None when body is not a valid string // This is None when body is not a valid string
pub json_body: Option<CanonicalJsonValue>, pub json_body: Option<CanonicalJsonValue>,
pub appservice_info: Option<RegistrationInfo>, pub appservice_info: Option<RegistrationInfo>,

23
conduit/src/api/server_server.rs

@ -7,6 +7,7 @@ use crate::{
globals::SigningKeys, globals::SigningKeys,
media::FileMeta, media::FileMeta,
pdu::{PduBuilder, gen_event_id_canonical_json}, pdu::{PduBuilder, gen_event_id_canonical_json},
rate_limiting::Target,
}, },
services, utils, services, utils,
}; };
@ -2244,6 +2245,13 @@ pub async fn create_invite_route(
pub async fn get_content_route( pub async fn get_content_route(
body: Ruma<get_content::v1::Request>, body: Ruma<get_content::v1::Request>,
) -> Result<get_content::v1::Response> { ) -> Result<get_content::v1::Response> {
let sender_servername = body
.sender_servername
.as_ref()
.expect("server is authenticated");
let target = Some(Target::Server(sender_servername.to_owned()));
services() services()
.media .media
.check_blocked(services().globals.server_name(), &body.media_id)?; .check_blocked(services().globals.server_name(), &body.media_id)?;
@ -2254,7 +2262,11 @@ pub async fn get_content_route(
file, file,
}) = services() }) = services()
.media .media
.get(services().globals.server_name(), &body.media_id, true) .get(
services().globals.server_name(),
&body.media_id,
target.clone(),
)
.await? .await?
{ {
Ok(get_content::v1::Response::new( Ok(get_content::v1::Response::new(
@ -2276,6 +2288,13 @@ pub async fn get_content_route(
pub async fn get_content_thumbnail_route( pub async fn get_content_thumbnail_route(
body: Ruma<get_content_thumbnail::v1::Request>, body: Ruma<get_content_thumbnail::v1::Request>,
) -> Result<get_content_thumbnail::v1::Response> { ) -> Result<get_content_thumbnail::v1::Response> {
let Ruma::<get_content_thumbnail::v1::Request> {
body,
sender_servername,
..
} = body;
let sender_servername = sender_servername.expect("server is authenticated");
services() services()
.media .media
.check_blocked(services().globals.server_name(), &body.media_id)?; .check_blocked(services().globals.server_name(), &body.media_id)?;
@ -2295,7 +2314,7 @@ pub async fn get_content_thumbnail_route(
body.height body.height
.try_into() .try_into()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?,
true, Some(Target::Server(sender_servername)),
) )
.await? .await?
else { else {

34
conduit/src/database/key_value/media.rs

@ -204,19 +204,7 @@ impl service::media::Data for KeyValueDatabase {
let is_blocked_via_filehash = self.is_blocked_filehash(&sha256_digest)?; let is_blocked_via_filehash = self.is_blocked_filehash(&sha256_digest)?;
let time_info = if let Some(filehash_meta) = self let file_info = self.file_info(&sha256_digest)?;
.filehash_metadata
.get(&sha256_digest)?
.map(FilehashMetadata::from_vec)
{
Some(FileInfo {
creation: filehash_meta.creation(&sha256_digest)?,
last_access: filehash_meta.last_access(&sha256_digest)?,
size: filehash_meta.size(&sha256_digest)?,
})
} else {
None
};
Some(MediaQueryFileInfo { Some(MediaQueryFileInfo {
uploader_localpart, uploader_localpart,
@ -225,7 +213,7 @@ impl service::media::Data for KeyValueDatabase {
content_type, content_type,
unauthenticated_access_permitted, unauthenticated_access_permitted,
is_blocked_via_filehash, is_blocked_via_filehash,
file_info: time_info, file_info,
}) })
} else { } else {
None None
@ -1358,6 +1346,24 @@ impl service::media::Data for KeyValueDatabase {
Ok(()) Ok(())
} }
} }
fn file_info(&self, sha256_digest: &[u8]) -> Result<Option<FileInfo>, Error> {
Ok(
if let Some(filehash_meta) = self
.filehash_metadata
.get(sha256_digest)?
.map(FilehashMetadata::from_vec)
{
Some(FileInfo {
creation: filehash_meta.creation(sha256_digest)?,
last_access: filehash_meta.last_access(sha256_digest)?,
size: filehash_meta.size(sha256_digest)?,
})
} else {
None
},
)
}
} }
impl KeyValueDatabase { impl KeyValueDatabase {

9
conduit/src/service/admin/mod.rs

@ -42,6 +42,7 @@ use tokio::sync::{Mutex, RwLock, mpsc};
use crate::{ use crate::{
Error, PduEvent, Result, Error, PduEvent, Result,
api::client_server::{self, AUTO_GEN_PASSWORD_LENGTH, leave_all_rooms}, api::client_server::{self, AUTO_GEN_PASSWORD_LENGTH, leave_all_rooms},
service::rate_limiting::Target,
services, services,
utils::{self, HtmlEscape}, utils::{self, HtmlEscape},
}; };
@ -1174,8 +1175,12 @@ impl Service {
file, file,
content_type, content_type,
content_disposition, content_disposition,
} = client_server::media::get_content(server_name, media_id.to_owned(), true, true) } = client_server::media::get_content(
.await?; server_name,
media_id.to_owned(),
Some(Target::User(services().globals.server_user().to_owned())),
)
.await?;
if let Ok(image) = image::load_from_memory(&file) { if let Ok(image) = image::load_from_memory(&file) {
let filename = content_disposition.and_then(|cd| cd.filename); let filename = content_disposition.and_then(|cd| cd.filename);

5
conduit/src/service/media/data.rs

@ -2,7 +2,7 @@ use conduit_config::MediaRetentionConfig;
use ruma::{OwnedServerName, ServerName, UserId}; use ruma::{OwnedServerName, ServerName, UserId};
use sha2::{Sha256, digest::Output}; use sha2::{Sha256, digest::Output};
use crate::{Error, Result}; use crate::{Error, Result, service::media::FileInfo};
use super::{ use super::{
BlockedMediaInfo, DbFileMeta, MediaListItem, MediaQuery, MediaType, ServerNameOrUserId, BlockedMediaInfo, DbFileMeta, MediaListItem, MediaQuery, MediaType, ServerNameOrUserId,
@ -125,4 +125,7 @@ pub trait Data: Send + Sync {
fn update_last_accessed(&self, server_name: &ServerName, media_id: &str) -> Result<()>; fn update_last_accessed(&self, server_name: &ServerName, media_id: &str) -> Result<()>;
fn update_last_accessed_filehash(&self, sha256_digest: &[u8]) -> Result<()>; fn update_last_accessed_filehash(&self, sha256_digest: &[u8]) -> Result<()>;
/// Returns the known information about a file
fn file_info(&self, sha256_digest: &[u8]) -> Result<Option<FileInfo>>;
} }

45
conduit/src/service/media/mod.rs

@ -16,7 +16,7 @@ use rusty_s3::{
use sha2::{Digest, Sha256, digest::Output}; use sha2::{Digest, Sha256, digest::Output};
use tracing::{error, info, warn}; use tracing::{error, info, warn};
use crate::{Error, Result, services, utils}; use crate::{Error, Result, service::rate_limiting::Target, services, utils};
use image::imageops::FilterType; use image::imageops::FilterType;
pub struct DbFileMeta { pub struct DbFileMeta {
@ -242,7 +242,7 @@ impl Service {
&self, &self,
servername: &ServerName, servername: &ServerName,
media_id: &str, media_id: &str,
authenticated: bool, target: Option<Target>,
) -> Result<Option<FileMeta>> { ) -> Result<Option<FileMeta>> {
let DbFileMeta { let DbFileMeta {
sha256_digest, sha256_digest,
@ -251,12 +251,19 @@ impl Service {
unauthenticated_access_permitted, unauthenticated_access_permitted,
} = self.db.search_file_metadata(servername, media_id)?; } = self.db.search_file_metadata(servername, media_id)?;
if !(authenticated || unauthenticated_access_permitted) { if !(target.as_ref().is_some_and(Target::is_authenticated)
|| unauthenticated_access_permitted)
{
return Ok(None); return Ok(None);
} }
let file = self.get_file(&sha256_digest, None).await?; let file = self.get_file(&sha256_digest, None).await?;
services()
.rate_limiting
.check_media_download(target, size(&file)?)
.await?;
Ok(Some(FileMeta { Ok(Some(FileMeta {
content_disposition: content_disposition(filename, &content_type), content_disposition: content_disposition(filename, &content_type),
content_type, content_type,
@ -293,7 +300,7 @@ impl Service {
media_id: &str, media_id: &str,
width: u32, width: u32,
height: u32, height: u32,
authenticated: bool, target: Option<Target>,
) -> Result<Option<FileMeta>> { ) -> Result<Option<FileMeta>> {
if let Some((width, height, crop)) = self.thumbnail_properties(width, height) { if let Some((width, height, crop)) = self.thumbnail_properties(width, height) {
if let Ok(DbFileMeta { if let Ok(DbFileMeta {
@ -305,10 +312,19 @@ impl Service {
.db .db
.search_thumbnail_metadata(servername, media_id, width, height) .search_thumbnail_metadata(servername, media_id, width, height)
{ {
if !(authenticated || unauthenticated_access_permitted) { if !(target.as_ref().is_some_and(Target::is_authenticated)
|| unauthenticated_access_permitted)
{
return Ok(None); return Ok(None);
} }
let file_info = self.file_info(&sha256_digest)?;
services()
.rate_limiting
.check_media_download(target, file_info.size)
.await?;
// Using saved thumbnail // Using saved thumbnail
let file = self let file = self
.get_file(&sha256_digest, Some((servername, media_id))) .get_file(&sha256_digest, Some((servername, media_id)))
@ -319,19 +335,15 @@ impl Service {
content_type, content_type,
file, file,
})) }))
} else if !authenticated { } else if !target.as_ref().is_some_and(Target::is_authenticated) {
return Ok(None); return Ok(None);
} else if let Ok(DbFileMeta { } else if let Ok(DbFileMeta {
sha256_digest, sha256_digest,
filename, filename,
content_type, content_type,
unauthenticated_access_permitted, ..
}) = self.db.search_file_metadata(servername, media_id) }) = self.db.search_file_metadata(servername, media_id)
{ {
if !(authenticated || unauthenticated_access_permitted) {
return Ok(None);
}
let content_disposition = content_disposition(filename.clone(), &content_type); let content_disposition = content_disposition(filename.clone(), &content_type);
// Generate a thumbnail // Generate a thumbnail
let file = self.get_file(&sha256_digest, None).await?; let file = self.get_file(&sha256_digest, None).await?;
@ -431,7 +443,9 @@ impl Service {
return Ok(None); return Ok(None);
}; };
if !(authenticated || unauthenticated_access_permitted) { if !(target.as_ref().is_some_and(Target::is_authenticated)
|| unauthenticated_access_permitted)
{
return Ok(None); return Ok(None);
} }
@ -667,6 +681,13 @@ impl Service {
.update_last_accessed_filehash(sha256_digest) .update_last_accessed_filehash(sha256_digest)
.map(|_| file) .map(|_| file)
} }
fn file_info(&self, sha256_digest: &[u8]) -> Result<FileInfo> {
self.db
.file_info(sha256_digest)
.transpose()
.unwrap_or_else(|| Err(Error::BadRequest(ErrorKind::NotFound, "Fi)le not found")))
}
} }
/// Creates the media file, using the configured media backend /// Creates the media file, using the configured media backend

4
conduit/src/service/mod.rs

@ -17,6 +17,7 @@ pub mod key_backups;
pub mod media; pub mod media;
pub mod pdu; pub mod pdu;
pub mod pusher; pub mod pusher;
pub mod rate_limiting;
pub mod rooms; pub mod rooms;
pub mod sending; pub mod sending;
pub mod transaction_ids; pub mod transaction_ids;
@ -36,6 +37,7 @@ pub struct Services {
pub key_backups: key_backups::Service, pub key_backups: key_backups::Service,
pub media: Arc<media::Service>, pub media: Arc<media::Service>,
pub sending: Arc<sending::Service>, pub sending: Arc<sending::Service>,
pub rate_limiting: Arc<rate_limiting::Service>,
} }
impl Services { impl Services {
@ -123,6 +125,8 @@ impl Services {
media: Arc::new(media::Service { db }), media: Arc::new(media::Service { db }),
sending: sending::Service::build(db, &config), sending: sending::Service::build(db, &config),
rate_limiting: rate_limiting::Service::build(&config),
globals: globals::Service::load(db, config)?, globals: globals::Service::load(db, config)?,
}) })
} }

581
conduit/src/service/rate_limiting/mod.rs

@ -0,0 +1,581 @@
use std::{
collections::{HashMap, hash_map::Entry},
net::IpAddr,
sync::Arc,
time::Duration,
};
use conduit_config::{
Config,
rate_limiting::{MediaLimitation, RequestLimitation, Restriction},
};
use ruma::{
OwnedServerName, OwnedUserId, UserId,
api::{
Metadata,
client::error::{ErrorKind, RetryAfter},
},
};
use tokio::{
sync::{Mutex, MutexGuard, RwLock},
time::Instant,
};
use crate::{Error, Result, service::appservice::RegistrationInfo, services};
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub enum Target {
User(OwnedUserId),
// Server endpoints should be rate-limited on a server and room basis
Server(OwnedServerName),
Appservice { id: String, rate_limited: bool },
Ip(IpAddr),
}
impl Target {
pub fn from_client_request(
registration_info: Option<RegistrationInfo>,
sender_user: &UserId,
) -> Self {
if let Some(info) = registration_info {
// `rate_limited` only effects "masqueraded users", "The sender [user?] is excluded"
return Target::Appservice {
id: info.registration.id,
rate_limited: info.registration.rate_limited.unwrap_or(true)
&& !(sender_user.server_name() == services().globals.server_name()
&& info.registration.sender_localpart == sender_user.localpart()),
};
}
Target::User(sender_user.to_owned())
}
pub fn from_client_request_optional_auth(
registration_info: Option<RegistrationInfo>,
sender_user: &Option<OwnedUserId>,
ip_addr: Option<IpAddr>,
) -> Option<Self> {
if let Some(sender_user) = sender_user.as_ref() {
Some(Self::from_client_request(registration_info, sender_user))
} else {
ip_addr.map(Self::Ip)
}
}
fn rate_limited(&self) -> bool {
match self {
Target::User(user_id) => user_id != services().globals.server_user(),
Target::Appservice {
id: _,
rate_limited,
} => *rate_limited,
_ => true,
}
}
pub fn is_authenticated(&self) -> bool {
!matches!(self, Target::Ip(_))
}
}
// NOTE: still chokes on the global mutex around the map, which in theory only needed for inserts,
// but due to rust ownership model I don't think it's possible to (easily) only require to lock
// said mutex for inserts (and have a rwlock for removals if memory usage ever becomes a problem).
type MediaBucket = Mutex<HashMap<Target, Arc<Mutex<Instant>>>>;
type GlobalMediaBucket = Arc<Mutex<Instant>>;
type RequestPermittedAfter = Arc<Mutex<Instant>>;
pub struct Service {
buckets: Mutex<HashMap<(Target, Restriction), RequestPermittedAfter>>,
global_bucket: Mutex<HashMap<Restriction, RequestPermittedAfter>>,
media_upload: MediaBucket,
media_fetch: MediaBucket,
media_download: MediaBucket,
global_media_upload: GlobalMediaBucket,
global_media_fetch: GlobalMediaBucket,
global_media_download_client: GlobalMediaBucket,
global_media_download_federation: GlobalMediaBucket,
authentication_failures: RwLock<HashMap<IpAddr, Arc<RwLock<Instant>>>>,
}
impl Service {
pub fn build(config: &Config) -> Arc<Self> {
let now = Instant::now();
let global_media_config = &config.rate_limiting.global;
Arc::new(Self {
buckets: Mutex::new(HashMap::new()),
global_bucket: Mutex::new(HashMap::new()),
media_upload: Mutex::new(HashMap::new()),
media_fetch: Mutex::new(HashMap::new()),
media_download: Mutex::new(HashMap::new()),
global_media_upload: default_media_entry(global_media_config.client.media.upload, now),
global_media_fetch: default_media_entry(global_media_config.client.media.fetch, now),
global_media_download_client: default_media_entry(
global_media_config.client.media.download,
now,
),
global_media_download_federation: default_media_entry(
global_media_config.federation.media.download,
now,
),
authentication_failures: RwLock::new(HashMap::new()),
})
}
//TODO: use checked and saturating arithmetic
/// Takes the target and request, and either accepts the request while adding to the
/// bucket, or rejects the request, returning the duration that should be waited until
/// the request should be retried.
pub async fn check(&self, target: Option<Target>, request: Metadata) -> Result<()> {
let Ok(restriction) = request.try_into() else {
// Endpoint has no associated restriction
return Ok(());
};
let arrival = Instant::now();
let config = services()
.globals
.config
.rate_limiting
.global
.get(&restriction);
let mut map = self.global_bucket.lock().await;
let entry = map.entry(restriction);
let proposed_entry = match &entry {
Entry::Occupied(occupied_entry) => {
let entry = Arc::clone(occupied_entry.get());
let entry = entry.lock().await;
if arrival.checked_duration_since(*entry).is_none() {
return instant_to_err(&entry);
}
let min_instant = arrival
- Duration::from_nanos(
config.timeframe.nano_gap() * config.burst_capacity.get(),
);
entry.max(min_instant) + Duration::from_nanos(config.timeframe.nano_gap())
}
Entry::Vacant(_) => {
arrival
- Duration::from_nanos(
config.timeframe.nano_gap() * (config.burst_capacity.get() - 1),
)
}
};
if let Some(target) = target {
let config = services()
.globals
.config
.rate_limiting
.target
.get(&restriction);
let mut map = self.buckets.lock().await;
let entry = map.entry((target, restriction));
match entry {
Entry::Occupied(occupied_entry) => {
let entry = Arc::clone(occupied_entry.get());
let mut entry = entry.lock().await;
if arrival.checked_duration_since(*entry).is_none() {
return instant_to_err(&entry);
}
let min_instant = arrival
- Duration::from_nanos(
config.timeframe.nano_gap() * config.burst_capacity.get(),
);
*entry =
entry.max(min_instant) + Duration::from_nanos(config.timeframe.nano_gap());
}
Entry::Vacant(vacant_entry) => {
vacant_entry.insert(Arc::new(Mutex::new(
arrival
- Duration::from_nanos(
config.timeframe.nano_gap() * (config.burst_capacity.get() - 1),
),
)));
}
}
}
entry.insert_entry(Arc::new(Mutex::new(proposed_entry)));
Ok(())
}
pub async fn check_media_download(&self, target: Option<Target>, size: u64) -> Result<()> {
// All targets besides servers use the client-server API
let (target_limitation, global_limitation, global_bucket) =
if let Some(Target::Server(_)) = &target {
(
services()
.globals
.config
.rate_limiting
.target
.federation
.media
.download,
services()
.globals
.config
.rate_limiting
.global
.federation
.media
.download,
&self.global_media_download_federation,
)
} else {
(
services()
.globals
.config
.rate_limiting
.target
.client
.media
.download,
services()
.globals
.config
.rate_limiting
.global
.client
.media
.download,
&self.global_media_download_client,
)
};
check_media(
target,
size,
target_limitation,
global_limitation,
&self.media_download,
global_bucket,
)
.await
}
pub async fn check_media_upload(&self, target: Target, size: u64) -> Result<()> {
let target_limitation = services()
.globals
.config
.rate_limiting
.target
// Media can only be uploaded on the client-server API
.client
.media
.upload;
let global_limitation = services()
.globals
.config
.rate_limiting
.global
// Media can only be uploaded on the client-server API
.client
.media
.upload;
check_media(
Some(target),
size,
target_limitation,
global_limitation,
&self.media_upload,
&self.global_media_upload,
)
.await
}
pub async fn check_media_pre_fetch(&self, target: &Target) -> Result<()> {
if !target.rate_limited() {
return Ok(());
}
let arrival = Instant::now();
let check = async |map: &MediaBucket, global_bucket: &GlobalMediaBucket| {
let map = map.lock().await;
if let Some(mutex) = map.get(target) {
let mutex = mutex.lock().await;
if arrival.checked_duration_since(*mutex).is_none() {
return instant_to_err(&mutex);
}
}
let global_bucket = global_bucket.lock().await;
if arrival.checked_duration_since(*global_bucket).is_none() {
return instant_to_err(&global_bucket);
}
Ok(())
};
// checking fetch
check(&self.media_fetch, &self.global_media_fetch).await?;
// checking download as well
check(&self.media_download, &self.global_media_download_client).await
}
/// Checks whether the ip address is has been rate limited due to too many bad access tokens being sent.
pub async fn pre_auth_check(&self, ip_addr: IpAddr) -> Result<()> {
let arrival = Instant::now();
if let Some(instant) = self.authentication_failures.read().await.get(&ip_addr) {
let instant = instant.read().await;
if arrival.checked_duration_since(*instant).is_none() {
return instant_to_err(&instant);
}
}
Ok(())
}
/// Updates the bad auth rate limiter when a bad access token is sent where access tokens auth is an option.
pub async fn update_post_auth_failure(&self, ip_addr: IpAddr) {
let arrival = Instant::now();
let RequestLimitation {
timeframe,
burst_capacity,
} = services()
.globals
.config
.rate_limiting
.target
.client
.authentication_failures;
let mut map = self.authentication_failures.write().await;
let entry = map.entry(ip_addr);
match entry {
Entry::Occupied(occupied_entry) => {
let entry = Arc::clone(occupied_entry.get());
let mut entry = entry.write().await;
let min_instant =
arrival - Duration::from_nanos(timeframe.nano_gap() * burst_capacity.get());
*entry = entry.max(min_instant) + Duration::from_nanos(timeframe.nano_gap());
}
Entry::Vacant(vacant_entry) => {
vacant_entry.insert(Arc::new(RwLock::new(
arrival - Duration::from_nanos(burst_capacity.get() / timeframe.nano_gap()),
)));
}
}
}
pub async fn update_media_post_fetch(&self, target: Target, size: u64) {
if !target.rate_limited() {
return;
}
let arrival = Instant::now();
let update = async |map: &MediaBucket,
target_limitation: &MediaLimitation,
global_bucket: &GlobalMediaBucket,
global_limitation: &MediaLimitation| {
let mut map = map.lock().await;
let entry = map.entry(target.clone());
match entry {
Entry::Occupied(occupied_entry) => {
let entry = Arc::clone(occupied_entry.get());
let _ =
update_media_entry(size, target_limitation, &arrival, entry, false).await;
}
Entry::Vacant(vacant_entry) => {
vacant_entry.insert(Arc::new(Mutex::new(
arrival
- Duration::from_nanos(
target_limitation.burst_capacity.as_u64()
/ target_limitation.timeframe.bytes_per_sec(),
),
)));
}
}
let _ = update_media_entry(
size,
global_limitation,
&arrival,
Arc::clone(global_bucket),
false,
)
.await;
};
// updating fetch
update(
&self.media_fetch,
&services()
.globals
.config
.rate_limiting
.target
.client
.media
.fetch,
&self.global_media_fetch,
&services()
.globals
.config
.rate_limiting
.global
.client
.media
.fetch,
)
.await;
// updating download as well
update(
&self.media_download,
&services()
.globals
.config
.rate_limiting
.target
.client
.media
.download,
&self.global_media_download_client,
&services()
.globals
.config
.rate_limiting
.global
.client
.media
.download,
)
.await;
}
}
async fn update_media_entry(
size: u64,
limitation: &MediaLimitation,
arrival: &Instant,
entry: Arc<Mutex<Instant>>,
and_check: bool,
) -> Result<()> {
let mut entry = entry.lock().await;
//TODO: use more precise conversion than secs
let proposed_entry = get_proposed_entry(size, limitation, arrival, &entry, and_check)?;
*entry = proposed_entry;
Ok(())
}
fn get_proposed_entry(
size: u64,
limitation: &MediaLimitation,
arrival: &Instant,
entry: &MutexGuard<'_, Instant>,
and_check: bool,
) -> Result<Instant> {
let min_instant = *arrival
- Duration::from_secs(
limitation.burst_capacity.as_u64() / limitation.timeframe.bytes_per_sec(),
);
let proposed_entry =
entry.max(min_instant) + Duration::from_secs(size / limitation.timeframe.bytes_per_sec());
if and_check && arrival.checked_duration_since(proposed_entry).is_none() {
return instant_to_err(&proposed_entry).map(|_| proposed_entry);
}
Ok(proposed_entry)
}
async fn check_media(
target: Option<Target>,
size: u64,
target_limitation: MediaLimitation,
global_limitation: MediaLimitation,
target_map: &MediaBucket,
global_bucket: &GlobalMediaBucket,
) -> Result<()> {
if !target.as_ref().is_some_and(Target::rate_limited) {
return Ok(());
}
let arrival = Instant::now();
let mut global_bucket = global_bucket.lock().await;
let proposed = get_proposed_entry(size, &global_limitation, &arrival, &global_bucket, true)?;
if let Some(target) = target {
let mut map = target_map.lock().await;
let entry = map.entry(target);
match entry {
Entry::Occupied(occupied_entry) => {
let entry = Arc::clone(occupied_entry.get());
update_media_entry(size, &target_limitation, &arrival, entry, true).await?;
}
Entry::Vacant(vacant_entry) => {
vacant_entry.insert(default_media_entry(target_limitation, arrival));
}
}
}
*global_bucket = proposed;
Ok(())
}
fn default_media_entry(
target_limitation: MediaLimitation,
arrival: Instant,
) -> Arc<Mutex<Instant>> {
Arc::new(Mutex::new(
arrival
- Duration::from_nanos(
target_limitation.burst_capacity.as_u64()
/ target_limitation.timeframe.bytes_per_sec(),
),
))
}
fn instant_to_err(instant: &Instant) -> Result<()> {
let now = Instant::now();
Err(Error::BadRequest(
ErrorKind::LimitExceeded {
// Not using ::DateTime because conversion from Instant to SystemTime is convoluted
retry_after: Some(RetryAfter::Delay(instant.duration_since(now))),
},
"Rate limit exceeded",
))
}
Loading…
Cancel
Save