2025-02-17 19:44:17 +05:30

88 lines
2.1 KiB
Python

from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field, validator
from reworkd_platform.web.api.agent.analysis import Analysis
LLM_Model = Literal[
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4o",
]
Loop_Step = Literal[
"start",
"analyze",
"execute",
"create",
"summarize",
"chat",
]
LLM_MODEL_MAX_TOKENS: Dict[LLM_Model, int] = {
"gpt-3.5-turbo": 4000,
"gpt-3.5-turbo-16k": 16000,
"gpt-4o": 8000,
}
class ModelSettings(BaseModel):
model: LLM_Model = Field(default="gpt-4o")
custom_api_key: Optional[str] = Field(default=None)
temperature: float = Field(default=0.9, ge=0.0, le=1.0)
max_tokens: int = Field(default=500, ge=0)
language: str = Field(default="English")
@validator("max_tokens")
def validate_max_tokens(cls, v: float, values: Dict[str, Any]) -> float:
model = values["model"]
if v > (max_tokens := LLM_MODEL_MAX_TOKENS[model]):
raise ValueError(f"Model {model} only supports {max_tokens} tokens")
return v
class AgentRunCreate(BaseModel):
goal: str
model_settings: ModelSettings = Field(default=ModelSettings())
class AgentRun(AgentRunCreate):
run_id: str
class AgentTaskAnalyze(AgentRun):
task: str
tool_names: List[str] = Field(default=[])
model_settings: ModelSettings = Field(default=ModelSettings())
class AgentTaskExecute(AgentRun):
task: str
analysis: Analysis
class AgentTaskCreate(AgentRun):
tasks: List[str] = Field(default=[])
last_task: Optional[str] = Field(default=None)
result: Optional[str] = Field(default=None)
completed_tasks: List[str] = Field(default=[])
class AgentSummarize(AgentRun):
results: List[str] = Field(default=[])
class AgentChat(AgentRun):
message: str
results: List[str] = Field(default=[])
class NewTasksResponse(BaseModel):
run_id: str
new_tasks: List[str] = Field(alias="newTasks")
class RunCount(BaseModel):
count: int
first_run: Optional[datetime]
last_run: Optional[datetime]