-
Notifications
You must be signed in to change notification settings - Fork 0
React to me architectural upgrade - Advanced hybrid retrieval, preprocessing pipeline, and safety system #97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2607236
8b35029
8b2578f
b2cc4bb
3b9e95d
2864e97
f35f3e0
ba01931
7f8d4c5
67fcd60
3ea2ba8
5b82199
27e761c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,58 +1,93 @@ | ||
| from typing import Annotated, TypedDict | ||
| from typing import Annotated, Literal, TypedDict | ||
|
|
||
| from langchain_core.embeddings import Embeddings | ||
| from langchain_core.language_models.chat_models import BaseChatModel | ||
| from langchain_core.messages import BaseMessage | ||
| from langchain_core.runnables import Runnable, RunnableConfig | ||
| from langgraph.graph.message import add_messages | ||
|
|
||
| from agent.tasks.rephrase import create_rephrase_chain | ||
| from tools.external_search.state import SearchState, WebSearchResult | ||
| from tools.external_search.workflow import create_search_workflow | ||
| from tools.preprocessing.state import PreprocessingState | ||
| from tools.preprocessing.workflow import create_preprocessing_workflow | ||
|
|
||
| # Constants | ||
| SAFETY_SAFE: Literal["true"] = "true" | ||
| SAFETY_UNSAFE: Literal["false"] = "false" | ||
| DEFAULT_LANGUAGE: str = "English" | ||
|
|
||
|
|
||
| class AdditionalContent(TypedDict, total=False): | ||
| """Additional content sent on graph completion.""" | ||
|
|
||
| search_results: list[WebSearchResult] | ||
|
|
||
|
|
||
| class InputState(TypedDict, total=False): | ||
| user_input: str # User input text | ||
| """Input state for user queries.""" | ||
|
|
||
| user_input: str | ||
|
|
||
|
|
||
| class OutputState(TypedDict, total=False): | ||
| answer: str # primary LLM response that is streamed to the user | ||
| additional_content: AdditionalContent # sends on graph completion | ||
| """Output state for responses.""" | ||
|
|
||
| answer: str | ||
| additional_content: AdditionalContent | ||
|
|
||
|
|
||
| class BaseState(InputState, OutputState, total=False): | ||
| rephrased_input: str # LLM-generated query from user input | ||
| """Base state containing all common fields for agent workflows.""" | ||
|
|
||
| rephrased_input: str | ||
| chat_history: Annotated[list[BaseMessage], add_messages] | ||
|
|
||
| # Preprocessing results | ||
| safety: str | ||
| reason_unsafe: str | ||
| expanded_queries: list[str] | ||
| detected_language: str | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove |
||
|
|
||
|
|
||
| class BaseGraphBuilder: | ||
| # NOTE: Anything that is common to all graph builders goes here | ||
|
|
||
| def __init__( | ||
| self, | ||
| llm: BaseChatModel, | ||
| embedding: Embeddings, | ||
| ) -> None: | ||
| self.rephrase_chain: Runnable = create_rephrase_chain(llm) | ||
| """Base class for all graph builders with common preprocessing and postprocessing.""" | ||
|
|
||
| def __init__(self, llm: BaseChatModel, embedding: Embeddings) -> None: | ||
| """Initialize with LLM and embedding models.""" | ||
| self.preprocessing_workflow: Runnable = create_preprocessing_workflow(llm) | ||
| self.search_workflow: Runnable = create_search_workflow(llm) | ||
|
|
||
| async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: | ||
| rephrased_input: str = await self.rephrase_chain.ainvoke( | ||
| { | ||
| "user_input": state["user_input"], | ||
| "chat_history": state["chat_history"], | ||
| }, | ||
| """Run the complete preprocessing workflow and map results to state.""" | ||
| result: PreprocessingState = await self.preprocessing_workflow.ainvoke( | ||
| PreprocessingState( | ||
| user_input=state["user_input"], | ||
| chat_history=state["chat_history"], | ||
| ), | ||
| config, | ||
| ) | ||
| return BaseState(rephrased_input=rephrased_input) | ||
|
|
||
| return self._map_preprocessing_result(result) | ||
|
|
||
| def _map_preprocessing_result(self, result: PreprocessingState) -> BaseState: | ||
| """Map preprocessing results to BaseState with defaults.""" | ||
| return BaseState( | ||
| rephrased_input=result["rephrased_input"], | ||
| safety=result.get("safety", SAFETY_SAFE), | ||
| reason_unsafe=result.get("reason_unsafe", ""), | ||
| expanded_queries=result.get("expanded_queries", []), | ||
| detected_language=result.get("detected_language", DEFAULT_LANGUAGE), | ||
| ) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove |
||
|
|
||
| async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: | ||
| """Postprocess that preserves existing state and conditionally adds search results.""" | ||
| search_results: list[WebSearchResult] = [] | ||
| if config["configurable"]["enable_postprocess"]: | ||
|
|
||
| # Only run external search for safe questions | ||
| if ( | ||
| state.get("safety") == SAFETY_SAFE | ||
| and config["configurable"]["enable_postprocess"] | ||
| ): | ||
| result: SearchState = await self.search_workflow.ainvoke( | ||
| SearchState( | ||
| input=state["rephrased_input"], | ||
|
|
@@ -61,6 +96,9 @@ async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseSta | |
| config=RunnableConfig(callbacks=config["callbacks"]), | ||
| ) | ||
| search_results = result["search_results"] | ||
|
|
||
| # Create new state with updated additional_content | ||
| return BaseState( | ||
| **state, # Copy existing state | ||
| additional_content=AdditionalContent(search_results=search_results) | ||
| ) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please exclude changes to code formatting & comments to unmodified existing code from the diff. |
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -1,52 +1,123 @@ | ||||
| from typing import Any | ||||
| from typing import Any, Literal | ||||
|
|
||||
| from langchain_core.embeddings import Embeddings | ||||
| from langchain_core.language_models.chat_models import BaseChatModel | ||||
| from langchain_core.messages import AIMessage, HumanMessage | ||||
| from langchain_core.runnables import Runnable, RunnableConfig | ||||
| from langchain_openai import ChatOpenAI | ||||
| from langgraph.graph.state import StateGraph | ||||
|
|
||||
| from agent.profiles.base import BaseGraphBuilder, BaseState | ||||
| from agent.profiles.base import (SAFETY_SAFE, SAFETY_UNSAFE, BaseGraphBuilder, | ||||
| BaseState) | ||||
| from agent.tasks.final_answer_generation.unsafe_question import \ | ||||
| create_unsafe_answer_generator | ||||
| from retrievers.reactome.rag import create_reactome_rag | ||||
|
|
||||
|
|
||||
| class ReactToMeState(BaseState): | ||||
| """ReactToMe state extends BaseState with all preprocessing results.""" | ||||
|
|
||||
| pass | ||||
|
|
||||
|
|
||||
| class ReactToMeGraphBuilder(BaseGraphBuilder): | ||||
| def __init__( | ||||
| self, | ||||
| llm: BaseChatModel, | ||||
| embedding: Embeddings, | ||||
| ) -> None: | ||||
| """Graph builder for ReactToMe profile with Reactome-specific functionality.""" | ||||
|
|
||||
| def __init__(self, llm: BaseChatModel, embedding: Embeddings) -> None: | ||||
| """Initialize ReactToMe graph builder with required components.""" | ||||
| super().__init__(llm, embedding) | ||||
|
|
||||
| # Create runnables (tasks & tools) | ||||
| # Create a streaming LLM instance only for final answer generation | ||||
| streaming_llm = ChatOpenAI( | ||||
| model=llm.model_name if hasattr(llm, "model_name") else "gpt-4o-mini", | ||||
| temperature=0.0, | ||||
| streaming=True, | ||||
| ) | ||||
|
|
||||
| self.unsafe_answer_generator = create_unsafe_answer_generator(streaming_llm) | ||||
| self.reactome_rag: Runnable = create_reactome_rag( | ||||
| llm, embedding, streaming=True | ||||
| streaming_llm, embedding, streaming=True | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why create
|
||||
| ) | ||||
|
|
||||
| # Create graph | ||||
| self.uncompiled_graph: StateGraph = self._build_workflow() | ||||
|
|
||||
| def _build_workflow(self) -> StateGraph: | ||||
| """Build and configure the ReactToMe workflow graph.""" | ||||
| state_graph = StateGraph(ReactToMeState) | ||||
| # Set up nodes | ||||
|
|
||||
| # Add nodes | ||||
| state_graph.add_node("preprocess", self.preprocess) | ||||
| state_graph.add_node("model", self.call_model) | ||||
| state_graph.add_node("generate_unsafe_response", self.generate_unsafe_response) | ||||
| state_graph.add_node("postprocess", self.postprocess) | ||||
| # Set up edges | ||||
|
|
||||
| # Add edges | ||||
| state_graph.set_entry_point("preprocess") | ||||
| state_graph.add_edge("preprocess", "model") | ||||
| state_graph.add_conditional_edges( | ||||
| "preprocess", | ||||
| self.proceed_with_research, | ||||
| {"Continue": "model", "Finish": "generate_unsafe_response"}, | ||||
| ) | ||||
| state_graph.add_edge("model", "postprocess") | ||||
| state_graph.add_edge("generate_unsafe_response", "postprocess") | ||||
| state_graph.set_finish_point("postprocess") | ||||
|
|
||||
| self.uncompiled_graph: StateGraph = state_graph | ||||
| return state_graph | ||||
|
|
||||
| async def preprocess( | ||||
| self, state: ReactToMeState, config: RunnableConfig | ||||
| ) -> ReactToMeState: | ||||
| """Run preprocessing workflow.""" | ||||
| result = await super().preprocess(state, config) | ||||
| return ReactToMeState(**result) | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to define this |
||||
|
|
||||
| async def proceed_with_research( | ||||
| self, state: ReactToMeState | ||||
| ) -> Literal["Continue", "Finish"]: | ||||
| """Determine whether to proceed with research based on safety check.""" | ||||
| return "Continue" if state["safety"] == SAFETY_SAFE else "Finish" | ||||
|
|
||||
| async def generate_unsafe_response( | ||||
| self, state: ReactToMeState, config: RunnableConfig | ||||
| ) -> ReactToMeState: | ||||
| """Generate appropriate refusal response for unsafe queries.""" | ||||
| final_answer_message = await self.unsafe_answer_generator.ainvoke( | ||||
| { | ||||
| "language": state["detected_language"], | ||||
| "user_input": state["rephrased_input"], | ||||
| "reason_unsafe": state["reason_unsafe"], | ||||
| }, | ||||
| config, | ||||
| ) | ||||
|
|
||||
| final_answer = ( | ||||
| final_answer_message.content | ||||
| if hasattr(final_answer_message, "content") | ||||
| else str(final_answer_message) | ||||
| ) | ||||
|
|
||||
| return ReactToMeState( | ||||
| chat_history=[ | ||||
| HumanMessage(state["user_input"]), | ||||
| ( | ||||
| final_answer_message | ||||
| if hasattr(final_answer_message, "content") | ||||
| else AIMessage(final_answer) | ||||
| ), | ||||
| ], | ||||
| answer=final_answer, | ||||
| safety=SAFETY_UNSAFE, | ||||
| additional_content={"search_results": []}, | ||||
| ) | ||||
|
|
||||
| async def call_model( | ||||
| self, state: ReactToMeState, config: RunnableConfig | ||||
| ) -> ReactToMeState: | ||||
| """Generate response using Reactome RAG for safe queries.""" | ||||
| result: dict[str, Any] = await self.reactome_rag.ainvoke( | ||||
| { | ||||
| "input": state["rephrased_input"], | ||||
| "expanded_queries": state.get("expanded_queries", []), | ||||
| "chat_history": ( | ||||
| state["chat_history"] | ||||
| if state["chat_history"] | ||||
|
|
@@ -55,6 +126,7 @@ async def call_model( | |||
| }, | ||||
| config, | ||||
| ) | ||||
|
|
||||
| return ReactToMeState( | ||||
| chat_history=[ | ||||
| HumanMessage(state["user_input"]), | ||||
|
|
@@ -64,8 +136,6 @@ async def call_model( | |||
| ) | ||||
|
|
||||
|
|
||||
| def create_reactome_graph( | ||||
| llm: BaseChatModel, | ||||
| embedding: Embeddings, | ||||
| ) -> StateGraph: | ||||
| def create_reactome_graph(llm: BaseChatModel, embedding: Embeddings) -> StateGraph: | ||||
| """Create and return the ReactToMe workflow graph.""" | ||||
| return ReactToMeGraphBuilder(llm, embedding).uncompiled_graph | ||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| from langchain_core.language_models.chat_models import BaseChatModel | ||
| from langchain_core.prompts import ChatPromptTemplate | ||
| from langchain_core.runnables import Runnable | ||
|
|
||
|
|
||
| # unsafe or out of scope answer generator | ||
| def create_unsafe_answer_generator(llm: BaseChatModel) -> Runnable: | ||
| """ | ||
| Create an unsafe answer generator chain. | ||
|
|
||
| Args: | ||
| llm: Language model to use | ||
|
|
||
| Returns: | ||
| Runnable that takes language, user_input, reactome_context, uniprot_context, chat_history | ||
| """ | ||
| system_prompt = """ | ||
| You are an expert scientific assistant operating under the React-to-Me platform. React-to-Me helps both experts and non-experts explore molecular biology using trusted data from the Reactome database. | ||
|
|
||
| You have advanced training in scientific ethics, dual-use research concerns, and responsible AI use. | ||
|
|
||
| You will receive three inputs: | ||
| 1. The user's question. | ||
| 2. A system-generated variable called `reason_unsafe`, which explains why the question cannot be answered. | ||
| 3. The user's preferred language (as a language code or name). | ||
|
|
||
| Your task is to clearly, respectfully, and firmly explain to the user *why* their question cannot be answered, based solely on the `reason_unsafe` input. Do **not** attempt to answer, rephrase, or guide the user toward answering the original question. | ||
|
|
||
| You must: | ||
| - Respond in the user’s preferred language. | ||
| - Politely explain the refusal, grounded in the `reason_unsafe`. | ||
| - Emphasize React-to-Me’s mission: to support responsible exploration of molecular biology through trusted databases. | ||
| - Suggest examples of appropriate topics (e.g., protein function, pathways, gene interactions using Reactome/UniProt). | ||
|
|
||
| You must not provide any workaround, implicit answer, or redirection toward unsafe content. | ||
| """ | ||
| prompt = ChatPromptTemplate.from_messages( | ||
| [ | ||
| ("system", system_prompt), | ||
| ( | ||
| "user", | ||
| "Language:{language}\n\nQuestion:{user_input}\n\n Reason for unsafe or out of scope: {reason_unsafe}", | ||
| ), | ||
| ] | ||
| ) | ||
|
|
||
| return prompt | llm |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| import json | ||
| from typing import List | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. importing |
||
|
|
||
| from langchain.prompts import ChatPromptTemplate | ||
| from langchain_core.language_models.chat_models import BaseChatModel | ||
| from langchain_core.output_parsers import StrOutputParser | ||
| from langchain_core.runnables import Runnable, RunnableLambda | ||
|
|
||
|
|
||
| def QueryExpansionParser(output: str) -> List[str]: | ||
| """Parse JSON array output from LLM.""" | ||
| try: | ||
| return json.loads(output) | ||
| except json.JSONDecodeError: | ||
| raise ValueError("LLM output was not valid JSON. Output:\n" + output) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will an LLM emitting invalid JSON crash the chatbot here? |
||
|
|
||
|
|
||
| def create_query_expander(llm: BaseChatModel) -> Runnable: | ||
| """ | ||
| Create a query expansion chain that generates 4 alternative queries. | ||
| Args: | ||
| llm: Language model to use | ||
| Returns: | ||
| Runnable that takes standalone_query and returns List[str] | ||
| """ | ||
| system_prompt = """ | ||
| You are a biomedical question expansion engine for information retrieval over the Reactome biological pathway database. | ||
| Given a single user question, generate **exactly 4** alternate standalone questions. These should be: | ||
| - Semantically related to the original question. | ||
| - Lexically diverse to improve retrieval via vector search and RAG-fusion. | ||
| - Biologically enriched with inferred or associated details. | ||
| Your goal is to improve recall of relevant documents by expanding the original query using: | ||
| - Synonymous gene/protein names (e.g., EGFR, ErbB1, HER1) | ||
| - Pathway or process-level context (e.g., signal transduction, apoptosis) | ||
| - Known diseases, phenotypes, or biological functions | ||
| - Cellular localization (e.g., nucleus, cytoplasm, membrane) | ||
| - Upstream/downstream molecular interactions | ||
| Rules: | ||
| - Each question must be **fully standalone** (no "this"/"it"). | ||
| - Do not change the core intent—preserve the user's informational goal. | ||
| - Use appropriate biological terminology and Reactome-relevant concepts. | ||
| - Vary the **phrasing**, **focus**, or **biological angle** of each question. | ||
| - If the input is ambiguous, infer a biologically meaningful interpretation. | ||
| Output: | ||
| Return only a valid JSON array of 4 strings (no explanations, no metadata). | ||
| Do not include any explanations or metadata. | ||
| """ | ||
|
|
||
| prompt = ChatPromptTemplate.from_messages( | ||
| [("system", system_prompt), ("user", "Original Question: {rephrased_input}")] | ||
| ) | ||
|
|
||
| return prompt | llm | StrOutputParser() | RunnableLambda(QueryExpansionParser) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please exclude changes to code formatting & comments to unmodified existing code from the diff.