Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 116 additions & 4 deletions crates/sdk-core/src/core_tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,38 @@ use crate::{
},
worker::{
PollerBehavior,
client::mocks::{mock_manual_worker_client, mock_worker_client},
client::mocks::{
DEFAULT_WORKERS_REGISTRY, MockManualWorkerClient, mock_manual_worker_client,
mock_worker_client,
},
},
};
use futures_util::FutureExt;
use std::{sync::LazyLock, time::Duration};
use std::{
future,
sync::{Arc, LazyLock},
time::Duration,
};
use temporalio_common::protos::{
coresdk::{
workflow_activation::{WorkflowActivationJob, workflow_activation_job},
workflow_completion::WorkflowActivationCompletion,
},
temporal::api::{
enums::v1::EventType, history::v1::WorkflowExecutionOptionsUpdatedEventAttributes,
enums::v1::EventType,
history::v1::WorkflowExecutionOptionsUpdatedEventAttributes,
namespace::v1::{NamespaceInfo, namespace_info::Capabilities},
workflowservice::v1::{
DescribeNamespaceResponse, PollActivityTaskQueueResponse,
RecordActivityTaskHeartbeatResponse,
},
},
};
use tokio::{sync::Barrier, time::sleep};
use temporalio_common::worker::WorkerTaskTypes;
use tokio::{
sync::{Barrier, Notify},
time::sleep,
};

#[tokio::test]
async fn after_shutdown_server_is_not_polled() {
Expand Down Expand Up @@ -111,6 +128,101 @@ async fn shutdown_interrupts_both_polls() {
};
}

#[tokio::test]
async fn graceful_activity_poll_shutdown_handles_unimplemented_shutdown_worker() {
let activity_poll_started = Arc::new(Notify::new());
let activity_poll_started_clone = activity_poll_started.clone();
let shutdown_worker_called = Arc::new(Notify::new());
let shutdown_worker_called_clone = shutdown_worker_called.clone();

let mut mock_client = MockManualWorkerClient::new();
mock_client.expect_capabilities().returning(|| None);
mock_client
.expect_workers()
.returning(|| DEFAULT_WORKERS_REGISTRY.clone());
mock_client.expect_is_mock().returning(|| true);
mock_client
.expect_sdk_name_and_version()
.returning(|| ("test-core".to_string(), "0.0.0".to_string()));
mock_client
.expect_identity()
.returning(|| "test-identity".to_string());
mock_client
.expect_worker_grouping_key()
.returning(uuid::Uuid::new_v4);
mock_client
.expect_worker_instance_key()
.returning(uuid::Uuid::new_v4);
mock_client
.expect_describe_namespace()
.times(1)
.returning(|| {
async {
Ok(DescribeNamespaceResponse {
namespace_info: Some(NamespaceInfo {
capabilities: Some(Capabilities {
worker_poll_complete_on_shutdown: true,
..Default::default()
}),
..Default::default()
}),
..Default::default()
})
}
.boxed()
});
mock_client
.expect_shutdown_worker()
.times(1)
.returning(move |_, _, _, _| {
let shutdown_worker_called = shutdown_worker_called_clone.clone();
async move {
shutdown_worker_called.notify_one();
Err(tonic::Status::unimplemented(
"ShutdownWorker disabled by server",
))
}
.boxed()
});
mock_client
.expect_poll_activity_task()
.times(1)
.returning(move |_, _| {
let activity_poll_started = activity_poll_started_clone.clone();
async move {
activity_poll_started.notify_one();
future::pending::<Result<PollActivityTaskQueueResponse, tonic::Status>>().await
}
.boxed()
});
mock_client
.expect_record_activity_heartbeat()
.returning(|_, _| async { Ok(RecordActivityTaskHeartbeatResponse::default()) }.boxed());

let mut cfg = test_worker_cfg()
.activity_task_poller_behavior(PollerBehavior::SimpleMaximum(1_usize))
.build()
.unwrap();
cfg.task_types = WorkerTaskTypes::activity_only();
let worker = Worker::new_test(cfg, mock_client);
worker.validate().await.unwrap();

let poll_fut = async { worker.poll_activity_task().await };
let shutdown_fut = async {
activity_poll_started.notified().await;
worker.initiate_shutdown();
shutdown_worker_called.notified().await;
};

let (poll_result, _) = tokio::time::timeout(Duration::from_millis(500), async {
tokio::join!(poll_fut, shutdown_fut)
})
.await
.expect("activity poll remained pending after shutdown_worker returned UNIMPLEMENTED");

assert_matches!(poll_result.unwrap_err(), PollError::ShutDown);
}

#[tokio::test]
async fn ignores_workflow_options_updated_event() {
let mut t = TestHistoryBuilder::default();
Expand Down
12 changes: 11 additions & 1 deletion crates/sdk-core/src/pollers/poll_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,17 @@ where
let capabilities = capabilities.clone();
let poll_task = tokio::spawn(async move {
let r = if capabilities.graceful_poll_shutdown() {
pf(timeout_override).await
let shutdown_for_graceful_fallback = shutdown.clone();
let local_interrupt_after_graceful_disabled = async move {
shutdown_for_graceful_fallback.cancelled().await;
while capabilities.graceful_poll_shutdown() {
tokio::time::sleep(Duration::from_millis(10)).await;
}
};
tokio::select! {
r = pf(timeout_override) => r,
_ = local_interrupt_after_graceful_disabled => return,
}
} else {
let poll_interruptor = shutdown.cancelled().then(|_| async move {
if let Some(w) = poll_shutdown_interrupt_wait {
Expand Down
15 changes: 13 additions & 2 deletions crates/sdk-core/src/worker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1439,23 +1439,34 @@ impl Worker {
.heartbeat_manager
.as_ref()
.map(|hm| hm.heartbeat_callback.clone()());
let capabilities = self.capabilities.clone();
let handle = tokio::spawn(async move {
match client
.shutdown_worker(sticky_name, task_queue, task_queue_types, heartbeat)
.await
{
Err(err)
if !matches!(
if matches!(
err.code(),
tonic::Code::Unimplemented | tonic::Code::Unavailable
) =>
{
capabilities
.graceful_poll_shutdown
.store(false, Ordering::Relaxed);
debug!(
"shutdown_worker rpc unavailable during worker shutdown; \
disabling graceful poll shutdown and interrupting polls locally: {:?}",
err
);
}
Err(err) => {
warn!(
"shutdown_worker rpc errored during worker shutdown: {:?}",
err
);
}
_ => {}
Ok(_) => {}
}
});
*guard = Some(handle);
Expand Down
Loading