282 lines
8.4 KiB
JavaScript
282 lines
8.4 KiB
JavaScript
![]() |
import { tool } from "@langchain/core/tools";
|
||
|
import { z } from "zod";
|
||
|
import { ToolNode } from "@langchain/langgraph/prebuilt";
|
||
|
import { ChatOpenAI } from "@langchain/openai";
|
||
|
import {
|
||
|
StateGraph,
|
||
|
MessagesAnnotation,
|
||
|
END,
|
||
|
START,
|
||
|
interrupt,
|
||
|
MemorySaver,
|
||
|
Command,
|
||
|
} from "@langchain/langgraph";
|
||
|
import {
|
||
|
AIMessage,
|
||
|
ToolMessage,
|
||
|
HumanMessage,
|
||
|
isAIMessage,
|
||
|
} from "@langchain/core/messages";
|
||
|
// npm install langchain inquirer @langchain/langgraph @langchain/openai zod
|
||
|
import inquirer from "inquirer";
|
||
|
|
||
|
async function handleToolApproval(interruptEvent) {
|
||
|
const toolCalls = interruptEvent[0].value;
|
||
|
const approvalResults = {};
|
||
|
|
||
|
for (const toolCall of toolCalls) {
|
||
|
const { id, name, args } = toolCall;
|
||
|
|
||
|
const answer = await inquirer.prompt([
|
||
|
{
|
||
|
type: "confirm",
|
||
|
name: "approved",
|
||
|
message: `Approve execution of ${name} with parameters:\n${JSON.stringify(args)}?`,
|
||
|
default: false,
|
||
|
},
|
||
|
]);
|
||
|
|
||
|
approvalResults[id] = answer.approved;
|
||
|
}
|
||
|
|
||
|
return new Command({
|
||
|
resume: approvalResults,
|
||
|
});
|
||
|
}
|
||
|
|
||
|
const approveRequiredTool1 = tool(
|
||
|
({ testParam, debug = false }) => {
|
||
|
if (debug) {
|
||
|
return `approveRequiredTool1 called successfully with ${testParam}`;
|
||
|
}
|
||
|
return `approveRequiredTool1 call failed with ${testParam}. User denied the requiest`;
|
||
|
},
|
||
|
{
|
||
|
name: "approveRequiredTool1",
|
||
|
description: "Call when user ask for the function approveRequiredTool1",
|
||
|
schema: z.object({
|
||
|
testParam: z
|
||
|
.string()
|
||
|
.describe("Some random string you come up with just for testing"),
|
||
|
debug: z
|
||
|
.boolean()
|
||
|
.default(false)
|
||
|
.describe("User controlled variable, you cannot control this"),
|
||
|
}),
|
||
|
},
|
||
|
);
|
||
|
|
||
|
const approveRequiredTool2 = tool(
|
||
|
({ testParam, debug = false }) => {
|
||
|
if (debug) {
|
||
|
return `approveRequiredTool2 called successfully with ${testParam}`;
|
||
|
}
|
||
|
return `approveRequiredTool2 call failed with ${testParam}. User denied the requiest`;
|
||
|
},
|
||
|
{
|
||
|
name: "approveRequiredTool2",
|
||
|
description: "Call when user ask for the function approveRequiredTool2",
|
||
|
schema: z.object({
|
||
|
testParam: z
|
||
|
.string()
|
||
|
.describe("Some random string you come up with just for testing"),
|
||
|
debug: z
|
||
|
.boolean()
|
||
|
.default(false)
|
||
|
.describe("User controlled variable, you cannot control this"),
|
||
|
}),
|
||
|
},
|
||
|
);
|
||
|
|
||
|
const regularTool1 = tool(
|
||
|
(_) => {
|
||
|
return "regularTool1 called successfully";
|
||
|
},
|
||
|
{
|
||
|
name: "regularTool1",
|
||
|
description: "Call when user ask for the function regularTool1",
|
||
|
schema: z.string(),
|
||
|
},
|
||
|
);
|
||
|
|
||
|
const regularTool2 = tool(
|
||
|
(_) => {
|
||
|
return "regularTool2 called successfully";
|
||
|
},
|
||
|
{
|
||
|
name: "regularTool2",
|
||
|
description: "Call when user ask for the function regularTool2",
|
||
|
schema: z.string(),
|
||
|
},
|
||
|
);
|
||
|
|
||
|
const approveRequiredTools = [approveRequiredTool1, approveRequiredTool2];
|
||
|
const allTools = [regularTool1, regularTool2, ...approveRequiredTools];
|
||
|
const allToolsNode = new ToolNode(allTools);
|
||
|
const model = new ChatOpenAI({
|
||
|
model: "gpt-4o",
|
||
|
temperature: 0,
|
||
|
apiKey:
|
||
|
"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",
|
||
|
});
|
||
|
|
||
|
const modelWithTools = model.bindTools(allTools);
|
||
|
|
||
|
function shouldContinue(state) {
|
||
|
const lastMessage = state.messages[state.messages.length - 1];
|
||
|
if (
|
||
|
"tool_calls" in lastMessage &&
|
||
|
Array.isArray(lastMessage.tool_calls) &&
|
||
|
lastMessage.tool_calls?.length
|
||
|
) {
|
||
|
return "checkApproval";
|
||
|
}
|
||
|
return END;
|
||
|
}
|
||
|
|
||
|
async function callModel(state) {
|
||
|
const messages = state.messages;
|
||
|
const response = await modelWithTools.invoke(messages);
|
||
|
return { messages: [response] };
|
||
|
}
|
||
|
|
||
|
function checkApproval(state) {
|
||
|
const lastMessage = state.messages[state.messages.length - 1];
|
||
|
|
||
|
// Identify tool calls that require approval
|
||
|
const toolCallsNeedingApproval = lastMessage.tool_calls.filter((tc) =>
|
||
|
approveRequiredTools.some((tool) => tool.name === tc.name)
|
||
|
);
|
||
|
|
||
|
if (toolCallsNeedingApproval.length === 0) {
|
||
|
return { messages: [lastMessage] }; // No approval needed, return as is
|
||
|
}
|
||
|
|
||
|
|
||
|
// Reset only the tools requiring approval
|
||
|
const resetToolCalls = toolCallsNeedingApproval.map((tc) => ({
|
||
|
...tc,
|
||
|
args: {
|
||
|
...tc.args,
|
||
|
debug: false,
|
||
|
},
|
||
|
}));
|
||
|
|
||
|
// Interrupt with only the reset tool calls to request user approval
|
||
|
const toolCallsListReviewed = interrupt(resetToolCalls);
|
||
|
|
||
|
// Process user input and update debug values
|
||
|
const updatedToolCalls = lastMessage.tool_calls.map((tc) =>
|
||
|
toolCallsListReviewed[tc.id] !== undefined
|
||
|
? { ...tc, args: { ...tc.args, debug: toolCallsListReviewed[tc.id] } }
|
||
|
: tc
|
||
|
);
|
||
|
|
||
|
const updatedMessage = {
|
||
|
role: "ai",
|
||
|
content: lastMessage.content,
|
||
|
tool_calls: updatedToolCalls,
|
||
|
id: lastMessage.id,
|
||
|
};
|
||
|
|
||
|
return { messages: [updatedMessage] };
|
||
|
}
|
||
|
|
||
|
|
||
|
const workflow = new StateGraph(MessagesAnnotation)
|
||
|
.addNode("agent", callModel)
|
||
|
.addNode("checkApproval", checkApproval)
|
||
|
.addNode("allToolsNode", allToolsNode)
|
||
|
.addEdge(START, "agent")
|
||
|
.addConditionalEdges("agent", shouldContinue, ["checkApproval", END])
|
||
|
.addEdge("checkApproval", "allToolsNode")
|
||
|
.addEdge("allToolsNode", "agent");
|
||
|
|
||
|
const memory = new MemorySaver();
|
||
|
|
||
|
const graph = workflow.compile({
|
||
|
checkpointer: memory,
|
||
|
});
|
||
|
|
||
|
// const drawableGraph = graph.getGraph();
|
||
|
// const image = await drawableGraph.drawMermaidPng();
|
||
|
// const arrayBuffer = await image.arrayBuffer();
|
||
|
|
||
|
// // Save the image to "graph.png" in the current directory
|
||
|
// fs.writeFileSync("graph.png", Buffer.from(arrayBuffer));
|
||
|
|
||
|
// console.log("Graph saved as graph.png");
|
||
|
|
||
|
const input = {
|
||
|
role: "user",
|
||
|
content:
|
||
|
// "Call approveRequiredTool1 and approveRequiredTool2, then tell me what did they",
|
||
|
"In the first round ,call regularTool2 then regularTool1. After receiving the response call approveRequiredTool1 and regularTool2.After receiving the response call approveRequiredTool2 and approveRequiredTool1. AFter receiving their response tell me what did each tool tell you in each round",
|
||
|
};
|
||
|
|
||
|
const config = { configurable: { thread_id: "3" } };
|
||
|
console.log(`\n\n\n\n`);
|
||
|
|
||
|
console.log("=".repeat(75));
|
||
|
console.log(`${'='.repeat(25)} User Message ${'='.repeat(25)}`);
|
||
|
console.log(input.content);
|
||
|
|
||
|
async function processStream(command, config) {
|
||
|
for await (const event of await graph.stream(command, config)) {
|
||
|
if (event.__interrupt__) {
|
||
|
continue;
|
||
|
} else if (event.agent) {
|
||
|
console.log("=".repeat(75));
|
||
|
console.log(`${'='.repeat(25)} Agent Message ${'='.repeat(25)}`);
|
||
|
const ai_message_content = event.agent.messages[0].content;
|
||
|
const ai_tool_content = event.agent.messages[0].tool_calls;
|
||
|
if (ai_message_content) console.log(`Agent message: ${ai_message_content}`);
|
||
|
if (ai_tool_content) {
|
||
|
ai_tool_content.forEach((tool) => {
|
||
|
console.log(`Tool Name: ${tool.name}`);
|
||
|
console.log(`Arguments:`);
|
||
|
Object.entries(tool.args).forEach(([key, value]) => {
|
||
|
console.log(` - ${key}: ${value}`);
|
||
|
});
|
||
|
console.log("--------------------");
|
||
|
});
|
||
|
}
|
||
|
} else if (event.allToolsNode) {
|
||
|
console.log("=".repeat(75));
|
||
|
console.log(`${'='.repeat(25)} Tools Message ${'='.repeat(25)}`);
|
||
|
event.allToolsNode.messages.forEach((message) => {
|
||
|
console.log({
|
||
|
name: message.name,
|
||
|
content: message.content,
|
||
|
tool_call_id: message.tool_call_id,
|
||
|
});
|
||
|
});
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
while (true) {
|
||
|
const memoryData = await memory.get(config);
|
||
|
const lastMessage = memoryData?.channel_values?.messages?.at(-1);
|
||
|
const lastMessage_hasToolCalls = lastMessage?.tool_calls?.length > 0;
|
||
|
const lastMessage_isAIMessageFlag = lastMessage ? isAIMessage(lastMessage) : false;
|
||
|
|
||
|
if (lastMessage_isAIMessageFlag && !lastMessage_hasToolCalls) {
|
||
|
console.log("Breaking loop: Last message is an AI message with no tool calls.");
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
if (typeof memoryData?.channel_values === "object" && "branch:agent:condition:checkApproval" in memoryData.channel_values) {
|
||
|
console.log("=".repeat(75));
|
||
|
console.log(`${'='.repeat(25)} Graph interrupted for user input ${'='.repeat(25)}`);
|
||
|
let approvalCommand;
|
||
|
for await (const event of await graph.stream(null, config)) {
|
||
|
approvalCommand = await handleToolApproval(event.__interrupt__);
|
||
|
}
|
||
|
await processStream(approvalCommand, config);
|
||
|
} else {
|
||
|
await processStream({ messages: [input] }, config);
|
||
|
}
|
||
|
}
|
||
|
|