Skip to content
155 changes: 112 additions & 43 deletions redisvl/index/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
import warnings
import weakref
from math import ceil
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -39,7 +40,7 @@
convert_bytes,
make_dict,
)
from redisvl.types import AsyncRedisClient, SyncRedisClient
from redisvl.types import AsyncRedisClient, SyncRedisClient, SyncRedisCluster
from redisvl.utils.utils import deprecated_argument, deprecated_function, sync_wrapper

if TYPE_CHECKING:
Expand Down Expand Up @@ -706,6 +707,32 @@ def delete(self, drop: bool = True):
except Exception as e:
raise RedisSearchError(f"Error while deleting index: {str(e)}") from e

def _delete_batch(self, batch_keys: List[str]) -> int:
"""Delete a batch of keys from Redis.

For Redis Cluster, keys are deleted individually due to potential
cross-slot limitations. For standalone Redis, keys are deleted in
a single operation for better performance.

Args:
batch_keys (List[str]): List of Redis keys to delete.

Returns:
int: Count of records deleted from Redis.
"""
client = cast(SyncRedisClient, self._redis_client)
is_cluster = isinstance(client, RedisCluster)
if is_cluster:
records_deleted_in_batch = 0
for key_to_delete in batch_keys:
try:
records_deleted_in_batch += cast(int, client.delete(key_to_delete))
except redis.exceptions.RedisError as e:
logger.warning(f"Failed to delete key {key_to_delete}: {e}")
else:
records_deleted_in_batch = cast(int, client.delete(*batch_keys))
return records_deleted_in_batch

def clear(self) -> int:
"""Clear all keys in Redis associated with the index, leaving the index
available and in-place for future insertions or updates.
Expand All @@ -717,28 +744,24 @@ def clear(self) -> int:
Returns:
int: Count of records deleted from Redis.
"""
client = cast(SyncRedisClient, self._redis_client)
batch_size = 500
max_ratio = 1.01

info = self.info()
max_records_deleted = ceil(
info["num_docs"] * max_ratio
) # Allow to remove some additional concurrent inserts
total_records_deleted: int = 0
query = FilterQuery(FilterExpression("*"), return_fields=["id"])
query.paging(0, batch_size)

for batch in self.paginate(
FilterQuery(FilterExpression("*"), return_fields=["id"]), page_size=500
):
batch_keys = [record["id"] for record in batch]
if batch_keys:
is_cluster = isinstance(client, RedisCluster)
if is_cluster:
records_deleted_in_batch = 0
for key_to_delete in batch_keys:
try:
records_deleted_in_batch += cast(
int, client.delete(key_to_delete)
)
except redis.exceptions.RedisError as e:
logger.warning(f"Failed to delete key {key_to_delete}: {e}")
total_records_deleted += records_deleted_in_batch
else:
record_deleted = cast(int, client.delete(*batch_keys))
total_records_deleted += record_deleted
while True:
batch = self._query(query)
if batch and total_records_deleted <= max_records_deleted:
batch_keys = [record["id"] for record in batch]
total_records_deleted += self._delete_batch(batch_keys)
else:
break

return total_records_deleted

Expand Down Expand Up @@ -1160,6 +1183,9 @@ def paginate(self, query: BaseQuery, page_size: int = 30) -> Generator:
batch contains. Adjust this value based on performance
considerations and the expected volume of search results.

Note:
For stable pagination, the query must have a `sort_by` clause.

"""
if not isinstance(page_size, int):
raise TypeError("page_size must be an integer")
Expand Down Expand Up @@ -1197,7 +1223,15 @@ def exists(self) -> bool:
def _info(name: str, redis_client: SyncRedisClient) -> Dict[str, Any]:
"""Run FT.INFO to fetch information about the index."""
try:
return convert_bytes(redis_client.ft(name).info()) # type: ignore
if isinstance(redis_client, SyncRedisCluster):
node = redis_client.get_random_node()
values = redis_client.execute_command(
"FT.INFO", name, target_nodes=node
)
info = make_dict(values)
else:
info = redis_client.ft(name).info()
return convert_bytes(info)
except Exception as e:
raise RedisSearchError(
f"Error while fetching {name} index info: {str(e)}"
Expand Down Expand Up @@ -1425,7 +1459,15 @@ async def _validate_client(
@staticmethod
async def _info(name: str, redis_client: AsyncRedisClient) -> Dict[str, Any]:
try:
return convert_bytes(await redis_client.ft(name).info())
if isinstance(redis_client, AsyncRedisCluster):
node = redis_client.get_random_node()
values = await redis_client.execute_command(
"FT.INFO", name, target_nodes=node
)
info = make_dict(values)
else:
info = await redis_client.ft(name).info()
return convert_bytes(info)
except Exception as e:
raise RedisSearchError(
f"Error while fetching {name} index info: {str(e)}"
Expand Down Expand Up @@ -1549,6 +1591,34 @@ async def delete(self, drop: bool = True):
except Exception as e:
raise RedisSearchError(f"Error while deleting index: {str(e)}") from e

async def _delete_batch(self, batch_keys: List[str]) -> int:
"""Delete a batch of keys from Redis.

