diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index c5b41ff45e..6ce31d374d 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -4,9 +4,8 @@ from collections.abc import Iterable from concurrent.futures import Future from functools import cached_property -from itertools import chain from pathlib import Path -from typing import Self +from typing import Any, Self from bluesky_stomp.messaging import MessageContext, StompClient from bluesky_stomp.models import Broker @@ -50,6 +49,32 @@ log = logging.getLogger(__name__) +def _pretty_type(schema: dict[str, Any]) -> str: + if "$ref" in schema: + return schema["$ref"].split("/")[-1] + + if schema.get("type") == "array": + item_schema = schema.get("items", {}) + inner = _pretty_type(item_schema) + return f"list[{inner}]" + + if "anyOf" in schema: + return " | ".join(_pretty_type(s) for s in schema["anyOf"]) + + json_type = schema.get("type") + type_map = { + "string": "str", + "integer": "int", + "boolean": "bool", + "number": "float", + "object": "dict", + } + if isinstance(json_type, str): + return type_map.get(json_type, json_type.split(".")[-1]) + + return "Any" + + class MissingInstrumentSessionError(Exception): pass @@ -154,7 +179,7 @@ def help_text(self) -> str: return self.model.description or f"Plan {self!r}" @property - def properties(self) -> set[str]: + def properties(self) -> dict[str, Any]: return self.model.parameter_schema.get("properties", {}).keys() @property @@ -192,9 +217,35 @@ def _build_args(self, *args, **kwargs): return params def __repr__(self): - opts = [p for p in self.properties if p not in self.required] - params = ", ".join(chain(self.required, (f"{opt}=None" for opt in opts))) - return f"{self.name}({params})" + props = self.model.parameter_schema.get("properties", {}) + required = set(self.required) + + tab = " " + args = [] + + for name, info in props.items(): + typ = _pretty_type(info) + arg = f"{name}: {typ}" + + if name not in required: + if "default" in info: + default = repr(info["default"]) + arg = f"{arg} = {default}" + else: + arg = f"{arg} | None = None" + + args.append(arg) + + single_line = f"{self.name}({', '.join(args)})" + max_length = 100 + max_args_inline = 3 + + if len(single_line) <= max_length and len(args) <= max_args_inline: + return single_line + + # Fall back to multiline if too many arguments or too long. + multiline_args = ",\n".join(f"{tab}{arg}" for arg in args) + return f"{self.name}(\n{multiline_args}\n)" class BlueapiClient: diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index 78682ccf6e..2fbc7386c4 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -5,7 +5,7 @@ from importlib import import_module from inspect import Parameter, isclass, signature from types import ModuleType, NoneType, UnionType -from typing import Any, Generic, TypeVar, Union, get_args, get_origin, get_type_hints +from typing import Any, TypeVar, Union, get_args, get_origin, get_type_hints from bluesky.protocols import HasName from bluesky.run_engine import RunEngine @@ -516,14 +516,16 @@ def _type_spec_for_function( ): default_factory = self._composite_factory(arg_type) _type = SkipJsonSchema[self._convert_type(arg_type, no_default)] + field_info = FieldInfo(default_factory=default_factory) else: - default_factory = DefaultFactory(para.default) _type = self._convert_type(arg_type, no_default) - factory = None if no_default else default_factory - new_args[name] = ( - _type, - FieldInfo(default_factory=factory), - ) + if no_default: + field_info = FieldInfo() + else: + field_info = FieldInfo(default=para.default) + + new_args[name] = (_type, field_info) + return new_args def _convert_type(self, typ: Any, no_default: bool = True) -> type: @@ -574,19 +576,3 @@ def _inject_composite(): return composite_class(**devices) return _inject_composite - - -D = TypeVar("D") - - -class DefaultFactory(Generic[D]): - _value: D - - def __init__(self, value: D): - self._value = value - - def __call__(self) -> D: - return self._value - - def __eq__(self, other) -> bool: - return other.__class__ == self.__class__ and self._value == other._value diff --git a/tests/system_tests/plans.json b/tests/system_tests/plans.json index 0124b86577..e55e3b474b 100644 --- a/tests/system_tests/plans.json +++ b/tests/system_tests/plans.json @@ -20,7 +20,8 @@ }, "num": { "title": "Num", - "type": "integer" + "type": "integer", + "default": 1 }, "delay": { "anyOf": [ @@ -34,12 +35,14 @@ "type": "array" } ], + "default": 0.0, "title": "Delay" }, "metadata": { "additionalProperties": true, "title": "Metadata", - "type": "object" + "type": "object", + "default": null } }, "required": [ @@ -681,7 +684,8 @@ "metadata": { "additionalProperties": true, "title": "Metadata", - "type": "object" + "type": "object", + "default": null } }, "required": [ @@ -711,11 +715,13 @@ }, "group": { "title": "Group", - "type": "string" + "type": "string", + "default": null }, "wait": { "title": "Wait", - "type": "boolean" + "type": "boolean", + "default": false } }, "required": [ @@ -745,11 +751,13 @@ }, "group": { "title": "Group", - "type": "string" + "type": "string", + "default": null }, "wait": { "title": "Wait", - "type": "boolean" + "type": "boolean", + "default": false } }, "required": [ @@ -773,7 +781,8 @@ }, "group": { "title": "Group", - "type": "string" + "type": "string", + "default": null } }, "required": [ @@ -796,7 +805,8 @@ }, "group": { "title": "Group", - "type": "string" + "type": "string", + "default": null } }, "required": [ @@ -832,11 +842,13 @@ "properties": { "group": { "title": "Group", - "type": "string" + "type": "string", + "default": null }, "timeout": { "title": "Timeout", - "type": "number" + "type": "number", + "default": null } }, "title": "wait", diff --git a/tests/unit_tests/client/test_client.py b/tests/unit_tests/client/test_client.py index a96f428e8c..111aa0c8c6 100644 --- a/tests/unit_tests/client/test_client.py +++ b/tests/unit_tests/client/test_client.py @@ -716,7 +716,72 @@ def test_plan_fallback_help_text(client): ), client, ) - assert plan.help_text == "Plan foo(one, two=None)" + assert plan.help_text == "Plan foo(one: Any, two: Any | None = None)" + + +def test_plan_multi_parameter_fallback_help_text(client): + plan = Plan( + "foo", + PlanModel( + name="foo", + schema={ + "properties": { + "one": {}, + "two": { + "anyOf": [{"items": {}, "type": "array"}, {"type": "boolean"}], + }, + "three": {"default": 3}, + "four": {"default": None}, + }, + "required": ["one", "two"], + }, + ), + client, + ) + assert ( + plan.help_text == "Plan foo(\n" + " one: Any,\n" + " two: list[Any] | bool,\n" + " three: Any = 3,\n" + " four: Any = None\n" + ")" + ) + + +def test_plan_help_text_with_ref(client): + schema = { + "$defs": { + "Spec": { + "properties": { + "foo": {"type": "integer"}, + "bar": {"$ref": "#/$defs/InnerSpec"}, + }, + "required": ["foo", "bar"], + }, + "InnerSpec": { + "properties": { + "x": {"type": "number"}, + "y": {"default": 10, "type": "number"}, + }, + "required": ["x"], + }, + }, + "properties": { + "spec": {"$ref": "#/$defs/Spec"}, + "meta": {"type": "string", "default": "abc"}, + }, + "required": ["spec"], + } + + plan = Plan( + "ref_plan", + PlanModel(name="ref_plan", schema=schema), + client, + ) + + expected = "Plan ref_plan(spec: Spec, meta: str = 'abc')" + + assert plan.help_text == expected def test_plan_properties(client): diff --git a/tests/unit_tests/core/test_context.py b/tests/unit_tests/core/test_context.py index 738a4c564c..f659d72fd4 100644 --- a/tests/unit_tests/core/test_context.py +++ b/tests/unit_tests/core/test_context.py @@ -1,9 +1,10 @@ from __future__ import annotations from dataclasses import dataclass, field +from inspect import Parameter from pathlib import Path from types import ModuleType, NoneType -from typing import Any, Generic, TypeVar, Union +from typing import Any, Generic, TypeVar, Union, get_args, get_type_hints from unittest.mock import ANY, MagicMock, Mock, patch import pytest @@ -46,7 +47,7 @@ TiledConfig, ) from blueapi.core import BlueskyContext, is_bluesky_compatible_device -from blueapi.core.context import DefaultFactory, generic_bounds, qualified_name +from blueapi.core.context import generic_bounds, qualified_name from blueapi.core.protocols import DeviceConnectResult, DeviceManager from blueapi.utils.connect_devices import _establish_device_connections from blueapi.utils.invalid_config_error import InvalidConfigError @@ -88,6 +89,10 @@ def has_some_params(foo: int = 42, bar: str = "bar") -> MsgGenerator: yield from () +def has_optional_parameter(foo: dict[str, Any] | None = None) -> MsgGenerator: + yield from () + + def has_typeless_param(foo) -> MsgGenerator: yield from () @@ -169,7 +174,9 @@ def some_configurable() -> SomeConfigurable: return SomeConfigurable() -@pytest.mark.parametrize("plan", [has_no_params, has_one_param, has_some_params]) +@pytest.mark.parametrize( + "plan", [has_no_params, has_one_param, has_some_params, has_optional_parameter] +) def test_add_plan(empty_context: BlueskyContext, plan: PlanGenerator): empty_context.register_plan(plan) assert plan.__name__ in empty_context.plans @@ -428,12 +435,23 @@ def test_with_config_passes_mock_to_with_dodal_module( mock_with_dodal_module.assert_called_once_with(ANY, mock=mock) -def test_function_spec(empty_context: BlueskyContext): +def test_function_spec_with_some_params(empty_context: BlueskyContext): spec = empty_context._type_spec_for_function(has_some_params) assert spec["foo"][0] is int - assert spec["foo"][1].default_factory == DefaultFactory(42) + assert spec["foo"][1].default == 42 assert spec["bar"][0] is str - assert spec["bar"][1].default_factory == DefaultFactory("bar") + assert spec["bar"][1].default == "bar" + + +def test_function_spec_with_optional_params(empty_context: BlueskyContext): + spec = empty_context._type_spec_for_function(has_optional_parameter) + types = get_type_hints(has_optional_parameter) + arg_type = types.get("foo", Parameter.empty) + + _type = SkipJsonSchema[empty_context._convert_type(arg_type, False)] + inner_type, *annotations = get_args(_type) + assert spec["foo"][0] == inner_type + assert spec["foo"][1].default is None def test_basic_type_conversion(empty_context: BlueskyContext): @@ -514,7 +532,7 @@ def default_movable(mov: Movable = inject("demo")) -> MsgGenerator: spec = empty_context._type_spec_for_function(default_movable) movable_ref = empty_context._reference(Movable) assert spec["mov"][0] == movable_ref - assert spec["mov"][1].default_factory == DefaultFactory("demo") + assert spec["mov"][1].default == "demo" def test_generic_default_device_reference(empty_context: BlueskyContext): @@ -524,7 +542,7 @@ def default_movable(mov: Movable[float] = inject("demo")) -> MsgGenerator: spec = empty_context._type_spec_for_function(default_movable) motor_ref = empty_context._reference(Movable[float]) assert spec["mov"][0] == motor_ref - assert spec["mov"][1].default_factory == DefaultFactory("demo") + assert spec["mov"][1].default == "demo" class ConcreteStoppable(Stoppable): @@ -574,7 +592,7 @@ def test_str_default(empty_context: BlueskyContext, sim_motor: Motor, alt_motor: spec = empty_context._type_spec_for_function(has_default_reference) assert spec["m"][0] is movable_ref - assert (df := spec["m"][1].default_factory) and df() == SIM_MOTOR_NAME # type: ignore + assert spec["m"][1].default == SIM_MOTOR_NAME assert has_default_reference.__name__ in empty_context.plans model = empty_context.plans[has_default_reference.__name__].model @@ -593,7 +611,7 @@ def test_nested_str_default( spec = empty_context._type_spec_for_function(has_default_nested_reference) assert spec["m"][0] == list[movable_ref] - assert (df := spec["m"][1].default_factory) and df() == [SIM_MOTOR_NAME] # type: ignore + assert spec["m"][1].default == [SIM_MOTOR_NAME] assert has_default_nested_reference.__name__ in empty_context.plans model = empty_context.plans[has_default_nested_reference.__name__].model @@ -697,7 +715,7 @@ def demo_plan(foo: int | None = None) -> MsgGenerator: empty_context.register_plan(demo_plan) schema = empty_context.plans["demo_plan"].model.model_json_schema() assert schema["properties"] == { - "foo": {"title": "Foo", "type": "integer"}, + "foo": {"title": "Foo", "type": "integer", "default": None}, } assert "foo" not in schema.get("required", []) @@ -725,7 +743,11 @@ def demo_plan(foo: int | str | None = None) -> MsgGenerator: empty_context.register_plan(demo_plan) schema = empty_context.plans["demo_plan"].model.model_json_schema() assert schema["properties"] == { - "foo": {"title": "Foo", "anyOf": [{"type": "integer"}, {"type": "string"}]} + "foo": { + "title": "Foo", + "anyOf": [{"type": "integer"}, {"type": "string"}], + "default": None, + } } assert "foo" not in schema.get("required", []) @@ -739,7 +761,10 @@ def demo_plan(foo: int | None) -> MsgGenerator: empty_context.register_plan(demo_plan) schema = empty_context.plans["demo_plan"].model.model_json_schema() assert schema["properties"] == { - "foo": {"title": "Foo", "anyOf": [{"type": "integer"}, {"type": "null"}]} + "foo": { + "title": "Foo", + "anyOf": [{"type": "integer"}, {"type": "null"}], + } } assert "foo" in schema.get("required", [])