import dataclasses as dc import datetime import decimal import enum import ipaddress import operator import pathlib import re import uuid from enum import Enum from typing import ( Any, Callable, Container, Iterable, Mapping, MutableMapping, Optional, Tuple, ) import bson import pydantic import beanie from beanie.odm.fields import Link, LinkTypes from beanie.odm.utils.pydantic import ( IS_PYDANTIC_V2, IS_PYDANTIC_V2_10, get_model_fields, ) SingleArgCallable = Callable[[Any], Any] DEFAULT_CUSTOM_ENCODERS: MutableMapping[type, SingleArgCallable] = { ipaddress.IPv4Address: str, ipaddress.IPv4Interface: str, ipaddress.IPv4Network: str, ipaddress.IPv6Address: str, ipaddress.IPv6Interface: str, ipaddress.IPv6Network: str, pathlib.PurePath: str, pydantic.SecretBytes: pydantic.SecretBytes.get_secret_value, pydantic.SecretStr: pydantic.SecretStr.get_secret_value, datetime.date: lambda d: datetime.datetime.combine(d, datetime.time.min), datetime.timedelta: operator.methodcaller("total_seconds"), enum.Enum: operator.attrgetter("value"), Link: operator.attrgetter("ref"), bytes: bson.Binary, decimal.Decimal: bson.Decimal128, uuid.UUID: bson.Binary.from_uuid, re.Pattern: bson.Regex.from_native, } if IS_PYDANTIC_V2: from pydantic_core import Url DEFAULT_CUSTOM_ENCODERS[Url] = str if IS_PYDANTIC_V2_10: from pydantic import AnyUrl DEFAULT_CUSTOM_ENCODERS[AnyUrl] = str BSON_SCALAR_TYPES = ( type(None), str, int, float, datetime.datetime, bson.Binary, bson.DBRef, bson.Decimal128, bson.MaxKey, bson.MinKey, bson.ObjectId, bson.Regex, ) @dc.dataclass class Encoder: """ BSON encoding class """ exclude: Container[str] = frozenset() custom_encoders: Mapping[type, SingleArgCallable] = dc.field( default_factory=dict ) to_db: bool = False keep_nulls: bool = True def _encode_document(self, obj: "beanie.Document") -> Mapping[str, Any]: obj.parse_store() settings = obj.get_settings() obj_dict = {} if settings.union_doc is not None: obj_dict[settings.class_id] = ( settings.union_doc_alias or obj.__class__.__name__ ) if obj._class_id: obj_dict[settings.class_id] = obj._class_id link_fields = obj.get_link_fields() or {} sub_encoder = Encoder( # don't propagate self.exclude to subdocuments custom_encoders=settings.bson_encoders, to_db=self.to_db, keep_nulls=self.keep_nulls, ) for key, value in self._iter_model_items(obj): if key in link_fields: link_type = link_fields[key].link_type if link_type in (LinkTypes.DIRECT, LinkTypes.OPTIONAL_DIRECT): if value is not None: value = value.to_ref() elif link_type in (LinkTypes.LIST, LinkTypes.OPTIONAL_LIST): if value is not None: value = [link.to_ref() for link in value] elif self.to_db: continue obj_dict[key] = sub_encoder.encode(value) return obj_dict def encode(self, obj: Any) -> Any: if self.custom_encoders: encoder = _get_encoder(obj, self.custom_encoders) if encoder is not None: return encoder(obj) if isinstance(obj, BSON_SCALAR_TYPES): return obj encoder = _get_encoder(obj, DEFAULT_CUSTOM_ENCODERS) if encoder is not None: return encoder(obj) if isinstance(obj, beanie.Document): return self._encode_document(obj) if IS_PYDANTIC_V2 and isinstance(obj, pydantic.RootModel): return self.encode(obj.root) if isinstance(obj, pydantic.BaseModel): items = self._iter_model_items(obj) return {key: self.encode(value) for key, value in items} if isinstance(obj, Mapping): return { key if isinstance(key, Enum) else str(key): self.encode(value) for key, value in obj.items() } if isinstance(obj, Iterable): return [self.encode(value) for value in obj] raise ValueError(f"Cannot encode {obj!r}") def _iter_model_items( self, obj: pydantic.BaseModel ) -> Iterable[Tuple[str, Any]]: exclude, keep_nulls = self.exclude, self.keep_nulls get_model_field = get_model_fields(obj).get for key, value in obj.__iter__(): field_info = get_model_field(key) if field_info is not None: key = field_info.alias or key if key not in exclude and (value is not None or keep_nulls): yield key, value def _get_encoder( obj: Any, custom_encoders: Mapping[type, SingleArgCallable] ) -> Optional[SingleArgCallable]: encoder = custom_encoders.get(type(obj)) if encoder is not None: return encoder for cls, encoder in custom_encoders.items(): if isinstance(obj, cls): return encoder return None