For Redis Cluster, keys are deleted individually due to potential
cross-slot limitations. For standalone Redis, keys are deleted in
a single operation for better performance.

Args:
batch_keys (List[str]): List of Redis keys to delete.

Returns:
int: Count of records deleted from Redis.
"""
client = await self._get_client()
is_cluster = isinstance(client, AsyncRedisCluster)
if is_cluster:
records_deleted_in_batch = 0
for key_to_delete in batch_keys:
try:
records_deleted_in_batch += cast(
int, await client.delete(key_to_delete)
)
except redis.exceptions.RedisError as e:
logger.warning(f"Failed to delete key {key_to_delete}: {e}")
else:
records_deleted_in_batch = await client.delete(*batch_keys)
return records_deleted_in_batch

async def clear(self) -> int:
"""Clear all keys in Redis associated with the index, leaving the index
available and in-place for future insertions or updates.
Expand All @@ -1560,28 +1630,24 @@ async def clear(self) -> int:
Returns:
int: Count of records deleted from Redis.
"""
client = await self._get_client()
batch_size = 500
max_ratio = 1.01

info = await self.info()
max_records_deleted = ceil(
info["num_docs"] * max_ratio
) # Allow to remove some additional concurrent inserts
total_records_deleted: int = 0
query = FilterQuery(FilterExpression("*"), return_fields=["id"])
query.paging(0, batch_size)

async for batch in self.paginate(
FilterQuery(FilterExpression("*"), return_fields=["id"]), page_size=500
):
batch_keys = [record["id"] for record in batch]
if batch_keys:
is_cluster = isinstance(client, AsyncRedisCluster)
if is_cluster:
records_deleted_in_batch = 0
for key_to_delete in batch_keys:
try:
records_deleted_in_batch += cast(
int, await client.delete(key_to_delete)
)
except redis.exceptions.RedisError as e:
logger.warning(f"Failed to delete key {key_to_delete}: {e}")
total_records_deleted += records_deleted_in_batch
else:
records_deleted = await client.delete(*batch_keys)
total_records_deleted += records_deleted
while True:
batch = await self._query(query)
if batch and total_records_deleted <= max_records_deleted:
batch_keys = [record["id"] for record in batch]
total_records_deleted += await self._delete_batch(batch_keys)
else:
break

return total_records_deleted

Expand Down Expand Up @@ -2039,6 +2105,9 @@ async def paginate(self, query: BaseQuery, page_size: int = 30) -> AsyncGenerato
batch contains. Adjust this value based on performance
considerations and the expected volume of search results.

Note:
For stable pagination, the query must have a `sort_by` clause.

"""
if not isinstance(page_size, int):
raise TypeError("page_size must be of type int")
Expand Down
13 changes: 10 additions & 3 deletions tests/integration/test_async_search_index.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from random import choice
from unittest import mock

import pytest
Expand Down Expand Up @@ -245,14 +246,20 @@ async def test_search_index_delete(async_index):


@pytest.mark.asyncio
async def test_search_index_clear(async_index):
@pytest.mark.parametrize("num_docs", [0, 1, 5, 10, 2042])
async def test_search_index_clear(async_index, num_docs):
await async_index.create(overwrite=True, drop=True)
data = [{"id": "1", "test": "foo"}]
tags = ["foo", "bar", "baz"]
data = [{"id": str(i), "test": choice(tags)} for i in range(num_docs)]
await async_index.load(data, id_field="id")
info = await async_index.info()
assert info["num_docs"] == num_docs

count = await async_index.clear()
assert count == len(data)
assert count == num_docs
assert await async_index.exists()
info = await async_index.info()
assert info["num_docs"] == 0


