102 lines
3.1 KiB
JavaScript
102 lines
3.1 KiB
JavaScript
![]() |
import { approveRequiredTool1 , approveRequiredTool2, regularTool1, regularTool2} from "@/server/tools/testing";
|
||
|
import { ToolNode } from "@langchain/langgraph/prebuilt";
|
||
|
import { ChatOpenAI } from "@langchain/openai";
|
||
|
import {
|
||
|
StateGraph,
|
||
|
MessagesAnnotation,
|
||
|
END,
|
||
|
START,
|
||
|
interrupt,
|
||
|
} from "@langchain/langgraph";
|
||
|
import checkpointer from "@/server/checkpointer";
|
||
|
import { transferSolanaTool } from "@/server/tools/solana/transfer";
|
||
|
|
||
|
const getGraph = (solAgentKit) => {
|
||
|
const transferSolTool = transferSolanaTool(solAgentKit)
|
||
|
const approveRequiredTools = [approveRequiredTool1, approveRequiredTool2, transferSolTool];
|
||
|
const allTools = [regularTool1, regularTool2, ...approveRequiredTools];
|
||
|
const allToolsNode = new ToolNode(allTools);
|
||
|
const model = new ChatOpenAI({
|
||
|
model: "gpt-4o",
|
||
|
temperature: 0,
|
||
|
apiKey: process.env.OPENAI_API_KEY,
|
||
|
});
|
||
|
|
||
|
const modelWithTools = model.bindTools(allTools);
|
||
|
|
||
|
function shouldContinue(state) {
|
||
|
const lastMessage = state.messages[state.messages.length - 1];
|
||
|
if (
|
||
|
"tool_calls" in lastMessage &&
|
||
|
Array.isArray(lastMessage.tool_calls) &&
|
||
|
lastMessage.tool_calls?.length
|
||
|
) {
|
||
|
return "checkApproval";
|
||
|
}
|
||
|
return END;
|
||
|
}
|
||
|
|
||
|
async function callModel(state) {
|
||
|
const messages = state.messages;
|
||
|
const response = await modelWithTools.invoke(messages);
|
||
|
return { messages: [response] };
|
||
|
}
|
||
|
|
||
|
function checkApproval(state) {
|
||
|
const lastMessage = state.messages[state.messages.length - 1];
|
||
|
|
||
|
// Identify tool calls that require approval
|
||
|
const toolCallsNeedingApproval = lastMessage.tool_calls.filter((tc) =>
|
||
|
approveRequiredTools.some((tool) => tool.name === tc.name),
|
||
|
);
|
||
|
|
||
|
if (toolCallsNeedingApproval.length === 0) {
|
||
|
return { messages: [lastMessage] }; // No approval needed, return as is
|
||
|
}
|
||
|
|
||
|
// Reset only the tools requiring approval
|
||
|
const resetToolCalls = toolCallsNeedingApproval.map((tc) => ({
|
||
|
...tc,
|
||
|
args: {
|
||
|
...tc.args,
|
||
|
debug: false,
|
||
|
},
|
||
|
}));
|
||
|
|
||
|
// Interrupt with only the reset tool calls to request user approval
|
||
|
const toolCallsListReviewed = interrupt(resetToolCalls);
|
||
|
|
||
|
// Process user input and update debug values
|
||
|
const updatedToolCalls = lastMessage.tool_calls.map((tc) =>
|
||
|
toolCallsListReviewed[tc.id] !== undefined
|
||
|
? { ...tc, args: { ...tc.args, debug: toolCallsListReviewed[tc.id] } }
|
||
|
: tc,
|
||
|
);
|
||
|
|
||
|
const updatedMessage = {
|
||
|
role: "ai",
|
||
|
content: lastMessage.content,
|
||
|
tool_calls: updatedToolCalls,
|
||
|
id: lastMessage.id,
|
||
|
};
|
||
|
console.log(updatedMessage);
|
||
|
|
||
|
return { messages: [updatedMessage] };
|
||
|
}
|
||
|
|
||
|
const workflow = new StateGraph(MessagesAnnotation)
|
||
|
.addNode("agent", callModel)
|
||
|
.addNode("checkApproval", checkApproval)
|
||
|
.addNode("allToolsNode", allToolsNode)
|
||
|
.addEdge(START, "agent")
|
||
|
.addConditionalEdges("agent", shouldContinue, ["checkApproval", END])
|
||
|
.addEdge("checkApproval", "allToolsNode")
|
||
|
.addEdge("allToolsNode", "agent");
|
||
|
|
||
|
const graph = workflow.compile({
|
||
|
checkpointer,
|
||
|
});
|
||
|
return graph
|
||
|
}
|
||
|
|
||
|
export default getGraph;
|