書接上文(【AI Agent】【LangGraph】0. 快速上手:協(xié)同LangChain,LangGraph幫你用圖結(jié)構(gòu)輕松構(gòu)建多智能體),前面我們了解了 LangGraph 的概念和基本構(gòu)造方法,今天我們來看下 LangGraph 構(gòu)造中的進(jìn)階用法:給邊加個條件 - 條件分支(Conditional edges)。 LangGraph 構(gòu)造的是個圖的數(shù)據(jù)結(jié)構(gòu),有節(jié)點(diǎn)(node) 和邊(edge),那它的邊也可以是帶條件的。如何給邊加入條件呢?可以通過 add_conditional_edges 函數(shù)添加帶條件的邊。 1. 完整代碼及運(yùn)行廢話不多說,先上完整代碼,和運(yùn)行結(jié)果。先跑起來看看效果再說。 from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage, BaseMessage from langgraph.graph import END, MessageGraph import json from langchain_core.messages import ToolMessage from langchain_core.tools import tool from langchain_core.utils.function_calling import convert_to_openai_tool from typing import List
@tool def multiply(first_number: int, second_number: int): """Multiplies two numbers together.""" return first_number * second_number
model = ChatOpenAI(temperature=0) model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])
graph = MessageGraph()
def invoke_model(state: List[BaseMessage]): return model_with_tools.invoke(state)
graph.add_node("oracle", invoke_model)
def invoke_tool(state: List[BaseMessage]): tool_calls = state[-1].additional_kwargs.get("tool_calls", []) multiply_call = None
for tool_call in tool_calls: if tool_call.get("function").get("name") == "multiply": multiply_call = tool_call
if multiply_call is None: raise Exception("No adder input found.")
res = multiply.invoke( json.loads(multiply_call.get("function").get("arguments")) )
return ToolMessage( tool_call_id=multiply_call.get("id"), content=res )
graph.add_node("multiply", invoke_tool)
graph.add_edge("multiply", END)
graph.set_entry_point("oracle")
def router(state: List[BaseMessage]): tool_calls = state[-1].additional_kwargs.get("tool_calls", []) if len(tool_calls): return "multiply" else: return "end"
graph.add_conditional_edges("oracle", router, { "multiply": "multiply", "end": END, })
runnable = graph.compile()
response = runnable.invoke(HumanMessage("What is 123 * 456?")) print(response)
運(yùn)行結(jié)果如下: 2. 代碼詳解下面對上面的代碼進(jìn)行詳細(xì)解釋。 2.1 add_conditional_edges首先,我們知道了可以通過 add_conditional_edges 來對邊進(jìn)行條件添加。這部分代碼如下: graph.add_conditional_edges("oracle", router, { "multiply": "multiply", "end": END, })
add_conditional_edges 接收三個參數(shù):
如上面的代碼,意思就是往 “oracle” node上添加邊,這個node有兩條邊,一條是往“multiply” node上走,一條是往“END”上走。怎么決定往哪個方向去:條件是 router(后面解釋),如果 router 返回的是“multiply”,則往“multiply”方向走,如果 router 返回的是 “end”,則走“END”。 來看下這個函數(shù)的源碼: def add_conditional_edges( self, start_key: str, condition: Callable[..., str], conditional_edge_mapping: Optional[Dict[str, str]] = None, ) -> None: if self.compiled: logger.warning( "Adding an edge to a graph that has already been compiled. This will " "not be reflected in the compiled graph." ) if start_key not in self.nodes: raise ValueError(f"Need to add_node `{start_key}` first") if iscoroutinefunction(condition): raise ValueError("Condition cannot be a coroutine function") if conditional_edge_mapping and set( conditional_edge_mapping.values() ).difference([END]).difference(self.nodes): raise ValueError( f"Missing nodes which are in conditional edge mapping. Mapping " f"contains possible destinations: " f"{list(conditional_edge_mapping.values())}. Possible nodes are " f"{list(self.nodes.keys())}." )
self.branches[start_key].append(Branch(condition, conditional_edge_mapping))
重點(diǎn)是這一句:self.branches[start_key].append(Branch(condition, conditional_edge_mapping)) ,給當(dāng)前node添加分支Branch。 2.2 條件 router條件代碼如下:判斷執(zhí)行結(jié)果中是否有 tool_calls 參數(shù),如果有,則返回"multiply",沒有,則返回“end”。 def router(state: List[BaseMessage]): tool_calls = state[-1].additional_kwargs.get("tool_calls", []) if len(tool_calls): return "multiply" else: return "end"
2.3 各node的定義(1)起始node:oracle @tool def multiply(first_number: int, second_number: int): """Multiplies two numbers together.""" return first_number * second_number
model = ChatOpenAI(temperature=0) model_with_tools = model.bind(tools=[convert_to_openai_tool(multiply)])
graph = MessageGraph()
def invoke_model(state: List[BaseMessage]): return model_with_tools.invoke(state)
graph.add_node("oracle", invoke_model)
這個node是一個帶有Tools 的 ChatOpenAI。在LangChain中使用Tools的詳細(xì)教程請看這篇文章:【AI大模型應(yīng)用開發(fā)】【LangChain系列】5. LangChain入門:智能體Agents模塊的實戰(zhàn)詳解。簡單解釋就是:這個node的執(zhí)行結(jié)果,將返回是否應(yīng)該使用綁定的Tools。 (2)multiply def invoke_tool(state: List[BaseMessage]): tool_calls = state[-1].additional_kwargs.get("tool_calls", []) multiply_call = None
for tool_call in tool_calls: if tool_call.get("function").get("name") == "multiply": multiply_call = tool_call
if multiply_call is None: raise Exception("No adder input found.")
res = multiply.invoke( json.loads(multiply_call.get("function").get("arguments")) )
return ToolMessage( tool_call_id=multiply_call.get("id"), content=res )
graph.add_node("multiply", invoke_tool)
這個node的作用就是執(zhí)行Tools。 2.4 總體流程如果覺得本文對你有幫助,麻煩點(diǎn)個贊和關(guān)注唄 ~~~
|