Browse Source

feat: make IP address detection method configurable

rate-limiting
Matthias Ahouansou 5 months ago
parent
commit
1f46c677f4
No known key found for this signature in database
  1. 2
      Cargo.lock
  2. 18
      conduit-config/src/lib.rs
  3. 1
      conduit/Cargo.toml
  4. 28
      conduit/src/api/ruma_wrapper/axum.rs
  5. 4
      conduit/src/main.rs
  6. 21
      docs/configuration.md

2
Cargo.lock generated

@ -171,6 +171,7 @@ dependencies = [
"serde_path_to_error",
"serde_urlencoded",
"sync_wrapper 1.0.2",
"tokio",
"tower 0.5.2",
"tower-layer",
"tower-service",
@ -3617,6 +3618,7 @@ dependencies = [
"futures-util",
"pin-project-lite",
"sync_wrapper 1.0.2",
"tokio",
"tower-layer",
"tower-service",
]

18
conduit-config/src/lib.rs

@ -81,6 +81,8 @@ pub struct IncompleteConfig {
pub trusted_servers: Vec<OwnedServerName>,
#[serde(default = "default_log")]
pub log: String,
#[serde(default)]
pub ip_address_detection: IpAddrDetection,
pub turn_username: Option<String>,
pub turn_password: Option<String>,
pub turn_uris: Option<Vec<String>>,
@ -137,6 +139,7 @@ pub struct Config {
pub jwt_secret: Option<String>,
pub trusted_servers: Vec<OwnedServerName>,
pub log: String,
pub ip_address_detection: IpAddrDetection,
pub turn: Option<TurnConfig>,
@ -183,6 +186,7 @@ impl From<IncompleteConfig> for Config {
jwt_secret,
trusted_servers,
log,
ip_address_detection,
turn_username,
turn_password,
turn_uris,
@ -288,6 +292,7 @@ impl From<IncompleteConfig> for Config {
jwt_secret,
trusted_servers,
log,
ip_address_detection,
turn,
media,
emergency_password,
@ -655,6 +660,19 @@ pub struct S3MediaBackend {
pub directory_structure: DirectoryStructure,
}
#[derive(Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
pub enum IpAddrDetection {
SocketAddress,
Header(String),
}
impl Default for IpAddrDetection {
fn default() -> Self {
Self::Header("X-Forwarded-For".to_owned())
}
}
#[cfg(any(feature = "sqlite", feature = "rocksdb"))]
impl std::fmt::Display for Config {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {

1
conduit/Cargo.toml

@ -26,6 +26,7 @@ axum = { version = "0.8", default-features = false, features = [
"http2",
"json",
"matched-path",
"tokio",
], optional = true }
axum-extra = { version = "0.10", features = ["typed-header"] }
axum-server = { version = "0.7", features = ["tls-rustls"] }

28
conduit/src/api/ruma_wrapper/axum.rs

@ -1,9 +1,15 @@
use std::{collections::BTreeMap, error::Error as _, iter::FromIterator, str};
use std::{
collections::BTreeMap,
error::Error as _,
iter::FromIterator,
net::{IpAddr, SocketAddr},
str,
};
use axum::{
RequestPartsExt,
body::Body,
extract::{FromRequest, Path},
extract::{ConnectInfo, FromRequest, Path},
response::{IntoResponse, Response},
};
use axum_extra::{
@ -24,7 +30,9 @@ use serde::Deserialize;
use tracing::{debug, error, warn};
use super::{Ruma, RumaResponse};
use crate::{Error, Result, service::appservice::RegistrationInfo, services};
use crate::{
Error, Result, config::IpAddrDetection, service::appservice::RegistrationInfo, services,
};
enum Token {
Appservice(Box<RegistrationInfo>),
@ -99,6 +107,20 @@ where
None => query_params.access_token.as_deref(),
};
let sender_ip_address: Option<IpAddr> =
match &services().globals.config.ip_address_detection {
IpAddrDetection::SocketAddress => {
let addr: ConnectInfo<SocketAddr> = parts.extract().await?;
Some(addr.ip())
}
IpAddrDetection::Header(name) => parts
.headers
.get(name)
.and_then(|header| header.to_str().ok())
.map(|header| header.split_once(',').map(|(ip, _)| ip).unwrap_or(header))
.and_then(|ip| IpAddr::from_str(ip).ok()),
};
let token = if let Some(token) = token {
if let Some(reg_info) = services().appservice.find_from_token(token).await {
Token::Appservice(Box::new(reg_info.clone()))

4
conduit/src/main.rs

@ -271,7 +271,9 @@ async fn run_server() -> io::Result<()> {
)
.layer(map_response(set_csp_header));
let app = routes(config).layer(middlewares).into_make_service();
let app = routes(config)
.layer(middlewares)
.into_make_service_with_connect_info::<SocketAddr>();
let handle = ServerHandle::new();
tokio::spawn(shutdown_signal(handle.clone()));

21
docs/configuration.md

@ -52,6 +52,7 @@ The `global` section contains the following fields:
| `jwt_secret` | `string` | The secret used in the JWT to enable JWT login without it a 400 error will be returned | N/A |
| `trusted_servers` | `array` | The list of trusted servers to gather public keys of offline servers | `["matrix.org"]` |
| `log` | `string` | The log verbosity to use | `"warn"` |
| `ip_address_detection` | See the [IP address detection configuration](#ip-address-detection) | See the [IP address detection configuration](#ip-address-detection) |
| `turn_username` | `string` | The TURN username | `""` |
| `turn_password` | `string` | The TURN password | `""` |
| `turn_uris` | `array` | The TURN URIs | `[]` |
@ -182,6 +183,26 @@ space = "1GB"
```
### IP Address detection
The method used to detect the IP address of the origin of the connection, which is currently used
for rate limiting, but may be used for other features in the future.
Currently available methods are:
- `header` (default): Reads the value from the specified header, assuming it has the same format as the [`X-Forwarded-For` header](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/X-Forwarded-For), which is the header read by default.
> **WARNING**: This assumes that the header will always be set by your reverse proxy and cannot be overridden by connecting clients, so make sure your reverse proxy is configured to do this (Caddy does this for `X-Forwarded-For` by default).
- `socket_address`: Uses the IP address of the client connecting to Conduit directly. This does not work with reverse proxies, as it would just use the IP address of the reverse proxy, so this is only recommended for testing.
To use a header other than `X-Forwarded-For`, set the following in your configuration:
```toml
[global]
ip_address_detection.header = "A-Different-Header"
```
If you instead want to use `socket_address`:
```toml
[global]
ip_address_detection = "socket_address"
```
### TLS
The `tls` table contains the following fields:
- `certs`: The path to the public PEM certificate

Loading…
Cancel
Save