diff --git a/docs/scripts/_gen_widgets.py b/docs/scripts/_gen_widgets.py index f92cde13c..daae9a2d6 100644 --- a/docs/scripts/_gen_widgets.py +++ b/docs/scripts/_gen_widgets.py @@ -77,7 +77,7 @@ def _snap_image(_obj: type, _name: str) -> str: from qtpy.QtWidgets import QVBoxLayout, QWidget outer = QWidget() - if _obj is widgets.Container: + if _obj in (widgets.Container, widgets.ModelContainerWidget): return "" if issubclass(_obj, widgets.FunctionGui): return "" diff --git a/src/magicgui/schema/_guiclass.py b/src/magicgui/schema/_guiclass.py index bd2ca1570..f7e80bad4 100644 --- a/src/magicgui/schema/_guiclass.py +++ b/src/magicgui/schema/_guiclass.py @@ -19,7 +19,7 @@ from magicgui.schema._ui_field import build_widget from magicgui.widgets import PushButton -from magicgui.widgets.bases import BaseValueWidget, ContainerWidget +from magicgui.widgets.bases import BaseValueWidget if TYPE_CHECKING: from collections.abc import Mapping @@ -27,12 +27,15 @@ from typing_extensions import TypeGuard + from magicgui.widgets._concrete import ModelContainerWidget + from magicgui.widgets.bases._container_widget import BaseContainerWidget + # fmt: off class GuiClassProtocol(Protocol): """Protocol for a guiclass.""" @property - def gui(self) -> ContainerWidget: ... + def gui(self) -> ModelContainerWidget: ... @property def events(self) -> SignalGroup: ... # fmt: on @@ -233,7 +236,7 @@ def __set_name__(self, owner: type, name: str) -> None: evented(owner, events_namespace=self._events_namespace) setattr(owner, _GUICLASS_FLAG, True) - def widget(self) -> ContainerWidget: + def widget(self) -> ModelContainerWidget: """Return a widget for the dataclass or instance.""" if self._owner is None: raise TypeError( @@ -243,7 +246,7 @@ def widget(self) -> ContainerWidget: def __get__( self, instance: object | None, owner: type - ) -> ContainerWidget[BaseValueWidget] | GuiBuilder: + ) -> ModelContainerWidget[BaseValueWidget] | GuiBuilder: if instance is None: return self wdg = build_widget(instance) @@ -253,7 +256,8 @@ def __get__( for k, v in vars(owner).items(): if hasattr(v, _BUTTON_ATTR): kwargs = getattr(v, _BUTTON_ATTR) - button = PushButton(**kwargs) + # gui_only=True excludes button from model value construction + button = PushButton(gui_only=True, **kwargs) if instance is not None: # call the bound method if we're in an instance button.clicked.connect(getattr(instance, k)) @@ -277,7 +281,7 @@ def __get__( def bind_gui_to_instance( - gui: ContainerWidget, instance: Any, two_way: bool = True + gui: BaseContainerWidget, instance: Any, two_way: bool = True ) -> None: """Set change events in `gui` to update the corresponding attributes in `model`. @@ -340,7 +344,7 @@ def bind_gui_to_instance( signals[name].connect_setattr(widget, "value") -def unbind_gui_from_instance(gui: ContainerWidget, instance: Any) -> None: +def unbind_gui_from_instance(gui: BaseContainerWidget, instance: Any) -> None: """Unbind a gui from an instance. This will disconnect all events that were connected by `bind_gui_to_instance`. diff --git a/src/magicgui/schema/_ui_field.py b/src/magicgui/schema/_ui_field.py index 40039a18b..b0741d82a 100644 --- a/src/magicgui/schema/_ui_field.py +++ b/src/magicgui/schema/_ui_field.py @@ -15,7 +15,6 @@ Literal, TypeVar, Union, - cast, ) from typing_extensions import TypeGuard, get_args, get_origin @@ -23,7 +22,7 @@ from magicgui.types import JsonStringFormats, Undefined, _Undefined if TYPE_CHECKING: - from collections.abc import Iterable, Iterator, Mapping + from collections.abc import Iterator from typing import Protocol import attrs @@ -32,7 +31,8 @@ from attrs import Attribute from pydantic.fields import FieldInfo, ModelField - from magicgui.widgets.bases import BaseValueWidget, ContainerWidget + from magicgui.widgets import ModelContainerWidget + from magicgui.widgets.bases import BaseValueWidget class HasAttrs(Protocol): """Protocol for objects that have an ``attrs`` attribute.""" @@ -441,6 +441,16 @@ def create_widget(self, value: T | _Undefined = Undefined) -> BaseValueWidget[T] opts["min"] = d["exclusive_minimum"] + m value = value if value is not Undefined else self.get_default() # type: ignore + # build a container widget from a dataclass-like object + # TODO: should this eventually move to get_widget_class? + if _is_dataclass_like(self.type): + wdg = build_widget(self.type) + wdg.name = self.name or "" + wdg.label = self.name or "" + if value is not None: + wdg.value = value + return wdg + # create widget subclass for everything else cls, kwargs = get_widget_class(value=value, annotation=self.type, options=opts) return cls(**kwargs) # type: ignore @@ -718,6 +728,24 @@ def _ui_fields_from_annotation(cls: type) -> Iterator[UiField]: yield field.parse_annotated() +def _is_dataclass_like(object: Any) -> bool: + # check if it's a pydantic1 style dataclass + model = _get_pydantic_model(object) + if model is not None: + if hasattr(model, "model_fields"): + return True + # check if it's a pydantic2 style dataclass + if hasattr(object, "__pydantic_fields__"): + return True + # check if it's a (non-pydantic) dataclass + if dc.is_dataclass(object): + return True + # check if it's an attrs class + if _is_attrs_model(object): + return True + return False + + def _iter_ui_fields(object: Any) -> Iterator[UiField]: # check if it's a pydantic model model = _get_pydantic_model(object) @@ -781,76 +809,40 @@ def get_ui_fields(cls_or_instance: object) -> tuple[UiField, ...]: return tuple(_iter_ui_fields(cls_or_instance)) -def _uifields_to_container( - ui_fields: Iterable[UiField], - values: Mapping[str, Any] | None = None, - *, - container_kwargs: Mapping | None = None, -) -> ContainerWidget[BaseValueWidget]: - """Create a container widget from a sequence of UiFields. +# TODO: unify this with magicgui +# this todo could be the same thing as moving the logic in create_widget above +# to get_widget_cls... +def build_widget(cls_or_instance: Any) -> ModelContainerWidget: + """Build a magicgui widget from a dataclass, attrs, pydantic, or function. - This function is the heart of build_widget. + Returns a ModelContainerWidget whose `.value` property returns an instance + of the model type, constructed from the current widget values. Parameters ---------- - ui_fields : Iterable[UiField] - A sequence of UiFields to use to create the container. - values : Mapping[str, Any], optional - A mapping of field name to values to use to initialize each widget the - container, by default None. - container_kwargs : Mapping, optional - A mapping of keyword arguments to pass to the container constructor, - by default None. + cls_or_instance : Any + The class or instance to build the widget from. Returns ------- - ContainerWidget[ValueWidget] - A container widget with a widget for each UiField. - """ - from magicgui import widgets - - container = widgets.Container( - widgets=[field.create_widget() for field in ui_fields], - **(container_kwargs or {}), - ) - if values is not None: - container.update(values) - return container - - -def _get_values(obj: Any) -> dict | None: - """Return a dict of values from an object. - - The object can be a dataclass, attrs, pydantic object or named tuple. + ModelContainerWidget + The constructed widget. """ - if isinstance(obj, dict): - return obj - - # named tuple - if isinstance(obj, tuple) and hasattr(obj, "_asdict"): - return cast("dict", obj._asdict()) - - # dataclass - if dc.is_dataclass(type(obj)): - return dc.asdict(obj) - - # attrs - attr = sys.modules.get("attr") - if attr is not None and attr.has(obj): - return cast("dict", attr.asdict(obj)) - - # pydantic models - if hasattr(obj, "model_dump"): - return cast("dict", obj.model_dump()) - elif hasattr(obj, "dict"): - return cast("dict", obj.dict()) - - return None + from magicgui.widgets import ModelContainerWidget + # Get the class (type) for the model + if isinstance(cls_or_instance, type): + model_type = cls_or_instance + value = None + else: + model_type = type(cls_or_instance) + value = cls_or_instance -# TODO: unify this with magicgui -def build_widget(cls_or_instance: Any) -> ContainerWidget[BaseValueWidget]: - """Build a magicgui widget from a dataclass, attrs, pydantic, or function.""" - values = None if isinstance(cls_or_instance, type) else _get_values(cls_or_instance) fields = get_ui_fields(cls_or_instance) - return _uifields_to_container(fields, values=values) + inner_widgets = [f.create_widget() for f in fields] + + return ModelContainerWidget( + value_type=model_type, + widgets=inner_widgets, + value=value, + ) diff --git a/src/magicgui/widgets/__init__.py b/src/magicgui/widgets/__init__.py index 98b9ced76..4bdd2e682 100644 --- a/src/magicgui/widgets/__init__.py +++ b/src/magicgui/widgets/__init__.py @@ -29,6 +29,7 @@ LiteralEvalLineEdit, LogSlider, MainWindow, + ModelContainerWidget, Password, ProgressBar, PushButton, @@ -92,6 +93,7 @@ "LogSlider", "MainFunctionGui", "MainWindow", + "ModelContainerWidget", "Password", "ProgressBar", "PushButton", diff --git a/src/magicgui/widgets/_concrete.py b/src/magicgui/widgets/_concrete.py index 7eb60a092..982afe102 100644 --- a/src/magicgui/widgets/_concrete.py +++ b/src/magicgui/widgets/_concrete.py @@ -11,6 +11,7 @@ import inspect import math import os +import sys from pathlib import Path from typing import ( TYPE_CHECKING, @@ -68,6 +69,7 @@ WidgetVar = TypeVar("WidgetVar", bound=Widget) WidgetTypeVar = TypeVar("WidgetTypeVar", bound=type[Widget]) _V = TypeVar("_V") +_M = TypeVar("_M") # For model/dataclass types @overload @@ -994,6 +996,103 @@ def set_value(self, vals: Sequence[Any]) -> None: self.changed.emit(self.value) +class ModelContainerWidget(ValuedContainerWidget[_M], Generic[_M]): + """A container widget for dataclass-like models (dataclass, pydantic, attrs). + + This widget wraps a structured type (dataclass, pydantic model, attrs class, etc.) + and provides a `.value` property that returns an instance of that type, constructed + from the values of its child widgets. + + Parameters + ---------- + value_type : type[_M] + The model class to construct when getting the value. + widgets : Sequence[Widget], optional + Child widgets representing the model's fields. + **kwargs : Any + Additional arguments passed to ValuedContainerWidget. + """ + + def __init__( + self, + value_type: type[_M], + widgets: Sequence[Widget] = (), + value: _M | None | _Undefined = Undefined, + **kwargs: Any, + ) -> None: + self._value_type = value_type + super().__init__(widgets=widgets, **kwargs) + # Connect child widget changes to emit our changed signal + for w in self._list: + if isinstance(w, BaseValueWidget): + w.changed.connect(self._on_child_changed) + if not isinstance(value, _Undefined): + self.set_value(value) + + def _on_child_changed(self, _: Any = None) -> None: + """Emit changed signal when any child widget changes.""" + self.changed.emit(self.value) + + def get_value(self) -> _M: + """Construct a model instance from child widget values.""" + values: dict[str, Any] = {} + for w in self._list: + if not w.name or w.gui_only: + continue + if hasattr(w, "value"): + values[w.name] = w.value + return self._value_type(**values) + + def set_value(self, value: _M | None) -> None: + """Distribute model instance values to child widgets.""" + if value is None: + return + + vals = self._get_values(value) + if vals is None: + return + with self.changed.blocked(): + for w in self._list: + if w.name and hasattr(w, "value") and w.name in vals: + w.value = vals[w.name] + + def __repr__(self) -> str: + """Return string representation.""" + return f"<{self.__class__.__name__} value_type={self._value_type.__name__!r}>" + + @staticmethod + def _get_values(obj: Any) -> dict | None: + """Return a dict of values from an object. + + The object can be a dataclass, attrs, pydantic object or named tuple. + """ + if isinstance(obj, dict): + return obj + + # named tuple + if isinstance(obj, tuple) and hasattr(obj, "_asdict"): + return cast("dict", obj._asdict()) + + import dataclasses + + # dataclass + if dataclasses.is_dataclass(type(obj)): + return dataclasses.asdict(obj) + + # attrs + attr = sys.modules.get("attr") + if attr is not None and attr.has(obj): + return cast("dict", attr.asdict(obj)) + + # pydantic models + if hasattr(obj, "model_dump"): + return cast("dict", obj.model_dump()) + elif hasattr(obj, "dict"): + return cast("dict", obj.dict()) + + return None + + @backend_widget class ToolBar(ToolBarWidget): """Toolbar that contains a set of controls.""" diff --git a/src/magicgui/widgets/bases/_container_widget.py b/src/magicgui/widgets/bases/_container_widget.py index 9a2ebc31f..e80e10b80 100644 --- a/src/magicgui/widgets/bases/_container_widget.py +++ b/src/magicgui/widgets/bases/_container_widget.py @@ -230,6 +230,14 @@ def _pop_widget(self, index: int) -> WidgetVar: del self._list[index] return item + def asdict(self) -> dict[str, Any]: + """Return state of widget as dict.""" + return { + w.name: getattr(w, "value", None) + for w in self._list + if w.name and not w.gui_only + } + class ValuedContainerWidget( BaseContainerWidget[Widget], BaseValueWidget[T], Generic[T] @@ -269,6 +277,29 @@ def __init__( if self._bound_value is not Undefined and "visible" not in base_widget_kwargs: self.hide() + def insert(self, index: int, value: Widget) -> None: + """Insert `value` (a widget) at ``index``.""" + if isinstance(value, (BaseValueWidget, BaseContainerWidget)): + value.changed.connect(lambda: self.changed.emit(self.value)) + self._insert_widget(index, value) + + def append(self, widget: Widget) -> None: + """Append a widget to the container.""" + self.insert(len(self), widget) + + def update( + self, + mapping: Mapping | Iterable[tuple[str, Any]] = (), + **kwargs: Any, + ) -> None: + """Update the parameters in the widget from a mapping, iterable, or kwargs.""" + with self.changed.blocked(): + items = mapping.items() if isinstance(mapping, Mapping) else mapping + for key, value in chain(items, kwargs.items()): + if isinstance(wdg := self._list.get_by_name(key), BaseValueWidget): + wdg.value = value + self.changed.emit(self.value) + class ContainerWidget(BaseContainerWidget[WidgetVar], MutableSequence[WidgetVar]): """Container widget that can insert/remove child widgets. @@ -330,6 +361,10 @@ def __setattr__(self, name: str, value: Any) -> None: ) object.__setattr__(self, name, value) + def append(self, value: WidgetVar) -> None: + """Append a widget to the container.""" + self.insert(len(self), value) + def index(self, value: Any, start: int = 0, stop: int = 9223372036854775807) -> int: """Return index of a specific widget instance (or widget name).""" if isinstance(value, str): @@ -372,11 +407,11 @@ def __dir__(self) -> list[str]: d.extend([w.name for w in self._list if not w.gui_only]) return d - def insert(self, key: int, widget: WidgetVar) -> None: - """Insert widget at ``key``.""" - if isinstance(widget, (BaseValueWidget, BaseContainerWidget)): - widget.changed.connect(lambda: self.changed.emit(self)) - self._insert_widget(key, widget) + def insert(self, index: int, value: WidgetVar) -> None: + """Insert widget at ``index``.""" + if isinstance(value, (BaseValueWidget, BaseContainerWidget)): + value.changed.connect(lambda: self.changed.emit(self)) + self._insert_widget(index, value) @property def __signature__(self) -> MagicSignature: @@ -424,14 +459,6 @@ def __repr__(self) -> str: NO_VALUE = "NO_VALUE" - def asdict(self) -> dict[str, Any]: - """Return state of widget as dict.""" - return { - w.name: getattr(w, "value", None) - for w in self._list - if w.name and not w.gui_only - } - def update( self, mapping: Mapping | Iterable[tuple[str, Any]] = (), diff --git a/tests/test_gui_class.py b/tests/test_gui_class.py index 3f457dd49..efaa3df48 100644 --- a/tests/test_gui_class.py +++ b/tests/test_gui_class.py @@ -16,7 +16,7 @@ is_guiclass, unbind_gui_from_instance, ) -from magicgui.widgets import Container, PushButton +from magicgui.widgets import Container, ModelContainerWidget, PushButton def test_guiclass() -> None: @@ -44,7 +44,7 @@ def func(self) -> dict: assert foo.a == 1 assert foo.b == "bar" - assert isinstance(foo.gui, Container) + assert isinstance(foo.gui, ModelContainerWidget) assert isinstance(foo.gui.func, PushButton) assert foo.gui.a.value == 1 assert foo.gui.b.value == "bar" @@ -88,7 +88,7 @@ def func(self) -> dict: assert foo.a == 1 assert foo.b == "bar" - assert isinstance(foo.gui, Container) + assert isinstance(foo.gui, ModelContainerWidget) assert isinstance(foo.gui.get_widget("func"), PushButton) assert foo.gui.a.value == 1 assert foo.gui.b.value == "bar" @@ -128,7 +128,7 @@ class Foo: foo = Foo() assert foo.a == 1 assert foo.b == "bar" - assert isinstance(foo.gui, Container) + assert isinstance(foo.gui, ModelContainerWidget) @pytest.mark.skipif( @@ -158,7 +158,7 @@ class Foo: # note that with slots=True, the gui is recreated on every access assert foo.gui is not gui - assert isinstance(gui, Container) + assert isinstance(gui, ModelContainerWidget) assert gui.a.value == 1 foo.b = "baz" assert gui.b.value == "baz" @@ -230,6 +230,6 @@ class Foo: annotation: str = "bar" foo = Foo() - assert isinstance(foo.gui, Container) + assert isinstance(foo.gui, ModelContainerWidget) foo.gui.update({"name": "baz", "annotation": "qux"}) assert asdict(foo) == {"name": "baz", "annotation": "qux"} diff --git a/tests/test_ui_field.py b/tests/test_ui_field.py index f5d7f8355..4bb77263c 100644 --- a/tests/test_ui_field.py +++ b/tests/test_ui_field.py @@ -5,7 +5,7 @@ from typing_extensions import TypedDict from magicgui.schema._ui_field import UiField, build_widget, get_ui_fields -from magicgui.widgets import Container +from magicgui.widgets import ModelContainerWidget EXPECTED = ( UiField(name="a", type=int, nullable=True), @@ -18,7 +18,7 @@ def _assert_uifields(cls, instantiate=True): result = tuple(get_ui_fields(cls)) assert result == EXPECTED wdg = build_widget(cls) - assert isinstance(wdg, Container) + assert isinstance(wdg, ModelContainerWidget) assert wdg.asdict() == { "a": 0, "b": "", @@ -28,7 +28,7 @@ def _assert_uifields(cls, instantiate=True): instance = cls(a=1, b="hi") assert tuple(get_ui_fields(instance)) == EXPECTED wdg2 = build_widget(instance) - assert isinstance(wdg2, Container) + assert isinstance(wdg2, ModelContainerWidget) assert wdg2.asdict() == { "a": 1, "b": "hi", @@ -204,9 +204,67 @@ class Foo: # assert wdg.g.max_items == 5 # TODO -def test_resolved_type(): +def test_resolved_type() -> None: f: UiField[int] = UiField(type=Annotated["int", UiField(minimum=0)]) assert f.resolved_type is int f = UiField(type="int") assert f.resolved_type is int + + +def test_nested_dataclass() -> None: + """Test nested dataclass builds ModelContainerWidget with .value support.""" + + @dataclass + class Inner: + x: int = 1 + y: str = "hello" + + @dataclass + class Outer: + inner: Inner + a: int = 5 + + wdg = build_widget(Outer) + assert isinstance(wdg, ModelContainerWidget) + + # Check nested widget is a ModelContainerWidget with correct name + assert wdg.inner.name == "inner" + assert isinstance(wdg.inner, ModelContainerWidget) + + # Check child widget values + assert wdg.inner.x.value == 0 + assert wdg.inner.y.value == "" + + # KEY FEATURE: nested container has .value that returns model instance + inner_value = wdg.inner.value + assert isinstance(inner_value, Inner) + assert inner_value.x == 0 + assert inner_value.y == "" + + # Modify values via child widgets + wdg.a.value = 10 + wdg.inner.x.value = 42 + wdg.inner.y.value = "world" + + # Check .value reflects changes + inner_value = wdg.inner.value + assert inner_value.x == 42 + assert inner_value.y == "world" + + # asdict returns model instances for nested containers + result = wdg.asdict() + assert result["a"] == 10 + assert isinstance(result["inner"], Inner) + assert result["inner"].x == 42 + assert result["inner"].y == "world" + + # Setting .value on nested container updates child widgets + wdg.inner.value = Inner(x=99, y="updated") + assert wdg.inner.x.value == 99 + assert wdg.inner.y.value == "updated" + + # update() on outer container works with model instances + wdg.update({"a": 100, "inner": Inner(x=1, y="reset")}) + assert wdg.a.value == 100 + assert wdg.inner.value == Inner(x=1, y="reset")