2025-02-17 15:21:20 +07:00

98 lines
2.9 KiB
JavaScript

import { approveRequiredTool1 , approveRequiredTool2, regularTool1, regularTool2} from "./tools/testing.js";
import { ToolNode } from "@langchain/langgraph/prebuilt";
import { ChatOpenAI } from "@langchain/openai";
import {
StateGraph,
MessagesAnnotation,
END,
START,
interrupt,
} from "@langchain/langgraph";
import { PostgresSaver } from "@langchain/langgraph-checkpoint-postgres";
const checkpointer = PostgresSaver.fromConnString('postgres://postgres:eMd9hliASrN1yuNOYSk7LtOdlLRnnlnhUF31JKww6zQ=@supabase-pgdb-1:5432/solana');
const approveRequiredTools = [approveRequiredTool1, approveRequiredTool2];
const allTools = [regularTool1, regularTool2, ...approveRequiredTools];
const allToolsNode = new ToolNode(allTools);
const model = new ChatOpenAI({
model: "gpt-4o",
temperature: 0,
apiKey: process.env.OPENAI_API_KEY,
});
const modelWithTools = model.bindTools(allTools);
function shouldContinue(state) {
const lastMessage = state.messages[state.messages.length - 1];
if (
"tool_calls" in lastMessage &&
Array.isArray(lastMessage.tool_calls) &&
lastMessage.tool_calls?.length
) {
return "checkApproval";
}
return END;
}
async function callModel(state) {
const messages = state.messages;
const response = await modelWithTools.invoke(messages);
return { messages: [response] };
}
function checkApproval(state) {
const lastMessage = state.messages[state.messages.length - 1];
// Identify tool calls that require approval
const toolCallsNeedingApproval = lastMessage.tool_calls.filter((tc) =>
approveRequiredTools.some((tool) => tool.name === tc.name),
);
if (toolCallsNeedingApproval.length === 0) {
return { messages: [lastMessage] }; // No approval needed, return as is
}
// Reset only the tools requiring approval
const resetToolCalls = toolCallsNeedingApproval.map((tc) => ({
...tc,
args: {
...tc.args,
debug: false,
},
}));
// Interrupt with only the reset tool calls to request user approval
const toolCallsListReviewed = interrupt(resetToolCalls);
// Process user input and update debug values
const updatedToolCalls = lastMessage.tool_calls.map((tc) =>
toolCallsListReviewed[tc.id] !== undefined
? { ...tc, args: { ...tc.args, debug: toolCallsListReviewed[tc.id] } }
: tc,
);
const updatedMessage = {
role: "ai",
content: lastMessage.content,
tool_calls: updatedToolCalls,
id: lastMessage.id,
};
console.log(updatedMessage);
return { messages: [updatedMessage] };
}
const workflow = new StateGraph(MessagesAnnotation)
.addNode("agent", callModel)
.addNode("checkApproval", checkApproval)
.addNode("allToolsNode", allToolsNode)
.addEdge(START, "agent")
.addConditionalEdges("agent", shouldContinue, ["checkApproval", END])
.addEdge("checkApproval", "allToolsNode")
.addEdge("allToolsNode", "agent");
const graph = workflow.compile({
checkpointer,
});
export default graph