@pytest.mark.asyncio
Expand Down
40 changes: 40 additions & 0 deletions tests/integration/test_redis_cluster_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,46 @@ def test_search_index_cluster_client(redis_cluster_url):
index.delete(drop=True)


@pytest.mark.requires_cluster
def test_search_index_cluster_info(redis_cluster_url):
"""Test .info() method on SearchIndex with RedisCluster client."""
schema = IndexSchema.from_dict(
{
"index": {"name": "test_cluster_info", "prefix": "test_info"},
"fields": [{"name": "name", "type": "text"}],
}
)
client = RedisCluster.from_url(redis_cluster_url)
index = SearchIndex(schema=schema, redis_client=client)
try:
index.create(overwrite=True)
info = index.info()
assert isinstance(info, dict)
assert info.get("index_name", None) == "test_cluster_info"
finally:
index.delete(drop=True)

@pytest.mark.requires_cluster
@pytest.mark.asyncio
async def test_async_search_index_cluster_info(redis_cluster_url):
"""Test .info() method on AsyncSearchIndex with AsyncRedisCluster client."""
schema = IndexSchema.from_dict(
{
"index": {"name": "async_cluster_info", "prefix": "async_info"},
"fields": [{"name": "name", "type": "text"}],
}
)
client = AsyncRedisCluster.from_url(redis_cluster_url)
index = AsyncSearchIndex(schema=schema, redis_client=client)
try:
await index.create(overwrite=True)
info = await index.info()
assert isinstance(info, dict)
assert info.get("index_name", None) == "async_cluster_info"
finally:
await index.delete(drop=True)
await client.aclose()

@pytest.mark.requires_cluster
@pytest.mark.asyncio
async def test_async_search_index_client(redis_cluster_url):
Expand Down
12 changes: 9 additions & 3 deletions tests/integration/test_search_index.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from random import choice
from unittest import mock

import pytest
Expand Down Expand Up @@ -303,15 +304,20 @@ def test_search_index_delete(index):
assert not index.exists()
assert index.name not in convert_bytes(index.client.execute_command("FT._LIST"))


def test_search_index_clear(index):
@pytest.mark.parametrize("num_docs", [0, 1, 5, 10, 2042])
def test_search_index_clear(index, num_docs):
index.create(overwrite=True, drop=True)
data = [{"id": "1", "test": "foo"}]
tags = ["foo", "bar", "baz"]
data = [{"id": str(i), "test": choice(tags)} for i in range(num_docs)]
index.load(data, id_field="id")
info = index.info()
assert info["num_docs"] == num_docs

count = index.clear()
assert count == len(data)
assert index.exists()
info = index.info()
assert info["num_docs"] == 0


def test_search_index_drop_key(index):
Expand Down
31 changes: 23 additions & 8 deletions tests/unit/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,10 +454,15 @@ def test_clear_individual_key_deletion_errors(self, mock_validate):
1, # Third succeeds
]

# Mock the paginate method to return test data
with patch.object(SearchIndex, "paginate") as mock_paginate:
mock_paginate.return_value = [
[{"id": "test:key1"}, {"id": "test:key2"}, {"id": "test:key3"}]
# Mock the .info() and ._query() methods to return test data
with (
patch.object(SearchIndex, "info") as mock_info,
patch.object(SearchIndex, "_query") as mock_query,
):
mock_info.return_value = {"num_docs": 3}
mock_query.side_effect = [
[{"id": "test:key1"}, {"id": "test:key2"}, {"id": "test:key3"}],
[],
]

# Create index with mocked client
Expand Down Expand Up @@ -502,11 +507,21 @@ async def test_async_clear_individual_key_deletion_errors(self, mock_validate):
]
)

# Mock the paginate method to return test data
async def mock_paginate_generator(*args, **kwargs):
yield [{"id": "test:key1"}, {"id": "test:key2"}, {"id": "test:key3"}]
# Mock the .info() and ._query() methods to return test data
async def mock_info(*args, **kwargs):
return {"num_docs": 3}

with patch.object(AsyncSearchIndex, "paginate", mock_paginate_generator):
mock_query = AsyncMock(
side_effect=[
[{"id": "test:key1"}, {"id": "test:key2"}, {"id": "test:key3"}],
[],
]
)

with (
patch.object(AsyncSearchIndex, "info", mock_info),
patch.object(AsyncSearchIndex, "_query", mock_query),
):
# Create index with mocked client
index = AsyncSearchIndex(schema)
index._redis_client = mock_cluster_client
Expand Down