diff --git a/cuda_core/tests/test_object_protocols.py b/cuda_core/tests/test_object_protocols.py index fd0859855a..bd92ad0696 100644 --- a/cuda_core/tests/test_object_protocols.py +++ b/cuda_core/tests/test_object_protocols.py @@ -1,10 +1,11 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 """ -Tests for Python object protocols (__eq__, __hash__, __weakref__, __repr__). +Tests for Python object protocols (__eq__, __hash__, __weakref__, __repr__, pickle). This module tests that core cuda.core classes properly implement standard Python -object protocols for identity, hashing, weak references, and string representation. +object protocols for identity, hashing, weak references, string representation, +and serialization. """ import itertools @@ -15,7 +16,17 @@ from helpers.graph_kernels import compile_common_kernels from helpers.misc import try_create_condition -from cuda.core import Buffer, Device, Kernel, LaunchConfig, Program, Stream, system +from cuda.core import ( + Buffer, + Device, + DeviceMemoryResource, + DeviceMemoryResourceOptions, + Kernel, + LaunchConfig, + Program, + Stream, + system, +) from cuda.core._graph._graphdef import GraphDef from cuda.core._program import _can_load_generated_ptx @@ -208,6 +219,30 @@ def sample_kernel_alt(sample_object_code_alt): return sample_object_code_alt.get_kernel("test_kernel_alt") +# ============================================================================= +# Fixtures - IPC samples (for pickle tests) +# ============================================================================= + +POOL_SIZE = 2097152 + + +@pytest.fixture +def sample_ipc_buffer_descriptor(ipc_device): + """An IPCBufferDescriptor.""" + options = DeviceMemoryResourceOptions(max_size=POOL_SIZE, ipc_enabled=True) + mr = DeviceMemoryResource(ipc_device, options=options) + buf = mr.allocate(64) + return buf.get_ipc_descriptor() + + +@pytest.fixture +def sample_ipc_event_descriptor(ipc_device): + """An IPCEventDescriptor.""" + stream = ipc_device.create_stream() + e = stream.record(options={"ipc_enabled": True}) + return e.get_ipc_descriptor() + + # ============================================================================= # Fixtures - Graph types (GraphDef and GraphNode) # ============================================================================= @@ -606,6 +641,20 @@ def sample_switch_node_alt(sample_graphdef): ("sample_kernel", lambda k: Kernel.from_handle(int(k.handle))), ] +# Types with __reduce__ support (pickle/cloudpickle). +# Event, Buffer, and memory resources are excluded: Event only supports +# IPC-based serialization via multiprocessing reduction; Buffer and memory +# resource __reduce__ use a cross-process registry that doesn't support +# same-process roundtrips. +PICKLE_TYPES = [ + "sample_device", + "sample_object_code_cubin", + "sample_ipc_buffer_descriptor", + "sample_ipc_event_descriptor", +] + +PICKLE_MODULES = ["pickle", "cloudpickle"] + # Derived type groupings for collection tests DICT_KEY_TYPES = sorted(set(HASH_TYPES) & set(EQ_TYPES)) WEAK_KEY_TYPES = sorted(set(HASH_TYPES) & set(EQ_TYPES) & set(WEAKREF_TYPES)) @@ -796,3 +845,18 @@ def test_repr_format(fixture_name, pattern, request): obj = request.getfixturevalue(fixture_name) result = repr(obj) assert re.fullmatch(pattern, result) + + +# ============================================================================= +# Pickle tests +# ============================================================================= + + +@pytest.mark.parametrize("pickle_module", PICKLE_MODULES) +@pytest.mark.parametrize("fixture_name", PICKLE_TYPES) +def test_pickle_roundtrip(fixture_name, pickle_module, request): + """Object survives a pickle/cloudpickle roundtrip.""" + mod = pytest.importorskip(pickle_module) + obj = request.getfixturevalue(fixture_name) + result = mod.loads(mod.dumps(obj)) + assert type(result) is type(obj)