367 lines
12 KiB
Python
367 lines
12 KiB
Python
from abc import abstractmethod
|
|
from enum import Enum
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
Generator,
|
|
List,
|
|
Mapping,
|
|
Optional,
|
|
Type,
|
|
Union,
|
|
)
|
|
|
|
from pymongo import ReturnDocument
|
|
from pymongo import UpdateMany as UpdateManyPyMongo
|
|
from pymongo import UpdateOne as UpdateOnePyMongo
|
|
from pymongo.asynchronous.client_session import AsyncClientSession
|
|
from pymongo.results import InsertOneResult, UpdateResult
|
|
|
|
from beanie.odm.bulk import BulkWriter
|
|
from beanie.odm.interfaces.clone import CloneInterface
|
|
from beanie.odm.interfaces.session import SessionMethods
|
|
from beanie.odm.interfaces.update import (
|
|
UpdateMethods,
|
|
)
|
|
from beanie.odm.operators.update import BaseUpdateOperator
|
|
from beanie.odm.operators.update.general import SetRevisionId
|
|
from beanie.odm.utils.encoder import Encoder
|
|
from beanie.odm.utils.parsing import parse_obj
|
|
|
|
if TYPE_CHECKING:
|
|
from beanie.odm.documents import DocType
|
|
|
|
|
|
class UpdateResponse(str, Enum):
|
|
UPDATE_RESULT = "UPDATE_RESULT" # PyMongo update result
|
|
OLD_DOCUMENT = "OLD_DOCUMENT" # Original document
|
|
NEW_DOCUMENT = "NEW_DOCUMENT" # Updated document
|
|
|
|
|
|
class UpdateQuery(UpdateMethods, SessionMethods, CloneInterface):
|
|
"""
|
|
Update Query base class
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
document_model: Type["DocType"],
|
|
find_query: Mapping[str, Any],
|
|
):
|
|
self.document_model = document_model
|
|
self.find_query = find_query
|
|
self.update_expressions: List[Mapping[str, Any]] = []
|
|
self.session = None
|
|
self.is_upsert = False
|
|
self.upsert_insert_doc: Optional["DocType"] = None
|
|
self.encoders: Dict[Any, Callable[[Any], Any]] = {}
|
|
self.bulk_writer: Optional[BulkWriter] = None
|
|
self.encoders = self.document_model.get_settings().bson_encoders
|
|
self.pymongo_kwargs: Dict[str, Any] = {}
|
|
|
|
@property
|
|
def update_query(self) -> Dict[str, Any]:
|
|
query: Union[Dict[str, Any], List[Dict[str, Any]], None] = None
|
|
for expression in self.update_expressions:
|
|
if isinstance(expression, BaseUpdateOperator):
|
|
if query is None:
|
|
query = {}
|
|
if isinstance(query, list):
|
|
raise TypeError("Wrong expression type")
|
|
query.update(expression.query)
|
|
elif isinstance(expression, dict):
|
|
if query is None:
|
|
query = {}
|
|
if isinstance(query, list):
|
|
raise TypeError("Wrong expression type")
|
|
query.update(expression)
|
|
elif isinstance(expression, SetRevisionId):
|
|
if query is None:
|
|
query = {}
|
|
if isinstance(query, list):
|
|
raise TypeError("Wrong expression type")
|
|
set_query = query.get("$set", {})
|
|
set_query.update(expression.query.get("$set", {}))
|
|
query["$set"] = set_query
|
|
elif isinstance(expression, list):
|
|
if query is None:
|
|
query = []
|
|
if isinstance(query, dict):
|
|
raise TypeError("Wrong expression type")
|
|
query.extend(expression)
|
|
else:
|
|
raise TypeError("Wrong expression type")
|
|
return Encoder(custom_encoders=self.encoders).encode(query)
|
|
|
|
@abstractmethod
|
|
async def _update(self) -> UpdateResult: ...
|
|
|
|
|
|
class UpdateMany(UpdateQuery):
|
|
"""
|
|
Update Many query class
|
|
"""
|
|
|
|
def update(
|
|
self,
|
|
*args: Mapping[str, Any],
|
|
session: Optional[AsyncClientSession] = None,
|
|
bulk_writer: Optional[BulkWriter] = None,
|
|
**pymongo_kwargs: Any,
|
|
) -> "UpdateQuery":
|
|
"""
|
|
Provide modifications to the update query.
|
|
|
|
:param args: *Union[dict, Mapping] - the modifications to apply.
|
|
:param session: Optional[AsyncClientSession] - pymongo session
|
|
:param bulk_writer: Optional[BulkWriter]
|
|
:param pymongo_kwargs: pymongo native parameters for update operation
|
|
:return: UpdateMany query
|
|
"""
|
|
self.set_session(session=session)
|
|
self.update_expressions += args
|
|
if bulk_writer:
|
|
self.bulk_writer = bulk_writer
|
|
self.pymongo_kwargs.update(pymongo_kwargs)
|
|
return self
|
|
|
|
def upsert(
|
|
self,
|
|
*args: Mapping[str, Any],
|
|
on_insert: "DocType",
|
|
session: Optional[AsyncClientSession] = None,
|
|
**pymongo_kwargs: Any,
|
|
) -> "UpdateQuery":
|
|
"""
|
|
Provide modifications to the upsert query.
|
|
|
|
:param args: *Union[dict, Mapping] - the modifications to apply.
|
|
:param on_insert: DocType - document to insert if there is no matched
|
|
document in the collection
|
|
:param session: Optional[AsyncClientSession] - pymongo session
|
|
:param **pymongo_kwargs: pymongo native parameters for update operation
|
|
:return: UpdateMany query
|
|
"""
|
|
self.upsert_insert_doc = on_insert # type: ignore
|
|
self.update(*args, session=session, **pymongo_kwargs)
|
|
return self
|
|
|
|
def update_many(
|
|
self,
|
|
*args: Mapping[str, Any],
|
|
session: Optional[AsyncClientSession] = None,
|
|
bulk_writer: Optional[BulkWriter] = None,
|
|
**pymongo_kwargs: Any,
|
|
):
|
|
"""
|
|
Provide modifications to the update query
|
|
|
|
:param args: *Union[dict, Mapping] - the modifications to apply.
|
|
:param session: Optional[AsyncClientSession] - pymongo session
|
|
:param bulk_writer: "BulkWriter" - Beanie bulk writer
|
|
:param pymongo_kwargs: pymongo native parameters for update operation
|
|
:return: UpdateMany query
|
|
"""
|
|
return self.update(
|
|
*args, session=session, bulk_writer=bulk_writer, **pymongo_kwargs
|
|
)
|
|
|
|
async def _update(self):
|
|
if self.bulk_writer is None:
|
|
return (
|
|
await self.document_model.get_pymongo_collection().update_many(
|
|
self.find_query,
|
|
self.update_query,
|
|
session=self.session,
|
|
**self.pymongo_kwargs,
|
|
)
|
|
)
|
|
else:
|
|
self.bulk_writer.add_operation(
|
|
self.document_model,
|
|
UpdateManyPyMongo(
|
|
self.find_query, self.update_query, **self.pymongo_kwargs
|
|
),
|
|
)
|
|
|
|
def __await__(
|
|
self,
|
|
) -> Generator[
|
|
Any, None, Union[UpdateResult, InsertOneResult, Optional["DocType"]]
|
|
]:
|
|
"""
|
|
Run the query
|
|
:return:
|
|
"""
|
|
|
|
update_result = yield from self._update().__await__()
|
|
if self.upsert_insert_doc is None:
|
|
return update_result
|
|
|
|
if update_result is not None and update_result.matched_count == 0:
|
|
return (
|
|
yield from self.document_model.insert_one(
|
|
document=self.upsert_insert_doc,
|
|
session=self.session,
|
|
bulk_writer=self.bulk_writer,
|
|
).__await__()
|
|
)
|
|
|
|
return update_result
|
|
|
|
|
|
class UpdateOne(UpdateQuery):
|
|
"""
|
|
Update One query class
|
|
"""
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
super(UpdateOne, self).__init__(*args, **kwargs)
|
|
self.response_type = UpdateResponse.UPDATE_RESULT
|
|
|
|
def update(
|
|
self,
|
|
*args: Mapping[str, Any],
|
|
session: Optional[AsyncClientSession] = None,
|
|
bulk_writer: Optional[BulkWriter] = None,
|
|
response_type: Optional[UpdateResponse] = None,
|
|
**pymongo_kwargs: Any,
|
|
) -> "UpdateQuery":
|
|
"""
|
|
Provide modifications to the update query.
|
|
|
|
:param args: *Union[dict, Mapping] - the modifications to apply.
|
|
:param session: Optional[AsyncClientSession] - pymongo session
|
|
:param bulk_writer: Optional[BulkWriter]
|
|
:param response_type: UpdateResponse
|
|
:param pymongo_kwargs: pymongo native parameters for update operation
|
|
:return: UpdateMany query
|
|
"""
|
|
self.set_session(session=session)
|
|
self.update_expressions += args
|
|
if response_type is not None:
|
|
self.response_type = response_type
|
|
if bulk_writer:
|
|
self.bulk_writer = bulk_writer
|
|
self.pymongo_kwargs.update(pymongo_kwargs)
|
|
return self
|
|
|
|
def upsert(
|
|
self,
|
|
*args: Mapping[str, Any],
|
|
on_insert: "DocType",
|
|
session: Optional[AsyncClientSession] = None,
|
|
response_type: Optional[UpdateResponse] = None,
|
|
**pymongo_kwargs: Any,
|
|
) -> "UpdateQuery":
|
|
"""
|
|
Provide modifications to the upsert query.
|
|
|
|
:param args: *Union[dict, Mapping] - the modifications to apply.
|
|
:param on_insert: DocType - document to insert if there is no matched
|
|
document in the collection
|
|
:param session: Optional[AsyncClientSession] - pymongo session
|
|
:param response_type: Optional[UpdateResponse]
|
|
:param pymongo_kwargs: pymongo native parameters for update operation
|
|
:return: UpdateMany query
|
|
"""
|
|
self.upsert_insert_doc = on_insert # type: ignore
|
|
self.update(
|
|
*args,
|
|
response_type=response_type,
|
|
session=session,
|
|
**pymongo_kwargs,
|
|
)
|
|
return self
|
|
|
|
def update_one(
|
|
self,
|
|
*args: Mapping[str, Any],
|
|
session: Optional[AsyncClientSession] = None,
|
|
bulk_writer: Optional[BulkWriter] = None,
|
|
response_type: Optional[UpdateResponse] = None,
|
|
**pymongo_kwargs: Any,
|
|
):
|
|
"""
|
|
Provide modifications to the update query. The same as `update()`
|
|
|
|
:param args: *Union[dict, Mapping] - the modifications to apply.
|
|
:param session: Optional[AsyncClientSession] - pymongo session
|
|
:param bulk_writer: "BulkWriter" - Beanie bulk writer
|
|
:param response_type: Optional[UpdateResponse]
|
|
:param pymongo_kwargs: pymongo native parameters for update operation
|
|
:return: UpdateMany query
|
|
"""
|
|
return self.update(
|
|
*args,
|
|
session=session,
|
|
bulk_writer=bulk_writer,
|
|
response_type=response_type,
|
|
**pymongo_kwargs,
|
|
)
|
|
|
|
async def _update(self):
|
|
if not self.bulk_writer:
|
|
if self.response_type == UpdateResponse.UPDATE_RESULT:
|
|
return await self.document_model.get_pymongo_collection().update_one(
|
|
self.find_query,
|
|
self.update_query,
|
|
session=self.session,
|
|
**self.pymongo_kwargs,
|
|
)
|
|
else:
|
|
result = await self.document_model.get_pymongo_collection().find_one_and_update(
|
|
self.find_query,
|
|
self.update_query,
|
|
session=self.session,
|
|
return_document=(
|
|
ReturnDocument.BEFORE
|
|
if self.response_type == UpdateResponse.OLD_DOCUMENT
|
|
else ReturnDocument.AFTER
|
|
),
|
|
**self.pymongo_kwargs,
|
|
)
|
|
if result is not None:
|
|
result = parse_obj(self.document_model, result)
|
|
return result
|
|
else:
|
|
self.bulk_writer.add_operation(
|
|
self.document_model,
|
|
UpdateOnePyMongo(
|
|
self.find_query, self.update_query, **self.pymongo_kwargs
|
|
),
|
|
)
|
|
|
|
def __await__(
|
|
self,
|
|
) -> Generator[
|
|
Any, None, Union[UpdateResult, InsertOneResult, Optional["DocType"]]
|
|
]:
|
|
"""
|
|
Run the query
|
|
:return:
|
|
"""
|
|
update_result = yield from self._update().__await__()
|
|
if self.upsert_insert_doc is None:
|
|
return update_result
|
|
|
|
if (
|
|
self.response_type == UpdateResponse.UPDATE_RESULT
|
|
and update_result is not None
|
|
and update_result.matched_count == 0
|
|
) or (
|
|
self.response_type != UpdateResponse.UPDATE_RESULT
|
|
and update_result is None
|
|
):
|
|
return (
|
|
yield from self.document_model.insert_one(
|
|
document=self.upsert_insert_doc,
|
|
session=self.session,
|
|
bulk_writer=self.bulk_writer,
|
|
).__await__()
|
|
)
|
|
|
|
return update_result
|