diff --git a/simln-lib/src/sim_node.rs b/simln-lib/src/sim_node.rs index b7f2be05..5c0e3dd5 100755 --- a/simln-lib/src/sim_node.rs +++ b/simln-lib/src/sim_node.rs @@ -845,7 +845,8 @@ pub struct HtlcRef { /// If any interceptor returns a `ForwardingError`, it triggers a shutdown signal to all other /// interceptors and waits for them to shutdown. The forwarding error is then returned. /// If any interceptor returns a `CriticalError`, it is immediately returned to trigger a -/// simulation shutdown. TODO: If a critical error happens, we could instead trigger the shutdown +/// simulation shutdown. +// If a critical error happens, we could instead trigger the shutdown /// for the interceptors and let them finish before returning. /// While waiting on the interceptors, it listens on the shutdown_listener for any signals from /// upstream and trigger shutdowns to the interceptors if needed. @@ -880,6 +881,8 @@ async fn handle_intercepted_htlc( // the HTLC. If any of the interceptors did return an error, we send a shutdown signal // to the other interceptors that may have not returned yet. let mut interceptor_failure = None; + let mut critical_error = None; + 'get_resp: loop { tokio::select! { res = intercepts.join_next() => { @@ -917,8 +920,11 @@ async fn handle_intercepted_htlc( interceptor_failure = Some(fwd_error); interceptor_trigger.trigger(); }, + // Interceptor returned a CriticalError, Trigger a shutdown and store the error + // to return after all interceptors have finished. Err(e) => { - return Err(e); + critical_error = Some(e); + interceptor_trigger.trigger(); }, } } @@ -933,6 +939,11 @@ async fn handle_intercepted_htlc( } } + // if we have a critical error, returned it after all interceptors have finished. + if let Some(e) = critical_error { + return Err(e); + } + if let Some(e) = interceptor_failure { return Ok(Err(e)); } @@ -2470,6 +2481,38 @@ mod tests { assert!(response.unwrap().unwrap() == CustomRecords::from([(1000, vec![1])])); } + /// Tests intercepted htlc with a critical error from one interceptor. + #[tokio::test] + async fn test_intercepted_htlc_critical_error() { + // Interceptor that will return a critical error. + let mut mock_interceptor = MockTestInterceptor::new(); + mock_interceptor + .expect_intercept_htlc() + .returning(|_| Err(CriticalError::InterceptorError("critical failure".into()))); + mock_interceptor + .expect_notify_resolution() + .returning(|_| Ok(())); + + let (interceptor_trigger, interceptor_listener) = triggered::trigger(); + let mock_request = create_intercept_request(interceptor_listener); + + let mock_interceptor: Arc = Arc::new(mock_interceptor); + let interceptors = vec![mock_interceptor]; + let (_, shutdown_listener) = triggered::trigger(); + + let result = handle_intercepted_htlc( + mock_request, + &interceptors, + interceptor_trigger, + shutdown_listener, + ) + .await; + + assert!(result.is_err()); + let err_string = format!("{:?}", result.unwrap_err()); + assert!(err_string.contains("critical failure")); + } + /// Tests a long resolving interceptor gets correctly interrupted during a shutdown. #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_shutdown_intercepted_htlc() {