diff --git a/sql/schema.sql b/sql/schema.sql index 7c77baf..5109ab7 100644 --- a/sql/schema.sql +++ b/sql/schema.sql @@ -792,7 +792,10 @@ begin perform durable.emit_event( p_queue_name, '$child:' || p_task_id::text, - jsonb_build_object('status', p_status) || coalesce(p_payload, '{}'::jsonb) + jsonb_build_object( + 'inner', jsonb_build_object('status', p_status) || coalesce(p_payload, '{}'::jsonb), + 'metadata', '{}'::jsonb + ) ); end if; @@ -1423,6 +1426,23 @@ begin raise exception 'event_name must be provided'; end if; + -- Validate that if p_payload is not null, it has exactly the allowed keys ('inner' and 'metadata') + if p_payload is not null and jsonb_typeof(p_payload) = 'object' then + if exists ( + select 1 + from jsonb_object_keys(p_payload) as k + where k not in ('inner', 'metadata') + ) then + raise exception 'p_payload may only contain ''inner'' and ''metadata'' keys'; + end if; + if not p_payload ? 'inner' then + raise exception 'p_payload must contain an ''inner'' key'; + end if; + if not p_payload ? 'metadata' then + raise exception 'p_payload must contain a ''metadata'' key'; + end if; + end if; + -- Insert the event into the events table (first-writer-wins). -- Subsequent emits for the same event are no-ops. -- We use DO UPDATE WHERE payload IS NULL to handle the case where await_event diff --git a/src/client.rs b/src/client.rs index 370093d..585f2fc 100644 --- a/src/client.rs +++ b/src/client.rs @@ -11,8 +11,8 @@ use uuid::Uuid; use crate::error::{DurableError, DurableResult}; use crate::task::{Task, TaskRegistry, TaskWrapper}; use crate::types::{ - CancellationPolicy, RetryStrategy, SpawnDefaults, SpawnOptions, SpawnResult, SpawnResultRow, - WorkerOptions, + CancellationPolicy, DurableEventPayload, RetryStrategy, SpawnDefaults, SpawnOptions, + SpawnResult, SpawnResultRow, WorkerOptions, }; /// Internal struct for serializing spawn options to the database. @@ -684,7 +684,21 @@ where #[cfg(feature = "telemetry")] tracing::Span::current().record("queue", queue); - let payload_json = serde_json::to_value(payload)?; + let inner_payload_json = serde_json::to_value(payload)?; + + let mut payload_wrapper = DurableEventPayload { + inner: inner_payload_json, + metadata: JsonValue::Null, + }; + + #[allow(unused_mut)] // mut is needed when telemetry feature is enabled + let mut metadata_map: HashMap = HashMap::new(); + + #[cfg(feature = "telemetry")] + crate::telemetry::inject_trace_context(&mut metadata_map); + payload_wrapper.metadata = serde_json::to_value(metadata_map)?; + + let payload_json = serde_json::to_value(payload_wrapper)?; let query = "SELECT durable.emit_event($1, $2, $3)"; sqlx::query(query) diff --git a/src/context.rs b/src/context.rs index 10993fe..1827eae 100644 --- a/src/context.rs +++ b/src/context.rs @@ -8,6 +8,7 @@ use uuid::Uuid; use crate::Durable; use crate::error::{ControlFlow, TaskError, TaskResult}; use crate::task::Task; +use crate::types::DurableEventPayload; use crate::types::{ AwaitEventResult, CheckpointRow, ChildCompletePayload, ChildStatus, ClaimedTask, SpawnOptions, TaskHandle, @@ -351,7 +352,9 @@ where // Check cache for already-received event if let Some(cached) = self.checkpoint_cache.get(&checkpoint_name) { - return Ok(serde_json::from_value(cached.clone())?); + let durable_event_payload: DurableEventPayload = + serde_json::from_value(cached.clone())?; + return self.process_event_payload_wrapper(durable_event_payload); } // Check if we were woken by this event but it timed out (null payload) @@ -383,10 +386,39 @@ where } // Event arrived - cache and return - let payload = result.payload.unwrap_or(JsonValue::Null); - self.checkpoint_cache - .insert(checkpoint_name, payload.clone()); - Ok(serde_json::from_value(payload)?) + let durable_event_payload = result.payload.unwrap_or(DurableEventPayload { + inner: JsonValue::Null, + metadata: JsonValue::Null, + }); + self.checkpoint_cache.insert( + checkpoint_name, + serde_json::to_value(durable_event_payload.clone())?, + ); + + self.process_event_payload_wrapper(durable_event_payload) + } + + fn process_event_payload_wrapper( + &self, + value: DurableEventPayload, + ) -> TaskResult { + #[cfg(feature = "telemetry")] + { + use opentelemetry::KeyValue; + use opentelemetry::trace::TraceContextExt; + use tracing_opentelemetry::OpenTelemetrySpanExt; + + let metadata: Option> = + serde_json::from_value(value.metadata)?; + if let Some(metadata) = metadata { + let context = crate::telemetry::extract_trace_context(&metadata); + tracing::Span::current().add_link_with_attributes( + context.span().span_context().clone(), + vec![KeyValue::new("sentry.link.type", "previous_trace")], + ); + } + } + Ok(serde_json::from_value(value.inner)?) } /// Emit an event to this task's queue. @@ -404,22 +436,13 @@ where ) )] pub async fn emit_event(&self, event_name: &str, payload: &T) -> TaskResult<()> { - if event_name.is_empty() { - return Err(TaskError::Validation { - message: "event_name must be non-empty".to_string(), - }); - } - - let payload_json = serde_json::to_value(payload)?; - let query = "SELECT durable.emit_event($1, $2, $3)"; - sqlx::query(query) - .bind(self.durable.queue_name()) - .bind(event_name) - .bind(&payload_json) - .execute(self.durable.pool()) - .await?; - - Ok(()) + self.durable + .emit_event(event_name, payload, None) + .await + .map_err(|e| TaskError::EmitEventFailed { + event_name: event_name.to_string(), + error: e, + }) } /// Extend the task's lease to prevent timeout. @@ -693,8 +716,11 @@ where // Check cache for already-received event if let Some(cached) = self.checkpoint_cache.get(&checkpoint_name) { - let payload: ChildCompletePayload = serde_json::from_value(cached.clone())?; - return Self::process_child_payload(&step_name, payload); + let durable_event_payload: DurableEventPayload = + serde_json::from_value(cached.clone())?; + let child_complete_payload: ChildCompletePayload = + self.process_event_payload_wrapper(durable_event_payload)?; + return Self::process_child_payload(&step_name, child_complete_payload); } // Check if we were woken by this event but it timed out (null payload) @@ -724,12 +750,18 @@ where } // Event arrived - parse and return - let payload_json = result.payload.unwrap_or(JsonValue::Null); - self.checkpoint_cache - .insert(checkpoint_name, payload_json.clone()); + let durable_event_payload = result.payload.unwrap_or(DurableEventPayload { + inner: JsonValue::Null, + metadata: JsonValue::Null, + }); + self.checkpoint_cache.insert( + checkpoint_name, + serde_json::to_value(durable_event_payload.clone())?, + ); - let payload: ChildCompletePayload = serde_json::from_value(payload_json)?; - Self::process_child_payload(&step_name, payload) + let child_complete_payload: ChildCompletePayload = + self.process_event_payload_wrapper(durable_event_payload)?; + Self::process_child_payload(&step_name, child_complete_payload) } /// Process the child completion payload and return the appropriate result. diff --git a/src/error.rs b/src/error.rs index 7e97013..baeaff6 100644 --- a/src/error.rs +++ b/src/error.rs @@ -77,6 +77,13 @@ pub enum TaskError { #[error("failed to spawn subtask `{name}`: {error}")] SubtaskSpawnFailed { name: String, error: DurableError }, + /// Error occurred while trying to emit an event. + #[error("failed to emit event `{event_name}`: {error}")] + EmitEventFailed { + event_name: String, + error: DurableError, + }, + /// A child task failed. /// /// Returned by [`TaskContext::join`](crate::TaskContext::join) when the child @@ -231,6 +238,13 @@ pub fn serialize_task_error(err: &TaskError) -> JsonValue { "subtask_name": name, }) } + TaskError::EmitEventFailed { event_name, error } => { + serde_json::json!({ + "name": "EmitEventFailed", + "message": error.to_string(), + "event_name": event_name, + }) + } TaskError::ChildFailed { step_name, message } => { serde_json::json!({ "name": "ChildFailed", diff --git a/src/lib.rs b/src/lib.rs index 09e2cf2..79a4085 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -109,8 +109,8 @@ pub use context::TaskContext; pub use error::{ControlFlow, DurableError, DurableResult, TaskError, TaskResult}; pub use task::{ErasedTask, Task, TaskWrapper}; pub use types::{ - CancellationPolicy, ClaimedTask, RetryStrategy, SpawnDefaults, SpawnOptions, SpawnResult, - TaskHandle, WorkerOptions, + CancellationPolicy, ClaimedTask, DurableEventPayload, RetryStrategy, SpawnDefaults, + SpawnOptions, SpawnResult, TaskHandle, WorkerOptions, }; pub use worker::Worker; diff --git a/src/postgres/migrations/20260126135558_event_payload_structure.sql b/src/postgres/migrations/20260126135558_event_payload_structure.sql new file mode 100644 index 0000000..edd065b --- /dev/null +++ b/src/postgres/migrations/20260126135558_event_payload_structure.sql @@ -0,0 +1,173 @@ +-- Migration to enforce structured event payload format with 'inner' and 'metadata' keys + +-- Update durable.cleanup_task_terminal to wrap the status in an 'inner' object +create or replace function durable.cleanup_task_terminal ( + p_queue_name text, + p_task_id uuid, + p_status text, -- 'completed', 'failed', 'cancelled' + p_payload jsonb default null, + p_cascade_children boolean default false +) + returns void + language plpgsql +as $$ +declare + v_parent_task_id uuid; +begin + -- Get parent_task_id for event emission + execute format( + 'select parent_task_id from durable.%I where task_id = $1', + 't_' || p_queue_name + ) into v_parent_task_id using p_task_id; + + -- Delete wait registrations for this task + execute format( + 'delete from durable.%I where task_id = $1', + 'w_' || p_queue_name + ) using p_task_id; + + -- Emit completion event for parent (if subtask) + if v_parent_task_id is not null then + perform durable.emit_event( + p_queue_name, + '$child:' || p_task_id::text, + jsonb_build_object( + 'inner', jsonb_build_object('status', p_status) || coalesce(p_payload, '{}'::jsonb), + 'metadata', '{}'::jsonb + ) + ); + end if; + + -- Cascade cancel children if requested + if p_cascade_children then + perform durable.cascade_cancel_children(p_queue_name, p_task_id); + end if; +end; +$$; + +-- Update durable.emit_event to validate payload structure +create or replace function durable.emit_event ( + p_queue_name text, + p_event_name text, + p_payload jsonb default null +) + returns void + language plpgsql +as $$ +declare + v_now timestamptz := durable.current_time(); + v_payload jsonb := coalesce(p_payload, 'null'::jsonb); + v_inserted_count integer; +begin + if p_event_name is null or length(trim(p_event_name)) = 0 then + raise exception 'event_name must be provided'; + end if; + + -- Validate that if p_payload is not null, it has exactly the allowed keys ('inner' and 'metadata') + if p_payload is not null and jsonb_typeof(p_payload) = 'object' then + if exists ( + select 1 + from jsonb_object_keys(p_payload) as k + where k not in ('inner', 'metadata') + ) then + raise exception 'p_payload may only contain ''inner'' and ''metadata'' keys'; + end if; + if not p_payload ? 'inner' then + raise exception 'p_payload must contain an ''inner'' key'; + end if; + if not p_payload ? 'metadata' then + raise exception 'p_payload must contain a ''metadata'' key'; + end if; + end if; + + -- Insert the event into the events table (first-writer-wins). + -- Subsequent emits for the same event are no-ops. + -- We use DO UPDATE WHERE payload IS NULL to handle the case where await_event + -- created a placeholder row before emit_event ran. + execute format( + 'insert into durable.%I (event_name, payload, emitted_at) + values ($1, $2, $3) + on conflict (event_name) do update + set payload = excluded.payload, emitted_at = excluded.emitted_at + where durable.%I.payload is null', + 'e_' || p_queue_name, + 'e_' || p_queue_name + ) using p_event_name, v_payload, v_now; + + get diagnostics v_inserted_count = row_count; + + -- Only wake waiters if we actually inserted (first emit). + -- Subsequent emits are no-ops to maintain consistency. + if v_inserted_count = 0 then + return; + end if; + + execute format( + 'with expired_waits as ( + delete from durable.%1$I w + where w.event_name = $1 + and w.timeout_at is not null + and w.timeout_at <= $2 + returning w.run_id + ), + affected as ( + select run_id, task_id, step_name + from durable.%1$I + where event_name = $1 + and (timeout_at is null or timeout_at > $2) + ), + -- Lock tasks before updating runs to prevent waking cancelled tasks. + -- Only lock sleeping tasks to avoid interfering with other operations. + -- This prevents waking cancelled tasks (e.g., when cascade_cancel_children + -- is running concurrently). + locked_tasks as ( + select t.task_id + from durable.%4$I t + where t.task_id in (select task_id from affected) + and t.state = ''sleeping'' + for update + ), + -- update the run table for all waiting runs so they are pending again + updated_runs as ( + update durable.%2$I r + set state = ''pending'', + available_at = $2, + wake_event = null, + event_payload = $3, + claimed_by = null, + claim_expires_at = null + where r.run_id in (select run_id from affected) + and r.state = ''sleeping'' + and r.task_id in (select task_id from locked_tasks) + returning r.run_id, r.task_id + ), + -- update checkpoints for all affected tasks/steps so they contain the event payload + checkpoint_upd as ( + insert into durable.%3$I (task_id, checkpoint_name, state, owner_run_id, updated_at) + select a.task_id, a.step_name, $3, a.run_id, $2 + from affected a + join updated_runs ur on ur.run_id = a.run_id + on conflict (task_id, checkpoint_name) + do update set state = excluded.state, + owner_run_id = excluded.owner_run_id, + updated_at = excluded.updated_at + ), + -- update the task table to set to pending + updated_tasks as ( + update durable.%4$I t + set state = ''pending'' + where t.task_id in (select task_id from updated_runs) + returning task_id + ) + -- delete the wait registrations that were satisfied + delete from durable.%5$I w + where w.event_name = $1 + and w.run_id in (select run_id from updated_runs)', + 'w_' || p_queue_name, + 'r_' || p_queue_name, + 'c_' || p_queue_name, + 't_' || p_queue_name, + 'w_' || p_queue_name + ) using p_event_name, v_now, v_payload; +end; +$$; diff --git a/src/types.rs b/src/types.rs index 3cfaa07..192a9c4 100644 --- a/src/types.rs +++ b/src/types.rs @@ -275,7 +275,18 @@ pub struct SpawnResultRow { #[derive(Debug, Clone, sqlx::FromRow)] pub struct AwaitEventResult { pub should_suspend: bool, - pub payload: Option, + #[sqlx(json(nullable))] + pub payload: Option, +} + +/// The wrapper type used for all durable events (including ones emitted from +/// within the durable sql itself) +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct DurableEventPayload { + /// The user-defined payload passed to `emit_event` + pub inner: JsonValue, + /// Metadata attached by durable itself + pub metadata: JsonValue, } /// Handle to a spawned subtask. diff --git a/tests/event_test.rs b/tests/event_test.rs index 36cd756..5005b78 100644 --- a/tests/event_test.rs +++ b/tests/event_test.rs @@ -4,7 +4,7 @@ mod common; use common::helpers::{get_task_state, wait_for_task_terminal}; use common::tasks::{EventEmitterParams, EventEmitterTask, EventWaitParams, EventWaitingTask}; -use durable::{Durable, MIGRATOR, RetryStrategy, SpawnOptions, WorkerOptions}; +use durable::{Durable, DurableEventPayload, MIGRATOR, RetryStrategy, SpawnOptions, WorkerOptions}; use serde_json::json; use sqlx::postgres::PgConnectOptions; use sqlx::{AssertSqlSafe, Connection, PgConnection, PgPool}; @@ -939,7 +939,12 @@ async fn test_event_race_stress(pool: PgPool) -> sqlx::Result<()> { async fn test_await_emit_event_race_does_not_lose_wakeup(pool: PgPool) -> sqlx::Result<()> { let queue = "event_race_gate"; let event_name = "race-event"; - let payload = json!({"value": 42}); + // Payload must use the DurableEventPayload format (inner/metadata structure) + let payload = serde_json::to_value(DurableEventPayload { + inner: json!({"value": 42}), + metadata: json!({}), + }) + .unwrap(); // Setup: Create queue, spawn task, claim it sqlx::query("SELECT durable.create_queue($1)") diff --git a/tests/lock_order_test.rs b/tests/lock_order_test.rs index be8cbb6..4048dea 100644 --- a/tests/lock_order_test.rs +++ b/tests/lock_order_test.rs @@ -21,7 +21,8 @@ use common::helpers::{get_task_state, single_conn_pool, wait_for_task_terminal}; use common::tasks::{ DoubleParams, DoubleTask, FailingParams, FailingTask, SleepParams, SleepingTask, }; -use durable::{Durable, MIGRATOR, RetryStrategy, SpawnOptions, WorkerOptions}; +use durable::{Durable, DurableEventPayload, MIGRATOR, RetryStrategy, SpawnOptions, WorkerOptions}; +use serde_json::json; use sqlx::postgres::{PgConnectOptions, PgConnection}; use sqlx::{AssertSqlSafe, Connection, PgPool}; use std::time::{Duration, Instant}; @@ -258,11 +259,18 @@ async fn test_emit_event_with_lock_ordering(pool: PgPool) -> sqlx::Result<()> { "Task should be sleeping waiting for event" ); - // Emit the event - let emit_query = AssertSqlSafe( - "SELECT durable.emit_event('lock_emit', 'test_event', '\"hello\"'::jsonb)".to_string(), - ); - sqlx::query(emit_query).execute(&pool).await?; + // Emit the event - payload must use DurableEventPayload format + let payload = serde_json::to_value(DurableEventPayload { + inner: json!("hello"), + metadata: json!({}), + }) + .unwrap(); + sqlx::query("SELECT durable.emit_event($1, $2, $3)") + .bind("lock_emit") + .bind("test_event") + .bind(&payload) + .execute(&pool) + .await?; // Wait for task to complete let terminal = wait_for_task_terminal( @@ -328,16 +336,22 @@ async fn test_concurrent_emit_and_cancel(pool: PgPool) -> sqlx::Result<()> { ); } - // Cancel one task while emitting the event + // Cancel one task while emitting the event - payload must use DurableEventPayload format let cancel_task_id = task_ids[0]; + let payload = serde_json::to_value(DurableEventPayload { + inner: json!("wakeup"), + metadata: json!({}), + }) + .unwrap(); let emit_handle = tokio::spawn({ let test_pool = test_pool.clone(); async move { - let emit_query = AssertSqlSafe( - "SELECT durable.emit_event('lock_emit_cancel', 'shared_event', '\"wakeup\"'::jsonb)" - .to_string(), - ); - sqlx::query(emit_query).execute(&test_pool).await + sqlx::query("SELECT durable.emit_event($1, $2, $3)") + .bind("lock_emit_cancel") + .bind("shared_event") + .bind(&payload) + .execute(&test_pool) + .await } });