177 lines
5.1 KiB
Python
177 lines
5.1 KiB
Python
import dataclasses as dc
|
|
import datetime
|
|
import decimal
|
|
import enum
|
|
import ipaddress
|
|
import operator
|
|
import pathlib
|
|
import re
|
|
import uuid
|
|
from enum import Enum
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Container,
|
|
Iterable,
|
|
Mapping,
|
|
MutableMapping,
|
|
Optional,
|
|
Tuple,
|
|
)
|
|
|
|
import bson
|
|
import pydantic
|
|
|
|
import beanie
|
|
from beanie.odm.fields import Link, LinkTypes
|
|
from beanie.odm.utils.pydantic import (
|
|
IS_PYDANTIC_V2,
|
|
IS_PYDANTIC_V2_10,
|
|
get_model_fields,
|
|
)
|
|
|
|
SingleArgCallable = Callable[[Any], Any]
|
|
DEFAULT_CUSTOM_ENCODERS: MutableMapping[type, SingleArgCallable] = {
|
|
ipaddress.IPv4Address: str,
|
|
ipaddress.IPv4Interface: str,
|
|
ipaddress.IPv4Network: str,
|
|
ipaddress.IPv6Address: str,
|
|
ipaddress.IPv6Interface: str,
|
|
ipaddress.IPv6Network: str,
|
|
pathlib.PurePath: str,
|
|
pydantic.SecretBytes: pydantic.SecretBytes.get_secret_value,
|
|
pydantic.SecretStr: pydantic.SecretStr.get_secret_value,
|
|
datetime.date: lambda d: datetime.datetime.combine(d, datetime.time.min),
|
|
datetime.timedelta: operator.methodcaller("total_seconds"),
|
|
enum.Enum: operator.attrgetter("value"),
|
|
Link: operator.attrgetter("ref"),
|
|
bytes: bson.Binary,
|
|
decimal.Decimal: bson.Decimal128,
|
|
uuid.UUID: bson.Binary.from_uuid,
|
|
re.Pattern: bson.Regex.from_native,
|
|
}
|
|
if IS_PYDANTIC_V2:
|
|
from pydantic_core import Url
|
|
|
|
DEFAULT_CUSTOM_ENCODERS[Url] = str
|
|
|
|
if IS_PYDANTIC_V2_10:
|
|
from pydantic import AnyUrl
|
|
|
|
DEFAULT_CUSTOM_ENCODERS[AnyUrl] = str
|
|
|
|
BSON_SCALAR_TYPES = (
|
|
type(None),
|
|
str,
|
|
int,
|
|
float,
|
|
datetime.datetime,
|
|
bson.Binary,
|
|
bson.DBRef,
|
|
bson.Decimal128,
|
|
bson.MaxKey,
|
|
bson.MinKey,
|
|
bson.ObjectId,
|
|
bson.Regex,
|
|
)
|
|
|
|
|
|
@dc.dataclass
|
|
class Encoder:
|
|
"""
|
|
BSON encoding class
|
|
"""
|
|
|
|
exclude: Container[str] = frozenset()
|
|
custom_encoders: Mapping[type, SingleArgCallable] = dc.field(
|
|
default_factory=dict
|
|
)
|
|
to_db: bool = False
|
|
keep_nulls: bool = True
|
|
|
|
def _encode_document(self, obj: "beanie.Document") -> Mapping[str, Any]:
|
|
obj.parse_store()
|
|
settings = obj.get_settings()
|
|
obj_dict = {}
|
|
if settings.union_doc is not None:
|
|
obj_dict[settings.class_id] = (
|
|
settings.union_doc_alias or obj.__class__.__name__
|
|
)
|
|
if obj._class_id:
|
|
obj_dict[settings.class_id] = obj._class_id
|
|
|
|
link_fields = obj.get_link_fields() or {}
|
|
sub_encoder = Encoder(
|
|
# don't propagate self.exclude to subdocuments
|
|
custom_encoders=settings.bson_encoders,
|
|
to_db=self.to_db,
|
|
keep_nulls=self.keep_nulls,
|
|
)
|
|
for key, value in self._iter_model_items(obj):
|
|
if key in link_fields:
|
|
link_type = link_fields[key].link_type
|
|
if link_type in (LinkTypes.DIRECT, LinkTypes.OPTIONAL_DIRECT):
|
|
if value is not None:
|
|
value = value.to_ref()
|
|
elif link_type in (LinkTypes.LIST, LinkTypes.OPTIONAL_LIST):
|
|
if value is not None:
|
|
value = [link.to_ref() for link in value]
|
|
elif self.to_db:
|
|
continue
|
|
obj_dict[key] = sub_encoder.encode(value)
|
|
return obj_dict
|
|
|
|
def encode(self, obj: Any) -> Any:
|
|
if self.custom_encoders:
|
|
encoder = _get_encoder(obj, self.custom_encoders)
|
|
if encoder is not None:
|
|
return encoder(obj)
|
|
|
|
if isinstance(obj, BSON_SCALAR_TYPES):
|
|
return obj
|
|
|
|
encoder = _get_encoder(obj, DEFAULT_CUSTOM_ENCODERS)
|
|
if encoder is not None:
|
|
return encoder(obj)
|
|
|
|
if isinstance(obj, beanie.Document):
|
|
return self._encode_document(obj)
|
|
if IS_PYDANTIC_V2 and isinstance(obj, pydantic.RootModel):
|
|
return self.encode(obj.root)
|
|
if isinstance(obj, pydantic.BaseModel):
|
|
items = self._iter_model_items(obj)
|
|
return {key: self.encode(value) for key, value in items}
|
|
if isinstance(obj, Mapping):
|
|
return {
|
|
key if isinstance(key, Enum) else str(key): self.encode(value)
|
|
for key, value in obj.items()
|
|
}
|
|
if isinstance(obj, Iterable):
|
|
return [self.encode(value) for value in obj]
|
|
|
|
raise ValueError(f"Cannot encode {obj!r}")
|
|
|
|
def _iter_model_items(
|
|
self, obj: pydantic.BaseModel
|
|
) -> Iterable[Tuple[str, Any]]:
|
|
exclude, keep_nulls = self.exclude, self.keep_nulls
|
|
get_model_field = get_model_fields(obj).get
|
|
for key, value in obj.__iter__():
|
|
field_info = get_model_field(key)
|
|
if field_info is not None:
|
|
key = field_info.alias or key
|
|
if key not in exclude and (value is not None or keep_nulls):
|
|
yield key, value
|
|
|
|
|
|
def _get_encoder(
|
|
obj: Any, custom_encoders: Mapping[type, SingleArgCallable]
|
|
) -> Optional[SingleArgCallable]:
|
|
encoder = custom_encoders.get(type(obj))
|
|
if encoder is not None:
|
|
return encoder
|
|
for cls, encoder in custom_encoders.items():
|
|
if isinstance(obj, cls):
|
|
return encoder
|
|
return None
|