mirror of https://gitlab.com/famedly/conduit.git
13 changed files with 1386 additions and 353 deletions
@ -0,0 +1,236 @@
|
||||
use std::{collections::HashMap, num::NonZeroU64}; |
||||
|
||||
use bytesize::ByteSize; |
||||
use ruma::api::Metadata; |
||||
use serde::Deserialize; |
||||
|
||||
#[derive(Debug, Clone, Deserialize)] |
||||
pub struct Config { |
||||
#[serde(flatten)] |
||||
pub target: ConfigFragment, |
||||
pub global: ConfigFragment, |
||||
} |
||||
|
||||
#[derive(Debug, Clone, Deserialize)] |
||||
pub struct ConfigFragment { |
||||
pub client: ConfigClientFragment, |
||||
pub federation: ConfigFederationFragment, |
||||
} |
||||
|
||||
#[derive(Debug, Clone, Deserialize)] |
||||
pub struct ConfigClientFragment { |
||||
pub map: HashMap<ClientRestriction, RequestLimitation>, |
||||
pub media: ClientMediaConfig, |
||||
// TODO: Only have available on target, not global (same with most authenticated endpoints too maybe)?
|
||||
pub authentication_failures: RequestLimitation, |
||||
} |
||||
|
||||
#[derive(Debug, Clone, Deserialize)] |
||||
pub struct ConfigFederationFragment { |
||||
pub map: HashMap<FederationRestriction, RequestLimitation>, |
||||
pub media: FederationMediaConfig, |
||||
} |
||||
|
||||
impl ConfigFragment { |
||||
pub fn get(&self, restriction: &Restriction) -> &RequestLimitation { |
||||
// Maybe look into https://github.com/moriyoshi-kasuga/enum-table
|
||||
match restriction { |
||||
Restriction::Client(client_restriction) => { |
||||
self.client.map.get(client_restriction).unwrap() |
||||
} |
||||
Restriction::Federation(federation_restriction) => { |
||||
self.federation.map.get(federation_restriction).unwrap() |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] |
||||
pub enum Restriction { |
||||
Client(ClientRestriction), |
||||
Federation(FederationRestriction), |
||||
} |
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)] |
||||
#[serde(rename_all = "snake_case")] |
||||
pub enum ClientRestriction { |
||||
Registration, |
||||
Login, |
||||
RegistrationTokenValidity, |
||||
|
||||
SendEvent, |
||||
|
||||
Join, |
||||
Invite, |
||||
Knock, |
||||
|
||||
SendReport, |
||||
CreateAlias, |
||||
|
||||
MediaDownload, |
||||
MediaCreate, |
||||
} |
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd)] |
||||
#[serde(rename_all = "snake_case")] |
||||
pub enum FederationRestriction { |
||||
Join, |
||||
Knock, |
||||
Invite, |
||||
|
||||
// Transactions should be handled by a completely dedicated rate-limiter
|
||||
Transaction, |
||||
|
||||
MediaDownload, |
||||
} |
||||
|
||||
impl TryFrom<Metadata> for Restriction { |
||||
type Error = (); |
||||
|
||||
fn try_from(value: Metadata) -> Result<Self, Self::Error> { |
||||
use Restriction::*; |
||||
use ruma::api::{ |
||||
IncomingRequest, |
||||
client::{ |
||||
account::{check_registration_token_validity, register}, |
||||
alias::create_alias, |
||||
authenticated_media::{ |
||||
get_content, get_content_as_filename, get_content_thumbnail, get_media_preview, |
||||
}, |
||||
knock::knock_room, |
||||
media::{self, create_content, create_mxc_uri}, |
||||
membership::{invite_user, join_room_by_id, join_room_by_id_or_alias}, |
||||
message::send_message_event, |
||||
reporting::report_user, |
||||
room::{report_content, report_room}, |
||||
session::login, |
||||
state::send_state_event, |
||||
}, |
||||
federation::{ |
||||
authenticated_media::{ |
||||
get_content as federation_get_content, |
||||
get_content_thumbnail as federation_get_content_thumbnail, |
||||
}, |
||||
membership::{create_invite, create_join_event, create_knock_event}, |
||||
}, |
||||
}; |
||||
|
||||
Ok(match value { |
||||
register::v3::Request::METADATA => Client(ClientRestriction::Registration), |
||||
check_registration_token_validity::v1::Request::METADATA => { |
||||
Client(ClientRestriction::RegistrationTokenValidity) |
||||
} |
||||
login::v3::Request::METADATA => Client(ClientRestriction::Login), |
||||
send_message_event::v3::Request::METADATA | send_state_event::v3::Request::METADATA => { |
||||
Client(ClientRestriction::SendEvent) |
||||
} |
||||
join_room_by_id::v3::Request::METADATA |
||||
| join_room_by_id_or_alias::v3::Request::METADATA => Client(ClientRestriction::Join), |
||||
invite_user::v3::Request::METADATA => Client(ClientRestriction::Invite), |
||||
knock_room::v3::Request::METADATA => Client(ClientRestriction::Knock), |
||||
report_user::v3::Request::METADATA |
||||
| report_content::v3::Request::METADATA |
||||
| report_room::v3::Request::METADATA => Client(ClientRestriction::SendReport), |
||||
create_alias::v3::Request::METADATA => Client(ClientRestriction::CreateAlias), |
||||
// NOTE: handle async media upload in a way that doesn't half the number of uploads you can do within a short timeframe, while not allowing pre-generation of MXC uris to allow uploading double the number of media at once
|
||||
create_content::v3::Request::METADATA | create_mxc_uri::v1::Request::METADATA => { |
||||
Client(ClientRestriction::MediaCreate) |
||||
} |
||||
// Unauthenticate media is deprecated
|
||||
#[allow(deprecated)] |
||||
media::get_content::v3::Request::METADATA |
||||
| media::get_content_as_filename::v3::Request::METADATA |
||||
| media::get_content_thumbnail::v3::Request::METADATA |
||||
| media::get_media_preview::v3::Request::METADATA |
||||
| get_content::v1::Request::METADATA |
||||
| get_content_as_filename::v1::Request::METADATA |
||||
| get_content_thumbnail::v1::Request::METADATA |
||||
| get_media_preview::v1::Request::METADATA => Client(ClientRestriction::MediaDownload), |
||||
federation_get_content::v1::Request::METADATA |
||||
| federation_get_content_thumbnail::v1::Request::METADATA => { |
||||
Federation(FederationRestriction::MediaDownload) |
||||
} |
||||
// v1 is deprecated
|
||||
#[allow(deprecated)] |
||||
create_join_event::v1::Request::METADATA | create_join_event::v2::Request::METADATA => { |
||||
Federation(FederationRestriction::Join) |
||||
} |
||||
create_knock_event::v1::Request::METADATA => Federation(FederationRestriction::Knock), |
||||
create_invite::v1::Request::METADATA | create_invite::v2::Request::METADATA => { |
||||
Federation(FederationRestriction::Invite) |
||||
} |
||||
|
||||
_ => return Err(()), |
||||
}) |
||||
} |
||||
} |
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize)] |
||||
pub struct RequestLimitation { |
||||
#[serde(flatten)] |
||||
pub timeframe: Timeframe, |
||||
pub burst_capacity: NonZeroU64, |
||||
} |
||||
|
||||
#[derive(Deserialize, Clone, Copy, Debug)] |
||||
#[serde(rename_all = "snake_case")] |
||||
// When deserializing, we want this prefix
|
||||
#[allow(clippy::enum_variant_names)] |
||||
pub enum Timeframe { |
||||
PerSecond(NonZeroU64), |
||||
PerMinute(NonZeroU64), |
||||
PerHour(NonZeroU64), |
||||
PerDay(NonZeroU64), |
||||
} |
||||
|
||||
impl Timeframe { |
||||
pub fn nano_gap(&self) -> u64 { |
||||
match self { |
||||
Timeframe::PerSecond(t) => 1000 * 1000 * 1000 / t.get(), |
||||
Timeframe::PerMinute(t) => 1000 * 1000 * 1000 * 60 / t.get(), |
||||
Timeframe::PerHour(t) => 1000 * 1000 * 1000 * 60 * 60 / t.get(), |
||||
Timeframe::PerDay(t) => 1000 * 1000 * 1000 * 60 * 60 * 24 / t.get(), |
||||
} |
||||
} |
||||
} |
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize)] |
||||
pub struct ClientMediaConfig { |
||||
pub download: MediaLimitation, |
||||
pub upload: MediaLimitation, |
||||
pub fetch: MediaLimitation, |
||||
} |
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize)] |
||||
pub struct FederationMediaConfig { |
||||
pub download: MediaLimitation, |
||||
} |
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize)] |
||||
pub struct MediaLimitation { |
||||
#[serde(flatten)] |
||||
pub timeframe: MediaTimeframe, |
||||
pub burst_capacity: ByteSize, |
||||
} |
||||
|
||||
#[derive(Deserialize, Clone, Copy, Debug)] |
||||
#[serde(rename_all = "snake_case")] |
||||
// When deserializing, we want this prefix
|
||||
#[allow(clippy::enum_variant_names)] |
||||
pub enum MediaTimeframe { |
||||
PerSecond(ByteSize), |
||||
PerMinute(ByteSize), |
||||
PerHour(ByteSize), |
||||
PerDay(ByteSize), |
||||
} |
||||
|
||||
impl MediaTimeframe { |
||||
pub fn bytes_per_sec(&self) -> u64 { |
||||
match self { |
||||
MediaTimeframe::PerSecond(t) => t.as_u64(), |
||||
MediaTimeframe::PerMinute(t) => t.as_u64() / 60, |
||||
MediaTimeframe::PerHour(t) => t.as_u64() / (60 * 60), |
||||
MediaTimeframe::PerDay(t) => t.as_u64() / (60 * 60 * 24), |
||||
} |
||||
} |
||||
} |
||||
@ -0,0 +1,581 @@
|
||||
use std::{ |
||||
collections::{HashMap, hash_map::Entry}, |
||||
net::IpAddr, |
||||
sync::Arc, |
||||
time::Duration, |
||||
}; |
||||
|
||||
use conduit_config::{ |
||||
Config, |
||||
rate_limiting::{MediaLimitation, RequestLimitation, Restriction}, |
||||
}; |
||||
use ruma::{ |
||||
OwnedServerName, OwnedUserId, UserId, |
||||
api::{ |
||||
Metadata, |
||||
client::error::{ErrorKind, RetryAfter}, |
||||
}, |
||||
}; |
||||
use tokio::{ |
||||
sync::{Mutex, MutexGuard, RwLock}, |
||||
time::Instant, |
||||
}; |
||||
|
||||
use crate::{Error, Result, service::appservice::RegistrationInfo, services}; |
||||
|
||||
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] |
||||
pub enum Target { |
||||
User(OwnedUserId), |
||||
// Server endpoints should be rate-limited on a server and room basis
|
||||
Server(OwnedServerName), |
||||
Appservice { id: String, rate_limited: bool }, |
||||
Ip(IpAddr), |
||||
} |
||||
|
||||
impl Target { |
||||
pub fn from_client_request( |
||||
registration_info: Option<RegistrationInfo>, |
||||
sender_user: &UserId, |
||||
) -> Self { |
||||
if let Some(info) = registration_info { |
||||
// `rate_limited` only effects "masqueraded users", "The sender [user?] is excluded"
|
||||
return Target::Appservice { |
||||
id: info.registration.id, |
||||
rate_limited: info.registration.rate_limited.unwrap_or(true) |
||||
&& !(sender_user.server_name() == services().globals.server_name() |
||||
&& info.registration.sender_localpart == sender_user.localpart()), |
||||
}; |
||||
} |
||||
|
||||
Target::User(sender_user.to_owned()) |
||||
} |
||||
|
||||
pub fn from_client_request_optional_auth( |
||||
registration_info: Option<RegistrationInfo>, |
||||
sender_user: &Option<OwnedUserId>, |
||||
ip_addr: Option<IpAddr>, |
||||
) -> Option<Self> { |
||||
if let Some(sender_user) = sender_user.as_ref() { |
||||
Some(Self::from_client_request(registration_info, sender_user)) |
||||
} else { |
||||
ip_addr.map(Self::Ip) |
||||
} |
||||
} |
||||
|
||||
fn rate_limited(&self) -> bool { |
||||
match self { |
||||
Target::User(user_id) => user_id != services().globals.server_user(), |
||||
Target::Appservice { |
||||
id: _, |
||||
rate_limited, |
||||
} => *rate_limited, |
||||
_ => true, |
||||
} |
||||
} |
||||
|
||||
pub fn is_authenticated(&self) -> bool { |
||||
!matches!(self, Target::Ip(_)) |
||||
} |
||||
} |
||||
|
||||
// NOTE: still chokes on the global mutex around the map, which in theory only needed for inserts,
|
||||
// but due to rust ownership model I don't think it's possible to (easily) only require to lock
|
||||
// said mutex for inserts (and have a rwlock for removals if memory usage ever becomes a problem).
|
||||
type MediaBucket = Mutex<HashMap<Target, Arc<Mutex<Instant>>>>; |
||||
type GlobalMediaBucket = Arc<Mutex<Instant>>; |
||||
|
||||
type RequestPermittedAfter = Arc<Mutex<Instant>>; |
||||
|
||||
pub struct Service { |
||||
buckets: Mutex<HashMap<(Target, Restriction), RequestPermittedAfter>>, |
||||
global_bucket: Mutex<HashMap<Restriction, RequestPermittedAfter>>, |
||||
|
||||
media_upload: MediaBucket, |
||||
media_fetch: MediaBucket, |
||||
media_download: MediaBucket, |
||||
|
||||
global_media_upload: GlobalMediaBucket, |
||||
global_media_fetch: GlobalMediaBucket, |
||||
global_media_download_client: GlobalMediaBucket, |
||||
global_media_download_federation: GlobalMediaBucket, |
||||
|
||||
authentication_failures: RwLock<HashMap<IpAddr, Arc<RwLock<Instant>>>>, |
||||
} |
||||
|
||||
impl Service { |
||||
pub fn build(config: &Config) -> Arc<Self> { |
||||
let now = Instant::now(); |
||||
let global_media_config = &config.rate_limiting.global; |
||||
|
||||
Arc::new(Self { |
||||
buckets: Mutex::new(HashMap::new()), |
||||
global_bucket: Mutex::new(HashMap::new()), |
||||
|
||||
media_upload: Mutex::new(HashMap::new()), |
||||
media_fetch: Mutex::new(HashMap::new()), |
||||
media_download: Mutex::new(HashMap::new()), |
||||
|
||||
global_media_upload: default_media_entry(global_media_config.client.media.upload, now), |
||||
global_media_fetch: default_media_entry(global_media_config.client.media.fetch, now), |
||||
global_media_download_client: default_media_entry( |
||||
global_media_config.client.media.download, |
||||
now, |
||||
), |
||||
global_media_download_federation: default_media_entry( |
||||
global_media_config.federation.media.download, |
||||
now, |
||||
), |
||||
|
||||
authentication_failures: RwLock::new(HashMap::new()), |
||||
}) |
||||
} |
||||
|
||||
//TODO: use checked and saturating arithmetic
|
||||
|
||||
/// Takes the target and request, and either accepts the request while adding to the
|
||||
/// bucket, or rejects the request, returning the duration that should be waited until
|
||||
/// the request should be retried.
|
||||
pub async fn check(&self, target: Option<Target>, request: Metadata) -> Result<()> { |
||||
let Ok(restriction) = request.try_into() else { |
||||
// Endpoint has no associated restriction
|
||||
return Ok(()); |
||||
}; |
||||
let arrival = Instant::now(); |
||||
|
||||
let config = services() |
||||
.globals |
||||
.config |
||||
.rate_limiting |
||||
.global |
||||
.get(&restriction); |
||||
|
||||
let mut map = self.global_bucket.lock().await; |
||||
|
||||
let entry = map.entry(restriction); |
||||
let proposed_entry = match &entry { |
||||
Entry::Occupied(occupied_entry) => { |
||||
let entry = Arc::clone(occupied_entry.get()); |
||||
let entry = entry.lock().await; |
||||
|
||||
if arrival.checked_duration_since(*entry).is_none() { |
||||
return instant_to_err(&entry); |
||||
} |
||||
|
||||
let min_instant = arrival |
||||
- Duration::from_nanos( |
||||
config.timeframe.nano_gap() * config.burst_capacity.get(), |
||||
); |
||||
entry.max(min_instant) + Duration::from_nanos(config.timeframe.nano_gap()) |
||||
} |
||||
Entry::Vacant(_) => { |
||||
arrival |
||||
- Duration::from_nanos( |
||||
config.timeframe.nano_gap() * (config.burst_capacity.get() - 1), |
||||
) |
||||
} |
||||
}; |
||||
|
||||
if let Some(target) = target { |
||||
let config = services() |
||||
.globals |
||||
.config |
||||
.rate_limiting |
||||
.target |
||||
.get(&restriction); |
||||
|
||||
let mut map = self.buckets.lock().await; |
||||
let entry = map.entry((target, restriction)); |
||||
match entry { |
||||
Entry::Occupied(occupied_entry) => { |
||||
let entry = Arc::clone(occupied_entry.get()); |
||||
let mut entry = entry.lock().await; |
||||
|
||||
if arrival.checked_duration_since(*entry).is_none() { |
||||
return instant_to_err(&entry); |
||||
} |
||||
|
||||
let min_instant = arrival |
||||
- Duration::from_nanos( |
||||
config.timeframe.nano_gap() * config.burst_capacity.get(), |
||||
); |
||||
*entry = |
||||
entry.max(min_instant) + Duration::from_nanos(config.timeframe.nano_gap()); |
||||
} |
||||
Entry::Vacant(vacant_entry) => { |
||||
vacant_entry.insert(Arc::new(Mutex::new( |
||||
arrival |
||||
- Duration::from_nanos( |
||||
config.timeframe.nano_gap() * (config.burst_capacity.get() - 1), |
||||
), |
||||
))); |
||||
} |
||||
} |
||||
} |
||||
|
||||
entry.insert_entry(Arc::new(Mutex::new(proposed_entry))); |
||||
|
||||
Ok(()) |
||||
} |
||||
|
||||
pub async fn check_media_download(&self, target: Option<Target>, size: u64) -> Result<()> { |
||||
// All targets besides servers use the client-server API
|
||||
let (target_limitation, global_limitation, global_bucket) = |
||||
if let Some(Target::Server(_)) = &target { |
||||
( |
||||
services() |
||||
.globals |
||||
.config |
||||
.rate_limiting |
||||
.target |
||||
.federation |
||||
.media |
||||
.download, |
||||
services() |
||||
.globals |
||||
.config |
||||
.rate_limiting |
||||
.global |
||||
.federation |
||||
.media |
||||
.download, |
||||
&self.global_media_download_federation, |
||||
) |
||||
} else { |
||||
( |
||||
services() |
||||
.globals |
||||
.config |
||||
.rate_limiting |
||||
.target |
||||
.client |
||||
.media |
||||
.download, |
||||
services() |
||||
.globals |
||||
.config |
||||
.rate_limiting |
||||
.global |
||||
.client |
||||
.media |
||||
.download, |
||||
&self.global_media_download_client, |
||||
) |
||||
}; |
||||
|
||||
check_media( |
||||
target, |
||||
size, |
||||
target_limitation, |
||||
global_limitation, |
||||
&self.media_download, |
||||
global_bucket, |
||||
) |
||||
.await |
||||
} |
||||
|
||||
pub async fn check_media_upload(&self, target: Target, size: u64) -> Result<()> { |
||||
let target_limitation = services() |
||||
.globals |
||||
.config |
||||
.rate_limiting |
||||
.target |
||||
// Media can only be uploaded on the client-server API
|
||||
.client |
||||
.media |
||||
.upload; |
||||
|
||||
let global_limitation = services() |
||||
.globals |
||||
.config |
||||
.rate_limiting |
||||
.global |
||||
// Media can only be uploaded on the client-server API
|
||||
.client |
||||
.media |
||||
.upload; |
||||
|
||||
check_media( |
||||
Some(target), |
||||
size, |
||||
target_limitation, |
||||
global_limitation, |
||||
&self.media_upload, |
||||
&self.global_media_upload, |
||||
) |
||||
.await |
||||
} |
||||
|
||||
pub async fn check_media_pre_fetch(&self, target: &Target) -> Result<()> { |
||||
if !target.rate_limited() { |
||||
return Ok(()); |
||||
} |
||||
|
||||
let arrival = Instant::now(); |
||||
|
||||
let check = async |map: &MediaBucket, global_bucket: &GlobalMediaBucket| { |
||||
let map = map.lock().await; |
||||
if let Some(mutex) = map.get(target) { |
||||
let mutex = mutex.lock().await; |
||||
|
||||
if arrival.checked_duration_since(*mutex).is_none() { |
||||
return instant_to_err(&mutex); |
||||
} |
||||
} |
||||
|
||||
let global_bucket = global_bucket.lock().await; |
||||
|
||||
if arrival.checked_duration_since(*global_bucket).is_none() { |
||||
return instant_to_err(&global_bucket); |
||||
} |
||||
|
||||
Ok(()) |
||||
}; |
||||
|
||||
// checking fetch
|
||||
check(&self.media_fetch, &self.global_media_fetch).await?; |
||||
|
||||
// checking download as well
|
||||
check(&self.media_download, &self.global_media_download_client).await |
||||
} |
||||
|
||||
/// Checks whether the ip address is has been rate limited due to too many bad access tokens being sent.
|
||||
pub async fn pre_auth_check(&self, ip_addr: IpAddr) -> Result<()> { |
||||
let arrival = Instant::now(); |
||||
|
||||
if let Some(instant) = self.authentication_failures.read().await.get(&ip_addr) { |
||||
let instant = instant.read().await; |
||||
|
||||
if arrival.checked_duration_since(*instant).is_none() { |
||||
return instant_to_err(&instant); |
||||
} |
||||
} |
||||
|
||||
Ok(()) |
||||
} |
||||
|
||||
/// Updates the bad auth rate limiter when a bad access token is sent where access tokens auth is an option.
|
||||
pub async fn update_post_auth_failure(&self, ip_addr: IpAddr) { |
||||
let arrival = Instant::now(); |
||||
|
||||
let RequestLimitation { |
||||
timeframe, |
||||
burst_capacity, |
||||
} = services() |
||||
.globals |
||||
.config |
||||
.rate_limiting |
||||
.target |
||||
.client |
||||
.authentication_failures; |
||||
|
||||
let mut map = self.authentication_failures.write().await; |
||||
let entry = map.entry(ip_addr); |
||||
|
||||
match entry { |
||||
Entry::Occupied(occupied_entry) => { |
||||
let entry = Arc::clone(occupied_entry.get()); |
||||
let mut entry = entry.write().await; |
||||
|
||||
let min_instant = |
||||
arrival - Duration::from_nanos(timeframe.nano_gap() * burst_capacity.get()); |
||||
*entry = entry.max(min_instant) + Duration::from_nanos(timeframe.nano_gap()); |
||||
} |
||||
Entry::Vacant(vacant_entry) => { |
||||
vacant_entry.insert(Arc::new(RwLock::new( |
||||
arrival - Duration::from_nanos(burst_capacity.get() / timeframe.nano_gap()), |
||||
))); |
||||
} |
||||
} |
||||
} |
||||
|
||||
pub async fn update_media_post_fetch(&self, target: Target, size: u64) { |
||||
if !target.rate_limited() { |
||||
return; |
||||
} |
||||
|
||||
let arrival = Instant::now(); |
||||
|
||||
let update = async |map: &MediaBucket, |
||||
target_limitation: &MediaLimitation, |
||||
global_bucket: &GlobalMediaBucket, |
||||
global_limitation: &MediaLimitation| { |
||||
let mut map = map.lock().await; |
||||
let entry = map.entry(target.clone()); |
||||
|
||||
match entry { |
||||
Entry::Occupied(occupied_entry) => { |
||||
let entry = Arc::clone(occupied_entry.get()); |
||||
|
||||
let _ = |
||||
update_media_entry(size, target_limitation, &arrival, entry, false).await; |
||||
} |
||||
Entry::Vacant(vacant_entry) => { |
||||
vacant_entry.insert(Arc::new(Mutex::new( |
||||
arrival |
||||
- Duration::from_nanos( |
||||
target_limitation.burst_capacity.as_u64() |
||||
/ target_limitation.timeframe.bytes_per_sec(), |
||||
), |
||||
))); |
||||
} |
||||
} |
||||
|
||||
let _ = update_media_entry( |
||||
size, |
||||
global_limitation, |
||||
&arrival, |
||||
Arc::clone(global_bucket), |
||||
false, |
||||
) |
||||
.await; |
||||
}; |
||||
|
||||
// updating fetch
|
||||
update( |
||||
&self.media_fetch, |
||||
&services() |
||||
.globals |
||||
.config |
||||
.rate_limiting |
||||
.target |
||||
.client |
||||
.media |
||||
.fetch, |
||||
&self.global_media_fetch, |
||||
&services() |
||||
.globals |
||||
.config |
||||
.rate_limiting |
||||
.global |
||||
.client |
||||
.media |
||||
.fetch, |
||||
) |
||||
.await; |
||||
|
||||
// updating download as well
|
||||
update( |
||||
&self.media_download, |
||||
&services() |
||||
.globals |
||||
.config |
||||
.rate_limiting |
||||
.target |
||||
.client |
||||
.media |
||||
.download, |
||||
&self.global_media_download_client, |
||||
&services() |
||||
.globals |
||||
.config |
||||
.rate_limiting |
||||
.global |
||||
.client |
||||
.media |
||||
.download, |
||||
) |
||||
.await; |
||||
} |
||||
} |
||||
|
||||
async fn update_media_entry( |
||||
size: u64, |
||||
limitation: &MediaLimitation, |
||||
arrival: &Instant, |
||||
entry: Arc<Mutex<Instant>>, |
||||
and_check: bool, |
||||
) -> Result<()> { |
||||
let mut entry = entry.lock().await; |
||||
|
||||
//TODO: use more precise conversion than secs
|
||||
let proposed_entry = get_proposed_entry(size, limitation, arrival, &entry, and_check)?; |
||||
|
||||
*entry = proposed_entry; |
||||
|
||||
Ok(()) |
||||
} |
||||
|
||||
fn get_proposed_entry( |
||||
size: u64, |
||||
limitation: &MediaLimitation, |
||||
arrival: &Instant, |
||||
entry: &MutexGuard<'_, Instant>, |
||||
and_check: bool, |
||||
) -> Result<Instant> { |
||||
let min_instant = *arrival |
||||
- Duration::from_secs( |
||||
limitation.burst_capacity.as_u64() / limitation.timeframe.bytes_per_sec(), |
||||
); |
||||
|
||||
let proposed_entry = |
||||
entry.max(min_instant) + Duration::from_secs(size / limitation.timeframe.bytes_per_sec()); |
||||
|
||||
if and_check && arrival.checked_duration_since(proposed_entry).is_none() { |
||||
return instant_to_err(&proposed_entry).map(|_| proposed_entry); |
||||
} |
||||
|
||||
Ok(proposed_entry) |
||||
} |
||||
|
||||
async fn check_media( |
||||
target: Option<Target>, |
||||
size: u64, |
||||
target_limitation: MediaLimitation, |
||||
global_limitation: MediaLimitation, |
||||
target_map: &MediaBucket, |
||||
global_bucket: &GlobalMediaBucket, |
||||
) -> Result<()> { |
||||
if !target.as_ref().is_some_and(Target::rate_limited) { |
||||
return Ok(()); |
||||
} |
||||
|
||||
let arrival = Instant::now(); |
||||
|
||||
let mut global_bucket = global_bucket.lock().await; |
||||
let proposed = get_proposed_entry(size, &global_limitation, &arrival, &global_bucket, true)?; |
||||
|
||||
if let Some(target) = target { |
||||
let mut map = target_map.lock().await; |
||||
let entry = map.entry(target); |
||||
|
||||
match entry { |
||||
Entry::Occupied(occupied_entry) => { |
||||
let entry = Arc::clone(occupied_entry.get()); |
||||
|
||||
update_media_entry(size, &target_limitation, &arrival, entry, true).await?; |
||||
} |
||||
Entry::Vacant(vacant_entry) => { |
||||
vacant_entry.insert(default_media_entry(target_limitation, arrival)); |
||||
} |
||||
} |
||||
} |
||||
|
||||
*global_bucket = proposed; |
||||
|
||||
Ok(()) |
||||
} |
||||
|
||||
fn default_media_entry( |
||||
target_limitation: MediaLimitation, |
||||
arrival: Instant, |
||||
) -> Arc<Mutex<Instant>> { |
||||
Arc::new(Mutex::new( |
||||
arrival |
||||
- Duration::from_nanos( |
||||
target_limitation.burst_capacity.as_u64() |
||||
/ target_limitation.timeframe.bytes_per_sec(), |
||||
), |
||||
)) |
||||
} |
||||
|
||||
fn instant_to_err(instant: &Instant) -> Result<()> { |
||||
let now = Instant::now(); |
||||
|
||||
Err(Error::BadRequest( |
||||
ErrorKind::LimitExceeded { |
||||
// Not using ::DateTime because conversion from Instant to SystemTime is convoluted
|
||||
retry_after: Some(RetryAfter::Delay(instant.duration_since(now))), |
||||
}, |
||||
"Rate limit exceeded", |
||||
)) |
||||
} |
||||
Loading…
Reference in new issue