import { approveRequiredTool1 , approveRequiredTool2, regularTool1, regularTool2} from "@/server/tools/testing";
import { ToolNode } from "@langchain/langgraph/prebuilt";
import { ChatOpenAI } from "@langchain/openai";
import {
    StateGraph,
    MessagesAnnotation,
    END,
    START,
    interrupt,
  } from "@langchain/langgraph";  
import checkpointer from "@/server/checkpointer";
import { transferSolanaTool } from "@/server/tools/solana/transfer";

const getGraph = (solAgentKit) => {
  const transferSolTool = transferSolanaTool(solAgentKit)
  const approveRequiredTools = [approveRequiredTool1, approveRequiredTool2, transferSolTool];
  const allTools = [regularTool1, regularTool2, ...approveRequiredTools];
  const allToolsNode = new ToolNode(allTools);
  const model = new ChatOpenAI({
    model: "gpt-4o",
    temperature: 0,
    apiKey: process.env.OPENAI_API_KEY,
  });
  
  const modelWithTools = model.bindTools(allTools);
  
  function shouldContinue(state) {
    const lastMessage = state.messages[state.messages.length - 1];
    if (
      "tool_calls" in lastMessage &&
      Array.isArray(lastMessage.tool_calls) &&
      lastMessage.tool_calls?.length
    ) {
      return "checkApproval";
    }
    return END;
  }
  
  async function callModel(state) {
    const messages = state.messages;
    const response = await modelWithTools.invoke(messages);
    return { messages: [response] };
  }
  
  function checkApproval(state) {
    const lastMessage = state.messages[state.messages.length - 1];
  
    // Identify tool calls that require approval
    const toolCallsNeedingApproval = lastMessage.tool_calls.filter((tc) =>
      approveRequiredTools.some((tool) => tool.name === tc.name),
    );
  
    if (toolCallsNeedingApproval.length === 0) {
      return { messages: [lastMessage] }; // No approval needed, return as is
    }
  
    // Reset only the tools requiring approval
    const resetToolCalls = toolCallsNeedingApproval.map((tc) => ({
      ...tc,
      args: {
        ...tc.args,
        debug: false,
      },
    }));
  
    // Interrupt with only the reset tool calls to request user approval
    const toolCallsListReviewed = interrupt(resetToolCalls);
  
    // Process user input and update debug values
    const updatedToolCalls = lastMessage.tool_calls.map((tc) =>
      toolCallsListReviewed[tc.id] !== undefined
        ? { ...tc, args: { ...tc.args, debug: toolCallsListReviewed[tc.id] } }
        : tc,
    );
  
    const updatedMessage = {
      role: "ai",
      content: lastMessage.content,
      tool_calls: updatedToolCalls,
      id: lastMessage.id,
    };
    console.log(updatedMessage);
  
    return { messages: [updatedMessage] };
  }
  
  const workflow = new StateGraph(MessagesAnnotation)
    .addNode("agent", callModel)
    .addNode("checkApproval", checkApproval)
    .addNode("allToolsNode", allToolsNode)
    .addEdge(START, "agent")
    .addConditionalEdges("agent", shouldContinue, ["checkApproval", END])
    .addEdge("checkApproval", "allToolsNode")
    .addEdge("allToolsNode", "agent");
  
  const graph = workflow.compile({
    checkpointer,
  });
  return graph
}

export default getGraph;