From b9bf45373154fb3eee6f2b403aee720d94724fbc Mon Sep 17 00:00:00 2001 From: Johannes Kirschbauer Date: Wed, 5 Jun 2024 09:43:45 +0200 Subject: [PATCH] clan-api: add support for generic types & annotated types --- pkgs/clan-cli/clan_cli/api/util.py | 141 ++++++++++++++++++++++++++--- 1 file changed, 126 insertions(+), 15 deletions(-) diff --git a/pkgs/clan-cli/clan_cli/api/util.py b/pkgs/clan-cli/clan_cli/api/util.py index 277a273e..57e02a24 100644 --- a/pkgs/clan-cli/clan_cli/api/util.py +++ b/pkgs/clan-cli/clan_cli/api/util.py @@ -1,20 +1,92 @@ +import copy import dataclasses import pathlib from types import NoneType, UnionType -from typing import Any, Union +from typing import ( + Annotated, + Any, + Literal, + TypeVar, + Union, + get_args, + get_origin, +) -def type_to_dict(t: Any, scope: str = "") -> dict: +class JSchemaTypeError(Exception): + pass + + +# Inspect the fields of the parameterized type +def inspect_dataclass_fields(t: type) -> dict[TypeVar, type]: + """ + Returns a map of type variables to actual types for a parameterized type. + """ + origin = get_origin(t) + type_args = get_args(t) + if origin is None: + return {} + + type_params = origin.__parameters__ + # Create a map from type parameters to actual type arguments + type_map = dict(zip(type_params, type_args)) + + return type_map + + +def apply_annotations(schema: dict[str, Any], annotations: list[Any]) -> dict[str, Any]: + """ + Add metadata from typing.annotations to the json Schema. + The annotations can be a dict, a tuple, or a string and is directly applied to the schema as shown below. + No further validation is done, the caller is responsible for following json-schema. + + Examples + + ```python + # String annotation + Annotated[int, "This is an int"] -> {"type": "integer", "description": "This is an int"} + + # Dict annotation + Annotated[int, {"minimum": 0, "maximum": 10}] -> {"type": "integer", "minimum": 0, "maximum": 10} + + # Tuple annotation + Annotated[int, ("minimum", 0)] -> {"type": "integer", "minimum": 0} + ``` + """ + for annotation in annotations: + if isinstance(annotation, dict): + # Assuming annotation is a dict that can directly apply to the schema + schema.update(annotation) + elif isinstance(annotation, tuple) and len(annotation) == 2: + # Assuming a tuple where first element is a keyword (like 'minLength') and the second is the value + schema[annotation[0]] = annotation[1] + elif isinstance(annotation, str): + # String annotations can be used for description + schema.update({"description": f"{annotation}"}) + return schema + + +def type_to_dict(t: Any, scope: str = "", type_map: dict[TypeVar, type] = {}) -> dict: if t is None: return {"type": "null"} if dataclasses.is_dataclass(t): fields = dataclasses.fields(t) properties = { - f.name: type_to_dict(f.type, f"{scope} {t.__name__}.{f.name}") + f.name: type_to_dict(f.type, f"{scope} {t.__name__}.{f.name}", type_map) for f in fields } - required = [pn for pn, pv in properties.items() if "null" not in pv["type"]] + + required = [] + for pn, pv in properties.items(): + if pv.get("type") is not None: + if "null" not in pv["type"]: + required.append(pn) + + elif pv.get("oneOf") is not None: + if "null" not in [i["type"] for i in pv.get("oneOf", [])]: + required.append(pn) + return { "type": "object", "properties": properties, @@ -22,24 +94,54 @@ def type_to_dict(t: Any, scope: str = "") -> dict: # Dataclasses can only have the specified properties "additionalProperties": False, } + elif type(t) is UnionType: return { - "type": [type_to_dict(arg, scope)["type"] for arg in t.__args__], + "oneOf": [type_to_dict(arg, scope, type_map) for arg in t.__args__], } + if isinstance(t, TypeVar): + # if t is a TypeVar, look up the type in the type_map + # And return the resolved type instead of the TypeVar + resolved = type_map.get(t) + if not resolved: + raise JSchemaTypeError( + f"{scope} - TypeVar {t} not found in type_map, map: {type_map}" + ) + return type_to_dict(type_map.get(t), scope, type_map) + elif hasattr(t, "__origin__"): # Check if it's a generic type - origin = getattr(t, "__origin__", None) + origin = get_origin(t) + args = get_args(t) if origin is None: # Non-generic user-defined or built-in type # TODO: handle custom types - raise BaseException("Unhandled Type: ", origin) + raise JSchemaTypeError("Unhandled Type: ", origin) + + elif origin is Literal: + # Handle Literal values for enums in JSON Schema + return { + "type": "string", + "enum": list(args), # assumes all args are strings + } + + elif origin is Annotated: + base_type, *metadata = get_args(t) + schema = type_to_dict(base_type, scope) # Generate schema for the base type + return apply_annotations(schema, metadata) elif origin is Union: - return {"type": [type_to_dict(arg, scope)["type"] for arg in t.__args__]} + union_types = [type_to_dict(arg, scope, type_map) for arg in t.__args__] + return { + "oneOf": union_types, + } - elif issubclass(origin, list): - return {"type": "array", "items": type_to_dict(t.__args__[0], scope)} + elif origin in {list, set, frozenset}: + return { + "type": "array", + "items": type_to_dict(t.__args__[0], scope, type_map), + } elif issubclass(origin, dict): value_type = t.__args__[1] @@ -48,10 +150,19 @@ def type_to_dict(t: Any, scope: str = "") -> dict: else: return { "type": "object", - "additionalProperties": type_to_dict(value_type, scope), + "additionalProperties": type_to_dict(value_type, scope, type_map), } + # Generic dataclass with type parameters + elif dataclasses.is_dataclass(origin): + # This behavior should mimic the scoping of typeVars in dataclasses + # Once type_to_dict() encounters a TypeVar, it will look up the type in the type_map + # When type_to_dict() returns the map goes out of scope. + # This behaves like a stack, where the type_map is pushed and popped as we traverse the dataclass fields + new_map = copy.deepcopy(type_map) + new_map.update(inspect_dataclass_fields(t)) + return type_to_dict(origin, scope, new_map) - raise BaseException(f"Error api type not yet supported {t!s}") + raise JSchemaTypeError(f"Error api type not yet supported {t!s}") elif isinstance(t, type): if t is str: @@ -65,7 +176,7 @@ def type_to_dict(t: Any, scope: str = "") -> dict: if t is object: return {"type": "object"} if t is Any: - raise BaseException( + raise JSchemaTypeError( f"Usage of the Any type is not supported for API functions. In: {scope}" ) @@ -79,6 +190,6 @@ def type_to_dict(t: Any, scope: str = "") -> dict: if t is NoneType: return {"type": "null"} - raise BaseException(f"Error primitive type not supported {t!s}") + raise JSchemaTypeError(f"Error primitive type not supported {t!s}") else: - raise BaseException(f"Error type not supported {t!s}") + raise JSchemaTypeError(f"Error type not supported {t!s}")