From 26072365377f79e2d30a97bf168ab4a0c5e9f84e Mon Sep 17 00:00:00 2001 From: Helia Mohammadi Date: Sat, 27 Sep 2025 17:51:45 -0400 Subject: [PATCH 01/13] feat: expnad preprocessing to a multi-step workflow. - Implement parallel execution of safety and scope check, query expansion, and language detection --- src/agent/profiles/base.py | 85 ++++++++++---- src/agent/profiles/react_to_me.py | 106 +++++++++++++++--- .../unsafe_question.py | 42 +++++++ src/tools/preprocessing/__init__.py | 6 + src/tools/preprocessing/state.py | 20 ++++ src/tools/preprocessing/workflow.py | 80 +++++++++++++ 6 files changed, 298 insertions(+), 41 deletions(-) create mode 100644 src/agent/tasks/final_answer_generation/unsafe_question.py create mode 100644 src/tools/preprocessing/__init__.py create mode 100644 src/tools/preprocessing/state.py create mode 100644 src/tools/preprocessing/workflow.py diff --git a/src/agent/profiles/base.py b/src/agent/profiles/base.py index 9a6e26c..6536978 100644 --- a/src/agent/profiles/base.py +++ b/src/agent/profiles/base.py @@ -1,4 +1,4 @@ -from typing import Annotated, TypedDict +from typing import Annotated, Literal, TypedDict from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel @@ -6,53 +6,88 @@ from langchain_core.runnables import Runnable, RunnableConfig from langgraph.graph.message import add_messages -from agent.tasks.rephrase import create_rephrase_chain from tools.external_search.state import SearchState, WebSearchResult from tools.external_search.workflow import create_search_workflow +from tools.preprocessing.state import PreprocessingState +from tools.preprocessing.workflow import create_preprocessing_workflow + +# Constants +SAFETY_SAFE: Literal["true"] = "true" +SAFETY_UNSAFE: Literal["false"] = "false" +DEFAULT_LANGUAGE: str = "English" class AdditionalContent(TypedDict, total=False): + """Additional content sent on graph completion.""" + search_results: list[WebSearchResult] class InputState(TypedDict, total=False): - user_input: str # User input text + """Input state for user queries.""" + + user_input: str class OutputState(TypedDict, total=False): - answer: str # primary LLM response that is streamed to the user - additional_content: AdditionalContent # sends on graph completion + """Output state for responses.""" + + answer: str + additional_content: AdditionalContent class BaseState(InputState, OutputState, total=False): - rephrased_input: str # LLM-generated query from user input + """Base state containing all common fields for agent workflows.""" + + rephrased_input: str chat_history: Annotated[list[BaseMessage], add_messages] + # Preprocessing results + safety: str = SAFETY_SAFE + reason_unsafe: str = "" + expanded_queries: list[str] = [] + detected_language: str = DEFAULT_LANGUAGE + class BaseGraphBuilder: - # NOTE: Anything that is common to all graph builders goes here - - def __init__( - self, - llm: BaseChatModel, - embedding: Embeddings, - ) -> None: - self.rephrase_chain: Runnable = create_rephrase_chain(llm) + """Base class for all graph builders with common preprocessing and postprocessing.""" + + def __init__(self, llm: BaseChatModel, embedding: Embeddings) -> None: + """Initialize with LLM and embedding models.""" + self.preprocessing_workflow: Runnable = create_preprocessing_workflow(llm) self.search_workflow: Runnable = create_search_workflow(llm) async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: - rephrased_input: str = await self.rephrase_chain.ainvoke( - { - "user_input": state["user_input"], - "chat_history": state["chat_history"], - }, + """Run the complete preprocessing workflow and map results to state.""" + result: PreprocessingState = await self.preprocessing_workflow.ainvoke( + PreprocessingState( + user_input=state["user_input"], + chat_history=state["chat_history"], + ), config, ) - return BaseState(rephrased_input=rephrased_input) + + return self._map_preprocessing_result(result) + + def _map_preprocessing_result(self, result: PreprocessingState) -> BaseState: + """Map preprocessing results to BaseState with defaults.""" + return BaseState( + rephrased_input=result["rephrased_input"], + safety=result.get("safety", SAFETY_SAFE), + reason_unsafe=result.get("reason_unsafe", ""), + expanded_queries=result.get("expanded_queries", []), + detected_language=result.get("detected_language", DEFAULT_LANGUAGE), + ) async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: + """Postprocess that preserves existing state and conditionally adds search results.""" search_results: list[WebSearchResult] = [] - if config["configurable"]["enable_postprocess"]: + + # Only run external search for safe questions + if ( + state.get("safety") == SAFETY_SAFE + and config["configurable"]["enable_postprocess"] + ): result: SearchState = await self.search_workflow.ainvoke( SearchState( input=state["rephrased_input"], @@ -61,6 +96,10 @@ async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseSta config=RunnableConfig(callbacks=config["callbacks"]), ) search_results = result["search_results"] - return BaseState( - additional_content=AdditionalContent(search_results=search_results) + + # Create new state with updated additional_content + new_state = dict(state) # Copy existing state + new_state["additional_content"] = AdditionalContent( + search_results=search_results ) + return BaseState(**new_state) diff --git a/src/agent/profiles/react_to_me.py b/src/agent/profiles/react_to_me.py index c162ac7..f0e4980 100644 --- a/src/agent/profiles/react_to_me.py +++ b/src/agent/profiles/react_to_me.py @@ -1,52 +1,123 @@ -from typing import Any +from typing import Any, Literal from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import AIMessage, HumanMessage from langchain_core.runnables import Runnable, RunnableConfig +from langchain_openai import ChatOpenAI from langgraph.graph.state import StateGraph -from agent.profiles.base import BaseGraphBuilder, BaseState +from agent.profiles.base import (SAFETY_SAFE, SAFETY_UNSAFE, BaseGraphBuilder, + BaseState) +from agent.tasks.final_answer_generation.unsafe_question import \ + create_unsafe_answer_generator from retrievers.reactome.rag import create_reactome_rag class ReactToMeState(BaseState): + """ReactToMe state extends BaseState with all preprocessing results.""" + pass class ReactToMeGraphBuilder(BaseGraphBuilder): - def __init__( - self, - llm: BaseChatModel, - embedding: Embeddings, - ) -> None: + """Graph builder for ReactToMe profile with Reactome-specific functionality.""" + + def __init__(self, llm: BaseChatModel, embedding: Embeddings) -> None: + """Initialize ReactToMe graph builder with required components.""" super().__init__(llm, embedding) - # Create runnables (tasks & tools) + # Create a streaming LLM instance only for final answer generation + streaming_llm = ChatOpenAI( + model=llm.model_name if hasattr(llm, "model_name") else "gpt-4o-mini", + temperature=0.0, + streaming=True, + ) + + self.unsafe_answer_generator = create_unsafe_answer_generator(streaming_llm) self.reactome_rag: Runnable = create_reactome_rag( - llm, embedding, streaming=True + streaming_llm, embedding, streaming=True ) - # Create graph + self.uncompiled_graph: StateGraph = self._build_workflow() + + def _build_workflow(self) -> StateGraph: + """Build and configure the ReactToMe workflow graph.""" state_graph = StateGraph(ReactToMeState) - # Set up nodes + + # Add workflow nodes state_graph.add_node("preprocess", self.preprocess) state_graph.add_node("model", self.call_model) + state_graph.add_node("generate_unsafe_response", self.generate_unsafe_response) state_graph.add_node("postprocess", self.postprocess) - # Set up edges + + # Configure workflow edges state_graph.set_entry_point("preprocess") - state_graph.add_edge("preprocess", "model") + state_graph.add_conditional_edges( + "preprocess", + self.proceed_with_research, + {"Continue": "model", "Finish": "generate_unsafe_response"}, + ) state_graph.add_edge("model", "postprocess") + state_graph.add_edge("generate_unsafe_response", "postprocess") state_graph.set_finish_point("postprocess") - self.uncompiled_graph: StateGraph = state_graph + return state_graph + + async def preprocess( + self, state: ReactToMeState, config: RunnableConfig + ) -> ReactToMeState: + """Run preprocessing workflow.""" + result = await super().preprocess(state, config) + return ReactToMeState(**result) + + async def proceed_with_research( + self, state: ReactToMeState + ) -> Literal["Continue", "Finish"]: + """Determine whether to proceed with research based on safety check.""" + return "Continue" if state["safety"] == SAFETY_SAFE else "Finish" + + async def generate_unsafe_response( + self, state: ReactToMeState, config: RunnableConfig + ) -> ReactToMeState: + """Generate appropriate refusal response for unsafe queries.""" + final_answer_message = await self.unsafe_answer_generator.ainvoke( + { + "language": state["detected_language"], + "user_input": state["rephrased_input"], + "reason_unsafe": state["reason_unsafe"], + }, + config, + ) + + final_answer = ( + final_answer_message.content + if hasattr(final_answer_message, "content") + else str(final_answer_message) + ) + + return ReactToMeState( + chat_history=[ + HumanMessage(state["user_input"]), + ( + final_answer_message + if hasattr(final_answer_message, "content") + else AIMessage(final_answer) + ), + ], + answer=final_answer, + safety=SAFETY_UNSAFE, + additional_content={"search_results": []}, + ) async def call_model( self, state: ReactToMeState, config: RunnableConfig ) -> ReactToMeState: + """Generate response using Reactome RAG for safe queries.""" result: dict[str, Any] = await self.reactome_rag.ainvoke( { "input": state["rephrased_input"], + "expanded_queries": state.get("expanded_queries", []), "chat_history": ( state["chat_history"] if state["chat_history"] @@ -55,6 +126,7 @@ async def call_model( }, config, ) + return ReactToMeState( chat_history=[ HumanMessage(state["user_input"]), @@ -64,8 +136,6 @@ async def call_model( ) -def create_reactome_graph( - llm: BaseChatModel, - embedding: Embeddings, -) -> StateGraph: +def create_reactome_graph(llm: BaseChatModel, embedding: Embeddings) -> StateGraph: + """Create and return the ReactToMe workflow graph.""" return ReactToMeGraphBuilder(llm, embedding).uncompiled_graph diff --git a/src/agent/tasks/final_answer_generation/unsafe_question.py b/src/agent/tasks/final_answer_generation/unsafe_question.py new file mode 100644 index 0000000..5369332 --- /dev/null +++ b/src/agent/tasks/final_answer_generation/unsafe_question.py @@ -0,0 +1,42 @@ +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.output_parsers import StrOutputParser +from langchain_core.runnables import Runnable + +# unsafe or out of scope answer generator +def create_unsafe_answer_generator(llm: BaseChatModel) -> Runnable: + """ + Create an unsafe answer generator chain. + + Args: + llm: Language model to use + + Returns: + Runnable that takes language, user_input, reactome_context, uniprot_context, chat_history + """ + system_prompt = """ + You are an expert scientific assistant operating under the React-to-Me platform. React-to-Me helps both experts and non-experts explore molecular biology using trusted data from the Reactome database. + +You have advanced training in scientific ethics, dual-use research concerns, and responsible AI use. + +You will receive three inputs: +1. The user's question. +2. A system-generated variable called `reason_unsafe`, which explains why the question cannot be answered. +3. The user's preferred language (as a language code or name). + +Your task is to clearly, respectfully, and firmly explain to the user *why* their question cannot be answered, based solely on the `reason_unsafe` input. Do **not** attempt to answer, rephrase, or guide the user toward answering the original question. + +You must: +- Respond in the user’s preferred language. +- Politely explain the refusal, grounded in the `reason_unsafe`. +- Emphasize React-to-Me’s mission: to support responsible exploration of molecular biology through trusted databases. +- Suggest examples of appropriate topics (e.g., protein function, pathways, gene interactions using Reactome/UniProt). + +You must not provide any workaround, implicit answer, or redirection toward unsafe content. +""" + prompt = ChatPromptTemplate.from_messages([ + ("system", system_prompt), + ("user", "Language:{language}\n\nQuestion:{user_input}\n\n Reason for unsafe or out of scope: {reason_unsafe}") + ]) + + return prompt | llm \ No newline at end of file diff --git a/src/tools/preprocessing/__init__.py b/src/tools/preprocessing/__init__.py new file mode 100644 index 0000000..f2eb429 --- /dev/null +++ b/src/tools/preprocessing/__init__.py @@ -0,0 +1,6 @@ +"""Preprocessing workflow for query enhancement and validation.""" + +from .state import PreprocessingState +from .workflow import create_preprocessing_workflow + +__all__ = ["create_preprocessing_workflow", "PreprocessingState"] diff --git a/src/tools/preprocessing/state.py b/src/tools/preprocessing/state.py new file mode 100644 index 0000000..34d162f --- /dev/null +++ b/src/tools/preprocessing/state.py @@ -0,0 +1,20 @@ +from typing import TypedDict + +from langchain_core.messages import BaseMessage + + +class PreprocessingState(TypedDict, total=False): + """State for the preprocessing workflow.""" + + # Input + user_input: str # Original user input + chat_history: list[BaseMessage] # Conversation history + + # Step 1: Rephrase and incorporate conversation history + rephrased_input: str # Standalone question with context + + # Step 2: Parallel processing + safety: str # "true" or "false" from safety check + reason_unsafe: str # Reason if unsafe + expanded_queries: list[str] # Alternative queries for better retrieval + detected_language: str # Detected language (e.g., "English", "French") diff --git a/src/tools/preprocessing/workflow.py b/src/tools/preprocessing/workflow.py new file mode 100644 index 0000000..889fcb7 --- /dev/null +++ b/src/tools/preprocessing/workflow.py @@ -0,0 +1,80 @@ +from typing import Any, Callable + +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.runnables import Runnable, RunnableConfig +from langgraph.graph import StateGraph +from langgraph.graph.state import CompiledStateGraph +from langgraph.utils.runnable import RunnableLike + +from agent.tasks.detect_language import create_language_detector +from agent.tasks.query_expansion import create_query_expander +from agent.tasks.rephrase import create_rephrase_chain +from agent.tasks.safety_checker import create_safety_checker +from tools.preprocessing.state import PreprocessingState + + +def create_task_wrapper( + task: Runnable, + input_mapper: Callable[[PreprocessingState], dict[str, Any]], + output_mapper: Callable[[Any], PreprocessingState], +) -> RunnableLike: + """Generic wrapper for preprocessing tasks.""" + + async def _wrapper( + state: PreprocessingState, config: RunnableConfig + ) -> PreprocessingState: + result = await task.ainvoke(input_mapper(state), config) + return output_mapper(result) + + return _wrapper + + +def create_preprocessing_workflow(llm: BaseChatModel) -> CompiledStateGraph: + """Create a preprocessing workflow with rephrasing and parallel processing.""" + + # Task configurations + tasks = { + "rephrase_query": ( + create_rephrase_chain(llm), + lambda state: { + "user_input": state["user_input"], + "chat_history": state["chat_history"], + }, + lambda result: PreprocessingState(rephrased_input=result), + ), + "safety_check": ( + create_safety_checker(llm), + lambda state: {"rephrased_input": state["rephrased_input"]}, + lambda result: PreprocessingState( + safety=result.safety, reason_unsafe=result.reason_unsafe + ), + ), + "query_expansion": ( + create_query_expander(llm), + lambda state: {"rephrased_input": state["rephrased_input"]}, + lambda result: PreprocessingState(expanded_queries=result), + ), + "detect_language": ( + create_language_detector(llm), + lambda state: {"user_input": state["user_input"]}, + lambda result: PreprocessingState(detected_language=result), + ), + } + + workflow = StateGraph(PreprocessingState) + + # Add nodes + for node_name, (task, input_mapper, output_mapper) in tasks.items(): + workflow.add_node( + node_name, create_task_wrapper(task, input_mapper, output_mapper) + ) + + # Configure workflow + workflow.set_entry_point("rephrase_query") + + # Parallel execution after rephrasing + for parallel_node in ["safety_check", "query_expansion", "detect_language"]: + workflow.add_edge("rephrase_query", parallel_node) + workflow.set_finish_point(parallel_node) + + return workflow.compile() From 8b3502967183de0baf29797dbc63572e96086438 Mon Sep 17 00:00:00 2001 From: Helia Mohammadi Date: Sat, 27 Sep 2025 18:00:06 -0400 Subject: [PATCH 02/13] feat: Add new runnables for checking question safety and scope, query expansion and conversation history management --- src/agent/graph.py | 2 +- src/agent/models.py | 2 +- src/agent/tasks/rephrase.py | 5 +- src/agent/tasks/safety_checker.py | 86 ++++++++++++++++++++++++------- 4 files changed, 70 insertions(+), 25 deletions(-) diff --git a/src/agent/graph.py b/src/agent/graph.py index 012df27..2fab414 100644 --- a/src/agent/graph.py +++ b/src/agent/graph.py @@ -97,4 +97,4 @@ async def ainvoke( }, ), ) - return result + return result \ No newline at end of file diff --git a/src/agent/models.py b/src/agent/models.py index 01b324c..b80413a 100644 --- a/src/agent/models.py +++ b/src/agent/models.py @@ -65,4 +65,4 @@ def get_llm( base_url=base_url, ) else: - raise ValueError(f"Unknown provider: {provider}") + raise ValueError(f"Unknown provider: {provider}") \ No newline at end of file diff --git a/src/agent/tasks/rephrase.py b/src/agent/tasks/rephrase.py index 1851747..7c1be6a 100644 --- a/src/agent/tasks/rephrase.py +++ b/src/agent/tasks/rephrase.py @@ -13,8 +13,7 @@ the returned question should always be in English. If the user’s question is already in English, self-contained and well-formed, return it as is. -Do NOT answer the question or provide explanations. -""" +Do NOT answer the question or provide explanations.""" contextualize_q_prompt = ChatPromptTemplate.from_messages( [ @@ -28,4 +27,4 @@ def create_rephrase_chain(llm: BaseChatModel) -> Runnable: return (contextualize_q_prompt | llm | StrOutputParser()).with_config( run_name="rephrase_question" - ) + ) \ No newline at end of file diff --git a/src/agent/tasks/safety_checker.py b/src/agent/tasks/safety_checker.py index ac372e3..5e080d6 100644 --- a/src/agent/tasks/safety_checker.py +++ b/src/agent/tasks/safety_checker.py @@ -1,30 +1,76 @@ +from langchain.prompts import ChatPromptTemplate from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import Runnable from pydantic import BaseModel, Field -safety_check_message = """ -You are an expert grader with a strong understanding of ethical guidelines and responsible AI usage. - -Your task is to determine whether a user-submitted question is appropriate and not harmful. -Provide a binary output as either: -Yes: The question is appropriate and not harmful. -No: The question is harmful, offensive, inappropriate, or unethical. -""" - -safety_check_prompt = ChatPromptTemplate.from_messages( - [ - ("system", safety_check_message), - ("human", "User question: \n\n {input}"), - ] -) - class SafetyCheck(BaseModel): - binary_score: str = Field( - description="Indicates whether the question is appropriate and related to molecular biology. Expected values: 'Yes' or 'No'." + safety: str = Field( + description="Indicates whether the question is appropriate and related to molecular biology. Expected values: 'true' or 'false'." + ) + reason_unsafe: str = Field( + description="If 'safety' is false, briefly state the reason; if 'safety' is true, leave this field empty." ) def create_safety_checker(llm: BaseChatModel) -> Runnable: - return safety_check_prompt | llm.with_structured_output(SafetyCheck) + """ + Create a safety checker chain. + + Args: + llm: Language model to use + + Returns: + Runnable that takes user_input and returns SafetyCheck + """ + safety_check_message = """ +You are an expert scientific assistant. You have advanced training in scientific ethics, dual-use assessment, and responsible AI. + +Tasks: + +1. Safety & Appropriateness Check + - Assess the question and determine if it is safe and appropriate for responsible AI use. + - The question and any plausible answer to it must NOT be harmful, offensive, unethical, or violate scientific or ethical guidelines. + - The question or its likely answers must NOT request, describe, or explore how to carry out harmful, illegal, dangerous, or dual-use activities (such as modifying pathogens, increasing infectivity, or other forms of misuse)—even if the scenario is hypothetical, fictional, or indirect. + - The question must NOT request personal, medical, legal, or other types of advice outside the scope of scientific knowledge. + - If a question is ambiguous or could plausibly be used for harm, treat it as unsafe. + - Treat hypothetical, fictional, or made-up scenarios with the same level of scrutiny as real-world questions. + +2. Reactome Relevance Check + - Determine if the question is relevant to biology, life sciences, molecular biology, or related topics. + - Mark questions as not relevant if they are about unrelated topics (such as programming, math, history, trivia, etc.). + +IMPORTANT: + - If the standalone question is unsafe or not relevant return "safety": "false". + - If the standalone question is both safe **and** relevant, return "safety": "true". + +Return only a JSON object in the following format: + "safety": "true" or "false", // Use string, not boolean. Mark as "false" if unsafe OR not relevant. + "reason_unsafe": "..." // If 'safety' is false, briefly state the reason + +Examples: + 1. Q: Describe methods for bypassing cellular immune responses to viral infection + "safety": "false", + "reason_unsafe": "Could be used for harmful or unethical manipulation of immune evasion." + + 2. Q: What are common side effects of statins in the treatment of high cholesterol? + "safety": "false", + "reason_unsafe": "Requests medical advice, which is out of scientific scope." + + 3. Q: How can gene drives be used to eliminate invasive species in an ecosystem? + "safety": "false", + "reason_unsafe": "Enables the use of gene drives, which have significant ecological and dual-use risks." + + 4. Q: What is the role of the immune system in the treatment of cancer? + "safety": "true", + "reason_unsafe": "" +""" + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", safety_check_message), + ("human", "User question: \n\n {rephrased_input}"), + ] + ) + + return prompt | llm.with_structured_output(SafetyCheck) \ No newline at end of file From 8b2578f9526880920e323b953647c23f84c99b4e Mon Sep 17 00:00:00 2001 From: Helia Mohammadi Date: Sat, 27 Sep 2025 18:04:13 -0400 Subject: [PATCH 03/13] feat:improved hybrid retrieval - Replace SelfQueryRetriever with efficient hybrid search (BM25 + vector) - Add RRF (Reciprocal Rank Fusion) support for query expansion - Implement parallel processing for improved performance --- src/retrievers/csv_chroma.py | 350 ++++++++++++++++++++++++++---- src/retrievers/rag_chain.py | 73 +++++-- src/retrievers/reactome/prompt.py | 43 ++-- src/retrievers/reactome/rag.py | 34 +-- src/retrievers/retrieval_utils.py | 27 +++ src/retrievers/uniprot/prompt.py | 2 +- src/retrievers/uniprot/rag.py | 32 +-- 7 files changed, 448 insertions(+), 113 deletions(-) create mode 100644 src/retrievers/retrieval_utils.py diff --git a/src/retrievers/csv_chroma.py b/src/retrievers/csv_chroma.py index 691b884..15a1a99 100644 --- a/src/retrievers/csv_chroma.py +++ b/src/retrievers/csv_chroma.py @@ -1,70 +1,326 @@ +import asyncio +import hashlib +import logging from pathlib import Path +from typing import Any, Dict, List, Optional, Union import chromadb.config -from langchain.chains.query_constructor.schema import AttributeInfo -from langchain.retrievers import EnsembleRetriever -from langchain.retrievers.merger_retriever import MergerRetriever -from langchain.retrievers.self_query.base import SelfQueryRetriever +import pandas as pd from langchain_chroma.vectorstores import Chroma -from langchain_community.document_loaders.csv_loader import CSVLoader from langchain_community.retrievers import BM25Retriever +from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.retrievers import BaseRetriever -from nltk.tokenize import word_tokenize -chroma_settings = chromadb.config.Settings(anonymized_telemetry=False) +from retrievers.retrieval_utils import reciprocal_rank_fusion +logger = logging.getLogger(__name__) -def list_chroma_subdirectories(directory: Path) -> list[str]: - subdirectories = list( +CHROMA_SETTINGS = chromadb.config.Settings(anonymized_telemetry=False) +DEFAULT_RETRIEVAL_K = 20 +RRF_FINAL_K = 10 +RRF_LAMBDA_MULTIPLIER = 60.0 +EXCLUDED_CONTENT_COLUMNS = {"st_id"} + + +def create_documents_from_csv(csv_path: Path) -> List[Document]: + """Create Document objects from CSV file with proper metadata extraction.""" + if not csv_path.exists(): + raise FileNotFoundError(f"CSV file not found: {csv_path}") + + try: + df = pd.read_csv(csv_path) + if df.empty: + raise ValueError(f"CSV file is empty: {csv_path}") + except Exception as e: + raise ValueError(f"Failed to read CSV file {csv_path}: {e}") + + documents = [] + + for index, row in df.iterrows(): + content_parts = [] + for column in df.columns: + if column not in EXCLUDED_CONTENT_COLUMNS: + value = str(row[column]) if pd.notna(row[column]) else "" + if value and value != "nan": + content_parts.append(f"{column}: {value}") + + page_content = "\n".join(content_parts) + + metadata = { + column: str(value) + for column in df.columns + for value in [row[column]] + if pd.notna(value) and str(value) != "nan" + } + metadata.update({"source": str(csv_path), "row_index": index}) + + documents.append(Document(page_content=page_content, metadata=metadata)) + + return documents + + +def list_chroma_subdirectories(directory: Path) -> List[str]: + """Discover all subdirectories containing ChromaDB files.""" + if not directory.exists(): + raise ValueError(f"Directory does not exist: {directory}") + + subdirectories = [ chroma_file.parent.name for chroma_file in directory.glob("*/chroma.sqlite3") - ) + ] + + if not subdirectories: + logger.warning(f"No ChromaDB subdirectories found in {directory}") + return subdirectories -def create_bm25_chroma_ensemble_retriever( - llm: BaseChatModel, - embedding: Embeddings, - embeddings_directory: Path, - *, - descriptions_info: dict[str, str], - field_info: dict[str, list[AttributeInfo]], -) -> MergerRetriever: - retriever_list: list[BaseRetriever] = [] - for subdirectory in list_chroma_subdirectories(embeddings_directory): - # set up BM25 retriever - csv_file_name = subdirectory + ".csv" - reactome_csvs_dir: Path = embeddings_directory / "csv_files" - loader = CSVLoader(file_path=reactome_csvs_dir / csv_file_name) - data = loader.load() - bm25_retriever = BM25Retriever.from_documents( - data, - preprocess_func=lambda text: word_tokenize( - text.casefold(), language="english" - ), +class HybridRetriever: + """Advanced hybrid retriever supporting RRF, parallel processing, and multi-source search.""" + + def __init__(self, embedding: Embeddings, embeddings_directory: Path): + self.embedding = embedding + self.embeddings_directory = embeddings_directory + self._retrievers: Dict[ + str, Dict[str, Optional[Union[BM25Retriever, object]]] + ] = {} + + try: + self._initialize_retrievers() + except Exception as e: + logger.error(f"Failed to initialize hybrid retriever: {e}") + raise RuntimeError(f"Hybrid retriever initialization failed: {e}") from e + + def _initialize_retrievers(self) -> None: + """Initialize BM25 and vector retrievers for all discovered subdirectories.""" + subdirectories = list_chroma_subdirectories(self.embeddings_directory) + + if not subdirectories: + raise ValueError(f"No subdirectories found in {self.embeddings_directory}") + + for subdirectory in subdirectories: + bm25_retriever = self._create_bm25_retriever(subdirectory) + vector_retriever = self._create_vector_retriever(subdirectory) + + self._retrievers[subdirectory] = { + "bm25": bm25_retriever, + "vector": vector_retriever, + } + + logger.info(f"Initialized retrievers for {len(subdirectories)} subdirectories") + + def _create_bm25_retriever(self, subdirectory: str) -> Optional[BM25Retriever]: + """Create BM25 retriever for a specific subdirectory.""" + csv_path = self.embeddings_directory / "csv_files" / f"{subdirectory}.csv" + + if not csv_path.exists(): + logger.warning(f"CSV file not found for {subdirectory}: {csv_path}") + return None + + try: + documents = create_documents_from_csv(csv_path) + retriever = BM25Retriever.from_documents(documents) + retriever.k = DEFAULT_RETRIEVAL_K + logger.debug( + f"Created BM25 retriever for {subdirectory} with {len(documents)} documents" + ) + return retriever + except Exception as e: + logger.error(f"Failed to create BM25 retriever for {subdirectory}: {e}") + return None + + def _create_vector_retriever(self, subdirectory: str) -> Optional[object]: + """Create vector retriever for a specific subdirectory.""" + vector_directory = self.embeddings_directory / subdirectory + + if not vector_directory.exists(): + logger.warning( + f"Vector directory not found for {subdirectory}: {vector_directory}" + ) + return None + + try: + vector_store = Chroma( + persist_directory=str(vector_directory), + embedding_function=self.embedding, + client_settings=CHROMA_SETTINGS, + ) + retriever = vector_store.as_retriever( + search_kwargs={"k": DEFAULT_RETRIEVAL_K} + ) + logger.debug(f"Created vector retriever for {subdirectory}") + return retriever + except Exception as e: + logger.error(f"Failed to create vector retriever for {subdirectory}: {e}") + return None + + async def _search_with_bm25( + self, query: str, retriever: BM25Retriever + ) -> List[Document]: + """Search using BM25 retriever asynchronously.""" + return await asyncio.to_thread(retriever.get_relevant_documents, query) + + async def _search_with_vector( + self, query: str, retriever: object + ) -> List[Document]: + """Search using vector retriever asynchronously.""" + return await asyncio.to_thread(retriever.get_relevant_documents, query) + + async def _execute_hybrid_search( + self, query: str, subdirectory: str + ) -> List[Document]: + """Execute hybrid search (BM25 + vector) for a single query on a subdirectory.""" + retriever_info = self._retrievers.get(subdirectory) + if not retriever_info: + logger.warning(f"No retrievers found for subdirectory: {subdirectory}") + return [] + + search_tasks = [] + + if retriever_info["bm25"]: + search_tasks.append(self._search_with_bm25(query, retriever_info["bm25"])) + + if retriever_info["vector"]: + search_tasks.append( + self._search_with_vector(query, retriever_info["vector"]) + ) + + if not search_tasks: + logger.warning(f"No active retrievers for subdirectory: {subdirectory}") + return [] + + try: + search_results = await asyncio.gather(*search_tasks, return_exceptions=True) + + combined_documents = [] + for result in search_results: + if isinstance(result, list): + combined_documents.extend(result) + elif isinstance(result, Exception): + logger.error(f"Search error in {subdirectory}: {result}") + + return combined_documents + except Exception as e: + logger.error(f"Failed to execute hybrid search for {subdirectory}: {e}") + return [] + + def _generate_document_identifier(self, document: Document) -> str: + """Generate unique identifier for a document.""" + for field in ["url", "id", "st_id"]: + if document.metadata.get(field): + return document.metadata[field] + + return hashlib.md5(document.page_content.encode()).hexdigest() + + async def _apply_reciprocal_rank_fusion( + self, queries: List[str], subdirectory: str + ) -> List[Document]: + """Apply Reciprocal Rank Fusion to results from multiple queries on a subdirectory.""" + logger.info( + f"Executing hybrid search for {len(queries)} queries in {subdirectory}" ) - bm25_retriever.k = 10 - # set up vectorstore SelfQuery retriever - vectordb = Chroma( - persist_directory=str(embeddings_directory / subdirectory), - embedding_function=embedding, - client_settings=chroma_settings, + search_tasks = [ + self._execute_hybrid_search(query, subdirectory) for query in queries + ] + all_search_results = await asyncio.gather(*search_tasks, return_exceptions=True) + + valid_result_sets = [] + for i, result in enumerate(all_search_results): + if isinstance(result, list): + valid_result_sets.append(result) + logger.debug(f"Query {i+1}: {len(result)} results") + elif isinstance(result, Exception): + logger.error(f"Query {i+1} failed: {result}") + + if not valid_result_sets: + logger.warning(f"No valid results for {subdirectory}") + return [] + + logger.info( + f"Applying RRF to {len(valid_result_sets)} result sets in {subdirectory}" ) - selfq_retriever = SelfQueryRetriever.from_llm( - llm=llm, - vectorstore=vectordb, - document_contents=descriptions_info[subdirectory], - metadata_field_info=field_info[subdirectory], - search_kwargs={"k": 10}, + top_documents, _, rrf_scores = reciprocal_rank_fusion( + ranked_lists=valid_result_sets, + final_k=RRF_FINAL_K, + lambda_mult=RRF_LAMBDA_MULTIPLIER, + rrf_k=None, + id_getter=self._generate_document_identifier, + ) + + logger.info(f"RRF completed for {subdirectory}: {len(top_documents)} documents") + if rrf_scores: + top_scores = dict(list(rrf_scores.items())[:3]) + logger.debug(f"Top RRF scores: {top_scores}") + + return top_documents + + async def ainvoke(self, inputs: Dict[str, Any]) -> str: + """Main retrieval method supporting RRF and parallel processing.""" + original_query = inputs.get("input", "").strip() + if not original_query: + raise ValueError("Input query cannot be empty") + + expanded_queries = inputs.get("expanded_queries", []) + all_queries = [original_query] + (expanded_queries or []) + + logger.info( + f"Processing {len(all_queries)} queries across {len(self._retrievers)} subdirectories" ) - rrf_retriever = EnsembleRetriever( - retrievers=[bm25_retriever, selfq_retriever], weights=[0.2, 0.8] + for i, query in enumerate(all_queries, 1): + logger.debug(f"Query {i}: {query}") + + rrf_tasks = [ + self._apply_reciprocal_rank_fusion(all_queries, subdirectory) + for subdirectory in self._retrievers.keys() + ] + + subdirectory_results = await asyncio.gather(*rrf_tasks, return_exceptions=True) + + context_parts = [] + + for i, subdirectory in enumerate(self._retrievers.keys()): + result = subdirectory_results[i] + + if isinstance(result, Exception): + logger.error(f"Subdirectory {subdirectory} failed: {result}") + continue + + if isinstance(result, list) and result: + for document in result: + context_parts.append(document.page_content) + context_parts.append("") + + final_context = "\n".join(context_parts) + + total_documents = sum( + len(result) if isinstance(result, list) else 0 + for result in subdirectory_results ) - retriever_list.append(rrf_retriever) + logger.info(f"Retrieved {total_documents} documents total") + + for i, subdirectory in enumerate(self._retrievers.keys()): + result = subdirectory_results[i] + if isinstance(result, list): + logger.info(f"{subdirectory}: {len(result)} documents") + else: + logger.warning(f"{subdirectory}: Failed") - reactome_retriever = MergerRetriever(retrievers=retriever_list) + logger.debug(f"Final context length: {len(final_context)} characters") - return reactome_retriever + return final_context + + +def create_hybrid_retriever( + embedding: Embeddings, embeddings_directory: Path +) -> HybridRetriever: + """Create a hybrid retriever with RRF and parallel processing support.""" + try: + return HybridRetriever( + embedding=embedding, embeddings_directory=embeddings_directory + ) + except Exception as e: + logger.error(f"Failed to create hybrid retriever: {e}") + raise RuntimeError(f"Hybrid retriever creation failed: {e}") from e \ No newline at end of file diff --git a/src/retrievers/rag_chain.py b/src/retrievers/rag_chain.py index 3e5df8e..5b78b66 100644 --- a/src/retrievers/rag_chain.py +++ b/src/retrievers/rag_chain.py @@ -1,26 +1,63 @@ -from langchain.chains.combine_documents import create_stuff_documents_chain -from langchain.chains.retrieval import create_retrieval_chain +from pathlib import Path +from typing import Any, Dict + +from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.prompts import ChatPromptTemplate -from langchain_core.retrievers import BaseRetriever -from langchain_core.runnables import Runnable +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables import Runnable, RunnableLambda + +from retrievers.csv_chroma import create_hybrid_retriever -def create_rag_chain( +def create_advanced_rag_chain( llm: BaseChatModel, - retriever: BaseRetriever, - qa_prompt: ChatPromptTemplate, + embedding: Embeddings, + embeddings_directory: Path, + system_prompt: str, + *, + streaming: bool = False, ) -> Runnable: - # Create the documents chain - question_answer_chain: Runnable = create_stuff_documents_chain( - llm=llm, - prompt=qa_prompt, - ) + """ + Create an advanced RAG chain with hybrid retrieval, query expansion, and streaming support. + + Args: + llm: Language model for generation + embedding: Embedding model for retrieval + embeddings_directory: Directory containing embeddings and CSV files + system_prompt: System prompt for the LLM + streaming: Whether to enable streaming responses + + Returns: + Runnable RAG chain + """ + retriever = create_hybrid_retriever(embedding, embeddings_directory) - # Create the retrieval chain - rag_chain: Runnable = create_retrieval_chain( - retriever=retriever, - combine_docs_chain=question_answer_chain, + prompt = ChatPromptTemplate.from_messages( + [ + ("system", system_prompt), + MessagesPlaceholder(variable_name="chat_history"), + ("user", "Context:\n{context}\n\nQuestion: {input}"), + ] ) - return rag_chain + if streaming: + llm = llm.model_copy(update={"streaming": True}) + + async def rag_chain(inputs: Dict[str, Any]) -> Dict[str, Any]: + user_input = inputs["input"] + chat_history = inputs.get("chat_history", []) + expanded_queries = inputs.get("expanded_queries", []) + + context = await retriever.ainvoke( + {"input": user_input, "expanded_queries": expanded_queries} + ) + + response = await llm.ainvoke( + prompt.format_messages( + context=context, input=user_input, chat_history=chat_history + ) + ) + + return {"answer": response.content, "context": context} + + return RunnableLambda(rag_chain) \ No newline at end of file diff --git a/src/retrievers/reactome/prompt.py b/src/retrievers/reactome/prompt.py index 9a11526..68af525 100644 --- a/src/retrievers/reactome/prompt.py +++ b/src/retrievers/reactome/prompt.py @@ -1,25 +1,32 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder reactome_system_prompt = """ -You are an expert in molecular biology with access to the Reactome Knowledgebase. -Your primary responsibility is to answer the user's questions comprehensively, accurately, and in an engaging manner, based strictly on the context provided from the Reactome Knowledgebase. -Provide any useful background information required to help the user better understand the significance of the answer. -Always provide citations and links to the documents you obtained the information from. +You are an expert in molecular biology with access to the **Reactome Knowledgebase**. +Your primary responsibility is to answer the user's questions **comprehensively, mechanistically, and with precision**, drawing strictly from the **Reactome Knowledgebase**. -When providing answers, please adhere to the following guidelines: -1. Provide answers **strictly based on the given context from the Reactome Knowledgebase**. Do **not** use or infer information from any external sources. -2. If the answer cannot be derived from the context provided, do **not** answer the question; instead explain that the information is not currently available in Reactome. -3. Answer the question comprehensively and accurately, providing useful background information based **only** on the context. -4. keep track of **all** the sources that are directly used to derive the final answer, ensuring **every** piece of information in your response is **explicitly cited**. -5. Create Citations for the sources used to generate the final asnwer according to the following: - - For Reactome always format citations in the following format: *Source_Name*, where *Source_Name* is the name of the retrieved document. - Examples: - - Apoptosis - - Cell Cycle +Your output must emphasize biological processes, molecular complexes, regulatory mechanisms, and interactions most relevant to the user’s question. +Provide an information-rich narrative that explains not only what is happening but also how and why, based only on Reactome context. -6. Always provide the citations you created in the format requested, in point-form at the end of the response paragraph, ensuring **every piece of information** provided in the final answer is cited. -7. Write in a conversational and engaging tone suitable for a chatbot. -8. Use clear, concise language to make complex topics accessible to a wide audience. + +## **Answering Guidelines** +1. Strict source discipline: Use only the information explicitly provided from Reactome. Do not invent, infer, or draw from external knowledge. + - If the answer cannot be derived from the context, explicitly state that the information is not currently available in Reactome. +2. Inline citations required: Every factual statement must include ≥1 inline anchor citation in the format: display_name + - If multiple entries support the same fact, cite them together (space-separated). +3. Comprehensiveness: Capture all mechanistically relevant details available in Reactome, focusing on processes, complexes, regulations, and interactions. +4. Tone & Style: + - Write in a clear, engaging, and conversational tone. + - Use accessible language while maintaining technical precision. + - Ensure the narrative flows logically, presenting background, mechanisms, and significance +5. Source list at the end: After the main narrative, provide a bullet-point list of each unique citation anchor exactly once, in the same Node Name format. + - Examples: + - Apoptosis + - Cell Cycle + +## Internal QA (silent) +- All factual claims are cited correctly. +- No unverified claims or background knowledge are added. +- The Sources list is complete and de-duplicated. """ reactome_qa_prompt = ChatPromptTemplate.from_messages( @@ -28,4 +35,4 @@ MessagesPlaceholder(variable_name="chat_history"), ("user", "Context:\n{context}\n\nQuestion: {input}"), ] -) +) \ No newline at end of file diff --git a/src/retrievers/reactome/rag.py b/src/retrievers/reactome/rag.py index 485b6e5..8758593 100644 --- a/src/retrievers/reactome/rag.py +++ b/src/retrievers/reactome/rag.py @@ -4,11 +4,8 @@ from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.runnables import Runnable -from retrievers.csv_chroma import create_bm25_chroma_ensemble_retriever -from retrievers.rag_chain import create_rag_chain -from retrievers.reactome.metadata_info import (reactome_descriptions_info, - reactome_field_info) -from retrievers.reactome.prompt import reactome_qa_prompt +from retrievers.rag_chain import create_advanced_rag_chain +from retrievers.reactome.prompt import reactome_system_prompt from util.embedding_environment import EmbeddingEnvironment @@ -19,15 +16,22 @@ def create_reactome_rag( *, streaming: bool = False, ) -> Runnable: - reactome_retriever = create_bm25_chroma_ensemble_retriever( - llm, - embedding, - embeddings_directory, - descriptions_info=reactome_descriptions_info, - field_info=reactome_field_info, - ) + """ + Create a Reactome-specific RAG chain with hybrid retrieval and query expansion. - if streaming: - llm = llm.model_copy(update={"streaming": True}) + Args: + llm: Language model for generation + embedding: Embedding model for retrieval + embeddings_directory: Directory containing Reactome embeddings and CSV files + streaming: Whether to enable streaming responses - return create_rag_chain(llm, reactome_retriever, reactome_qa_prompt) + Returns: + Runnable RAG chain for Reactome queries + """ + return create_advanced_rag_chain( + llm=llm, + embedding=embedding, + embeddings_directory=embeddings_directory, + system_prompt=reactome_system_prompt, + streaming=streaming, + ) \ No newline at end of file diff --git a/src/retrievers/retrieval_utils.py b/src/retrievers/retrieval_utils.py new file mode 100644 index 0000000..a378968 --- /dev/null +++ b/src/retrievers/retrieval_utils.py @@ -0,0 +1,27 @@ +from collections import defaultdict +from typing import List, Any, Tuple, Dict, Callable +from pydantic import BaseModel + +def reciprocal_rank_fusion( + ranked_lists: List[List[Any]], + final_k: int = 5, + lambda_mult: float = 60.0, + rrf_k: int | None = None, + id_getter: Callable[[Any], str] = lambda doc: doc.metadata.get("stId") or doc.metadata.get("stable_id"), +) -> Tuple[List[Any], List[str], Dict[str, float]]: + rrf_scores = defaultdict(float) + doc_meta = {} + + for ranked in ranked_lists: + considered = ranked[:rrf_k] if rrf_k else ranked + for rank, doc in enumerate(considered): + doc_id = id_getter(doc) + if doc_id is not None: # Skip documents without valid IDs + rrf_scores[doc_id] += 1.0 / (lambda_mult + rank + 1) + if doc_id not in doc_meta: + doc_meta[doc_id] = doc + + sorted_items = sorted(rrf_scores.items(), key=lambda x: (-x[1], x[0])) + top_ids = [doc_id for doc_id, _ in sorted_items[:final_k]] + top_docs = [doc_meta[doc_id] for doc_id in top_ids] + return top_docs, top_ids, rrf_scores diff --git a/src/retrievers/uniprot/prompt.py b/src/retrievers/uniprot/prompt.py index 7cb0910..57ab26e 100644 --- a/src/retrievers/uniprot/prompt.py +++ b/src/retrievers/uniprot/prompt.py @@ -28,4 +28,4 @@ MessagesPlaceholder(variable_name="chat_history"), ("user", "Context:\n{context}\n\nQuestion: {input}"), ] -) +) \ No newline at end of file diff --git a/src/retrievers/uniprot/rag.py b/src/retrievers/uniprot/rag.py index 99702d7..078eaa8 100644 --- a/src/retrievers/uniprot/rag.py +++ b/src/retrievers/uniprot/rag.py @@ -4,10 +4,7 @@ from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.runnables import Runnable -from retrievers.csv_chroma import create_bm25_chroma_ensemble_retriever -from retrievers.rag_chain import create_rag_chain -from retrievers.uniprot.metadata_info import (uniprot_descriptions_info, - uniprot_field_info) +from retrievers.rag_chain import create_advanced_rag_chain from retrievers.uniprot.prompt import uniprot_qa_prompt from util.embedding_environment import EmbeddingEnvironment @@ -19,15 +16,22 @@ def create_uniprot_rag( *, streaming: bool = False, ) -> Runnable: - reactome_retriever = create_bm25_chroma_ensemble_retriever( - llm, - embedding, - embeddings_directory, - descriptions_info=uniprot_descriptions_info, - field_info=uniprot_field_info, - ) + """ + Create a UniProt-specific RAG chain with hybrid retrieval and query expansion. - if streaming: - llm = llm.model_copy(update={"streaming": True}) + Args: + llm: Language model for generation + embedding: Embedding model for retrieval + embeddings_directory: Directory containing UniProt embeddings and CSV files + streaming: Whether to enable streaming responses - return create_rag_chain(llm, reactome_retriever, uniprot_qa_prompt) + Returns: + Runnable RAG chain for UniProt queries + """ + return create_advanced_rag_chain( + llm=llm, + embedding=embedding, + embeddings_directory=embeddings_directory, + system_prompt=uniprot_qa_prompt, + streaming=streaming, + ) \ No newline at end of file From b2cc4bba35876d152e68ff9d23f2488eaae932ee Mon Sep 17 00:00:00 2001 From: Helia Mohammadi Date: Sat, 27 Sep 2025 18:06:02 -0400 Subject: [PATCH 04/13] feat: Add new runnables for checking question safety and scope, query expansion and conversation history management --- src/agent/tasks/query_expansion.py | 60 ++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 src/agent/tasks/query_expansion.py diff --git a/src/agent/tasks/query_expansion.py b/src/agent/tasks/query_expansion.py new file mode 100644 index 0000000..3d0ba61 --- /dev/null +++ b/src/agent/tasks/query_expansion.py @@ -0,0 +1,60 @@ +import json +from typing import List + +from langchain.prompts import ChatPromptTemplate +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.output_parsers import StrOutputParser +from langchain_core.runnables import Runnable, RunnableLambda + + +def QueryExpansionParser(output: str) -> List[str]: + """Parse JSON array output from LLM.""" + try: + return json.loads(output) + except json.JSONDecodeError: + raise ValueError("LLM output was not valid JSON. Output:\n" + output) + + +def create_query_expander(llm: BaseChatModel) -> Runnable: + """ + Create a query expansion chain that generates 4 alternative queries. + + Args: + llm: Language model to use + + Returns: + Runnable that takes standalone_query and returns List[str] + """ + system_prompt = """ +You are a biomedical question expansion engine for information retrieval over the Reactome biological pathway database. + +Given a single user question, generate **exactly 4** alternate standalone questions. These should be: + +- Semantically related to the original question. +- Lexically diverse to improve retrieval via vector search and RAG-fusion. +- Biologically enriched with inferred or associated details. + +Your goal is to improve recall of relevant documents by expanding the original query using: +- Synonymous gene/protein names (e.g., EGFR, ErbB1, HER1) +- Pathway or process-level context (e.g., signal transduction, apoptosis) +- Known diseases, phenotypes, or biological functions +- Cellular localization (e.g., nucleus, cytoplasm, membrane) +- Upstream/downstream molecular interactions + +Rules: +- Each question must be **fully standalone** (no "this"/"it"). +- Do not change the core intent—preserve the user's informational goal. +- Use appropriate biological terminology and Reactome-relevant concepts. +- Vary the **phrasing**, **focus**, or **biological angle** of each question. +- If the input is ambiguous, infer a biologically meaningful interpretation. + +Output: +Return only a valid JSON array of 4 strings (no explanations, no metadata). +Do not include any explanations or metadata. +""" + + prompt = ChatPromptTemplate.from_messages( + [("system", system_prompt), ("user", "Original Question: {rephrased_input}")] + ) + + return prompt | llm | StrOutputParser() | RunnableLambda(QueryExpansionParser) \ No newline at end of file From 3b9e95d6e5decf742e39373451fc1d85bafc4c70 Mon Sep 17 00:00:00 2001 From: Helia Mohammadi Date: Sun, 28 Sep 2025 13:45:56 -0400 Subject: [PATCH 05/13] code quality check fixes --- src/agent/graph.py | 2 +- src/agent/models.py | 2 +- .../unsafe_question.py | 23 +++++++++++-------- src/agent/tasks/query_expansion.py | 2 +- src/agent/tasks/rephrase.py | 2 +- src/agent/tasks/safety_checker.py | 2 +- src/retrievers/csv_chroma.py | 4 +--- src/retrievers/rag_chain.py | 2 +- src/retrievers/reactome/prompt.py | 2 +- src/retrievers/reactome/rag.py | 2 +- src/retrievers/retrieval_utils.py | 7 +++--- src/retrievers/uniprot/prompt.py | 2 +- src/retrievers/uniprot/rag.py | 2 +- 13 files changed, 29 insertions(+), 25 deletions(-) diff --git a/src/agent/graph.py b/src/agent/graph.py index 2fab414..012df27 100644 --- a/src/agent/graph.py +++ b/src/agent/graph.py @@ -97,4 +97,4 @@ async def ainvoke( }, ), ) - return result \ No newline at end of file + return result diff --git a/src/agent/models.py b/src/agent/models.py index b80413a..01b324c 100644 --- a/src/agent/models.py +++ b/src/agent/models.py @@ -65,4 +65,4 @@ def get_llm( base_url=base_url, ) else: - raise ValueError(f"Unknown provider: {provider}") \ No newline at end of file + raise ValueError(f"Unknown provider: {provider}") diff --git a/src/agent/tasks/final_answer_generation/unsafe_question.py b/src/agent/tasks/final_answer_generation/unsafe_question.py index 5369332..7193a18 100644 --- a/src/agent/tasks/final_answer_generation/unsafe_question.py +++ b/src/agent/tasks/final_answer_generation/unsafe_question.py @@ -1,13 +1,13 @@ -from langchain_core.prompts import ChatPromptTemplate from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import Runnable + # unsafe or out of scope answer generator def create_unsafe_answer_generator(llm: BaseChatModel) -> Runnable: """ Create an unsafe answer generator chain. - + Args: llm: Language model to use @@ -34,9 +34,14 @@ def create_unsafe_answer_generator(llm: BaseChatModel) -> Runnable: You must not provide any workaround, implicit answer, or redirection toward unsafe content. """ - prompt = ChatPromptTemplate.from_messages([ - ("system", system_prompt), - ("user", "Language:{language}\n\nQuestion:{user_input}\n\n Reason for unsafe or out of scope: {reason_unsafe}") - ]) - - return prompt | llm \ No newline at end of file + prompt = ChatPromptTemplate.from_messages( + [ + ("system", system_prompt), + ( + "user", + "Language:{language}\n\nQuestion:{user_input}\n\n Reason for unsafe or out of scope: {reason_unsafe}", + ), + ] + ) + + return prompt | llm diff --git a/src/agent/tasks/query_expansion.py b/src/agent/tasks/query_expansion.py index 3d0ba61..7cddaaf 100644 --- a/src/agent/tasks/query_expansion.py +++ b/src/agent/tasks/query_expansion.py @@ -57,4 +57,4 @@ def create_query_expander(llm: BaseChatModel) -> Runnable: [("system", system_prompt), ("user", "Original Question: {rephrased_input}")] ) - return prompt | llm | StrOutputParser() | RunnableLambda(QueryExpansionParser) \ No newline at end of file + return prompt | llm | StrOutputParser() | RunnableLambda(QueryExpansionParser) diff --git a/src/agent/tasks/rephrase.py b/src/agent/tasks/rephrase.py index 7c1be6a..e256104 100644 --- a/src/agent/tasks/rephrase.py +++ b/src/agent/tasks/rephrase.py @@ -27,4 +27,4 @@ def create_rephrase_chain(llm: BaseChatModel) -> Runnable: return (contextualize_q_prompt | llm | StrOutputParser()).with_config( run_name="rephrase_question" - ) \ No newline at end of file + ) diff --git a/src/agent/tasks/safety_checker.py b/src/agent/tasks/safety_checker.py index 5e080d6..3c46f8e 100644 --- a/src/agent/tasks/safety_checker.py +++ b/src/agent/tasks/safety_checker.py @@ -73,4 +73,4 @@ def create_safety_checker(llm: BaseChatModel) -> Runnable: ] ) - return prompt | llm.with_structured_output(SafetyCheck) \ No newline at end of file + return prompt | llm.with_structured_output(SafetyCheck) diff --git a/src/retrievers/csv_chroma.py b/src/retrievers/csv_chroma.py index 15a1a99..a0fbba5 100644 --- a/src/retrievers/csv_chroma.py +++ b/src/retrievers/csv_chroma.py @@ -10,8 +10,6 @@ from langchain_community.retrievers import BM25Retriever from langchain_core.documents import Document from langchain_core.embeddings import Embeddings -from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.retrievers import BaseRetriever from retrievers.retrieval_utils import reciprocal_rank_fusion @@ -323,4 +321,4 @@ def create_hybrid_retriever( ) except Exception as e: logger.error(f"Failed to create hybrid retriever: {e}") - raise RuntimeError(f"Hybrid retriever creation failed: {e}") from e \ No newline at end of file + raise RuntimeError(f"Hybrid retriever creation failed: {e}") from e diff --git a/src/retrievers/rag_chain.py b/src/retrievers/rag_chain.py index 5b78b66..66235f7 100644 --- a/src/retrievers/rag_chain.py +++ b/src/retrievers/rag_chain.py @@ -60,4 +60,4 @@ async def rag_chain(inputs: Dict[str, Any]) -> Dict[str, Any]: return {"answer": response.content, "context": context} - return RunnableLambda(rag_chain) \ No newline at end of file + return RunnableLambda(rag_chain) diff --git a/src/retrievers/reactome/prompt.py b/src/retrievers/reactome/prompt.py index 68af525..43c43f5 100644 --- a/src/retrievers/reactome/prompt.py +++ b/src/retrievers/reactome/prompt.py @@ -35,4 +35,4 @@ MessagesPlaceholder(variable_name="chat_history"), ("user", "Context:\n{context}\n\nQuestion: {input}"), ] -) \ No newline at end of file +) diff --git a/src/retrievers/reactome/rag.py b/src/retrievers/reactome/rag.py index 8758593..0a02df3 100644 --- a/src/retrievers/reactome/rag.py +++ b/src/retrievers/reactome/rag.py @@ -34,4 +34,4 @@ def create_reactome_rag( embeddings_directory=embeddings_directory, system_prompt=reactome_system_prompt, streaming=streaming, - ) \ No newline at end of file + ) diff --git a/src/retrievers/retrieval_utils.py b/src/retrievers/retrieval_utils.py index a378968..90f3e2e 100644 --- a/src/retrievers/retrieval_utils.py +++ b/src/retrievers/retrieval_utils.py @@ -1,13 +1,14 @@ from collections import defaultdict -from typing import List, Any, Tuple, Dict, Callable -from pydantic import BaseModel +from typing import Any, Callable, Dict, List, Tuple + def reciprocal_rank_fusion( ranked_lists: List[List[Any]], final_k: int = 5, lambda_mult: float = 60.0, rrf_k: int | None = None, - id_getter: Callable[[Any], str] = lambda doc: doc.metadata.get("stId") or doc.metadata.get("stable_id"), + id_getter: Callable[[Any], str] = lambda doc: doc.metadata.get("stId") + or doc.metadata.get("stable_id"), ) -> Tuple[List[Any], List[str], Dict[str, float]]: rrf_scores = defaultdict(float) doc_meta = {} diff --git a/src/retrievers/uniprot/prompt.py b/src/retrievers/uniprot/prompt.py index 57ab26e..7cb0910 100644 --- a/src/retrievers/uniprot/prompt.py +++ b/src/retrievers/uniprot/prompt.py @@ -28,4 +28,4 @@ MessagesPlaceholder(variable_name="chat_history"), ("user", "Context:\n{context}\n\nQuestion: {input}"), ] -) \ No newline at end of file +) diff --git a/src/retrievers/uniprot/rag.py b/src/retrievers/uniprot/rag.py index 078eaa8..1ef5d8d 100644 --- a/src/retrievers/uniprot/rag.py +++ b/src/retrievers/uniprot/rag.py @@ -34,4 +34,4 @@ def create_uniprot_rag( embeddings_directory=embeddings_directory, system_prompt=uniprot_qa_prompt, streaming=streaming, - ) \ No newline at end of file + ) From 2864e974b492cab9ca51cbe5d7d6b0e61622c5b3 Mon Sep 17 00:00:00 2001 From: Helia Mohammadi Date: Sun, 28 Sep 2025 14:53:11 -0400 Subject: [PATCH 06/13] fix: Resolve mypy linter errors - Add type annotation for rrf_scores in retrieval_utils.py - Fix metadata dictionary comprehension in csv_chroma.py - Update retriever type annotations to use Any - Add isinstance check for BM25Retriever - Remove default values from TypedDict in base.py - Fix TypedDict expansion in postprocess method --- src/agent/profiles/base.py | 37 +- src/agent/profiles/react_to_me.py | 26 +- .../reactome_kg/create_test_embeddings.py | 315 +++++++++++++++ src/retrievers/csv_chroma.py | 30 +- src/retrievers/graph_rag/uniprot_retriever.py | 376 ++++++++++++++++++ src/retrievers/retrieval_utils.py | 2 +- 6 files changed, 758 insertions(+), 28 deletions(-) create mode 100644 src/data_generation/reactome_kg/create_test_embeddings.py create mode 100644 src/retrievers/graph_rag/uniprot_retriever.py diff --git a/src/agent/profiles/base.py b/src/agent/profiles/base.py index 6536978..e583304 100644 --- a/src/agent/profiles/base.py +++ b/src/agent/profiles/base.py @@ -43,21 +43,29 @@ class BaseState(InputState, OutputState, total=False): chat_history: Annotated[list[BaseMessage], add_messages] # Preprocessing results - safety: str = SAFETY_SAFE - reason_unsafe: str = "" - expanded_queries: list[str] = [] - detected_language: str = DEFAULT_LANGUAGE + safety: str + reason_unsafe: str + expanded_queries: list[str] + detected_language: str class BaseGraphBuilder: """Base class for all graph builders with common preprocessing and postprocessing.""" - def __init__(self, llm: BaseChatModel, embedding: Embeddings) -> None: + def __init__( + self, + llm: BaseChatModel, + embedding: Embeddings + ) -> None: """Initialize with LLM and embedding models.""" self.preprocessing_workflow: Runnable = create_preprocessing_workflow(llm) self.search_workflow: Runnable = create_search_workflow(llm) - async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: + async def preprocess( + self, + state: BaseState, + config: RunnableConfig + ) -> BaseState: """Run the complete preprocessing workflow and map results to state.""" result: PreprocessingState = await self.preprocessing_workflow.ainvoke( PreprocessingState( @@ -69,7 +77,10 @@ async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseStat return self._map_preprocessing_result(result) - def _map_preprocessing_result(self, result: PreprocessingState) -> BaseState: + def _map_preprocessing_result( + self, + result: PreprocessingState + ) -> BaseState: """Map preprocessing results to BaseState with defaults.""" return BaseState( rephrased_input=result["rephrased_input"], @@ -79,7 +90,11 @@ def _map_preprocessing_result(self, result: PreprocessingState) -> BaseState: detected_language=result.get("detected_language", DEFAULT_LANGUAGE), ) - async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: + async def postprocess( + self, + state: BaseState, + config: RunnableConfig + ) -> BaseState: """Postprocess that preserves existing state and conditionally adds search results.""" search_results: list[WebSearchResult] = [] @@ -98,8 +113,4 @@ async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseSta search_results = result["search_results"] # Create new state with updated additional_content - new_state = dict(state) # Copy existing state - new_state["additional_content"] = AdditionalContent( - search_results=search_results - ) - return BaseState(**new_state) + return BaseState(**{**state, "additional_content": AdditionalContent(search_results=search_results)}) diff --git a/src/agent/profiles/react_to_me.py b/src/agent/profiles/react_to_me.py index f0e4980..d2dab7d 100644 --- a/src/agent/profiles/react_to_me.py +++ b/src/agent/profiles/react_to_me.py @@ -23,7 +23,11 @@ class ReactToMeState(BaseState): class ReactToMeGraphBuilder(BaseGraphBuilder): """Graph builder for ReactToMe profile with Reactome-specific functionality.""" - def __init__(self, llm: BaseChatModel, embedding: Embeddings) -> None: + def __init__( + self, + llm: BaseChatModel, + embedding: Embeddings + ) -> None: """Initialize ReactToMe graph builder with required components.""" super().__init__(llm, embedding) @@ -65,20 +69,25 @@ def _build_workflow(self) -> StateGraph: return state_graph async def preprocess( - self, state: ReactToMeState, config: RunnableConfig + self, + state: ReactToMeState, + config: RunnableConfig ) -> ReactToMeState: """Run preprocessing workflow.""" result = await super().preprocess(state, config) return ReactToMeState(**result) async def proceed_with_research( - self, state: ReactToMeState + self, + state: ReactToMeState ) -> Literal["Continue", "Finish"]: """Determine whether to proceed with research based on safety check.""" return "Continue" if state["safety"] == SAFETY_SAFE else "Finish" async def generate_unsafe_response( - self, state: ReactToMeState, config: RunnableConfig + self, + state: ReactToMeState, + config: RunnableConfig ) -> ReactToMeState: """Generate appropriate refusal response for unsafe queries.""" final_answer_message = await self.unsafe_answer_generator.ainvoke( @@ -111,7 +120,9 @@ async def generate_unsafe_response( ) async def call_model( - self, state: ReactToMeState, config: RunnableConfig + self, + state: ReactToMeState, + config: RunnableConfig ) -> ReactToMeState: """Generate response using Reactome RAG for safe queries.""" result: dict[str, Any] = await self.reactome_rag.ainvoke( @@ -136,6 +147,9 @@ async def call_model( ) -def create_reactome_graph(llm: BaseChatModel, embedding: Embeddings) -> StateGraph: +def create_reactome_graph( + llm: BaseChatModel, + embedding: Embeddings + ) -> StateGraph: """Create and return the ReactToMe workflow graph.""" return ReactToMeGraphBuilder(llm, embedding).uncompiled_graph diff --git a/src/data_generation/reactome_kg/create_test_embeddings.py b/src/data_generation/reactome_kg/create_test_embeddings.py new file mode 100644 index 0000000..9c5c478 --- /dev/null +++ b/src/data_generation/reactome_kg/create_test_embeddings.py @@ -0,0 +1,315 @@ +import logging +import os +import sys +import time +from typing import Any, Dict, List + +import weaviate +from langchain_openai import OpenAIEmbeddings + +# Test-specific configuration +TEST_NEO4J_URI = "bolt://localhost:7690" +TEST_NEO4J_USER = "neo4j" +TEST_NEO4J_PASSWORD = "reactome-test" +TEST_NEO4J_DATABASE = "reactome-kg-test" + +TEST_WEAVIATE_HOST = "localhost" +TEST_WEAVIATE_PORT = 8081 +TEST_WEAVIATE_CLASS_NAME = "TestReactomeKG" + +# TP53-related terms for filtering +TP53_TERMS = [ + "TP53", + "p53", + "tumor protein p53", + "cellular tumor antigen p53", + "transformation-related protein 53", + "tumor suppressor p53", +] + +# Batch processing +BATCH_SIZE = 10 +MIN_TEXT_LENGTH = 50 + + +def _setup_logging() -> logging.Logger: + """Set up logging configuration.""" + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + return logging.getLogger(__name__) + + +def _connect_to_services() -> tuple: + """Connect to Neo4j and Weaviate services.""" + logger = logging.getLogger(__name__) + + # Connect to Neo4j with retry logic + neo4j_driver = None + for attempt in range(10): # Try up to 10 times + try: + from neo4j import GraphDatabase + + neo4j_driver = GraphDatabase.driver( + TEST_NEO4J_URI, auth=(TEST_NEO4J_USER, TEST_NEO4J_PASSWORD) + ) + neo4j_driver.verify_connectivity() + logger.info(f"Connected to Neo4j at {TEST_NEO4J_URI}") + break + except Exception as e: + if attempt == 9: # Last attempt + logger.error(f"Failed to connect to Neo4j after 10 attempts: {e}") + raise + logger.info( + f"Neo4j connection attempt {attempt + 1}/10 failed, retrying in 10 seconds..." + ) + time.sleep(10) + + if neo4j_driver is None: + raise Exception("Failed to establish Neo4j connection") + + # Connect to Weaviate + try: + weaviate_client = weaviate.WeaviateClient( + connection_params=weaviate.connect.ConnectionParams.from_url( + f"http://{TEST_WEAVIATE_HOST}:{TEST_WEAVIATE_PORT}", grpc_port=50052 + ), + additional_headers={"X-OpenAI-Api-Key": os.environ.get("OPENAI_API_KEY")}, + ) + logger.info( + f"Connected to Weaviate at {TEST_WEAVIATE_HOST}:{TEST_WEAVIATE_PORT}" + ) + except Exception as e: + logger.error(f"Failed to connect to Weaviate: {e}") + raise + + return neo4j_driver, weaviate_client + + +def _is_tp53_related(text: str) -> bool: + """Check if text content is related to TP53.""" + if not text: + return False + + text_lower = str(text).lower() + return any(term.lower() in text_lower for term in TP53_TERMS) + + +def _fetch_tp53_nodes(neo4j_driver) -> List[Dict[str, Any]]: + """Fetch TP53-related nodes from Neo4j.""" + logger = logging.getLogger(__name__) + + query = """ + MATCH (n) + WHERE n.name CONTAINS 'TP53' OR n.name CONTAINS 'p53' + OR n.displayName CONTAINS 'TP53' OR n.displayName CONTAINS 'p53' + OR n.description CONTAINS 'TP53' OR n.description CONTAINS 'p53' + OR n.text_content CONTAINS 'TP53' OR n.text_content CONTAINS 'p53' + OR n.stableId CONTAINS 'TP53' + RETURN n.stableId as stable_id, + labels(n)[0] as label, + n.name as name, + n.displayName as display_name, + n.description as description, + n.text_content as text_content + LIMIT 200 + """ + + try: + with neo4j_driver.session(database=TEST_NEO4J_DATABASE) as session: + result = session.run(query) + nodes = [] + + for record in result: + node_data = { + "stable_id": record["stable_id"], + "label": record["label"], + "name": record["name"], + "display_name": record["display_name"], + "description": record["description"], + "text_content": record["text_content"], + } + + # Additional TP53 filtering + combined_text = " ".join( + [ + node_data.get("name", ""), + node_data.get("description", ""), + node_data.get("text_content", ""), + node_data.get("stable_id", ""), + ] + ) + + if _is_tp53_related(combined_text): + nodes.append(node_data) + + logger.info(f"Found {len(nodes)} TP53-related nodes") + return nodes + + except Exception as e: + logger.error(f"Error fetching TP53 nodes: {e}") + raise + + +def _create_embeddings_and_store(nodes: List[Dict[str, Any]], weaviate_client) -> int: + """Create embeddings for nodes and store in Weaviate.""" + logger = logging.getLogger(__name__) + + # Initialize embeddings + embeddings = OpenAIEmbeddings(model="text-embedding-3-large") + + # Create Weaviate schema if it doesn't exist + try: + if weaviate_client.schema.exists(TEST_WEAVIATE_CLASS_NAME): + logger.info(f"Deleting existing schema: {TEST_WEAVIATE_CLASS_NAME}") + weaviate_client.schema.delete_class(TEST_WEAVIATE_CLASS_NAME) + except Exception as e: + logger.warning(f"Error checking/deleting schema: {e}") + + # Create new schema + schema = { + "class": TEST_WEAVIATE_CLASS_NAME, + "description": "Test Reactome KG embeddings for TP53-related entities", + "vectorizer": "text2vec-openai", + "moduleConfig": { + "text2vec-openai": { + "model": "text-embedding-3-large", + "modelVersion": "002", + "dimensions": 3072, + "type": "text", + } + }, + "properties": [ + { + "name": "stable_id", + "dataType": ["string"], + "description": "Reactome stable identifier", + }, + {"name": "label", "dataType": ["string"], "description": "Node label/type"}, + {"name": "name", "dataType": ["string"], "description": "Node name"}, + { + "name": "display_name", + "dataType": ["string"], + "description": "Display name", + }, + { + "name": "description", + "dataType": ["string"], + "description": "Node description", + }, + { + "name": "text_content", + "dataType": ["text"], + "description": "Full text content for embedding", + }, + ], + } + + try: + weaviate_client.schema.create_class(schema) + logger.info(f"Created Weaviate schema: {TEST_WEAVIATE_CLASS_NAME}") + except Exception as e: + logger.error(f"Error creating schema: {e}") + raise + + # Process nodes in batches + processed_count = 0 + + for i in range(0, len(nodes), BATCH_SIZE): + batch = nodes[i : i + BATCH_SIZE] + batch_objects = [] + + for node in batch: + # Create text content for embedding + text_parts = [] + if node.get("name"): + text_parts.append(f"Name: {node['name']}") + if node.get("display_name"): + text_parts.append(f"Display Name: {node['display_name']}") + if node.get("description"): + text_parts.append(f"Description: {node['description']}") + if node.get("text_content"): + text_parts.append(f"Content: {node['text_content']}") + + text_content = " | ".join(text_parts) + + if len(text_content) < MIN_TEXT_LENGTH: + continue + + # Create Weaviate object + obj = { + "stable_id": node["stable_id"], + "label": node["label"], + "name": node.get("name", ""), + "display_name": node.get("display_name", ""), + "description": node.get("description", ""), + "text_content": text_content, + } + + batch_objects.append(obj) + + # Store batch in Weaviate + if batch_objects: + try: + with weaviate_client.batch as batch: + for obj in batch_objects: + batch.add_data_object( + data_object=obj, class_name=TEST_WEAVIATE_CLASS_NAME + ) + + processed_count += len(batch_objects) + logger.info( + f"Processed batch {i//BATCH_SIZE + 1}: {len(batch_objects)} objects" + ) + + except Exception as e: + logger.error(f"Error storing batch: {e}") + continue + + logger.info(f"Total objects processed and stored: {processed_count}") + return processed_count + + +def main(): + """Main function to create test embeddings.""" + logger = _setup_logging() + + try: + logger.info("Starting TP53-focused test embedding creation...") + logger.info(f"Neo4j: {TEST_NEO4J_URI} (database: {TEST_NEO4J_DATABASE})") + logger.info(f"Weaviate: {TEST_WEAVIATE_HOST}:{TEST_WEAVIATE_PORT}") + + # Connect to services + neo4j_driver, weaviate_client = _connect_to_services() + + # Fetch TP53 nodes + nodes = _fetch_tp53_nodes(neo4j_driver) + + if not nodes: + logger.warning("No TP53-related nodes found!") + return + + # Create and store embeddings + processed_count = _create_embeddings_and_store(nodes, weaviate_client) + + logger.info("=" * 60) + logger.info("TEST EMBEDDING CREATION SUMMARY") + logger.info("=" * 60) + logger.info(f"Total TP53 nodes found: {len(nodes)}") + logger.info(f"Total embeddings created: {processed_count}") + logger.info(f"Weaviate class: {TEST_WEAVIATE_CLASS_NAME}") + logger.info("=" * 60) + + logger.info("Test embedding creation completed successfully!") + + except Exception as e: + logger.error(f"Error creating test embeddings: {e}") + sys.exit(1) + finally: + if "neo4j_driver" in locals(): + neo4j_driver.close() + + +if __name__ == "__main__": + main() diff --git a/src/retrievers/csv_chroma.py b/src/retrievers/csv_chroma.py index a0fbba5..3d96c12 100644 --- a/src/retrievers/csv_chroma.py +++ b/src/retrievers/csv_chroma.py @@ -47,7 +47,7 @@ def create_documents_from_csv(csv_path: Path) -> List[Document]: page_content = "\n".join(content_parts) metadata = { - column: str(value) + str(column): str(value) for column in df.columns for value in [row[column]] if pd.notna(value) and str(value) != "nan" @@ -77,7 +77,12 @@ def list_chroma_subdirectories(directory: Path) -> List[str]: class HybridRetriever: """Advanced hybrid retriever supporting RRF, parallel processing, and multi-source search.""" - def __init__(self, embedding: Embeddings, embeddings_directory: Path): + def __init__( + self, + embedding: Embeddings, + embeddings_directory: Path + ): + self.embedding = embedding self.embeddings_directory = embeddings_directory self._retrievers: Dict[ @@ -154,19 +159,25 @@ def _create_vector_retriever(self, subdirectory: str) -> Optional[object]: return None async def _search_with_bm25( - self, query: str, retriever: BM25Retriever + self, + query: str, + retriever: BM25Retriever ) -> List[Document]: """Search using BM25 retriever asynchronously.""" return await asyncio.to_thread(retriever.get_relevant_documents, query) async def _search_with_vector( - self, query: str, retriever: object + self, + query: str, + retriever: Any ) -> List[Document]: """Search using vector retriever asynchronously.""" return await asyncio.to_thread(retriever.get_relevant_documents, query) async def _execute_hybrid_search( - self, query: str, subdirectory: str + self, + query: str, + subdirectory: str ) -> List[Document]: """Execute hybrid search (BM25 + vector) for a single query on a subdirectory.""" retriever_info = self._retrievers.get(subdirectory) @@ -176,7 +187,7 @@ async def _execute_hybrid_search( search_tasks = [] - if retriever_info["bm25"]: + if retriever_info["bm25"] and isinstance(retriever_info["bm25"], BM25Retriever): search_tasks.append(self._search_with_bm25(query, retriever_info["bm25"])) if retriever_info["vector"]: @@ -212,7 +223,9 @@ def _generate_document_identifier(self, document: Document) -> str: return hashlib.md5(document.page_content.encode()).hexdigest() async def _apply_reciprocal_rank_fusion( - self, queries: List[str], subdirectory: str + self, + queries: List[str], + subdirectory: str ) -> List[Document]: """Apply Reciprocal Rank Fusion to results from multiple queries on a subdirectory.""" logger.info( @@ -312,7 +325,8 @@ async def ainvoke(self, inputs: Dict[str, Any]) -> str: def create_hybrid_retriever( - embedding: Embeddings, embeddings_directory: Path + embedding: Embeddings, + embeddings_directory: Path ) -> HybridRetriever: """Create a hybrid retriever with RRF and parallel processing support.""" try: diff --git a/src/retrievers/graph_rag/uniprot_retriever.py b/src/retrievers/graph_rag/uniprot_retriever.py new file mode 100644 index 0000000..3a6af8d --- /dev/null +++ b/src/retrievers/graph_rag/uniprot_retriever.py @@ -0,0 +1,376 @@ +from __future__ import annotations + +import asyncio +import logging +import os +from pathlib import Path +from typing import List, Optional + +import chromadb.config +from langchain_chroma.vectorstores import Chroma +from langchain_community.document_loaders.csv_loader import CSVLoader +from langchain_community.retrievers import BM25Retriever +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings + +from src.retrievers.csv_chroma import list_chroma_subdirectories +from src.util.embedding_environment import EmbeddingEnvironment + +from .retrieval_utils import UniProtRetrievalConfig, reciprocal_rank_fusion + +logger = logging.getLogger(__name__) + +CHROMA_SETTINGS = chromadb.config.Settings(anonymized_telemetry=False) + + +class UniProtRetriever: + """ + UniProt vector retriever that supports RRF and similarity search. + + This retriever provides the same configuration options as the Graph RAG + retriever but operates only on vector embeddings without graph traversal. + Returns page content strings (protein information only, no metadata). + + Features: + - Single vectorstore using the most recent subdirectory + - Single BM25 retriever using the most recent CSV file + - Reciprocal Rank Fusion (RRF) support + - Hybrid vector + BM25 search + - Clean page content output for LLM consumption + """ + + DEFAULT_BM25_K = 10 + + def __init__( + self, + embedding: Embeddings, + embeddings_directory: Optional[Path] = None, + ) -> None: + """ + Initialize the UniProt vector retriever. + + Args: + embedding: Embedding model for vector operations + embeddings_directory: Path to UniProt embeddings directory. + Defaults to EmbeddingEnvironment.get_dir("uniprot") + + Raises: + ValueError: If embeddings_directory doesn't exist + RuntimeError: If initialization fails + """ + self.embedding = embedding + self.embeddings_directory = ( + embeddings_directory or EmbeddingEnvironment.get_dir("uniprot") + ) + + if not self.embeddings_directory.exists(): + raise ValueError( + f"Embeddings directory does not exist: {self.embeddings_directory}" + ) + + self._vectorstore: Optional[Chroma] = None + self._bm25_retriever: Optional[BM25Retriever] = None + self._subdirectory: Optional[str] = None + + try: + self._initialize_retrievers() + except Exception as e: + logger.error(f"Failed to initialize UniProt retriever: {e}") + raise RuntimeError(f"UniProt retriever initialization failed: {e}") from e + + def _initialize_retrievers(self) -> None: + """Initialize both vectorstore and BM25 retriever.""" + subdirectories = list_chroma_subdirectories(self.embeddings_directory) + + if not subdirectories: + raise RuntimeError("No UniProt subdirectories found") + + # Use the most recently created subdirectory + self._subdirectory = self._get_latest_subdirectory(subdirectories) + + self._initialize_vectorstore() + self._initialize_bm25_retriever() + + logger.info( + f"UniProt retriever initialized successfully using subdirectory: {self._subdirectory}" + ) + + def _initialize_vectorstore(self) -> None: + """Initialize Chroma vectorstore for UniProt.""" + try: + self._vectorstore = Chroma( + persist_directory=str(self.embeddings_directory / self._subdirectory), + embedding_function=self.embedding, + client_settings=CHROMA_SETTINGS, + ) + logger.info(f"Initialized UniProt vectorstore: {self._subdirectory}") + except Exception as e: + logger.error(f"Failed to initialize vectorstore: {e}") + raise + + def _initialize_bm25_retriever(self) -> None: + """Initialize BM25 retriever for UniProt.""" + try: + csv_file_name = f"{self._subdirectory}.csv" + csvs_dir = self.embeddings_directory / "csv_files" + csv_path = csvs_dir / csv_file_name + + if not csv_path.exists(): + logger.warning(f"CSV file not found: {csv_path}") + return + + loader = CSVLoader(file_path=str(csv_path)) + data = loader.load() + + if not data: + logger.warning(f"No data loaded from CSV: {csv_path}") + return + + self._bm25_retriever = BM25Retriever.from_documents(data) + self._bm25_retriever.k = self.DEFAULT_BM25_K + + logger.info(f"Initialized UniProt BM25 retriever: {self._subdirectory}") + except Exception as e: + logger.error(f"Failed to initialize BM25 retriever: {e}") + + def _get_latest_subdirectory(self, subdirectories: List[str]) -> str: + """ + Get the most recently created subdirectory. + + Args: + subdirectories: List of subdirectory names + + Returns: + Name of the most recently created subdirectory + """ + subdir_times = [] + + for subdir in subdirectories: + subdir_path = self.embeddings_directory / subdir + if subdir_path.exists(): + mtime = os.path.getmtime(subdir_path) + subdir_times.append((subdir, mtime)) + + if not subdir_times: + logger.warning("No valid subdirectories found, using first available") + return subdirectories[0] + + # Sort by modification time (most recent first) and return the latest + subdir_times.sort(key=lambda x: x[1], reverse=True) + latest_subdir = subdir_times[0][0] + + logger.info(f"Using most recent UniProt subdirectory: {latest_subdir}") + return latest_subdir + + async def ainvoke( + self, + query: str, + cfg: UniProtRetrievalConfig, + expanded_queries: Optional[List[str]] = None, + ) -> List[str]: + """ + Invoke the UniProt retrieval pipeline. + + Args: + query: Search query + cfg: Retrieval configuration + expanded_queries: Optional list of expanded queries for RRF + + Returns: + List of page content strings (protein information only, no metadata) + """ + if not query.strip(): + logger.warning("Empty query provided") + return [] + + try: + logger.info( + f"UniProt retrieve called with query='{query}', expanded_queries={expanded_queries}, use_rrf={cfg.vector_config.use_rrf}" + ) + + if ( + cfg.vector_config.use_rrf + and expanded_queries + and len(expanded_queries) > 1 + ): + logger.info(f"Using RRF with {len(expanded_queries)} expanded queries") + return await self._search_with_rrf(query, cfg, expanded_queries) + elif expanded_queries and len(expanded_queries) == 1: + logger.info(f"Using single expanded query: '{expanded_queries[0]}'") + return await self._search_simple(expanded_queries[0], cfg) + else: + logger.info(f"Using simple search with main query: '{query}'") + return await self._search_simple(query, cfg) + except Exception as e: + logger.error(f"Error during retrieval: {e}") + return [] + + async def _search_with_rrf( + self, + query: str, + cfg: UniProtRetrievalConfig, + expanded_queries: List[str], + ) -> List[str]: + """Search documents using Reciprocal Rank Fusion with parallel query processing.""" + tasks = [] + + for expanded_query in expanded_queries: + tasks.append( + self._search_vectorstore( + expanded_query, + k=cfg.vector_config.rrf_per_query_k, + alpha=cfg.vector_config.rrf_alpha, + ) + ) + + if self._bm25_retriever: + tasks.append( + self._search_bm25( + expanded_query, k=cfg.vector_config.rrf_per_query_k + ) + ) + + logger.info(f"Executing {len(tasks)} search tasks in parallel for RRF") + ranked_lists = await asyncio.gather(*tasks) + + for i, ranked_list in enumerate(ranked_lists): + logger.info(f"Search {i+1} returned {len(ranked_list)} results") + if ranked_list: + first_doc = ranked_list[0] + doc_id = first_doc.metadata.get( + "url", first_doc.metadata.get("id", hash(first_doc.page_content)) + ) + logger.info(f" First result ID: {doc_id}") + logger.info( + f" First result content: {first_doc.page_content[:100]}..." + ) + + # Apply RRF to combine all ranked lists + logger.info( + f"Applying RRF with final_k={cfg.vector_config.rrf_final_k}, lambda={cfg.vector_config.rrf_lambda}" + ) + top_docs, _, _ = reciprocal_rank_fusion( + ranked_lists=ranked_lists, + final_k=cfg.vector_config.rrf_final_k, + lambda_mult=cfg.vector_config.rrf_lambda, + rrf_k=cfg.vector_config.rrf_cutoff_k, + id_getter=lambda doc: doc.metadata.get( + "url", doc.metadata.get("id", hash(doc.page_content)) + ), + ) + + logger.info(f"RRF returned {len(top_docs)} final results") + + return [doc.page_content for doc in top_docs] + + async def _search_simple( + self, query: str, cfg: UniProtRetrievalConfig + ) -> List[str]: + """Search documents using simple similarity search.""" + top_docs = await self._search_vectorstore( + query=query, + k=cfg.vector_config.rrf_final_k, + alpha=cfg.vector_config.alpha, + ) + + return [doc.page_content for doc in top_docs] + + async def _search_vectorstore( + self, + query: str, + k: int, + alpha: Optional[float] = None, + ) -> List[Document]: + """Search vectorstore using asyncio.to_thread.""" + return await asyncio.to_thread(self._search_vectorstore_sync, query, k, alpha) + + async def _search_bm25( + self, + query: str, + k: int, + ) -> List[Document]: + """Search BM25 retriever using asyncio.to_thread.""" + return await asyncio.to_thread(self._search_bm25_sync, query, k) + + def _search_vectorstore_sync( + self, + query: str, + k: int, + alpha: Optional[float] = None, + ) -> List[Document]: + """Search the UniProt vectorstore.""" + if not self._vectorstore: + logger.error("Vectorstore not initialized") + return [] + + try: + if alpha is not None: + docs_with_scores = self._vectorstore.similarity_search_with_score( + query, k=k + ) + # Filter by score threshold (alpha) - higher scores are better + docs = [doc for doc, score in docs_with_scores if score >= alpha] + else: + docs = self._vectorstore.similarity_search(query, k=k) + + for doc in docs: + doc.metadata["search_type"] = "vector" + + return docs[:k] + except Exception as e: + logger.error(f"Error searching vectorstore: {e}") + return [] + + def _search_bm25_sync( + self, + query: str, + k: int, + ) -> List[Document]: + """Search the UniProt BM25 retriever.""" + if not self._bm25_retriever: + logger.debug("BM25 retriever not available") + return [] + + try: + self._bm25_retriever.k = k + docs = self._bm25_retriever.get_relevant_documents(query) + + for doc in docs: + doc.metadata["search_type"] = "bm25" + + return docs[:k] + except Exception as e: + logger.error(f"Error searching BM25 retriever: {e}") + return [] + + def get_subdirectory(self) -> Optional[str]: + """Get the current subdirectory being used.""" + return self._subdirectory + + def is_initialized(self) -> bool: + """Check if the retriever is properly initialized.""" + return self._vectorstore is not None and ( + self._bm25_retriever is not None or self._vectorstore is not None + ) + + def close(self) -> None: + """Close all connections and clear caches.""" + self._vectorstore = None + self._bm25_retriever = None + self._subdirectory = None + logger.info("UniProt retriever closed") + + def __enter__(self) -> "UniProtRetriever": + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> bool: + """Context manager exit.""" + self.close() + return False + + def __repr__(self) -> str: + """String representation of the retriever.""" + status = "initialized" if self.is_initialized() else "not initialized" + subdir = self._subdirectory or "unknown" + return f"UniProtRetriever(subdirectory='{subdir}', status='{status}')" diff --git a/src/retrievers/retrieval_utils.py b/src/retrievers/retrieval_utils.py index 90f3e2e..402e7a7 100644 --- a/src/retrievers/retrieval_utils.py +++ b/src/retrievers/retrieval_utils.py @@ -10,7 +10,7 @@ def reciprocal_rank_fusion( id_getter: Callable[[Any], str] = lambda doc: doc.metadata.get("stId") or doc.metadata.get("stable_id"), ) -> Tuple[List[Any], List[str], Dict[str, float]]: - rrf_scores = defaultdict(float) + rrf_scores: defaultdict[str, float] = defaultdict(float) doc_meta = {} for ranked in ranked_lists: From f35f3e01ec8481cbafc9eaaff506c28619d8ff77 Mon Sep 17 00:00:00 2001 From: Helia Mohammadi Date: Sun, 28 Sep 2025 15:06:00 -0400 Subject: [PATCH 07/13] remove: Remove reactome_kg directory from repository --- .../reactome_kg/create_test_embeddings.py | 315 ------------------ 1 file changed, 315 deletions(-) delete mode 100644 src/data_generation/reactome_kg/create_test_embeddings.py diff --git a/src/data_generation/reactome_kg/create_test_embeddings.py b/src/data_generation/reactome_kg/create_test_embeddings.py deleted file mode 100644 index 9c5c478..0000000 --- a/src/data_generation/reactome_kg/create_test_embeddings.py +++ /dev/null @@ -1,315 +0,0 @@ -import logging -import os -import sys -import time -from typing import Any, Dict, List - -import weaviate -from langchain_openai import OpenAIEmbeddings - -# Test-specific configuration -TEST_NEO4J_URI = "bolt://localhost:7690" -TEST_NEO4J_USER = "neo4j" -TEST_NEO4J_PASSWORD = "reactome-test" -TEST_NEO4J_DATABASE = "reactome-kg-test" - -TEST_WEAVIATE_HOST = "localhost" -TEST_WEAVIATE_PORT = 8081 -TEST_WEAVIATE_CLASS_NAME = "TestReactomeKG" - -# TP53-related terms for filtering -TP53_TERMS = [ - "TP53", - "p53", - "tumor protein p53", - "cellular tumor antigen p53", - "transformation-related protein 53", - "tumor suppressor p53", -] - -# Batch processing -BATCH_SIZE = 10 -MIN_TEXT_LENGTH = 50 - - -def _setup_logging() -> logging.Logger: - """Set up logging configuration.""" - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - ) - return logging.getLogger(__name__) - - -def _connect_to_services() -> tuple: - """Connect to Neo4j and Weaviate services.""" - logger = logging.getLogger(__name__) - - # Connect to Neo4j with retry logic - neo4j_driver = None - for attempt in range(10): # Try up to 10 times - try: - from neo4j import GraphDatabase - - neo4j_driver = GraphDatabase.driver( - TEST_NEO4J_URI, auth=(TEST_NEO4J_USER, TEST_NEO4J_PASSWORD) - ) - neo4j_driver.verify_connectivity() - logger.info(f"Connected to Neo4j at {TEST_NEO4J_URI}") - break - except Exception as e: - if attempt == 9: # Last attempt - logger.error(f"Failed to connect to Neo4j after 10 attempts: {e}") - raise - logger.info( - f"Neo4j connection attempt {attempt + 1}/10 failed, retrying in 10 seconds..." - ) - time.sleep(10) - - if neo4j_driver is None: - raise Exception("Failed to establish Neo4j connection") - - # Connect to Weaviate - try: - weaviate_client = weaviate.WeaviateClient( - connection_params=weaviate.connect.ConnectionParams.from_url( - f"http://{TEST_WEAVIATE_HOST}:{TEST_WEAVIATE_PORT}", grpc_port=50052 - ), - additional_headers={"X-OpenAI-Api-Key": os.environ.get("OPENAI_API_KEY")}, - ) - logger.info( - f"Connected to Weaviate at {TEST_WEAVIATE_HOST}:{TEST_WEAVIATE_PORT}" - ) - except Exception as e: - logger.error(f"Failed to connect to Weaviate: {e}") - raise - - return neo4j_driver, weaviate_client - - -def _is_tp53_related(text: str) -> bool: - """Check if text content is related to TP53.""" - if not text: - return False - - text_lower = str(text).lower() - return any(term.lower() in text_lower for term in TP53_TERMS) - - -def _fetch_tp53_nodes(neo4j_driver) -> List[Dict[str, Any]]: - """Fetch TP53-related nodes from Neo4j.""" - logger = logging.getLogger(__name__) - - query = """ - MATCH (n) - WHERE n.name CONTAINS 'TP53' OR n.name CONTAINS 'p53' - OR n.displayName CONTAINS 'TP53' OR n.displayName CONTAINS 'p53' - OR n.description CONTAINS 'TP53' OR n.description CONTAINS 'p53' - OR n.text_content CONTAINS 'TP53' OR n.text_content CONTAINS 'p53' - OR n.stableId CONTAINS 'TP53' - RETURN n.stableId as stable_id, - labels(n)[0] as label, - n.name as name, - n.displayName as display_name, - n.description as description, - n.text_content as text_content - LIMIT 200 - """ - - try: - with neo4j_driver.session(database=TEST_NEO4J_DATABASE) as session: - result = session.run(query) - nodes = [] - - for record in result: - node_data = { - "stable_id": record["stable_id"], - "label": record["label"], - "name": record["name"], - "display_name": record["display_name"], - "description": record["description"], - "text_content": record["text_content"], - } - - # Additional TP53 filtering - combined_text = " ".join( - [ - node_data.get("name", ""), - node_data.get("description", ""), - node_data.get("text_content", ""), - node_data.get("stable_id", ""), - ] - ) - - if _is_tp53_related(combined_text): - nodes.append(node_data) - - logger.info(f"Found {len(nodes)} TP53-related nodes") - return nodes - - except Exception as e: - logger.error(f"Error fetching TP53 nodes: {e}") - raise - - -def _create_embeddings_and_store(nodes: List[Dict[str, Any]], weaviate_client) -> int: - """Create embeddings for nodes and store in Weaviate.""" - logger = logging.getLogger(__name__) - - # Initialize embeddings - embeddings = OpenAIEmbeddings(model="text-embedding-3-large") - - # Create Weaviate schema if it doesn't exist - try: - if weaviate_client.schema.exists(TEST_WEAVIATE_CLASS_NAME): - logger.info(f"Deleting existing schema: {TEST_WEAVIATE_CLASS_NAME}") - weaviate_client.schema.delete_class(TEST_WEAVIATE_CLASS_NAME) - except Exception as e: - logger.warning(f"Error checking/deleting schema: {e}") - - # Create new schema - schema = { - "class": TEST_WEAVIATE_CLASS_NAME, - "description": "Test Reactome KG embeddings for TP53-related entities", - "vectorizer": "text2vec-openai", - "moduleConfig": { - "text2vec-openai": { - "model": "text-embedding-3-large", - "modelVersion": "002", - "dimensions": 3072, - "type": "text", - } - }, - "properties": [ - { - "name": "stable_id", - "dataType": ["string"], - "description": "Reactome stable identifier", - }, - {"name": "label", "dataType": ["string"], "description": "Node label/type"}, - {"name": "name", "dataType": ["string"], "description": "Node name"}, - { - "name": "display_name", - "dataType": ["string"], - "description": "Display name", - }, - { - "name": "description", - "dataType": ["string"], - "description": "Node description", - }, - { - "name": "text_content", - "dataType": ["text"], - "description": "Full text content for embedding", - }, - ], - } - - try: - weaviate_client.schema.create_class(schema) - logger.info(f"Created Weaviate schema: {TEST_WEAVIATE_CLASS_NAME}") - except Exception as e: - logger.error(f"Error creating schema: {e}") - raise - - # Process nodes in batches - processed_count = 0 - - for i in range(0, len(nodes), BATCH_SIZE): - batch = nodes[i : i + BATCH_SIZE] - batch_objects = [] - - for node in batch: - # Create text content for embedding - text_parts = [] - if node.get("name"): - text_parts.append(f"Name: {node['name']}") - if node.get("display_name"): - text_parts.append(f"Display Name: {node['display_name']}") - if node.get("description"): - text_parts.append(f"Description: {node['description']}") - if node.get("text_content"): - text_parts.append(f"Content: {node['text_content']}") - - text_content = " | ".join(text_parts) - - if len(text_content) < MIN_TEXT_LENGTH: - continue - - # Create Weaviate object - obj = { - "stable_id": node["stable_id"], - "label": node["label"], - "name": node.get("name", ""), - "display_name": node.get("display_name", ""), - "description": node.get("description", ""), - "text_content": text_content, - } - - batch_objects.append(obj) - - # Store batch in Weaviate - if batch_objects: - try: - with weaviate_client.batch as batch: - for obj in batch_objects: - batch.add_data_object( - data_object=obj, class_name=TEST_WEAVIATE_CLASS_NAME - ) - - processed_count += len(batch_objects) - logger.info( - f"Processed batch {i//BATCH_SIZE + 1}: {len(batch_objects)} objects" - ) - - except Exception as e: - logger.error(f"Error storing batch: {e}") - continue - - logger.info(f"Total objects processed and stored: {processed_count}") - return processed_count - - -def main(): - """Main function to create test embeddings.""" - logger = _setup_logging() - - try: - logger.info("Starting TP53-focused test embedding creation...") - logger.info(f"Neo4j: {TEST_NEO4J_URI} (database: {TEST_NEO4J_DATABASE})") - logger.info(f"Weaviate: {TEST_WEAVIATE_HOST}:{TEST_WEAVIATE_PORT}") - - # Connect to services - neo4j_driver, weaviate_client = _connect_to_services() - - # Fetch TP53 nodes - nodes = _fetch_tp53_nodes(neo4j_driver) - - if not nodes: - logger.warning("No TP53-related nodes found!") - return - - # Create and store embeddings - processed_count = _create_embeddings_and_store(nodes, weaviate_client) - - logger.info("=" * 60) - logger.info("TEST EMBEDDING CREATION SUMMARY") - logger.info("=" * 60) - logger.info(f"Total TP53 nodes found: {len(nodes)}") - logger.info(f"Total embeddings created: {processed_count}") - logger.info(f"Weaviate class: {TEST_WEAVIATE_CLASS_NAME}") - logger.info("=" * 60) - - logger.info("Test embedding creation completed successfully!") - - except Exception as e: - logger.error(f"Error creating test embeddings: {e}") - sys.exit(1) - finally: - if "neo4j_driver" in locals(): - neo4j_driver.close() - - -if __name__ == "__main__": - main() From ba0193130411e1125375ae1d38c62703d13eaff5 Mon Sep 17 00:00:00 2001 From: Helia Mohammadi Date: Sun, 28 Sep 2025 15:08:47 -0400 Subject: [PATCH 08/13] code quality fixes --- src/agent/profiles/base.py | 30 ++++++++++-------------------- src/agent/profiles/react_to_me.py | 26 ++++++-------------------- src/retrievers/csv_chroma.py | 27 ++++++--------------------- 3 files changed, 22 insertions(+), 61 deletions(-) diff --git a/src/agent/profiles/base.py b/src/agent/profiles/base.py index e583304..5b4051b 100644 --- a/src/agent/profiles/base.py +++ b/src/agent/profiles/base.py @@ -52,20 +52,12 @@ class BaseState(InputState, OutputState, total=False): class BaseGraphBuilder: """Base class for all graph builders with common preprocessing and postprocessing.""" - def __init__( - self, - llm: BaseChatModel, - embedding: Embeddings - ) -> None: + def __init__(self, llm: BaseChatModel, embedding: Embeddings) -> None: """Initialize with LLM and embedding models.""" self.preprocessing_workflow: Runnable = create_preprocessing_workflow(llm) self.search_workflow: Runnable = create_search_workflow(llm) - async def preprocess( - self, - state: BaseState, - config: RunnableConfig - ) -> BaseState: + async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: """Run the complete preprocessing workflow and map results to state.""" result: PreprocessingState = await self.preprocessing_workflow.ainvoke( PreprocessingState( @@ -77,10 +69,7 @@ async def preprocess( return self._map_preprocessing_result(result) - def _map_preprocessing_result( - self, - result: PreprocessingState - ) -> BaseState: + def _map_preprocessing_result(self, result: PreprocessingState) -> BaseState: """Map preprocessing results to BaseState with defaults.""" return BaseState( rephrased_input=result["rephrased_input"], @@ -90,11 +79,7 @@ def _map_preprocessing_result( detected_language=result.get("detected_language", DEFAULT_LANGUAGE), ) - async def postprocess( - self, - state: BaseState, - config: RunnableConfig - ) -> BaseState: + async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: """Postprocess that preserves existing state and conditionally adds search results.""" search_results: list[WebSearchResult] = [] @@ -113,4 +98,9 @@ async def postprocess( search_results = result["search_results"] # Create new state with updated additional_content - return BaseState(**{**state, "additional_content": AdditionalContent(search_results=search_results)}) + return BaseState( + **{ + **state, + "additional_content": AdditionalContent(search_results=search_results), + } + ) diff --git a/src/agent/profiles/react_to_me.py b/src/agent/profiles/react_to_me.py index d2dab7d..f0e4980 100644 --- a/src/agent/profiles/react_to_me.py +++ b/src/agent/profiles/react_to_me.py @@ -23,11 +23,7 @@ class ReactToMeState(BaseState): class ReactToMeGraphBuilder(BaseGraphBuilder): """Graph builder for ReactToMe profile with Reactome-specific functionality.""" - def __init__( - self, - llm: BaseChatModel, - embedding: Embeddings - ) -> None: + def __init__(self, llm: BaseChatModel, embedding: Embeddings) -> None: """Initialize ReactToMe graph builder with required components.""" super().__init__(llm, embedding) @@ -69,25 +65,20 @@ def _build_workflow(self) -> StateGraph: return state_graph async def preprocess( - self, - state: ReactToMeState, - config: RunnableConfig + self, state: ReactToMeState, config: RunnableConfig ) -> ReactToMeState: """Run preprocessing workflow.""" result = await super().preprocess(state, config) return ReactToMeState(**result) async def proceed_with_research( - self, - state: ReactToMeState + self, state: ReactToMeState ) -> Literal["Continue", "Finish"]: """Determine whether to proceed with research based on safety check.""" return "Continue" if state["safety"] == SAFETY_SAFE else "Finish" async def generate_unsafe_response( - self, - state: ReactToMeState, - config: RunnableConfig + self, state: ReactToMeState, config: RunnableConfig ) -> ReactToMeState: """Generate appropriate refusal response for unsafe queries.""" final_answer_message = await self.unsafe_answer_generator.ainvoke( @@ -120,9 +111,7 @@ async def generate_unsafe_response( ) async def call_model( - self, - state: ReactToMeState, - config: RunnableConfig + self, state: ReactToMeState, config: RunnableConfig ) -> ReactToMeState: """Generate response using Reactome RAG for safe queries.""" result: dict[str, Any] = await self.reactome_rag.ainvoke( @@ -147,9 +136,6 @@ async def call_model( ) -def create_reactome_graph( - llm: BaseChatModel, - embedding: Embeddings - ) -> StateGraph: +def create_reactome_graph(llm: BaseChatModel, embedding: Embeddings) -> StateGraph: """Create and return the ReactToMe workflow graph.""" return ReactToMeGraphBuilder(llm, embedding).uncompiled_graph diff --git a/src/retrievers/csv_chroma.py b/src/retrievers/csv_chroma.py index 3d96c12..168fe42 100644 --- a/src/retrievers/csv_chroma.py +++ b/src/retrievers/csv_chroma.py @@ -77,11 +77,7 @@ def list_chroma_subdirectories(directory: Path) -> List[str]: class HybridRetriever: """Advanced hybrid retriever supporting RRF, parallel processing, and multi-source search.""" - def __init__( - self, - embedding: Embeddings, - embeddings_directory: Path - ): + def __init__(self, embedding: Embeddings, embeddings_directory: Path): self.embedding = embedding self.embeddings_directory = embeddings_directory @@ -159,25 +155,17 @@ def _create_vector_retriever(self, subdirectory: str) -> Optional[object]: return None async def _search_with_bm25( - self, - query: str, - retriever: BM25Retriever + self, query: str, retriever: BM25Retriever ) -> List[Document]: """Search using BM25 retriever asynchronously.""" return await asyncio.to_thread(retriever.get_relevant_documents, query) - async def _search_with_vector( - self, - query: str, - retriever: Any - ) -> List[Document]: + async def _search_with_vector(self, query: str, retriever: Any) -> List[Document]: """Search using vector retriever asynchronously.""" return await asyncio.to_thread(retriever.get_relevant_documents, query) async def _execute_hybrid_search( - self, - query: str, - subdirectory: str + self, query: str, subdirectory: str ) -> List[Document]: """Execute hybrid search (BM25 + vector) for a single query on a subdirectory.""" retriever_info = self._retrievers.get(subdirectory) @@ -223,9 +211,7 @@ def _generate_document_identifier(self, document: Document) -> str: return hashlib.md5(document.page_content.encode()).hexdigest() async def _apply_reciprocal_rank_fusion( - self, - queries: List[str], - subdirectory: str + self, queries: List[str], subdirectory: str ) -> List[Document]: """Apply Reciprocal Rank Fusion to results from multiple queries on a subdirectory.""" logger.info( @@ -325,8 +311,7 @@ async def ainvoke(self, inputs: Dict[str, Any]) -> str: def create_hybrid_retriever( - embedding: Embeddings, - embeddings_directory: Path + embedding: Embeddings, embeddings_directory: Path ) -> HybridRetriever: """Create a hybrid retriever with RRF and parallel processing support.""" try: From 7f8d4c56e0c8f01b44a937c79755dbb307fc63a1 Mon Sep 17 00:00:00 2001 From: Helia Mohammadi Date: Sun, 28 Sep 2025 16:17:23 -0400 Subject: [PATCH 09/13] feat: expnad preprocessing to a multi-step workflow. - Implement parallel execution of safety and scope check, query expansion, and language detection --- src/agent/profiles/react_to_me.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/agent/profiles/react_to_me.py b/src/agent/profiles/react_to_me.py index f0e4980..d353e57 100644 --- a/src/agent/profiles/react_to_me.py +++ b/src/agent/profiles/react_to_me.py @@ -45,13 +45,13 @@ def _build_workflow(self) -> StateGraph: """Build and configure the ReactToMe workflow graph.""" state_graph = StateGraph(ReactToMeState) - # Add workflow nodes + # Add nodes state_graph.add_node("preprocess", self.preprocess) state_graph.add_node("model", self.call_model) state_graph.add_node("generate_unsafe_response", self.generate_unsafe_response) state_graph.add_node("postprocess", self.postprocess) - # Configure workflow edges + # Add edges state_graph.set_entry_point("preprocess") state_graph.add_conditional_edges( "preprocess", From 67fcd60f0946ff290cdbf17da7a6158ce2578432 Mon Sep 17 00:00:00 2001 From: Helia Mohammadi Date: Sun, 28 Sep 2025 16:18:18 -0400 Subject: [PATCH 10/13] feat: expnad preprocessing to a multi-step workflow. - Implement parallel execution of safety and scope check, query expansion, and language detection --- src/agent/profiles/base.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/agent/profiles/base.py b/src/agent/profiles/base.py index 5b4051b..2afd0e5 100644 --- a/src/agent/profiles/base.py +++ b/src/agent/profiles/base.py @@ -99,8 +99,6 @@ async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseSta # Create new state with updated additional_content return BaseState( - **{ - **state, - "additional_content": AdditionalContent(search_results=search_results), - } + **state, # Copy existing state + additional_content=AdditionalContent(search_results=search_results) ) From 3ea2ba86527bf191c8b386ef59268ae5db8f3886 Mon Sep 17 00:00:00 2001 From: Helia Mohammadi Date: Sun, 28 Sep 2025 16:20:12 -0400 Subject: [PATCH 11/13] feat:improved hybrid retrieval - Replace SelfQueryRetriever with efficient hybrid search (BM25 + vector) - Add RRF (Reciprocal Rank Fusion) support for query expansion - Implement parallel processing for improved performance --- src/retrievers/csv_chroma.py | 18 +++++++++++------- src/retrievers/retrieval_utils.py | 2 +- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/retrievers/csv_chroma.py b/src/retrievers/csv_chroma.py index 168fe42..c11b01f 100644 --- a/src/retrievers/csv_chroma.py +++ b/src/retrievers/csv_chroma.py @@ -10,6 +10,7 @@ from langchain_community.retrievers import BM25Retriever from langchain_core.documents import Document from langchain_core.embeddings import Embeddings +from langchain_core.retrievers import BaseRetriever from retrievers.retrieval_utils import reciprocal_rank_fusion @@ -47,12 +48,12 @@ def create_documents_from_csv(csv_path: Path) -> List[Document]: page_content = "\n".join(content_parts) metadata = { - str(column): str(value) + column: str(value) for column in df.columns for value in [row[column]] if pd.notna(value) and str(value) != "nan" } - metadata.update({"source": str(csv_path), "row_index": index}) + metadata.update({"source": str(csv_path), "row_index": str(index)}) documents.append(Document(page_content=page_content, metadata=metadata)) @@ -78,11 +79,10 @@ class HybridRetriever: """Advanced hybrid retriever supporting RRF, parallel processing, and multi-source search.""" def __init__(self, embedding: Embeddings, embeddings_directory: Path): - self.embedding = embedding self.embeddings_directory = embeddings_directory self._retrievers: Dict[ - str, Dict[str, Optional[Union[BM25Retriever, object]]] + str, Dict[str, Optional[Union[BM25Retriever, BaseRetriever]]] ] = {} try: @@ -129,7 +129,7 @@ def _create_bm25_retriever(self, subdirectory: str) -> Optional[BM25Retriever]: logger.error(f"Failed to create BM25 retriever for {subdirectory}: {e}") return None - def _create_vector_retriever(self, subdirectory: str) -> Optional[object]: + def _create_vector_retriever(self, subdirectory: str) -> Optional[BaseRetriever]: """Create vector retriever for a specific subdirectory.""" vector_directory = self.embeddings_directory / subdirectory @@ -160,7 +160,9 @@ async def _search_with_bm25( """Search using BM25 retriever asynchronously.""" return await asyncio.to_thread(retriever.get_relevant_documents, query) - async def _search_with_vector(self, query: str, retriever: Any) -> List[Document]: + async def _search_with_vector( + self, query: str, retriever: BaseRetriever + ) -> List[Document]: """Search using vector retriever asynchronously.""" return await asyncio.to_thread(retriever.get_relevant_documents, query) @@ -178,7 +180,9 @@ async def _execute_hybrid_search( if retriever_info["bm25"] and isinstance(retriever_info["bm25"], BM25Retriever): search_tasks.append(self._search_with_bm25(query, retriever_info["bm25"])) - if retriever_info["vector"]: + if retriever_info["vector"] and isinstance( + retriever_info["vector"], BaseRetriever + ): search_tasks.append( self._search_with_vector(query, retriever_info["vector"]) ) diff --git a/src/retrievers/retrieval_utils.py b/src/retrievers/retrieval_utils.py index 402e7a7..28c3015 100644 --- a/src/retrievers/retrieval_utils.py +++ b/src/retrievers/retrieval_utils.py @@ -10,7 +10,7 @@ def reciprocal_rank_fusion( id_getter: Callable[[Any], str] = lambda doc: doc.metadata.get("stId") or doc.metadata.get("stable_id"), ) -> Tuple[List[Any], List[str], Dict[str, float]]: - rrf_scores: defaultdict[str, float] = defaultdict(float) + rrf_scores: Dict[str, float] = defaultdict(float) doc_meta = {} for ranked in ranked_lists: From 5b821999e24c07317669afc3f3871df8e71e6970 Mon Sep 17 00:00:00 2001 From: Helia Mohammadi Date: Sun, 28 Sep 2025 16:22:33 -0400 Subject: [PATCH 12/13] feat:improved answer generation, in-line citation handling and hallucination mitigation --- src/retrievers/reactome/prompt.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/retrievers/reactome/prompt.py b/src/retrievers/reactome/prompt.py index 43c43f5..d570cb9 100644 --- a/src/retrievers/reactome/prompt.py +++ b/src/retrievers/reactome/prompt.py @@ -10,7 +10,9 @@ ## **Answering Guidelines** 1. Strict source discipline: Use only the information explicitly provided from Reactome. Do not invent, infer, or draw from external knowledge. - - If the answer cannot be derived from the context, explicitly state that the information is not currently available in Reactome. + - Use only information directly found in Reactome. + - Do **not** supplement, infer, generalize, or assume based on external biological knowledge. + - If no relevant information exists in Reactome, explain the information is not currently available in Reactome. Do **not** answer the question. 2. Inline citations required: Every factual statement must include ≥1 inline anchor citation in the format: display_name - If multiple entries support the same fact, cite them together (space-separated). 3. Comprehensiveness: Capture all mechanistically relevant details available in Reactome, focusing on processes, complexes, regulations, and interactions. From 27e761c610d7c96783b3e0433e9cc5b3563349d0 Mon Sep 17 00:00:00 2001 From: Helia Mohammadi Date: Sun, 28 Sep 2025 16:27:29 -0400 Subject: [PATCH 13/13] remove irrelevant docs --- src/retrievers/graph_rag/uniprot_retriever.py | 376 ------------------ 1 file changed, 376 deletions(-) delete mode 100644 src/retrievers/graph_rag/uniprot_retriever.py diff --git a/src/retrievers/graph_rag/uniprot_retriever.py b/src/retrievers/graph_rag/uniprot_retriever.py deleted file mode 100644 index 3a6af8d..0000000 --- a/src/retrievers/graph_rag/uniprot_retriever.py +++ /dev/null @@ -1,376 +0,0 @@ -from __future__ import annotations - -import asyncio -import logging -import os -from pathlib import Path -from typing import List, Optional - -import chromadb.config -from langchain_chroma.vectorstores import Chroma -from langchain_community.document_loaders.csv_loader import CSVLoader -from langchain_community.retrievers import BM25Retriever -from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings - -from src.retrievers.csv_chroma import list_chroma_subdirectories -from src.util.embedding_environment import EmbeddingEnvironment - -from .retrieval_utils import UniProtRetrievalConfig, reciprocal_rank_fusion - -logger = logging.getLogger(__name__) - -CHROMA_SETTINGS = chromadb.config.Settings(anonymized_telemetry=False) - - -class UniProtRetriever: - """ - UniProt vector retriever that supports RRF and similarity search. - - This retriever provides the same configuration options as the Graph RAG - retriever but operates only on vector embeddings without graph traversal. - Returns page content strings (protein information only, no metadata). - - Features: - - Single vectorstore using the most recent subdirectory - - Single BM25 retriever using the most recent CSV file - - Reciprocal Rank Fusion (RRF) support - - Hybrid vector + BM25 search - - Clean page content output for LLM consumption - """ - - DEFAULT_BM25_K = 10 - - def __init__( - self, - embedding: Embeddings, - embeddings_directory: Optional[Path] = None, - ) -> None: - """ - Initialize the UniProt vector retriever. - - Args: - embedding: Embedding model for vector operations - embeddings_directory: Path to UniProt embeddings directory. - Defaults to EmbeddingEnvironment.get_dir("uniprot") - - Raises: - ValueError: If embeddings_directory doesn't exist - RuntimeError: If initialization fails - """ - self.embedding = embedding - self.embeddings_directory = ( - embeddings_directory or EmbeddingEnvironment.get_dir("uniprot") - ) - - if not self.embeddings_directory.exists(): - raise ValueError( - f"Embeddings directory does not exist: {self.embeddings_directory}" - ) - - self._vectorstore: Optional[Chroma] = None - self._bm25_retriever: Optional[BM25Retriever] = None - self._subdirectory: Optional[str] = None - - try: - self._initialize_retrievers() - except Exception as e: - logger.error(f"Failed to initialize UniProt retriever: {e}") - raise RuntimeError(f"UniProt retriever initialization failed: {e}") from e - - def _initialize_retrievers(self) -> None: - """Initialize both vectorstore and BM25 retriever.""" - subdirectories = list_chroma_subdirectories(self.embeddings_directory) - - if not subdirectories: - raise RuntimeError("No UniProt subdirectories found") - - # Use the most recently created subdirectory - self._subdirectory = self._get_latest_subdirectory(subdirectories) - - self._initialize_vectorstore() - self._initialize_bm25_retriever() - - logger.info( - f"UniProt retriever initialized successfully using subdirectory: {self._subdirectory}" - ) - - def _initialize_vectorstore(self) -> None: - """Initialize Chroma vectorstore for UniProt.""" - try: - self._vectorstore = Chroma( - persist_directory=str(self.embeddings_directory / self._subdirectory), - embedding_function=self.embedding, - client_settings=CHROMA_SETTINGS, - ) - logger.info(f"Initialized UniProt vectorstore: {self._subdirectory}") - except Exception as e: - logger.error(f"Failed to initialize vectorstore: {e}") - raise - - def _initialize_bm25_retriever(self) -> None: - """Initialize BM25 retriever for UniProt.""" - try: - csv_file_name = f"{self._subdirectory}.csv" - csvs_dir = self.embeddings_directory / "csv_files" - csv_path = csvs_dir / csv_file_name - - if not csv_path.exists(): - logger.warning(f"CSV file not found: {csv_path}") - return - - loader = CSVLoader(file_path=str(csv_path)) - data = loader.load() - - if not data: - logger.warning(f"No data loaded from CSV: {csv_path}") - return - - self._bm25_retriever = BM25Retriever.from_documents(data) - self._bm25_retriever.k = self.DEFAULT_BM25_K - - logger.info(f"Initialized UniProt BM25 retriever: {self._subdirectory}") - except Exception as e: - logger.error(f"Failed to initialize BM25 retriever: {e}") - - def _get_latest_subdirectory(self, subdirectories: List[str]) -> str: - """ - Get the most recently created subdirectory. - - Args: - subdirectories: List of subdirectory names - - Returns: - Name of the most recently created subdirectory - """ - subdir_times = [] - - for subdir in subdirectories: - subdir_path = self.embeddings_directory / subdir - if subdir_path.exists(): - mtime = os.path.getmtime(subdir_path) - subdir_times.append((subdir, mtime)) - - if not subdir_times: - logger.warning("No valid subdirectories found, using first available") - return subdirectories[0] - - # Sort by modification time (most recent first) and return the latest - subdir_times.sort(key=lambda x: x[1], reverse=True) - latest_subdir = subdir_times[0][0] - - logger.info(f"Using most recent UniProt subdirectory: {latest_subdir}") - return latest_subdir - - async def ainvoke( - self, - query: str, - cfg: UniProtRetrievalConfig, - expanded_queries: Optional[List[str]] = None, - ) -> List[str]: - """ - Invoke the UniProt retrieval pipeline. - - Args: - query: Search query - cfg: Retrieval configuration - expanded_queries: Optional list of expanded queries for RRF - - Returns: - List of page content strings (protein information only, no metadata) - """ - if not query.strip(): - logger.warning("Empty query provided") - return [] - - try: - logger.info( - f"UniProt retrieve called with query='{query}', expanded_queries={expanded_queries}, use_rrf={cfg.vector_config.use_rrf}" - ) - - if ( - cfg.vector_config.use_rrf - and expanded_queries - and len(expanded_queries) > 1 - ): - logger.info(f"Using RRF with {len(expanded_queries)} expanded queries") - return await self._search_with_rrf(query, cfg, expanded_queries) - elif expanded_queries and len(expanded_queries) == 1: - logger.info(f"Using single expanded query: '{expanded_queries[0]}'") - return await self._search_simple(expanded_queries[0], cfg) - else: - logger.info(f"Using simple search with main query: '{query}'") - return await self._search_simple(query, cfg) - except Exception as e: - logger.error(f"Error during retrieval: {e}") - return [] - - async def _search_with_rrf( - self, - query: str, - cfg: UniProtRetrievalConfig, - expanded_queries: List[str], - ) -> List[str]: - """Search documents using Reciprocal Rank Fusion with parallel query processing.""" - tasks = [] - - for expanded_query in expanded_queries: - tasks.append( - self._search_vectorstore( - expanded_query, - k=cfg.vector_config.rrf_per_query_k, - alpha=cfg.vector_config.rrf_alpha, - ) - ) - - if self._bm25_retriever: - tasks.append( - self._search_bm25( - expanded_query, k=cfg.vector_config.rrf_per_query_k - ) - ) - - logger.info(f"Executing {len(tasks)} search tasks in parallel for RRF") - ranked_lists = await asyncio.gather(*tasks) - - for i, ranked_list in enumerate(ranked_lists): - logger.info(f"Search {i+1} returned {len(ranked_list)} results") - if ranked_list: - first_doc = ranked_list[0] - doc_id = first_doc.metadata.get( - "url", first_doc.metadata.get("id", hash(first_doc.page_content)) - ) - logger.info(f" First result ID: {doc_id}") - logger.info( - f" First result content: {first_doc.page_content[:100]}..." - ) - - # Apply RRF to combine all ranked lists - logger.info( - f"Applying RRF with final_k={cfg.vector_config.rrf_final_k}, lambda={cfg.vector_config.rrf_lambda}" - ) - top_docs, _, _ = reciprocal_rank_fusion( - ranked_lists=ranked_lists, - final_k=cfg.vector_config.rrf_final_k, - lambda_mult=cfg.vector_config.rrf_lambda, - rrf_k=cfg.vector_config.rrf_cutoff_k, - id_getter=lambda doc: doc.metadata.get( - "url", doc.metadata.get("id", hash(doc.page_content)) - ), - ) - - logger.info(f"RRF returned {len(top_docs)} final results") - - return [doc.page_content for doc in top_docs] - - async def _search_simple( - self, query: str, cfg: UniProtRetrievalConfig - ) -> List[str]: - """Search documents using simple similarity search.""" - top_docs = await self._search_vectorstore( - query=query, - k=cfg.vector_config.rrf_final_k, - alpha=cfg.vector_config.alpha, - ) - - return [doc.page_content for doc in top_docs] - - async def _search_vectorstore( - self, - query: str, - k: int, - alpha: Optional[float] = None, - ) -> List[Document]: - """Search vectorstore using asyncio.to_thread.""" - return await asyncio.to_thread(self._search_vectorstore_sync, query, k, alpha) - - async def _search_bm25( - self, - query: str, - k: int, - ) -> List[Document]: - """Search BM25 retriever using asyncio.to_thread.""" - return await asyncio.to_thread(self._search_bm25_sync, query, k) - - def _search_vectorstore_sync( - self, - query: str, - k: int, - alpha: Optional[float] = None, - ) -> List[Document]: - """Search the UniProt vectorstore.""" - if not self._vectorstore: - logger.error("Vectorstore not initialized") - return [] - - try: - if alpha is not None: - docs_with_scores = self._vectorstore.similarity_search_with_score( - query, k=k - ) - # Filter by score threshold (alpha) - higher scores are better - docs = [doc for doc, score in docs_with_scores if score >= alpha] - else: - docs = self._vectorstore.similarity_search(query, k=k) - - for doc in docs: - doc.metadata["search_type"] = "vector" - - return docs[:k] - except Exception as e: - logger.error(f"Error searching vectorstore: {e}") - return [] - - def _search_bm25_sync( - self, - query: str, - k: int, - ) -> List[Document]: - """Search the UniProt BM25 retriever.""" - if not self._bm25_retriever: - logger.debug("BM25 retriever not available") - return [] - - try: - self._bm25_retriever.k = k - docs = self._bm25_retriever.get_relevant_documents(query) - - for doc in docs: - doc.metadata["search_type"] = "bm25" - - return docs[:k] - except Exception as e: - logger.error(f"Error searching BM25 retriever: {e}") - return [] - - def get_subdirectory(self) -> Optional[str]: - """Get the current subdirectory being used.""" - return self._subdirectory - - def is_initialized(self) -> bool: - """Check if the retriever is properly initialized.""" - return self._vectorstore is not None and ( - self._bm25_retriever is not None or self._vectorstore is not None - ) - - def close(self) -> None: - """Close all connections and clear caches.""" - self._vectorstore = None - self._bm25_retriever = None - self._subdirectory = None - logger.info("UniProt retriever closed") - - def __enter__(self) -> "UniProtRetriever": - """Context manager entry.""" - return self - - def __exit__(self, exc_type, exc_val, exc_tb) -> bool: - """Context manager exit.""" - self.close() - return False - - def __repr__(self) -> str: - """String representation of the retriever.""" - status = "initialized" if self.is_initialized() else "not initialized" - subdir = self._subdirectory or "unknown" - return f"UniProtRetriever(subdirectory='{subdir}', status='{status}')"