diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index fa728730..b1744168 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -906,11 +906,13 @@ def query(self): return self._query self._query = self._resolve_redisearch_query(self.expression) if self.knn: - self._query = ( - self._query - if self._query.startswith("(") or self._query == "*" - else f"({self._query})" - ) + f"=>[{self.knn}]" + # Always wrap the filter expression in parentheses when combining with KNN, + # unless it's the wildcard "*". This ensures OR expressions like + # "(A)| (B)" become "((A)| (B))=>[KNN ...]" instead of the invalid + # "(A)| (B)=>[KNN ...]" where KNN only applies to the second term. + if self._query != "*": + self._query = f"({self._query})" + self._query += f"=>[{self.knn}]" # RETURN clause should be added to args, not to the query string return self._query diff --git a/tests/test_knn_expression.py b/tests/test_knn_expression.py index 1e836759..09a5d911 100644 --- a/tests/test_knn_expression.py +++ b/tests/test_knn_expression.py @@ -113,3 +113,91 @@ async def test_nested_vector_field(n: Type[JsonModel]): assert len(members) == 1 assert members[0].embeddings_score is not None + + + +@pytest_asyncio.fixture +async def album_model(key_prefix, redis): + """Fixture for testing OR expressions with KNN.""" + class BaseJsonModel(JsonModel, abc.ABC): + class Meta: + global_key_prefix = key_prefix + database = redis + + vector_options = VectorFieldOptions.flat( + type=VectorFieldOptions.TYPE.FLOAT32, + dimension=2, + distance_metric=VectorFieldOptions.DISTANCE_METRIC.COSINE, + ) + + class Album(BaseJsonModel, index=True): + title: str = Field(primary_key=True) + tags: str = Field(index=True) + title_embeddings: list[float] = Field( + [], index=True, vector_options=vector_options + ) + embeddings_score: Optional[float] = None + + await Migrator(conn=redis).run() + + return Album + + +@py_test_mark_asyncio +async def test_or_expression_with_knn(album_model): + """Test that OR expressions work correctly with KNN. + + Regression test for GitHub issue #557: Using an OR expression with a + KNN expression raises ResponseError with syntax error. + """ + Album = album_model + + # Create test data + albums = [ + Album( + title="Rumours", + tags="Genre:rock|Decade:70s", + title_embeddings=[0.7, 0.3], + ), + Album( + title="Abbey Road", + tags="Genre:rock|Decade:60s", + title_embeddings=[0.6, 0.4], + ), + Album( + title="The Dark Side Of The Moon", + tags="Genre:prog-rock|Decade:70s", + title_embeddings=[0.5, 0.5], + ), + ] + for album in albums: + await album.save() + + # Create OR expression + or_expr = (Album.tags == "Genre:rock|Decade:70s") | ( + Album.tags == "Genre:rock|Decade:60s" + ) + + # Create KNN expression + knn = KNNExpression( + k=3, + vector_field=Album.title_embeddings, + score_field=Album.embeddings_score, + reference_vector=to_bytes([0.65, 0.35]), + ) + + # Query with just OR expression (should work) + or_results = await Album.find(or_expr).all() + assert len(or_results) == 2 + + # Query with just KNN (should work) + knn_results = await Album.find(knn=knn).all() + assert len(knn_results) == 3 + + # Query with OR expression AND KNN (this was failing before the fix) + combined_results = await Album.find(or_expr, knn=knn).all() + # Should return only the 2 albums matching the OR expression + assert len(combined_results) == 2 + # All results should have an embeddings score from KNN + for result in combined_results: + assert result.embeddings_score is not None