diff --git a/crates/client/src/lib.rs b/crates/client/src/lib.rs index 0fb02d22d..967507c14 100644 --- a/crates/client/src/lib.rs +++ b/crates/client/src/lib.rs @@ -1149,7 +1149,7 @@ where .and_then(|d| d.try_into().ok()), workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()), workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()), - search_attributes: options.search_attributes.map(|d| d.into()), + search_attributes: options.search_attributes.map(|t| t.into_proto()), cron_schedule: options.cron_schedule.unwrap_or_default(), header: options.header.or(start_signal.header), user_metadata, @@ -1185,7 +1185,7 @@ where .and_then(|d| d.try_into().ok()), workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()), workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()), - search_attributes: options.search_attributes.map(|d| d.into()), + search_attributes: options.search_attributes.map(|t| t.into_proto()), cron_schedule: options.cron_schedule.unwrap_or_default(), request_eager_execution: options.enable_eager_workflow_start, retry_policy: options.retry_policy, diff --git a/crates/client/src/options_structs.rs b/crates/client/src/options_structs.rs index ed9fb97e1..446c25de4 100644 --- a/crates/client/src/options_structs.rs +++ b/crates/client/src/options_structs.rs @@ -6,7 +6,7 @@ use temporalio_common::{ protos::temporal::api::{ common::{ self, - v1::{Header, Payload, Payloads}, + v1::{Header, Payloads}, }, enums::v1::{ ArchivalState, HistoryEventFilterType, QueryRejectCondition, WorkflowIdConflictPolicy, @@ -15,6 +15,7 @@ use temporalio_common::{ replication::v1::ClusterReplicationConfig, workflowservice::v1::RegisterNamespaceRequest, }, + search_attributes::SearchAttributes, telemetry::metrics::TemporalMeter, }; use tokio_rustls::rustls::client::danger::ServerCertVerifier; @@ -286,8 +287,8 @@ pub struct WorkflowStartOptions { /// Optionally set a cron schedule for the workflow pub cron_schedule: Option, - /// Optionally associate extra search attributes with a workflow - pub search_attributes: Option>, + /// Additional search attributes for the workflow. + pub search_attributes: Option, /// Optionally enable Eager Workflow Start, a latency optimization using local workers /// NOTE: Experimental diff --git a/crates/common-wasm/Cargo.toml b/crates/common-wasm/Cargo.toml index 7676ed591..2d95a1782 100644 --- a/crates/common-wasm/Cargo.toml +++ b/crates/common-wasm/Cargo.toml @@ -23,9 +23,14 @@ bon = { workspace = true } crc32fast = "1" derive_more = { workspace = true } erased-serde = "0.4" +# Only the `alloc` feature is needed for RFC3339 formatting. Do NOT use +# Utc::now() or other clock functions — they are unavailable in +# wasm32-unknown-unknown without the `wasmbind` feature. +chrono = { version = "0.4", default-features = false, features = ["alloc"] } futures = { version = "0.3", default-features = false, features = ["alloc"] } parking_lot = { version = "0.12" } prost = { workspace = true } +prost-types = { workspace = true } serde = { version = "1.0", features = ["derive"] } serde_json = { workspace = true } thiserror = { workspace = true } diff --git a/crates/common-wasm/src/lib.rs b/crates/common-wasm/src/lib.rs index 2634935a0..cf13c043c 100644 --- a/crates/common-wasm/src/lib.rs +++ b/crates/common-wasm/src/lib.rs @@ -16,11 +16,16 @@ pub mod protos { pub use temporalio_protos::*; } +pub mod search_attributes; pub mod worker; mod workflow_definition; pub use activity_definition::{ActivityDefinition, ActivityError}; pub use priority::Priority; +pub use search_attributes::{ + SearchAttributeError, SearchAttributeKey, SearchAttributeUpdate, SearchAttributeValue, + SearchAttributes, Timestamp, +}; pub use worker::WorkerDeploymentVersion; pub use workflow_definition::{ HasWorkflowDefinition, QueryDefinition, SignalDefinition, UntypedWorkflow, UpdateDefinition, diff --git a/crates/common-wasm/src/search_attributes.rs b/crates/common-wasm/src/search_attributes.rs new file mode 100644 index 000000000..b29223e7e --- /dev/null +++ b/crates/common-wasm/src/search_attributes.rs @@ -0,0 +1,1236 @@ +//! Type-safe search attribute APIs for the Temporal Rust SDK. +//! +//! Search attributes are key-value pairs attached to workflows that enable +//! server-side filtering via visibility queries. This module provides a typed +//! layer over the raw proto payloads so that attribute values are checked at +//! compile time. +//! +//! # Example +//! +//! ``` +//! use temporalio_common_wasm::search_attributes::SearchAttributeKey; +//! +//! const MY_BOOL: SearchAttributeKey = SearchAttributeKey::bool("my_bool"); +//! const MY_KW: SearchAttributeKey = SearchAttributeKey::keyword("my_keyword"); +//! +//! let update = MY_BOOL.value_set(true); +//! let unset = MY_KW.value_unset(); +//! ``` + +use std::collections::HashMap; +use std::marker::PhantomData; + +use tracing::warn; + +use crate::protos::temporal::api::common::v1::{ + Payload, SearchAttributes as ProtoSearchAttributes, +}; +use crate::protos::temporal::api::enums::v1::IndexedValueType; + +/// Metadata key for the search attribute value type, kept consistent across all SDKs. +const TYPE_METADATA_KEY: &str = "type"; + +/// Errors arising from search attribute serialization or deserialization. +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum SearchAttributeError { + /// JSON serialization failed. + #[error("failed to serialize search attribute value: {0}")] + Serialization(#[from] serde_json::Error), + + /// The payload is missing required metadata or has an unexpected encoding. + #[error("invalid search attribute payload: {reason}")] + InvalidPayload { + /// Description of what was wrong with the payload. + reason: String, + }, + + /// A timestamp value could not be formatted or parsed as RFC3339. + #[error("invalid timestamp: {0}")] + InvalidTimestamp(String), +} + +// --------------------------------------------------------------------------- +// SDK-owned Timestamp type +// --------------------------------------------------------------------------- + +/// An SDK-owned timestamp for Datetime search attributes. +/// +/// This type decouples the public API from `prost_types::Timestamp`. Conversion +/// traits are provided for [`prost_types::Timestamp`] and +/// [`std::time::SystemTime`]. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Timestamp { + seconds: i64, + nanos: i32, +} + +impl Timestamp { + /// The maximum valid value for nanoseconds. + const MAX_NANOS: i32 = 999_999_999; + + /// Creates a new `Timestamp`. + /// + /// # Arguments + /// * `seconds` — seconds since the Unix epoch (negative for pre-epoch). + /// * `nanos` — non-negative nanosecond offset within the second, + /// in the range `[0, 999_999_999]`. Values outside this range are + /// clamped. + pub fn new(seconds: i64, nanos: i32) -> Self { + Self { + seconds, + nanos: nanos.clamp(0, Self::MAX_NANOS), + } + } + + /// Returns seconds since the Unix epoch. + pub fn seconds(&self) -> i64 { + self.seconds + } + + /// Returns the nanosecond component (always in `[0, 999_999_999]`). + pub fn nanos(&self) -> i32 { + self.nanos + } + + /// Returns this timestamp as a `prost_types::Timestamp`. + pub fn to_prost(&self) -> prost_types::Timestamp { + prost_types::Timestamp { + seconds: self.seconds, + nanos: self.nanos, + } + } +} + +impl std::fmt::Display for Timestamp { + /// Formats the timestamp as an RFC3339 string (e.g., `2023-11-14T22:13:20.000000000Z`). + /// Falls back to `Debug` formatting if the timestamp is out of chrono's range. + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match timestamp_to_rfc3339(self) { + Ok(s) => f.write_str(&s), + Err(_) => write!(f, "Timestamp({}, {})", self.seconds, self.nanos), + } + } +} + +impl From for Timestamp { + fn from(ts: prost_types::Timestamp) -> Self { + Timestamp::new(ts.seconds, ts.nanos) + } +} + +impl From for prost_types::Timestamp { + fn from(ts: Timestamp) -> Self { + prost_types::Timestamp { + seconds: ts.seconds(), + nanos: ts.nanos(), + } + } +} + +impl From for Timestamp { + fn from(st: std::time::SystemTime) -> Self { + match st.duration_since(std::time::UNIX_EPOCH) { + Ok(dur) => Timestamp::new(dur.as_secs() as i64, dur.subsec_nanos() as i32), + Err(e) => { + // Normalize to protobuf convention: nanos always non-negative. + // Example: 1.25s before epoch → { seconds: -2, nanos: 750_000_000 } + let dur = e.duration(); + let secs = dur.as_secs() as i64; + let nanos = dur.subsec_nanos(); + if nanos == 0 { + Timestamp::new(-secs, 0) + } else { + Timestamp::new(-(secs + 1), (1_000_000_000 - nanos) as i32) + } + } + } + } +} + +impl TryFrom for std::time::SystemTime { + type Error = SearchAttributeError; + + fn try_from(ts: Timestamp) -> Result { + let epoch = std::time::UNIX_EPOCH; + if ts.seconds >= 0 { + epoch + .checked_add(std::time::Duration::new( + ts.seconds as u64, + ts.nanos.max(0) as u32, + )) + .ok_or_else(|| { + SearchAttributeError::InvalidTimestamp( + "timestamp out of SystemTime range".into(), + ) + }) + } else { + // Reverse the normalization: { seconds: -2, nanos: 750_000_000 } + // means 1.25s before epoch → Duration::new(1, 250_000_000) + let abs_secs = ts.seconds.unsigned_abs(); + let nanos = ts.nanos.max(0) as u32; + let dur = if nanos == 0 { + std::time::Duration::new(abs_secs, 0) + } else { + std::time::Duration::new(abs_secs - 1, 1_000_000_000 - nanos) + }; + epoch.checked_sub(dur).ok_or_else(|| { + SearchAttributeError::InvalidTimestamp("timestamp out of SystemTime range".into()) + }) + } + } +} + +// --------------------------------------------------------------------------- +// SearchAttributeValue trait +// --------------------------------------------------------------------------- + +mod private { + pub trait Sealed {} + impl Sealed for bool {} + impl Sealed for i64 {} + impl Sealed for f64 {} + impl Sealed for String {} + impl Sealed for super::Timestamp {} + impl Sealed for Vec {} +} + +/// A value type that can be stored as a Temporal search attribute. +/// +/// This trait is sealed and implemented for: `bool`, `i64`, `f64`, `String`, +/// [`Timestamp`], and `Vec`. +pub trait SearchAttributeValue: private::Sealed + Clone + Sized { + /// Encode this value into a search attribute [`Payload`]. + fn to_search_attribute_payload( + &self, + indexed_value_type: IndexedValueType, + ) -> Result; + + /// Decode a value from a search attribute [`Payload`]. + fn from_search_attribute_payload(payload: &Payload) -> Result; + + /// The default [`IndexedValueType`] for this Rust type. + /// + /// This is used internally when a key does not explicitly specify the + /// indexed value type. Most callers should use [`SearchAttributeKey`] + /// constructors rather than calling this directly. + fn default_indexed_value_type() -> IndexedValueType; +} + +// --------------------------------------------------------------------------- +// Shared JSON payload helpers (reuses the SDK's JSON payload encoding conventions) +// --------------------------------------------------------------------------- + +#[allow(unreachable_patterns)] // Wildcard is intentional for forward-compat with new proto variants +fn type_metadata_str(ivt: IndexedValueType) -> &'static str { + match ivt { + IndexedValueType::Bool => "Bool", + IndexedValueType::Int => "Int", + IndexedValueType::Double => "Double", + IndexedValueType::Keyword => "Keyword", + IndexedValueType::Text => "Text", + IndexedValueType::Datetime => "Datetime", + IndexedValueType::KeywordList => "KeywordList", + IndexedValueType::Unspecified | _ => "Unspecified", + } +} + +/// Encode a serde-serializable value into a search attribute [`Payload`]. +/// +/// This mirrors the encoding used by the SDK's +/// [`SerdeJsonPayloadConverter`][crate::data_converters], but adds the +/// search-attribute `type` metadata key. By using the same `json/plain` +/// encoding and metadata layout, payloads produced here are decode-compatible +/// with the standard payload converter and vice-versa. +fn encode_json_search_attr( + value: &T, + indexed_value_type: IndexedValueType, +) -> Result { + let data = serde_json::to_vec(value)?; + let mut metadata = HashMap::with_capacity(2); + metadata.insert("encoding".to_string(), b"json/plain".to_vec()); + metadata.insert( + TYPE_METADATA_KEY.to_string(), + type_metadata_str(indexed_value_type).as_bytes().to_vec(), + ); + Ok(Payload { + metadata, + data, + ..Default::default() + }) +} + +/// Decode a search attribute [`Payload`] back into a concrete type. +/// +/// Validates the `json/plain` encoding metadata (matching the standard payload +/// converter expectation) before attempting JSON deserialization. +fn decode_json_search_attr( + payload: &Payload, +) -> Result { + let encoding = + payload + .metadata + .get("encoding") + .ok_or_else(|| SearchAttributeError::InvalidPayload { + reason: "missing encoding metadata".into(), + })?; + if encoding.as_slice() != b"json/plain" { + return Err(SearchAttributeError::InvalidPayload { + reason: format!( + "expected encoding 'json/plain', got '{}'", + String::from_utf8_lossy(encoding) + ), + }); + } + Ok(serde_json::from_slice(&payload.data)?) +} + +// --------------------------------------------------------------------------- +// Macro for simple (serde-native) SearchAttributeValue impls +// --------------------------------------------------------------------------- + +/// Implements [`SearchAttributeValue`] for types that are directly +/// serde-serializable as their JSON wire representation (no special conversion). +macro_rules! impl_simple_search_attribute_value { + ($ty:ty, $ivt:expr) => { + impl SearchAttributeValue for $ty { + fn to_search_attribute_payload( + &self, + indexed_value_type: IndexedValueType, + ) -> Result { + encode_json_search_attr(self, indexed_value_type) + } + + fn from_search_attribute_payload( + payload: &Payload, + ) -> Result { + decode_json_search_attr(payload) + } + + fn default_indexed_value_type() -> IndexedValueType { + $ivt + } + } + }; +} + +impl_simple_search_attribute_value!(bool, IndexedValueType::Bool); +impl_simple_search_attribute_value!(i64, IndexedValueType::Int); +impl_simple_search_attribute_value!(String, IndexedValueType::Keyword); +impl_simple_search_attribute_value!(Vec, IndexedValueType::KeywordList); + +// f64 requires a manual impl to reject NaN and Infinity, which serde_json +// silently serializes as `null` rather than returning an error. +impl SearchAttributeValue for f64 { + fn to_search_attribute_payload( + &self, + indexed_value_type: IndexedValueType, + ) -> Result { + if !self.is_finite() { + return Err(SearchAttributeError::InvalidPayload { + reason: format!("f64 search attribute value must be finite, got {}", self), + }); + } + encode_json_search_attr(self, indexed_value_type) + } + + fn from_search_attribute_payload(payload: &Payload) -> Result { + decode_json_search_attr(payload) + } + + fn default_indexed_value_type() -> IndexedValueType { + IndexedValueType::Double + } +} + +// --------------------------------------------------------------------------- +// Timestamp SearchAttributeValue impl (RFC3339 string on the wire) +// --------------------------------------------------------------------------- + +/// Format a [`Timestamp`] as an RFC3339 string using `chrono`. +fn timestamp_to_rfc3339(ts: &Timestamp) -> Result { + use chrono::{DateTime, Utc}; + + let nanos = u32::try_from(ts.nanos()).unwrap_or(0); + let dt = DateTime::::from_timestamp(ts.seconds(), nanos).ok_or_else(|| { + SearchAttributeError::InvalidTimestamp(format!( + "cannot represent seconds={} nanos={} as DateTime", + ts.seconds(), + ts.nanos() + )) + })?; + Ok(dt.to_rfc3339_opts(chrono::SecondsFormat::Nanos, true)) +} + +/// Parse an RFC3339 string into a [`Timestamp`] using `chrono`. +fn rfc3339_to_timestamp(s: &str) -> Result { + use chrono::DateTime; + + // Strip surrounding quotes if present — some SDKs or raw payloads may + // pass the RFC3339 string with JSON-style quotes still attached. + let s = s.trim_matches('"'); + let dt = DateTime::parse_from_rfc3339(s).map_err(|e| { + SearchAttributeError::InvalidTimestamp(format!("failed to parse RFC3339 '{}': {}", s, e)) + })?; + Ok(Timestamp::new( + dt.timestamp(), + dt.timestamp_subsec_nanos() as i32, + )) +} + +impl SearchAttributeValue for Timestamp { + fn to_search_attribute_payload( + &self, + indexed_value_type: IndexedValueType, + ) -> Result { + let rfc3339 = timestamp_to_rfc3339(self)?; + encode_json_search_attr(&rfc3339, indexed_value_type) + } + + fn from_search_attribute_payload(payload: &Payload) -> Result { + let s: String = decode_json_search_attr(payload)?; + rfc3339_to_timestamp(&s) + } + + fn default_indexed_value_type() -> IndexedValueType { + IndexedValueType::Datetime + } +} + +// --------------------------------------------------------------------------- +// SearchAttributeKey +// --------------------------------------------------------------------------- + +/// A typed handle for a named search attribute, carrying its value type at the +/// type level. Construct via the const factory methods such as +/// [`SearchAttributeKey::bool`], [`SearchAttributeKey::keyword`], etc. +/// +/// # Key names +/// +/// Key names must be `&'static str`, which enables compile-time construction +/// via `const` but means runtime-determined key names are not supported. +/// For dynamic key names (e.g., from config), use +/// [`SearchAttributes::raw_payload`] as an escape hatch for untyped access. +/// +/// ``` +/// use temporalio_common_wasm::search_attributes::SearchAttributeKey; +/// +/// const MY_KEY: SearchAttributeKey = SearchAttributeKey::keyword("my_attr"); +/// ``` +#[derive(Debug, Clone, Copy)] +pub struct SearchAttributeKey { + name: &'static str, + indexed_value_type: IndexedValueType, + _marker: PhantomData, +} + +impl SearchAttributeKey { + /// Returns the attribute name used as the key in the proto map. + pub fn name(&self) -> &'static str { + self.name + } + + /// Returns the [`IndexedValueType`] configured for this key. + pub fn indexed_value_type(&self) -> IndexedValueType { + self.indexed_value_type + } + + /// Create a [`SearchAttributeUpdate`] that sets the attribute to the given value. + /// + /// # Panics + /// + /// Panics if the value cannot be serialized to JSON. This can happen for + /// `f64` values that are `NaN` or `Infinity` (which are not valid JSON), + /// or for `Timestamp` values with out-of-range seconds. Use + /// [`try_value_set`](Self::try_value_set) for a fallible alternative. + pub fn value_set(&self, val: T) -> SearchAttributeUpdate { + self.try_value_set(val) + .expect("search attribute serialization failed (use try_value_set for non-finite f64 or out-of-range timestamps)") + } + + /// Fallible version of [`value_set`](Self::value_set). Returns an error + /// instead of panicking if the value cannot be serialized. + pub fn try_value_set(&self, val: T) -> Result { + let payload = val.to_search_attribute_payload(self.indexed_value_type)?; + Ok(SearchAttributeUpdate { + name: self.name.to_string(), + payload: Some(payload), + }) + } + + /// Create a [`SearchAttributeUpdate`] that removes this attribute. + pub fn value_unset(&self) -> SearchAttributeUpdate { + SearchAttributeUpdate { + name: self.name.to_string(), + payload: None, + } + } +} + +impl SearchAttributeKey { + /// Create a key for a `Bool`-typed search attribute. + pub const fn bool(name: &'static str) -> Self { + Self { + name, + indexed_value_type: IndexedValueType::Bool, + _marker: PhantomData, + } + } +} + +impl SearchAttributeKey { + /// Create a key for an `Int`-typed search attribute. + pub const fn int(name: &'static str) -> Self { + Self { + name, + indexed_value_type: IndexedValueType::Int, + _marker: PhantomData, + } + } +} + +impl SearchAttributeKey { + /// Create a key for a `Double`-typed search attribute. + pub const fn float(name: &'static str) -> Self { + Self { + name, + indexed_value_type: IndexedValueType::Double, + _marker: PhantomData, + } + } +} + +impl SearchAttributeKey { + /// Create a key for a `Keyword`-typed search attribute. + pub const fn keyword(name: &'static str) -> Self { + Self { + name, + indexed_value_type: IndexedValueType::Keyword, + _marker: PhantomData, + } + } + + /// Create a key for a `Text`-typed search attribute. + pub const fn text(name: &'static str) -> Self { + Self { + name, + indexed_value_type: IndexedValueType::Text, + _marker: PhantomData, + } + } +} + +impl SearchAttributeKey { + /// Create a key for a `Datetime`-typed search attribute. + pub const fn datetime(name: &'static str) -> Self { + Self { + name, + indexed_value_type: IndexedValueType::Datetime, + _marker: PhantomData, + } + } +} + +impl SearchAttributeKey> { + /// Create a key for a `KeywordList`-typed search attribute. + pub const fn keyword_list(name: &'static str) -> Self { + Self { + name, + indexed_value_type: IndexedValueType::KeywordList, + _marker: PhantomData, + } + } +} + +// --------------------------------------------------------------------------- +// SearchAttributeUpdate +// --------------------------------------------------------------------------- + +/// A pending mutation to a single search attribute. +/// +/// When `payload` is `None`, the attribute should be removed. The semantics +/// differ slightly depending on how the update is consumed: +/// +/// - [`SearchAttributes::new`] / [`SearchAttributes::apply`]: a `None` payload +/// removes the key from the in-memory collection (the key is simply absent). +/// - [`SearchAttributes::updates_to_proto`]: a `None` payload produces an +/// empty [`Payload`] in the proto map, signaling the server to clear that +/// attribute. +#[derive(Debug, Clone)] +pub struct SearchAttributeUpdate { + pub(crate) name: String, + pub(crate) payload: Option, +} + +impl SearchAttributeUpdate { + /// Returns the attribute name being updated. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns `true` if this update removes the attribute. + pub fn is_unset(&self) -> bool { + self.payload.is_none() + } +} + +// --------------------------------------------------------------------------- +// SearchAttributes +// --------------------------------------------------------------------------- + +/// A collection of search attribute payloads, providing type-safe access via +/// [`SearchAttributeKey`]. +#[derive(Debug, Clone, Default, PartialEq)] +pub struct SearchAttributes { + fields: HashMap, +} + +impl SearchAttributes { + /// Construct from an iterator of [`SearchAttributeUpdate`]s. + /// + /// Updates with `None` payloads remove any existing entry for that key. + pub fn new(updates: impl IntoIterator) -> Self { + let mut fields = HashMap::new(); + for update in updates { + match update.payload { + Some(payload) => { + fields.insert(update.name, payload); + } + None => { + fields.remove(&update.name); + } + } + } + Self { fields } + } + + /// Apply a single update to this collection. If the update sets a value, + /// it is inserted (replacing any existing entry); if the update unsets a + /// value, the entry is removed. + pub fn apply(&mut self, update: SearchAttributeUpdate) { + match update.payload { + Some(payload) => { + self.fields.insert(update.name, payload); + } + None => { + self.fields.remove(&update.name); + } + } + } + + /// Retrieve a typed value. Returns `None` if the key is absent or + /// deserialization fails (graceful degradation — no panic on type mismatch). + pub fn get(&self, key: &SearchAttributeKey) -> Option { + let payload = self.fields.get(key.name())?; + match T::from_search_attribute_payload(payload) { + Ok(val) => Some(val), + Err(e) => { + warn!( + key = key.name(), + error = %e, + "Failed to deserialize search attribute; returning None. \ + Use try_get() for explicit error handling." + ); + None + } + } + } + + /// Retrieve a typed value, distinguishing "key absent" from "deserialization + /// failed". Returns `Ok(None)` if the key is absent, `Ok(Some(val))` on + /// success, or `Err` if the payload is present but cannot be deserialized. + pub fn try_get( + &self, + key: &SearchAttributeKey, + ) -> Result, SearchAttributeError> { + match self.fields.get(key.name()) { + None => Ok(None), + Some(payload) => T::from_search_attribute_payload(payload).map(Some), + } + } + + /// Returns `true` if a payload exists for the given key. + pub fn contains_key(&self, key: &SearchAttributeKey) -> bool { + self.fields.contains_key(key.name()) + } + + /// Returns true if there are no search attributes. + pub fn is_empty(&self) -> bool { + self.fields.is_empty() + } + + /// Returns the number of search attributes. + pub fn len(&self) -> usize { + self.fields.len() + } + + /// Returns an iterator over the attribute names in this collection. + pub fn keys(&self) -> impl Iterator { + self.fields.keys().map(|s| s.as_str()) + } + + /// Returns a reference to the raw payload for the given attribute name, + /// if present. This is useful for advanced use cases such as forwarding + /// payloads without deserializing them. + pub fn raw_payload(&self, name: &str) -> Option<&Payload> { + self.fields.get(name) + } + + /// Convert to the proto wire representation. + pub fn to_proto(&self) -> ProtoSearchAttributes { + ProtoSearchAttributes { + indexed_fields: self.fields.clone(), + } + } + + /// Convert to the proto wire representation, consuming `self` to avoid + /// cloning. + pub fn into_proto(self) -> ProtoSearchAttributes { + ProtoSearchAttributes { + indexed_fields: self.fields, + } + } + + /// Construct from the proto wire representation by cloning the inner map. + pub fn from_proto(attrs: &ProtoSearchAttributes) -> Self { + Self { + fields: attrs.indexed_fields.clone(), + } + } +} + +impl From for SearchAttributes { + /// Construct from an owned proto, moving the inner map without cloning. + fn from(attrs: ProtoSearchAttributes) -> Self { + Self { + fields: attrs.indexed_fields, + } + } +} + +impl SearchAttributes { + /// Convert to the proto representation, producing empty-data payloads for + /// entries that were unset. This is used when building an upsert command + /// that needs to explicitly clear attributes on the server. + pub fn updates_to_proto( + updates: impl IntoIterator, + ) -> ProtoSearchAttributes { + let mut indexed_fields = HashMap::new(); + for update in updates { + let payload = update.payload.unwrap_or_default(); + indexed_fields.insert(update.name, payload); + } + ProtoSearchAttributes { indexed_fields } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const BOOL_KEY: SearchAttributeKey = SearchAttributeKey::bool("my_bool"); + const INT_KEY: SearchAttributeKey = SearchAttributeKey::int("my_int"); + const FLOAT_KEY: SearchAttributeKey = SearchAttributeKey::float("my_float"); + const KW_KEY: SearchAttributeKey = SearchAttributeKey::keyword("my_keyword"); + const TEXT_KEY: SearchAttributeKey = SearchAttributeKey::text("my_text"); + const DT_KEY: SearchAttributeKey = SearchAttributeKey::datetime("my_datetime"); + const KWL_KEY: SearchAttributeKey> = + SearchAttributeKey::keyword_list("my_keyword_list"); + + fn assert_payload_metadata(payload: &Payload, expected_type: &str) { + assert_eq!( + payload.metadata.get("encoding").unwrap(), + b"json/plain".as_slice() + ); + assert_eq!( + payload.metadata.get(TYPE_METADATA_KEY).unwrap(), + expected_type.as_bytes() + ); + } + + #[test] + fn round_trip_bool() { + let val = true; + let payload = val + .to_search_attribute_payload(IndexedValueType::Bool) + .unwrap(); + assert_payload_metadata(&payload, "Bool"); + assert_eq!(bool::from_search_attribute_payload(&payload).unwrap(), true); + } + + #[test] + fn round_trip_int() { + let val: i64 = -42; + let payload = val + .to_search_attribute_payload(IndexedValueType::Int) + .unwrap(); + assert_payload_metadata(&payload, "Int"); + assert_eq!(i64::from_search_attribute_payload(&payload).unwrap(), -42); + } + + #[test] + fn round_trip_double() { + let val: f64 = 3.14; + let payload = val + .to_search_attribute_payload(IndexedValueType::Double) + .unwrap(); + assert_payload_metadata(&payload, "Double"); + let decoded = f64::from_search_attribute_payload(&payload).unwrap(); + assert!((decoded - 3.14).abs() < f64::EPSILON); + } + + #[test] + fn round_trip_keyword() { + let val = "hello".to_string(); + let payload = val + .to_search_attribute_payload(IndexedValueType::Keyword) + .unwrap(); + assert_payload_metadata(&payload, "Keyword"); + assert_eq!( + String::from_search_attribute_payload(&payload).unwrap(), + "hello" + ); + } + + #[test] + fn round_trip_text() { + let val = "some long text".to_string(); + let payload = val + .to_search_attribute_payload(IndexedValueType::Text) + .unwrap(); + assert_payload_metadata(&payload, "Text"); + assert_eq!( + String::from_search_attribute_payload(&payload).unwrap(), + "some long text" + ); + } + + #[test] + fn round_trip_datetime() { + let ts = Timestamp::new(1_700_000_000, 123_456_789); + let payload = ts + .to_search_attribute_payload(IndexedValueType::Datetime) + .unwrap(); + assert_payload_metadata(&payload, "Datetime"); + + let json_str: String = serde_json::from_slice(&payload.data).unwrap(); + assert!(json_str.ends_with('Z')); + assert!(json_str.contains('T')); + + let decoded = Timestamp::from_search_attribute_payload(&payload).unwrap(); + assert_eq!(decoded.seconds(), ts.seconds()); + assert_eq!(decoded.nanos(), ts.nanos()); + + let attrs = SearchAttributes::new([DT_KEY.value_set(ts.clone())]); + let got = attrs.get(&DT_KEY).unwrap(); + assert_eq!(got.seconds(), ts.seconds()); + assert_eq!(got.nanos(), ts.nanos()); + } + + #[test] + fn round_trip_datetime_no_nanos() { + let ts = Timestamp::new(0, 0); + let payload = ts + .to_search_attribute_payload(IndexedValueType::Datetime) + .unwrap(); + let decoded = Timestamp::from_search_attribute_payload(&payload).unwrap(); + assert_eq!(decoded.seconds(), 0); + assert_eq!(decoded.nanos(), 0); + } + + #[test] + fn round_trip_keyword_list() { + let val = vec!["a".to_string(), "b".to_string(), "c".to_string()]; + let payload = val + .to_search_attribute_payload(IndexedValueType::KeywordList) + .unwrap(); + assert_payload_metadata(&payload, "KeywordList"); + assert_eq!( + Vec::::from_search_attribute_payload(&payload).unwrap(), + vec!["a", "b", "c"] + ); + } + + #[test] + fn typed_search_attributes_new_and_get() { + let attrs = SearchAttributes::new([ + BOOL_KEY.value_set(true), + INT_KEY.value_set(99), + FLOAT_KEY.value_set(2.72), + KW_KEY.value_set("kw_val".into()), + TEXT_KEY.value_set("text_val".into()), + KWL_KEY.value_set(vec!["x".into(), "y".into()]), + ]); + + assert_eq!(attrs.len(), 6); + assert!(!attrs.is_empty()); + assert_eq!(attrs.get(&BOOL_KEY), Some(true)); + assert_eq!(attrs.get(&INT_KEY), Some(99)); + assert!((attrs.get(&FLOAT_KEY).unwrap() - 2.72).abs() < f64::EPSILON); + assert_eq!(attrs.get(&KW_KEY), Some("kw_val".into())); + assert_eq!(attrs.get(&TEXT_KEY), Some("text_val".into())); + assert_eq!( + attrs.get(&KWL_KEY), + Some(vec!["x".to_string(), "y".to_string()]) + ); + } + + #[test] + fn to_proto_from_proto_round_trip() { + let attrs = SearchAttributes::new([BOOL_KEY.value_set(false), INT_KEY.value_set(7)]); + + let proto = attrs.to_proto(); + assert_eq!(proto.indexed_fields.len(), 2); + + let restored = SearchAttributes::from_proto(&proto); + assert_eq!(restored.get(&BOOL_KEY), Some(false)); + assert_eq!(restored.get(&INT_KEY), Some(7)); + } + + #[test] + fn value_unset_removes_entry() { + let attrs = SearchAttributes::new([BOOL_KEY.value_set(true), BOOL_KEY.value_unset()]); + assert!(attrs.is_empty()); + assert_eq!(attrs.get(&BOOL_KEY), None); + } + + #[test] + fn keyword_vs_text_disambiguation() { + let kw_update = KW_KEY.value_set("same_value".into()); + let text_update = TEXT_KEY.value_set("same_value".into()); + + let kw_payload = kw_update.payload.as_ref().unwrap(); + let text_payload = text_update.payload.as_ref().unwrap(); + + assert_eq!( + kw_payload.metadata.get(TYPE_METADATA_KEY).unwrap(), + b"Keyword" + ); + assert_eq!( + text_payload.metadata.get(TYPE_METADATA_KEY).unwrap(), + b"Text" + ); + + assert_eq!(KW_KEY.indexed_value_type(), IndexedValueType::Keyword); + assert_eq!(TEXT_KEY.indexed_value_type(), IndexedValueType::Text); + } + + #[test] + fn get_returns_none_for_missing_key() { + let attrs = SearchAttributes::default(); + assert_eq!(attrs.get(&BOOL_KEY), None); + assert!(!attrs.contains_key(&INT_KEY)); + } + + #[test] + fn get_returns_none_for_type_mismatch() { + let attrs = SearchAttributes::new([BOOL_KEY.value_set(true)]); + // Try to read the bool payload as an i64 — should gracefully return None + let mismatched_key = SearchAttributeKey::::int("my_bool"); + assert_eq!(attrs.get(&mismatched_key), None); + } + + #[test] + fn updates_to_proto_includes_empty_payload_for_unset() { + let proto = + SearchAttributes::updates_to_proto([BOOL_KEY.value_set(true), INT_KEY.value_unset()]); + + let bool_payload = proto.indexed_fields.get("my_bool").unwrap(); + assert!(!bool_payload.data.is_empty()); + + let int_payload = proto.indexed_fields.get("my_int").unwrap(); + assert!(int_payload.data.is_empty()); + assert!(int_payload.metadata.is_empty()); + } + + #[test] + fn contains_key_returns_true_when_present() { + let attrs = SearchAttributes::new([INT_KEY.value_set(42)]); + assert!(attrs.contains_key(&INT_KEY)); + } + + #[test] + fn timestamp_rfc3339_format() { + let ts = Timestamp::new(1_700_000_000, 0); + let rfc = timestamp_to_rfc3339(&ts).unwrap(); + // SecondsFormat::Nanos emits full precision even for zero nanos + assert_eq!(rfc, "2023-11-14T22:13:20.000000000Z"); + } + + #[test] + fn timestamp_rfc3339_with_nanos() { + let ts = Timestamp::new(1_700_000_000, 500_000_000); + let rfc = timestamp_to_rfc3339(&ts).unwrap(); + assert_eq!(rfc, "2023-11-14T22:13:20.500000000Z"); + + let parsed = rfc3339_to_timestamp(&rfc).unwrap(); + assert_eq!(parsed.seconds(), ts.seconds()); + assert_eq!(parsed.nanos(), ts.nanos()); + } + + #[test] + fn search_attribute_update_accessors() { + let set = BOOL_KEY.value_set(true); + assert_eq!(set.name(), "my_bool"); + assert!(!set.is_unset()); + + let unset = BOOL_KEY.value_unset(); + assert_eq!(unset.name(), "my_bool"); + assert!(unset.is_unset()); + } + + #[test] + fn timestamp_from_prost_types() { + let prost_ts = prost_types::Timestamp { + seconds: 1_000_000, + nanos: 42, + }; + let ts: Timestamp = prost_ts.into(); + assert_eq!(ts.seconds(), 1_000_000); + assert_eq!(ts.nanos(), 42); + + let back: prost_types::Timestamp = ts.into(); + assert_eq!(back.seconds, 1_000_000); + assert_eq!(back.nanos, 42); + } + + #[test] + fn timestamp_from_system_time() { + let st = std::time::UNIX_EPOCH + std::time::Duration::new(1_700_000_000, 123_456_789); + let ts: Timestamp = st.into(); + assert_eq!(ts.seconds(), 1_700_000_000); + assert_eq!(ts.nanos(), 123_456_789); + + let back: std::time::SystemTime = ts.try_into().unwrap(); + assert_eq!(back, st); + } + + // --- Edge-case tests (from review feedback) --- + + #[test] + fn timestamp_pre_epoch_normalized() { + // 1.25 seconds before epoch → { seconds: -2, nanos: 750_000_000 } + let st = std::time::UNIX_EPOCH - std::time::Duration::new(1, 250_000_000); + let ts: Timestamp = st.into(); + assert_eq!(ts.seconds(), -2); + assert_eq!(ts.nanos(), 750_000_000); + + let back: std::time::SystemTime = ts.try_into().unwrap(); + assert_eq!(back, st); + } + + #[test] + fn timestamp_pre_epoch_exact_second() { + // Exactly 5 seconds before epoch + let st = std::time::UNIX_EPOCH - std::time::Duration::new(5, 0); + let ts: Timestamp = st.into(); + assert_eq!(ts.seconds(), -5); + assert_eq!(ts.nanos(), 0); + + let back: std::time::SystemTime = ts.try_into().unwrap(); + assert_eq!(back, st); + } + + #[test] + fn timestamp_pre_epoch_rfc3339_round_trip() { + let ts = Timestamp::new(-2, 750_000_000); + let payload = ts + .to_search_attribute_payload(IndexedValueType::Datetime) + .unwrap(); + let decoded = Timestamp::from_search_attribute_payload(&payload).unwrap(); + assert_eq!(decoded.seconds(), ts.seconds()); + assert_eq!(decoded.nanos(), ts.nanos()); + } + + #[test] + #[should_panic(expected = "search attribute serialization failed")] + fn value_set_panics_on_nan() { + FLOAT_KEY.value_set(f64::NAN); + } + + #[test] + #[should_panic(expected = "search attribute serialization failed")] + fn value_set_panics_on_infinity() { + FLOAT_KEY.value_set(f64::INFINITY); + } + + #[test] + fn try_value_set_returns_error_on_nan() { + let result = FLOAT_KEY.try_value_set(f64::NAN); + assert!(result.is_err()); + } + + #[test] + fn try_value_set_returns_error_on_infinity() { + let result = FLOAT_KEY.try_value_set(f64::INFINITY); + assert!(result.is_err()); + } + + #[test] + fn round_trip_empty_string() { + let val = String::new(); + let payload = val + .to_search_attribute_payload(IndexedValueType::Keyword) + .unwrap(); + assert_eq!(String::from_search_attribute_payload(&payload).unwrap(), ""); + } + + #[test] + fn round_trip_empty_keyword_list() { + let val: Vec = vec![]; + let payload = val + .to_search_attribute_payload(IndexedValueType::KeywordList) + .unwrap(); + assert_eq!( + Vec::::from_search_attribute_payload(&payload).unwrap(), + Vec::::new() + ); + } + + #[test] + fn round_trip_large_int_boundaries() { + for val in [i64::MAX, i64::MIN, 0i64] { + let payload = val + .to_search_attribute_payload(IndexedValueType::Int) + .unwrap(); + assert_eq!(i64::from_search_attribute_payload(&payload).unwrap(), val); + } + } + + #[test] + fn decode_missing_encoding_metadata() { + let payload = Payload { + metadata: HashMap::new(), + data: b"true".to_vec(), + ..Default::default() + }; + let result = bool::from_search_attribute_payload(&payload); + assert!(result.is_err()); + } + + #[test] + fn decode_wrong_encoding_metadata() { + let mut metadata = HashMap::new(); + metadata.insert("encoding".to_string(), b"binary/plain".to_vec()); + let payload = Payload { + metadata, + data: b"true".to_vec(), + ..Default::default() + }; + let result = bool::from_search_attribute_payload(&payload); + assert!(result.is_err()); + } + + #[test] + fn decode_garbage_json_data() { + let mut metadata = HashMap::new(); + metadata.insert("encoding".to_string(), b"json/plain".to_vec()); + let payload = Payload { + metadata, + data: b"not-valid-json!!!".to_vec(), + ..Default::default() + }; + let result = bool::from_search_attribute_payload(&payload); + assert!(result.is_err()); + } + + #[test] + fn keys_returns_attribute_names() { + let attrs = SearchAttributes::new([BOOL_KEY.value_set(true), INT_KEY.value_set(42)]); + let mut keys: Vec<&str> = attrs.keys().collect(); + keys.sort(); + assert_eq!(keys, vec!["my_bool", "my_int"]); + } + + #[test] + fn raw_payload_returns_payload() { + let attrs = SearchAttributes::new([BOOL_KEY.value_set(true)]); + let payload = attrs.raw_payload("my_bool").unwrap(); + assert!(!payload.data.is_empty()); + assert!(attrs.raw_payload("nonexistent").is_none()); + } + + #[test] + fn into_proto_moves_without_clone() { + let attrs = SearchAttributes::new([INT_KEY.value_set(7)]); + let proto = attrs.into_proto(); + assert_eq!(proto.indexed_fields.len(), 1); + } + + #[test] + fn search_attribute_key_is_copy() { + let key = BOOL_KEY; + let key2 = key; // Copy, not move + assert_eq!(key.name(), key2.name()); + } + + #[test] + fn timestamp_new_clamps_negative_nanos() { + let ts = Timestamp::new(100, -42); + assert_eq!(ts.seconds(), 100); + assert_eq!(ts.nanos(), 0); // clamped to 0 + } + + #[test] + fn timestamp_new_clamps_excessive_nanos() { + let ts = Timestamp::new(100, 2_000_000_000); + assert_eq!(ts.seconds(), 100); + assert_eq!(ts.nanos(), 999_999_999); // clamped to MAX_NANOS + } + + #[test] + fn timestamp_to_prost_round_trips() { + let ts = Timestamp::new(1_700_000_000, 123_456_789); + let prost_ts = ts.to_prost(); + assert_eq!(prost_ts.seconds, 1_700_000_000); + assert_eq!(prost_ts.nanos, 123_456_789); + let back: Timestamp = prost_ts.into(); + assert_eq!(back, ts); + } + + #[test] + fn apply_inserts_and_removes() { + let mut attrs = SearchAttributes::new([INT_KEY.value_set(42)]); + assert_eq!(attrs.get(&INT_KEY), Some(42)); + + // Apply an update that changes the value + attrs.apply(INT_KEY.value_set(99)); + assert_eq!(attrs.get(&INT_KEY), Some(99)); + + // Apply an unset + attrs.apply(INT_KEY.value_unset()); + assert_eq!(attrs.get(&INT_KEY), None); + assert!(attrs.is_empty()); + } + + #[test] + fn from_owned_proto_moves_without_clone() { + let proto = ProtoSearchAttributes { + indexed_fields: { + let mut m = HashMap::new(); + m.insert("k".to_string(), INT_KEY.value_set(7).payload.unwrap()); + m + }, + }; + let attrs: SearchAttributes = proto.into(); + assert_eq!(attrs.get(&SearchAttributeKey::int("k")), Some(7)); + } + + #[test] + fn search_attributes_equality() { + let a = SearchAttributes::new([BOOL_KEY.value_set(true), INT_KEY.value_set(42)]); + let b = SearchAttributes::new([BOOL_KEY.value_set(true), INT_KEY.value_set(42)]); + let c = SearchAttributes::new([BOOL_KEY.value_set(false), INT_KEY.value_set(42)]); + assert_eq!(a, b); + assert_ne!(a, c); + } + + #[test] + fn from_proto_trait_matches_from_proto_method() { + let updates = [INT_KEY.value_set(99), BOOL_KEY.value_set(true)]; + let proto = SearchAttributes::new(updates).to_proto(); + let via_method = SearchAttributes::from_proto(&proto); + let via_trait: SearchAttributes = proto.into(); + assert_eq!(via_method, via_trait); + } +} diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index faa401a27..a3283348e 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -20,7 +20,7 @@ pub mod worker; pub use temporalio_common_wasm::{ ActivityDefinition, ActivityError, HasWorkflowDefinition, Priority, QueryDefinition, SignalDefinition, UntypedWorkflow, UpdateDefinition, WorkerDeploymentVersion, - WorkflowDefinition, data_converters, error, + WorkflowDefinition, data_converters, error, search_attributes, }; macro_rules! dbg_panic { diff --git a/crates/sdk-core/tests/integ_tests/workflow_tests/continue_as_new.rs b/crates/sdk-core/tests/integ_tests/workflow_tests/continue_as_new.rs index 8b40654ab..5695d239b 100644 --- a/crates/sdk-core/tests/integ_tests/workflow_tests/continue_as_new.rs +++ b/crates/sdk-core/tests/integ_tests/workflow_tests/continue_as_new.rs @@ -1,19 +1,15 @@ use crate::common::{CoreWfStarter, SEARCH_ATTR_TXT, build_fake_sdk}; -use std::{collections::HashMap, sync::Arc, time::Duration}; +use std::{sync::Arc, time::Duration}; use temporalio_client::WorkflowStartOptions; use temporalio_common::{ - protos::{ - coresdk::AsJsonPayloadExt, - temporal::api::{ - command::v1::command::Attributes, - common::v1::SearchAttributes, - enums::v1::{ - CommandType, - ContinueAsNewVersioningBehavior as ProtoContinueAsNewVersioningBehavior, - }, - history::v1::history_event, + protos::temporal::api::{ + command::v1::command::Attributes, + enums::v1::{ + CommandType, ContinueAsNewVersioningBehavior as ProtoContinueAsNewVersioningBehavior, }, + history::v1::history_event, }, + search_attributes::{SearchAttributeKey, SearchAttributes}, worker::WorkerTaskTypes, }; use temporalio_macros::{workflow, workflow_methods}; @@ -25,6 +21,8 @@ use temporalio_sdk_core::{ }; use temporalio_workflow::runtime::types::ContinueAsNewRequest; +const SA_TXT: SearchAttributeKey = SearchAttributeKey::text(SEARCH_ATTR_TXT); + #[workflow] #[derive(Default)] struct ContinueAsNewWf; @@ -187,7 +185,7 @@ impl ClearSearchAttrsOnContinueAsNewWf { ctx.continue_as_new(&false, opts)?; } - assert!(ctx.search_attributes().indexed_fields.is_empty()); + assert!(ctx.search_attributes().is_empty()); Ok(()) } } @@ -208,10 +206,7 @@ async fn clear_search_attributes_on_continue_as_new() { ClearSearchAttrsOnContinueAsNewWf::run, true, WorkflowStartOptions::new(task_queue, wf_name.to_string()) - .search_attributes(HashMap::from([( - SEARCH_ATTR_TXT.to_string(), - "hello".as_json_payload().unwrap(), - )])) + .search_attributes(SearchAttributes::new([SA_TXT.value_set("hello".into())])) .build(), ) .await diff --git a/crates/sdk-core/tests/integ_tests/workflow_tests/eager.rs b/crates/sdk-core/tests/integ_tests/workflow_tests/eager.rs index af309e1c4..11522b461 100644 --- a/crates/sdk-core/tests/integ_tests/workflow_tests/eager.rs +++ b/crates/sdk-core/tests/integ_tests/workflow_tests/eager.rs @@ -111,7 +111,7 @@ pub(crate) async fn eager_start( .and_then(|d| d.try_into().ok()), workflow_run_timeout: options.run_timeout.and_then(|d| d.try_into().ok()), workflow_task_timeout: options.task_timeout.and_then(|d| d.try_into().ok()), - search_attributes: options.search_attributes.map(|d| d.into()), + search_attributes: options.search_attributes.map(|d| d.into_proto()), cron_schedule: options.cron_schedule.unwrap_or_default(), request_eager_execution: options.enable_eager_workflow_start, retry_policy: options.retry_policy, diff --git a/crates/sdk-core/tests/integ_tests/workflow_tests/upsert_search_attrs.rs b/crates/sdk-core/tests/integ_tests/workflow_tests/upsert_search_attrs.rs index fa1fecc7e..cb9ec6135 100644 --- a/crates/sdk-core/tests/integ_tests/workflow_tests/upsert_search_attrs.rs +++ b/crates/sdk-core/tests/integ_tests/workflow_tests/upsert_search_attrs.rs @@ -1,18 +1,18 @@ use crate::common::{CoreWfStarter, SEARCH_ATTR_INT, SEARCH_ATTR_TXT, build_fake_sdk}; use assert_matches::assert_matches; -use std::{collections::HashMap, time::Duration}; +use std::time::Duration; use temporalio_client::{ UntypedWorkflow, WorkflowDescribeOptions, WorkflowGetResultOptions, WorkflowStartOptions, }; use temporalio_common::{ protos::{ - coresdk::{AsJsonPayloadExt, FromJsonPayloadExt}, + coresdk::FromJsonPayloadExt, temporal::api::{ command::v1::{Command, command}, - common::v1::Payload, enums::v1::EventType, }, }, + search_attributes::{SearchAttributeKey, SearchAttributes}, worker::WorkerTaskTypes, }; use temporalio_macros::{workflow, workflow_methods}; @@ -23,6 +23,9 @@ use temporalio_sdk_core::{ }; use uuid::Uuid; +const SA_INT: SearchAttributeKey = SearchAttributeKey::int(SEARCH_ATTR_INT); +const SA_TXT: SearchAttributeKey = SearchAttributeKey::text(SEARCH_ATTR_TXT); + #[workflow] #[derive(Default)] struct SearchAttrUpdater; @@ -31,17 +34,12 @@ struct SearchAttrUpdater; impl SearchAttrUpdater { #[run(name = "sends_upsert_search_attrs")] async fn run(ctx: &mut WorkflowContext) -> WorkflowResult<()> { - let mut int_val = ctx - .search_attributes() - .indexed_fields - .get(SEARCH_ATTR_INT) - .cloned() - .unwrap_or_default(); - let orig_val = int_val.data[0]; - int_val.data[0] += 1; + let typed = ctx.search_attributes(); + let orig_val = typed.get(&SA_INT).unwrap_or(0); + let new_val = orig_val + 1; ctx.upsert_search_attributes([ - (SEARCH_ATTR_TXT.to_string(), "goodbye".as_json_payload()?), - (SEARCH_ATTR_INT.to_string(), int_val), + SA_TXT.value_set("goodbye".into()), + SA_INT.value_set(new_val), ]); if orig_val == 49 { Err(WorkflowTermination::continue_as_new(Default::default())) @@ -66,12 +64,9 @@ async fn sends_upsert() { wf_name, vec![], WorkflowStartOptions::new(task_queue, wf_id.to_string()) - .search_attributes(HashMap::from([ - ( - SEARCH_ATTR_TXT.to_string(), - "hello".as_json_payload().unwrap(), - ), - (SEARCH_ATTR_INT.to_string(), 1.as_json_payload().unwrap()), + .search_attributes(SearchAttributes::new([ + SA_TXT.value_set("hello".into()), + SA_INT.value_set(1), ])) .execution_timeout(Duration::from_secs(4)) .build(), @@ -117,23 +112,11 @@ struct UpsertTestWf; impl UpsertTestWf { #[run(name = DEFAULT_WORKFLOW_TYPE)] async fn run(ctx: &mut WorkflowContext) -> WorkflowResult<()> { - const K1: &str = "foo"; - const K2: &str = "bar"; + const K1: SearchAttributeKey = SearchAttributeKey::keyword("foo"); + const K2: SearchAttributeKey = SearchAttributeKey::keyword("bar"); ctx.upsert_search_attributes([ - ( - String::from(K1), - Payload { - data: vec![0x01], - ..Default::default() - }, - ), - ( - String::from(K2), - Payload { - data: vec![0x02], - ..Default::default() - }, - ), + K1.value_set("value1".into()), + K2.value_set("value2".into()), ]); Ok(()) } @@ -159,8 +142,8 @@ async fn upsert_search_attrs_from_workflow() { let fields = &msg.search_attributes.as_ref().unwrap().indexed_fields; let payload1 = fields.get(k1).unwrap(); let payload2 = fields.get(k2).unwrap(); - assert_eq!(payload1.data[0], 0x01); - assert_eq!(payload2.data[0], 0x02); + assert_eq!(payload1.data, b"\"value1\""); + assert_eq!(payload2.data, b"\"value2\""); assert_eq!(fields.len(), 2); } ); diff --git a/crates/sdk/examples/search_attributes/starter.rs b/crates/sdk/examples/search_attributes/starter.rs index 24a7df391..10e1fa8a2 100644 --- a/crates/sdk/examples/search_attributes/starter.rs +++ b/crates/sdk/examples/search_attributes/starter.rs @@ -1,12 +1,11 @@ mod workflows; -use std::collections::HashMap; use temporalio_client::{ Client, ClientOptions, Connection, WorkflowGetResultOptions, WorkflowStartOptions, envconfig::LoadClientConfigProfileOptions, }; -use temporalio_common::protos::coresdk::AsJsonPayloadExt; -use workflows::SearchAttributesWorkflow; +use temporalio_common::search_attributes::SearchAttributes; +use workflows::{INT_FIELD, KEYWORD_FIELD, SearchAttributesWorkflow}; #[tokio::main] async fn main() -> Result<(), Box> { @@ -15,11 +14,10 @@ async fn main() -> Result<(), Box> { let connection = Connection::connect(conn_opts).await?; let client = Client::new(connection, client_opts)?; - let mut search_attrs = HashMap::new(); - search_attrs.insert( - "CustomKeywordField".to_string(), - "initial-value".as_json_payload()?, - ); + let search_attrs = SearchAttributes::new([ + KEYWORD_FIELD.value_set("initial-value".into()), + INT_FIELD.value_set(0), + ]); let handle = client .start_workflow( diff --git a/crates/sdk/examples/search_attributes/workflows.rs b/crates/sdk/examples/search_attributes/workflows.rs index 92cc52c27..cea9f44d6 100644 --- a/crates/sdk/examples/search_attributes/workflows.rs +++ b/crates/sdk/examples/search_attributes/workflows.rs @@ -1,8 +1,12 @@ #![allow(unreachable_pub)] -use temporalio_common::protos::coresdk::AsJsonPayloadExt; +use temporalio_common::search_attributes::SearchAttributeKey; use temporalio_macros::{workflow, workflow_methods}; use temporalio_sdk::{WorkflowContext, WorkflowResult}; +pub const KEYWORD_FIELD: SearchAttributeKey = + SearchAttributeKey::keyword("CustomKeywordField"); +pub const INT_FIELD: SearchAttributeKey = SearchAttributeKey::int("CustomIntField"); + #[workflow] #[derive(Default)] pub struct SearchAttributesWorkflow; @@ -11,22 +15,14 @@ pub struct SearchAttributesWorkflow; impl SearchAttributesWorkflow { #[run] pub async fn run(ctx: &mut WorkflowContext, _input: ()) -> WorkflowResult { - let initial_attrs = ctx.search_attributes(); - let initial_keyword = initial_attrs - .indexed_fields - .get("CustomKeywordField") - .and_then(|p| serde_json::from_slice::(&p.data).ok()) + let initial_keyword = ctx + .search_attributes() + .get(&KEYWORD_FIELD) .unwrap_or_default(); ctx.upsert_search_attributes([ - ( - "CustomKeywordField".to_string(), - "updated-value".as_json_payload().unwrap(), - ), - ( - "CustomIntField".to_string(), - 42i64.as_json_payload().unwrap(), - ), + KEYWORD_FIELD.value_set("updated-value".into()), + INT_FIELD.value_set(42), ]); Ok(format!( diff --git a/crates/workflow/src/workflow_context.rs b/crates/workflow/src/workflow_context.rs index 349a58999..4f257e289 100644 --- a/crates/workflow/src/workflow_context.rs +++ b/crates/workflow/src/workflow_context.rs @@ -23,11 +23,10 @@ use futures_util::{ task::Context, }; use std::{ - cell::{Cell, Ref, RefCell}, + cell::{Cell, RefCell}, collections::HashMap, future::{self, Future}, marker::PhantomData, - ops::Deref, pin::Pin, rc::Rc, sync::atomic::{AtomicBool, Ordering}, @@ -67,11 +66,12 @@ use temporalio_common_wasm::{ }, }, temporal::api::{ - common::v1::{Memo, Payload, SearchAttributes}, + common::v1::{Memo, Payload, SearchAttributes as ProtoSearchAttributes}, failure::v1::{CanceledFailureInfo, Failure, failure::FailureInfo}, }, utilities::TryIntoOrNone, }, + search_attributes::{SearchAttributeUpdate, SearchAttributes}, worker::WorkerDeploymentVersion, }; @@ -315,7 +315,7 @@ pub struct WorkflowContextView { pub cron_schedule: Option, /// User-defined memo pub memo: Option, - /// Initial search attributes + /// Initial search attributes as a typed collection. pub search_attributes: Option, } @@ -395,7 +395,10 @@ impl WorkflowContextView { retry_policy: init.retry_policy.clone(), cron_schedule, memo: init.memo.clone(), - search_attributes: init.search_attributes.clone(), + search_attributes: init + .search_attributes + .as_ref() + .map(SearchAttributes::from_proto), } } } @@ -790,9 +793,9 @@ impl SyncWorkflowContext { .map(Into::into) } - /// Return current values for workflow search attributes - pub fn search_attributes(&self) -> impl Deref + '_ { - Ref::map(self.base.inner.shared.borrow(), |s| &s.search_attributes) + /// Return current values for workflow search attributes. + pub fn search_attributes(&self) -> SearchAttributes { + SearchAttributes::from_proto(&self.base.inner.shared.borrow().search_attributes) } /// Return the workflow's randomness seed @@ -996,14 +999,35 @@ impl SyncWorkflowContext { } } - /// Add or create a set of search attributes - pub fn upsert_search_attributes(&self, attr_iter: impl IntoIterator) { + /// Add, update, or remove search attributes using typed keys. + /// + /// Updates are applied to the local in-memory view immediately so that + /// subsequent calls to [`search_attributes()`](Self::search_attributes) + /// reflect the changes. The command is also sent to the server. + pub fn upsert_search_attributes( + &self, + updates: impl IntoIterator, + ) { + // Collect so we can iterate twice: once for local state, once for the + // wire proto (which uses a different encoding for "unset"). + let updates: Vec = updates.into_iter().collect(); + + // Update local state using the typed API, which correctly removes keys + // on unset (rather than inserting empty payloads like the wire format). + { + let mut shared = self.base.inner.shared.borrow_mut(); + let mut attrs = SearchAttributes::from_proto(&shared.search_attributes); + for update in updates.iter().cloned() { + attrs.apply(update); + } + shared.search_attributes = attrs.into_proto(); + } + + let proto = SearchAttributes::updates_to_proto(updates); self.base.inner.runtime.host.push_command( workflow_command::Variant::UpsertWorkflowSearchAttributes( UpsertWorkflowSearchAttributes { - search_attributes: Some(SearchAttributes { - indexed_fields: attr_iter.into_iter().collect(), - }), + search_attributes: Some(proto), }, ) .into(), @@ -1151,8 +1175,8 @@ impl WorkflowContext { self.sync.current_deployment_version() } - /// Return current values for workflow search attributes - pub fn search_attributes(&self) -> impl Deref + '_ { + /// Return current values for workflow search attributes. + pub fn search_attributes(&self) -> SearchAttributes { self.sync.search_attributes() } @@ -1276,9 +1300,12 @@ impl WorkflowContext { self.sync.external_workflow(workflow_id, run_id) } - /// Add or create a set of search attributes - pub fn upsert_search_attributes(&self, attr_iter: impl IntoIterator) { - self.sync.upsert_search_attributes(attr_iter) + /// Add, update, or remove search attributes using typed keys. + pub fn upsert_search_attributes( + &self, + updates: impl IntoIterator, + ) { + self.sync.upsert_search_attributes(updates) } /// Add or create a set of memo fields @@ -1414,7 +1441,7 @@ struct WorkflowContextSharedData { /// Maps change ids -> resolved status changes: HashMap, activation: CoreWorkflowActivation, - search_attributes: SearchAttributes, + search_attributes: ProtoSearchAttributes, random_seed: u64, /// Current details string, surfaced via the workflow metadata query. current_details: String, @@ -2413,11 +2440,12 @@ mod tests { "header-key".to_string(), Payload::from(b"header-value".as_slice()), ); - let mut search_attributes = SearchAttributes::default(); - search_attributes.indexed_fields.insert( + let mut proto_search_attributes = ProtoSearchAttributes::default(); + proto_search_attributes.indexed_fields.insert( "CustomKeywordField".to_string(), Payload::from(b"value".as_slice()), ); + let search_attributes = SearchAttributes::from_proto(&proto_search_attributes); let termination = sync .continue_as_new( @@ -2461,7 +2489,7 @@ mod tests { backoff_start_interval: Some(Duration::from_secs(4).try_into().unwrap()), memo, headers, - search_attributes: Some(search_attributes), + search_attributes: Some(proto_search_attributes), retry_policy: Some(RetryPolicy { maximum_attempts: 5, ..Default::default() @@ -2491,7 +2519,10 @@ mod tests { unreachable!() }; - assert_eq!(cmd.search_attributes, Some(SearchAttributes::default())); + assert_eq!( + cmd.search_attributes, + Some(ProtoSearchAttributes::default()) + ); } #[test] @@ -2584,4 +2615,106 @@ mod tests { }; assert_eq!(err.to_string(), "Encoding error: serialization failure"); } + + #[test] + fn upsert_search_attributes_updates_local_state() { + use temporalio_common_wasm::search_attributes::SearchAttributeKey; + + const K: SearchAttributeKey = SearchAttributeKey::int("my_int"); + + let ctx = test_context(); + assert!(ctx.search_attributes().is_empty()); + + ctx.upsert_search_attributes([K.value_set(42)]); + let attrs = ctx.search_attributes(); + assert_eq!(attrs.get(&K), Some(42)); + } + + #[test] + fn upsert_search_attributes_unset_removes_from_local_state() { + use temporalio_common_wasm::search_attributes::SearchAttributeKey; + + const K: SearchAttributeKey = SearchAttributeKey::keyword("my_kw"); + + let ctx = test_context(); + // Set, then unset. + ctx.upsert_search_attributes([K.value_set("hello".into())]); + assert_eq!(ctx.search_attributes().get(&K), Some("hello".into())); + + ctx.upsert_search_attributes([K.value_unset()]); + assert!(!ctx.search_attributes().contains_key(&K)); + assert!(ctx.search_attributes().is_empty()); + } + + #[test] + fn upsert_search_attributes_multiple_updates_last_wins() { + use temporalio_common_wasm::search_attributes::SearchAttributeKey; + + const K: SearchAttributeKey = SearchAttributeKey::int("counter"); + + let ctx = test_context(); + ctx.upsert_search_attributes([K.value_set(1), K.value_set(2)]); + assert_eq!(ctx.search_attributes().get(&K), Some(2)); + } + + #[test] + fn upsert_search_attributes_merges_with_initial() { + use temporalio_common_wasm::search_attributes::SearchAttributeKey; + + const A: SearchAttributeKey = SearchAttributeKey::int("attr_a"); + const B: SearchAttributeKey = SearchAttributeKey::keyword("attr_b"); + + // Start with initial search attribute A. + let init_sa = SearchAttributes::new([A.value_set(1)]).into_proto(); + let init = InitializeWorkflow { + workflow_type: TestWorkflow.name().to_string(), + search_attributes: Some(init_sa), + ..Default::default() + }; + let base = BaseWorkflowContext::new( + "default".to_string(), + "tq".to_string(), + "run-id".to_string(), + init, + DataConverter::default(), + Rc::new(NoopHost), + ); + let ctx = WorkflowContext::from_base(base, Rc::new(RefCell::new(TestWorkflow))); + + assert_eq!(ctx.search_attributes().get(&A), Some(1)); + + // Upsert B — A should still be present. + ctx.upsert_search_attributes([B.value_set("hello".into())]); + assert_eq!(ctx.search_attributes().get(&A), Some(1)); + assert_eq!(ctx.search_attributes().get(&B), Some("hello".into())); + } + + #[test] + fn view_search_attributes_returns_typed() { + use temporalio_common_wasm::search_attributes::SearchAttributeKey; + + const K: SearchAttributeKey = SearchAttributeKey::bool("active"); + + let init_sa = SearchAttributes::new([K.value_set(true)]).into_proto(); + let init = InitializeWorkflow { + workflow_type: TestWorkflow.name().to_string(), + search_attributes: Some(init_sa), + ..Default::default() + }; + let base = BaseWorkflowContext::new( + "default".to_string(), + "tq".to_string(), + "run-id".to_string(), + init, + DataConverter::default(), + Rc::new(NoopHost), + ); + let ctx = WorkflowContext::from_base(base, Rc::new(RefCell::new(TestWorkflow))); + + let view = ctx.view(); + let sa = view + .search_attributes + .expect("should have search attributes"); + assert_eq!(sa.get(&K), Some(true)); + } } diff --git a/crates/workflow/src/workflow_context/options.rs b/crates/workflow/src/workflow_context/options.rs index 4281f54d1..fc6415631 100644 --- a/crates/workflow/src/workflow_context/options.rs +++ b/crates/workflow/src/workflow_context/options.rs @@ -19,7 +19,7 @@ use temporalio_common_wasm::{ }, }, temporal::api::{ - common::v1::{Payload, RetryPolicy, SearchAttributes}, + common::v1::{Payload, RetryPolicy}, enums::v1::{ ContinueAsNewVersioningBehavior as ProtoContinueAsNewVersioningBehavior, WorkflowIdReusePolicy, @@ -27,6 +27,7 @@ use temporalio_common_wasm::{ sdk::v1::UserMetadata, }, }, + search_attributes::SearchAttributes, }; /// Options for scheduling an activity #[derive(Debug, bon::Builder, Clone)] @@ -281,8 +282,8 @@ pub struct ChildWorkflowOptions { pub task_timeout: Option, /// Optionally set a cron schedule for the workflow pub cron_schedule: Option, - /// Optionally associate extra search attributes with a workflow - pub search_attributes: Option>, + /// Additional search attributes to set on the child workflow. + pub search_attributes: Option, /// Priority for the workflow pub priority: Option, } @@ -318,11 +319,7 @@ impl ChildWorkflowOptions { .task_timeout .and_then(|duration| duration.try_into().ok()), cron_schedule: self.cron_schedule.unwrap_or_default(), - search_attributes: self.search_attributes.and_then(|attrs| { - (!attrs.is_empty()).then_some(SearchAttributes { - indexed_fields: attrs, - }) - }), + search_attributes: self.search_attributes.map(|t| t.into_proto()), priority: self.priority.map(Into::into), ..Default::default() }), @@ -597,7 +594,7 @@ impl ContinueAsNewOptions { .and_then(|duration| duration.try_into().ok()), memo: self.memo.unwrap_or_default(), headers: self.headers.unwrap_or_default(), - search_attributes: self.search_attributes, + search_attributes: self.search_attributes.map(|t| t.into_proto()), retry_policy: self.retry_policy, versioning_intent: self .versioning_intent