diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 4bf928ed..4bc66f67 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -4,6 +4,7 @@ import time import warnings import weakref +from math import ceil from typing import ( TYPE_CHECKING, Any, @@ -39,7 +40,7 @@ convert_bytes, make_dict, ) -from redisvl.types import AsyncRedisClient, SyncRedisClient +from redisvl.types import AsyncRedisClient, SyncRedisClient, SyncRedisCluster from redisvl.utils.utils import deprecated_argument, deprecated_function, sync_wrapper if TYPE_CHECKING: @@ -706,6 +707,32 @@ def delete(self, drop: bool = True): except Exception as e: raise RedisSearchError(f"Error while deleting index: {str(e)}") from e + def _delete_batch(self, batch_keys: List[str]) -> int: + """Delete a batch of keys from Redis. + + For Redis Cluster, keys are deleted individually due to potential + cross-slot limitations. For standalone Redis, keys are deleted in + a single operation for better performance. + + Args: + batch_keys (List[str]): List of Redis keys to delete. + + Returns: + int: Count of records deleted from Redis. + """ + client = cast(SyncRedisClient, self._redis_client) + is_cluster = isinstance(client, RedisCluster) + if is_cluster: + records_deleted_in_batch = 0 + for key_to_delete in batch_keys: + try: + records_deleted_in_batch += cast(int, client.delete(key_to_delete)) + except redis.exceptions.RedisError as e: + logger.warning(f"Failed to delete key {key_to_delete}: {e}") + else: + records_deleted_in_batch = cast(int, client.delete(*batch_keys)) + return records_deleted_in_batch + def clear(self) -> int: """Clear all keys in Redis associated with the index, leaving the index available and in-place for future insertions or updates. @@ -717,28 +744,24 @@ def clear(self) -> int: Returns: int: Count of records deleted from Redis. """ - client = cast(SyncRedisClient, self._redis_client) + batch_size = 500 + max_ratio = 1.01 + + info = self.info() + max_records_deleted = ceil( + info["num_docs"] * max_ratio + ) # Allow to remove some additional concurrent inserts total_records_deleted: int = 0 + query = FilterQuery(FilterExpression("*"), return_fields=["id"]) + query.paging(0, batch_size) - for batch in self.paginate( - FilterQuery(FilterExpression("*"), return_fields=["id"]), page_size=500 - ): - batch_keys = [record["id"] for record in batch] - if batch_keys: - is_cluster = isinstance(client, RedisCluster) - if is_cluster: - records_deleted_in_batch = 0 - for key_to_delete in batch_keys: - try: - records_deleted_in_batch += cast( - int, client.delete(key_to_delete) - ) - except redis.exceptions.RedisError as e: - logger.warning(f"Failed to delete key {key_to_delete}: {e}") - total_records_deleted += records_deleted_in_batch - else: - record_deleted = cast(int, client.delete(*batch_keys)) - total_records_deleted += record_deleted + while True: + batch = self._query(query) + if batch and total_records_deleted <= max_records_deleted: + batch_keys = [record["id"] for record in batch] + total_records_deleted += self._delete_batch(batch_keys) + else: + break return total_records_deleted @@ -1160,6 +1183,9 @@ def paginate(self, query: BaseQuery, page_size: int = 30) -> Generator: batch contains. Adjust this value based on performance considerations and the expected volume of search results. + Note: + For stable pagination, the query must have a `sort_by` clause. + """ if not isinstance(page_size, int): raise TypeError("page_size must be an integer") @@ -1197,7 +1223,15 @@ def exists(self) -> bool: def _info(name: str, redis_client: SyncRedisClient) -> Dict[str, Any]: """Run FT.INFO to fetch information about the index.""" try: - return convert_bytes(redis_client.ft(name).info()) # type: ignore + if isinstance(redis_client, SyncRedisCluster): + node = redis_client.get_random_node() + values = redis_client.execute_command( + "FT.INFO", name, target_nodes=node + ) + info = make_dict(values) + else: + info = redis_client.ft(name).info() + return convert_bytes(info) except Exception as e: raise RedisSearchError( f"Error while fetching {name} index info: {str(e)}" @@ -1425,7 +1459,15 @@ async def _validate_client( @staticmethod async def _info(name: str, redis_client: AsyncRedisClient) -> Dict[str, Any]: try: - return convert_bytes(await redis_client.ft(name).info()) + if isinstance(redis_client, AsyncRedisCluster): + node = redis_client.get_random_node() + values = await redis_client.execute_command( + "FT.INFO", name, target_nodes=node + ) + info = make_dict(values) + else: + info = await redis_client.ft(name).info() + return convert_bytes(info) except Exception as e: raise RedisSearchError( f"Error while fetching {name} index info: {str(e)}" @@ -1549,6 +1591,34 @@ async def delete(self, drop: bool = True): except Exception as e: raise RedisSearchError(f"Error while deleting index: {str(e)}") from e + async def _delete_batch(self, batch_keys: List[str]) -> int: + """Delete a batch of keys from Redis. + + For Redis Cluster, keys are deleted individually due to potential + cross-slot limitations. For standalone Redis, keys are deleted in + a single operation for better performance. + + Args: + batch_keys (List[str]): List of Redis keys to delete. + + Returns: + int: Count of records deleted from Redis. + """ + client = await self._get_client() + is_cluster = isinstance(client, AsyncRedisCluster) + if is_cluster: + records_deleted_in_batch = 0 + for key_to_delete in batch_keys: + try: + records_deleted_in_batch += cast( + int, await client.delete(key_to_delete) + ) + except redis.exceptions.RedisError as e: + logger.warning(f"Failed to delete key {key_to_delete}: {e}") + else: + records_deleted_in_batch = await client.delete(*batch_keys) + return records_deleted_in_batch + async def clear(self) -> int: """Clear all keys in Redis associated with the index, leaving the index available and in-place for future insertions or updates. @@ -1560,28 +1630,24 @@ async def clear(self) -> int: Returns: int: Count of records deleted from Redis. """ - client = await self._get_client() + batch_size = 500 + max_ratio = 1.01 + + info = await self.info() + max_records_deleted = ceil( + info["num_docs"] * max_ratio + ) # Allow to remove some additional concurrent inserts total_records_deleted: int = 0 + query = FilterQuery(FilterExpression("*"), return_fields=["id"]) + query.paging(0, batch_size) - async for batch in self.paginate( - FilterQuery(FilterExpression("*"), return_fields=["id"]), page_size=500 - ): - batch_keys = [record["id"] for record in batch] - if batch_keys: - is_cluster = isinstance(client, AsyncRedisCluster) - if is_cluster: - records_deleted_in_batch = 0 - for key_to_delete in batch_keys: - try: - records_deleted_in_batch += cast( - int, await client.delete(key_to_delete) - ) - except redis.exceptions.RedisError as e: - logger.warning(f"Failed to delete key {key_to_delete}: {e}") - total_records_deleted += records_deleted_in_batch - else: - records_deleted = await client.delete(*batch_keys) - total_records_deleted += records_deleted + while True: + batch = await self._query(query) + if batch and total_records_deleted <= max_records_deleted: + batch_keys = [record["id"] for record in batch] + total_records_deleted += await self._delete_batch(batch_keys) + else: + break return total_records_deleted @@ -2039,6 +2105,9 @@ async def paginate(self, query: BaseQuery, page_size: int = 30) -> AsyncGenerato batch contains. Adjust this value based on performance considerations and the expected volume of search results. + Note: + For stable pagination, the query must have a `sort_by` clause. + """ if not isinstance(page_size, int): raise TypeError("page_size must be of type int") diff --git a/tests/integration/test_async_search_index.py b/tests/integration/test_async_search_index.py index 2ce139e7..c3ca3774 100644 --- a/tests/integration/test_async_search_index.py +++ b/tests/integration/test_async_search_index.py @@ -1,4 +1,5 @@ import warnings +from random import choice from unittest import mock import pytest @@ -245,14 +246,20 @@ async def test_search_index_delete(async_index): @pytest.mark.asyncio -async def test_search_index_clear(async_index): +@pytest.mark.parametrize("num_docs", [0, 1, 5, 10, 2042]) +async def test_search_index_clear(async_index, num_docs): await async_index.create(overwrite=True, drop=True) - data = [{"id": "1", "test": "foo"}] + tags = ["foo", "bar", "baz"] + data = [{"id": str(i), "test": choice(tags)} for i in range(num_docs)] await async_index.load(data, id_field="id") + info = await async_index.info() + assert info["num_docs"] == num_docs count = await async_index.clear() - assert count == len(data) + assert count == num_docs assert await async_index.exists() + info = await async_index.info() + assert info["num_docs"] == 0 @pytest.mark.asyncio diff --git a/tests/integration/test_redis_cluster_support.py b/tests/integration/test_redis_cluster_support.py index c7134eca..80b82420 100644 --- a/tests/integration/test_redis_cluster_support.py +++ b/tests/integration/test_redis_cluster_support.py @@ -70,6 +70,46 @@ def test_search_index_cluster_client(redis_cluster_url): index.delete(drop=True) +@pytest.mark.requires_cluster +def test_search_index_cluster_info(redis_cluster_url): + """Test .info() method on SearchIndex with RedisCluster client.""" + schema = IndexSchema.from_dict( + { + "index": {"name": "test_cluster_info", "prefix": "test_info"}, + "fields": [{"name": "name", "type": "text"}], + } + ) + client = RedisCluster.from_url(redis_cluster_url) + index = SearchIndex(schema=schema, redis_client=client) + try: + index.create(overwrite=True) + info = index.info() + assert isinstance(info, dict) + assert info.get("index_name", None) == "test_cluster_info" + finally: + index.delete(drop=True) + +@pytest.mark.requires_cluster +@pytest.mark.asyncio +async def test_async_search_index_cluster_info(redis_cluster_url): + """Test .info() method on AsyncSearchIndex with AsyncRedisCluster client.""" + schema = IndexSchema.from_dict( + { + "index": {"name": "async_cluster_info", "prefix": "async_info"}, + "fields": [{"name": "name", "type": "text"}], + } + ) + client = AsyncRedisCluster.from_url(redis_cluster_url) + index = AsyncSearchIndex(schema=schema, redis_client=client) + try: + await index.create(overwrite=True) + info = await index.info() + assert isinstance(info, dict) + assert info.get("index_name", None) == "async_cluster_info" + finally: + await index.delete(drop=True) + await client.aclose() + @pytest.mark.requires_cluster @pytest.mark.asyncio async def test_async_search_index_client(redis_cluster_url): diff --git a/tests/integration/test_search_index.py b/tests/integration/test_search_index.py index 0cbdf8ab..ae64a229 100644 --- a/tests/integration/test_search_index.py +++ b/tests/integration/test_search_index.py @@ -1,4 +1,5 @@ import warnings +from random import choice from unittest import mock import pytest @@ -303,15 +304,20 @@ def test_search_index_delete(index): assert not index.exists() assert index.name not in convert_bytes(index.client.execute_command("FT._LIST")) - -def test_search_index_clear(index): +@pytest.mark.parametrize("num_docs", [0, 1, 5, 10, 2042]) +def test_search_index_clear(index, num_docs): index.create(overwrite=True, drop=True) - data = [{"id": "1", "test": "foo"}] + tags = ["foo", "bar", "baz"] + data = [{"id": str(i), "test": choice(tags)} for i in range(num_docs)] index.load(data, id_field="id") + info = index.info() + assert info["num_docs"] == num_docs count = index.clear() assert count == len(data) assert index.exists() + info = index.info() + assert info["num_docs"] == 0 def test_search_index_drop_key(index): diff --git a/tests/unit/test_error_handling.py b/tests/unit/test_error_handling.py index 478743f9..173ff226 100644 --- a/tests/unit/test_error_handling.py +++ b/tests/unit/test_error_handling.py @@ -454,10 +454,15 @@ def test_clear_individual_key_deletion_errors(self, mock_validate): 1, # Third succeeds ] - # Mock the paginate method to return test data - with patch.object(SearchIndex, "paginate") as mock_paginate: - mock_paginate.return_value = [ - [{"id": "test:key1"}, {"id": "test:key2"}, {"id": "test:key3"}] + # Mock the .info() and ._query() methods to return test data + with ( + patch.object(SearchIndex, "info") as mock_info, + patch.object(SearchIndex, "_query") as mock_query, + ): + mock_info.return_value = {"num_docs": 3} + mock_query.side_effect = [ + [{"id": "test:key1"}, {"id": "test:key2"}, {"id": "test:key3"}], + [], ] # Create index with mocked client @@ -502,11 +507,21 @@ async def test_async_clear_individual_key_deletion_errors(self, mock_validate): ] ) - # Mock the paginate method to return test data - async def mock_paginate_generator(*args, **kwargs): - yield [{"id": "test:key1"}, {"id": "test:key2"}, {"id": "test:key3"}] + # Mock the .info() and ._query() methods to return test data + async def mock_info(*args, **kwargs): + return {"num_docs": 3} - with patch.object(AsyncSearchIndex, "paginate", mock_paginate_generator): + mock_query = AsyncMock( + side_effect=[ + [{"id": "test:key1"}, {"id": "test:key2"}, {"id": "test:key3"}], + [], + ] + ) + + with ( + patch.object(AsyncSearchIndex, "info", mock_info), + patch.object(AsyncSearchIndex, "_query", mock_query), + ): # Create index with mocked client index = AsyncSearchIndex(schema) index._redis_client = mock_cluster_client