diff --git a/src/agent/profiles/base.py b/src/agent/profiles/base.py index 9a6e26c..2afd0e5 100644 --- a/src/agent/profiles/base.py +++ b/src/agent/profiles/base.py @@ -1,4 +1,4 @@ -from typing import Annotated, TypedDict +from typing import Annotated, Literal, TypedDict from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel @@ -6,53 +6,88 @@ from langchain_core.runnables import Runnable, RunnableConfig from langgraph.graph.message import add_messages -from agent.tasks.rephrase import create_rephrase_chain from tools.external_search.state import SearchState, WebSearchResult from tools.external_search.workflow import create_search_workflow +from tools.preprocessing.state import PreprocessingState +from tools.preprocessing.workflow import create_preprocessing_workflow + +# Constants +SAFETY_SAFE: Literal["true"] = "true" +SAFETY_UNSAFE: Literal["false"] = "false" +DEFAULT_LANGUAGE: str = "English" class AdditionalContent(TypedDict, total=False): + """Additional content sent on graph completion.""" + search_results: list[WebSearchResult] class InputState(TypedDict, total=False): - user_input: str # User input text + """Input state for user queries.""" + + user_input: str class OutputState(TypedDict, total=False): - answer: str # primary LLM response that is streamed to the user - additional_content: AdditionalContent # sends on graph completion + """Output state for responses.""" + + answer: str + additional_content: AdditionalContent class BaseState(InputState, OutputState, total=False): - rephrased_input: str # LLM-generated query from user input + """Base state containing all common fields for agent workflows.""" + + rephrased_input: str chat_history: Annotated[list[BaseMessage], add_messages] + # Preprocessing results + safety: str + reason_unsafe: str + expanded_queries: list[str] + detected_language: str + class BaseGraphBuilder: - # NOTE: Anything that is common to all graph builders goes here - - def __init__( - self, - llm: BaseChatModel, - embedding: Embeddings, - ) -> None: - self.rephrase_chain: Runnable = create_rephrase_chain(llm) + """Base class for all graph builders with common preprocessing and postprocessing.""" + + def __init__(self, llm: BaseChatModel, embedding: Embeddings) -> None: + """Initialize with LLM and embedding models.""" + self.preprocessing_workflow: Runnable = create_preprocessing_workflow(llm) self.search_workflow: Runnable = create_search_workflow(llm) async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: - rephrased_input: str = await self.rephrase_chain.ainvoke( - { - "user_input": state["user_input"], - "chat_history": state["chat_history"], - }, + """Run the complete preprocessing workflow and map results to state.""" + result: PreprocessingState = await self.preprocessing_workflow.ainvoke( + PreprocessingState( + user_input=state["user_input"], + chat_history=state["chat_history"], + ), config, ) - return BaseState(rephrased_input=rephrased_input) + + return self._map_preprocessing_result(result) + + def _map_preprocessing_result(self, result: PreprocessingState) -> BaseState: + """Map preprocessing results to BaseState with defaults.""" + return BaseState( + rephrased_input=result["rephrased_input"], + safety=result.get("safety", SAFETY_SAFE), + reason_unsafe=result.get("reason_unsafe", ""), + expanded_queries=result.get("expanded_queries", []), + detected_language=result.get("detected_language", DEFAULT_LANGUAGE), + ) async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: + """Postprocess that preserves existing state and conditionally adds search results.""" search_results: list[WebSearchResult] = [] - if config["configurable"]["enable_postprocess"]: + + # Only run external search for safe questions + if ( + state.get("safety") == SAFETY_SAFE + and config["configurable"]["enable_postprocess"] + ): result: SearchState = await self.search_workflow.ainvoke( SearchState( input=state["rephrased_input"], @@ -61,6 +96,9 @@ async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseSta config=RunnableConfig(callbacks=config["callbacks"]), ) search_results = result["search_results"] + + # Create new state with updated additional_content return BaseState( + **state, # Copy existing state additional_content=AdditionalContent(search_results=search_results) ) diff --git a/src/agent/profiles/react_to_me.py b/src/agent/profiles/react_to_me.py index c162ac7..d353e57 100644 --- a/src/agent/profiles/react_to_me.py +++ b/src/agent/profiles/react_to_me.py @@ -1,52 +1,123 @@ -from typing import Any +from typing import Any, Literal from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import AIMessage, HumanMessage from langchain_core.runnables import Runnable, RunnableConfig +from langchain_openai import ChatOpenAI from langgraph.graph.state import StateGraph -from agent.profiles.base import BaseGraphBuilder, BaseState +from agent.profiles.base import (SAFETY_SAFE, SAFETY_UNSAFE, BaseGraphBuilder, + BaseState) +from agent.tasks.final_answer_generation.unsafe_question import \ + create_unsafe_answer_generator from retrievers.reactome.rag import create_reactome_rag class ReactToMeState(BaseState): + """ReactToMe state extends BaseState with all preprocessing results.""" + pass class ReactToMeGraphBuilder(BaseGraphBuilder): - def __init__( - self, - llm: BaseChatModel, - embedding: Embeddings, - ) -> None: + """Graph builder for ReactToMe profile with Reactome-specific functionality.""" + + def __init__(self, llm: BaseChatModel, embedding: Embeddings) -> None: + """Initialize ReactToMe graph builder with required components.""" super().__init__(llm, embedding) - # Create runnables (tasks & tools) + # Create a streaming LLM instance only for final answer generation + streaming_llm = ChatOpenAI( + model=llm.model_name if hasattr(llm, "model_name") else "gpt-4o-mini", + temperature=0.0, + streaming=True, + ) + + self.unsafe_answer_generator = create_unsafe_answer_generator(streaming_llm) self.reactome_rag: Runnable = create_reactome_rag( - llm, embedding, streaming=True + streaming_llm, embedding, streaming=True ) - # Create graph + self.uncompiled_graph: StateGraph = self._build_workflow() + + def _build_workflow(self) -> StateGraph: + """Build and configure the ReactToMe workflow graph.""" state_graph = StateGraph(ReactToMeState) - # Set up nodes + + # Add nodes state_graph.add_node("preprocess", self.preprocess) state_graph.add_node("model", self.call_model) + state_graph.add_node("generate_unsafe_response", self.generate_unsafe_response) state_graph.add_node("postprocess", self.postprocess) - # Set up edges + + # Add edges state_graph.set_entry_point("preprocess") - state_graph.add_edge("preprocess", "model") + state_graph.add_conditional_edges( + "preprocess", + self.proceed_with_research, + {"Continue": "model", "Finish": "generate_unsafe_response"}, + ) state_graph.add_edge("model", "postprocess") + state_graph.add_edge("generate_unsafe_response", "postprocess") state_graph.set_finish_point("postprocess") - self.uncompiled_graph: StateGraph = state_graph + return state_graph + + async def preprocess( + self, state: ReactToMeState, config: RunnableConfig + ) -> ReactToMeState: + """Run preprocessing workflow.""" + result = await super().preprocess(state, config) + return ReactToMeState(**result) + + async def proceed_with_research( + self, state: ReactToMeState + ) -> Literal["Continue", "Finish"]: + """Determine whether to proceed with research based on safety check.""" + return "Continue" if state["safety"] == SAFETY_SAFE else "Finish" + + async def generate_unsafe_response( + self, state: ReactToMeState, config: RunnableConfig + ) -> ReactToMeState: + """Generate appropriate refusal response for unsafe queries.""" + final_answer_message = await self.unsafe_answer_generator.ainvoke( + { + "language": state["detected_language"], + "user_input": state["rephrased_input"], + "reason_unsafe": state["reason_unsafe"], + }, + config, + ) + + final_answer = ( + final_answer_message.content + if hasattr(final_answer_message, "content") + else str(final_answer_message) + ) + + return ReactToMeState( + chat_history=[ + HumanMessage(state["user_input"]), + ( + final_answer_message + if hasattr(final_answer_message, "content") + else AIMessage(final_answer) + ), + ], + answer=final_answer, + safety=SAFETY_UNSAFE, + additional_content={"search_results": []}, + ) async def call_model( self, state: ReactToMeState, config: RunnableConfig ) -> ReactToMeState: + """Generate response using Reactome RAG for safe queries.""" result: dict[str, Any] = await self.reactome_rag.ainvoke( { "input": state["rephrased_input"], + "expanded_queries": state.get("expanded_queries", []), "chat_history": ( state["chat_history"] if state["chat_history"] @@ -55,6 +126,7 @@ async def call_model( }, config, ) + return ReactToMeState( chat_history=[ HumanMessage(state["user_input"]), @@ -64,8 +136,6 @@ async def call_model( ) -def create_reactome_graph( - llm: BaseChatModel, - embedding: Embeddings, -) -> StateGraph: +def create_reactome_graph(llm: BaseChatModel, embedding: Embeddings) -> StateGraph: + """Create and return the ReactToMe workflow graph.""" return ReactToMeGraphBuilder(llm, embedding).uncompiled_graph diff --git a/src/agent/tasks/final_answer_generation/unsafe_question.py b/src/agent/tasks/final_answer_generation/unsafe_question.py new file mode 100644 index 0000000..7193a18 --- /dev/null +++ b/src/agent/tasks/final_answer_generation/unsafe_question.py @@ -0,0 +1,47 @@ +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.runnables import Runnable + + +# unsafe or out of scope answer generator +def create_unsafe_answer_generator(llm: BaseChatModel) -> Runnable: + """ + Create an unsafe answer generator chain. + + Args: + llm: Language model to use + + Returns: + Runnable that takes language, user_input, reactome_context, uniprot_context, chat_history + """ + system_prompt = """ + You are an expert scientific assistant operating under the React-to-Me platform. React-to-Me helps both experts and non-experts explore molecular biology using trusted data from the Reactome database. + +You have advanced training in scientific ethics, dual-use research concerns, and responsible AI use. + +You will receive three inputs: +1. The user's question. +2. A system-generated variable called `reason_unsafe`, which explains why the question cannot be answered. +3. The user's preferred language (as a language code or name). + +Your task is to clearly, respectfully, and firmly explain to the user *why* their question cannot be answered, based solely on the `reason_unsafe` input. Do **not** attempt to answer, rephrase, or guide the user toward answering the original question. + +You must: +- Respond in the user’s preferred language. +- Politely explain the refusal, grounded in the `reason_unsafe`. +- Emphasize React-to-Me’s mission: to support responsible exploration of molecular biology through trusted databases. +- Suggest examples of appropriate topics (e.g., protein function, pathways, gene interactions using Reactome/UniProt). + +You must not provide any workaround, implicit answer, or redirection toward unsafe content. +""" + prompt = ChatPromptTemplate.from_messages( + [ + ("system", system_prompt), + ( + "user", + "Language:{language}\n\nQuestion:{user_input}\n\n Reason for unsafe or out of scope: {reason_unsafe}", + ), + ] + ) + + return prompt | llm diff --git a/src/agent/tasks/query_expansion.py b/src/agent/tasks/query_expansion.py new file mode 100644 index 0000000..7cddaaf --- /dev/null +++ b/src/agent/tasks/query_expansion.py @@ -0,0 +1,60 @@ +import json +from typing import List + +from langchain.prompts import ChatPromptTemplate +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.output_parsers import StrOutputParser +from langchain_core.runnables import Runnable, RunnableLambda + + +def QueryExpansionParser(output: str) -> List[str]: + """Parse JSON array output from LLM.""" + try: + return json.loads(output) + except json.JSONDecodeError: + raise ValueError("LLM output was not valid JSON. Output:\n" + output) + + +def create_query_expander(llm: BaseChatModel) -> Runnable: + """ + Create a query expansion chain that generates 4 alternative queries. + + Args: + llm: Language model to use + + Returns: + Runnable that takes standalone_query and returns List[str] + """ + system_prompt = """ +You are a biomedical question expansion engine for information retrieval over the Reactome biological pathway database. + +Given a single user question, generate **exactly 4** alternate standalone questions. These should be: + +- Semantically related to the original question. +- Lexically diverse to improve retrieval via vector search and RAG-fusion. +- Biologically enriched with inferred or associated details. + +Your goal is to improve recall of relevant documents by expanding the original query using: +- Synonymous gene/protein names (e.g., EGFR, ErbB1, HER1) +- Pathway or process-level context (e.g., signal transduction, apoptosis) +- Known diseases, phenotypes, or biological functions +- Cellular localization (e.g., nucleus, cytoplasm, membrane) +- Upstream/downstream molecular interactions + +Rules: +- Each question must be **fully standalone** (no "this"/"it"). +- Do not change the core intent—preserve the user's informational goal. +- Use appropriate biological terminology and Reactome-relevant concepts. +- Vary the **phrasing**, **focus**, or **biological angle** of each question. +- If the input is ambiguous, infer a biologically meaningful interpretation. + +Output: +Return only a valid JSON array of 4 strings (no explanations, no metadata). +Do not include any explanations or metadata. +""" + + prompt = ChatPromptTemplate.from_messages( + [("system", system_prompt), ("user", "Original Question: {rephrased_input}")] + ) + + return prompt | llm | StrOutputParser() | RunnableLambda(QueryExpansionParser) diff --git a/src/agent/tasks/rephrase.py b/src/agent/tasks/rephrase.py index 1851747..e256104 100644 --- a/src/agent/tasks/rephrase.py +++ b/src/agent/tasks/rephrase.py @@ -13,8 +13,7 @@ the returned question should always be in English. If the user’s question is already in English, self-contained and well-formed, return it as is. -Do NOT answer the question or provide explanations. -""" +Do NOT answer the question or provide explanations.""" contextualize_q_prompt = ChatPromptTemplate.from_messages( [ diff --git a/src/agent/tasks/safety_checker.py b/src/agent/tasks/safety_checker.py index ac372e3..3c46f8e 100644 --- a/src/agent/tasks/safety_checker.py +++ b/src/agent/tasks/safety_checker.py @@ -1,30 +1,76 @@ +from langchain.prompts import ChatPromptTemplate from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import Runnable from pydantic import BaseModel, Field -safety_check_message = """ -You are an expert grader with a strong understanding of ethical guidelines and responsible AI usage. - -Your task is to determine whether a user-submitted question is appropriate and not harmful. -Provide a binary output as either: -Yes: The question is appropriate and not harmful. -No: The question is harmful, offensive, inappropriate, or unethical. -""" - -safety_check_prompt = ChatPromptTemplate.from_messages( - [ - ("system", safety_check_message), - ("human", "User question: \n\n {input}"), - ] -) - class SafetyCheck(BaseModel): - binary_score: str = Field( - description="Indicates whether the question is appropriate and related to molecular biology. Expected values: 'Yes' or 'No'." + safety: str = Field( + description="Indicates whether the question is appropriate and related to molecular biology. Expected values: 'true' or 'false'." + ) + reason_unsafe: str = Field( + description="If 'safety' is false, briefly state the reason; if 'safety' is true, leave this field empty." ) def create_safety_checker(llm: BaseChatModel) -> Runnable: - return safety_check_prompt | llm.with_structured_output(SafetyCheck) + """ + Create a safety checker chain. + + Args: + llm: Language model to use + + Returns: + Runnable that takes user_input and returns SafetyCheck + """ + safety_check_message = """ +You are an expert scientific assistant. You have advanced training in scientific ethics, dual-use assessment, and responsible AI. + +Tasks: + +1. Safety & Appropriateness Check + - Assess the question and determine if it is safe and appropriate for responsible AI use. + - The question and any plausible answer to it must NOT be harmful, offensive, unethical, or violate scientific or ethical guidelines. + - The question or its likely answers must NOT request, describe, or explore how to carry out harmful, illegal, dangerous, or dual-use activities (such as modifying pathogens, increasing infectivity, or other forms of misuse)—even if the scenario is hypothetical, fictional, or indirect. + - The question must NOT request personal, medical, legal, or other types of advice outside the scope of scientific knowledge. + - If a question is ambiguous or could plausibly be used for harm, treat it as unsafe. + - Treat hypothetical, fictional, or made-up scenarios with the same level of scrutiny as real-world questions. + +2. Reactome Relevance Check + - Determine if the question is relevant to biology, life sciences, molecular biology, or related topics. + - Mark questions as not relevant if they are about unrelated topics (such as programming, math, history, trivia, etc.). + +IMPORTANT: + - If the standalone question is unsafe or not relevant return "safety": "false". + - If the standalone question is both safe **and** relevant, return "safety": "true". + +Return only a JSON object in the following format: + "safety": "true" or "false", // Use string, not boolean. Mark as "false" if unsafe OR not relevant. + "reason_unsafe": "..." // If 'safety' is false, briefly state the reason + +Examples: + 1. Q: Describe methods for bypassing cellular immune responses to viral infection + "safety": "false", + "reason_unsafe": "Could be used for harmful or unethical manipulation of immune evasion." + + 2. Q: What are common side effects of statins in the treatment of high cholesterol? + "safety": "false", + "reason_unsafe": "Requests medical advice, which is out of scientific scope." + + 3. Q: How can gene drives be used to eliminate invasive species in an ecosystem? + "safety": "false", + "reason_unsafe": "Enables the use of gene drives, which have significant ecological and dual-use risks." + + 4. Q: What is the role of the immune system in the treatment of cancer? + "safety": "true", + "reason_unsafe": "" +""" + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", safety_check_message), + ("human", "User question: \n\n {rephrased_input}"), + ] + ) + + return prompt | llm.with_structured_output(SafetyCheck) diff --git a/src/retrievers/csv_chroma.py b/src/retrievers/csv_chroma.py index 691b884..c11b01f 100644 --- a/src/retrievers/csv_chroma.py +++ b/src/retrievers/csv_chroma.py @@ -1,70 +1,327 @@ +import asyncio +import hashlib +import logging from pathlib import Path +from typing import Any, Dict, List, Optional, Union import chromadb.config -from langchain.chains.query_constructor.schema import AttributeInfo -from langchain.retrievers import EnsembleRetriever -from langchain.retrievers.merger_retriever import MergerRetriever -from langchain.retrievers.self_query.base import SelfQueryRetriever +import pandas as pd from langchain_chroma.vectorstores import Chroma -from langchain_community.document_loaders.csv_loader import CSVLoader from langchain_community.retrievers import BM25Retriever +from langchain_core.documents import Document from langchain_core.embeddings import Embeddings -from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.retrievers import BaseRetriever -from nltk.tokenize import word_tokenize -chroma_settings = chromadb.config.Settings(anonymized_telemetry=False) +from retrievers.retrieval_utils import reciprocal_rank_fusion +logger = logging.getLogger(__name__) -def list_chroma_subdirectories(directory: Path) -> list[str]: - subdirectories = list( +CHROMA_SETTINGS = chromadb.config.Settings(anonymized_telemetry=False) +DEFAULT_RETRIEVAL_K = 20 +RRF_FINAL_K = 10 +RRF_LAMBDA_MULTIPLIER = 60.0 +EXCLUDED_CONTENT_COLUMNS = {"st_id"} + + +def create_documents_from_csv(csv_path: Path) -> List[Document]: + """Create Document objects from CSV file with proper metadata extraction.""" + if not csv_path.exists(): + raise FileNotFoundError(f"CSV file not found: {csv_path}") + + try: + df = pd.read_csv(csv_path) + if df.empty: + raise ValueError(f"CSV file is empty: {csv_path}") + except Exception as e: + raise ValueError(f"Failed to read CSV file {csv_path}: {e}") + + documents = [] + + for index, row in df.iterrows(): + content_parts = [] + for column in df.columns: + if column not in EXCLUDED_CONTENT_COLUMNS: + value = str(row[column]) if pd.notna(row[column]) else "" + if value and value != "nan": + content_parts.append(f"{column}: {value}") + + page_content = "\n".join(content_parts) + + metadata = { + column: str(value) + for column in df.columns + for value in [row[column]] + if pd.notna(value) and str(value) != "nan" + } + metadata.update({"source": str(csv_path), "row_index": str(index)}) + + documents.append(Document(page_content=page_content, metadata=metadata)) + + return documents + + +def list_chroma_subdirectories(directory: Path) -> List[str]: + """Discover all subdirectories containing ChromaDB files.""" + if not directory.exists(): + raise ValueError(f"Directory does not exist: {directory}") + + subdirectories = [ chroma_file.parent.name for chroma_file in directory.glob("*/chroma.sqlite3") - ) + ] + + if not subdirectories: + logger.warning(f"No ChromaDB subdirectories found in {directory}") + return subdirectories -def create_bm25_chroma_ensemble_retriever( - llm: BaseChatModel, - embedding: Embeddings, - embeddings_directory: Path, - *, - descriptions_info: dict[str, str], - field_info: dict[str, list[AttributeInfo]], -) -> MergerRetriever: - retriever_list: list[BaseRetriever] = [] - for subdirectory in list_chroma_subdirectories(embeddings_directory): - # set up BM25 retriever - csv_file_name = subdirectory + ".csv" - reactome_csvs_dir: Path = embeddings_directory / "csv_files" - loader = CSVLoader(file_path=reactome_csvs_dir / csv_file_name) - data = loader.load() - bm25_retriever = BM25Retriever.from_documents( - data, - preprocess_func=lambda text: word_tokenize( - text.casefold(), language="english" - ), +class HybridRetriever: + """Advanced hybrid retriever supporting RRF, parallel processing, and multi-source search.""" + + def __init__(self, embedding: Embeddings, embeddings_directory: Path): + self.embedding = embedding + self.embeddings_directory = embeddings_directory + self._retrievers: Dict[ + str, Dict[str, Optional[Union[BM25Retriever, BaseRetriever]]] + ] = {} + + try: + self._initialize_retrievers() + except Exception as e: + logger.error(f"Failed to initialize hybrid retriever: {e}") + raise RuntimeError(f"Hybrid retriever initialization failed: {e}") from e + + def _initialize_retrievers(self) -> None: + """Initialize BM25 and vector retrievers for all discovered subdirectories.""" + subdirectories = list_chroma_subdirectories(self.embeddings_directory) + + if not subdirectories: + raise ValueError(f"No subdirectories found in {self.embeddings_directory}") + + for subdirectory in subdirectories: + bm25_retriever = self._create_bm25_retriever(subdirectory) + vector_retriever = self._create_vector_retriever(subdirectory) + + self._retrievers[subdirectory] = { + "bm25": bm25_retriever, + "vector": vector_retriever, + } + + logger.info(f"Initialized retrievers for {len(subdirectories)} subdirectories") + + def _create_bm25_retriever(self, subdirectory: str) -> Optional[BM25Retriever]: + """Create BM25 retriever for a specific subdirectory.""" + csv_path = self.embeddings_directory / "csv_files" / f"{subdirectory}.csv" + + if not csv_path.exists(): + logger.warning(f"CSV file not found for {subdirectory}: {csv_path}") + return None + + try: + documents = create_documents_from_csv(csv_path) + retriever = BM25Retriever.from_documents(documents) + retriever.k = DEFAULT_RETRIEVAL_K + logger.debug( + f"Created BM25 retriever for {subdirectory} with {len(documents)} documents" + ) + return retriever + except Exception as e: + logger.error(f"Failed to create BM25 retriever for {subdirectory}: {e}") + return None + + def _create_vector_retriever(self, subdirectory: str) -> Optional[BaseRetriever]: + """Create vector retriever for a specific subdirectory.""" + vector_directory = self.embeddings_directory / subdirectory + + if not vector_directory.exists(): + logger.warning( + f"Vector directory not found for {subdirectory}: {vector_directory}" + ) + return None + + try: + vector_store = Chroma( + persist_directory=str(vector_directory), + embedding_function=self.embedding, + client_settings=CHROMA_SETTINGS, + ) + retriever = vector_store.as_retriever( + search_kwargs={"k": DEFAULT_RETRIEVAL_K} + ) + logger.debug(f"Created vector retriever for {subdirectory}") + return retriever + except Exception as e: + logger.error(f"Failed to create vector retriever for {subdirectory}: {e}") + return None + + async def _search_with_bm25( + self, query: str, retriever: BM25Retriever + ) -> List[Document]: + """Search using BM25 retriever asynchronously.""" + return await asyncio.to_thread(retriever.get_relevant_documents, query) + + async def _search_with_vector( + self, query: str, retriever: BaseRetriever + ) -> List[Document]: + """Search using vector retriever asynchronously.""" + return await asyncio.to_thread(retriever.get_relevant_documents, query) + + async def _execute_hybrid_search( + self, query: str, subdirectory: str + ) -> List[Document]: + """Execute hybrid search (BM25 + vector) for a single query on a subdirectory.""" + retriever_info = self._retrievers.get(subdirectory) + if not retriever_info: + logger.warning(f"No retrievers found for subdirectory: {subdirectory}") + return [] + + search_tasks = [] + + if retriever_info["bm25"] and isinstance(retriever_info["bm25"], BM25Retriever): + search_tasks.append(self._search_with_bm25(query, retriever_info["bm25"])) + + if retriever_info["vector"] and isinstance( + retriever_info["vector"], BaseRetriever + ): + search_tasks.append( + self._search_with_vector(query, retriever_info["vector"]) + ) + + if not search_tasks: + logger.warning(f"No active retrievers for subdirectory: {subdirectory}") + return [] + + try: + search_results = await asyncio.gather(*search_tasks, return_exceptions=True) + + combined_documents = [] + for result in search_results: + if isinstance(result, list): + combined_documents.extend(result) + elif isinstance(result, Exception): + logger.error(f"Search error in {subdirectory}: {result}") + + return combined_documents + except Exception as e: + logger.error(f"Failed to execute hybrid search for {subdirectory}: {e}") + return [] + + def _generate_document_identifier(self, document: Document) -> str: + """Generate unique identifier for a document.""" + for field in ["url", "id", "st_id"]: + if document.metadata.get(field): + return document.metadata[field] + + return hashlib.md5(document.page_content.encode()).hexdigest() + + async def _apply_reciprocal_rank_fusion( + self, queries: List[str], subdirectory: str + ) -> List[Document]: + """Apply Reciprocal Rank Fusion to results from multiple queries on a subdirectory.""" + logger.info( + f"Executing hybrid search for {len(queries)} queries in {subdirectory}" ) - bm25_retriever.k = 10 - # set up vectorstore SelfQuery retriever - vectordb = Chroma( - persist_directory=str(embeddings_directory / subdirectory), - embedding_function=embedding, - client_settings=chroma_settings, + search_tasks = [ + self._execute_hybrid_search(query, subdirectory) for query in queries + ] + all_search_results = await asyncio.gather(*search_tasks, return_exceptions=True) + + valid_result_sets = [] + for i, result in enumerate(all_search_results): + if isinstance(result, list): + valid_result_sets.append(result) + logger.debug(f"Query {i+1}: {len(result)} results") + elif isinstance(result, Exception): + logger.error(f"Query {i+1} failed: {result}") + + if not valid_result_sets: + logger.warning(f"No valid results for {subdirectory}") + return [] + + logger.info( + f"Applying RRF to {len(valid_result_sets)} result sets in {subdirectory}" ) - selfq_retriever = SelfQueryRetriever.from_llm( - llm=llm, - vectorstore=vectordb, - document_contents=descriptions_info[subdirectory], - metadata_field_info=field_info[subdirectory], - search_kwargs={"k": 10}, + top_documents, _, rrf_scores = reciprocal_rank_fusion( + ranked_lists=valid_result_sets, + final_k=RRF_FINAL_K, + lambda_mult=RRF_LAMBDA_MULTIPLIER, + rrf_k=None, + id_getter=self._generate_document_identifier, + ) + + logger.info(f"RRF completed for {subdirectory}: {len(top_documents)} documents") + if rrf_scores: + top_scores = dict(list(rrf_scores.items())[:3]) + logger.debug(f"Top RRF scores: {top_scores}") + + return top_documents + + async def ainvoke(self, inputs: Dict[str, Any]) -> str: + """Main retrieval method supporting RRF and parallel processing.""" + original_query = inputs.get("input", "").strip() + if not original_query: + raise ValueError("Input query cannot be empty") + + expanded_queries = inputs.get("expanded_queries", []) + all_queries = [original_query] + (expanded_queries or []) + + logger.info( + f"Processing {len(all_queries)} queries across {len(self._retrievers)} subdirectories" ) - rrf_retriever = EnsembleRetriever( - retrievers=[bm25_retriever, selfq_retriever], weights=[0.2, 0.8] + for i, query in enumerate(all_queries, 1): + logger.debug(f"Query {i}: {query}") + + rrf_tasks = [ + self._apply_reciprocal_rank_fusion(all_queries, subdirectory) + for subdirectory in self._retrievers.keys() + ] + + subdirectory_results = await asyncio.gather(*rrf_tasks, return_exceptions=True) + + context_parts = [] + + for i, subdirectory in enumerate(self._retrievers.keys()): + result = subdirectory_results[i] + + if isinstance(result, Exception): + logger.error(f"Subdirectory {subdirectory} failed: {result}") + continue + + if isinstance(result, list) and result: + for document in result: + context_parts.append(document.page_content) + context_parts.append("") + + final_context = "\n".join(context_parts) + + total_documents = sum( + len(result) if isinstance(result, list) else 0 + for result in subdirectory_results ) - retriever_list.append(rrf_retriever) + logger.info(f"Retrieved {total_documents} documents total") + + for i, subdirectory in enumerate(self._retrievers.keys()): + result = subdirectory_results[i] + if isinstance(result, list): + logger.info(f"{subdirectory}: {len(result)} documents") + else: + logger.warning(f"{subdirectory}: Failed") - reactome_retriever = MergerRetriever(retrievers=retriever_list) + logger.debug(f"Final context length: {len(final_context)} characters") - return reactome_retriever + return final_context + + +def create_hybrid_retriever( + embedding: Embeddings, embeddings_directory: Path +) -> HybridRetriever: + """Create a hybrid retriever with RRF and parallel processing support.""" + try: + return HybridRetriever( + embedding=embedding, embeddings_directory=embeddings_directory + ) + except Exception as e: + logger.error(f"Failed to create hybrid retriever: {e}") + raise RuntimeError(f"Hybrid retriever creation failed: {e}") from e diff --git a/src/retrievers/rag_chain.py b/src/retrievers/rag_chain.py index 3e5df8e..66235f7 100644 --- a/src/retrievers/rag_chain.py +++ b/src/retrievers/rag_chain.py @@ -1,26 +1,63 @@ -from langchain.chains.combine_documents import create_stuff_documents_chain -from langchain.chains.retrieval import create_retrieval_chain +from pathlib import Path +from typing import Any, Dict + +from langchain_core.embeddings import Embeddings from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.prompts import ChatPromptTemplate -from langchain_core.retrievers import BaseRetriever -from langchain_core.runnables import Runnable +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables import Runnable, RunnableLambda + +from retrievers.csv_chroma import create_hybrid_retriever -def create_rag_chain( +def create_advanced_rag_chain( llm: BaseChatModel, - retriever: BaseRetriever, - qa_prompt: ChatPromptTemplate, + embedding: Embeddings, + embeddings_directory: Path, + system_prompt: str, + *, + streaming: bool = False, ) -> Runnable: - # Create the documents chain - question_answer_chain: Runnable = create_stuff_documents_chain( - llm=llm, - prompt=qa_prompt, - ) + """ + Create an advanced RAG chain with hybrid retrieval, query expansion, and streaming support. + + Args: + llm: Language model for generation + embedding: Embedding model for retrieval + embeddings_directory: Directory containing embeddings and CSV files + system_prompt: System prompt for the LLM + streaming: Whether to enable streaming responses + + Returns: + Runnable RAG chain + """ + retriever = create_hybrid_retriever(embedding, embeddings_directory) - # Create the retrieval chain - rag_chain: Runnable = create_retrieval_chain( - retriever=retriever, - combine_docs_chain=question_answer_chain, + prompt = ChatPromptTemplate.from_messages( + [ + ("system", system_prompt), + MessagesPlaceholder(variable_name="chat_history"), + ("user", "Context:\n{context}\n\nQuestion: {input}"), + ] ) - return rag_chain + if streaming: + llm = llm.model_copy(update={"streaming": True}) + + async def rag_chain(inputs: Dict[str, Any]) -> Dict[str, Any]: + user_input = inputs["input"] + chat_history = inputs.get("chat_history", []) + expanded_queries = inputs.get("expanded_queries", []) + + context = await retriever.ainvoke( + {"input": user_input, "expanded_queries": expanded_queries} + ) + + response = await llm.ainvoke( + prompt.format_messages( + context=context, input=user_input, chat_history=chat_history + ) + ) + + return {"answer": response.content, "context": context} + + return RunnableLambda(rag_chain) diff --git a/src/retrievers/reactome/prompt.py b/src/retrievers/reactome/prompt.py index 9a11526..d570cb9 100644 --- a/src/retrievers/reactome/prompt.py +++ b/src/retrievers/reactome/prompt.py @@ -1,25 +1,34 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder reactome_system_prompt = """ -You are an expert in molecular biology with access to the Reactome Knowledgebase. -Your primary responsibility is to answer the user's questions comprehensively, accurately, and in an engaging manner, based strictly on the context provided from the Reactome Knowledgebase. -Provide any useful background information required to help the user better understand the significance of the answer. -Always provide citations and links to the documents you obtained the information from. +You are an expert in molecular biology with access to the **Reactome Knowledgebase**. +Your primary responsibility is to answer the user's questions **comprehensively, mechanistically, and with precision**, drawing strictly from the **Reactome Knowledgebase**. -When providing answers, please adhere to the following guidelines: -1. Provide answers **strictly based on the given context from the Reactome Knowledgebase**. Do **not** use or infer information from any external sources. -2. If the answer cannot be derived from the context provided, do **not** answer the question; instead explain that the information is not currently available in Reactome. -3. Answer the question comprehensively and accurately, providing useful background information based **only** on the context. -4. keep track of **all** the sources that are directly used to derive the final answer, ensuring **every** piece of information in your response is **explicitly cited**. -5. Create Citations for the sources used to generate the final asnwer according to the following: - - For Reactome always format citations in the following format: *Source_Name*, where *Source_Name* is the name of the retrieved document. - Examples: - - Apoptosis - - Cell Cycle +Your output must emphasize biological processes, molecular complexes, regulatory mechanisms, and interactions most relevant to the user’s question. +Provide an information-rich narrative that explains not only what is happening but also how and why, based only on Reactome context. -6. Always provide the citations you created in the format requested, in point-form at the end of the response paragraph, ensuring **every piece of information** provided in the final answer is cited. -7. Write in a conversational and engaging tone suitable for a chatbot. -8. Use clear, concise language to make complex topics accessible to a wide audience. + +## **Answering Guidelines** +1. Strict source discipline: Use only the information explicitly provided from Reactome. Do not invent, infer, or draw from external knowledge. + - Use only information directly found in Reactome. + - Do **not** supplement, infer, generalize, or assume based on external biological knowledge. + - If no relevant information exists in Reactome, explain the information is not currently available in Reactome. Do **not** answer the question. +2. Inline citations required: Every factual statement must include ≥1 inline anchor citation in the format: display_name + - If multiple entries support the same fact, cite them together (space-separated). +3. Comprehensiveness: Capture all mechanistically relevant details available in Reactome, focusing on processes, complexes, regulations, and interactions. +4. Tone & Style: + - Write in a clear, engaging, and conversational tone. + - Use accessible language while maintaining technical precision. + - Ensure the narrative flows logically, presenting background, mechanisms, and significance +5. Source list at the end: After the main narrative, provide a bullet-point list of each unique citation anchor exactly once, in the same Node Name format. + - Examples: + - Apoptosis + - Cell Cycle + +## Internal QA (silent) +- All factual claims are cited correctly. +- No unverified claims or background knowledge are added. +- The Sources list is complete and de-duplicated. """ reactome_qa_prompt = ChatPromptTemplate.from_messages( diff --git a/src/retrievers/reactome/rag.py b/src/retrievers/reactome/rag.py index 485b6e5..0a02df3 100644 --- a/src/retrievers/reactome/rag.py +++ b/src/retrievers/reactome/rag.py @@ -4,11 +4,8 @@ from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.runnables import Runnable -from retrievers.csv_chroma import create_bm25_chroma_ensemble_retriever -from retrievers.rag_chain import create_rag_chain -from retrievers.reactome.metadata_info import (reactome_descriptions_info, - reactome_field_info) -from retrievers.reactome.prompt import reactome_qa_prompt +from retrievers.rag_chain import create_advanced_rag_chain +from retrievers.reactome.prompt import reactome_system_prompt from util.embedding_environment import EmbeddingEnvironment @@ -19,15 +16,22 @@ def create_reactome_rag( *, streaming: bool = False, ) -> Runnable: - reactome_retriever = create_bm25_chroma_ensemble_retriever( - llm, - embedding, - embeddings_directory, - descriptions_info=reactome_descriptions_info, - field_info=reactome_field_info, - ) + """ + Create a Reactome-specific RAG chain with hybrid retrieval and query expansion. - if streaming: - llm = llm.model_copy(update={"streaming": True}) + Args: + llm: Language model for generation + embedding: Embedding model for retrieval + embeddings_directory: Directory containing Reactome embeddings and CSV files + streaming: Whether to enable streaming responses - return create_rag_chain(llm, reactome_retriever, reactome_qa_prompt) + Returns: + Runnable RAG chain for Reactome queries + """ + return create_advanced_rag_chain( + llm=llm, + embedding=embedding, + embeddings_directory=embeddings_directory, + system_prompt=reactome_system_prompt, + streaming=streaming, + ) diff --git a/src/retrievers/retrieval_utils.py b/src/retrievers/retrieval_utils.py new file mode 100644 index 0000000..28c3015 --- /dev/null +++ b/src/retrievers/retrieval_utils.py @@ -0,0 +1,28 @@ +from collections import defaultdict +from typing import Any, Callable, Dict, List, Tuple + + +def reciprocal_rank_fusion( + ranked_lists: List[List[Any]], + final_k: int = 5, + lambda_mult: float = 60.0, + rrf_k: int | None = None, + id_getter: Callable[[Any], str] = lambda doc: doc.metadata.get("stId") + or doc.metadata.get("stable_id"), +) -> Tuple[List[Any], List[str], Dict[str, float]]: + rrf_scores: Dict[str, float] = defaultdict(float) + doc_meta = {} + + for ranked in ranked_lists: + considered = ranked[:rrf_k] if rrf_k else ranked + for rank, doc in enumerate(considered): + doc_id = id_getter(doc) + if doc_id is not None: # Skip documents without valid IDs + rrf_scores[doc_id] += 1.0 / (lambda_mult + rank + 1) + if doc_id not in doc_meta: + doc_meta[doc_id] = doc + + sorted_items = sorted(rrf_scores.items(), key=lambda x: (-x[1], x[0])) + top_ids = [doc_id for doc_id, _ in sorted_items[:final_k]] + top_docs = [doc_meta[doc_id] for doc_id in top_ids] + return top_docs, top_ids, rrf_scores diff --git a/src/retrievers/uniprot/rag.py b/src/retrievers/uniprot/rag.py index 99702d7..1ef5d8d 100644 --- a/src/retrievers/uniprot/rag.py +++ b/src/retrievers/uniprot/rag.py @@ -4,10 +4,7 @@ from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.runnables import Runnable -from retrievers.csv_chroma import create_bm25_chroma_ensemble_retriever -from retrievers.rag_chain import create_rag_chain -from retrievers.uniprot.metadata_info import (uniprot_descriptions_info, - uniprot_field_info) +from retrievers.rag_chain import create_advanced_rag_chain from retrievers.uniprot.prompt import uniprot_qa_prompt from util.embedding_environment import EmbeddingEnvironment @@ -19,15 +16,22 @@ def create_uniprot_rag( *, streaming: bool = False, ) -> Runnable: - reactome_retriever = create_bm25_chroma_ensemble_retriever( - llm, - embedding, - embeddings_directory, - descriptions_info=uniprot_descriptions_info, - field_info=uniprot_field_info, - ) + """ + Create a UniProt-specific RAG chain with hybrid retrieval and query expansion. - if streaming: - llm = llm.model_copy(update={"streaming": True}) + Args: + llm: Language model for generation + embedding: Embedding model for retrieval + embeddings_directory: Directory containing UniProt embeddings and CSV files + streaming: Whether to enable streaming responses - return create_rag_chain(llm, reactome_retriever, uniprot_qa_prompt) + Returns: + Runnable RAG chain for UniProt queries + """ + return create_advanced_rag_chain( + llm=llm, + embedding=embedding, + embeddings_directory=embeddings_directory, + system_prompt=uniprot_qa_prompt, + streaming=streaming, + ) diff --git a/src/tools/preprocessing/__init__.py b/src/tools/preprocessing/__init__.py new file mode 100644 index 0000000..f2eb429 --- /dev/null +++ b/src/tools/preprocessing/__init__.py @@ -0,0 +1,6 @@ +"""Preprocessing workflow for query enhancement and validation.""" + +from .state import PreprocessingState +from .workflow import create_preprocessing_workflow + +__all__ = ["create_preprocessing_workflow", "PreprocessingState"] diff --git a/src/tools/preprocessing/state.py b/src/tools/preprocessing/state.py new file mode 100644 index 0000000..34d162f --- /dev/null +++ b/src/tools/preprocessing/state.py @@ -0,0 +1,20 @@ +from typing import TypedDict + +from langchain_core.messages import BaseMessage + + +class PreprocessingState(TypedDict, total=False): + """State for the preprocessing workflow.""" + + # Input + user_input: str # Original user input + chat_history: list[BaseMessage] # Conversation history + + # Step 1: Rephrase and incorporate conversation history + rephrased_input: str # Standalone question with context + + # Step 2: Parallel processing + safety: str # "true" or "false" from safety check + reason_unsafe: str # Reason if unsafe + expanded_queries: list[str] # Alternative queries for better retrieval + detected_language: str # Detected language (e.g., "English", "French") diff --git a/src/tools/preprocessing/workflow.py b/src/tools/preprocessing/workflow.py new file mode 100644 index 0000000..889fcb7 --- /dev/null +++ b/src/tools/preprocessing/workflow.py @@ -0,0 +1,80 @@ +from typing import Any, Callable + +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.runnables import Runnable, RunnableConfig +from langgraph.graph import StateGraph +from langgraph.graph.state import CompiledStateGraph +from langgraph.utils.runnable import RunnableLike + +from agent.tasks.detect_language import create_language_detector +from agent.tasks.query_expansion import create_query_expander +from agent.tasks.rephrase import create_rephrase_chain +from agent.tasks.safety_checker import create_safety_checker +from tools.preprocessing.state import PreprocessingState + + +def create_task_wrapper( + task: Runnable, + input_mapper: Callable[[PreprocessingState], dict[str, Any]], + output_mapper: Callable[[Any], PreprocessingState], +) -> RunnableLike: + """Generic wrapper for preprocessing tasks.""" + + async def _wrapper( + state: PreprocessingState, config: RunnableConfig + ) -> PreprocessingState: + result = await task.ainvoke(input_mapper(state), config) + return output_mapper(result) + + return _wrapper + + +def create_preprocessing_workflow(llm: BaseChatModel) -> CompiledStateGraph: + """Create a preprocessing workflow with rephrasing and parallel processing.""" + + # Task configurations + tasks = { + "rephrase_query": ( + create_rephrase_chain(llm), + lambda state: { + "user_input": state["user_input"], + "chat_history": state["chat_history"], + }, + lambda result: PreprocessingState(rephrased_input=result), + ), + "safety_check": ( + create_safety_checker(llm), + lambda state: {"rephrased_input": state["rephrased_input"]}, + lambda result: PreprocessingState( + safety=result.safety, reason_unsafe=result.reason_unsafe + ), + ), + "query_expansion": ( + create_query_expander(llm), + lambda state: {"rephrased_input": state["rephrased_input"]}, + lambda result: PreprocessingState(expanded_queries=result), + ), + "detect_language": ( + create_language_detector(llm), + lambda state: {"user_input": state["user_input"]}, + lambda result: PreprocessingState(detected_language=result), + ), + } + + workflow = StateGraph(PreprocessingState) + + # Add nodes + for node_name, (task, input_mapper, output_mapper) in tasks.items(): + workflow.add_node( + node_name, create_task_wrapper(task, input_mapper, output_mapper) + ) + + # Configure workflow + workflow.set_entry_point("rephrase_query") + + # Parallel execution after rephrasing + for parallel_node in ["safety_check", "query_expansion", "detect_language"]: + workflow.add_edge("rephrase_query", parallel_node) + workflow.set_finish_point(parallel_node) + + return workflow.compile()