From 08edef5cec2d5500e002b35bec00ec0bdce942de Mon Sep 17 00:00:00 2001 From: David Steiner Date: Mon, 19 Jan 2026 13:48:35 +0100 Subject: [PATCH 1/9] Remove unwraps during session creation --- Cargo.lock | 3 + crates/hotfix-dictionary/Cargo.toml | 1 + crates/hotfix-dictionary/src/dictionary.rs | 9 +-- crates/hotfix-dictionary/src/error.rs | 10 +++ crates/hotfix-dictionary/src/lib.rs | 2 + crates/hotfix-dictionary/src/quickfix.rs | 49 +++++-------- crates/hotfix/src/initiator.rs | 15 ++-- crates/hotfix/src/session.rs | 25 +++---- crates/hotfix/src/session/session_ref.rs | 14 ++-- crates/hotfix/tests/common/setup.rs | 3 +- examples/load-testing/Cargo.toml | 1 + examples/load-testing/src/main.rs | 29 ++++---- examples/simple-new-order/Cargo.toml | 1 + examples/simple-new-order/src/main.rs | 82 ++++++++++++++-------- 14 files changed, 140 insertions(+), 104 deletions(-) create mode 100644 crates/hotfix-dictionary/src/error.rs diff --git a/Cargo.lock b/Cargo.lock index a056672..8ddaef2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1347,6 +1347,7 @@ dependencies = [ "smartstring", "strum", "strum_macros", + "thiserror", ] [[package]] @@ -1833,6 +1834,7 @@ checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" name = "load-testing" version = "0.1.0" dependencies = [ + "anyhow", "async-trait", "clap", "hotfix", @@ -3152,6 +3154,7 @@ dependencies = [ name = "simple-new-order" version = "0.1.0" dependencies = [ + "anyhow", "async-trait", "axum", "clap", diff --git a/crates/hotfix-dictionary/Cargo.toml b/crates/hotfix-dictionary/Cargo.toml index 79c322f..1362ec0 100644 --- a/crates/hotfix-dictionary/Cargo.toml +++ b/crates/hotfix-dictionary/Cargo.toml @@ -33,3 +33,4 @@ roxmltree.workspace = true smartstring = { workspace = true, optional = true } strum.workspace = true strum_macros.workspace = true +thiserror.workspace = true \ No newline at end of file diff --git a/crates/hotfix-dictionary/src/dictionary.rs b/crates/hotfix-dictionary/src/dictionary.rs index 933a413..acd6642 100644 --- a/crates/hotfix-dictionary/src/dictionary.rs +++ b/crates/hotfix-dictionary/src/dictionary.rs @@ -1,7 +1,8 @@ use crate::{Component, ComponentData, Datatype, DatatypeData, Field, FieldData}; +use crate::error::ParseError; use crate::message_definition::{MessageData, MessageDefinition}; -use crate::quickfix::{ParseDictionaryError, QuickFixReader}; +use crate::quickfix::QuickFixReader; use crate::string::SmartString; use fnv::FnvHashMap; @@ -52,9 +53,9 @@ impl Dictionary { /// Attempts to read a QuickFIX-style specification file and convert it into /// a [`Dictionary`]. - pub fn from_quickfix_spec(input: &str) -> Result { + pub fn from_quickfix_spec(input: &str) -> Result { let xml_document = - roxmltree::Document::parse(input).map_err(|_| ParseDictionaryError::InvalidFormat)?; + roxmltree::Document::parse(input).map_err(|_| ParseError::InvalidFormat)?; QuickFixReader::new(&xml_document) } @@ -71,7 +72,7 @@ impl Dictionary { self.version.as_str() } - pub fn load_from_file(path: &str) -> Result { + pub fn load_from_file(path: &str) -> Result { let spec = std::fs::read_to_string(path) .unwrap_or_else(|_| panic!("unable to read FIX dictionary file at {path}")); Dictionary::from_quickfix_spec(&spec) diff --git a/crates/hotfix-dictionary/src/error.rs b/crates/hotfix-dictionary/src/error.rs new file mode 100644 index 0000000..e143ded --- /dev/null +++ b/crates/hotfix-dictionary/src/error.rs @@ -0,0 +1,10 @@ +pub(crate) type ParseResult = Result; + +/// The error type that can arise when decoding a QuickFIX Dictionary. +#[derive(Clone, Debug, thiserror::Error)] +pub enum ParseError { + #[error("invalid format")] + InvalidFormat, + #[error("invalid data: {0}")] + InvalidData(String), +} diff --git a/crates/hotfix-dictionary/src/lib.rs b/crates/hotfix-dictionary/src/lib.rs index 83bce7b..e580be9 100644 --- a/crates/hotfix-dictionary/src/lib.rs +++ b/crates/hotfix-dictionary/src/lib.rs @@ -4,6 +4,7 @@ mod builder; mod component; mod datatype; mod dictionary; +mod error; mod field; mod layout; mod message_definition; @@ -14,6 +15,7 @@ use component::{Component, ComponentData}; use datatype::DatatypeData; pub use datatype::{Datatype, FixDatatype}; pub use dictionary::Dictionary; +pub use error::ParseError; pub use field::{Field, FieldEnum, FieldLocation, IsFieldDefinition}; use field::{FieldData, FieldEnumData}; use fnv::FnvHashMap; diff --git a/crates/hotfix-dictionary/src/quickfix.rs b/crates/hotfix-dictionary/src/quickfix.rs index 9b587d1..8df9e54 100644 --- a/crates/hotfix-dictionary/src/quickfix.rs +++ b/crates/hotfix-dictionary/src/quickfix.rs @@ -1,5 +1,6 @@ use crate::builder::DictionaryBuilder; use crate::component::{ComponentData, FixmlComponentAttributes}; +use crate::error::{ParseError, ParseResult}; use crate::message_definition::MessageData; use crate::string::SmartString; use crate::{ @@ -29,7 +30,7 @@ impl<'a> QuickFixReader<'a> { if child.is_element() { let name = child .attribute("name") - .ok_or(ParseDictionaryError::InvalidFormat)? + .ok_or(ParseError::InvalidFormat)? .to_string(); import_component(&mut reader.builder, child, &name)?; } @@ -61,23 +62,17 @@ impl<'a> QuickFixReader<'a> { let find_tagged_child = |tag: &str| { root.children() .find(|n| n.has_tag_name(tag)) - .ok_or_else(|| ParseDictionaryError::InvalidData(format!("<{tag}> tag not found"))) + .ok_or_else(|| ParseError::InvalidData(format!("<{tag}> tag not found"))) }; let version_type = root .attribute("type") - .ok_or(ParseDictionaryError::InvalidData( - "No version attribute.".to_string(), - ))?; - let version_major = root - .attribute("major") - .ok_or(ParseDictionaryError::InvalidData( - "No major version attribute.".to_string(), - ))?; - let version_minor = root - .attribute("minor") - .ok_or(ParseDictionaryError::InvalidData( - "No minor version attribute.".to_string(), - ))?; + .ok_or(ParseError::InvalidData("No version attribute.".to_string()))?; + let version_major = root.attribute("major").ok_or(ParseError::InvalidData( + "No major version attribute.".to_string(), + ))?; + let version_minor = root.attribute("minor").ok_or(ParseError::InvalidData( + "No minor version attribute.".to_string(), + ))?; let version_sp = root.attribute("servicepack").unwrap_or("0"); let version = format!( "{}.{}.{}{}", @@ -104,19 +99,19 @@ impl<'a> QuickFixReader<'a> { fn import_field(builder: &mut DictionaryBuilder, node: roxmltree::Node) -> ParseResult<()> { if node.tag_name().name() != "field" { - return Err(ParseDictionaryError::InvalidFormat); + return Err(ParseError::InvalidFormat); } let data_type_name = import_datatype(builder, node); let value_restrictions = value_restrictions_from_node(node, data_type_name.clone()); let name = node .attribute("name") - .ok_or(ParseDictionaryError::InvalidFormat)? + .ok_or(ParseError::InvalidFormat)? .into(); let tag = node .attribute("number") - .ok_or(ParseDictionaryError::InvalidFormat)? + .ok_or(ParseError::InvalidFormat)? .parse() - .map_err(|_| ParseDictionaryError::InvalidFormat)?; + .map_err(|_| ParseError::InvalidFormat)?; let field = FieldData { name, tag, @@ -142,11 +137,11 @@ fn import_message(builder: &mut DictionaryBuilder, node: roxmltree::Node) -> Par let message = MessageData { name: node .attribute("name") - .ok_or(ParseDictionaryError::InvalidFormat)? + .ok_or(ParseError::InvalidFormat)? .into(), msg_type: node .attribute("msgtype") - .ok_or(ParseDictionaryError::InvalidFormat)? + .ok_or(ParseError::InvalidFormat)? .into(), component_id: 0, layout_items, @@ -289,7 +284,7 @@ fn import_layout_item( } } _ => { - return Err(ParseDictionaryError::InvalidFormat); + return Err(ParseError::InvalidFormat); } }; let item = LayoutItemData { required, kind }; @@ -306,13 +301,3 @@ fn panic_missing_tag_in_element(elem: roxmltree::Node, tag: &str) -> ! { .unwrap_or("Error retrieving element text") ); } - -type ParseError = ParseDictionaryError; -type ParseResult = Result; - -/// The error type that can arise when decoding a QuickFIX Dictionary. -#[derive(Clone, Debug)] -pub enum ParseDictionaryError { - InvalidFormat, - InvalidData(String), -} diff --git a/crates/hotfix/src/initiator.rs b/crates/hotfix/src/initiator.rs index a9360e7..e8539cd 100644 --- a/crates/hotfix/src/initiator.rs +++ b/crates/hotfix/src/initiator.rs @@ -6,6 +6,7 @@ //! The initiator establishes the transport layer connection with //! the peer, and sends the initial Logon (35=A) message. For transport, //! `HotFIX` supports plain TCP and encrypted TLS over TCP connections. +use anyhow::Result; use std::time::Duration; use tokio::sync::watch; use tokio::time::sleep; @@ -30,8 +31,8 @@ impl Initiator { config: SessionConfig, application: impl Application, store: impl MessageStore + Send + Sync + 'static, - ) -> Self { - let session_ref = InternalSessionRef::new(config.clone(), application, store); + ) -> Result { + let session_ref = InternalSessionRef::new(config.clone(), application, store)?; let (completion_tx, completion_rx) = watch::channel(false); tokio::spawn({ @@ -40,14 +41,16 @@ impl Initiator { establish_connection(config, session_ref, completion_tx) }); - Self { + let initiator = Self { config, session_handle: session_ref.into(), completion_rx, - } + }; + + Ok(initiator) } - pub async fn send_message(&self, msg: Outbound) -> anyhow::Result<()> { + pub async fn send_message(&self, msg: Outbound) -> Result<()> { self.session_handle.send_message(msg).await?; Ok(()) @@ -61,7 +64,7 @@ impl Initiator { self.session_handle.clone() } - pub async fn shutdown(self, reconnect: bool) -> anyhow::Result<()> { + pub async fn shutdown(self, reconnect: bool) -> Result<()> { self.session_handle.shutdown(reconnect).await?; tokio::time::timeout(Duration::from_secs(5), self.wait_for_shutdown()).await?; diff --git a/crates/hotfix/src/session.rs b/crates/hotfix/src/session.rs index 2b6f77a..86e48c9 100644 --- a/crates/hotfix/src/session.rs +++ b/crates/hotfix/src/session.rs @@ -13,7 +13,7 @@ use crate::message::parser::RawFixMessage; use crate::message::{InboundMessage, generate_message}; use crate::store::MessageStore; use crate::transport::writer::WriterRef; -use anyhow::{Result, anyhow}; +use anyhow::{Result, anyhow, bail}; use chrono::Utc; use hotfix_message::dict::Dictionary; use hotfix_message::message::{Config as MessageConfig, Message}; @@ -78,16 +78,15 @@ where config: SessionConfig, application: App, store: Store, - ) -> Session { + ) -> Result> { let schedule_check_timer = sleep(Duration::from_secs(SCHEDULE_CHECK_INTERVAL)); - let dictionary = Self::get_data_dictionary(&config); + let dictionary = Self::get_data_dictionary(&config)?; let message_config = MessageConfig::default(); - let message_builder = MessageBuilder::new(dictionary, message_config) - .expect("failed to create message builder"); - let schedule = config.schedule.as_ref().try_into().unwrap(); + let message_builder = MessageBuilder::new(dictionary, message_config)?; + let schedule = config.schedule.as_ref().try_into()?; - Self { + let session = Self { config, schedule, message_config, @@ -98,17 +97,19 @@ where schedule_check_timer: Box::pin(schedule_check_timer), reset_on_next_logon: false, _phantom: std::marker::PhantomData, - } + }; + + Ok(session) } - fn get_data_dictionary(config: &SessionConfig) -> Dictionary { + fn get_data_dictionary(config: &SessionConfig) -> Result { match &config.data_dictionary_path { None => match config.begin_string.as_str() { #[cfg(feature = "fix44")] - "FIX.4.4" => Dictionary::fix44(), - _ => panic!("unsupported begin string: {}", config.begin_string), + "FIX.4.4" => Ok(Dictionary::fix44()), + _ => bail!("unsupported begin string: {}", config.begin_string), }, - Some(dictionary_path) => Dictionary::load_from_file(dictionary_path).unwrap(), + Some(dictionary_path) => Ok(Dictionary::load_from_file(dictionary_path)?), } } diff --git a/crates/hotfix/src/session/session_ref.rs b/crates/hotfix/src/session/session_ref.rs index 8a449d8..4c22b19 100644 --- a/crates/hotfix/src/session/session_ref.rs +++ b/crates/hotfix/src/session/session_ref.rs @@ -1,3 +1,7 @@ +use anyhow::Result; +use tokio::sync::{mpsc, oneshot}; +use tracing::debug; + use crate::config::SessionConfig; use crate::message::{InboundMessage, OutboundMessage, RawFixMessage}; use crate::session::Session; @@ -6,8 +10,6 @@ use crate::session::event::{AwaitingActiveSessionResponse, SessionEvent}; use crate::store::MessageStore; use crate::transport::writer::WriterRef; use crate::{Application, session}; -use tokio::sync::{mpsc, oneshot}; -use tracing::debug; #[derive(Clone)] pub struct InternalSessionRef { @@ -21,11 +23,11 @@ impl InternalSessionRef { config: SessionConfig, application: impl Application, store: impl MessageStore + Send + Sync + 'static, - ) -> Self { + ) -> Result { let (event_sender, event_receiver) = mpsc::channel::(100); let (outbound_message_sender, outbound_message_receiver) = mpsc::channel::(10); let (admin_request_sender, admin_request_receiver) = mpsc::channel::(10); - let session = Session::new(config, application, store); + let session = Session::new(config, application, store)?; tokio::spawn(session::run_session( session, event_receiver, @@ -33,11 +35,11 @@ impl InternalSessionRef { admin_request_receiver, )); - Self { + Ok(Self { event_sender, outbound_message_sender, admin_request_sender, - } + }) } pub async fn register_writer(&self, writer: WriterRef) { diff --git a/crates/hotfix/tests/common/setup.rs b/crates/hotfix/tests/common/setup.rs index 2392f3e..9480fba 100644 --- a/crates/hotfix/tests/common/setup.rs +++ b/crates/hotfix/tests/common/setup.rs @@ -28,7 +28,8 @@ pub async fn given_a_connected_session_with_store( let counterparty_config = create_counterparty_session_config(config.clone()); let (message_tx, message_rx) = tokio::sync::mpsc::unbounded_channel(); - let session = InternalSessionRef::new(config, FakeApplication::new(message_tx), message_store); + let session = InternalSessionRef::new(config, FakeApplication::new(message_tx), message_store) + .expect("session to be created successfully"); let session_spy = SessionSpy::new(session.clone().into(), message_rx); let mock_counterparty = FakeCounterparty::start(session.clone(), counterparty_config).await; diff --git a/examples/load-testing/Cargo.toml b/examples/load-testing/Cargo.toml index 6fdf5bc..3dc1964 100644 --- a/examples/load-testing/Cargo.toml +++ b/examples/load-testing/Cargo.toml @@ -9,6 +9,7 @@ publish = false [dependencies] hotfix = { path = "../../crates/hotfix", features = ["fix44", "mongodb", "redb"] } +anyhow.workspace = true async-trait.workspace = true clap = { workspace = true, features = ["derive"] } tokio = { workspace = true, features = ["full"] } diff --git a/examples/load-testing/src/main.rs b/examples/load-testing/src/main.rs index 52001de..5686451 100644 --- a/examples/load-testing/src/main.rs +++ b/examples/load-testing/src/main.rs @@ -1,8 +1,7 @@ mod application; mod messages; -use crate::application::LoadTestingApplication; -use crate::messages::{ExecutionReport, NewOrderSingle, OutboundMsg}; +use anyhow::Result; use clap::{Parser, ValueEnum}; use hotfix::config::SessionConfig; use hotfix::field_types::{Date, Timestamp}; @@ -16,6 +15,9 @@ use tokio::sync::mpsc::{UnboundedReceiver, unbounded_channel}; use tracing::info; use tracing_subscriber::EnvFilter; +use crate::application::LoadTestingApplication; +use crate::messages::{ExecutionReport, NewOrderSingle, OutboundMsg}; + #[derive(ValueEnum, Clone, Debug)] #[clap(rename_all = "lower")] enum Database { @@ -37,7 +39,7 @@ struct Args { const WAIT_SECONDS: u64 = 3; -fn main() { +fn main() -> Result<()> { let args = Args::parse(); let runtime = Builder::new_multi_thread() @@ -50,10 +52,12 @@ fn main() { runtime.block_on(run_load_test( args.message_count, args.database.unwrap_or(Database::Memory), - )); + ))?; + + Ok(()) } -async fn run_load_test(message_count: u32, database: Database) { +async fn run_load_test(message_count: u32, database: Database) -> Result<()> { tracing_subscriber::fmt() .pretty() .with_env_filter(EnvFilter::from_default_env()) @@ -63,7 +67,7 @@ async fn run_load_test(message_count: u32, database: Database) { let (tx, rx) = unbounded_channel(); let application = LoadTestingApplication::new(tx); - let initiator = start_session(config, database, application).await; + let initiator = start_session(config, database, application).await?; for s in 0..WAIT_SECONDS { info!("starting in {} seconds", WAIT_SECONDS - s); @@ -74,24 +78,23 @@ async fn run_load_test(message_count: u32, database: Database) { let messages_handler = tokio::spawn(submit_messages(initiator.session_handle(), message_count)); let report_handler = tokio::spawn(listen_for_reports(rx, message_count)); - messages_handler.await.unwrap(); + messages_handler.await?; info!("sent all messages, awaiting responses"); - report_handler.await.unwrap(); + report_handler.await?; let duration = start.elapsed(); info!("completed run in {duration:?} seconds"); - initiator - .shutdown(false) - .await - .expect("graceful shutdown to succeed"); + initiator.shutdown(false).await?; + + Ok(()) } async fn start_session( session_config: SessionConfig, db_config: Database, app: LoadTestingApplication, -) -> Initiator { +) -> Result> { match db_config { Database::Memory => { let store = hotfix::store::in_memory::InMemoryMessageStore::default(); diff --git a/examples/simple-new-order/Cargo.toml b/examples/simple-new-order/Cargo.toml index 43ab193..51e2458 100644 --- a/examples/simple-new-order/Cargo.toml +++ b/examples/simple-new-order/Cargo.toml @@ -10,6 +10,7 @@ publish = false hotfix = { path = "../../crates/hotfix", features = ["fix44", "mongodb", "redb"] } hotfix-web = { path = "../../crates/hotfix-web", features = ["ui"] } +anyhow.workspace = true async-trait.workspace = true axum.workspace = true clap = { workspace = true, features = ["derive"] } diff --git a/examples/simple-new-order/src/main.rs b/examples/simple-new-order/src/main.rs index e09c707..c4a3ded 100644 --- a/examples/simple-new-order/src/main.rs +++ b/examples/simple-new-order/src/main.rs @@ -1,8 +1,7 @@ mod application; mod messages; -use crate::application::TestApplication; -use crate::messages::{NewOrderSingle, OutboundMsg}; +use anyhow::{Context, Result}; use clap::{Parser, ValueEnum}; use hotfix::config::Config; use hotfix::field_types::{Date, Timestamp}; @@ -18,6 +17,9 @@ use tokio_util::sync::CancellationToken; use tracing::info; use tracing_subscriber::EnvFilter; +use crate::application::TestApplication; +use crate::messages::{NewOrderSingle, OutboundMsg}; + #[derive(ValueEnum, Clone, Debug)] #[clap(rename_all = "lower")] enum Database { @@ -37,23 +39,24 @@ struct Args { } #[tokio::main] -async fn main() { +async fn main() -> Result<()> { let args = Args::parse(); if let Some(path) = args.logfile { let p = Path::new(&path); - std::fs::create_dir_all(p.parent().unwrap()).unwrap(); + let parent = p.parent().context("log file path has no parent directory")?; + std::fs::create_dir_all(parent)?; let logfile = std::fs::OpenOptions::new() .write(true) .create(true) .truncate(true) .open(p) - .expect("log file to open successfully"); + .context("failed to open log file")?; let subscriber = tracing_subscriber::fmt::Subscriber::builder() .with_writer(logfile) .with_env_filter(EnvFilter::from_default_env()) .finish(); - tracing::subscriber::set_global_default(subscriber).unwrap(); + tracing::subscriber::set_global_default(subscriber)?; } else { tracing_subscriber::fmt() .pretty() @@ -63,40 +66,48 @@ async fn main() { let db_config = args.database.unwrap_or(Database::Redb); let app = TestApplication::default(); - let initiator = start_session(&args.config, &db_config, app).await; + let initiator = start_session(&args.config, &db_config, app).await?; let status_service_token = CancellationToken::new(); - tokio::spawn(start_web_service( - initiator.session_handle(), - status_service_token.child_token(), - )); + let session_handle = initiator.session_handle(); + let child_token = status_service_token.child_token(); + tokio::spawn(async move { + if let Err(e) = start_web_service(session_handle, child_token).await { + tracing::error!("web service error: {e:?}"); + } + }); - user_loop(&initiator).await; + user_loop(&initiator).await?; status_service_token.cancel(); initiator .shutdown(false) .await - .expect("graceful shutdown to succeed"); + .context("graceful shutdown failed")?; + Ok(()) } -async fn user_loop(session: &Initiator) { +async fn user_loop(session: &Initiator) -> Result<()> { loop { println!("(q) to quit, (s) to send message"); - let command_task = spawn_blocking(|| { + let command_task = spawn_blocking(|| -> Result { let mut input = String::new(); std::io::stdin() .read_line(&mut input) - .expect("read line to succeed"); - input + .context("failed to read line from stdin")?; + Ok(input) }); - match command_task.await.unwrap().trim() { + let input: String = command_task + .await + .context("failed to join blocking task")??; + + match input.trim() { "q" => { - return; + return Ok(()); } "s" => { - send_message(session).await; + send_message(session).await?; } _ => { println!("Unrecognised command"); @@ -105,7 +116,7 @@ async fn user_loop(session: &Initiator) { } } -async fn send_message(session: &Initiator) { +async fn send_message(session: &Initiator) -> Result<()> { let mut order_id = format!("{}", uuid::Uuid::new_v4()); order_id.truncate(12); let order = NewOrderSingle { @@ -114,7 +125,7 @@ async fn send_message(session: &Initiator) { cl_ord_id: order_id, side: fix44::Side::Buy, order_qty: 230, - settlement_date: Date::new(2023, 9, 19).unwrap(), + settlement_date: Date::new(2023, 9, 19).context("invalid settlement date")?, currency: "USD".to_string(), number_of_allocations: 1, allocation_account: "acc1".to_string(), @@ -122,32 +133,39 @@ async fn send_message(session: &Initiator) { }; let msg = OutboundMsg::NewOrderSingle(order); - session.send_message(msg).await.unwrap(); + session + .send_message(msg) + .await + .context("failed to send message")?; + Ok(()) } async fn start_session( config_path: &str, db_config: &Database, app: TestApplication, -) -> Initiator { +) -> Result> { let mut config = Config::load_from_path(config_path); - let session_config = config.sessions.pop().expect("config to include a session"); + let session_config = config + .sessions + .pop() + .context("config must include a session")?; match db_config { Database::Redb => { let store = hotfix::store::redb::RedbMessageStore::new("session.db") - .expect("be able to create store"); + .context("failed to create redb store")?; Initiator::start(session_config, app, store).await } Database::Mongodb => { let uri = "mongodb://localhost:30001"; let client = Client::with_uri_str(uri) .await - .expect("able to create client"); + .context("failed to create mongodb client")?; let store = hotfix::store::mongodb::MongoDbMessageStore::new(client.database("hotfix"), None) .await - .expect("be able to create store"); + .context("failed to create mongodb store")?; Initiator::start(session_config, app, store).await } } @@ -156,13 +174,15 @@ async fn start_session( async fn start_web_service( session_handle: SessionHandle, cancellation_token: CancellationToken, -) { +) -> Result<()> { let config = RouterConfig { enable_admin_endpoints: true, }; let router = build_router_with_config(session_handle, config); let host_and_port = std::env::var("HOST_AND_PORT").unwrap_or("0.0.0.0:9881".to_string()); - let listener = tokio::net::TcpListener::bind(&host_and_port).await.unwrap(); + let listener = tokio::net::TcpListener::bind(&host_and_port) + .await + .context("failed to bind TCP listener")?; info!("starting web interface on http://{host_and_port}"); @@ -176,4 +196,6 @@ async fn start_web_service( info!("status service cancelled"); } } + + Ok(()) } From 822e3a73b00f8dafb10aaf801eff8289103abfc6 Mon Sep 17 00:00:00 2001 From: David Steiner Date: Mon, 19 Jan 2026 14:13:25 +0100 Subject: [PATCH 2/9] Replace unwraps in resend function with error propagation --- crates/hotfix/src/session.rs | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/crates/hotfix/src/session.rs b/crates/hotfix/src/session.rs index 86e48c9..08ef691 100644 --- a/crates/hotfix/src/session.rs +++ b/crates/hotfix/src/session.rs @@ -13,7 +13,7 @@ use crate::message::parser::RawFixMessage; use crate::message::{InboundMessage, generate_message}; use crate::store::MessageStore; use crate::transport::writer::WriterRef; -use anyhow::{Result, anyhow, bail}; +use anyhow::{Context, Result, anyhow, bail}; use chrono::Utc; use hotfix_message::dict::Dictionary; use hotfix_message::message::{Config as MessageConfig, Message}; @@ -430,7 +430,7 @@ where self.store.increment_target_seq_number().await?; self.resend_messages(begin_seq_number, end_seq_number, message) - .await; + .await?; Ok(()) } @@ -663,13 +663,13 @@ where }; } - async fn resend_messages(&mut self, begin: u64, end: u64, _message: &Message) { + async fn resend_messages(&mut self, begin: u64, end: u64, _message: &Message) -> Result<()> { info!(begin, end, "resending messages as requested"); let messages = self .store .get_slice(begin as usize, end as usize) .await - .unwrap(); + .context("failed to retrieve messages from store")?; let no = messages.len(); debug!(number_of_messages = no, "number of messages"); @@ -682,9 +682,18 @@ where .message_builder .build(msg.as_slice()) .into_message() - .unwrap(); - sequence_number = message.header().get(MSG_SEQ_NUM).unwrap(); - let message_type: String = message.header().get::<&str>(MSG_TYPE).unwrap().to_string(); + .with_context(|| format!("failed to build message for raw message: {msg:?}"))?; + sequence_number = message.header().get::(MSG_SEQ_NUM).map_err(|e| { + anyhow!( + "sequence number in message to resend is unexpectedly missing: {:?}", + e + ) + })?; + let message_type: String = message + .header() + .get::<&str>(MSG_TYPE) + .context("message type in message to resend is unexpectedly missing")? + .to_string(); if is_admin(message_type.as_str()) { if reset_start.is_none() { @@ -708,13 +717,16 @@ where } self.send_raw( message_type.as_bytes(), - message.encode(&self.message_config).unwrap(), + message + .encode(&self.message_config) + .context("failed to encode message")?, ) .await; if enabled!(tracing::Level::DEBUG) { - let m = String::from_utf8(msg.clone()).unwrap(); - debug!(sequence_number, message = m, "resent message"); + if let Ok(m) = String::from_utf8(msg.clone()) { + debug!(sequence_number, message = m, "resent message"); + } } } @@ -724,6 +736,8 @@ where Self::log_skipped_admin_messages(begin, end); self.send_sequence_reset(begin, end).await; } + + Ok(()) } fn log_skipped_admin_messages(begin: u64, end: u64) { From 4eabb9cde090dbf620f76cf90695618e8fb4b0fb Mon Sep 17 00:00:00 2001 From: David Steiner Date: Mon, 19 Jan 2026 14:53:15 +0100 Subject: [PATCH 3/9] Swallow errors when failing to send messages to the writer --- crates/hotfix/src/transport/writer.rs | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/crates/hotfix/src/transport/writer.rs b/crates/hotfix/src/transport/writer.rs index d36f939..348c9eb 100644 --- a/crates/hotfix/src/transport/writer.rs +++ b/crates/hotfix/src/transport/writer.rs @@ -1,5 +1,6 @@ use crate::message::parser::RawFixMessage; use tokio::sync::mpsc; +use tracing::warn; #[derive(Clone, Debug)] pub enum WriterMessage { @@ -18,16 +19,20 @@ impl WriterRef { } pub async fn send_raw_message(&self, msg: RawFixMessage) { - self.sender - .send(WriterMessage::SendMessage(msg)) - .await - .expect("be able to send message"); + if let Err(err) = self.sender.send(WriterMessage::SendMessage(msg)).await { + // If the channel is closed, the writer task has terminated. + // The session will receive a Disconnected event with the actual + // disconnection reason, so we don't need to handle the error here. + // The message we failed to send will be recovered by the counterparty + // through the built-in recovery mechanisms of FIX. + warn!("trying to send message but the writer is gone: {}", err); + } } pub async fn disconnect(&self) { - self.sender - .send(WriterMessage::Disconnect) - .await - .expect("be able to disconnect") + if let Err(err) = self.sender.send(WriterMessage::Disconnect).await { + // If the channel is closed, we're already effectively disconnected. + warn!("trying to send disconnect but the writer is gone: {}", err); + } } } From 0a6ad7164212909afe30d63f916212c56cb706f8 Mon Sep 17 00:00:00 2001 From: David Steiner Date: Mon, 19 Jan 2026 16:03:11 +0100 Subject: [PATCH 4/9] Replace unwraps with anyhow errors in session code --- crates/hotfix/src/session.rs | 190 ++++++++++++++++++-------- examples/simple-new-order/src/main.rs | 4 +- 2 files changed, 135 insertions(+), 59 deletions(-) diff --git a/crates/hotfix/src/session.rs b/crates/hotfix/src/session.rs index 08ef691..87d16e2 100644 --- a/crates/hotfix/src/session.rs +++ b/crates/hotfix/src/session.rs @@ -139,7 +139,9 @@ where let reject = Reject::new(msg_seq_num) .session_reject_reason(SessionRejectReason::InvalidTagNumber) .text(&format!("invalid field {tag}")); - self.send_message(reject).await; + self.send_message(reject) + .await + .context("failed to send reject")?; } Err(err) => { error!("failed to get message seq num: {:?}", err); @@ -161,7 +163,9 @@ where SessionRejectReason::RepeatingGroupFieldsOutOfOrder, ) .text(&format!("field appears in incorrect order:{tag}")); - self.send_message(reject).await; + self.send_message(reject) + .await + .context("failed to send reject")?; } Err(err) => { error!("failed to get message seq num: {:?}", err); @@ -240,7 +244,10 @@ where } self.store.increment_target_seq_number().await?; } - Err(err) => self.handle_verification_error(err).await, + Err(err) => self + .handle_verification_error(err) + .await + .context("failed to handle verification error")?, } Ok(()) @@ -290,14 +297,16 @@ where verify_message(message, &self.config, expected_seq_number) } - async fn on_connect(&mut self, writer: WriterRef) { + async fn on_connect(&mut self, writer: WriterRef) -> Result<()> { self.state = SessionState::AwaitingLogon { writer, logon_sent: false, logon_timeout: Instant::now() + Duration::from_secs(self.config.logon_timeout), }; self.reset_peer_timer(None); - self.send_logon().await; + self.send_logon().await?; + + Ok(()) } async fn on_disconnect(&mut self, reason: String) { @@ -327,7 +336,10 @@ where self.application.on_logon().await; self.store.increment_target_seq_number().await?; } - Err(err) => self.handle_verification_error(err).await, + Err(err) => self + .handle_verification_error(err) + .await + .context("failed to handle verification error")?, } } else { error!("received unexpected logon message"); @@ -338,7 +350,7 @@ where async fn on_logout(&mut self) -> Result<()> { if self.state.is_logged_on() { - self.send_logout("Logout acknowledged").await; + self.send_logout("Logout acknowledged").await?; } self.application.on_logout("peer has logged us out").await; @@ -378,7 +390,8 @@ where self.store.increment_target_seq_number().await?; self.send_message(Heartbeat::for_request(req_id.to_string())) - .await; + .await + .context("failed to send heartbeat in response to test request")?; Ok(()) } @@ -399,7 +412,9 @@ where ) .session_reject_reason(SessionRejectReason::RequiredTagMissing) .text("missing begin sequence number for resend request"); - self.send_message(reject).await; + self.send_message(reject) + .await + .context("failed to send reject for invalid resend request")?; return Ok(()); } }; @@ -422,7 +437,9 @@ where ) .session_reject_reason(SessionRejectReason::RequiredTagMissing) .text("missing end sequence number for resend request"); - self.send_message(reject).await; + self.send_message(reject) + .await + .context("failed to send reject for invalid resend request")?; return Ok(()); } }; @@ -456,7 +473,7 @@ where .map_err(|_| anyhow!("failed to get seq number"))?; let is_gap_fill: bool = message.get(GAP_FILL_FLAG).unwrap_or(false); if let Err(err) = self.verify_message(message, is_gap_fill) { - self.handle_verification_error(err).await; + self.handle_verification_error(err).await?; return Ok(()); } @@ -470,7 +487,9 @@ where let reject = Reject::new(msg_seq_num) .session_reject_reason(SessionRejectReason::RequiredTagMissing) .text("missing NewSeqNo tag in sequence reset message"); - self.send_message(reject).await; + self.send_message(reject).await.context( + "failed to send reject message in response to invalid sequence reset message", + )?; // note: we don't increment the target seq number here // this is an ambiguous case in the specification, but leaving the @@ -490,14 +509,16 @@ where let reject = Reject::new(msg_seq_num) .session_reject_reason(SessionRejectReason::ValueIsIncorrect) .text(&text); - self.send_message(reject).await; + self.send_message(reject).await.context( + "failed to send reject message in response to invalid sequence reset message", + )?; return Ok(()); } self.store.set_target_seq_number(end - 1).await } - async fn handle_verification_error(&mut self, error: MessageVerificationError) { + async fn handle_verification_error(&mut self, error: MessageVerificationError) -> Result<()> { match error { MessageVerificationError::SeqNumberTooLow { expected, @@ -508,7 +529,8 @@ where .await; } MessageVerificationError::SeqNumberTooHigh { expected, actual } => { - self.handle_sequence_number_too_high(expected, actual).await; + self.handle_sequence_number_too_high(expected, actual) + .await?; } MessageVerificationError::IncorrectBeginString(begin_string) => { self.handle_incorrect_begin_string(begin_string).await; @@ -542,6 +564,8 @@ where .await; } } + + Ok(()) } async fn handle_incorrect_begin_string(&mut self, received_begin_string: String) { @@ -563,7 +587,9 @@ where let reject = Reject::new(msg_seq_num) .session_reject_reason(SessionRejectReason::ValueIsIncorrect) .text(&format!("invalid comp ID {received_comp_id}")); - self.send_message(reject).await; + if let Err(err) = self.send_message(reject).await { + error!("failed to send reject message with invalid comp ID: {err}"); + }; self.logout_and_terminate("incorrect comp ID received") .await; @@ -589,7 +615,7 @@ where self.state = SessionState::new_disconnected(false, &reason); } - async fn handle_sequence_number_too_high(&mut self, expected: u64, actual: u64) { + async fn handle_sequence_number_too_high(&mut self, expected: u64, actual: u64) -> Result<()> { match self .state .try_transition_to_awaiting_resend(expected, actual) @@ -598,7 +624,9 @@ where debug!( "we are behind target (ours: {expected}, theirs: {actual}), requesting resend." ); - self.send_resend_request(expected, actual).await; + self.send_resend_request(expected, actual) + .await + .context("failed to send resend request")?; } AwaitingResendTransitionOutcome::InvalidState(reason) => { error!("failed to request resend: {reason}"); @@ -618,6 +646,8 @@ where ); } } + + Ok(()) } async fn handle_invalid_msg_type(&mut self, message: Message, msg_type: &str) { @@ -626,7 +656,9 @@ where let reject = Reject::new(msg_seq_num) .session_reject_reason(SessionRejectReason::InvalidMsgtype) .text(&format!("invalid message type {msg_type}")); - self.send_message(reject).await; + if let Err(err) = self.send_message(reject).await { + error!("failed to send reject message for invalid msgtype: {err}"); + }; #[allow(clippy::collapsible_if)] if let Ok(seq_num) = message.header().get::(MSG_SEQ_NUM) @@ -647,7 +679,9 @@ where let reject = Reject::new(msg_seq_num) .session_reject_reason(SessionRejectReason::SendingtimeAccuracyProblem) .text(text); - self.send_message(reject).await; + if let Err(err) = self.send_message(reject).await { + error!("failed to send reject for time accuracy problem: {err}"); + }; if let Err(err) = self.store.increment_target_seq_number().await { error!("failed to increment target seq number: {:?}", err); }; @@ -657,7 +691,9 @@ where let reject = Reject::new(msg_seq_num) .session_reject_reason(SessionRejectReason::RequiredTagMissing) .text("original sending time is required"); - self.send_message(reject).await; + if let Err(err) = self.send_message(reject).await { + error!("failed to send reject for time missing tag: {err}"); + }; if let Err(err) = self.store.increment_target_seq_number().await { error!("failed to increment target seq number: {:?}", err); }; @@ -705,7 +741,9 @@ where if let Some(begin) = reset_start { let end = sequence_number; Self::log_skipped_admin_messages(begin, end); - self.send_sequence_reset(begin, end).await; + self.send_sequence_reset(begin, end) + .await + .context("failed to send sequence reset")?; reset_start = None; } @@ -723,10 +761,10 @@ where ) .await; - if enabled!(tracing::Level::DEBUG) { - if let Ok(m) = String::from_utf8(msg.clone()) { - debug!(sequence_number, message = m, "resent message"); - } + if enabled!(tracing::Level::DEBUG) + && let Ok(m) = String::from_utf8(msg.clone()) + { + debug!(sequence_number, message = m, "resent message"); } } @@ -734,7 +772,9 @@ where // the final reset if needed let end = sequence_number; Self::log_skipped_admin_messages(begin, end); - self.send_sequence_reset(begin, end).await; + self.send_sequence_reset(begin, end) + .await + .context("failed to send sequence reset")?; } Ok(()) @@ -757,25 +797,27 @@ where .reset_peer_timer(self.config.heartbeat_interval, test_request_id); } - async fn send_app_message(&mut self, message: Outbound) { + async fn send_app_message(&mut self, message: Outbound) -> Result<()> { match self.application.on_outbound_message(&message).await { OutboundDecision::Send => { - self.send_message(message).await; + self.send_message(message) + .await + .context("failed to send app message")?; } OutboundDecision::Drop => { debug!("dropped outbound message as instructed by the application"); } OutboundDecision::TerminateSession => { - error!("failed to send message to application"); + warn!("the application indicated we should terminate the session"); self.state.disconnect_writer().await; } } + + Ok(()) } - async fn send_message(&mut self, message: impl OutboundMessage) { + async fn send_message(&mut self, message: impl OutboundMessage) -> Result<()> { let seq_num = self.store.next_sender_seq_number(); - self.store.increment_sender_seq_number().await.unwrap(); - let msg_type = message.message_type().as_bytes().to_vec(); let msg = generate_message( &self.config.begin_string, @@ -784,9 +826,19 @@ where seq_num, message, ) - .unwrap(); - self.store.add(seq_num, &msg).await.unwrap(); + .context("failed to generate message")?; + self.store + .increment_sender_seq_number() + .await + .context("failed to increment sender seq number")?; + + self.store + .add(seq_num, &msg) + .await + .context("failed to add message to store")?; self.send_raw(&msg_type, msg).await; + + Ok(()) } async fn send_raw(&mut self, message_type: &[u8], data: Vec) { @@ -796,7 +848,7 @@ where self.reset_heartbeat_timer(); } - async fn send_sequence_reset(&mut self, begin: u64, end: u64) { + async fn send_sequence_reset(&mut self, begin: u64, end: u64) -> Result<()> { let sequence_reset = SequenceReset { gap_fill: true, new_seq_no: end, @@ -808,20 +860,22 @@ where begin, sequence_reset, ) - .unwrap(); + .context("failed to generate message")?; self.send_raw(b"4", raw_message).await; debug!(begin, end, "sent reset sequence"); + + Ok(()) } - async fn send_resend_request(&mut self, begin: u64, end: u64) { + async fn send_resend_request(&mut self, begin: u64, end: u64) -> Result<()> { let request = ResendRequest::new(begin, end); - self.send_message(request).await; + self.send_message(request).await } - async fn send_logon(&mut self) { + async fn send_logon(&mut self) -> Result<()> { let reset_config = if self.config.reset_on_logon || self.reset_on_next_logon { - self.store.reset().await.unwrap(); + self.store.reset().await?; ResetSeqNumConfig::Reset } else { ResetSeqNumConfig::NoReset(Some(self.store.next_target_seq_number())) @@ -830,12 +884,12 @@ where let logon = Logon::new(self.config.heartbeat_interval, reset_config); - self.send_message(logon).await; + self.send_message(logon).await } - async fn send_logout(&mut self, reason: &str) { + async fn send_logout(&mut self, reason: &str) -> Result<()> { let logout = Logout::with_reason(reason.to_string()); - self.send_message(logout).await; + self.send_message(logout).await } /// Sends a logout message and immediately disconnects the counterparty. @@ -846,7 +900,9 @@ where /// /// In other scenarios, [`initiate_graceful_logout`] should be preferred. async fn logout_and_terminate(&mut self, reason: &str) { - self.send_logout(reason).await; + if let Err(err) = self.send_logout(reason).await { + warn!("failed to send logout during session termination: {}", err); + } self.state.disconnect_writer().await; } @@ -855,13 +911,15 @@ where /// The session waits for a configurable timeout period for the counterparty to /// respond with a `Logout` message. If no response is received within the timeout /// period, it disconnects the counterparty. - async fn initiate_graceful_logout(&mut self, reason: &str, reconnect: bool) { + async fn initiate_graceful_logout(&mut self, reason: &str, reconnect: bool) -> Result<()> { if self.state.try_transition_to_awaiting_logout( Duration::from_secs(self.config.logout_timeout), reconnect, ) { - self.send_logout(reason).await; + self.send_logout(reason).await?; } + + Ok(()) } async fn handle_session_event(&mut self, event: SessionEvent) { @@ -881,12 +939,14 @@ where self.on_disconnect(reason).await; } SessionEvent::Connected(w) => { - self.on_connect(w).await; + if let Err(err) = self.on_connect(w).await { + error!(err = ?err, "failed to establish logon after connecting"); + } } SessionEvent::ShouldReconnect(responder) => { - responder - .send(self.state.should_reconnect()) - .expect("be able to respond"); + if let Err(_) = responder.send(self.state.should_reconnect()) { + warn!("tried to respond to ShouldReconnect query but the receiver is gone"); + } } SessionEvent::AwaitingActiveSession(responder) => { self.state.register_session_awaiter(responder); @@ -895,15 +955,21 @@ where } async fn handle_outbound_message(&mut self, message: Outbound) { - self.send_app_message(message).await; + if let Err(err) = self.send_app_message(message).await { + error!(err = ?err, "failed to send app message: {err}"); + } } async fn handle_admin_request(&mut self, request: AdminRequest) { match request { AdminRequest::InitiateGracefulShutdown { reconnect } => { warn!("initiating shutdown on request from admin.."); - self.initiate_graceful_logout("explicitly requested", reconnect) - .await; + if let Err(err) = self + .initiate_graceful_logout("explicitly requested", reconnect) + .await + { + error!(err = ?err, "initiating graceful shutdown"); + } } AdminRequest::RequestSessionInfo(responder) => { info!("session info requested"); @@ -919,7 +985,9 @@ where } async fn handle_heartbeat_timeout(&mut self) { - self.send_message(Heartbeat::default()).await; + if let Err(err) = self.send_message(Heartbeat::default()).await { + error!(err = ?err, "failed to send heartbeat message"); + } } async fn handle_peer_timeout(&mut self) { @@ -936,7 +1004,9 @@ where let req_id = format!("TEST_{}", self.store.next_target_seq_number()); info!("sending TestRequest due to peer timer expiring"); let request = TestRequest::new(req_id.clone()); - self.send_message(request).await; + if let Err(err) = self.send_message(request).await { + error!(err = ?err, "failed to send TestRequest"); + } self.reset_peer_timer(Some(req_id)); } } @@ -971,8 +1041,12 @@ where } } else if self.state.is_connected() { // we are currently outside scheduled session time - self.initiate_graceful_logout("End of session time", true) - .await; + if let Err(err) = self + .initiate_graceful_logout("End of session time", true) + .await + { + error!(err = ?err, "failed to initiate graceful logout"); + } } // we always need to reschedule the check, otherwise we won't be able to resume an inactive session diff --git a/examples/simple-new-order/src/main.rs b/examples/simple-new-order/src/main.rs index c4a3ded..5ef58a8 100644 --- a/examples/simple-new-order/src/main.rs +++ b/examples/simple-new-order/src/main.rs @@ -44,7 +44,9 @@ async fn main() -> Result<()> { if let Some(path) = args.logfile { let p = Path::new(&path); - let parent = p.parent().context("log file path has no parent directory")?; + let parent = p + .parent() + .context("log file path has no parent directory")?; std::fs::create_dir_all(parent)?; let logfile = std::fs::OpenOptions::new() .write(true) From 77d6291a2011538292766e420ae8a70c306d1651 Mon Sep 17 00:00:00 2001 From: David Steiner Date: Mon, 19 Jan 2026 16:37:45 +0100 Subject: [PATCH 5/9] Add test cases to assert the writer ref doesn't panic when the receiver is dropped --- crates/hotfix/src/transport/writer.rs | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/crates/hotfix/src/transport/writer.rs b/crates/hotfix/src/transport/writer.rs index 348c9eb..0b5d134 100644 --- a/crates/hotfix/src/transport/writer.rs +++ b/crates/hotfix/src/transport/writer.rs @@ -36,3 +36,26 @@ impl WriterRef { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn send_raw_message_does_not_panic_when_channel_closed() { + let (sender, receiver) = mpsc::channel(1); + let writer = WriterRef::new(sender); + drop(receiver); + + writer.send_raw_message(RawFixMessage::new(vec![])).await; + } + + #[tokio::test] + async fn disconnect_does_not_panic_when_channel_closed() { + let (sender, receiver) = mpsc::channel(1); + let writer = WriterRef::new(sender); + drop(receiver); + + writer.disconnect().await; + } +} From ef6dde4bec01ea08e51b080a313d79671a6a9a61 Mon Sep 17 00:00:00 2001 From: David Steiner Date: Mon, 19 Jan 2026 17:26:14 +0100 Subject: [PATCH 6/9] Add test cases for schedule handing in session layer --- crates/hotfix/src/session.rs | 381 +++++++++++++++++++++++++++++++++++ 1 file changed, 381 insertions(+) diff --git a/crates/hotfix/src/session.rs b/crates/hotfix/src/session.rs index 87d16e2..ff28504 100644 --- a/crates/hotfix/src/session.rs +++ b/crates/hotfix/src/session.rs @@ -1123,3 +1123,384 @@ async fn run_session( debug!("session is shutting down") } + +#[cfg(test)] +mod tests { + use super::*; + use crate::application::{InboundDecision, OutboundDecision}; + use crate::message::{InboundMessage, OutboundMessage}; + use chrono::{DateTime, Datelike, NaiveDate, NaiveTime, TimeDelta, Timelike}; + use chrono_tz::Tz; + use hotfix_message::message::Message; + use std::sync::Arc; + use std::sync::atomic::{AtomicBool, Ordering}; + use tokio::sync::mpsc; + + /// A controllable store for testing that implements MessageStore + #[derive(Clone)] + struct TestStore { + creation_time: DateTime, + fail_reset: Arc, + reset_called: Arc, + sender_seq: u64, + target_seq: u64, + } + + impl TestStore { + fn new(creation_time: DateTime) -> Self { + Self { + creation_time, + fail_reset: Arc::new(AtomicBool::new(false)), + reset_called: Arc::new(AtomicBool::new(false)), + sender_seq: 1, + target_seq: 1, + } + } + + fn set_fail_reset(&self) { + self.fail_reset.store(true, Ordering::SeqCst); + } + + fn was_reset_called(&self) -> bool { + self.reset_called.load(Ordering::SeqCst) + } + } + + #[async_trait::async_trait] + impl MessageStore for TestStore { + async fn add(&mut self, _sequence_number: u64, _message: &[u8]) -> Result<()> { + Ok(()) + } + + async fn get_slice(&self, _begin: usize, _end: usize) -> Result>> { + Ok(vec![]) + } + + fn next_sender_seq_number(&self) -> u64 { + self.sender_seq + } + + fn next_target_seq_number(&self) -> u64 { + self.target_seq + } + + async fn increment_sender_seq_number(&mut self) -> Result<()> { + self.sender_seq += 1; + Ok(()) + } + + async fn increment_target_seq_number(&mut self) -> Result<()> { + self.target_seq += 1; + Ok(()) + } + + async fn set_target_seq_number(&mut self, seq_number: u64) -> Result<()> { + self.target_seq = seq_number; + Ok(()) + } + + async fn reset(&mut self) -> Result<()> { + self.reset_called.store(true, Ordering::SeqCst); + if self.fail_reset.load(Ordering::SeqCst) { + bail!("simulated reset failure") + } + self.creation_time = Utc::now(); + Ok(()) + } + + fn creation_time(&self) -> DateTime { + self.creation_time + } + } + + /// Dummy message type for testing that implements required traits + #[derive(Clone)] + struct DummyMessage; + + impl OutboundMessage for DummyMessage { + fn write(&self, _msg: &mut Message) {} + fn message_type(&self) -> &str { + "0" + } + } + + impl InboundMessage for DummyMessage { + fn parse(_message: &Message) -> Self { + DummyMessage + } + } + + /// Minimal no-op application for testing + struct NoOpApp; + + #[async_trait::async_trait] + impl Application for NoOpApp { + async fn on_outbound_message(&self, _: &DummyMessage) -> OutboundDecision { + OutboundDecision::Send + } + async fn on_inbound_message(&self, _: DummyMessage) -> InboundDecision { + InboundDecision::Accept + } + async fn on_logout(&mut self, _: &str) {} + async fn on_logon(&mut self) {} + } + + fn create_writer_ref() -> WriterRef { + let (sender, _) = mpsc::channel(10); + WriterRef::new(sender) + } + + fn create_test_config() -> SessionConfig { + SessionConfig { + begin_string: "FIX.4.4".to_string(), + sender_comp_id: "SENDER".to_string(), + target_comp_id: "TARGET".to_string(), + data_dictionary_path: None, + connection_host: "localhost".to_string(), + connection_port: 9876, + tls_config: None, + heartbeat_interval: 30, + logon_timeout: 10, + logout_timeout: 2, + reconnect_interval: 30, + reset_on_logon: false, + schedule: None, + } + } + + fn create_test_session( + schedule: SessionSchedule, + state: SessionState, + store: TestStore, + ) -> Session { + let config = create_test_config(); + let message_config = MessageConfig::default(); + let dictionary = Dictionary::fix44(); + let message_builder = MessageBuilder::new(dictionary, message_config).unwrap(); + + Session { + message_config, + config, + schedule, + message_builder, + state, + application: NoOpApp, + store, + schedule_check_timer: Box::pin(sleep(Duration::from_secs(1))), + reset_on_next_logon: false, + _phantom: std::marker::PhantomData, + } + } + + /// Creates a Daily schedule that is active at the current time + fn create_active_schedule() -> SessionSchedule { + // Use a 24-hour window that's definitely active + SessionSchedule::Daily { + start_time: NaiveTime::from_hms_opt(0, 0, 0).unwrap(), + end_time: NaiveTime::from_hms_opt(23, 59, 59).unwrap(), + timezone: Tz::UTC, + } + } + + /// Creates a Daily schedule that is inactive at the current time + fn create_inactive_schedule() -> SessionSchedule { + let now = Utc::now(); + let current_hour = now.time().hour(); + // Create a 1-hour window that's 12 hours from now (definitely not the current hour) + let start_hour = (current_hour + 12) % 24; + let end_hour = (start_hour + 1) % 24; + SessionSchedule::Daily { + start_time: NaiveTime::from_hms_opt(start_hour, 0, 0).unwrap(), + end_time: NaiveTime::from_hms_opt(end_hour, 0, 0).unwrap(), + timezone: Tz::UTC, + } + } + + #[tokio::test] + async fn test_handle_schedule_check_active_same_period() { + // Use NonStop schedule - always active, always same period + let schedule = SessionSchedule::NonStop; + let writer = create_writer_ref(); + let state = SessionState::new_active(writer, 30); + let store = TestStore::new(Utc::now()); + + let mut session = create_test_session(schedule, state, store); + + session.handle_schedule_check().await; + + // State should remain Active (no logout triggered) + assert!( + session.state.is_logged_on(), + "State should remain logged on for same period" + ); + assert!( + !session.store.was_reset_called(), + "Store reset should not be called for same period" + ); + } + + #[tokio::test] + async fn test_handle_schedule_check_active_different_period() { + // Use a Daily schedule that's currently active + let schedule = create_active_schedule(); + let writer = create_writer_ref(); + let state = SessionState::new_active(writer, 30); + // Creation time is yesterday - different session period + let yesterday = Utc::now() - TimeDelta::days(1); + let store = TestStore::new(yesterday); + + let mut session = create_test_session(schedule, state, store); + + // Verify the schedule correctly identifies different periods + let now = Utc::now(); + let creation_time = session.store.creation_time(); + let same_period = session + .schedule + .is_same_session_period(&creation_time, &now); + assert!( + matches!(same_period, Ok(false)), + "Schedule should identify different periods" + ); + + session.handle_schedule_check().await; + + // Store reset should have been called (indicates Ok(false) branch was taken) + // Note: logout_and_terminate disconnects the writer but state transition to + // Disconnected happens asynchronously via event processing, not in this call + assert!( + session.store.was_reset_called(), + "Store reset should be called for different period" + ); + } + + #[tokio::test] + async fn test_handle_schedule_check_active_reset_fails() { + // Use a Daily schedule that's currently active + let schedule = create_active_schedule(); + let writer = create_writer_ref(); + let state = SessionState::new_active(writer, 30); + // Creation time is yesterday - different session period + let yesterday = Utc::now() - TimeDelta::days(1); + let store = TestStore::new(yesterday); + store.set_fail_reset(); + + let mut session = create_test_session(schedule, state, store); + + session.handle_schedule_check().await; + + // Store reset should have been attempted + assert!( + session.store.was_reset_called(), + "Store reset should be called" + ); + // When reset fails, state is explicitly set to Disconnected(reconnect=false) + assert!( + matches!(session.state, SessionState::Disconnected(_)), + "State should be Disconnected after reset failure" + ); + // Should NOT reconnect since reset failed + assert!( + !session.state.should_reconnect(), + "Should not reconnect after failed reset" + ); + } + + #[tokio::test] + async fn test_handle_schedule_check_active_period_error() { + // Use a narrow schedule that's currently active but creation_time is outside + let now = Utc::now(); + let current_hour = now.time().hour(); + + // Create a 2-hour window around current time + let start_hour = if current_hour == 0 { + 23 + } else { + current_hour - 1 + }; + let end_hour = (current_hour + 2) % 24; + + let schedule = SessionSchedule::Daily { + start_time: NaiveTime::from_hms_opt(start_hour, 0, 0).unwrap(), + end_time: NaiveTime::from_hms_opt(end_hour, 0, 0).unwrap(), + timezone: Tz::UTC, + }; + + let writer = create_writer_ref(); + let state = SessionState::new_active(writer, 30); + + // Creation time is today but at a time outside the schedule window + // Use a time that's definitely outside the window (6 hours from now) + let outside_hour = (current_hour + 6) % 24; + let creation_time = DateTime::from_naive_utc_and_offset( + NaiveDate::from_ymd_opt(now.year(), now.month(), now.day()) + .unwrap() + .and_hms_opt(outside_hour, 30, 0) + .unwrap(), + Utc, + ); + + let store = TestStore::new(creation_time); + + let mut session = create_test_session(schedule, state, store); + + // Verify that is_same_session_period will return an error + let same_period = session + .schedule + .is_same_session_period(&creation_time, &now); + assert!( + same_period.is_err(), + "Schedule should return error when creation_time is outside active window" + ); + + session.handle_schedule_check().await; + + // The Err branch calls logout_and_terminate which disconnects the writer. + // Store reset is NOT called in the Err branch, only in Ok(false). + assert!( + !session.store.was_reset_called(), + "Store reset should not be called on period check error" + ); + } + + #[tokio::test] + async fn test_handle_schedule_check_inactive_connected() { + // Use a schedule that's currently inactive + let schedule = create_inactive_schedule(); + let writer = create_writer_ref(); + let state = SessionState::new_active(writer, 30); + let store = TestStore::new(Utc::now()); + + let mut session = create_test_session(schedule, state, store); + + session.handle_schedule_check().await; + + // State should be AwaitingLogout (graceful logout initiated) + assert!( + session.state.is_awaiting_logout(), + "State should be AwaitingLogout when schedule is inactive and was connected" + ); + } + + #[tokio::test] + async fn test_handle_schedule_check_inactive_disconnected() { + // Use a schedule that's currently inactive + let schedule = create_inactive_schedule(); + let state = SessionState::new_disconnected(true, "test"); + let store = TestStore::new(Utc::now()); + + let mut session = create_test_session(schedule, state, store); + + session.handle_schedule_check().await; + + // State should remain Disconnected (no action taken) + assert!( + matches!(session.state, SessionState::Disconnected(_)), + "State should remain Disconnected when schedule is inactive and was disconnected" + ); + // Reconnect flag should be preserved + assert!( + session.state.should_reconnect(), + "Reconnect flag should be preserved" + ); + } +} From ca35e193be5c5f6f105878fde5640af641a1264d Mon Sep 17 00:00:00 2001 From: David Steiner Date: Mon, 19 Jan 2026 19:47:08 +0100 Subject: [PATCH 7/9] Add test cases for initiator to cover basic flows --- crates/hotfix/src/initiator.rs | 129 +++++++++++++++++++++++++++++++++ crates/hotfix/src/session.rs | 2 +- 2 files changed, 130 insertions(+), 1 deletion(-) diff --git a/crates/hotfix/src/initiator.rs b/crates/hotfix/src/initiator.rs index e8539cd..2f852b5 100644 --- a/crates/hotfix/src/initiator.rs +++ b/crates/hotfix/src/initiator.rs @@ -119,3 +119,132 @@ async fn establish_connection( completion_tx.send_replace(true); } + +#[cfg(all(test, feature = "fix44"))] +mod tests { + use super::*; + use crate::application::{Application, InboundDecision, OutboundDecision}; + use crate::message::InboundMessage; + use crate::store::in_memory::InMemoryMessageStore; + use hotfix_message::message::Message; + use std::time::Duration; + use tokio::net::TcpListener; + + // Minimal message type for tests + #[derive(Clone)] + struct DummyMessage; + + impl OutboundMessage for DummyMessage { + fn write(&self, _msg: &mut Message) {} + fn message_type(&self) -> &str { + "0" + } + } + + impl InboundMessage for DummyMessage { + fn parse(_message: &Message) -> Self { + DummyMessage + } + } + + // No-op application + struct NoOpApp; + + #[async_trait::async_trait] + impl Application for NoOpApp { + async fn on_outbound_message(&self, _msg: &DummyMessage) -> OutboundDecision { + OutboundDecision::Send + } + async fn on_inbound_message(&self, _msg: DummyMessage) -> InboundDecision { + InboundDecision::Accept + } + async fn on_logout(&mut self, _reason: &str) {} + async fn on_logon(&mut self) {} + } + + fn create_test_config(host: &str, port: u16) -> SessionConfig { + SessionConfig { + begin_string: "FIX.4.4".to_string(), + sender_comp_id: "TEST-SENDER".to_string(), + target_comp_id: "TEST-TARGET".to_string(), + data_dictionary_path: None, + connection_host: host.to_string(), + connection_port: port, + tls_config: None, + heartbeat_interval: 30, + logon_timeout: 10, + logout_timeout: 2, + reconnect_interval: 1, // Short for tests + reset_on_logon: false, + schedule: None, + } + } + + #[tokio::test] + async fn test_start_creates_initiator_successfully() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let config = create_test_config("127.0.0.1", port); + + let initiator = Initiator::start(config, NoOpApp, InMemoryMessageStore::default()) + .await + .unwrap(); + + // Verify initial state + assert!(!initiator.is_shutdown()); + assert!(initiator.is_interested("TEST-SENDER", "TEST-TARGET")); + assert!(!initiator.is_interested("WRONG", "TEST-TARGET")); + } + + #[tokio::test] + async fn test_initiator_connects_to_listener() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let config = create_test_config("127.0.0.1", port); + + let _initiator = Initiator::start(config, NoOpApp, InMemoryMessageStore::default()) + .await + .unwrap(); + + // Accept the connection from the initiator + let accept_result = tokio::time::timeout(Duration::from_secs(2), listener.accept()).await; + + assert!( + accept_result.is_ok(), + "Initiator should connect to listener" + ); + } + + #[tokio::test] + async fn test_initiator_reconnects_after_disconnect() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let mut config = create_test_config("127.0.0.1", port); + config.reconnect_interval = 1; // Short interval for test + + let _initiator = Initiator::::start::( + config, + NoOpApp, + InMemoryMessageStore::default(), + ) + .await + .unwrap(); + + // Accept first connection + let (conn1, _) = tokio::time::timeout(Duration::from_secs(2), listener.accept()) + .await + .expect("no connection was established within timeout duration") + .expect("IO error in connection"); + + // Drop the connection to trigger reconnect + drop(conn1); + + // Should reconnect - accept second connection + let accept_result = tokio::time::timeout(Duration::from_secs(3), listener.accept()).await; + + assert!( + accept_result.is_ok(), + "Initiator should reconnect after disconnect" + ); + } +} diff --git a/crates/hotfix/src/session.rs b/crates/hotfix/src/session.rs index ff28504..221502a 100644 --- a/crates/hotfix/src/session.rs +++ b/crates/hotfix/src/session.rs @@ -944,7 +944,7 @@ where } } SessionEvent::ShouldReconnect(responder) => { - if let Err(_) = responder.send(self.state.should_reconnect()) { + if responder.send(self.state.should_reconnect()).is_err() { warn!("tried to respond to ShouldReconnect query but the receiver is gone"); } } From b5c618195f845a8a97bf5a75d199196cbbccbd67 Mon Sep 17 00:00:00 2001 From: David Steiner Date: Mon, 19 Jan 2026 21:41:47 +0100 Subject: [PATCH 8/9] Add test cases for load_from_file --- codecov.yml | 2 ++ crates/hotfix-dictionary/src/dictionary.rs | 26 ++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/codecov.yml b/codecov.yml index fb6842c..2fb7dbc 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,3 +1,5 @@ ignore: - "**/examples/**/*" - "**/tests/**/*" + +coverage: diff --git a/crates/hotfix-dictionary/src/dictionary.rs b/crates/hotfix-dictionary/src/dictionary.rs index acd6642..d5f1d8e 100644 --- a/crates/hotfix-dictionary/src/dictionary.rs +++ b/crates/hotfix-dictionary/src/dictionary.rs @@ -305,3 +305,29 @@ impl Dictionary { .collect() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_load_from_file_success() { + let path = concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/resources/quickfix/FIX-4.4.xml" + ); + let dict = Dictionary::load_from_file(path).unwrap(); + assert_eq!(dict.version(), "FIX.4.4"); + assert!(dict.message_by_name("Heartbeat").is_some()); + } + + #[test] + fn test_load_from_file_invalid_content() { + let path = concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/test_data/quickfix_specs/empty_file.xml" + ); + let result = Dictionary::load_from_file(path); + assert!(result.is_err()); + } +} From 467cadf5793d51245a3ffc073069e4bc0aa8fd88 Mon Sep 17 00:00:00 2001 From: David Steiner Date: Mon, 19 Jan 2026 21:44:08 +0100 Subject: [PATCH 9/9] Require 80% patch coverage --- codecov.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/codecov.yml b/codecov.yml index 2fb7dbc..0d5a543 100644 --- a/codecov.yml +++ b/codecov.yml @@ -3,3 +3,10 @@ ignore: - "**/tests/**/*" coverage: + status: + patch: + default: + target: 80% + project: + default: + target: auto