mirror of
https://github.com/trushildhokiya/allininx-2.git
synced 2025-03-15 17:58:39 +00:00
99 lines
2.7 KiB
Python
99 lines
2.7 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
import uuid
|
||
|
from typing import Any, Dict, List
|
||
|
|
||
|
from langchain.embeddings import OpenAIEmbeddings
|
||
|
from langchain.embeddings.base import Embeddings
|
||
|
from pinecone import Index # import doesnt work on plane wifi
|
||
|
from pydantic import BaseModel
|
||
|
|
||
|
from reworkd_platform.settings import settings
|
||
|
from reworkd_platform.timer import timed_function
|
||
|
from reworkd_platform.web.api.memory.memory import AgentMemory
|
||
|
|
||
|
OPENAI_EMBEDDING_DIM = 1536
|
||
|
|
||
|
|
||
|
class Row(BaseModel):
|
||
|
id: str
|
||
|
values: List[float]
|
||
|
metadata: Dict[str, Any] = {}
|
||
|
|
||
|
|
||
|
class QueryResult(BaseModel):
|
||
|
id: str
|
||
|
score: float
|
||
|
metadata: Dict[str, Any] = {}
|
||
|
|
||
|
|
||
|
class PineconeMemory(AgentMemory):
|
||
|
"""
|
||
|
Wrapper around pinecone
|
||
|
"""
|
||
|
|
||
|
def __init__(self, index_name: str, namespace: str = ""):
|
||
|
self.index = Index(settings.pinecone_index_name)
|
||
|
self.namespace = namespace or index_name
|
||
|
|
||
|
@timed_function(level="DEBUG")
|
||
|
def __enter__(self) -> AgentMemory:
|
||
|
self.embeddings: Embeddings = OpenAIEmbeddings(
|
||
|
client=None, # Meta private value but mypy will complain its missing
|
||
|
openai_api_key=settings.openai_api_key,
|
||
|
)
|
||
|
|
||
|
return self
|
||
|
|
||
|
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||
|
pass
|
||
|
|
||
|
@timed_function(level="DEBUG")
|
||
|
def reset_class(self) -> None:
|
||
|
self.index.delete(delete_all=True, namespace=self.namespace)
|
||
|
|
||
|
@timed_function(level="DEBUG")
|
||
|
def add_tasks(self, tasks: List[str]) -> List[str]:
|
||
|
if len(tasks) == 0:
|
||
|
return []
|
||
|
|
||
|
embeds = self.embeddings.embed_documents(tasks)
|
||
|
|
||
|
if len(tasks) != len(embeds):
|
||
|
raise ValueError("Embeddings and tasks are not the same length")
|
||
|
|
||
|
rows = [
|
||
|
Row(values=vector, metadata={"text": tasks[i]}, id=str(uuid.uuid4()))
|
||
|
for i, vector in enumerate(embeds)
|
||
|
]
|
||
|
|
||
|
self.index.upsert(
|
||
|
vectors=[row.dict() for row in rows], namespace=self.namespace
|
||
|
)
|
||
|
|
||
|
return [row.id for row in rows]
|
||
|
|
||
|
@timed_function(level="DEBUG")
|
||
|
def get_similar_tasks(
|
||
|
self, text: str, score_threshold: float = 0.95
|
||
|
) -> List[QueryResult]:
|
||
|
# Get similar tasks
|
||
|
vector = self.embeddings.embed_query(text)
|
||
|
results = self.index.query(
|
||
|
vector=vector,
|
||
|
top_k=5,
|
||
|
include_metadata=True,
|
||
|
include_values=True,
|
||
|
namespace=self.namespace,
|
||
|
)
|
||
|
|
||
|
return [
|
||
|
QueryResult(id=row.id, score=row.score, metadata=row.metadata)
|
||
|
for row in getattr(results, "matches", [])
|
||
|
if row.score > score_threshold
|
||
|
]
|
||
|
|
||
|
@staticmethod
|
||
|
def should_use() -> bool:
|
||
|
return False
|