diff --git a/pkgs/clan-app/clan_app/views/webview.py b/pkgs/clan-app/clan_app/views/webview.py index 054c6303..ade90b0b 100644 --- a/pkgs/clan-app/clan_app/views/webview.py +++ b/pkgs/clan-app/clan_app/views/webview.py @@ -4,9 +4,11 @@ import logging import sys import threading from collections.abc import Callable +from dataclasses import fields, is_dataclass from pathlib import Path from threading import Lock -from typing import Any +from types import UnionType +from typing import Any, get_args import gi from clan_cli.api import API @@ -37,7 +39,7 @@ def dataclass_to_dict(obj: Any) -> Any: elif isinstance(obj, dict): return {k: dataclass_to_dict(v) for k, v in obj.items()} else: - return obj + return str(obj) # Implement the abstract open_file function @@ -135,6 +137,58 @@ def open_file(file_request: FileRequest) -> str | None: return selected_path +def is_union_type(type_hint: type) -> bool: + return type(type_hint) is UnionType + + +def get_inner_type(type_hint: type) -> type: + if is_union_type(type_hint): + # Return the first non-None type + return next(t for t in get_args(type_hint) if t is not type(None)) + return type_hint + + +def from_dict(t: type, data: dict[str, Any] | None) -> Any: + """ + Dynamically instantiate a data class from a dictionary, handling nested data classes. + """ + if not data: + return None + + try: + # Attempt to create an instance of the data_class + field_values = {} + for field in fields(t): + field_value = data.get(field.name) + field_type = get_inner_type(field.type) + if field_value is not None: + # If the field is another dataclass, recursively instantiate it + if is_dataclass(field_type): + field_value = from_dict(field_type, field_value) + elif isinstance(field_type, Path | str) and isinstance( + field_value, str + ): + field_value = ( + Path(field_value) if field_type == Path else field_value + ) + + if ( + field.default is not dataclasses.MISSING + or field.default_factory is not dataclasses.MISSING + ): + # Field has a default value. We cannot set the value to None + if field_value is not None: + field_values[field.name] = field_value + else: + field_values[field.name] = field_value + + return t(**field_values) + + except (TypeError, ValueError) as e: + print(f"Failed to instantiate {t.__name__}: {e}") + return None + + class WebView: def __init__(self, methods: dict[str, Callable]) -> None: self.method_registry: dict[str, Callable] = methods @@ -216,14 +270,13 @@ class WebView: # But the js api returns dictionaries. # Introspect the function and create the expected dataclass from dict dynamically # Depending on the introspected argument_type - arg_type = API.get_method_argtype(method_name, k) - if dataclasses.is_dataclass(arg_type): - reconciled_arguments[k] = arg_type(**v) + arg_class = API.get_method_argtype(method_name, k) + if dataclasses.is_dataclass(arg_class): + reconciled_arguments[k] = from_dict(arg_class, v) else: reconciled_arguments[k] = v result = handler_fn(**reconciled_arguments) - serialized = json.dumps(dataclass_to_dict(result)) # Use idle_add to queue the response call to js on the main GTK thread diff --git a/pkgs/clan-cli/clan_cli/api/util.py b/pkgs/clan-cli/clan_cli/api/util.py index d356b28a..84f96691 100644 --- a/pkgs/clan-cli/clan_cli/api/util.py +++ b/pkgs/clan-cli/clan_cli/api/util.py @@ -1,6 +1,7 @@ import copy import dataclasses import pathlib +from dataclasses import MISSING from types import NoneType, UnionType from typing import ( Annotated, @@ -77,20 +78,29 @@ def type_to_dict(t: Any, scope: str = "", type_map: dict[TypeVar, type] = {}) -> for f in fields } - required = [] + required = set() for pn, pv in properties.items(): if pv.get("type") is not None: if "null" not in pv["type"]: - required.append(pn) + required.add(pn) elif pv.get("oneOf") is not None: if "null" not in [i["type"] for i in pv.get("oneOf", [])]: - required.append(pn) + required.add(pn) + + required_fields = { + f.name + for f in fields + if f.default is MISSING and f.default_factory is MISSING + } + + # Find intersection + intersection = required & required_fields return { "type": "object", "properties": properties, - "required": required, + "required": list(intersection), # Dataclasses can only have the specified properties "additionalProperties": False, }