freeleaps-ops/venv/lib/python3.12/site-packages/beanie/odm/documents.py

1386 lines
47 KiB
Python

import asyncio
import warnings
from datetime import datetime, timezone
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Coroutine,
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from uuid import UUID, uuid4
from bson import DBRef, ObjectId
from lazy_model import LazyModel
from pydantic import (
ConfigDict,
Field,
PrivateAttr,
ValidationError,
)
from pydantic.class_validators import root_validator
from pydantic.main import BaseModel
from pymongo import InsertOne
from pymongo.asynchronous.client_session import AsyncClientSession
from pymongo.errors import DuplicateKeyError
from pymongo.results import (
DeleteResult,
InsertManyResult,
)
from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias
from beanie.exceptions import (
CollectionWasNotInitialized,
DocumentNotFound,
DocumentWasNotSaved,
NotSupported,
ReplaceError,
RevisionIdWasChanged,
)
from beanie.odm.actions import (
ActionDirections,
EventTypes,
wrap_with_actions,
)
from beanie.odm.bulk import BulkWriter
from beanie.odm.cache import LRUCache
from beanie.odm.enums import SortDirection
from beanie.odm.fields import (
BackLink,
DeleteRules,
ExpressionField,
Link,
LinkInfo,
LinkTypes,
PydanticObjectId,
WriteRules,
)
from beanie.odm.interfaces.aggregate import AggregateInterface
from beanie.odm.interfaces.detector import ModelType
from beanie.odm.interfaces.find import FindInterface
from beanie.odm.interfaces.getters import OtherGettersInterface
from beanie.odm.interfaces.inheritance import InheritanceInterface
from beanie.odm.interfaces.setters import SettersInterface
from beanie.odm.models import (
InspectionError,
InspectionResult,
InspectionStatuses,
)
from beanie.odm.operators.find.comparison import In
from beanie.odm.operators.update.general import (
CurrentDate,
Inc,
SetRevisionId,
Unset,
)
from beanie.odm.operators.update.general import (
Set as SetOperator,
)
from beanie.odm.queries.find import FindMany, FindOne
from beanie.odm.queries.update import UpdateMany, UpdateResponse
from beanie.odm.settings.document import DocumentSettings
from beanie.odm.utils.dump import get_dict, get_top_level_nones
from beanie.odm.utils.parsing import apply_changes, merge_models
from beanie.odm.utils.pydantic import (
IS_PYDANTIC_V2,
get_extra_field_info,
get_field_type,
get_model_dump,
get_model_fields,
parse_model,
parse_object_as,
)
from beanie.odm.utils.self_validation import validate_self_before
from beanie.odm.utils.state import (
previous_saved_state_needed,
save_state_after,
saved_state_needed,
)
from beanie.odm.utils.typing import extract_id_class
if IS_PYDANTIC_V2:
from pydantic import model_validator
if TYPE_CHECKING:
from beanie.odm.views import View
FindType = TypeVar("FindType", bound=Union["Document", "View"])
DocType = TypeVar("DocType", bound="Document")
P = ParamSpec("P")
R = TypeVar("R")
# can describe both sync and async, where R itself is a coroutine
AnyDocMethod: TypeAlias = Callable[Concatenate[DocType, P], R]
# describes only async
AsyncDocMethod: TypeAlias = Callable[
Concatenate[DocType, P], Coroutine[Any, Any, R]
]
DocumentProjectionType = TypeVar("DocumentProjectionType", bound=BaseModel)
def json_schema_extra(schema: Dict[str, Any], model: Type["Document"]) -> None:
# remove excluded fields from the json schema
properties = schema.get("properties")
if not properties:
return
for k, field in get_model_fields(model).items():
k = field.alias or k
if k not in properties:
continue
field_info = field if IS_PYDANTIC_V2 else field.field_info
if field_info.exclude:
del properties[k]
def document_alias_generator(s: str) -> str:
if s == "id":
return "_id"
return s
class MergeStrategy(str, Enum):
local = "local"
remote = "remote"
class Document(
LazyModel,
SettersInterface,
InheritanceInterface,
FindInterface,
AggregateInterface,
OtherGettersInterface,
):
"""
Document Mapping class.
Fields:
- `id` - MongoDB document ObjectID "_id" field.
Mapped to the PydanticObjectId class
"""
if IS_PYDANTIC_V2:
model_config = ConfigDict(
json_schema_extra=json_schema_extra,
populate_by_name=True,
alias_generator=document_alias_generator,
)
else:
class Config:
json_encoders = {ObjectId: str}
allow_population_by_field_name = True
fields = {"id": "_id"}
schema_extra = staticmethod(json_schema_extra)
id: Optional[PydanticObjectId] = Field(
default=None, description="MongoDB document ObjectID"
)
# State
revision_id: Optional[UUID] = Field(default=None, exclude=True)
_saved_state: Optional[Dict[str, Any]] = PrivateAttr(default=None)
_previous_saved_state: Optional[Dict[str, Any]] = PrivateAttr(default=None)
# Relations
_link_fields: ClassVar[Optional[Dict[str, LinkInfo]]] = None
# Cache
_cache: ClassVar[Optional[LRUCache]] = None
# Settings
_document_settings: ClassVar[Optional[DocumentSettings]] = None
# Database
_database_major_version: ClassVar[int] = 4
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(Document, self).__init__(*args, **kwargs)
self.get_pymongo_collection()
@classmethod
def _fill_back_refs(cls, values):
if cls._link_fields:
for field_name, link_info in cls._link_fields.items():
if (
link_info.link_type
in [LinkTypes.BACK_DIRECT, LinkTypes.OPTIONAL_BACK_DIRECT]
and field_name not in values
):
values[field_name] = BackLink[link_info.document_class](
link_info.document_class
)
if (
link_info.link_type
in [LinkTypes.BACK_LIST, LinkTypes.OPTIONAL_BACK_LIST]
and field_name not in values
):
values[field_name] = [
BackLink[link_info.document_class](
link_info.document_class
)
]
return values
if IS_PYDANTIC_V2:
@model_validator(mode="before")
def fill_back_refs(cls, values):
return cls._fill_back_refs(values)
else:
@root_validator(pre=True)
def fill_back_refs(cls, values):
return cls._fill_back_refs(values)
@classmethod
async def get(
cls: Type["DocType"],
document_id: Any,
session: Optional[AsyncClientSession] = None,
ignore_cache: bool = False,
fetch_links: bool = False,
with_children: bool = False,
nesting_depth: Optional[int] = None,
nesting_depths_per_field: Optional[Dict[str, int]] = None,
**pymongo_kwargs: Any,
) -> Optional["DocType"]:
"""
Get document by id, returns None if document does not exist
:param document_id: PydanticObjectId - document id
:param session: Optional[AsyncClientSession] - pymongo session
:param ignore_cache: bool - ignore cache (if it is turned on)
:param **pymongo_kwargs: pymongo native parameters for find operation
:return: Union["Document", None]
"""
if not isinstance(
document_id,
extract_id_class(get_field_type(get_model_fields(cls)["id"])),
):
document_id = parse_object_as(
get_field_type(get_model_fields(cls)["id"]), document_id
)
return await cls.find_one(
{"_id": document_id},
session=session,
ignore_cache=ignore_cache,
fetch_links=fetch_links,
with_children=with_children,
nesting_depth=nesting_depth,
nesting_depths_per_field=nesting_depths_per_field,
**pymongo_kwargs,
)
async def sync(self, merge_strategy: MergeStrategy = MergeStrategy.remote):
"""
Sync the document with the database
:param merge_strategy: MergeStrategy - how to merge the document
:return: None
"""
if (
merge_strategy == MergeStrategy.local
and self.get_settings().use_state_management is False
):
raise ValueError(
"State management must be turned on to use local merge strategy"
)
if self.id is None:
raise DocumentWasNotSaved
document = await self.find_one({"_id": self.id})
if document is None:
raise DocumentNotFound
if merge_strategy == MergeStrategy.local:
original_changes = self.get_changes()
new_state = document.get_saved_state()
if new_state is None:
raise DocumentWasNotSaved
changes_to_apply = self._collect_updates(
new_state, original_changes
)
merge_models(self, document)
apply_changes(changes_to_apply, self)
elif merge_strategy == MergeStrategy.remote:
merge_models(self, document)
else:
raise ValueError("Invalid merge strategy")
@wrap_with_actions(EventTypes.INSERT)
@save_state_after
@validate_self_before
async def insert(
self: Self,
*,
link_rule: WriteRules = WriteRules.DO_NOTHING,
session: Optional[AsyncClientSession] = None,
skip_actions: Optional[List[Union[ActionDirections, str]]] = None,
) -> Self:
"""
Insert the document (self) to the collection
:param link_rule: WriteRules - if "WriteRules.WRITE", it will insert Link Documents to db.
:param session: AsyncClientSession - pymongo session
:return: self
"""
if self.get_settings().use_revision:
self.revision_id = uuid4()
if link_rule == WriteRules.WRITE:
link_fields = self.get_link_fields()
if link_fields is not None:
for field_info in link_fields.values():
value = getattr(self, field_info.field_name)
if field_info.link_type in [
LinkTypes.DIRECT,
LinkTypes.OPTIONAL_DIRECT,
]:
if isinstance(value, Document):
await value.save(
link_rule=WriteRules.WRITE, session=session
)
if field_info.link_type in [
LinkTypes.LIST,
LinkTypes.OPTIONAL_LIST,
]:
if isinstance(value, List):
await asyncio.gather(
*[
obj.save(
link_rule=WriteRules.WRITE,
session=session,
)
for obj in value
if isinstance(obj, Document)
]
)
result = await self.get_pymongo_collection().insert_one(
get_dict(
self, to_db=True, keep_nulls=self.get_settings().keep_nulls
),
session=session,
)
new_id = result.inserted_id
if not isinstance(
new_id,
extract_id_class(get_field_type(get_model_fields(self)["id"])),
):
new_id = parse_object_as(
get_field_type(get_model_fields(self)["id"]), new_id
)
self.id = new_id
return self
async def create(
self: Self,
session: Optional[AsyncClientSession] = None,
) -> Self:
"""
The same as self.insert()
:return: self
"""
return await self.insert(session=session)
@classmethod
async def insert_one(
cls: Type[DocType],
document: DocType,
session: Optional[AsyncClientSession] = None,
bulk_writer: Optional["BulkWriter"] = None,
link_rule: WriteRules = WriteRules.DO_NOTHING,
) -> Optional[DocType]:
"""
Insert one document to the collection
:param document: Document - document to insert
:param session: AsyncClientSession - pymongo session
:param bulk_writer: "BulkWriter" - Beanie bulk writer
:param link_rule: InsertRules - hot to manage link fields
:return: DocType
"""
if not isinstance(document, cls):
raise TypeError(
"Inserting document must be of the original document class"
)
if bulk_writer is None:
return await document.insert(link_rule=link_rule, session=session)
else:
if link_rule == WriteRules.WRITE:
raise NotSupported(
"Cascade insert with bulk writing not supported"
)
bulk_writer.add_operation(
type(document),
InsertOne(
get_dict(
document,
to_db=True,
keep_nulls=document.get_settings().keep_nulls,
)
),
)
return None
@classmethod
async def insert_many(
cls: Type[DocType],
documents: Iterable[DocType],
session: Optional[AsyncClientSession] = None,
link_rule: WriteRules = WriteRules.DO_NOTHING,
**pymongo_kwargs: Any,
) -> InsertManyResult:
"""
Insert many documents to the collection
:param documents: List["Document"] - documents to insert
:param session: AsyncClientSession - pymongo session
:param link_rule: InsertRules - how to manage link fields
:return: InsertManyResult
"""
if link_rule == WriteRules.WRITE:
raise NotSupported(
"Cascade insert not supported for insert many method"
)
documents_list = [
get_dict(
document,
to_db=True,
keep_nulls=document.get_settings().keep_nulls,
)
for document in documents
]
return await cls.get_pymongo_collection().insert_many(
documents_list, session=session, **pymongo_kwargs
)
@wrap_with_actions(EventTypes.REPLACE)
@save_state_after
@validate_self_before
async def replace(
self: Self,
ignore_revision: bool = False,
session: Optional[AsyncClientSession] = None,
bulk_writer: Optional[BulkWriter] = None,
link_rule: WriteRules = WriteRules.DO_NOTHING,
skip_actions: Optional[List[Union[ActionDirections, str]]] = None,
) -> Self:
"""
Fully update the document in the database
:param session: Optional[AsyncClientSession] - pymongo session.
:param ignore_revision: bool - do force replace.
Used when revision based protection is turned on.
:param bulk_writer: "BulkWriter" - Beanie bulk writer
:return: self
"""
if self.id is None:
raise ValueError("Document must have an id")
if bulk_writer is not None and link_rule != WriteRules.DO_NOTHING:
raise NotSupported
if link_rule == WriteRules.WRITE:
link_fields = self.get_link_fields()
if link_fields is not None:
for field_info in link_fields.values():
value = getattr(self, field_info.field_name)
if field_info.link_type in [
LinkTypes.DIRECT,
LinkTypes.OPTIONAL_DIRECT,
LinkTypes.BACK_DIRECT,
LinkTypes.OPTIONAL_BACK_DIRECT,
]:
if isinstance(value, Document):
await value.replace(
link_rule=link_rule,
bulk_writer=bulk_writer,
ignore_revision=ignore_revision,
session=session,
)
if field_info.link_type in [
LinkTypes.LIST,
LinkTypes.OPTIONAL_LIST,
LinkTypes.BACK_LIST,
LinkTypes.OPTIONAL_BACK_LIST,
]:
if isinstance(value, List):
await asyncio.gather(
*[
obj.replace(
link_rule=link_rule,
bulk_writer=bulk_writer,
ignore_revision=ignore_revision,
session=session,
)
for obj in value
if isinstance(obj, Document)
]
)
use_revision_id = self.get_settings().use_revision
find_query: Dict[str, Any] = {"_id": self.id}
if use_revision_id and not ignore_revision:
find_query["revision_id"] = self.revision_id
self.revision_id = uuid4()
try:
await self.find_one(find_query).replace_one(
self,
session=session,
bulk_writer=bulk_writer,
)
except DocumentNotFound:
if use_revision_id and not ignore_revision:
raise RevisionIdWasChanged
else:
raise DocumentNotFound
return self
@wrap_with_actions(EventTypes.SAVE)
@save_state_after
@validate_self_before
async def save(
self: Self,
session: Optional[AsyncClientSession] = None,
link_rule: WriteRules = WriteRules.DO_NOTHING,
ignore_revision: bool = False,
**kwargs: Any,
) -> Self:
"""
Update an existing model in the database or
insert it if it does not yet exist.
:param session: Optional[AsyncClientSession] - pymongo session.
:param link_rule: WriteRules - rules how to deal with links on writing
:param ignore_revision: bool - do force save.
:return: self
"""
if link_rule == WriteRules.WRITE:
link_fields = self.get_link_fields()
if link_fields is not None:
for field_info in link_fields.values():
value = getattr(self, field_info.field_name)
if field_info.link_type in [
LinkTypes.DIRECT,
LinkTypes.OPTIONAL_DIRECT,
LinkTypes.BACK_DIRECT,
LinkTypes.OPTIONAL_BACK_DIRECT,
]:
if isinstance(value, Document):
await value.save(
link_rule=link_rule, session=session
)
if field_info.link_type in [
LinkTypes.LIST,
LinkTypes.OPTIONAL_LIST,
LinkTypes.BACK_LIST,
LinkTypes.OPTIONAL_BACK_LIST,
]:
if isinstance(value, List):
await asyncio.gather(
*[
obj.save(
link_rule=link_rule, session=session
)
for obj in value
if isinstance(obj, Document)
]
)
if self.get_settings().keep_nulls is False:
return await self.update(
SetOperator(
get_dict(
self,
to_db=True,
keep_nulls=self.get_settings().keep_nulls,
)
),
Unset(get_top_level_nones(self)),
session=session,
ignore_revision=ignore_revision,
upsert=True,
**kwargs,
)
else:
return await self.update(
SetOperator(
get_dict(
self,
to_db=True,
keep_nulls=self.get_settings().keep_nulls,
)
),
session=session,
ignore_revision=ignore_revision,
upsert=True,
**kwargs,
)
@saved_state_needed
@wrap_with_actions(EventTypes.SAVE_CHANGES)
@validate_self_before
async def save_changes(
self: Self,
ignore_revision: bool = False,
session: Optional[AsyncClientSession] = None,
bulk_writer: Optional[BulkWriter] = None,
skip_actions: Optional[List[Union[ActionDirections, str]]] = None,
) -> Optional[Self]:
"""
Save changes.
State management usage must be turned on
:param ignore_revision: bool - ignore revision id, if revision is turned on
:param bulk_writer: "BulkWriter" - Beanie bulk writer
:return: Optional[self]
"""
if not self.is_changed:
return None
changes = self.get_changes()
if self.get_settings().keep_nulls is False:
return await self.update(
SetOperator(changes),
Unset(get_top_level_nones(self)),
ignore_revision=ignore_revision,
session=session,
bulk_writer=bulk_writer,
)
else:
return await self.set(
changes,
ignore_revision=ignore_revision,
session=session,
bulk_writer=bulk_writer,
)
@classmethod
async def replace_many(
cls: Type[DocType],
documents: List[DocType],
session: Optional[AsyncClientSession] = None,
) -> None:
"""
Replace list of documents
:param documents: List["Document"]
:param session: Optional[AsyncClientSession] - pymongo session.
:return: None
"""
ids_list = [document.id for document in documents]
if await cls.find(In(cls.id, ids_list)).count() != len(ids_list):
raise ReplaceError(
"Some of the documents are not exist in the collection"
)
async with BulkWriter(session=session) as bulk_writer:
for document in documents:
await document.replace(
bulk_writer=bulk_writer, session=session
)
@wrap_with_actions(EventTypes.UPDATE)
@save_state_after
async def update(
self: Self,
*args: Union[Dict[Any, Any], Mapping[Any, Any]],
ignore_revision: bool = False,
session: Optional[AsyncClientSession] = None,
bulk_writer: Optional[BulkWriter] = None,
skip_actions: Optional[List[Union[ActionDirections, str]]] = None,
skip_sync: Optional[bool] = None,
**pymongo_kwargs: Any,
) -> Self:
"""
Partially update the document in the database
:param args: *Union[dict, Mapping] - the modifications to apply.
:param session: AsyncClientSession - pymongo session.
:param ignore_revision: bool - force update. Will update even if revision id is not the same, as stored
:param bulk_writer: "BulkWriter" - Beanie bulk writer
:param pymongo_kwargs: pymongo native parameters for update operation
:return: self
"""
arguments: list[Any] = list(args)
if skip_sync is not None:
raise DeprecationWarning(
"skip_sync parameter is not supported. The document get synced always using atomic operation."
)
use_revision_id = self.get_settings().use_revision
if self.id is not None:
find_query: Dict[str, Any] = {"_id": self.id}
else:
find_query = {"_id": PydanticObjectId()}
if use_revision_id and not ignore_revision:
find_query["revision_id"] = self.revision_id
if use_revision_id:
new_revision_id = uuid4()
arguments.append(SetRevisionId(new_revision_id))
try:
result = await self.find_one(find_query).update(
*arguments,
session=session,
response_type=UpdateResponse.NEW_DOCUMENT,
bulk_writer=bulk_writer,
**pymongo_kwargs,
)
except DuplicateKeyError:
raise RevisionIdWasChanged
if bulk_writer is None:
if use_revision_id and not ignore_revision and result is None:
raise RevisionIdWasChanged
merge_models(self, result)
return self
@classmethod
def update_all(
cls,
*args: Union[dict, Mapping],
session: Optional[AsyncClientSession] = None,
bulk_writer: Optional[BulkWriter] = None,
**pymongo_kwargs: Any,
) -> UpdateMany:
"""
Partially update all the documents
:param args: *Union[dict, Mapping] - the modifications to apply.
:param session: AsyncClientSession - pymongo session.
:param bulk_writer: "BulkWriter" - Beanie bulk writer
:param **pymongo_kwargs: pymongo native parameters for find operation
:return: UpdateMany query
"""
return cls.find_all().update_many(
*args, session=session, bulk_writer=bulk_writer, **pymongo_kwargs
)
def set(
self: Self,
expression: Dict[Union[ExpressionField, str, Any], Any],
session: Optional[AsyncClientSession] = None,
bulk_writer: Optional[BulkWriter] = None,
skip_sync: Optional[bool] = None,
**kwargs: Any,
) -> Coroutine[None, None, Self]:
"""
Set values
Example:
```python
class Sample(Document):
one: int
await Document.find(Sample.one == 1).set({Sample.one: 100})
```
Uses [Set operator](operators/update.md#set)
:param expression: Dict[Union[ExpressionField, str, Any], Any] - keys and
values to set
:param session: Optional[AsyncClientSession] - pymongo session
:param bulk_writer: Optional[BulkWriter] - bulk writer
:param skip_sync: bool - skip doc syncing. Available for the direct instances only
:return: self
"""
return self.update(
SetOperator(expression),
session=session,
bulk_writer=bulk_writer,
skip_sync=skip_sync,
**kwargs,
)
def current_date(
self: Self,
expression: Dict[Union[datetime, ExpressionField, str], Any],
session: Optional[AsyncClientSession] = None,
bulk_writer: Optional[BulkWriter] = None,
skip_sync: Optional[bool] = None,
**kwargs: Any,
) -> Coroutine[None, None, Self]:
"""
Set current date
Uses [CurrentDate operator](operators/update.md#currentdate)
:param expression: Dict[Union[datetime, ExpressionField, str], Any]
:param session: Optional[AsyncClientSession] - pymongo session
:param bulk_writer: Optional[BulkWriter] - bulk writer
:param skip_sync: bool - skip doc syncing. Available for the direct instances only
:return: self
"""
return self.update(
CurrentDate(expression),
session=session,
bulk_writer=bulk_writer,
skip_sync=skip_sync,
**kwargs,
)
def inc(
self: Self,
expression: Dict[Union[ExpressionField, float, int, str], Any],
session: Optional[AsyncClientSession] = None,
bulk_writer: Optional[BulkWriter] = None,
skip_sync: Optional[bool] = None,
**kwargs: Any,
) -> Coroutine[None, None, Self]:
"""
Increment
Example:
```python
class Sample(Document):
one: int
await Document.find(Sample.one == 1).inc({Sample.one: 100})
```
Uses [Inc operator](operators/update.md#inc)
:param expression: Dict[Union[ExpressionField, float, int, str], Any]
:param session: Optional[AsyncClientSession] - pymongo session
:param bulk_writer: Optional[BulkWriter] - bulk writer
:param skip_sync: bool - skip doc syncing. Available for the direct instances only
:return: self
"""
return self.update(
Inc(expression),
session=session,
bulk_writer=bulk_writer,
skip_sync=skip_sync,
**kwargs,
)
@wrap_with_actions(EventTypes.DELETE)
async def delete(
self,
session: Optional[AsyncClientSession] = None,
bulk_writer: Optional[BulkWriter] = None,
link_rule: DeleteRules = DeleteRules.DO_NOTHING,
skip_actions: Optional[List[Union[ActionDirections, str]]] = None,
**pymongo_kwargs: Any,
) -> Optional[DeleteResult]:
"""
Delete the document
:param session: Optional[AsyncClientSession] - pymongo session.
:param bulk_writer: "BulkWriter" - Beanie bulk writer
:param link_rule: DeleteRules - rules for link fields
:param **pymongo_kwargs: pymongo native parameters for delete operation
:return: Optional[DeleteResult] - pymongo DeleteResult instance.
"""
if link_rule == DeleteRules.DELETE_LINKS:
link_fields = self.get_link_fields()
if link_fields is not None:
for field_info in link_fields.values():
value = getattr(self, field_info.field_name)
if field_info.link_type in [
LinkTypes.DIRECT,
LinkTypes.OPTIONAL_DIRECT,
LinkTypes.BACK_DIRECT,
LinkTypes.OPTIONAL_BACK_DIRECT,
]:
if isinstance(value, Document):
await value.delete(
link_rule=DeleteRules.DELETE_LINKS,
**pymongo_kwargs,
)
if field_info.link_type in [
LinkTypes.LIST,
LinkTypes.OPTIONAL_LIST,
LinkTypes.BACK_LIST,
LinkTypes.OPTIONAL_BACK_LIST,
]:
if isinstance(value, List):
await asyncio.gather(
*[
obj.delete(
link_rule=DeleteRules.DELETE_LINKS,
**pymongo_kwargs,
)
for obj in value
if isinstance(obj, Document)
]
)
return await self.find_one({"_id": self.id}).delete(
session=session, bulk_writer=bulk_writer, **pymongo_kwargs
)
@classmethod
async def delete_all(
cls,
session: Optional[AsyncClientSession] = None,
bulk_writer: Optional[BulkWriter] = None,
**pymongo_kwargs: Any,
) -> Optional[DeleteResult]:
"""
Delete all the documents
:param session: Optional[AsyncClientSession] - pymongo session.
:param bulk_writer: "BulkWriter" - Beanie bulk writer
:param **pymongo_kwargs: pymongo native parameters for delete operation
:return: Optional[DeleteResult] - pymongo DeleteResult instance.
"""
return await cls.find_all().delete(
session=session, bulk_writer=bulk_writer, **pymongo_kwargs
)
# State management
@classmethod
def use_state_management(cls) -> bool:
"""
Is state management turned on
:return: bool
"""
return cls.get_settings().use_state_management
@classmethod
def state_management_save_previous(cls) -> bool:
"""
Should we save the previous state after a commit to database
:return: bool
"""
return cls.get_settings().state_management_save_previous
@classmethod
def state_management_replace_objects(cls) -> bool:
"""
Should objects be replaced when using state management
:return: bool
"""
return cls.get_settings().state_management_replace_objects
def _save_state(self) -> None:
"""
Save current document state. Internal method
:return: None
"""
if self.use_state_management() and self.id is not None:
if self.state_management_save_previous():
self._previous_saved_state = self._saved_state
self._saved_state = get_dict(
self,
to_db=True,
keep_nulls=self.get_settings().keep_nulls,
exclude={"revision_id"},
)
def get_saved_state(self) -> Optional[Dict[str, Any]]:
"""
Saved state getter. It is protected property.
:return: Optional[Dict[str, Any]] - saved state
"""
return self._saved_state
def get_previous_saved_state(self) -> Optional[Dict[str, Any]]:
"""
Previous state getter. It is a protected property.
:return: Optional[Dict[str, Any]] - previous state
"""
return self._previous_saved_state
@property
@saved_state_needed
def is_changed(self) -> bool:
if self._saved_state == get_dict(
self,
to_db=True,
keep_nulls=self.get_settings().keep_nulls,
exclude={"revision_id"},
):
return False
return True
@property
@saved_state_needed
@previous_saved_state_needed
def has_changed(self) -> bool:
if (
self._previous_saved_state is None
or self._previous_saved_state == self._saved_state
):
return False
return True
def _collect_updates(
self, old_dict: Dict[str, Any], new_dict: Dict[str, Any]
) -> Dict[str, Any]:
"""
Compares old_dict with new_dict and returns field paths that have been updated
Args:
old_dict: dict1
new_dict: dict2
Returns: dictionary with updates
"""
updates = {}
if old_dict.keys() - new_dict.keys():
updates = new_dict
else:
for field_name, field_value in new_dict.items():
if field_value != old_dict.get(field_name):
if not self.state_management_replace_objects() and (
isinstance(field_value, dict)
and isinstance(old_dict.get(field_name), dict)
):
if old_dict.get(field_name) is None:
updates[field_name] = field_value
elif isinstance(field_value, dict) and isinstance(
old_dict.get(field_name), dict
):
field_data = self._collect_updates(
old_dict.get(field_name), # type: ignore
field_value,
)
for k, v in field_data.items():
updates[f"{field_name}.{k}"] = v
else:
updates[field_name] = field_value
return updates
@saved_state_needed
def get_changes(self) -> Dict[str, Any]:
return self._collect_updates(
self._saved_state, # type: ignore
get_dict(
self,
to_db=True,
keep_nulls=self.get_settings().keep_nulls,
exclude={"revision_id"},
),
)
@saved_state_needed
@previous_saved_state_needed
def get_previous_changes(self) -> Dict[str, Any]:
if self._previous_saved_state is None:
return {}
return self._collect_updates(
self._previous_saved_state,
self._saved_state, # type: ignore
)
@saved_state_needed
def rollback(self) -> None:
if self.is_changed:
for key, value in self._saved_state.items(): # type: ignore
if key == "_id":
setattr(self, "id", value)
else:
setattr(self, key, value)
# Other
@classmethod
def get_settings(cls) -> DocumentSettings:
"""
Get document settings, which was created on
the initialization step
:return: DocumentSettings class
"""
if cls._document_settings is None:
raise CollectionWasNotInitialized
return cls._document_settings
@classmethod
async def inspect_collection(
cls, session: Optional[AsyncClientSession] = None
) -> InspectionResult:
"""
Check, if documents, stored in the MongoDB collection
are compatible with the Document schema
:param session: Optional[AsyncClientSession] - pymongo session
The session instance used for transactional operations. Defaults to None.
:return: InspectionResult
"""
inspection_result = InspectionResult()
async for json_document in cls.get_pymongo_collection().find(
{}, session=session
):
try:
parse_model(cls, json_document)
except ValidationError as e:
if inspection_result.status == InspectionStatuses.OK:
inspection_result.status = InspectionStatuses.FAIL
inspection_result.errors.append(
InspectionError(
document_id=json_document["_id"], error=str(e)
)
)
return inspection_result
@classmethod
def _check_hidden_fields(cls):
hidden_fields = [
(name, field)
for name, field in get_model_fields(cls).items()
if get_extra_field_info(field, "hidden") is True
]
if not hidden_fields:
return
warnings.warn(
f"{cls.__name__}: 'hidden=True' is deprecated, please use 'exclude=True'",
DeprecationWarning,
stacklevel=2,
)
if IS_PYDANTIC_V2:
for name, field in hidden_fields:
field.exclude = True
del field.json_schema_extra["hidden"]
cls.model_rebuild(force=True)
else:
for name, field in hidden_fields:
field.field_info.exclude = True
del field.field_info.extra["hidden"]
cls.__exclude_fields__[name] = True
@wrap_with_actions(event_type=EventTypes.VALIDATE_ON_SAVE)
async def validate_self(self, *args: Any, **kwargs: Any):
# TODO: it can be sync, but needs some actions controller improvements
if self.get_settings().validate_on_save:
new_model = parse_model(self.__class__, get_model_dump(self))
merge_models(self, new_model)
def to_ref(self):
if self.id is None:
raise DocumentWasNotSaved("Can not create dbref without id")
return DBRef(self.get_pymongo_collection().name, self.id)
async def fetch_link(self, field: Union[str, Any]):
ref_obj = getattr(self, field, None)
if isinstance(ref_obj, Link):
value = await ref_obj.fetch(fetch_links=True)
setattr(self, field, value)
if isinstance(ref_obj, list) and ref_obj:
values = await Link.fetch_list(ref_obj, fetch_links=True)
setattr(self, field, values)
async def fetch_all_links(self):
coros = []
link_fields = self.get_link_fields()
if link_fields is not None:
for ref in link_fields.values():
coros.append(self.fetch_link(ref.field_name)) # TODO lists
await asyncio.gather(*coros)
@classmethod
def get_link_fields(cls) -> Optional[Dict[str, LinkInfo]]:
return cls._link_fields
@classmethod
def get_model_type(cls) -> ModelType:
return ModelType.Document
@classmethod
async def distinct(
cls,
key: str,
filter: Optional[Mapping[str, Any]] = None,
session: Optional[AsyncClientSession] = None,
**kwargs: Any,
) -> list:
return await cls.get_pymongo_collection().distinct(
key=key, filter=filter, session=session, **kwargs
)
@classmethod
def link_from_id(cls, id: Any):
ref = DBRef(id=id, collection=cls.get_collection_name())
return Link(ref, document_class=cls)
@classmethod
def bulk_writer(
cls,
session: Optional[AsyncClientSession] = None,
ordered: bool = True,
bypass_document_validation: Optional[bool] = False,
comment: Optional[Any] = None,
) -> BulkWriter:
"""
Returns a BulkWriter instance for handling bulk write operations.
:param session: Optional[AsyncClientSession] - pymongo session.
The session instance used for transactional operations.
:param ordered: bool
If ``True`` (the default), requests will be performed on the server serially, in the order provided. If an error
occurs, all remaining operations are aborted. If ``False``, requests will be performed on the server in
arbitrary order, possibly in parallel, and all operations will be attempted.
:param bypass_document_validation: bool, optional
If ``True``, allows the write to opt-out of document-level validation. Default is ``False``.
:param comment: str, optional
A user-provided comment to attach to the BulkWriter.
:returns: BulkWriter
An instance of BulkWriter configured with the provided settings.
Example Usage:
--------------
This method is typically used within an asynchronous context manager.
.. code-block:: python
async with Document.bulk_writer(ordered=True) as bulk:
await Document.insert_one(Document(field="value"), bulk_writer=bulk)
"""
return BulkWriter(
session, ordered, cls, bypass_document_validation, comment
)
class DocumentWithSoftDelete(Document):
deleted_at: Optional[datetime] = None
def is_deleted(self) -> bool:
return self.deleted_at is not None
async def hard_delete(
self,
session: Optional[AsyncClientSession] = None,
bulk_writer: Optional[BulkWriter] = None,
link_rule: DeleteRules = DeleteRules.DO_NOTHING,
skip_actions: Optional[List[Union[ActionDirections, str]]] = None,
**pymongo_kwargs: Any,
) -> Optional[DeleteResult]:
return await super().delete(
session=session,
bulk_writer=bulk_writer,
link_rule=link_rule,
skip_actions=skip_actions,
**pymongo_kwargs,
)
async def delete(
self,
session: Optional[AsyncClientSession] = None,
bulk_writer: Optional[BulkWriter] = None,
link_rule: DeleteRules = DeleteRules.DO_NOTHING,
skip_actions: Optional[List[Union[ActionDirections, str]]] = None,
**pymongo_kwargs,
) -> Optional[DeleteResult]:
self.deleted_at = datetime.now(tz=timezone.utc)
await self.save()
return None
@classmethod
def find_many_in_all( # type: ignore
cls: Type[FindType],
*args: Union[Mapping[Any, Any], bool],
projection_model: Optional[Type["DocumentProjectionType"]] = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
sort: Union[None, str, List[Tuple[str, SortDirection]]] = None,
session: Optional[AsyncClientSession] = None,
ignore_cache: bool = False,
fetch_links: bool = False,
with_children: bool = False,
lazy_parse: bool = False,
nesting_depth: Optional[int] = None,
nesting_depths_per_field: Optional[Dict[str, int]] = None,
**pymongo_kwargs: Any,
) -> Union[FindMany[FindType], FindMany["DocumentProjectionType"]]:
return cls._find_many_query_class(document_model=cls).find_many(
*args,
sort=sort,
skip=skip,
limit=limit,
projection_model=projection_model,
session=session,
ignore_cache=ignore_cache,
fetch_links=fetch_links,
lazy_parse=lazy_parse,
nesting_depth=nesting_depth,
nesting_depths_per_field=nesting_depths_per_field,
**pymongo_kwargs,
)
@classmethod
def find_many( # type: ignore
cls: Type[FindType],
*args: Union[Mapping[Any, Any], bool],
projection_model: Optional[Type["DocumentProjectionType"]] = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
sort: Union[None, str, List[Tuple[str, SortDirection]]] = None,
session: Optional[AsyncClientSession] = None,
ignore_cache: bool = False,
fetch_links: bool = False,
with_children: bool = False,
lazy_parse: bool = False,
nesting_depth: Optional[int] = None,
nesting_depths_per_field: Optional[Dict[str, int]] = None,
**pymongo_kwargs: Any,
) -> Union[FindMany[FindType], FindMany["DocumentProjectionType"]]:
args = cls._add_class_id_filter(args, with_children) + (
{"deleted_at": None},
)
return cls._find_many_query_class(document_model=cls).find_many(
*args,
sort=sort,
skip=skip,
limit=limit,
projection_model=projection_model,
session=session,
ignore_cache=ignore_cache,
fetch_links=fetch_links,
lazy_parse=lazy_parse,
nesting_depth=nesting_depth,
nesting_depths_per_field=nesting_depths_per_field,
**pymongo_kwargs,
)
@classmethod
def find_one( # type: ignore
cls: Type[FindType],
*args: Union[Mapping[Any, Any], bool],
projection_model: Optional[Type["DocumentProjectionType"]] = None,
session: Optional[AsyncClientSession] = None,
ignore_cache: bool = False,
fetch_links: bool = False,
with_children: bool = False,
nesting_depth: Optional[int] = None,
nesting_depths_per_field: Optional[Dict[str, int]] = None,
**pymongo_kwargs: Any,
) -> Union[FindOne[FindType], FindOne["DocumentProjectionType"]]:
args = cls._add_class_id_filter(args, with_children) + (
{"deleted_at": None},
)
return cls._find_one_query_class(document_model=cls).find_one(
*args,
projection_model=projection_model,
session=session,
ignore_cache=ignore_cache,
fetch_links=fetch_links,
nesting_depth=nesting_depth,
nesting_depths_per_field=nesting_depths_per_field,
**pymongo_kwargs,
)