XylotrechusZ
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, Mapping, TypeVar, Union
from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator
from typing_extensions import Literal
from ..errors import PydanticErrorCodes, PydanticUserError
from ..plugin._schema_validator import PluggableSchemaValidator
if TYPE_CHECKING:
from ..dataclasses import PydanticDataclass
from ..main import BaseModel
ValSer = TypeVar('ValSer', bound=Union[SchemaValidator, PluggableSchemaValidator, SchemaSerializer])
T = TypeVar('T')
class MockCoreSchema(Mapping[str, Any]):
"""Mocker for `pydantic_core.CoreSchema` which optionally attempts to
rebuild the thing it's mocking when one of its methods is accessed and raises an error if that fails.
"""
__slots__ = '_error_message', '_code', '_attempt_rebuild', '_built_memo'
def __init__(
self,
error_message: str,
*,
code: PydanticErrorCodes,
attempt_rebuild: Callable[[], CoreSchema | None] | None = None,
) -> None:
self._error_message = error_message
self._code: PydanticErrorCodes = code
self._attempt_rebuild = attempt_rebuild
self._built_memo: CoreSchema | None = None
def __getitem__(self, key: str) -> Any:
return self._get_built().__getitem__(key)
def __len__(self) -> int:
return self._get_built().__len__()
def __iter__(self) -> Iterator[str]:
return self._get_built().__iter__()
def _get_built(self) -> CoreSchema:
if self._built_memo is not None:
return self._built_memo
if self._attempt_rebuild:
schema = self._attempt_rebuild()
if schema is not None:
self._built_memo = schema
return schema
raise PydanticUserError(self._error_message, code=self._code)
def rebuild(self) -> CoreSchema | None:
self._built_memo = None
if self._attempt_rebuild:
val_ser = self._attempt_rebuild()
if val_ser is not None:
return val_ser
else:
raise PydanticUserError(self._error_message, code=self._code)
return None
class MockValSer(Generic[ValSer]):
"""Mocker for `pydantic_core.SchemaValidator` or `pydantic_core.SchemaSerializer` which optionally attempts to
rebuild the thing it's mocking when one of its methods is accessed and raises an error if that fails.
"""
__slots__ = '_error_message', '_code', '_val_or_ser', '_attempt_rebuild'
def __init__(
self,
error_message: str,
*,
code: PydanticErrorCodes,
val_or_ser: Literal['validator', 'serializer'],
attempt_rebuild: Callable[[], ValSer | None] | None = None,
) -> None:
self._error_message = error_message
self._val_or_ser = SchemaValidator if val_or_ser == 'validator' else SchemaSerializer
self._code: PydanticErrorCodes = code
self._attempt_rebuild = attempt_rebuild
def __getattr__(self, item: str) -> None:
__tracebackhide__ = True
if self._attempt_rebuild:
val_ser = self._attempt_rebuild()
if val_ser is not None:
return getattr(val_ser, item)
# raise an AttributeError if `item` doesn't exist
getattr(self._val_or_ser, item)
raise PydanticUserError(self._error_message, code=self._code)
def rebuild(self) -> ValSer | None:
if self._attempt_rebuild:
val_ser = self._attempt_rebuild()
if val_ser is not None:
return val_ser
else:
raise PydanticUserError(self._error_message, code=self._code)
return None
def set_model_mocks(cls: type[BaseModel], cls_name: str, undefined_name: str = 'all referenced types') -> None:
"""Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a model.
Args:
cls: The model class to set the mocks on
cls_name: Name of the model class, used in error messages
undefined_name: Name of the undefined thing, used in error messages
"""
undefined_type_error_message = (
f'`{cls_name}` is not fully defined; you should define {undefined_name},'
f' then call `{cls_name}.model_rebuild()`.'
)
def attempt_rebuild_fn(attr_fn: Callable[[type[BaseModel]], T]) -> Callable[[], T | None]:
def handler() -> T | None:
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5) is not False:
return attr_fn(cls)
else:
return None
return handler
cls.__pydantic_core_schema__ = MockCoreSchema( # type: ignore[assignment]
undefined_type_error_message,
code='class-not-fully-defined',
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_core_schema__),
)
cls.__pydantic_validator__ = MockValSer( # type: ignore[assignment]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='validator',
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_validator__),
)
cls.__pydantic_serializer__ = MockValSer( # type: ignore[assignment]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='serializer',
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_serializer__),
)
def set_dataclass_mocks(
cls: type[PydanticDataclass], cls_name: str, undefined_name: str = 'all referenced types'
) -> None:
"""Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a dataclass.
Args:
cls: The model class to set the mocks on
cls_name: Name of the model class, used in error messages
undefined_name: Name of the undefined thing, used in error messages
"""
from ..dataclasses import rebuild_dataclass
undefined_type_error_message = (
f'`{cls_name}` is not fully defined; you should define {undefined_name},'
f' then call `pydantic.dataclasses.rebuild_dataclass({cls_name})`.'
)
def attempt_rebuild_fn(attr_fn: Callable[[type[PydanticDataclass]], T]) -> Callable[[], T | None]:
def handler() -> T | None:
if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5) is not False:
return attr_fn(cls)
else:
return None
return handler
cls.__pydantic_core_schema__ = MockCoreSchema( # type: ignore[assignment]
undefined_type_error_message,
code='class-not-fully-defined',
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_core_schema__),
)
cls.__pydantic_validator__ = MockValSer( # type: ignore[assignment]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='validator',
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_validator__),
)
cls.__pydantic_serializer__ = MockValSer( # type: ignore[assignment]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='validator',
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_serializer__),
)