diff --git a/CHANGELOG.md b/CHANGELOG.md index fe5b80d..2a86e6d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,14 +2,19 @@ ## Unreleased +### Fixed + +* A bug where locally unavailable documents sent by peers with an announce + policy set to false would be marked as unavailable + ### Added -- `TcpDialer::new` which takes a `Url` parameter, rather than a host and a port +* `TcpDialer::new` which takes a `Url` parameter, rather than a host and a port or a socket address. -- `Repo::dial_tcp()` to simplify construction of `TcpDialer`. -- Allow documents syncing over the TCP transport to be up to 8gb size instead +* `Repo::dial_tcp()` to simplify construction of `TcpDialer`. +* Allow documents syncing over the TCP transport to be up to 8gb size instead of Tokio's default 8mb frame size -- Exposed receiving `ConnectionHandle`s via `accept()`. Users can now subscribe +* Exposed receiving `ConnectionHandle`s via `accept()`. Users can now subscribe to an `events()` stream directly on the handle, or `await` for `handshake_completed()`. diff --git a/samod-core/src/actors/document.rs b/samod-core/src/actors/document.rs index 0908c3b..59de216 100644 --- a/samod-core/src/actors/document.rs +++ b/samod-core/src/actors/document.rs @@ -40,8 +40,7 @@ mod load; mod on_disk_state; pub use on_disk_state::CompactionHash; mod peer_doc_connection; -mod ready; -mod request; +mod phase; mod spawn_args; mod with_doc_result; pub use with_doc_result::WithDocResult; diff --git a/samod-core/src/actors/document/doc_actor_result.rs b/samod-core/src/actors/document/doc_actor_result.rs index 9aa9570..e014018 100644 --- a/samod-core/src/actors/document/doc_actor_result.rs +++ b/samod-core/src/actors/document/doc_actor_result.rs @@ -3,11 +3,11 @@ use std::collections::HashMap; use automerge::ChangeHash; use crate::{ - ConnectionId, DocumentChanged, PeerId, StorageKey, + ConnectionId, DocumentChanged, DocumentId, PeerId, StorageKey, actors::{ DocToHubMsg, - document::{SyncMessageStat, io::DocumentIoTask}, - messages::DocToHubMsgPayload, + document::{DocumentStatus, SyncMessageStat, io::DocumentIoTask}, + messages::{Broadcast, DocToHubMsgPayload, SyncMessage}, }, io::{IoTask, IoTaskId, StorageTask}, network::PeerDocState, @@ -58,8 +58,60 @@ impl DocActorResult { } /// Send a message back to the hub - pub(crate) fn send_message(&mut self, message: DocToHubMsgPayload) { - self.outgoing_messages.push(DocToHubMsg(message)); + pub(crate) fn send_sync_message( + &mut self, + conn_id: ConnectionId, + doc_id: DocumentId, + message: SyncMessage, + ) { + self.outgoing_messages + .push(DocToHubMsg(DocToHubMsgPayload::SendSyncMessage { + connection_id: conn_id, + document_id: doc_id, + message, + })); + } + + pub(crate) fn send_broadcast(&mut self, connections: Vec, msg: Broadcast) { + self.outgoing_messages + .push(DocToHubMsg(DocToHubMsgPayload::Broadcast { + connections, + msg, + })); + } + + pub(crate) fn send_terminated(&mut self) { + self.outgoing_messages + .push(DocToHubMsg(DocToHubMsgPayload::Terminated)); + } + + pub(crate) fn send_peer_states_changes( + &mut self, + new_states: HashMap, + ) { + // Remove previous peer state change messages as they are redundant + self.outgoing_messages + .retain(|m| !matches!(m.0, DocToHubMsgPayload::PeerStatesChanged { .. })); + self.outgoing_messages + .push(DocToHubMsg(DocToHubMsgPayload::PeerStatesChanged { + new_states, + })); + } + + pub(crate) fn send_doc_status_update(&mut self, new_status: DocumentStatus) { + // remove any existing doc status update so that if the document status changes + // multiple times during a turn, only the latest status is sent to the hub. + // This is especially important to avoid bouncing through a NotFound state + // when loading a document as that will cause any outstanding find commands + // to fail even if the document loads successfully in this turn (as it might + // if we finish loading after receiving a sync message with the document + // content). + self.outgoing_messages + .retain(|m| !matches!(m.0, DocToHubMsgPayload::DocumentStatusChanged { .. })); + self.outgoing_messages + .push(DocToHubMsg(DocToHubMsgPayload::DocumentStatusChanged { + new_status, + })); } pub(crate) fn put(&mut self, key: StorageKey, value: Vec) -> IoTaskId { diff --git a/samod-core/src/actors/document/doc_state.rs b/samod-core/src/actors/document/doc_state.rs index 77c27da..02084f0 100644 --- a/samod-core/src/actors/document/doc_state.rs +++ b/samod-core/src/actors/document/doc_state.rs @@ -6,16 +6,21 @@ use automerge::Automerge; use crate::{ ConnectionId, DocumentId, StorageKey, UnixTimestamp, actors::{ - document::{DocActorResult, SyncDirection, SyncMessageStat}, - messages::{Broadcast, DocMessage, DocToHubMsgPayload, SyncMessage}, + document::{ + DocActorResult, SyncDirection, SyncMessageStat, + phase::{ + loading::Loading, + ready::Ready, + request::{Request, RequestState}, + }, + }, + messages::{Broadcast, DocMessage, SyncMessage}, }, }; use super::{ DocumentStatus, peer_doc_connection::{AnnouncePolicy, PeerDocConnection}, - ready::Ready, - request::{Request, RequestState}, }; #[derive(Debug)] @@ -32,9 +37,7 @@ pub(super) struct DocState { #[derive(Debug)] pub enum Phase { - Loading { - pending_sync_messages: HashMap>, - }, + Loading(Loading), Requesting(Request), Ready(Ready), NotFound, @@ -56,9 +59,7 @@ impl DocState { any_dialer_connecting: bool, ) -> Self { Self { - phase: Phase::Loading { - pending_sync_messages: HashMap::new(), - }, + phase: Phase::Loading(Loading::new()), document_id, doc, any_dialer_connecting, @@ -79,43 +80,33 @@ impl DocState { PhaseTransition::None => {} PhaseTransition::ToReady => { tracing::trace!("transitioning to ready"); - out.send_message(DocToHubMsgPayload::DocumentStatusChanged { - new_status: DocumentStatus::Ready, - }); + out.send_doc_status_update(DocumentStatus::Ready); out.emit_doc_changed(self.doc.get_heads()); self.phase = Phase::Ready(Ready::new()); } PhaseTransition::ToNotFound => { tracing::trace!("transitioning to NotFound"); - out.send_message(DocToHubMsgPayload::DocumentStatusChanged { - new_status: DocumentStatus::NotFound, - }); + out.send_doc_status_update(DocumentStatus::NotFound); if let Phase::Requesting(request) = &self.phase { for peer in request.peers_waiting_for_us_to_respond() { - out.send_message(DocToHubMsgPayload::SendSyncMessage { - connection_id: peer, - document_id: self.document_id.clone(), - message: SyncMessage::DocUnavailable, - }); + out.send_sync_message( + peer, + self.document_id.clone(), + SyncMessage::DocUnavailable, + ); } } self.phase = Phase::NotFound; } PhaseTransition::ToRequesting(request) => { tracing::trace!("transitioning to requesting"); - out.send_message(DocToHubMsgPayload::DocumentStatusChanged { - new_status: DocumentStatus::Requesting, - }); + out.send_doc_status_update(DocumentStatus::Requesting); self.phase = Phase::Requesting(request); } PhaseTransition::ToLoading => { tracing::trace!("transitioning to loading"); - out.send_message(DocToHubMsgPayload::DocumentStatusChanged { - new_status: DocumentStatus::Loading, - }); - self.phase = Phase::Loading { - pending_sync_messages: HashMap::new(), - }; + out.send_doc_status_update(DocumentStatus::Loading); + self.phase = Phase::Loading(Loading::new()); } } } @@ -159,8 +150,9 @@ impl DocState { } // self.save_state // .add_on_disk(snapshots.into_keys().chain(incrementals.into_keys())); - if matches!(self.phase, Phase::Loading { .. }) { - if self.doc.get_heads().is_empty() { + if let Phase::Loading(loading) = &mut self.phase { + let pending_sync_messages = loading.take_pending_sync_messages(); + let phase_transition = if self.doc.get_heads().is_empty() { let eligible_conns = peer_connections .values() .any(|p| p.announce_policy() != AnnouncePolicy::DontAnnounce); @@ -170,47 +162,23 @@ impl DocState { self.any_dialer_connecting, "no data found on disk, requesting document" ); - let mut next_phase = Phase::Requesting(Request::new( + PhaseTransition::ToRequesting(Request::new( self.document_id.clone(), peer_connections.values(), - )); - std::mem::swap(&mut self.phase, &mut next_phase); - let Phase::Loading { - pending_sync_messages, - } = next_phase - else { - unreachable!("we already checked"); - }; - for (conn_id, msgs) in pending_sync_messages { - for msg in msgs { - self.handle_sync_message(now, out, conn_id, peer_connections, msg, now); - } - } - out.send_message(DocToHubMsgPayload::DocumentStatusChanged { - new_status: DocumentStatus::Requesting, - }); + )) } else { tracing::debug!( "no data found on disk and no connections available, transitioning to NotFound" ); - self.handle_phase_transition(out, PhaseTransition::ToNotFound); + PhaseTransition::ToNotFound } - return; - } + } else { + tracing::trace!("load complete, transitioning to ready"); + PhaseTransition::ToReady + }; - tracing::trace!("load complete, transitioning to ready"); + self.handle_phase_transition(out, phase_transition); - let mut next_phase = Phase::Ready(Ready::new()); - std::mem::swap(&mut self.phase, &mut next_phase); - let Phase::Loading { - pending_sync_messages, - } = next_phase - else { - unreachable!("we already checked"); - }; - out.send_message(DocToHubMsgPayload::DocumentStatusChanged { - new_status: DocumentStatus::Ready, - }); for (conn_id, msgs) in pending_sync_messages { for msg in msgs { self.handle_sync_message(now, out, conn_id, peer_connections, msg, now); @@ -254,10 +222,7 @@ impl DocState { } }) .collect(); - out.send_message(DocToHubMsgPayload::Broadcast { - connections: targets, - msg: Broadcast::Gossip { msg }, - }); + out.send_broadcast(targets, Broadcast::Gossip { msg }); } DocMessage::Sync(msg) => self.handle_sync_message( now, @@ -298,13 +263,8 @@ impl DocState { }; let (transition, duration) = match &mut self.phase { - Phase::Loading { - pending_sync_messages, - } => { - pending_sync_messages - .entry(connection_id) - .or_default() - .push(msg); + Phase::Loading(loading) => { + loading.receive_sync_message(connection_id, msg); (PhaseTransition::None, None) } Phase::Requesting(request) => { @@ -380,11 +340,8 @@ impl DocState { ) -> HashMap> { let mut result: HashMap> = HashMap::new(); for (conn_id, peer_conn) in peer_connections { - if let Phase::Loading { - pending_sync_messages, - } = &self.phase - { - out.pending_sync_messages = pending_sync_messages.values().map(|v| v.len()).sum(); + if let Phase::Loading(loading) = &self.phase { + out.pending_sync_messages = loading.pending_msg_count(); continue; } diff --git a/samod-core/src/actors/document/document_actor.rs b/samod-core/src/actors/document/document_actor.rs index 276a51a..65cda62 100644 --- a/samod-core/src/actors/document/document_actor.rs +++ b/samod-core/src/actors/document/document_actor.rs @@ -70,9 +70,7 @@ impl DocumentActor { let state = if let Some(doc) = initial_content { // Let the hub know this document is ready immediately if we already have content - out.send_message(DocToHubMsgPayload::DocumentStatusChanged { - new_status: DocumentStatus::Ready, - }); + out.send_doc_status_update(DocumentStatus::Ready); DocState::new_ready(document_id.clone(), doc, any_dialer_pending) } else { DocState::new_loading(document_id.clone(), Automerge::new(), any_dialer_pending) @@ -400,7 +398,7 @@ impl DocumentActor { if self.run_state == RunState::Stopping { if self.on_disk_state.is_flushed() { self.run_state = RunState::Stopped; - out.send_message(DocToHubMsgPayload::Terminated); + out.send_terminated(); out.stopped = true; } return; @@ -458,11 +456,7 @@ impl DocumentActor { .generate_sync_messages(now, out, &mut self.peer_connections) { for msg in msgs { - out.send_message(DocToHubMsgPayload::SendSyncMessage { - connection_id: conn_id, - document_id: doc_id.clone(), - message: msg, - }); + out.send_sync_message(conn_id, doc_id.clone(), msg); } } } @@ -475,7 +469,7 @@ impl DocumentActor { .collect::>(); if !states.is_empty() { out.peer_state_changes = states.clone(); - out.send_message(DocToHubMsgPayload::PeerStatesChanged { new_states: states }) + out.send_peer_states_changes(states) } } diff --git a/samod-core/src/actors/document/phase.rs b/samod-core/src/actors/document/phase.rs new file mode 100644 index 0000000..edb36b0 --- /dev/null +++ b/samod-core/src/actors/document/phase.rs @@ -0,0 +1,3 @@ +pub(crate) mod loading; +pub(crate) mod ready; +pub(crate) mod request; diff --git a/samod-core/src/actors/document/phase/loading.rs b/samod-core/src/actors/document/phase/loading.rs new file mode 100644 index 0000000..59511d4 --- /dev/null +++ b/samod-core/src/actors/document/phase/loading.rs @@ -0,0 +1,34 @@ +use std::collections::HashMap; + +use crate::{ConnectionId, actors::messages::SyncMessage}; + +#[derive(Debug)] +pub(crate) struct Loading { + pending_sync_messages: HashMap>, +} + +impl Loading { + pub(crate) fn new() -> Self { + Self { + pending_sync_messages: HashMap::new(), + } + } + + pub(crate) fn pending_msg_count(&self) -> usize { + self.pending_sync_messages + .values() + .map(|v| v.len()) + .sum::() + } + + pub(crate) fn receive_sync_message(&mut self, conn_id: ConnectionId, msg: SyncMessage) { + self.pending_sync_messages + .entry(conn_id) + .or_default() + .push(msg); + } + + pub(crate) fn take_pending_sync_messages(&mut self) -> HashMap> { + std::mem::take(&mut self.pending_sync_messages) + } +} diff --git a/samod-core/src/actors/document/ready.rs b/samod-core/src/actors/document/phase/ready.rs similarity index 92% rename from samod-core/src/actors/document/ready.rs rename to samod-core/src/actors/document/phase/ready.rs index 9c1e4a2..72e469f 100644 --- a/samod-core/src/actors/document/ready.rs +++ b/samod-core/src/actors/document/phase/ready.rs @@ -2,9 +2,13 @@ use std::time::Duration; use automerge::{Automerge, sync}; -use crate::{UnixTimestamp, actors::messages::SyncMessage}; - -use super::peer_doc_connection::{AnnouncePolicy, PeerDocConnection}; +use crate::{ + UnixTimestamp, + actors::{ + document::peer_doc_connection::{AnnouncePolicy, PeerDocConnection}, + messages::SyncMessage, + }, +}; #[derive(Debug)] pub(crate) struct Ready; diff --git a/samod-core/src/actors/document/request.rs b/samod-core/src/actors/document/phase/request.rs similarity index 98% rename from samod-core/src/actors/document/request.rs rename to samod-core/src/actors/document/phase/request.rs index 39049b7..25bb7e8 100644 --- a/samod-core/src/actors/document/request.rs +++ b/samod-core/src/actors/document/phase/request.rs @@ -3,9 +3,13 @@ use std::time::Duration; use automerge::{Automerge, ChangeHash, ReadDoc, sync}; -use crate::{ConnectionId, DocumentId, UnixTimestamp, actors::messages::SyncMessage}; - -use super::peer_doc_connection::{AnnouncePolicy, PeerDocConnection}; +use crate::{ + ConnectionId, DocumentId, UnixTimestamp, + actors::{ + document::peer_doc_connection::{AnnouncePolicy, PeerDocConnection}, + messages::SyncMessage, + }, +}; #[derive(Debug)] pub(crate) struct Request { diff --git a/samod-core/tests/document_sync_advanced.rs b/samod-core/tests/document_sync_advanced.rs index 220fbda..0722e4a 100644 --- a/samod-core/tests/document_sync_advanced.rs +++ b/samod-core/tests/document_sync_advanced.rs @@ -652,3 +652,106 @@ fn three_chained_sync_servers() { assert_eq!(verification_result.0, "alice"); assert_eq!(verification_result.1, "hello from alice"); } + +#[test] +fn dont_announce_policy_retains_documents_synced_by_clients() { + // This test is a reproduction of an issue described in + // https://github.com/alexjg/samod/pull/85 + // + // The issue was that if we receive a sync message for a document we don't + // have from a peer who does have the document, but our announce policy is + // set to false, then we erroneously treat the document as unavailable. The + // reason for this is that we were dropping the sync message from the peer + // because of the announce policy, but the sync message contained the + // document. + init_logging(); + let mut network = Network::new(); + + let server = network.create_samod("Server"); + network + .samod(&server) + .set_announce_policy(Box::new(|_, _| false)); + + let client = network.create_samod("Client"); + + let RunningDocIds { doc_id, actor_id } = network.samod(&client).create_document(); + network + .samod(&client) + .with_document_by_actor(actor_id, |doc| { + doc.transact::<_, _, AutomergeError>(|tx| { + tx.put(automerge::ROOT, "foo", "bar")?; + Ok(()) + }) + .unwrap() + }) + .unwrap(); + + network.run_until_quiescent(); + + network.connect(client, server); + network.run_until_quiescent(); + + let server_actor = network.samod(&server).find_document(&doc_id); + assert!(server_actor.is_some(), "Server should have the document"); + + let server_actor = server_actor.unwrap(); + let val = network + .samod(&server) + .with_document_by_actor(server_actor, |doc| { + doc.get(automerge::ROOT, "foo") + .unwrap() + .map(|(v, _)| v.to_string()) + }) + .unwrap(); + + assert_eq!(val.as_deref(), Some("\"bar\"")); +} + +#[test] +fn find_doesnt_bounce_through_unavailable_when_receiving_doc() { + init_logging(); + let mut network = Network::new(); + + let server = network.create_samod("Server"); + network + .samod(&server) + .set_announce_policy(Box::new(|_, _| false)); + + let client = network.create_samod("Client"); + + network.connect(client, server); + network.run_until_quiescent(); + + // Peers are now connected, now create the document on the client whilst + // simultaenously finding it on the server + + let RunningDocIds { doc_id, actor_id } = network.samod(&client).create_document(); + network + .samod(&client) + .with_document_by_actor(actor_id, |doc| { + doc.transact::<_, _, AutomergeError>(|tx| { + tx.put(automerge::ROOT, "foo", "bar")?; + Ok(()) + }) + .unwrap() + }) + .unwrap(); + let find_command = network.samod(&server).begin_find_document(&doc_id); + + network.samod(&server).pause_storage(); + assert!( + network + .samod(&server) + .check_find_document_result(find_command) + .is_none() + ); + network.run_until_quiescent(); + network.samod(&server).resume_storage(); + network.run_until_quiescent(); + + network + .samod(&server) + .check_find_document_result(find_command) + .expect("find command should have completed") + .expect("document should be found on server"); +} diff --git a/samod-test-harness/src/doc_actor_runner.rs b/samod-test-harness/src/doc_actor_runner.rs index f84d940..79919ed 100644 --- a/samod-test-harness/src/doc_actor_runner.rs +++ b/samod-test-harness/src/doc_actor_runner.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, VecDeque}; +use std::collections::{HashMap, HashSet, VecDeque}; use automerge::Automerge; use samod_core::{ @@ -10,7 +10,7 @@ use samod_core::{ io::{DocumentIoResult, DocumentIoTask}, }, }, - io::{IoResult, IoTask}, + io::{IoResult, IoTask, IoTaskId}, network::PeerDocState, }; @@ -26,6 +26,7 @@ pub(crate) struct DocActorRunner { ephemera: Vec>, doc_changed: Vec, peer_doc_state_changes: Vec>, + pending_storage_tasks: HashSet, } impl DocActorRunner { @@ -42,6 +43,7 @@ impl DocActorRunner { ephemera: Vec::new(), doc_changed: Vec::new(), peer_doc_state_changes: Vec::new(), + pending_storage_tasks: HashSet::new(), }; runner.enqueue_events(results); runner @@ -53,6 +55,7 @@ impl DocActorRunner { storage: &mut Storage, announce_policy: &dyn Fn(DocumentId, PeerId) -> bool, ) { + self.handle_completed_storage(now, storage); while let Some(event) = self.inbox.pop_front() { if self.actor.is_stopped() { self.inbox.clear(); @@ -66,27 +69,50 @@ impl DocActorRunner { .expect("failed to handle actor message"); self.enqueue_events(result); } - ActorEvent::Io(task) => { - let io_result = match task.action { - DocumentIoTask::Storage(storage_task) => IoResult { - task_id: task.task_id, - payload: DocumentIoResult::Storage(storage.handle_task(storage_task)), - }, - DocumentIoTask::CheckAnnouncePolicy { peer_id } => IoResult { + ActorEvent::Io(task) => match task.action { + DocumentIoTask::Storage(storage_task) => { + storage.handle_task(task.task_id, storage_task); + self.pending_storage_tasks.insert(task.task_id); + } + DocumentIoTask::CheckAnnouncePolicy { peer_id } => { + let io_result = IoResult { task_id: task.task_id, payload: DocumentIoResult::CheckAnnouncePolicy(announce_policy( self.doc_id.clone(), peer_id, )), - }, - }; - let actor_result = self - .actor - .handle_io_complete(now, io_result) - .expect("failed to handle IO completion"); - self.enqueue_events(actor_result); - } + }; + let actor_result = self + .actor + .handle_io_complete(now, io_result) + .expect("failed to handle IO completion"); + self.enqueue_events(actor_result); + } + }, } + self.handle_completed_storage(now, storage); + } + } + + fn handle_completed_storage(&mut self, now: UnixTimestamp, storage: &mut Storage) { + let mut completed = Vec::new(); + self.pending_storage_tasks.retain(|task_id| { + let Some(completed_task) = storage.check_pending_task(*task_id) else { + return true; + }; + completed.push((*task_id, completed_task)); + false + }); + for (task_id, completed_task) in completed { + let io_result = IoResult { + task_id, + payload: DocumentIoResult::Storage(completed_task), + }; + let actor_result = self + .actor + .handle_io_complete(now, io_result) + .expect("failed to handle IO completion"); + self.enqueue_events(actor_result); } } diff --git a/samod-test-harness/src/samod_ref.rs b/samod-test-harness/src/samod_ref.rs index f0e758b..29097f9 100644 --- a/samod-test-harness/src/samod_ref.rs +++ b/samod-test-harness/src/samod_ref.rs @@ -211,4 +211,12 @@ impl SamodRef<'_> { ) -> &[HashMap] { self.wrapper_ref().peer_state_changes(doc_id) } + + pub fn pause_storage(&mut self) { + self.wrapper().pause_storage(); + } + + pub fn resume_storage(&mut self) { + self.wrapper().resume_storage(); + } } diff --git a/samod-test-harness/src/samod_wrapper.rs b/samod-test-harness/src/samod_wrapper.rs index 3360f1e..b7b07a9 100644 --- a/samod-test-harness/src/samod_wrapper.rs +++ b/samod-test-harness/src/samod_wrapper.rs @@ -60,7 +60,10 @@ impl SamodWrapper { match loader.step(&mut rng, now) { samod_core::LoaderState::NeedIo(tasks) => { for task in tasks { - let result = storage.handle_task(task.action); + storage.handle_task(task.task_id, task.action); + let result = storage + .check_pending_task(task.task_id) + .expect("storage should not be paused"); loader.provide_io_result(IoResult { task_id: task.task_id, payload: result, @@ -85,6 +88,14 @@ impl SamodWrapper { } } + pub fn pause_storage(&mut self) { + self.storage.pause(); + } + + pub fn resume_storage(&mut self) { + self.storage.resume(); + } + /// Register a new dialer in the hub (for connector tests). pub fn add_dialer(&mut self, config: DialerConfig) -> DialerId { let DispatchedCommand { command_id, event } = HubEvent::add_dialer(config); @@ -237,7 +248,7 @@ impl SamodWrapper { pub fn start_find_document(&mut self, document_id: &DocumentId) -> CommandId { let DispatchedCommand { command_id, event } = HubEvent::find_document(document_id.clone()); self.inbox.push_back(event); - self.handle_events(); + // self.handle_events(); command_id } @@ -360,11 +371,7 @@ impl SamodWrapper { } pub fn storage(&self) -> &HashMap> { - &self.storage.0 - } - - pub fn storage_mut(&mut self) -> &mut HashMap> { - &mut self.storage.0 + self.storage.data() } pub fn push_event(&mut self, event: HubEvent) { diff --git a/samod-test-harness/src/storage.rs b/samod-test-harness/src/storage.rs index 6fbf55a..dd1b58e 100644 --- a/samod-test-harness/src/storage.rs +++ b/samod-test-harness/src/storage.rs @@ -2,30 +2,84 @@ use std::collections::HashMap; use samod_core::{ StorageKey, - io::{StorageResult, StorageTask}, + io::{IoTaskId, StorageResult, StorageTask}, }; -pub struct Storage(pub(crate) HashMap>); +pub struct Storage { + data: HashMap>, + state: StorageState, + completed_tasks: HashMap, +} + +enum StorageState { + Running, + Paused { + pending_tasks: HashMap, + }, +} impl From>> for Storage { fn from(map: HashMap>) -> Self { - Storage(map) + Storage { + data: map, + state: StorageState::Running, + completed_tasks: HashMap::new(), + } } } impl Storage { pub(crate) fn new() -> Self { - Storage(HashMap::new()) + Storage { + data: HashMap::new(), + state: StorageState::Running, + completed_tasks: HashMap::new(), + } } - pub(crate) fn handle_task(&mut self, task: StorageTask) -> StorageResult { - match task { + pub(crate) fn data(&self) -> &HashMap> { + &self.data + } + + pub(crate) fn pause(&mut self) { + if let StorageState::Running = self.state { + self.state = StorageState::Paused { + pending_tasks: HashMap::new(), + }; + } + } + + pub(crate) fn resume(&mut self) { + if let StorageState::Paused { pending_tasks } = &mut self.state { + let tasks = std::mem::take(pending_tasks); + self.state = StorageState::Running; + for (task_id, task) in tasks { + self.perform_task(task_id, task); + } + } + } + + pub(crate) fn check_pending_task(&mut self, task_id: IoTaskId) -> Option { + self.completed_tasks.remove(&task_id) + } + + pub(crate) fn handle_task(&mut self, task_id: IoTaskId, task: StorageTask) { + match &mut self.state { + StorageState::Running => self.perform_task(task_id, task), + StorageState::Paused { pending_tasks } => { + pending_tasks.insert(task_id, task); + } + } + } + + fn perform_task(&mut self, task_id: IoTaskId, task: StorageTask) { + let result = match task { StorageTask::Load { key } => StorageResult::Load { - value: self.0.get(&key).cloned(), + value: self.data.get(&key).cloned(), }, StorageTask::LoadRange { prefix } => { let values = self - .0 + .data .iter() .filter(|(k, _)| prefix.is_prefix_of(k)) .map(|(k, v)| (k.clone(), v.clone())) @@ -33,13 +87,14 @@ impl Storage { StorageResult::LoadRange { values } } StorageTask::Put { key, value } => { - self.0.insert(key.clone(), value); + self.data.insert(key.clone(), value); StorageResult::Put } StorageTask::Delete { key } => { - self.0.remove(&key); + self.data.remove(&key); StorageResult::Delete } - } + }; + self.completed_tasks.insert(task_id, result); } }