diff --git a/crates/sdk-core/src/core_tests/mod.rs b/crates/sdk-core/src/core_tests/mod.rs index 29f5df27c..cc5169404 100644 --- a/crates/sdk-core/src/core_tests/mod.rs +++ b/crates/sdk-core/src/core_tests/mod.rs @@ -14,21 +14,38 @@ use crate::{ }, worker::{ PollerBehavior, - client::mocks::{mock_manual_worker_client, mock_worker_client}, + client::mocks::{ + DEFAULT_WORKERS_REGISTRY, MockManualWorkerClient, mock_manual_worker_client, + mock_worker_client, + }, }, }; use futures_util::FutureExt; -use std::{sync::LazyLock, time::Duration}; +use std::{ + future, + sync::{Arc, LazyLock}, + time::Duration, +}; use temporalio_common::protos::{ coresdk::{ workflow_activation::{WorkflowActivationJob, workflow_activation_job}, workflow_completion::WorkflowActivationCompletion, }, temporal::api::{ - enums::v1::EventType, history::v1::WorkflowExecutionOptionsUpdatedEventAttributes, + enums::v1::EventType, + history::v1::WorkflowExecutionOptionsUpdatedEventAttributes, + namespace::v1::{NamespaceInfo, namespace_info::Capabilities}, + workflowservice::v1::{ + DescribeNamespaceResponse, PollActivityTaskQueueResponse, + RecordActivityTaskHeartbeatResponse, + }, }, }; -use tokio::{sync::Barrier, time::sleep}; +use temporalio_common::worker::WorkerTaskTypes; +use tokio::{ + sync::{Barrier, Notify}, + time::sleep, +}; #[tokio::test] async fn after_shutdown_server_is_not_polled() { @@ -111,6 +128,101 @@ async fn shutdown_interrupts_both_polls() { }; } +#[tokio::test] +async fn graceful_activity_poll_shutdown_handles_unimplemented_shutdown_worker() { + let activity_poll_started = Arc::new(Notify::new()); + let activity_poll_started_clone = activity_poll_started.clone(); + let shutdown_worker_called = Arc::new(Notify::new()); + let shutdown_worker_called_clone = shutdown_worker_called.clone(); + + let mut mock_client = MockManualWorkerClient::new(); + mock_client.expect_capabilities().returning(|| None); + mock_client + .expect_workers() + .returning(|| DEFAULT_WORKERS_REGISTRY.clone()); + mock_client.expect_is_mock().returning(|| true); + mock_client + .expect_sdk_name_and_version() + .returning(|| ("test-core".to_string(), "0.0.0".to_string())); + mock_client + .expect_identity() + .returning(|| "test-identity".to_string()); + mock_client + .expect_worker_grouping_key() + .returning(uuid::Uuid::new_v4); + mock_client + .expect_worker_instance_key() + .returning(uuid::Uuid::new_v4); + mock_client + .expect_describe_namespace() + .times(1) + .returning(|| { + async { + Ok(DescribeNamespaceResponse { + namespace_info: Some(NamespaceInfo { + capabilities: Some(Capabilities { + worker_poll_complete_on_shutdown: true, + ..Default::default() + }), + ..Default::default() + }), + ..Default::default() + }) + } + .boxed() + }); + mock_client + .expect_shutdown_worker() + .times(1) + .returning(move |_, _, _, _| { + let shutdown_worker_called = shutdown_worker_called_clone.clone(); + async move { + shutdown_worker_called.notify_one(); + Err(tonic::Status::unimplemented( + "ShutdownWorker disabled by server", + )) + } + .boxed() + }); + mock_client + .expect_poll_activity_task() + .times(1) + .returning(move |_, _| { + let activity_poll_started = activity_poll_started_clone.clone(); + async move { + activity_poll_started.notify_one(); + future::pending::>().await + } + .boxed() + }); + mock_client + .expect_record_activity_heartbeat() + .returning(|_, _| async { Ok(RecordActivityTaskHeartbeatResponse::default()) }.boxed()); + + let mut cfg = test_worker_cfg() + .activity_task_poller_behavior(PollerBehavior::SimpleMaximum(1_usize)) + .build() + .unwrap(); + cfg.task_types = WorkerTaskTypes::activity_only(); + let worker = Worker::new_test(cfg, mock_client); + worker.validate().await.unwrap(); + + let poll_fut = async { worker.poll_activity_task().await }; + let shutdown_fut = async { + activity_poll_started.notified().await; + worker.initiate_shutdown(); + shutdown_worker_called.notified().await; + }; + + let (poll_result, _) = tokio::time::timeout(Duration::from_millis(500), async { + tokio::join!(poll_fut, shutdown_fut) + }) + .await + .expect("activity poll remained pending after shutdown_worker returned UNIMPLEMENTED"); + + assert_matches!(poll_result.unwrap_err(), PollError::ShutDown); +} + #[tokio::test] async fn ignores_workflow_options_updated_event() { let mut t = TestHistoryBuilder::default(); diff --git a/crates/sdk-core/src/pollers/poll_buffer.rs b/crates/sdk-core/src/pollers/poll_buffer.rs index a4d1353de..fb2ca0129 100644 --- a/crates/sdk-core/src/pollers/poll_buffer.rs +++ b/crates/sdk-core/src/pollers/poll_buffer.rs @@ -387,7 +387,17 @@ where let capabilities = capabilities.clone(); let poll_task = tokio::spawn(async move { let r = if capabilities.graceful_poll_shutdown() { - pf(timeout_override).await + let shutdown_for_graceful_fallback = shutdown.clone(); + let local_interrupt_after_graceful_disabled = async move { + shutdown_for_graceful_fallback.cancelled().await; + while capabilities.graceful_poll_shutdown() { + tokio::time::sleep(Duration::from_millis(10)).await; + } + }; + tokio::select! { + r = pf(timeout_override) => r, + _ = local_interrupt_after_graceful_disabled => return, + } } else { let poll_interruptor = shutdown.cancelled().then(|_| async move { if let Some(w) = poll_shutdown_interrupt_wait { diff --git a/crates/sdk-core/src/worker/mod.rs b/crates/sdk-core/src/worker/mod.rs index 6082ff77e..063f5b580 100644 --- a/crates/sdk-core/src/worker/mod.rs +++ b/crates/sdk-core/src/worker/mod.rs @@ -1439,23 +1439,34 @@ impl Worker { .heartbeat_manager .as_ref() .map(|hm| hm.heartbeat_callback.clone()()); + let capabilities = self.capabilities.clone(); let handle = tokio::spawn(async move { match client .shutdown_worker(sticky_name, task_queue, task_queue_types, heartbeat) .await { Err(err) - if !matches!( + if matches!( err.code(), tonic::Code::Unimplemented | tonic::Code::Unavailable ) => { + capabilities + .graceful_poll_shutdown + .store(false, Ordering::Relaxed); + debug!( + "shutdown_worker rpc unavailable during worker shutdown; \ + disabling graceful poll shutdown and interrupting polls locally: {:?}", + err + ); + } + Err(err) => { warn!( "shutdown_worker rpc errored during worker shutdown: {:?}", err ); } - _ => {} + Ok(_) => {} } }); *guard = Some(handle);