diff --git a/Cargo.lock b/Cargo.lock index 8124c146..0415f3e1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -171,6 +171,7 @@ dependencies = [ "serde_path_to_error", "serde_urlencoded", "sync_wrapper 1.0.2", + "tokio", "tower 0.5.2", "tower-layer", "tower-service", @@ -3617,6 +3618,7 @@ dependencies = [ "futures-util", "pin-project-lite", "sync_wrapper 1.0.2", + "tokio", "tower-layer", "tower-service", ] diff --git a/conduit-config/src/lib.rs b/conduit-config/src/lib.rs index 057bdc3d..74f83868 100644 --- a/conduit-config/src/lib.rs +++ b/conduit-config/src/lib.rs @@ -81,6 +81,8 @@ pub struct IncompleteConfig { pub trusted_servers: Vec, #[serde(default = "default_log")] pub log: String, + #[serde(default)] + pub ip_address_detection: IpAddrDetection, pub turn_username: Option, pub turn_password: Option, pub turn_uris: Option>, @@ -137,6 +139,7 @@ pub struct Config { pub jwt_secret: Option, pub trusted_servers: Vec, pub log: String, + pub ip_address_detection: IpAddrDetection, pub turn: Option, @@ -183,6 +186,7 @@ impl From for Config { jwt_secret, trusted_servers, log, + ip_address_detection, turn_username, turn_password, turn_uris, @@ -288,6 +292,7 @@ impl From for Config { jwt_secret, trusted_servers, log, + ip_address_detection, turn, media, emergency_password, @@ -655,6 +660,19 @@ pub struct S3MediaBackend { pub directory_structure: DirectoryStructure, } +#[derive(Deserialize, Debug, Clone)] +#[serde(rename_all = "snake_case")] +pub enum IpAddrDetection { + SocketAddress, + Header(String), +} + +impl Default for IpAddrDetection { + fn default() -> Self { + Self::Header("X-Forwarded-For".to_owned()) + } +} + #[cfg(any(feature = "sqlite", feature = "rocksdb"))] impl std::fmt::Display for Config { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/conduit/Cargo.toml b/conduit/Cargo.toml index 8c7a3f1b..e4bc58f1 100644 --- a/conduit/Cargo.toml +++ b/conduit/Cargo.toml @@ -26,6 +26,7 @@ axum = { version = "0.8", default-features = false, features = [ "http2", "json", "matched-path", + "tokio", ], optional = true } axum-extra = { version = "0.10", features = ["typed-header"] } axum-server = { version = "0.7", features = ["tls-rustls"] } diff --git a/conduit/src/api/ruma_wrapper/axum.rs b/conduit/src/api/ruma_wrapper/axum.rs index 2b531adb..3e6e0272 100644 --- a/conduit/src/api/ruma_wrapper/axum.rs +++ b/conduit/src/api/ruma_wrapper/axum.rs @@ -1,9 +1,15 @@ -use std::{collections::BTreeMap, error::Error as _, iter::FromIterator, str}; +use std::{ + collections::BTreeMap, + error::Error as _, + iter::FromIterator, + net::{IpAddr, SocketAddr}, + str, +}; use axum::{ RequestPartsExt, body::Body, - extract::{FromRequest, Path}, + extract::{ConnectInfo, FromRequest, Path}, response::{IntoResponse, Response}, }; use axum_extra::{ @@ -24,7 +30,9 @@ use serde::Deserialize; use tracing::{debug, error, warn}; use super::{Ruma, RumaResponse}; -use crate::{Error, Result, service::appservice::RegistrationInfo, services}; +use crate::{ + Error, Result, config::IpAddrDetection, service::appservice::RegistrationInfo, services, +}; enum Token { Appservice(Box), @@ -99,6 +107,20 @@ where None => query_params.access_token.as_deref(), }; + let sender_ip_address: Option = + match &services().globals.config.ip_address_detection { + IpAddrDetection::SocketAddress => { + let addr: ConnectInfo = parts.extract().await?; + Some(addr.ip()) + } + IpAddrDetection::Header(name) => parts + .headers + .get(name) + .and_then(|header| header.to_str().ok()) + .map(|header| header.split_once(',').map(|(ip, _)| ip).unwrap_or(header)) + .and_then(|ip| IpAddr::from_str(ip).ok()), + }; + let token = if let Some(token) = token { if let Some(reg_info) = services().appservice.find_from_token(token).await { Token::Appservice(Box::new(reg_info.clone())) diff --git a/conduit/src/main.rs b/conduit/src/main.rs index cdac738c..23b9613d 100644 --- a/conduit/src/main.rs +++ b/conduit/src/main.rs @@ -271,7 +271,9 @@ async fn run_server() -> io::Result<()> { ) .layer(map_response(set_csp_header)); - let app = routes(config).layer(middlewares).into_make_service(); + let app = routes(config) + .layer(middlewares) + .into_make_service_with_connect_info::(); let handle = ServerHandle::new(); tokio::spawn(shutdown_signal(handle.clone())); diff --git a/docs/configuration.md b/docs/configuration.md index ebc16585..61becfc3 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -52,6 +52,7 @@ The `global` section contains the following fields: | `jwt_secret` | `string` | The secret used in the JWT to enable JWT login without it a 400 error will be returned | N/A | | `trusted_servers` | `array` | The list of trusted servers to gather public keys of offline servers | `["matrix.org"]` | | `log` | `string` | The log verbosity to use | `"warn"` | +| `ip_address_detection` | See the [IP address detection configuration](#ip-address-detection) | See the [IP address detection configuration](#ip-address-detection) | | `turn_username` | `string` | The TURN username | `""` | | `turn_password` | `string` | The TURN password | `""` | | `turn_uris` | `array` | The TURN URIs | `[]` | @@ -182,6 +183,26 @@ space = "1GB" ``` +### IP Address detection +The method used to detect the IP address of the origin of the connection, which is currently used +for rate limiting, but may be used for other features in the future. +Currently available methods are: +- `header` (default): Reads the value from the specified header, assuming it has the same format as the [`X-Forwarded-For` header](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/X-Forwarded-For), which is the header read by default. +> **WARNING**: This assumes that the header will always be set by your reverse proxy and cannot be overridden by connecting clients, so make sure your reverse proxy is configured to do this (Caddy does this for `X-Forwarded-For` by default). +- `socket_address`: Uses the IP address of the client connecting to Conduit directly. This does not work with reverse proxies, as it would just use the IP address of the reverse proxy, so this is only recommended for testing. + +To use a header other than `X-Forwarded-For`, set the following in your configuration: +```toml +[global] +ip_address_detection.header = "A-Different-Header" +``` + +If you instead want to use `socket_address`: +```toml +[global] +ip_address_detection = "socket_address" +``` + ### TLS The `tls` table contains the following fields: - `certs`: The path to the public PEM certificate