from abc import abstractmethod from enum import Enum from typing import ( TYPE_CHECKING, Any, Callable, Dict, Generator, List, Mapping, Optional, Type, Union, ) from pymongo import ReturnDocument from pymongo import UpdateMany as UpdateManyPyMongo from pymongo import UpdateOne as UpdateOnePyMongo from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.results import InsertOneResult, UpdateResult from beanie.odm.bulk import BulkWriter from beanie.odm.interfaces.clone import CloneInterface from beanie.odm.interfaces.session import SessionMethods from beanie.odm.interfaces.update import ( UpdateMethods, ) from beanie.odm.operators.update import BaseUpdateOperator from beanie.odm.operators.update.general import SetRevisionId from beanie.odm.utils.encoder import Encoder from beanie.odm.utils.parsing import parse_obj if TYPE_CHECKING: from beanie.odm.documents import DocType class UpdateResponse(str, Enum): UPDATE_RESULT = "UPDATE_RESULT" # PyMongo update result OLD_DOCUMENT = "OLD_DOCUMENT" # Original document NEW_DOCUMENT = "NEW_DOCUMENT" # Updated document class UpdateQuery(UpdateMethods, SessionMethods, CloneInterface): """ Update Query base class """ def __init__( self, document_model: Type["DocType"], find_query: Mapping[str, Any], ): self.document_model = document_model self.find_query = find_query self.update_expressions: List[Mapping[str, Any]] = [] self.session = None self.is_upsert = False self.upsert_insert_doc: Optional["DocType"] = None self.encoders: Dict[Any, Callable[[Any], Any]] = {} self.bulk_writer: Optional[BulkWriter] = None self.encoders = self.document_model.get_settings().bson_encoders self.pymongo_kwargs: Dict[str, Any] = {} @property def update_query(self) -> Dict[str, Any]: query: Union[Dict[str, Any], List[Dict[str, Any]], None] = None for expression in self.update_expressions: if isinstance(expression, BaseUpdateOperator): if query is None: query = {} if isinstance(query, list): raise TypeError("Wrong expression type") query.update(expression.query) elif isinstance(expression, dict): if query is None: query = {} if isinstance(query, list): raise TypeError("Wrong expression type") query.update(expression) elif isinstance(expression, SetRevisionId): if query is None: query = {} if isinstance(query, list): raise TypeError("Wrong expression type") set_query = query.get("$set", {}) set_query.update(expression.query.get("$set", {})) query["$set"] = set_query elif isinstance(expression, list): if query is None: query = [] if isinstance(query, dict): raise TypeError("Wrong expression type") query.extend(expression) else: raise TypeError("Wrong expression type") return Encoder(custom_encoders=self.encoders).encode(query) @abstractmethod async def _update(self) -> UpdateResult: ... class UpdateMany(UpdateQuery): """ Update Many query class """ def update( self, *args: Mapping[str, Any], session: Optional[AsyncClientSession] = None, bulk_writer: Optional[BulkWriter] = None, **pymongo_kwargs: Any, ) -> "UpdateQuery": """ Provide modifications to the update query. :param args: *Union[dict, Mapping] - the modifications to apply. :param session: Optional[AsyncClientSession] - pymongo session :param bulk_writer: Optional[BulkWriter] :param pymongo_kwargs: pymongo native parameters for update operation :return: UpdateMany query """ self.set_session(session=session) self.update_expressions += args if bulk_writer: self.bulk_writer = bulk_writer self.pymongo_kwargs.update(pymongo_kwargs) return self def upsert( self, *args: Mapping[str, Any], on_insert: "DocType", session: Optional[AsyncClientSession] = None, **pymongo_kwargs: Any, ) -> "UpdateQuery": """ Provide modifications to the upsert query. :param args: *Union[dict, Mapping] - the modifications to apply. :param on_insert: DocType - document to insert if there is no matched document in the collection :param session: Optional[AsyncClientSession] - pymongo session :param **pymongo_kwargs: pymongo native parameters for update operation :return: UpdateMany query """ self.upsert_insert_doc = on_insert # type: ignore self.update(*args, session=session, **pymongo_kwargs) return self def update_many( self, *args: Mapping[str, Any], session: Optional[AsyncClientSession] = None, bulk_writer: Optional[BulkWriter] = None, **pymongo_kwargs: Any, ): """ Provide modifications to the update query :param args: *Union[dict, Mapping] - the modifications to apply. :param session: Optional[AsyncClientSession] - pymongo session :param bulk_writer: "BulkWriter" - Beanie bulk writer :param pymongo_kwargs: pymongo native parameters for update operation :return: UpdateMany query """ return self.update( *args, session=session, bulk_writer=bulk_writer, **pymongo_kwargs ) async def _update(self): if self.bulk_writer is None: return ( await self.document_model.get_pymongo_collection().update_many( self.find_query, self.update_query, session=self.session, **self.pymongo_kwargs, ) ) else: self.bulk_writer.add_operation( self.document_model, UpdateManyPyMongo( self.find_query, self.update_query, **self.pymongo_kwargs ), ) def __await__( self, ) -> Generator[ Any, None, Union[UpdateResult, InsertOneResult, Optional["DocType"]] ]: """ Run the query :return: """ update_result = yield from self._update().__await__() if self.upsert_insert_doc is None: return update_result if update_result is not None and update_result.matched_count == 0: return ( yield from self.document_model.insert_one( document=self.upsert_insert_doc, session=self.session, bulk_writer=self.bulk_writer, ).__await__() ) return update_result class UpdateOne(UpdateQuery): """ Update One query class """ def __init__(self, *args: Any, **kwargs: Any): super(UpdateOne, self).__init__(*args, **kwargs) self.response_type = UpdateResponse.UPDATE_RESULT def update( self, *args: Mapping[str, Any], session: Optional[AsyncClientSession] = None, bulk_writer: Optional[BulkWriter] = None, response_type: Optional[UpdateResponse] = None, **pymongo_kwargs: Any, ) -> "UpdateQuery": """ Provide modifications to the update query. :param args: *Union[dict, Mapping] - the modifications to apply. :param session: Optional[AsyncClientSession] - pymongo session :param bulk_writer: Optional[BulkWriter] :param response_type: UpdateResponse :param pymongo_kwargs: pymongo native parameters for update operation :return: UpdateMany query """ self.set_session(session=session) self.update_expressions += args if response_type is not None: self.response_type = response_type if bulk_writer: self.bulk_writer = bulk_writer self.pymongo_kwargs.update(pymongo_kwargs) return self def upsert( self, *args: Mapping[str, Any], on_insert: "DocType", session: Optional[AsyncClientSession] = None, response_type: Optional[UpdateResponse] = None, **pymongo_kwargs: Any, ) -> "UpdateQuery": """ Provide modifications to the upsert query. :param args: *Union[dict, Mapping] - the modifications to apply. :param on_insert: DocType - document to insert if there is no matched document in the collection :param session: Optional[AsyncClientSession] - pymongo session :param response_type: Optional[UpdateResponse] :param pymongo_kwargs: pymongo native parameters for update operation :return: UpdateMany query """ self.upsert_insert_doc = on_insert # type: ignore self.update( *args, response_type=response_type, session=session, **pymongo_kwargs, ) return self def update_one( self, *args: Mapping[str, Any], session: Optional[AsyncClientSession] = None, bulk_writer: Optional[BulkWriter] = None, response_type: Optional[UpdateResponse] = None, **pymongo_kwargs: Any, ): """ Provide modifications to the update query. The same as `update()` :param args: *Union[dict, Mapping] - the modifications to apply. :param session: Optional[AsyncClientSession] - pymongo session :param bulk_writer: "BulkWriter" - Beanie bulk writer :param response_type: Optional[UpdateResponse] :param pymongo_kwargs: pymongo native parameters for update operation :return: UpdateMany query """ return self.update( *args, session=session, bulk_writer=bulk_writer, response_type=response_type, **pymongo_kwargs, ) async def _update(self): if not self.bulk_writer: if self.response_type == UpdateResponse.UPDATE_RESULT: return await self.document_model.get_pymongo_collection().update_one( self.find_query, self.update_query, session=self.session, **self.pymongo_kwargs, ) else: result = await self.document_model.get_pymongo_collection().find_one_and_update( self.find_query, self.update_query, session=self.session, return_document=( ReturnDocument.BEFORE if self.response_type == UpdateResponse.OLD_DOCUMENT else ReturnDocument.AFTER ), **self.pymongo_kwargs, ) if result is not None: result = parse_obj(self.document_model, result) return result else: self.bulk_writer.add_operation( self.document_model, UpdateOnePyMongo( self.find_query, self.update_query, **self.pymongo_kwargs ), ) def __await__( self, ) -> Generator[ Any, None, Union[UpdateResult, InsertOneResult, Optional["DocType"]] ]: """ Run the query :return: """ update_result = yield from self._update().__await__() if self.upsert_insert_doc is None: return update_result if ( self.response_type == UpdateResponse.UPDATE_RESULT and update_result is not None and update_result.matched_count == 0 ) or ( self.response_type != UpdateResponse.UPDATE_RESULT and update_result is None ): return ( yield from self.document_model.insert_one( document=self.upsert_insert_doc, session=self.session, bulk_writer=self.bulk_writer, ).__await__() ) return update_result