From 0bade5317fe38ff87acc69d2e556727610153c35 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 6 Jun 2024 22:31:52 +0000 Subject: [PATCH] add connection info to router Signed-off-by: Jason Volk --- src/router/layers.rs | 7 ++----- src/router/serve/mod.rs | 4 ++-- src/router/serve/plain.rs | 5 +++-- src/router/serve/tls.rs | 5 +++-- src/router/serve/unix.rs | 24 +++++++++++++++++------- 5 files changed, 27 insertions(+), 18 deletions(-) diff --git a/src/router/layers.rs b/src/router/layers.rs index 38d4ca97..c09f7e87 100644 --- a/src/router/layers.rs +++ b/src/router/layers.rs @@ -25,7 +25,7 @@ const CONDUWUIT_CSP: &str = "sandbox; default-src 'none'; font-src 'none'; scrip form-action 'none'; base-uri 'none';"; const CONDUWUIT_PERMISSIONS_POLICY: &str = "interest-cohort=(),browsing-topics=()"; -pub(crate) fn build(server: &Arc) -> io::Result> { +pub(crate) fn build(server: &Arc) -> io::Result { let layers = ServiceBuilder::new(); #[cfg(feature = "sentry_telemetry")] @@ -74,10 +74,7 @@ pub(crate) fn build(server: &Arc) -> io::Result, app: IntoMakeService, handle: ServerHandle, shutdown: broadcast::Receiver<()>, + server: &Arc, app: Router, handle: ServerHandle, shutdown: broadcast::Receiver<()>, ) -> Result<(), Error> { let config = &server.config; let addrs = config.get_bind_addrs(); diff --git a/src/router/serve/plain.rs b/src/router/serve/plain.rs index 339f8940..b79d342d 100644 --- a/src/router/serve/plain.rs +++ b/src/router/serve/plain.rs @@ -3,15 +3,16 @@ use std::{ sync::{atomic::Ordering, Arc}, }; -use axum::{routing::IntoMakeService, Router}; +use axum::Router; use axum_server::{bind, Handle as ServerHandle}; use conduit::{debug_info, Result, Server}; use tokio::task::JoinSet; use tracing::info; pub(super) async fn serve( - server: &Arc, app: IntoMakeService, handle: ServerHandle, addrs: Vec, + server: &Arc, app: Router, handle: ServerHandle, addrs: Vec, ) -> Result<()> { + let app = app.into_make_service_with_connect_info::(); let mut join_set = JoinSet::new(); for addr in &addrs { join_set.spawn_on(bind(*addr).handle(handle.clone()).serve(app.clone()), server.runtime()); diff --git a/src/router/serve/tls.rs b/src/router/serve/tls.rs index e4edeb32..6f58ce82 100644 --- a/src/router/serve/tls.rs +++ b/src/router/serve/tls.rs @@ -1,6 +1,6 @@ use std::{net::SocketAddr, sync::Arc}; -use axum::{routing::IntoMakeService, Router}; +use axum::Router; use axum_server::{bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; #[cfg(feature = "axum_dual_protocol")] use axum_server_dual_protocol::ServerExt; @@ -9,7 +9,7 @@ use tokio::task::JoinSet; use tracing::{debug, info, warn}; pub(super) async fn serve( - server: &Arc, app: IntoMakeService, handle: ServerHandle, addrs: Vec, + server: &Arc, app: Router, handle: ServerHandle, addrs: Vec, ) -> Result<()> { let config = &server.config; let tls = config.tls.as_ref().expect("TLS configuration"); @@ -31,6 +31,7 @@ pub(super) async fn serve( } let mut join_set = JoinSet::new(); + let app = app.into_make_service_with_connect_info::(); if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol { #[cfg(feature = "axum_dual_protocol")] for addr in &addrs { diff --git a/src/router/serve/unix.rs b/src/router/serve/unix.rs index a526ba29..6c406d28 100644 --- a/src/router/serve/unix.rs +++ b/src/router/serve/unix.rs @@ -1,8 +1,15 @@ #![cfg(unix)] -use std::{path::Path, sync::Arc}; +use std::{ + net::{self, IpAddr, Ipv4Addr}, + path::Path, + sync::Arc, +}; -use axum::{extract::Request, routing::IntoMakeService, Router}; +use axum::{ + extract::{connect_info::IntoMakeServiceWithConnectInfo, Request}, + Router, +}; use conduit::{debug_error, trace, utils, Error, Result, Server}; use hyper::{body::Incoming, service::service_fn}; use hyper_util::{ @@ -19,12 +26,15 @@ use tower::{Service, ServiceExt}; use tracing::{debug, info, warn}; use utils::unwrap_infallible; +type MakeService = IntoMakeServiceWithConnectInfo; + +static NULL_ADDR: net::SocketAddr = net::SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0); + #[tracing::instrument(skip_all)] -pub(super) async fn serve( - server: &Arc, app: IntoMakeService, mut shutdown: broadcast::Receiver<()>, -) -> Result<()> { +pub(super) async fn serve(server: &Arc, app: Router, mut shutdown: broadcast::Receiver<()>) -> Result<()> { let mut tasks = JoinSet::<()>::new(); let executor = TokioExecutor::new(); + let app = app.into_make_service_with_connect_info::(); let builder = server::conn::auto::Builder::new(executor); let listener = init(server).await?; loop { @@ -46,14 +56,14 @@ pub(super) async fn serve( #[allow(clippy::let_underscore_must_use)] async fn accept( - server: &Arc, listener: &UnixListener, tasks: &mut JoinSet<()>, mut app: IntoMakeService, + server: &Arc, listener: &UnixListener, tasks: &mut JoinSet<()>, mut app: MakeService, builder: server::conn::auto::Builder, conn: (UnixStream, SocketAddr), ) { let (socket, remote) = conn; let socket = TokioIo::new(socket); trace!(?listener, ?socket, ?remote, "accepted"); - let called = unwrap_infallible(app.call(()).await); + let called = unwrap_infallible(app.call(NULL_ADDR).await); let handler = service_fn(move |req: Request| called.clone().oneshot(req)); let task = async move {