diff --git a/src/specklepy/objects/base.py b/src/specklepy/objects/base.py index c4e6dca..504bb03 100644 --- a/src/specklepy/objects/base.py +++ b/src/specklepy/objects/base.py @@ -1,7 +1,7 @@ -import contextlib from dataclasses import dataclass, field from enum import Enum from inspect import isclass +from types import UnionType from typing import ( Any, ClassVar, @@ -222,23 +222,20 @@ def _validate_type(t: Optional[type], value: Any) -> Tuple[bool, Any]: if value in t._value2member_map_: return True, t(value) - if getattr(t, "__module__", None) == "typing": - if isinstance(t, ForwardRef): - return True, value + if isinstance(t, ForwardRef): + return True, value + if getattr(t, "__module__", None) in ["typing", "types"]: origin = get_origin(t) args = get_args(t) - # below is what in nicer for >= py38 - # origin = get_origin(t) - # recursive validation for Unions on both types preferring the fist type - # if origin is Union or isinstance(t, UnionType): - # for arg_t in args: - # ok, v = _validate_type(arg_t, value) - # if ok: - # return True, v - # return False, value + if origin is Union or isinstance(t, UnionType): + for arg_t in args: + ok, v = _validate_type(arg_t, value) + if ok: + return True, v + return False, value if origin is dict: if not isinstance(value, dict): return False, value @@ -311,13 +308,16 @@ def _validate_type(t: Optional[type], value: Any) -> Tuple[bool, Any]: if isinstance(value, t): return True, value - with contextlib.suppress(ValueError, TypeError): - if t is float and value is not None: - return True, float(value) - # TODO: dafuq, i had to add this not list check - # but it would also fail for objects and other complex values - if t is str and value and not isinstance(value, list): - return True, str(value) + if t is float and type(value) is int: + return True, float(value) + + # with contextlib.suppress(ValueError, TypeError): + # if t is float and value is not None: + # return True, float(value) + # # TODO: dafuq, i had to add this not list check + # # but it would also fail for objects and other complex values + # if t is str and value and not isinstance(value, list): + # return True, str(value) return False, value diff --git a/tests/unit/test_base.py b/tests/unit/test_base.py index b30c4d4..e5bf645 100644 --- a/tests/unit/test_base.py +++ b/tests/unit/test_base.py @@ -143,7 +143,7 @@ def test_type_checking() -> None: order = FrozenYoghurt() order.servings = 2 - order.price = "7" # type: ignore - it will get converted + order.price = 7 order.customer = "izzy" order.dietary = DietaryRestrictions.VEGAN order.tag = "preorder" diff --git a/tests/unit/test_type_validation.py b/tests/unit/test_type_validation.py index 7403dfb..c7c62ea 100644 --- a/tests/unit/test_type_validation.py +++ b/tests/unit/test_type_validation.py @@ -31,13 +31,13 @@ fake_bases = [FakeBase("foo"), FakeBase("bar")] @pytest.mark.parametrize( "input_type, value, is_valid, return_value", [ - (str, 10, True, "10"), + (str, 10, False, 10), (str, "foo_bar", True, "foo_bar"), ( str, {"foo": "bar"}, - True, - "{'foo': 'bar'}", + False, + {"foo": "bar"}, ), (float, 1, True, 1), # why are we allowing this??? We're lying to our users and ourselves too. @@ -85,9 +85,8 @@ fake_bases = [FakeBase("foo"), FakeBase("bar")] (Dict[int, Base], {1: test_base}, True, {1: test_base}), (Tuple[int, str, str], (1, "foo", "bar"), True, (1, "foo", "bar")), (Tuple, (1, "foo", "bar"), True, (1, "foo", "bar")), - # given our current rules, this is the reality. Its just sad... - (Tuple[str, str, str], (1, "foo", "bar"), True, ("1", "foo", "bar")), - (Tuple[str, Optional[str], str], (1, None, "bar"), True, ("1", None, "bar")), + (Tuple[str, str, str], (1, "foo", "bar"), False, (1, "foo", "bar")), + (Tuple[str, Optional[str], str], (1, None, "bar"), False, (1, None, "bar")), (Set[bool], set([1, 2]), False, set([1, 2])), (Set[int], set([1, 2]), True, set([1, 2])), (Set[int], set([None, 2]), True, set([None, 2])),