import { approveRequiredTool1 , approveRequiredTool2, regularTool1, regularTool2} from "@/server/tools/testing"; import { ToolNode } from "@langchain/langgraph/prebuilt"; import { ChatOpenAI } from "@langchain/openai"; import { StateGraph, MessagesAnnotation, END, START, interrupt, } from "@langchain/langgraph"; import checkpointer from "@/server/checkpointer"; import { transferSolanaTool } from "@/server/tools/solana/transfer"; const getGraph = (solAgentKit) => { const transferSolTool = transferSolanaTool(solAgentKit) const approveRequiredTools = [approveRequiredTool1, approveRequiredTool2, transferSolTool]; const allTools = [regularTool1, regularTool2, ...approveRequiredTools]; const allToolsNode = new ToolNode(allTools); const model = new ChatOpenAI({ model: "gpt-4o", temperature: 0, apiKey: process.env.OPENAI_API_KEY, }); const modelWithTools = model.bindTools(allTools); function shouldContinue(state) { const lastMessage = state.messages[state.messages.length - 1]; if ( "tool_calls" in lastMessage && Array.isArray(lastMessage.tool_calls) && lastMessage.tool_calls?.length ) { return "checkApproval"; } return END; } async function callModel(state) { const messages = state.messages; const response = await modelWithTools.invoke(messages); return { messages: [response] }; } function checkApproval(state) { const lastMessage = state.messages[state.messages.length - 1]; // Identify tool calls that require approval const toolCallsNeedingApproval = lastMessage.tool_calls.filter((tc) => approveRequiredTools.some((tool) => tool.name === tc.name), ); if (toolCallsNeedingApproval.length === 0) { return { messages: [lastMessage] }; // No approval needed, return as is } // Reset only the tools requiring approval const resetToolCalls = toolCallsNeedingApproval.map((tc) => ({ ...tc, args: { ...tc.args, debug: false, }, })); // Interrupt with only the reset tool calls to request user approval const toolCallsListReviewed = interrupt(resetToolCalls); // Process user input and update debug values const updatedToolCalls = lastMessage.tool_calls.map((tc) => toolCallsListReviewed[tc.id] !== undefined ? { ...tc, args: { ...tc.args, debug: toolCallsListReviewed[tc.id] } } : tc, ); const updatedMessage = { role: "ai", content: lastMessage.content, tool_calls: updatedToolCalls, id: lastMessage.id, }; console.log(updatedMessage); return { messages: [updatedMessage] }; } const workflow = new StateGraph(MessagesAnnotation) .addNode("agent", callModel) .addNode("checkApproval", checkApproval) .addNode("allToolsNode", allToolsNode) .addEdge(START, "agent") .addConditionalEdges("agent", shouldContinue, ["checkApproval", END]) .addEdge("checkApproval", "allToolsNode") .addEdge("allToolsNode", "agent"); const graph = workflow.compile({ checkpointer, }); return graph } export default getGraph;