diff --git a/src/specklepy/objects/base.py b/src/specklepy/objects/base.py index 4fb9588..c4e6dca 100644 --- a/src/specklepy/objects/base.py +++ b/src/specklepy/objects/base.py @@ -13,11 +13,13 @@ from typing import ( Tuple, Type, Union, + get_origin, get_type_hints, ) from warnings import warn from pydantic.alias_generators import to_pascal +from typing_extensions import get_args from specklepy.logging.exceptions import SpeckleException from specklepy.transports.memory import MemoryTransport @@ -224,28 +226,27 @@ def _validate_type(t: Optional[type], value: Any) -> Tuple[bool, Any]: if isinstance(t, ForwardRef): return True, value - origin = t.__origin__ + origin = get_origin(t) + args = get_args(t) + # below is what in nicer for >= py38 # origin = get_origin(t) # recursive validation for Unions on both types preferring the fist type - if origin is Union: - # below is what in nicer for >= py38 - # t_1, t_2 = get_args(t) - args = t.__args__ # type: ignore - for arg_t in args: - t_success, t_value = _validate_type(arg_t, value) - if t_success: - return True, t_value - return False, value + # if origin is Union or isinstance(t, UnionType): + # for arg_t in args: + # ok, v = _validate_type(arg_t, value) + # if ok: + # return True, v + # return False, value if origin is dict: if not isinstance(value, dict): return False, value - if value == {}: + if not value: return True, value - if not getattr(t, "__args__", None): + if not args: return True, value - t_key, t_value = t.__args__ # type: ignore + t_key, t_value = args if ( getattr(t_key, "__name__", None), @@ -265,11 +266,11 @@ def _validate_type(t: Optional[type], value: Any) -> Tuple[bool, Any]: if origin is list: if not isinstance(value, list): return False, value - if value == []: + if not value: return True, value - if not hasattr(t, "__args__"): + if not args: return True, value - t_items = t.__args__[0] # type: ignore + t_items = args[0] if getattr(t_items, "__name__", None) == "T": return True, value first_item_valid, _ = _validate_type(t_items, value[0]) @@ -280,10 +281,10 @@ def _validate_type(t: Optional[type], value: Any) -> Tuple[bool, Any]: if origin is tuple: if not isinstance(value, tuple): return False, value - if not hasattr(t, "__args__"): + if not args: return True, value args = t.__args__ # type: ignore - if args == tuple(): + if not args: return True, value # we're not checking for empty tuple, cause tuple lengths must match if len(args) != len(value): @@ -299,7 +300,7 @@ def _validate_type(t: Optional[type], value: Any) -> Tuple[bool, Any]: if origin is set: if not isinstance(value, set): return False, value - if not hasattr(t, "__args__"): + if not args: return True, value t_items = t.__args__[0] # type: ignore first_item_valid, _ = _validate_type(t_items, next(iter(value))) diff --git a/tests/objects/test_text.py b/tests/objects/test_text.py index 0233395..8474cd4 100644 --- a/tests/objects/test_text.py +++ b/tests/objects/test_text.py @@ -35,7 +35,7 @@ def sample_text_all_properties(sample_point: Point, sample_plane: Plane) -> Text alignmentH=AlignmentHorizontal.Center, alignmentV=AlignmentVertical.Center, plane=sample_plane, - maxWidth=20, + maxWidth=20.0, units=Units.m, ) @@ -56,7 +56,7 @@ def test_text_creation_minimal(sample_point: Point): def test_text_creation_extended(sample_point: Point, sample_plane: Plane): text_value = "text" - max_width = 20 + max_width = 20.0 text_obj = Text( value=text_value,