from __future__ import annotations import uuid from typing import Any, Dict, List from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings.base import Embeddings from pinecone import Index # import doesnt work on plane wifi from pydantic import BaseModel from reworkd_platform.settings import settings from reworkd_platform.timer import timed_function from reworkd_platform.web.api.memory.memory import AgentMemory OPENAI_EMBEDDING_DIM = 1536 class Row(BaseModel): id: str values: List[float] metadata: Dict[str, Any] = {} class QueryResult(BaseModel): id: str score: float metadata: Dict[str, Any] = {} class PineconeMemory(AgentMemory): """ Wrapper around pinecone """ def __init__(self, index_name: str, namespace: str = ""): self.index = Index(settings.pinecone_index_name) self.namespace = namespace or index_name @timed_function(level="DEBUG") def __enter__(self) -> AgentMemory: self.embeddings: Embeddings = OpenAIEmbeddings( client=None, # Meta private value but mypy will complain its missing openai_api_key=settings.openai_api_key, ) return self def __exit__(self, *args: Any, **kwargs: Any) -> None: pass @timed_function(level="DEBUG") def reset_class(self) -> None: self.index.delete(delete_all=True, namespace=self.namespace) @timed_function(level="DEBUG") def add_tasks(self, tasks: List[str]) -> List[str]: if len(tasks) == 0: return [] embeds = self.embeddings.embed_documents(tasks) if len(tasks) != len(embeds): raise ValueError("Embeddings and tasks are not the same length") rows = [ Row(values=vector, metadata={"text": tasks[i]}, id=str(uuid.uuid4())) for i, vector in enumerate(embeds) ] self.index.upsert( vectors=[row.dict() for row in rows], namespace=self.namespace ) return [row.id for row in rows] @timed_function(level="DEBUG") def get_similar_tasks( self, text: str, score_threshold: float = 0.95 ) -> List[QueryResult]: # Get similar tasks vector = self.embeddings.embed_query(text) results = self.index.query( vector=vector, top_k=5, include_metadata=True, include_values=True, namespace=self.namespace, ) return [ QueryResult(id=row.id, score=row.score, metadata=row.metadata) for row in getattr(results, "matches", []) if row.score > score_threshold ] @staticmethod def should_use() -> bool: return False