add some interruption points in recursive event handling to prevent shutdown hangs

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-10-22 07:15:28 +00:00 committed by strawberry
parent dd6621a720
commit b08c1241a8
3 changed files with 15 additions and 4 deletions

View file

@ -71,7 +71,7 @@ pub(crate) async fn send_transaction_message_route(
"Starting txn", "Starting txn",
); );
let resolved_map = handle_pdus(&services, &client, &body.pdus, origin, &txn_start_time).await; let resolved_map = handle_pdus(&services, &client, &body.pdus, origin, &txn_start_time).await?;
handle_edus(&services, &client, &body.edus, origin).await; handle_edus(&services, &client, &body.edus, origin).await;
debug!( debug!(
@ -93,7 +93,7 @@ pub(crate) async fn send_transaction_message_route(
async fn handle_pdus( async fn handle_pdus(
services: &Services, _client: &IpAddr, pdus: &[Box<RawJsonValue>], origin: &ServerName, txn_start_time: &Instant, services: &Services, _client: &IpAddr, pdus: &[Box<RawJsonValue>], origin: &ServerName, txn_start_time: &Instant,
) -> ResolvedMap { ) -> Result<ResolvedMap> {
let mut parsed_pdus = Vec::with_capacity(pdus.len()); let mut parsed_pdus = Vec::with_capacity(pdus.len());
for pdu in pdus { for pdu in pdus {
parsed_pdus.push(match services.rooms.event_handler.parse_incoming_pdu(pdu).await { parsed_pdus.push(match services.rooms.event_handler.parse_incoming_pdu(pdu).await {
@ -110,6 +110,7 @@ async fn handle_pdus(
let mut resolved_map = BTreeMap::new(); let mut resolved_map = BTreeMap::new();
for (event_id, value, room_id) in parsed_pdus { for (event_id, value, room_id) in parsed_pdus {
services.server.check_running()?;
let pdu_start_time = Instant::now(); let pdu_start_time = Instant::now();
let mutex_lock = services let mutex_lock = services
.rooms .rooms
@ -143,7 +144,7 @@ async fn handle_pdus(
} }
} }
resolved_map Ok(resolved_map)
} }
async fn handle_edus(services: &Services, client: &IpAddr, edus: &[Raw<Edu>], origin: &ServerName) { async fn handle_edus(services: &Services, client: &IpAddr, edus: &[Raw<Edu>], origin: &ServerName) {

View file

@ -5,7 +5,7 @@ use std::{
use tokio::{runtime, sync::broadcast}; use tokio::{runtime, sync::broadcast};
use crate::{config::Config, log::Log, metrics::Metrics, Err, Result}; use crate::{config::Config, err, log::Log, metrics::Metrics, Err, Result};
/// Server runtime state; public portion /// Server runtime state; public portion
pub struct Server { pub struct Server {
@ -107,6 +107,13 @@ impl Server {
.expect("runtime handle available in Server") .expect("runtime handle available in Server")
} }
#[inline]
pub fn check_running(&self) -> Result {
self.running()
.then_some(())
.ok_or_else(|| err!(debug_warn!("Server is shutting down.")))
}
#[inline] #[inline]
pub fn running(&self) -> bool { !self.stopping.load(Ordering::Acquire) } pub fn running(&self) -> bool { !self.stopping.load(Ordering::Acquire) }

View file

@ -205,6 +205,7 @@ impl Service {
debug!(events = ?sorted_prev_events, "Got previous events"); debug!(events = ?sorted_prev_events, "Got previous events");
for prev_id in sorted_prev_events { for prev_id in sorted_prev_events {
self.services.server.check_running()?;
match self match self
.handle_prev_pdu( .handle_prev_pdu(
origin, origin,
@ -1268,6 +1269,8 @@ impl Service {
let mut amount = 0; let mut amount = 0;
while let Some(prev_event_id) = todo_outlier_stack.pop() { while let Some(prev_event_id) = todo_outlier_stack.pop() {
self.services.server.check_running()?;
if let Some((pdu, mut json_opt)) = self if let Some((pdu, mut json_opt)) = self
.fetch_and_handle_outliers(origin, &[prev_event_id.clone()], create_event, room_id, room_version_id) .fetch_and_handle_outliers(origin, &[prev_event_id.clone()], create_event, room_id, room_version_id)
.boxed() .boxed()