diff --git a/.github/workflows/docs-pages.yml b/.github/workflows/docs-pages.yml index ef23559b..c70abfb1 100644 --- a/.github/workflows/docs-pages.yml +++ b/.github/workflows/docs-pages.yml @@ -1,4 +1,4 @@ -name: Docs: Build and deploy MkDocs site +name: Docs Build and deploy MkDocs site on: push: diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index bf69acfb..a110af96 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -2378,8 +2378,25 @@ async def update(self, **field_values): raise NotImplementedError async def save( - self: "Model", pipeline: Optional[redis.client.Pipeline] = None - ) -> "Model": + self: "Model", + pipeline: Optional[redis.client.Pipeline] = None, + nx: bool = False, + xx: bool = False, + ) -> Optional["Model"]: + """Save the model instance to Redis. + + Args: + pipeline: Optional Redis pipeline for batching operations. + nx: If True, only save if the key does NOT exist (insert-only). + xx: If True, only save if the key already exists (update-only). + + Returns: + The model instance if saved successfully, None if nx/xx condition + was not met. + + Raises: + ValueError: If both nx and xx are True. + """ raise NotImplementedError async def expire( @@ -2615,8 +2632,19 @@ def __init_subclass__(cls, **kwargs): ) async def save( - self: "Model", pipeline: Optional[redis.client.Pipeline] = None - ) -> "Model": + self: "Model", + pipeline: Optional[redis.client.Pipeline] = None, + nx: bool = False, + xx: bool = False, + ) -> Optional["Model"]: + if nx and xx: + raise ValueError("Cannot specify both nx and xx") + if pipeline and (nx or xx): + raise ValueError( + "Cannot use nx or xx with pipeline for HashModel. " + "Use JsonModel if you need conditional saves with pipelines." + ) + self.check() db = self._get_db(pipeline) @@ -2636,9 +2664,23 @@ async def save( for k, v in document.items() } + key = self.key() + + async def _do_save(conn): + # Check nx/xx conditions (HSET doesn't support these natively) + if nx or xx: + exists = await conn.exists(key) + if nx and exists: + return None # Key exists, nx means don't overwrite + if xx and not exists: + return None # Key doesn't exist, xx means only update existing + + await conn.hset(key, mapping=document) + return self + # TODO: Wrap any Redis response errors in a custom exception? try: - await db.hset(self.key(), mapping=document) + return await _do_save(db) except RuntimeError as e: if "Event loop is closed" in str(e): # Connection is bound to closed event loop, refresh it and retry @@ -2646,10 +2688,9 @@ async def save( self.__class__._meta.database = get_redis_connection() db = self._get_db(pipeline) - await db.hset(self.key(), mapping=document) + return await _do_save(db) else: raise - return self @classmethod async def all_pks(cls): # type: ignore @@ -2835,8 +2876,14 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) async def save( - self: "Model", pipeline: Optional[redis.client.Pipeline] = None - ) -> "Model": + self: "Model", + pipeline: Optional[redis.client.Pipeline] = None, + nx: bool = False, + xx: bool = False, + ) -> Optional["Model"]: + if nx and xx: + raise ValueError("Cannot specify both nx and xx") + self.check() db = self._get_db(pipeline) @@ -2847,9 +2894,20 @@ async def save( # Apply JSON encoding for complex types (Enums, UUIDs, Sets, etc.) data = jsonable_encoder(data) + key = self.key() + path = Path.root_path() + + async def _do_save(conn): + # JSON.SET supports nx and xx natively + result = await conn.json().set(key, path, data, nx=nx, xx=xx) + # JSON.SET returns None if nx/xx condition not met, "OK" otherwise + if result is None: + return None + return self + # TODO: Wrap response errors in a custom exception? try: - await db.json().set(self.key(), Path.root_path(), data) + return await _do_save(db) except RuntimeError as e: if "Event loop is closed" in str(e): # Connection is bound to closed event loop, refresh it and retry @@ -2857,10 +2915,9 @@ async def save( self.__class__._meta.database = get_redis_connection() db = self._get_db(pipeline) - await db.json().set(self.key(), Path.root_path(), data) + return await _do_save(db) else: raise - return self @classmethod async def all_pks(cls): # type: ignore diff --git a/tests/test_hash_model.py b/tests/test_hash_model.py index af8a9f2a..22cd8366 100644 --- a/tests/test_hash_model.py +++ b/tests/test_hash_model.py @@ -1357,3 +1357,113 @@ class Meta: assert len(rematerialized) == 1 assert rematerialized[0].pk == loc1.pk + + +@py_test_mark_asyncio +async def test_save_nx_only_saves_if_not_exists(m): + """Test that save(nx=True) only saves if the key doesn't exist.""" + await Migrator().run() + + member = m.Member( + id=1000, + first_name="Andrew", + last_name="Brookins", + email="a@example.com", + join_date=today, + age=38, + bio="Original bio", + ) + + # First save should succeed with nx=True + result = await member.save(nx=True) + assert result is not None + assert result.pk == member.pk + + # Second save with same pk should return None (key exists) + member2 = m.Member( + id=1000, + first_name="Different", + last_name="Name", + email="b@example.com", + join_date=today, + age=25, + bio="Different bio", + ) + result = await member2.save(nx=True) + assert result is None + + # Verify original data is unchanged + fetched = await m.Member.get(member.id) + assert fetched.first_name == "Andrew" + + +@py_test_mark_asyncio +async def test_save_xx_only_saves_if_exists(m): + """Test that save(xx=True) only saves if the key already exists.""" + await Migrator().run() + + member = m.Member( + id=2000, + first_name="Andrew", + last_name="Brookins", + email="a@example.com", + join_date=today, + age=38, + bio="Original bio", + ) + + # First save with xx=True should return None (key doesn't exist) + result = await member.save(xx=True) + assert result is None + + # Save without flags to create the key + await member.save() + + # Now update with xx=True should succeed + member.first_name = "Updated" + result = await member.save(xx=True) + assert result is not None + + # Verify data was updated + fetched = await m.Member.get(member.id) + assert fetched.first_name == "Updated" + + +@py_test_mark_asyncio +async def test_save_nx_xx_mutually_exclusive(m): + """Test that save() raises ValueError if both nx and xx are True.""" + await Migrator().run() + + member = m.Member( + id=3000, + first_name="Andrew", + last_name="Brookins", + email="a@example.com", + join_date=today, + age=38, + bio="Some bio", + ) + + with pytest.raises(ValueError, match="Cannot specify both nx and xx"): + await member.save(nx=True, xx=True) + + +@py_test_mark_asyncio +async def test_save_nx_with_pipeline_raises_error(m): + """Test that save(nx=True) with pipeline raises an error for HashModel.""" + await Migrator().run() + + member = m.Member( + id=4000, + first_name="Andrew", + last_name="Brookins", + email="a@example.com", + join_date=today, + age=38, + bio="Bio 1", + ) + + # HashModel doesn't support nx/xx with pipeline (HSET doesn't support it natively) + async with m.Member.db().pipeline(transaction=True) as pipe: + with pytest.raises(ValueError, match="Cannot use nx or xx with pipeline"): + await member.save(pipeline=pipe, nx=True) diff --git a/tests/test_json_model.py b/tests/test_json_model.py index 5a3d5ed1..a40179e7 100644 --- a/tests/test_json_model.py +++ b/tests/test_json_model.py @@ -1572,3 +1572,124 @@ class Meta: # Test sorting by NUMERIC field still works results = await Product.find().sort_by("price").all() assert results == [product3, product2, product1] # 30, 50, 100 + + +@py_test_mark_asyncio +async def test_save_nx_only_saves_if_not_exists(m, address): + """Test that save(nx=True) only saves if the key doesn't exist.""" + await Migrator().run() + + member = m.Member( + first_name="Andrew", + last_name="Brookins", + email="a@example.com", + join_date=today, + age=38, + address=address, + ) + + # First save should succeed with nx=True + result = await member.save(nx=True) + assert result is not None + assert result.pk == member.pk + + # Second save with same pk should return None (key exists) + member2 = m.Member( + pk=member.pk, + first_name="Different", + last_name="Name", + email="b@example.com", + join_date=today, + age=25, + address=address, + ) + result = await member2.save(nx=True) + assert result is None + + # Verify original data is unchanged + fetched = await m.Member.get(member.pk) + assert fetched.first_name == "Andrew" + + +@py_test_mark_asyncio +async def test_save_xx_only_saves_if_exists(m, address): + """Test that save(xx=True) only saves if the key already exists.""" + await Migrator().run() + + member = m.Member( + first_name="Andrew", + last_name="Brookins", + email="a@example.com", + join_date=today, + age=38, + address=address, + ) + + # First save with xx=True should return None (key doesn't exist) + result = await member.save(xx=True) + assert result is None + + # Save without flags to create the key + await member.save() + + # Now update with xx=True should succeed + member.first_name = "Updated" + result = await member.save(xx=True) + assert result is not None + + # Verify data was updated + fetched = await m.Member.get(member.pk) + assert fetched.first_name == "Updated" + + +@py_test_mark_asyncio +async def test_save_nx_xx_mutually_exclusive(m, address): + """Test that save() raises ValueError if both nx and xx are True.""" + await Migrator().run() + + member = m.Member( + first_name="Andrew", + last_name="Brookins", + email="a@example.com", + join_date=today, + age=38, + address=address, + ) + + with pytest.raises(ValueError, match="Cannot specify both nx and xx"): + await member.save(nx=True, xx=True) + + +@py_test_mark_asyncio +async def test_save_nx_with_pipeline(m, address): + """Test that save(nx=True) works with pipeline.""" + await Migrator().run() + + member1 = m.Member( + first_name="Andrew", + last_name="Brookins", + email="a@example.com", + join_date=today, + age=38, + address=address, + ) + member2 = m.Member( + first_name="Kim", + last_name="Brookins", + email="k@example.com", + join_date=today, + age=34, + address=address, + ) + + # Save both with nx=True via pipeline + async with m.Member.db().pipeline(transaction=True) as pipe: + await member1.save(pipeline=pipe, nx=True) + await member2.save(pipeline=pipe, nx=True) + await pipe.execute() + + # Verify both were saved + fetched1 = await m.Member.get(member1.pk) + fetched2 = await m.Member.get(member2.pk) + assert fetched1.first_name == "Andrew" + assert fetched2.first_name == "Kim"