Skip to content
42 changes: 42 additions & 0 deletions temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1924,6 +1924,13 @@ async def run_activity() -> Any:
raise
# Send a cancel request to the activity
handle._apply_cancel_command(self._add_command())
# Clear the cancellation counter on Python 3.11+ so the
# next await does not immediately re-raise CancelledError
if (
sys.version_info >= (3, 11)
and (t := asyncio.current_task()) is not None
):
t.uncancel() # type: ignore[union-attr]

# Create the handle and set as pending
handle = _ActivityHandle(self, input, run_activity())
Expand Down Expand Up @@ -2008,6 +2015,13 @@ async def run_child() -> Any:
return await asyncio.shield(handle._result_fut)
except asyncio.CancelledError:
apply_child_cancel_error()
# Clear the cancellation counter on Python 3.11+ so the
# next await does not immediately re-raise CancelledError
if (
sys.version_info >= (3, 11)
and (t := asyncio.current_task()) is not None
):
t.uncancel() # type: ignore[union-attr]

# Create the handle and set as pending
handle = _ChildWorkflowHandle(
Expand All @@ -2025,6 +2039,13 @@ async def run_child() -> Any:
return handle
except asyncio.CancelledError:
apply_child_cancel_error()
# Clear the cancellation counter on Python 3.11+ so the
# next await does not immediately re-raise CancelledError
if (
sys.version_info >= (3, 11)
and (t := asyncio.current_task()) is not None
):
t.uncancel() # type: ignore[union-attr]
if self._cancel_requested:
raise

Expand Down Expand Up @@ -2053,6 +2074,13 @@ async def operation_handle_fn() -> OutputT:
except asyncio.CancelledError:
cancel_command = self._add_command()
handle._apply_cancel_command(cancel_command)
# Clear the cancellation counter on Python 3.11+ so the
# next await does not immediately re-raise CancelledError
if (
sys.version_info >= (3, 11)
and (t := asyncio.current_task()) is not None
):
t.uncancel() # type: ignore[union-attr]

handle = _NexusOperationHandle(
self, self._next_seq("nexus_operation"), input, operation_handle_fn()
Expand All @@ -2067,6 +2095,13 @@ async def operation_handle_fn() -> OutputT:
except asyncio.CancelledError:
cancel_command = self._add_command()
handle._apply_cancel_command(cancel_command)
# Clear the cancellation counter on Python 3.11+ so the
# next await does not immediately re-raise CancelledError
if (
sys.version_info >= (3, 11)
and (t := asyncio.current_task()) is not None
):
t.uncancel() # type: ignore[union-attr]
if self._cancel_requested:
raise

Expand Down Expand Up @@ -2599,6 +2634,13 @@ async def _signal_external_workflow(
except asyncio.CancelledError:
cancel_command = self._add_command()
cancel_command.cancel_signal_workflow.seq = seq
# Clear the cancellation counter on Python 3.11+ so the
# next await does not immediately re-raise CancelledError
if (
sys.version_info >= (3, 11)
and (t := asyncio.current_task()) is not None
):
t.uncancel() # type: ignore[union-attr]

def _stack_trace(self) -> str:
stacks = []
Expand Down
114 changes: 67 additions & 47 deletions tests/nexus/test_workflow_caller_cancellation_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import logging
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
Expand All @@ -9,6 +10,7 @@
import pytest

import temporalio.nexus._operation_handlers
import temporalio.worker._workflow_instance
from temporalio import exceptions, nexus, workflow
from temporalio.api.enums.v1 import EventType
from temporalio.client import (
Expand All @@ -20,7 +22,7 @@
from temporalio.common import WorkflowIDConflictPolicy
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import Worker
from tests.helpers import assert_eventually
from tests.helpers import LogCapturer, assert_eventually
from tests.helpers.nexus import make_nexus_endpoint_name


Expand Down Expand Up @@ -268,54 +270,72 @@ async def test_cancellation_type(

client = env.client

async with Worker(
client,
task_queue=str(uuid.uuid4()),
workflows=[CallerWorkflow, HandlerWorkflow],
nexus_service_handlers=[ServiceHandler()],
) as worker:
await env.create_nexus_endpoint(
make_nexus_endpoint_name(worker.task_queue), worker.task_queue
)
log_capturer = LogCapturer()
with log_capturer.logs_captured(
temporalio.worker._workflow_instance.logger, level=logging.WARNING
):
async with Worker(
client,
task_queue=str(uuid.uuid4()),
workflows=[CallerWorkflow, HandlerWorkflow],
nexus_service_handlers=[ServiceHandler()],
) as worker:
await env.create_nexus_endpoint(
make_nexus_endpoint_name(worker.task_queue), worker.task_queue
)

# Start the caller workflow, wait for the nexus op to have started and retrieve the nexus op
# token
with_start_workflow = WithStartWorkflowOperation(
CallerWorkflow.run,
Input(
endpoint=make_nexus_endpoint_name(worker.task_queue),
cancellation_type=cancellation_type,
),
id=test_context.caller_workflow_id,
task_queue=worker.task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
)
# Start the caller workflow, wait for the nexus op to have started and retrieve the nexus op
# token
with_start_workflow = WithStartWorkflowOperation(
CallerWorkflow.run,
Input(
endpoint=make_nexus_endpoint_name(worker.task_queue),
cancellation_type=cancellation_type,
),
id=test_context.caller_workflow_id,
task_queue=worker.task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
)

operation_token = await client.execute_update_with_start_workflow(
CallerWorkflow.get_operation_token,
start_workflow_operation=with_start_workflow,
)
handler_wf = (
nexus.WorkflowHandle[None]
.from_token(operation_token)
._to_client_workflow_handle(client)
)
caller_wf = await with_start_workflow.workflow_handle()

if cancellation_type == workflow.NexusOperationCancellationType.ABANDON:
await check_behavior_for_abandon(caller_wf, handler_wf)
elif cancellation_type == workflow.NexusOperationCancellationType.TRY_CANCEL:
await check_behavior_for_try_cancel(caller_wf, handler_wf)
elif (
cancellation_type == workflow.NexusOperationCancellationType.WAIT_REQUESTED
):
await check_behavior_for_wait_cancellation_requested(caller_wf, handler_wf)
elif (
cancellation_type == workflow.NexusOperationCancellationType.WAIT_COMPLETED
):
await check_behavior_for_wait_cancellation_completed(caller_wf, handler_wf)
else:
pytest.fail(f"Invalid cancellation type: {cancellation_type}")
operation_token = await client.execute_update_with_start_workflow(
CallerWorkflow.get_operation_token,
start_workflow_operation=with_start_workflow,
)
handler_wf = (
nexus.WorkflowHandle[None]
.from_token(operation_token)
._to_client_workflow_handle(client)
)
caller_wf = await with_start_workflow.workflow_handle()

if cancellation_type == workflow.NexusOperationCancellationType.ABANDON:
await check_behavior_for_abandon(caller_wf, handler_wf)
elif (
cancellation_type == workflow.NexusOperationCancellationType.TRY_CANCEL
):
await check_behavior_for_try_cancel(caller_wf, handler_wf)
elif (
cancellation_type
== workflow.NexusOperationCancellationType.WAIT_REQUESTED
):
await check_behavior_for_wait_cancellation_requested(
caller_wf, handler_wf
)
elif (
cancellation_type
== workflow.NexusOperationCancellationType.WAIT_COMPLETED
):
await check_behavior_for_wait_cancellation_completed(
caller_wf, handler_wf
)
else:
pytest.fail(f"Invalid cancellation type: {cancellation_type}")

# Verify no spurious "exception in shielded future" error logs
shielded_err = log_capturer.find_log("exception in shielded future")
assert shielded_err is None, (
f"Unexpected 'exception in shielded future' log: {shielded_err}"
)


async def check_behavior_for_abandon(
Expand Down
Loading