diff --git a/gm_assistant/oracle/__init__.py b/gm_assistant/oracle/__init__.py index b1ae0c6..b014a52 100644 --- a/gm_assistant/oracle/__init__.py +++ b/gm_assistant/oracle/__init__.py @@ -2,3 +2,29 @@ # SPDX-FileContributor: Gergely Polonkai # # SPDX-License-Identifier: GPL-3.0-or-later +"""Oracle classes and related functions""" + +from typing import Any, Type + +from .base import Oracle +from .object_generator import ObjectGeneratorOracle +from .random_choice import RandomChoiceOracle + + +def generate_type_classes(class_list: dict[str, Any]) -> dict[str, Type[Oracle]]: + """Generate a dictionary of oracle type handlers""" + + ret: dict[str, Type[Oracle]] = {} + + for klass in class_list.values(): + if not isinstance(klass, type) or klass == Oracle or not issubclass(klass, Oracle): + continue + + if klass.TYPE_MARKER in ret: + raise KeyError( + f"{ret[klass.TYPE_MARKER].__name__} is already registered as a handler for {klass.TYPE_MARKER}" + ) + + ret[klass.TYPE_MARKER] = klass + + return ret diff --git a/tests/test_oracle_type_class_lister.py b/tests/test_oracle_type_class_lister.py new file mode 100644 index 0000000..90438e0 --- /dev/null +++ b/tests/test_oracle_type_class_lister.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: 2025 2025 +# SPDX-FileContributor: Gergely Polonkai +# +# SPDX-License-Identifier: GPL-3.0-or-later +"""Tests for the type class lister""" + +import pytest + +from gm_assistant.oracle import generate_type_classes +from gm_assistant.oracle.base import Oracle +from gm_assistant.oracle.object_generator import ObjectGeneratorOracle + + +class _TestOracle(Oracle): + """Test oracle class that has the same marker as ObjectGeneratorOracle""" + + TYPE_MARKER = ObjectGeneratorOracle.TYPE_MARKER + + def generate(self) -> str: # pragma: no cover + return "" + + +def test_generate_empty() -> None: + """Test generating the type class list from an empty dictionary""" + + assert generate_type_classes({}) == {} # pylint: disable=use-implicit-booleaness-not-comparison + + +def test_nontype_not_present() -> None: + """Test that non-types don’t get included in the results""" + + assert generate_type_classes({"test": True}) == {} # pylint: disable=use-implicit-booleaness-not-comparison + + +def test_non_oracle_not_present() -> None: + """Test that non-oracle types don’t get included in the results""" + + assert generate_type_classes({"test": dict}) == {} # pylint: disable=use-implicit-booleaness-not-comparison + + +def test_oracle_not_present() -> None: + """Test that the ``Oracle`` class doesn’t get included in the results""" + + assert generate_type_classes({"oracle": Oracle}) == {} # pylint: disable=use-implicit-booleaness-not-comparison + + +def test_duplace_type_marker() -> None: + """Test if ``generate_type_classes`` raises an error if a type marker appears twice""" + + with pytest.raises(KeyError) as ctx: + generate_type_classes({"ObjectGeneratorOracle": ObjectGeneratorOracle, "TestOracle": _TestOracle}) + + assert str(ctx.value) == "'ObjectGeneratorOracle is already registered as a handler for object-generator'"