Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 57 additions & 6 deletions src/blueapi/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
from collections.abc import Iterable
from concurrent.futures import Future
from functools import cached_property
from itertools import chain
from pathlib import Path
from typing import Self
from typing import Any, Self

from bluesky_stomp.messaging import MessageContext, StompClient
from bluesky_stomp.models import Broker
Expand Down Expand Up @@ -50,6 +49,32 @@
log = logging.getLogger(__name__)


def _pretty_type(schema: dict[str, Any]) -> str:
if "$ref" in schema:
return schema["$ref"].split("/")[-1]

if schema.get("type") == "array":
item_schema = schema.get("items", {})
inner = _pretty_type(item_schema)
return f"list[{inner}]"

if "anyOf" in schema:
return " | ".join(_pretty_type(s) for s in schema["anyOf"])

json_type = schema.get("type")
type_map = {
"string": "str",
"integer": "int",
"boolean": "bool",
"number": "float",
"object": "dict",
}
if isinstance(json_type, str):
return type_map.get(json_type, json_type.split(".")[-1])

return "Any"


class MissingInstrumentSessionError(Exception):
pass

Expand Down Expand Up @@ -154,7 +179,7 @@ def help_text(self) -> str:
return self.model.description or f"Plan {self!r}"

@property
def properties(self) -> set[str]:
def properties(self) -> dict[str, Any]:
return self.model.parameter_schema.get("properties", {}).keys()

@property
Expand Down Expand Up @@ -192,9 +217,35 @@ def _build_args(self, *args, **kwargs):
return params

def __repr__(self):
opts = [p for p in self.properties if p not in self.required]
params = ", ".join(chain(self.required, (f"{opt}=None" for opt in opts)))
return f"{self.name}({params})"
props = self.model.parameter_schema.get("properties", {})
required = set(self.required)

tab = " "
args = []

for name, info in props.items():
typ = _pretty_type(info)
arg = f"{name}: {typ}"

if name not in required:
if "default" in info:
default = repr(info["default"])
arg = f"{arg} = {default}"
else:
arg = f"{arg} | None = None"

args.append(arg)

single_line = f"{self.name}({', '.join(args)})"
max_length = 100
max_args_inline = 3

if len(single_line) <= max_length and len(args) <= max_args_inline:
return single_line

# Fall back to multiline if too many arguments or too long.
multiline_args = ",\n".join(f"{tab}{arg}" for arg in args)
return f"{self.name}(\n{multiline_args}\n)"


class BlueapiClient:
Expand Down
32 changes: 9 additions & 23 deletions src/blueapi/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from importlib import import_module
from inspect import Parameter, isclass, signature
from types import ModuleType, NoneType, UnionType
from typing import Any, Generic, TypeVar, Union, get_args, get_origin, get_type_hints
from typing import Any, TypeVar, Union, get_args, get_origin, get_type_hints

from bluesky.protocols import HasName
from bluesky.run_engine import RunEngine
Expand Down Expand Up @@ -516,14 +516,16 @@ def _type_spec_for_function(
):
default_factory = self._composite_factory(arg_type)
_type = SkipJsonSchema[self._convert_type(arg_type, no_default)]
field_info = FieldInfo(default_factory=default_factory)
else:
default_factory = DefaultFactory(para.default)
_type = self._convert_type(arg_type, no_default)
factory = None if no_default else default_factory
new_args[name] = (
_type,
FieldInfo(default_factory=factory),
)
if no_default:
field_info = FieldInfo()
else:
field_info = FieldInfo(default=para.default)

new_args[name] = (_type, field_info)

return new_args

def _convert_type(self, typ: Any, no_default: bool = True) -> type:
Expand Down Expand Up @@ -574,19 +576,3 @@ def _inject_composite():
return composite_class(**devices)

return _inject_composite


D = TypeVar("D")


class DefaultFactory(Generic[D]):
_value: D

def __init__(self, value: D):
self._value = value

def __call__(self) -> D:
return self._value

def __eq__(self, other) -> bool:
return other.__class__ == self.__class__ and self._value == other._value
34 changes: 23 additions & 11 deletions tests/system_tests/plans.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
},
"num": {
"title": "Num",
"type": "integer"
"type": "integer",
"default": 1
},
"delay": {
"anyOf": [
Expand All @@ -34,12 +35,14 @@
"type": "array"
}
],
"default": 0.0,
"title": "Delay"
},
"metadata": {
"additionalProperties": true,
"title": "Metadata",
"type": "object"
"type": "object",
"default": null
}
},
"required": [
Expand Down Expand Up @@ -681,7 +684,8 @@
"metadata": {
"additionalProperties": true,
"title": "Metadata",
"type": "object"
"type": "object",
"default": null
}
},
"required": [
Expand Down Expand Up @@ -711,11 +715,13 @@
},
"group": {
"title": "Group",
"type": "string"
"type": "string",
"default": null
},
"wait": {
"title": "Wait",
"type": "boolean"
"type": "boolean",
"default": false
}
},
"required": [
Expand Down Expand Up @@ -745,11 +751,13 @@
},
"group": {
"title": "Group",
"type": "string"
"type": "string",
"default": null
},
"wait": {
"title": "Wait",
"type": "boolean"
"type": "boolean",
"default": false
}
},
"required": [
Expand All @@ -773,7 +781,8 @@
},
"group": {
"title": "Group",
"type": "string"
"type": "string",
"default": null
}
},
"required": [
Expand All @@ -796,7 +805,8 @@
},
"group": {
"title": "Group",
"type": "string"
"type": "string",
"default": null
}
},
"required": [
Expand Down Expand Up @@ -832,11 +842,13 @@
"properties": {
"group": {
"title": "Group",
"type": "string"
"type": "string",
"default": null
},
"timeout": {
"title": "Timeout",
"type": "number"
"type": "number",
"default": null
}
},
"title": "wait",
Expand Down
67 changes: 66 additions & 1 deletion tests/unit_tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,72 @@ def test_plan_fallback_help_text(client):
),
client,
)
assert plan.help_text == "Plan foo(one, two=None)"
assert plan.help_text == "Plan foo(one: Any, two: Any | None = None)"


def test_plan_multi_parameter_fallback_help_text(client):
plan = Plan(
"foo",
PlanModel(
name="foo",
schema={
"properties": {
"one": {},
"two": {
"anyOf": [{"items": {}, "type": "array"}, {"type": "boolean"}],
},
"three": {"default": 3},
"four": {"default": None},
},
"required": ["one", "two"],
},
),
client,
)
assert (
plan.help_text == "Plan foo(\n"
" one: Any,\n"
" two: list[Any] | bool,\n"
" three: Any = 3,\n"
" four: Any = None\n"
")"
)


def test_plan_help_text_with_ref(client):
schema = {
"$defs": {
"Spec": {
"properties": {
"foo": {"type": "integer"},
"bar": {"$ref": "#/$defs/InnerSpec"},
},
"required": ["foo", "bar"],
},
"InnerSpec": {
"properties": {
"x": {"type": "number"},
"y": {"default": 10, "type": "number"},
},
"required": ["x"],
},
},
"properties": {
"spec": {"$ref": "#/$defs/Spec"},
"meta": {"type": "string", "default": "abc"},
},
"required": ["spec"],
}

plan = Plan(
"ref_plan",
PlanModel(name="ref_plan", schema=schema),
client,
)

expected = "Plan ref_plan(spec: Spec, meta: str = 'abc')"

assert plan.help_text == expected


def test_plan_properties(client):
Expand Down
Loading
Loading