mirror of
https://github.com/trushildhokiya/allininx-2.git
synced 2025-03-15 05:28:39 +00:00
commit -2
This commit is contained in:
parent
48da88668c
commit
bdfe11f039
24
cli/.gitignore
vendored
Normal file
24
cli/.gitignore
vendored
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
|
||||||
|
|
||||||
|
# dependencies
|
||||||
|
/node_modules
|
||||||
|
/.pnp
|
||||||
|
.pnp.js
|
||||||
|
|
||||||
|
# testing
|
||||||
|
/coverage
|
||||||
|
|
||||||
|
# production
|
||||||
|
/build
|
||||||
|
|
||||||
|
# misc
|
||||||
|
.DS_Store
|
||||||
|
*.pem
|
||||||
|
|
||||||
|
# debug
|
||||||
|
npm-debug.log*
|
||||||
|
yarn-debug.log*
|
||||||
|
yarn-error.log*
|
||||||
|
.pnpm-debug.log*
|
||||||
|
|
||||||
|
.eslintcache
|
27
cli/README.md
Normal file
27
cli/README.md
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
## AgentGPT CLI
|
||||||
|
|
||||||
|
AgentGPT CLI is a utility designed to streamline the setup process of your AgentGPT environment.
|
||||||
|
It uses Inquirer to interactively build up ENV values while also validating they are correct.
|
||||||
|
|
||||||
|
This was first created by @JPDucky on GitHub.
|
||||||
|
|
||||||
|
### Running the tool
|
||||||
|
|
||||||
|
```
|
||||||
|
// Running from the root of the project
|
||||||
|
./setup.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
// Running from the cli directory
|
||||||
|
cd cli/
|
||||||
|
npm run start
|
||||||
|
```
|
||||||
|
|
||||||
|
### Updating ENV values
|
||||||
|
|
||||||
|
To update ENV values:
|
||||||
|
|
||||||
|
- Add a question to the list of questions in `index.js` for the ENV value
|
||||||
|
- Add a value in the `envDefinition` for the ENV value
|
||||||
|
- Add the ENV value to the `.env.example` in the root of the project
|
2865
cli/package-lock.json
generated
Normal file
2865
cli/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
32
cli/package.json
Normal file
32
cli/package.json
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
{
|
||||||
|
"name": "agentgpt-cli",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"description": "A CLI to create your AgentGPT environment",
|
||||||
|
"private": true,
|
||||||
|
"engines": {
|
||||||
|
"node": ">=18.0.0 <19.0.0"
|
||||||
|
},
|
||||||
|
"type": "module",
|
||||||
|
"main": "index.js",
|
||||||
|
"scripts": {
|
||||||
|
"start": "node src/index.js",
|
||||||
|
"dev": "node src/index.js"
|
||||||
|
},
|
||||||
|
"author": "reworkd",
|
||||||
|
"dependencies": {
|
||||||
|
"@octokit/auth-basic": "^1.4.8",
|
||||||
|
"@octokit/rest": "^20.0.2",
|
||||||
|
"chalk": "^5.3.0",
|
||||||
|
"clear": "^0.1.0",
|
||||||
|
"clui": "^0.3.6",
|
||||||
|
"configstore": "^6.0.0",
|
||||||
|
"dotenv": "^16.3.1",
|
||||||
|
"figlet": "^1.7.0",
|
||||||
|
"inquirer": "^9.2.12",
|
||||||
|
"lodash": "^4.17.21",
|
||||||
|
"minimist": "^1.2.8",
|
||||||
|
"node-fetch": "^3.3.2",
|
||||||
|
"simple-git": "^3.20.0",
|
||||||
|
"touch": "^3.1.0"
|
||||||
|
}
|
||||||
|
}
|
142
cli/src/envGenerator.js
Normal file
142
cli/src/envGenerator.js
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
import crypto from "crypto";
|
||||||
|
import fs from "fs";
|
||||||
|
import chalk from "chalk";
|
||||||
|
|
||||||
|
export const generateEnv = (envValues) => {
|
||||||
|
let isDockerCompose = envValues.runOption === "docker-compose";
|
||||||
|
let dbPort = isDockerCompose ? 3307 : 3306;
|
||||||
|
let platformUrl = isDockerCompose
|
||||||
|
? "http://host.docker.internal:8000"
|
||||||
|
: "http://localhost:8000";
|
||||||
|
|
||||||
|
const envDefinition = getEnvDefinition(
|
||||||
|
envValues,
|
||||||
|
isDockerCompose,
|
||||||
|
dbPort,
|
||||||
|
platformUrl
|
||||||
|
);
|
||||||
|
|
||||||
|
const envFileContent = generateEnvFileContent(envDefinition);
|
||||||
|
saveEnvFile(envFileContent);
|
||||||
|
};
|
||||||
|
|
||||||
|
const getEnvDefinition = (envValues, isDockerCompose, dbPort, platformUrl) => {
|
||||||
|
return {
|
||||||
|
"Deployment Environment": {
|
||||||
|
NODE_ENV: "development",
|
||||||
|
NEXT_PUBLIC_VERCEL_ENV: "${NODE_ENV}",
|
||||||
|
},
|
||||||
|
NextJS: {
|
||||||
|
NEXT_PUBLIC_BACKEND_URL: "http://localhost:8000",
|
||||||
|
NEXT_PUBLIC_MAX_LOOPS: 100,
|
||||||
|
},
|
||||||
|
"Next Auth config": {
|
||||||
|
NEXTAUTH_SECRET: generateAuthSecret(),
|
||||||
|
NEXTAUTH_URL: "http://localhost:3000",
|
||||||
|
},
|
||||||
|
"Auth providers (Use if you want to get out of development mode sign-in)": {
|
||||||
|
GOOGLE_CLIENT_ID: "***",
|
||||||
|
GOOGLE_CLIENT_SECRET: "***",
|
||||||
|
GITHUB_CLIENT_ID: "***",
|
||||||
|
GITHUB_CLIENT_SECRET: "***",
|
||||||
|
DISCORD_CLIENT_SECRET: "***",
|
||||||
|
DISCORD_CLIENT_ID: "***",
|
||||||
|
},
|
||||||
|
Backend: {
|
||||||
|
REWORKD_PLATFORM_ENVIRONMENT: "${NODE_ENV}",
|
||||||
|
REWORKD_PLATFORM_FF_MOCK_MODE_ENABLED: false,
|
||||||
|
REWORKD_PLATFORM_MAX_LOOPS: "${NEXT_PUBLIC_MAX_LOOPS}",
|
||||||
|
REWORKD_PLATFORM_OPENAI_API_KEY:
|
||||||
|
envValues.OpenAIApiKey || '"<change me>"',
|
||||||
|
REWORKD_PLATFORM_FRONTEND_URL: "http://localhost:3000",
|
||||||
|
REWORKD_PLATFORM_RELOAD: true,
|
||||||
|
REWORKD_PLATFORM_OPENAI_API_BASE: "https://api.openai.com/v1",
|
||||||
|
REWORKD_PLATFORM_SERP_API_KEY: envValues.serpApiKey || '""',
|
||||||
|
REWORKD_PLATFORM_REPLICATE_API_KEY: envValues.replicateApiKey || '""',
|
||||||
|
},
|
||||||
|
"Database (Backend)": {
|
||||||
|
REWORKD_PLATFORM_DATABASE_USER: "reworkd_platform",
|
||||||
|
REWORKD_PLATFORM_DATABASE_PASSWORD: "reworkd_platform",
|
||||||
|
REWORKD_PLATFORM_DATABASE_HOST: "agentgpt_db",
|
||||||
|
REWORKD_PLATFORM_DATABASE_PORT: dbPort,
|
||||||
|
REWORKD_PLATFORM_DATABASE_NAME: "reworkd_platform",
|
||||||
|
REWORKD_PLATFORM_DATABASE_URL:
|
||||||
|
"mysql://${REWORKD_PLATFORM_DATABASE_USER}:${REWORKD_PLATFORM_DATABASE_PASSWORD}@${REWORKD_PLATFORM_DATABASE_HOST}:${REWORKD_PLATFORM_DATABASE_PORT}/${REWORKD_PLATFORM_DATABASE_NAME}",
|
||||||
|
},
|
||||||
|
"Database (Frontend)": {
|
||||||
|
DATABASE_USER: "reworkd_platform",
|
||||||
|
DATABASE_PASSWORD: "reworkd_platform",
|
||||||
|
DATABASE_HOST: "agentgpt_db",
|
||||||
|
DATABASE_PORT: dbPort,
|
||||||
|
DATABASE_NAME: "reworkd_platform",
|
||||||
|
DATABASE_URL:
|
||||||
|
"mysql://${DATABASE_USER}:${DATABASE_PASSWORD}@${DATABASE_HOST}:${DATABASE_PORT}/${DATABASE_NAME}",
|
||||||
|
},
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
const generateEnvFileContent = (config) => {
|
||||||
|
let configFile = "";
|
||||||
|
|
||||||
|
Object.entries(config).forEach(([section, variables]) => {
|
||||||
|
configFile += `# ${section}:\n`;
|
||||||
|
Object.entries(variables).forEach(([key, value]) => {
|
||||||
|
configFile += `${key}=${value}\n`;
|
||||||
|
});
|
||||||
|
configFile += "\n";
|
||||||
|
});
|
||||||
|
|
||||||
|
return configFile.trim();
|
||||||
|
};
|
||||||
|
|
||||||
|
const generateAuthSecret = () => {
|
||||||
|
const length = 32;
|
||||||
|
const buffer = crypto.randomBytes(length);
|
||||||
|
return buffer.toString("base64");
|
||||||
|
};
|
||||||
|
|
||||||
|
const ENV_PATH = "../next/.env";
|
||||||
|
const BACKEND_ENV_PATH = "../platform/.env";
|
||||||
|
|
||||||
|
export const doesEnvFileExist = () => {
|
||||||
|
return fs.existsSync(ENV_PATH);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Read the existing env file, test if it is missing any keys or contains any extra keys
|
||||||
|
export const testEnvFile = () => {
|
||||||
|
const data = fs.readFileSync(ENV_PATH, "utf8");
|
||||||
|
|
||||||
|
// Make a fake definition to compare the keys of
|
||||||
|
const envDefinition = getEnvDefinition({}, "", "", "", "");
|
||||||
|
|
||||||
|
const lines = data
|
||||||
|
.split("\n")
|
||||||
|
.filter((line) => !line.startsWith("#") && line.trim() !== "");
|
||||||
|
const envKeysFromFile = lines.map((line) => line.split("=")[0]);
|
||||||
|
|
||||||
|
const envKeysFromDef = Object.entries(envDefinition).flatMap(
|
||||||
|
([section, entries]) => Object.keys(entries)
|
||||||
|
);
|
||||||
|
|
||||||
|
const missingFromFile = envKeysFromDef.filter(
|
||||||
|
(key) => !envKeysFromFile.includes(key)
|
||||||
|
);
|
||||||
|
|
||||||
|
if (missingFromFile.length > 0) {
|
||||||
|
let errorMessage = "\nYour ./next/.env is missing the following keys:\n";
|
||||||
|
missingFromFile.forEach((key) => {
|
||||||
|
errorMessage += chalk.whiteBright(`- ❌ ${key}\n`);
|
||||||
|
});
|
||||||
|
errorMessage += "\n";
|
||||||
|
|
||||||
|
errorMessage += chalk.red(
|
||||||
|
"We recommend deleting your .env file(s) and restarting this script."
|
||||||
|
);
|
||||||
|
throw new Error(errorMessage);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
export const saveEnvFile = (envFileContent) => {
|
||||||
|
fs.writeFileSync(ENV_PATH, envFileContent);
|
||||||
|
fs.writeFileSync(BACKEND_ENV_PATH, envFileContent);
|
||||||
|
};
|
26
cli/src/helpers.js
Normal file
26
cli/src/helpers.js
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import chalk from "chalk";
|
||||||
|
import figlet from "figlet";
|
||||||
|
|
||||||
|
export const printTitle = () => {
|
||||||
|
console.log(
|
||||||
|
chalk.red(
|
||||||
|
figlet.textSync("AgentGPT", {
|
||||||
|
horizontalLayout: "full",
|
||||||
|
font: "ANSI Shadow",
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
console.log(
|
||||||
|
"Welcome to the AgentGPT CLI! This CLI will generate the required .env files."
|
||||||
|
);
|
||||||
|
console.log(
|
||||||
|
"Copies of the generated envs will be created in `./next/.env` and `./platform/.env`.\n"
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Function to check if entered api key is in the correct format or empty
|
||||||
|
export const isValidKey = (apikey, pattern) => {
|
||||||
|
return (apikey === "" || pattern.test(apikey))
|
||||||
|
};
|
||||||
|
|
||||||
|
export const validKeyErrorMessage = "\nInvalid api key. Please try again."
|
60
cli/src/index.js
Normal file
60
cli/src/index.js
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import inquirer from "inquirer";
|
||||||
|
import dotenv from "dotenv";
|
||||||
|
import { printTitle } from "./helpers.js";
|
||||||
|
import { doesEnvFileExist, generateEnv, testEnvFile } from "./envGenerator.js";
|
||||||
|
import { newEnvQuestions } from "./questions/newEnvQuestions.js";
|
||||||
|
import { existingEnvQuestions } from "./questions/existingEnvQuestions.js";
|
||||||
|
import { spawn } from "child_process";
|
||||||
|
import chalk from "chalk";
|
||||||
|
|
||||||
|
const handleExistingEnv = () => {
|
||||||
|
console.log(chalk.yellow("Existing ./next/env file found. Validating..."));
|
||||||
|
|
||||||
|
try {
|
||||||
|
testEnvFile();
|
||||||
|
} catch (e) {
|
||||||
|
console.log(e.message);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
inquirer.prompt(existingEnvQuestions).then((answers) => {
|
||||||
|
handleRunOption(answers.runOption);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleNewEnv = () => {
|
||||||
|
inquirer.prompt(newEnvQuestions).then((answers) => {
|
||||||
|
dotenv.config({ path: "./.env" });
|
||||||
|
generateEnv(answers);
|
||||||
|
console.log("\nEnv files successfully created!");
|
||||||
|
handleRunOption(answers.runOption);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleRunOption = (runOption) => {
|
||||||
|
if (runOption === "docker-compose") {
|
||||||
|
const dockerComposeUp = spawn("docker-compose", ["up", "--build"], {
|
||||||
|
stdio: "inherit",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (runOption === "manual") {
|
||||||
|
console.log(
|
||||||
|
"Please go into the ./next folder and run `npm install && npm run dev`."
|
||||||
|
);
|
||||||
|
console.log(
|
||||||
|
"Please also go into the ./platform folder and run `poetry install && poetry run python -m reworkd_platform`."
|
||||||
|
);
|
||||||
|
console.log(
|
||||||
|
"Please use or update the MySQL database configuration in the env file(s)."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
printTitle();
|
||||||
|
|
||||||
|
if (doesEnvFileExist()) {
|
||||||
|
handleExistingEnv();
|
||||||
|
} else {
|
||||||
|
handleNewEnv();
|
||||||
|
}
|
5
cli/src/questions/existingEnvQuestions.js
Normal file
5
cli/src/questions/existingEnvQuestions.js
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
import { RUN_OPTION_QUESTION } from "./sharedQuestions.js";
|
||||||
|
|
||||||
|
export const existingEnvQuestions = [
|
||||||
|
RUN_OPTION_QUESTION
|
||||||
|
];
|
87
cli/src/questions/newEnvQuestions.js
Normal file
87
cli/src/questions/newEnvQuestions.js
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
import { isValidKey, validKeyErrorMessage } from "../helpers.js";
|
||||||
|
import { RUN_OPTION_QUESTION } from "./sharedQuestions.js";
|
||||||
|
import fetch from "node-fetch";
|
||||||
|
|
||||||
|
export const newEnvQuestions = [
|
||||||
|
RUN_OPTION_QUESTION,
|
||||||
|
{
|
||||||
|
type: "input",
|
||||||
|
name: "OpenAIApiKey",
|
||||||
|
message:
|
||||||
|
"Enter your openai key (eg: sk...) or press enter to continue with no key:",
|
||||||
|
validate: async(apikey) => {
|
||||||
|
if(apikey === "") return true;
|
||||||
|
|
||||||
|
if(!isValidKey(apikey, /^sk(-proj)?-[a-zA-Z0-9\-\_]+$/)) {
|
||||||
|
return validKeyErrorMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
const endpoint = "https://api.openai.com/v1/models"
|
||||||
|
const response = await fetch(endpoint, {
|
||||||
|
headers: {
|
||||||
|
"Authorization": `Bearer ${apikey}`,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
if(!response.ok) {
|
||||||
|
return validKeyErrorMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: "input",
|
||||||
|
name: "serpApiKey",
|
||||||
|
message:
|
||||||
|
"What is your SERP API key (https://serper.dev/)? Leave empty to disable web search.",
|
||||||
|
validate: async(apikey) => {
|
||||||
|
if(apikey === "") return true;
|
||||||
|
|
||||||
|
if(!isValidKey(apikey, /^[a-zA-Z0-9]{40}$/)) {
|
||||||
|
return validKeyErrorMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
const endpoint = "https://google.serper.dev/search"
|
||||||
|
const response = await fetch(endpoint, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
"X-API-KEY": apikey,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
"q": "apple inc"
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
if(!response.ok) {
|
||||||
|
return validKeyErrorMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: "input",
|
||||||
|
name: "replicateApiKey",
|
||||||
|
message:
|
||||||
|
"What is your Replicate API key (https://replicate.com/)? Leave empty to just use DALL-E for image generation.",
|
||||||
|
validate: async(apikey) => {
|
||||||
|
if(apikey === "") return true;
|
||||||
|
|
||||||
|
if(!isValidKey(apikey, /^r8_[a-zA-Z0-9]{37}$/)) {
|
||||||
|
return validKeyErrorMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
const endpoint = "https://api.replicate.com/v1/models/replicate/hello-world"
|
||||||
|
const response = await fetch(endpoint, {
|
||||||
|
headers: {
|
||||||
|
"Authorization": `Token ${apikey}`,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
if(!response.ok) {
|
||||||
|
return validKeyErrorMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
},
|
||||||
|
];
|
10
cli/src/questions/sharedQuestions.js
Normal file
10
cli/src/questions/sharedQuestions.js
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
export const RUN_OPTION_QUESTION = {
|
||||||
|
type: 'list',
|
||||||
|
name: 'runOption',
|
||||||
|
choices: [
|
||||||
|
{ value: "docker-compose", name: "🐋 Docker-compose (Recommended)" },
|
||||||
|
{ value: "manual", name: "💪 Manual (Not recommended)" },
|
||||||
|
],
|
||||||
|
message: 'How will you be running AgentGPT?',
|
||||||
|
default: "docker-compose",
|
||||||
|
}
|
115
cli/tsconfig.json
Normal file
115
cli/tsconfig.json
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
/* Visit https://aka.ms/tsconfig to read more about this file */
|
||||||
|
|
||||||
|
/* Projects */
|
||||||
|
// "incremental": true, /* Save .tsbuildinfo files to allow for incremental compilation of projects. */
|
||||||
|
// "composite": true, /* Enable constraints that allow a TypeScript project to be used with project references. */
|
||||||
|
// "tsBuildInfoFile": "./.tsbuildinfo", /* Specify the path to .tsbuildinfo incremental compilation file. */
|
||||||
|
// "disableSourceOfProjectReferenceRedirect": true, /* Disable preferring source files instead of declaration files when referencing composite projects. */
|
||||||
|
// "disableSolutionSearching": true, /* Opt a project out of multi-project reference checking when editing. */
|
||||||
|
// "disableReferencedProjectLoad": true, /* Reduce the number of projects loaded automatically by TypeScript. */
|
||||||
|
|
||||||
|
/* Language and Environment */
|
||||||
|
"target": "es2016",
|
||||||
|
/* Set the JavaScript language version for emitted JavaScript and include compatible library declarations. */
|
||||||
|
// "src": [], /* Specify a set of bundled library declaration files that describe the target runtime environment. */
|
||||||
|
// "jsx": "preserve", /* Specify what JSX code is generated. */
|
||||||
|
// "experimentalDecorators": true, /* Enable experimental support for legacy experimental decorators. */
|
||||||
|
// "emitDecoratorMetadata": true, /* Emit design-type metadata for decorated declarations in source files. */
|
||||||
|
// "jsxFactory": "", /* Specify the JSX factory function used when targeting React JSX emit, e.g. 'React.createElement' or 'h'. */
|
||||||
|
// "jsxFragmentFactory": "", /* Specify the JSX Fragment reference used for fragments when targeting React JSX emit e.g. 'React.Fragment' or 'Fragment'. */
|
||||||
|
// "jsxImportSource": "", /* Specify module specifier used to import the JSX factory functions when using 'jsx: react-jsx*'. */
|
||||||
|
// "reactNamespace": "", /* Specify the object invoked for 'createElement'. This only applies when targeting 'react' JSX emit. */
|
||||||
|
// "noLib": true, /* Disable including any library files, including the default src.d.ts. */
|
||||||
|
// "useDefineForClassFields": true, /* Emit ECMAScript-standard-compliant class fields. */
|
||||||
|
// "moduleDetection": "auto", /* Control what method is used to detect module-format JS files. */
|
||||||
|
|
||||||
|
/* Modules */
|
||||||
|
"module": "commonjs",
|
||||||
|
/* Specify what module code is generated. */
|
||||||
|
// "rootDir": "./", /* Specify the root folder within your source files. */
|
||||||
|
// "moduleResolution": "node10", /* Specify how TypeScript looks up a file from a given module specifier. */
|
||||||
|
// "baseUrl": "./", /* Specify the base directory to resolve non-relative module names. */
|
||||||
|
// "paths": {}, /* Specify a set of entries that re-map imports to additional lookup locations. */
|
||||||
|
// "rootDirs": [], /* Allow multiple folders to be treated as one when resolving modules. */
|
||||||
|
// "typeRoots": [], /* Specify multiple folders that act like './node_modules/@types'. */
|
||||||
|
// "types": [], /* Specify type package names to be included without being referenced in a source file. */
|
||||||
|
// "allowUmdGlobalAccess": true, /* Allow accessing UMD globals from modules. */
|
||||||
|
// "moduleSuffixes": [], /* List of file name suffixes to search when resolving a module. */
|
||||||
|
// "allowImportingTsExtensions": true, /* Allow imports to include TypeScript file extensions. Requires '--moduleResolution bundler' and either '--noEmit' or '--emitDeclarationOnly' to be set. */
|
||||||
|
// "resolvePackageJsonExports": true, /* Use the package.json 'exports' field when resolving package imports. */
|
||||||
|
// "resolvePackageJsonImports": true, /* Use the package.json 'imports' field when resolving imports. */
|
||||||
|
// "customConditions": [], /* Conditions to set in addition to the resolver-specific defaults when resolving imports. */
|
||||||
|
// "resolveJsonModule": true, /* Enable importing .json files. */
|
||||||
|
// "allowArbitraryExtensions": true, /* Enable importing files with any extension, provided a declaration file is present. */
|
||||||
|
// "noResolve": true, /* Disallow 'import's, 'require's or '<reference>'s from expanding the number of files TypeScript should add to a project. */
|
||||||
|
|
||||||
|
/* JavaScript Support */
|
||||||
|
// "allowJs": true, /* Allow JavaScript files to be a part of your program. Use the 'checkJS' option to get errors from these files. */
|
||||||
|
// "checkJs": true, /* Enable error reporting in type-checked JavaScript files. */
|
||||||
|
// "maxNodeModuleJsDepth": 1, /* Specify the maximum folder depth used for checking JavaScript files from 'node_modules'. Only applicable with 'allowJs'. */
|
||||||
|
|
||||||
|
/* Emit */
|
||||||
|
// "declaration": true, /* Generate .d.ts files from TypeScript and JavaScript files in your project. */
|
||||||
|
// "declarationMap": true, /* Create sourcemaps for d.ts files. */
|
||||||
|
// "emitDeclarationOnly": true, /* Only output d.ts files and not JavaScript files. */
|
||||||
|
// "sourceMap": true, /* Create source map files for emitted JavaScript files. */
|
||||||
|
// "inlineSourceMap": true, /* Include sourcemap files inside the emitted JavaScript. */
|
||||||
|
// "outFile": "./", /* Specify a file that bundles all outputs into one JavaScript file. If 'declaration' is true, also designates a file that bundles all .d.ts output. */
|
||||||
|
// "outDir": "./", /* Specify an output folder for all emitted files. */
|
||||||
|
// "removeComments": true, /* Disable emitting comments. */
|
||||||
|
// "noEmit": true, /* Disable emitting files from a compilation. */
|
||||||
|
// "importHelpers": true, /* Allow importing helper functions from tslib once per project, instead of including them per-file. */
|
||||||
|
// "importsNotUsedAsValues": "remove", /* Specify emit/checking behavior for imports that are only used for types. */
|
||||||
|
// "downlevelIteration": true, /* Emit more compliant, but verbose and less performant JavaScript for iteration. */
|
||||||
|
// "sourceRoot": "", /* Specify the root path for debuggers to find the reference source code. */
|
||||||
|
// "mapRoot": "", /* Specify the location where debugger should locate map files instead of generated locations. */
|
||||||
|
// "inlineSources": true, /* Include source code in the sourcemaps inside the emitted JavaScript. */
|
||||||
|
// "emitBOM": true, /* Emit a UTF-8 Byte Order Mark (BOM) in the beginning of output files. */
|
||||||
|
// "newLine": "crlf", /* Set the newline character for emitting files. */
|
||||||
|
// "stripInternal": true, /* Disable emitting declarations that have '@internal' in their JSDoc comments. */
|
||||||
|
// "noEmitHelpers": true, /* Disable generating custom helper functions like '__extends' in compiled output. */
|
||||||
|
// "noEmitOnError": true, /* Disable emitting files if any type checking errors are reported. */
|
||||||
|
// "preserveConstEnums": true, /* Disable erasing 'const enum' declarations in generated code. */
|
||||||
|
// "declarationDir": "./", /* Specify the output directory for generated declaration files. */
|
||||||
|
// "preserveValueImports": true, /* Preserve unused imported values in the JavaScript output that would otherwise be removed. */
|
||||||
|
|
||||||
|
/* Interop Constraints */
|
||||||
|
// "isolatedModules": true, /* Ensure that each file can be safely transpiled without relying on other imports. */
|
||||||
|
// "verbatimModuleSyntax": true, /* Do not transform or elide any imports or exports not marked as type-only, ensuring they are written in the output file's format based on the 'module' setting. */
|
||||||
|
// "allowSyntheticDefaultImports": true, /* Allow 'import x from y' when a module doesn't have a default export. */
|
||||||
|
"esModuleInterop": true,
|
||||||
|
/* Emit additional JavaScript to ease support for importing CommonJS modules. This enables 'allowSyntheticDefaultImports' for type compatibility. */
|
||||||
|
// "preserveSymlinks": true, /* Disable resolving symlinks to their realpath. This correlates to the same flag in node. */
|
||||||
|
"forceConsistentCasingInFileNames": true,
|
||||||
|
/* Ensure that casing is correct in imports. */
|
||||||
|
|
||||||
|
/* Type Checking */
|
||||||
|
"strict": true,
|
||||||
|
/* Enable all strict type-checking options. */
|
||||||
|
// "noImplicitAny": true, /* Enable error reporting for expressions and declarations with an implied 'any' type. */
|
||||||
|
// "strictNullChecks": true, /* When type checking, take into account 'null' and 'undefined'. */
|
||||||
|
// "strictFunctionTypes": true, /* When assigning functions, check to ensure parameters and the return values are subtype-compatible. */
|
||||||
|
// "strictBindCallApply": true, /* Check that the arguments for 'bind', 'call', and 'apply' methods match the original function. */
|
||||||
|
// "strictPropertyInitialization": true, /* Check for class properties that are declared but not set in the constructor. */
|
||||||
|
// "noImplicitThis": true, /* Enable error reporting when 'this' is given the type 'any'. */
|
||||||
|
// "useUnknownInCatchVariables": true, /* Default catch clause variables as 'unknown' instead of 'any'. */
|
||||||
|
// "alwaysStrict": true, /* Ensure 'use strict' is always emitted. */
|
||||||
|
// "noUnusedLocals": true, /* Enable error reporting when local variables aren't read. */
|
||||||
|
// "noUnusedParameters": true, /* Raise an error when a function parameter isn't read. */
|
||||||
|
// "exactOptionalPropertyTypes": true, /* Interpret optional property types as written, rather than adding 'undefined'. */
|
||||||
|
// "noImplicitReturns": true, /* Enable error reporting for codepaths that do not explicitly return in a function. */
|
||||||
|
// "noFallthroughCasesInSwitch": true, /* Enable error reporting for fallthrough cases in switch statements. */
|
||||||
|
// "noUncheckedIndexedAccess": true, /* Add 'undefined' to a type when accessed using an index. */
|
||||||
|
// "noImplicitOverride": true, /* Ensure overriding members in derived classes are marked with an override modifier. */
|
||||||
|
// "noPropertyAccessFromIndexSignature": true, /* Enforces using indexed accessors for keys declared using an indexed type. */
|
||||||
|
// "allowUnusedLabels": true, /* Disable error reporting for unused labels. */
|
||||||
|
// "allowUnreachableCode": true, /* Disable error reporting for unreachable code. */
|
||||||
|
|
||||||
|
/* Completeness */
|
||||||
|
// "skipDefaultLibCheck": true, /* Skip type checking .d.ts files that are included with TypeScript. */
|
||||||
|
"skipLibCheck": true
|
||||||
|
/* Skip type checking all .d.ts files. */
|
||||||
|
}
|
||||||
|
}
|
3
db/Dockerfile
Normal file
3
db/Dockerfile
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
FROM mysql:8.0
|
||||||
|
|
||||||
|
ADD setup.sql /docker-entrypoint-initdb.d
|
11
db/setup.sql
Normal file
11
db/setup.sql
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
-- Prisma requires DB creation privileges to create a shadow database (https://pris.ly/d/migrate-shadow)
|
||||||
|
-- This is not available to our user by default, so we must manually add this
|
||||||
|
|
||||||
|
-- Create the user
|
||||||
|
CREATE USER IF NOT EXISTS 'reworkd_platform'@'%' IDENTIFIED BY 'reworkd_platform';
|
||||||
|
|
||||||
|
-- Grant the necessary permissions
|
||||||
|
GRANT CREATE, ALTER, DROP, INSERT, UPDATE, DELETE, SELECT ON *.* TO 'reworkd_platform'@'%';
|
||||||
|
|
||||||
|
-- Apply the changes
|
||||||
|
FLUSH PRIVILEGES;
|
145
platform/.dockerignore
Normal file
145
platform/.dockerignore
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
### Python template
|
||||||
|
|
||||||
|
deploy/
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
.git/
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
poetry.lock
|
31
platform/.editorconfig
Normal file
31
platform/.editorconfig
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
root = true
|
||||||
|
|
||||||
|
[*]
|
||||||
|
tab_width = 4
|
||||||
|
end_of_line = lf
|
||||||
|
max_line_length = 88
|
||||||
|
ij_visual_guides = 88
|
||||||
|
insert_final_newline = true
|
||||||
|
trim_trailing_whitespace = true
|
||||||
|
|
||||||
|
[*.{js,py,html}]
|
||||||
|
charset = utf-8
|
||||||
|
|
||||||
|
[*.md]
|
||||||
|
trim_trailing_whitespace = false
|
||||||
|
|
||||||
|
[*.{yml,yaml}]
|
||||||
|
indent_style = space
|
||||||
|
indent_size = 2
|
||||||
|
|
||||||
|
[Makefile]
|
||||||
|
indent_style = tab
|
||||||
|
|
||||||
|
[.flake8]
|
||||||
|
indent_style = space
|
||||||
|
indent_size = 2
|
||||||
|
|
||||||
|
[*.py]
|
||||||
|
indent_style = space
|
||||||
|
indent_size = 4
|
||||||
|
ij_python_from_import_parentheses_force_if_multiline = true
|
115
platform/.flake8
Normal file
115
platform/.flake8
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
[flake8]
|
||||||
|
max-complexity = 6
|
||||||
|
inline-quotes = double
|
||||||
|
max-line-length = 88
|
||||||
|
extend-ignore = E203
|
||||||
|
docstring_style=sphinx
|
||||||
|
|
||||||
|
ignore =
|
||||||
|
; Found `f` string
|
||||||
|
WPS305,
|
||||||
|
; Missing docstring in public module
|
||||||
|
D100,
|
||||||
|
; Missing docstring in magic method
|
||||||
|
D105,
|
||||||
|
; Missing docstring in __init__
|
||||||
|
D107,
|
||||||
|
; Found `__init__.py` module with logic
|
||||||
|
WPS412,
|
||||||
|
; Found class without a base class
|
||||||
|
WPS306,
|
||||||
|
; Missing docstring in public nested class
|
||||||
|
D106,
|
||||||
|
; First line should be in imperative mood
|
||||||
|
D401,
|
||||||
|
; Found wrong variable name
|
||||||
|
WPS110,
|
||||||
|
; Found `__init__.py` module with logic
|
||||||
|
WPS326,
|
||||||
|
; Found string constant over-use
|
||||||
|
WPS226,
|
||||||
|
; Found upper-case constant in a class
|
||||||
|
WPS115,
|
||||||
|
; Found nested function
|
||||||
|
WPS602,
|
||||||
|
; Found method without arguments
|
||||||
|
WPS605,
|
||||||
|
; Found overused expression
|
||||||
|
WPS204,
|
||||||
|
; Found too many module members
|
||||||
|
WPS202,
|
||||||
|
; Found too high module cognitive complexity
|
||||||
|
WPS232,
|
||||||
|
; line break before binary operator
|
||||||
|
W503,
|
||||||
|
; Found module with too many imports
|
||||||
|
WPS201,
|
||||||
|
; Inline strong start-string without end-string.
|
||||||
|
RST210,
|
||||||
|
; Found nested class
|
||||||
|
WPS431,
|
||||||
|
; Found wrong module name
|
||||||
|
WPS100,
|
||||||
|
; Found too many methods
|
||||||
|
WPS214,
|
||||||
|
; Found too long ``try`` body
|
||||||
|
WPS229,
|
||||||
|
; Found unpythonic getter or setter
|
||||||
|
WPS615,
|
||||||
|
; Found a line that starts with a dot
|
||||||
|
WPS348,
|
||||||
|
; Found complex default value (for dependency injection)
|
||||||
|
WPS404,
|
||||||
|
; not perform function calls in argument defaults (for dependency injection)
|
||||||
|
B008,
|
||||||
|
; Model should define verbose_name in its Meta inner class
|
||||||
|
DJ10,
|
||||||
|
; Model should define verbose_name_plural in its Meta inner class
|
||||||
|
DJ11,
|
||||||
|
; Found mutable module constant.
|
||||||
|
WPS407,
|
||||||
|
; Found too many empty lines in `def`
|
||||||
|
WPS473,
|
||||||
|
; Found missing trailing comma
|
||||||
|
C812,
|
||||||
|
|
||||||
|
per-file-ignores =
|
||||||
|
; all tests
|
||||||
|
test_*.py,tests.py,tests_*.py,*/tests/*,conftest.py:
|
||||||
|
; Use of assert detected
|
||||||
|
S101,
|
||||||
|
; Found outer scope names shadowing
|
||||||
|
WPS442,
|
||||||
|
; Found too many local variables
|
||||||
|
WPS210,
|
||||||
|
; Found magic number
|
||||||
|
WPS432,
|
||||||
|
; Missing parameter(s) in Docstring
|
||||||
|
DAR101,
|
||||||
|
; Found too many arguments
|
||||||
|
WPS211,
|
||||||
|
|
||||||
|
; all init files
|
||||||
|
__init__.py:
|
||||||
|
; ignore not used imports
|
||||||
|
F401,
|
||||||
|
; ignore import with wildcard
|
||||||
|
F403,
|
||||||
|
; Found wrong metadata variable
|
||||||
|
WPS410,
|
||||||
|
|
||||||
|
exclude =
|
||||||
|
./.cache,
|
||||||
|
./.git,
|
||||||
|
./.idea,
|
||||||
|
./.mypy_cache,
|
||||||
|
./.pytest_cache,
|
||||||
|
./.venv,
|
||||||
|
./venv,
|
||||||
|
./env,
|
||||||
|
./cached_venv,
|
||||||
|
./docs,
|
||||||
|
./deploy,
|
||||||
|
./var,
|
||||||
|
./.vscode,
|
||||||
|
*migrations*,
|
144
platform/.gitignore
vendored
Normal file
144
platform/.gitignore
vendored
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
### Python template
|
||||||
|
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.pem
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
ssl/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
*.sqlite3
|
||||||
|
*.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
.env
|
63
platform/.pre-commit-config.yaml
Normal file
63
platform/.pre-commit-config.yaml
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
---
|
||||||
|
# See https://pre-commit.com for more information
|
||||||
|
# See https://pre-commit.com/hooks.html for more hooks
|
||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v2.4.0
|
||||||
|
hooks:
|
||||||
|
- id: check-ast
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: check-toml
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
|
||||||
|
- repo: https://github.com/asottile/add-trailing-comma
|
||||||
|
rev: v2.1.0
|
||||||
|
hooks:
|
||||||
|
- id: add-trailing-comma
|
||||||
|
|
||||||
|
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
|
||||||
|
rev: v2.1.0
|
||||||
|
hooks:
|
||||||
|
- id: pretty-format-yaml
|
||||||
|
args:
|
||||||
|
- --autofix
|
||||||
|
- --preserve-quotes
|
||||||
|
- --indent=2
|
||||||
|
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: black
|
||||||
|
name: Format with Black
|
||||||
|
entry: poetry run black
|
||||||
|
language: system
|
||||||
|
types: [python]
|
||||||
|
|
||||||
|
- id: autoflake
|
||||||
|
name: autoflake
|
||||||
|
entry: poetry run autoflake
|
||||||
|
language: system
|
||||||
|
types: [python]
|
||||||
|
args: [--in-place, --remove-all-unused-imports, --remove-duplicate-keys]
|
||||||
|
|
||||||
|
- id: isort
|
||||||
|
name: isort
|
||||||
|
entry: poetry run isort
|
||||||
|
language: system
|
||||||
|
types: [python]
|
||||||
|
|
||||||
|
- id: flake8
|
||||||
|
name: Check with Flake8
|
||||||
|
entry: poetry run flake8
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
types: [python]
|
||||||
|
args: [--count, .]
|
||||||
|
|
||||||
|
- id: mypy
|
||||||
|
name: Validate types with MyPy
|
||||||
|
entry: poetry run mypy
|
||||||
|
language: system
|
||||||
|
types: [python]
|
||||||
|
pass_filenames: false
|
||||||
|
args:
|
||||||
|
- "reworkd_platform"
|
37
platform/Dockerfile
Normal file
37
platform/Dockerfile
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
FROM python:3.11-slim-buster as prod
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
default-libmysqlclient-dev \
|
||||||
|
gcc \
|
||||||
|
pkg-config \
|
||||||
|
openjdk-11-jdk \
|
||||||
|
build-essential \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
RUN pip install poetry==1.4.2
|
||||||
|
|
||||||
|
# Configuring poetry
|
||||||
|
RUN poetry config virtualenvs.create false
|
||||||
|
|
||||||
|
# Copying requirements of a project
|
||||||
|
COPY pyproject.toml /app/src/
|
||||||
|
WORKDIR /app/src
|
||||||
|
|
||||||
|
# Installing requirements
|
||||||
|
RUN poetry install --only main
|
||||||
|
# Removing gcc
|
||||||
|
RUN apt-get purge -y \
|
||||||
|
g++ \
|
||||||
|
gcc \
|
||||||
|
pkg-config \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Copying actual application
|
||||||
|
COPY . /app/src/
|
||||||
|
RUN poetry install --only main
|
||||||
|
|
||||||
|
CMD ["/usr/local/bin/python", "-m", "reworkd_platform"]
|
||||||
|
|
||||||
|
FROM prod as dev
|
||||||
|
|
||||||
|
RUN poetry install
|
151
platform/README.md
Normal file
151
platform/README.md
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
# reworkd_platform
|
||||||
|
|
||||||
|
This project was generated using fastapi_template.
|
||||||
|
|
||||||
|
## Poetry
|
||||||
|
|
||||||
|
This project uses poetry. It's a modern dependency management
|
||||||
|
tool.
|
||||||
|
|
||||||
|
To run the project use this set of commands:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
poetry install
|
||||||
|
poetry run python -m reworkd_platform
|
||||||
|
```
|
||||||
|
|
||||||
|
This will start the server on the configured host.
|
||||||
|
|
||||||
|
You can find swagger documentation at `/api/docs`.
|
||||||
|
|
||||||
|
You can read more about poetry here: https://python-poetry.org/
|
||||||
|
|
||||||
|
## Docker
|
||||||
|
|
||||||
|
You can start the project with docker using this command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker-compose -f deploy/docker-compose.yml --project-directory . up --build
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want to develop in docker with autoreload add `-f deploy/docker-compose.dev.yml` to your docker command.
|
||||||
|
Like this:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker-compose -f deploy/docker-compose.yml -f deploy/docker-compose.dev.yml --project-directory . up --build
|
||||||
|
```
|
||||||
|
|
||||||
|
This command exposes the web application on port 8000, mounts current directory and enables autoreload.
|
||||||
|
|
||||||
|
But you have to rebuild image every time you modify `poetry.lock` or `pyproject.toml` with this command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker-compose -f deploy/docker-compose.yml --project-directory . build
|
||||||
|
```
|
||||||
|
|
||||||
|
## Project structure
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ tree "reworkd_platform"
|
||||||
|
reworkd_platform
|
||||||
|
├── conftest.py # Fixtures for all tests.
|
||||||
|
├── db # module contains db configurations
|
||||||
|
│ ├── dao # Data Access Objects. Contains different classes to interact with database.
|
||||||
|
│ └── models # Package contains different models for ORMs.
|
||||||
|
├── __main__.py # Startup script. Starts uvicorn.
|
||||||
|
├── services # Package for different external services such as rabbit or redis etc.
|
||||||
|
├── settings.py # Main configuration settings for project.
|
||||||
|
├── static # Static content.
|
||||||
|
├── tests # Tests for project.
|
||||||
|
└── web # Package contains web server. Handlers, startup config.
|
||||||
|
├── api # Package with all handlers.
|
||||||
|
│ └── router.py # Main router.
|
||||||
|
├── application.py # FastAPI application configuration.
|
||||||
|
└── lifetime.py # Contains actions to perform on startup and shutdown.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
This application can be configured with environment variables.
|
||||||
|
|
||||||
|
You can create `.env` file in the root directory and place all
|
||||||
|
environment variables here.
|
||||||
|
|
||||||
|
All environment variables should start with "REWORKD_PLATFORM_" prefix.
|
||||||
|
|
||||||
|
For example if you see in your "reworkd_platform/settings.py" a variable named like
|
||||||
|
`random_parameter`, you should provide the "REWORKD_PLATFORM_RANDOM_PARAMETER"
|
||||||
|
variable to configure the value. This behaviour can be changed by overriding `env_prefix` property
|
||||||
|
in `reworkd_platform.settings.Settings.Config`.
|
||||||
|
|
||||||
|
An example of .env file:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
REWORKD_PLATFORM_RELOAD="True"
|
||||||
|
REWORKD_PLATFORM_PORT="8000"
|
||||||
|
REWORKD_PLATFORM_ENVIRONMENT="development"
|
||||||
|
```
|
||||||
|
|
||||||
|
You can read more about BaseSettings class here: https://pydantic-docs.helpmanual.io/usage/settings/
|
||||||
|
|
||||||
|
## Pre-commit
|
||||||
|
|
||||||
|
To install pre-commit simply run inside the shell:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pre-commit install
|
||||||
|
```
|
||||||
|
|
||||||
|
pre-commit is very useful to check your code before publishing it.
|
||||||
|
It's configured using .pre-commit-config.yaml file.
|
||||||
|
|
||||||
|
By default it runs:
|
||||||
|
|
||||||
|
* black (formats your code);
|
||||||
|
* mypy (validates types);
|
||||||
|
* isort (sorts imports in all files);
|
||||||
|
* flake8 (spots possibe bugs);
|
||||||
|
|
||||||
|
You can read more about pre-commit here: https://pre-commit.com/
|
||||||
|
|
||||||
|
## Running tests
|
||||||
|
|
||||||
|
If you want to run it in docker, simply run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker-compose -f deploy/docker-compose.yml -f deploy/docker-compose.dev.yml --project-directory . run --build --rm api pytest -vv .
|
||||||
|
docker-compose -f deploy/docker-compose.yml -f deploy/docker-compose.dev.yml --project-directory . down
|
||||||
|
```
|
||||||
|
|
||||||
|
For running tests on your local machine.
|
||||||
|
|
||||||
|
1. you need to start a database.
|
||||||
|
|
||||||
|
I prefer doing it with docker:
|
||||||
|
|
||||||
|
```
|
||||||
|
docker run -p "3306:3306" -e "MYSQL_PASSWORD=reworkd_platform" -e "MYSQL_USER=reworkd_platform" -e "MYSQL_DATABASE=reworkd_platform" -e ALLOW_EMPTY_PASSWORD=yes bitnami/mysql:8.0.30
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Run the pytest.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pytest -vv .
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running linters
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Flake
|
||||||
|
poetry run black .
|
||||||
|
poetry run autoflake --in-place --remove-duplicate-keys --remove-all-unused-imports -r .
|
||||||
|
poetry run flake8
|
||||||
|
poetry run mypy .
|
||||||
|
|
||||||
|
# Pytest
|
||||||
|
poetry run pytest -vv --cov="reworkd_platform" .
|
||||||
|
|
||||||
|
# Bump packages
|
||||||
|
poetry self add poetry-plugin-up
|
||||||
|
poetry up --latest
|
||||||
|
```
|
14
platform/entrypoint.sh
Normal file
14
platform/entrypoint.sh
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
#!/usr/bin/env sh
|
||||||
|
|
||||||
|
host=agentgpt_db
|
||||||
|
port=3306
|
||||||
|
|
||||||
|
until echo "SELECT 1;" | nc "$host" "$port" > /dev/null 2>&1; do
|
||||||
|
>&2 echo "Database is unavailable - Sleeping..."
|
||||||
|
sleep 2
|
||||||
|
done
|
||||||
|
|
||||||
|
>&2 echo "Database is available! Continuing..."
|
||||||
|
|
||||||
|
# Run cmd
|
||||||
|
exec "$@"
|
3413
platform/poetry.lock
generated
Normal file
3413
platform/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
100
platform/pyproject.toml
Normal file
100
platform/pyproject.toml
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
[tool.poetry]
|
||||||
|
name = "reworkd_platform"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = ""
|
||||||
|
authors = [
|
||||||
|
"awtkns",
|
||||||
|
"asim-shrestha"
|
||||||
|
]
|
||||||
|
|
||||||
|
maintainers = [
|
||||||
|
"reworkd"
|
||||||
|
]
|
||||||
|
|
||||||
|
readme = "README.md"
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = "^3.11"
|
||||||
|
fastapi = "^0.98.0"
|
||||||
|
boto3 = "^1.28.51"
|
||||||
|
uvicorn = { version = "^0.22.0", extras = ["standard"] }
|
||||||
|
pydantic = { version = "<2.0", extras = ["dotenv"] }
|
||||||
|
ujson = "^5.8.0"
|
||||||
|
sqlalchemy = { version = "^2.0.21", extras = ["mypy", "asyncio"] }
|
||||||
|
aiomysql = "^0.1.1"
|
||||||
|
mysqlclient = "^2.2.0"
|
||||||
|
sentry-sdk = "^1.31.0"
|
||||||
|
loguru = "^0.7.2"
|
||||||
|
aiokafka = "^0.8.1"
|
||||||
|
requests = "^2.31.0"
|
||||||
|
langchain = "^0.0.295"
|
||||||
|
openai = "^0.28.0"
|
||||||
|
wikipedia = "^1.4.0"
|
||||||
|
replicate = "^0.8.4"
|
||||||
|
lanarky = "0.7.15"
|
||||||
|
tiktoken = "^0.5.1"
|
||||||
|
grpcio = "^1.58.0"
|
||||||
|
pinecone-client = "^2.2.4"
|
||||||
|
python-multipart = "^0.0.6"
|
||||||
|
aws-secretsmanager-caching = "^1.1.1.5"
|
||||||
|
botocore = "^1.31.51"
|
||||||
|
stripe = "^5.5.0"
|
||||||
|
cryptography = "^41.0.4"
|
||||||
|
httpx = "^0.25.0"
|
||||||
|
|
||||||
|
|
||||||
|
[tool.poetry.dev-dependencies]
|
||||||
|
autopep8 = "^2.0.4"
|
||||||
|
pytest = "^7.4.2"
|
||||||
|
flake8 = "~6.0.0"
|
||||||
|
mypy = "^1.5.1"
|
||||||
|
isort = "^5.12.0"
|
||||||
|
pre-commit = "^3.4.0"
|
||||||
|
wemake-python-styleguide = "^0.18.0"
|
||||||
|
black = "^23.9.1"
|
||||||
|
autoflake = "^2.2.1"
|
||||||
|
pytest-cov = "^4.1.0"
|
||||||
|
anyio = "^3.7.1"
|
||||||
|
pytest-env = "^0.8.2"
|
||||||
|
|
||||||
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
dotmap = "^1.3.30"
|
||||||
|
pytest-mock = "^3.10.0"
|
||||||
|
pytest-asyncio = "^0.21.0"
|
||||||
|
mypy = "^1.4.1"
|
||||||
|
types-requests = "^2.31.0.1"
|
||||||
|
types-pytz = "^2023.3.0.0"
|
||||||
|
|
||||||
|
[tool.isort]
|
||||||
|
profile = "black"
|
||||||
|
multi_line_output = 3
|
||||||
|
src_paths = ["reworkd_platform"]
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
strict = true
|
||||||
|
ignore_missing_imports = true
|
||||||
|
allow_subclassing_any = true
|
||||||
|
allow_untyped_calls = true
|
||||||
|
pretty = true
|
||||||
|
show_error_codes = true
|
||||||
|
implicit_reexport = true
|
||||||
|
allow_untyped_decorators = true
|
||||||
|
warn_unused_ignores = false
|
||||||
|
warn_return_any = false
|
||||||
|
namespace_packages = true
|
||||||
|
exclude = "tests"
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
filterwarnings = [
|
||||||
|
"error",
|
||||||
|
"ignore::DeprecationWarning",
|
||||||
|
"ignore:.*unclosed.*:ResourceWarning",
|
||||||
|
"ignore::ImportWarning",
|
||||||
|
]
|
||||||
|
env = [
|
||||||
|
"REWORKD_PLATFORM_DB_BASE=reworkd_platform_test",
|
||||||
|
]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core>=1.0.0"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
1
platform/reworkd_platform/__init__.py
Normal file
1
platform/reworkd_platform/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""reworkd_platform package."""
|
20
platform/reworkd_platform/__main__.py
Normal file
20
platform/reworkd_platform/__main__.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from reworkd_platform.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Entrypoint of the application."""
|
||||||
|
uvicorn.run(
|
||||||
|
"reworkd_platform.web.application:get_app",
|
||||||
|
workers=settings.workers_count,
|
||||||
|
host=settings.host,
|
||||||
|
port=settings.port,
|
||||||
|
reload=settings.reload,
|
||||||
|
log_level="debug",
|
||||||
|
factory=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
107
platform/reworkd_platform/conftest.py
Normal file
107
platform/reworkd_platform/conftest.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from httpx import AsyncClient
|
||||||
|
from sqlalchemy.ext.asyncio import (
|
||||||
|
AsyncEngine,
|
||||||
|
AsyncSession,
|
||||||
|
async_sessionmaker,
|
||||||
|
create_async_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
from reworkd_platform.db.dependencies import get_db_session
|
||||||
|
from reworkd_platform.db.utils import create_database, drop_database
|
||||||
|
from reworkd_platform.settings import settings
|
||||||
|
from reworkd_platform.web.application import get_app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def anyio_backend() -> str:
|
||||||
|
"""
|
||||||
|
Backend for anyio pytest plugin.
|
||||||
|
|
||||||
|
:return: backend name.
|
||||||
|
"""
|
||||||
|
return "asyncio"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
async def _engine() -> AsyncGenerator[AsyncEngine, None]:
|
||||||
|
"""
|
||||||
|
Create engine and databases.
|
||||||
|
|
||||||
|
:yield: new engine.
|
||||||
|
"""
|
||||||
|
from reworkd_platform.db.meta import meta # noqa: WPS433
|
||||||
|
from reworkd_platform.db.models import load_all_models # noqa: WPS433
|
||||||
|
|
||||||
|
load_all_models()
|
||||||
|
|
||||||
|
await create_database()
|
||||||
|
|
||||||
|
engine = create_async_engine(str(settings.db_url))
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(meta.create_all)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield engine
|
||||||
|
finally:
|
||||||
|
await engine.dispose()
|
||||||
|
await drop_database()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def dbsession(
|
||||||
|
_engine: AsyncEngine,
|
||||||
|
) -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""
|
||||||
|
Get session to database.
|
||||||
|
|
||||||
|
Fixture that returns a SQLAlchemy session with a SAVEPOINT, and the rollback to it
|
||||||
|
after the test completes.
|
||||||
|
|
||||||
|
:param _engine: current engine.
|
||||||
|
:yields: async session.
|
||||||
|
"""
|
||||||
|
connection = await _engine.connect()
|
||||||
|
trans = await connection.begin()
|
||||||
|
|
||||||
|
session_maker = async_sessionmaker(
|
||||||
|
connection,
|
||||||
|
expire_on_commit=False,
|
||||||
|
)
|
||||||
|
session = session_maker()
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
await trans.rollback()
|
||||||
|
await connection.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fastapi_app(dbsession: AsyncSession) -> FastAPI:
|
||||||
|
"""
|
||||||
|
Fixture for creating FastAPI app.
|
||||||
|
|
||||||
|
:return: fastapi app with mocked dependencies.
|
||||||
|
"""
|
||||||
|
application = get_app()
|
||||||
|
application.dependency_overrides[get_db_session] = lambda: dbsession
|
||||||
|
return application # noqa: WPS331
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def client(
|
||||||
|
fastapi_app: FastAPI, anyio_backend: Any
|
||||||
|
) -> AsyncGenerator[AsyncClient, None]:
|
||||||
|
"""
|
||||||
|
Fixture that creates client for requesting server.
|
||||||
|
|
||||||
|
:param fastapi_app: the application.
|
||||||
|
:yield: client for the app.
|
||||||
|
"""
|
||||||
|
async with AsyncClient(app=fastapi_app, base_url="http://test") as ac:
|
||||||
|
yield ac
|
1
platform/reworkd_platform/constants.py
Normal file
1
platform/reworkd_platform/constants.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
ENV_PREFIX = "REWORKD_PLATFORM_"
|
0
platform/reworkd_platform/db/__init__.py
Normal file
0
platform/reworkd_platform/db/__init__.py
Normal file
68
platform/reworkd_platform/db/base.py
Normal file
68
platform/reworkd_platform/db/base.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, Type, TypeVar
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, String, func
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||||
|
|
||||||
|
from reworkd_platform.db.meta import meta
|
||||||
|
from reworkd_platform.web.api.http_responses import not_found
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="Base")
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
"""Base for all models."""
|
||||||
|
|
||||||
|
metadata = meta
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
String,
|
||||||
|
primary_key=True,
|
||||||
|
default=lambda _: str(uuid.uuid4()),
|
||||||
|
unique=True,
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get(cls: Type[T], session: AsyncSession, id_: str) -> Optional[T]:
|
||||||
|
return await session.get(cls, id_)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_or_404(cls: Type[T], session: AsyncSession, id_: str) -> T:
|
||||||
|
if model := await cls.get(session, id_):
|
||||||
|
return model
|
||||||
|
|
||||||
|
raise not_found(detail=f"{cls.__name__}[{id_}] not found")
|
||||||
|
|
||||||
|
async def save(self: T, session: AsyncSession) -> T:
|
||||||
|
session.add(self)
|
||||||
|
await session.flush()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def delete(self: T, session: AsyncSession) -> None:
|
||||||
|
await session.delete(self)
|
||||||
|
|
||||||
|
|
||||||
|
class TrackedModel(Base):
|
||||||
|
"""Base for all tracked models."""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
create_date = mapped_column(
|
||||||
|
DateTime, name="create_date", server_default=func.now(), nullable=False
|
||||||
|
)
|
||||||
|
update_date = mapped_column(
|
||||||
|
DateTime, name="update_date", onupdate=func.now(), nullable=True
|
||||||
|
)
|
||||||
|
delete_date = mapped_column(DateTime, name="delete_date", nullable=True)
|
||||||
|
|
||||||
|
async def delete(self, session: AsyncSession) -> None:
|
||||||
|
"""Marks the model as deleted."""
|
||||||
|
self.delete_date = datetime.now()
|
||||||
|
await self.save(session)
|
||||||
|
|
||||||
|
|
||||||
|
class UserMixin:
|
||||||
|
user_id = mapped_column(String, name="user_id", nullable=False)
|
||||||
|
organization_id = mapped_column(String, name="organization_id", nullable=True)
|
0
platform/reworkd_platform/db/crud/__init__.py
Normal file
0
platform/reworkd_platform/db/crud/__init__.py
Normal file
58
platform/reworkd_platform/db/crud/agent.py
Normal file
58
platform/reworkd_platform/db/crud/agent.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
from fastapi import HTTPException
|
||||||
|
from sqlalchemy import and_, func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from reworkd_platform.db.crud.base import BaseCrud
|
||||||
|
from reworkd_platform.db.models.agent import AgentRun, AgentTask
|
||||||
|
from reworkd_platform.schemas.agent import Loop_Step
|
||||||
|
from reworkd_platform.schemas.user import UserBase
|
||||||
|
from reworkd_platform.settings import settings
|
||||||
|
from reworkd_platform.web.api.errors import MaxLoopsError, MultipleSummaryError
|
||||||
|
|
||||||
|
|
||||||
|
class AgentCRUD(BaseCrud):
|
||||||
|
def __init__(self, session: AsyncSession, user: UserBase):
|
||||||
|
super().__init__(session)
|
||||||
|
self.user = user
|
||||||
|
|
||||||
|
async def create_run(self, goal: str) -> AgentRun:
|
||||||
|
return await AgentRun(
|
||||||
|
user_id=self.user.id,
|
||||||
|
goal=goal,
|
||||||
|
).save(self.session)
|
||||||
|
|
||||||
|
async def create_task(self, run_id: str, type_: Loop_Step) -> AgentTask:
|
||||||
|
await self.validate_task_count(run_id, type_)
|
||||||
|
return await AgentTask(
|
||||||
|
run_id=run_id,
|
||||||
|
type_=type_,
|
||||||
|
).save(self.session)
|
||||||
|
|
||||||
|
async def validate_task_count(self, run_id: str, type_: str) -> None:
|
||||||
|
if not await AgentRun.get(self.session, run_id):
|
||||||
|
raise HTTPException(404, f"Run {run_id} not found")
|
||||||
|
|
||||||
|
query = select(func.count(AgentTask.id)).where(
|
||||||
|
and_(
|
||||||
|
AgentTask.run_id == run_id,
|
||||||
|
AgentTask.type_ == type_,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
task_count = (await self.session.execute(query)).scalar_one()
|
||||||
|
max_ = settings.max_loops
|
||||||
|
|
||||||
|
if task_count >= max_:
|
||||||
|
raise MaxLoopsError(
|
||||||
|
StopIteration(),
|
||||||
|
f"Max loops of {max_} exceeded, shutting down.",
|
||||||
|
429,
|
||||||
|
should_log=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if type_ == "summarize" and task_count > 1:
|
||||||
|
raise MultipleSummaryError(
|
||||||
|
StopIteration(),
|
||||||
|
"Multiple summary tasks are not allowed",
|
||||||
|
429,
|
||||||
|
)
|
10
platform/reworkd_platform/db/crud/base.py
Normal file
10
platform/reworkd_platform/db/crud/base.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="BaseCrud")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCrud:
|
||||||
|
def __init__(self, session: AsyncSession):
|
||||||
|
self.session = session
|
77
platform/reworkd_platform/db/crud/oauth.py
Normal file
77
platform/reworkd_platform/db/crud/oauth.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import secrets
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from reworkd_platform.db.crud.base import BaseCrud
|
||||||
|
from reworkd_platform.db.dependencies import get_db_session
|
||||||
|
from reworkd_platform.db.models.auth import OauthCredentials
|
||||||
|
from reworkd_platform.schemas import UserBase
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthCrud(BaseCrud):
|
||||||
|
@classmethod
|
||||||
|
async def inject(
|
||||||
|
cls,
|
||||||
|
session: AsyncSession = Depends(get_db_session),
|
||||||
|
) -> "OAuthCrud":
|
||||||
|
return cls(session)
|
||||||
|
|
||||||
|
async def create_installation(
|
||||||
|
self, user: UserBase, provider: str, redirect_uri: Optional[str]
|
||||||
|
) -> OauthCredentials:
|
||||||
|
return await OauthCredentials(
|
||||||
|
user_id=user.id,
|
||||||
|
organization_id=user.organization_id,
|
||||||
|
provider=provider,
|
||||||
|
state=secrets.token_hex(16),
|
||||||
|
redirect_uri=redirect_uri,
|
||||||
|
).save(self.session)
|
||||||
|
|
||||||
|
async def get_installation_by_state(self, state: str) -> Optional[OauthCredentials]:
|
||||||
|
query = select(OauthCredentials).filter(OauthCredentials.state == state)
|
||||||
|
|
||||||
|
return (await self.session.execute(query)).scalar_one_or_none()
|
||||||
|
|
||||||
|
async def get_installation_by_user_id(
|
||||||
|
self, user_id: str, provider: str
|
||||||
|
) -> Optional[OauthCredentials]:
|
||||||
|
query = select(OauthCredentials).filter(
|
||||||
|
OauthCredentials.user_id == user_id,
|
||||||
|
OauthCredentials.provider == provider,
|
||||||
|
OauthCredentials.access_token_enc.isnot(None),
|
||||||
|
)
|
||||||
|
|
||||||
|
return (await self.session.execute(query)).scalars().first()
|
||||||
|
|
||||||
|
async def get_installation_by_organization_id(
|
||||||
|
self, organization_id: str, provider: str
|
||||||
|
) -> Optional[OauthCredentials]:
|
||||||
|
query = select(OauthCredentials).filter(
|
||||||
|
OauthCredentials.organization_id == organization_id,
|
||||||
|
OauthCredentials.provider == provider,
|
||||||
|
OauthCredentials.access_token_enc.isnot(None),
|
||||||
|
OauthCredentials.organization_id.isnot(None),
|
||||||
|
)
|
||||||
|
|
||||||
|
return (await self.session.execute(query)).scalars().first()
|
||||||
|
|
||||||
|
async def get_all(self, user: UserBase) -> Dict[str, str]:
|
||||||
|
query = (
|
||||||
|
select(
|
||||||
|
OauthCredentials.provider,
|
||||||
|
func.any_value(OauthCredentials.access_token_enc),
|
||||||
|
)
|
||||||
|
.filter(
|
||||||
|
OauthCredentials.access_token_enc.isnot(None),
|
||||||
|
OauthCredentials.organization_id == user.organization_id,
|
||||||
|
)
|
||||||
|
.group_by(OauthCredentials.provider)
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
provider: token
|
||||||
|
for provider, token in (await self.session.execute(query)).all()
|
||||||
|
}
|
98
platform/reworkd_platform/db/crud/organization.py
Normal file
98
platform/reworkd_platform/db/crud/organization.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
# from
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import and_, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import aliased
|
||||||
|
|
||||||
|
from reworkd_platform.db.crud.base import BaseCrud
|
||||||
|
from reworkd_platform.db.dependencies import get_db_session
|
||||||
|
from reworkd_platform.db.models.auth import Organization, OrganizationUser
|
||||||
|
from reworkd_platform.db.models.user import User
|
||||||
|
from reworkd_platform.schemas import UserBase
|
||||||
|
from reworkd_platform.web.api.dependencies import get_current_user
|
||||||
|
|
||||||
|
|
||||||
|
class OrgUser(BaseModel):
|
||||||
|
id: str
|
||||||
|
role: str
|
||||||
|
user: UserBase
|
||||||
|
|
||||||
|
|
||||||
|
class OrganizationUsers(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
users: List[OrgUser]
|
||||||
|
|
||||||
|
|
||||||
|
class OrganizationCrud(BaseCrud):
|
||||||
|
def __init__(self, session: AsyncSession, user: UserBase):
|
||||||
|
super().__init__(session)
|
||||||
|
self.user = user
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def inject(
|
||||||
|
cls,
|
||||||
|
session: AsyncSession = Depends(get_db_session),
|
||||||
|
user: UserBase = Depends(get_current_user),
|
||||||
|
) -> "OrganizationCrud":
|
||||||
|
return cls(session, user)
|
||||||
|
|
||||||
|
async def create_organization(self, name: str) -> Organization:
|
||||||
|
return await Organization(
|
||||||
|
created_by=self.user.id,
|
||||||
|
name=name,
|
||||||
|
).save(self.session)
|
||||||
|
|
||||||
|
async def get_by_name(self, name: str) -> Optional[OrganizationUsers]:
|
||||||
|
owner = aliased(OrganizationUser, name="owner")
|
||||||
|
|
||||||
|
query = (
|
||||||
|
select(
|
||||||
|
Organization,
|
||||||
|
User,
|
||||||
|
OrganizationUser,
|
||||||
|
)
|
||||||
|
.join(
|
||||||
|
OrganizationUser,
|
||||||
|
and_(
|
||||||
|
Organization.id == OrganizationUser.organization_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.join(
|
||||||
|
User,
|
||||||
|
User.id == OrganizationUser.user_id,
|
||||||
|
)
|
||||||
|
.join( # Owner
|
||||||
|
owner,
|
||||||
|
and_(
|
||||||
|
OrganizationUser.organization_id == Organization.id,
|
||||||
|
OrganizationUser.user_id == self.user.id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.filter(Organization.name == name)
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = (await self.session.execute(query)).all()
|
||||||
|
if not rows:
|
||||||
|
return None
|
||||||
|
|
||||||
|
org: Organization = rows[0][0]
|
||||||
|
return OrganizationUsers(
|
||||||
|
id=org.id,
|
||||||
|
name=org.name,
|
||||||
|
users=[
|
||||||
|
OrgUser(
|
||||||
|
id=org_user.user_id,
|
||||||
|
role=org_user.role,
|
||||||
|
user=UserBase(
|
||||||
|
id=user.id,
|
||||||
|
email=user.email,
|
||||||
|
name=user.name,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for [_, user, org_user] in rows
|
||||||
|
],
|
||||||
|
)
|
31
platform/reworkd_platform/db/crud/user.py
Normal file
31
platform/reworkd_platform/db/crud/user.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import and_, select
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
from reworkd_platform.db.crud.base import BaseCrud
|
||||||
|
from reworkd_platform.db.models.auth import OrganizationUser
|
||||||
|
from reworkd_platform.db.models.user import UserSession
|
||||||
|
|
||||||
|
|
||||||
|
class UserCrud(BaseCrud):
|
||||||
|
async def get_user_session(self, token: str) -> UserSession:
|
||||||
|
query = (
|
||||||
|
select(UserSession)
|
||||||
|
.filter(UserSession.session_token == token)
|
||||||
|
.options(selectinload(UserSession.user))
|
||||||
|
)
|
||||||
|
return (await self.session.execute(query)).scalar_one()
|
||||||
|
|
||||||
|
async def get_user_organization(
|
||||||
|
self, user_id: str, organization_id: str
|
||||||
|
) -> Optional[OrganizationUser]:
|
||||||
|
query = select(OrganizationUser).filter(
|
||||||
|
and_(
|
||||||
|
OrganizationUser.user_id == user_id,
|
||||||
|
OrganizationUser.organization_id == organization_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Only returns the first organization
|
||||||
|
return (await self.session.execute(query)).scalar()
|
20
platform/reworkd_platform/db/dependencies.py
Normal file
20
platform/reworkd_platform/db/dependencies.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
|
||||||
|
async def get_db_session(request: Request) -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""
|
||||||
|
Create and get database session.
|
||||||
|
|
||||||
|
:param request: current request.
|
||||||
|
:yield: database session.
|
||||||
|
"""
|
||||||
|
session: AsyncSession = request.app.state.db_session_factory()
|
||||||
|
|
||||||
|
try: # noqa: WPS501
|
||||||
|
yield session
|
||||||
|
await session.commit()
|
||||||
|
finally:
|
||||||
|
await session.close()
|
3
platform/reworkd_platform/db/meta.py
Normal file
3
platform/reworkd_platform/db/meta.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
meta = sa.MetaData()
|
14
platform/reworkd_platform/db/models/__init__.py
Normal file
14
platform/reworkd_platform/db/models/__init__.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
"""reworkd_platform models."""
|
||||||
|
import pkgutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def load_all_models() -> None:
|
||||||
|
"""Load all models from this folder."""
|
||||||
|
package_dir = Path(__file__).resolve().parent
|
||||||
|
modules = pkgutil.walk_packages(
|
||||||
|
path=[str(package_dir)],
|
||||||
|
prefix="reworkd_platform.db.models.",
|
||||||
|
)
|
||||||
|
for module in modules:
|
||||||
|
__import__(module.name) # noqa: WPS421
|
24
platform/reworkd_platform/db/models/agent.py
Normal file
24
platform/reworkd_platform/db/models/agent.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from sqlalchemy import DateTime, String, Text, func
|
||||||
|
from sqlalchemy.orm import mapped_column
|
||||||
|
|
||||||
|
from reworkd_platform.db.base import Base
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRun(Base):
|
||||||
|
__tablename__ = "agent_run"
|
||||||
|
|
||||||
|
user_id = mapped_column(String, nullable=False)
|
||||||
|
goal = mapped_column(Text, nullable=False)
|
||||||
|
create_date = mapped_column(
|
||||||
|
DateTime, name="create_date", server_default=func.now(), nullable=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTask(Base):
|
||||||
|
__tablename__ = "agent_task"
|
||||||
|
|
||||||
|
run_id = mapped_column(String, nullable=False)
|
||||||
|
type_ = mapped_column(String, nullable=False, name="type")
|
||||||
|
create_date = mapped_column(
|
||||||
|
DateTime, name="create_date", server_default=func.now(), nullable=False
|
||||||
|
)
|
36
platform/reworkd_platform/db/models/auth.py
Normal file
36
platform/reworkd_platform/db/models/auth.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from sqlalchemy import DateTime, String
|
||||||
|
from sqlalchemy.orm import mapped_column
|
||||||
|
|
||||||
|
from reworkd_platform.db.base import TrackedModel
|
||||||
|
|
||||||
|
|
||||||
|
class Organization(TrackedModel):
|
||||||
|
__tablename__ = "organization"
|
||||||
|
|
||||||
|
name = mapped_column(String, nullable=False)
|
||||||
|
created_by = mapped_column(String, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
class OrganizationUser(TrackedModel):
|
||||||
|
__tablename__ = "organization_user"
|
||||||
|
|
||||||
|
user_id = mapped_column(String, nullable=False)
|
||||||
|
organization_id = mapped_column(String, nullable=False)
|
||||||
|
role = mapped_column(String, nullable=False, default="member")
|
||||||
|
|
||||||
|
|
||||||
|
class OauthCredentials(TrackedModel):
|
||||||
|
__tablename__ = "oauth_credentials"
|
||||||
|
|
||||||
|
user_id = mapped_column(String, nullable=False)
|
||||||
|
organization_id = mapped_column(String, nullable=True)
|
||||||
|
provider = mapped_column(String, nullable=False)
|
||||||
|
state = mapped_column(String, nullable=False)
|
||||||
|
redirect_uri = mapped_column(String, nullable=False)
|
||||||
|
|
||||||
|
# Post-installation
|
||||||
|
token_type = mapped_column(String, nullable=True)
|
||||||
|
access_token_enc = mapped_column(String, nullable=True)
|
||||||
|
access_token_expiration = mapped_column(DateTime, nullable=True)
|
||||||
|
refresh_token_enc = mapped_column(String, nullable=True)
|
||||||
|
scope = mapped_column(String, nullable=True)
|
38
platform/reworkd_platform/db/models/user.py
Normal file
38
platform/reworkd_platform/db/models/user.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from sqlalchemy import DateTime, ForeignKey, Index, String, text
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from reworkd_platform.db.base import Base
|
||||||
|
|
||||||
|
|
||||||
|
class UserSession(Base):
|
||||||
|
__tablename__ = "Session"
|
||||||
|
|
||||||
|
session_token = mapped_column(String, unique=True, name="sessionToken")
|
||||||
|
user_id = mapped_column(
|
||||||
|
String, ForeignKey("User.id", ondelete="CASCADE"), name="userId"
|
||||||
|
)
|
||||||
|
expires = mapped_column(DateTime)
|
||||||
|
|
||||||
|
user = relationship("User")
|
||||||
|
|
||||||
|
__table_args__ = (Index("user_id"),)
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = "User"
|
||||||
|
|
||||||
|
name = mapped_column(String, nullable=True)
|
||||||
|
email = mapped_column(String, nullable=True, unique=True)
|
||||||
|
email_verified = mapped_column(DateTime, nullable=True, name="emailVerified")
|
||||||
|
image = mapped_column(String, nullable=True)
|
||||||
|
create_date = mapped_column(
|
||||||
|
DateTime, server_default=text("(now())"), name="createDate"
|
||||||
|
)
|
||||||
|
|
||||||
|
sessions: Mapped[List["UserSession"]] = relationship(
|
||||||
|
"UserSession", back_populates="user"
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (Index("email"),)
|
57
platform/reworkd_platform/db/utils.py
Normal file
57
platform/reworkd_platform/db/utils.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
from ssl import CERT_REQUIRED
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||||
|
|
||||||
|
from reworkd_platform.services.ssl import get_ssl_context
|
||||||
|
from reworkd_platform.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
def create_engine() -> AsyncEngine:
|
||||||
|
"""
|
||||||
|
Creates SQLAlchemy engine instance.
|
||||||
|
|
||||||
|
:return: SQLAlchemy engine instance.
|
||||||
|
"""
|
||||||
|
if settings.environment == "development":
|
||||||
|
return create_async_engine(
|
||||||
|
str(settings.db_url),
|
||||||
|
echo=settings.db_echo,
|
||||||
|
)
|
||||||
|
|
||||||
|
ssl_context = get_ssl_context(settings)
|
||||||
|
ssl_context.verify_mode = CERT_REQUIRED
|
||||||
|
connect_args = {"ssl": ssl_context}
|
||||||
|
|
||||||
|
return create_async_engine(
|
||||||
|
str(settings.db_url),
|
||||||
|
echo=settings.db_echo,
|
||||||
|
connect_args=connect_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_database() -> None:
|
||||||
|
"""Create a database."""
|
||||||
|
engine = create_async_engine(str(settings.db_url.with_path("/mysql")))
|
||||||
|
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
database_existance = await conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT 1 FROM INFORMATION_SCHEMA.SCHEMATA" # noqa: S608
|
||||||
|
f" WHERE SCHEMA_NAME='{settings.db_base}';",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
database_exists = database_existance.scalar() == 1
|
||||||
|
|
||||||
|
if database_exists:
|
||||||
|
await drop_database()
|
||||||
|
|
||||||
|
async with engine.connect() as conn: # noqa: WPS440
|
||||||
|
await conn.execute(text(f"CREATE DATABASE {settings.db_base};"))
|
||||||
|
|
||||||
|
|
||||||
|
async def drop_database() -> None:
|
||||||
|
"""Drop current database."""
|
||||||
|
engine = create_async_engine(str(settings.db_url.with_path("/mysql")))
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
await conn.execute(text(f"DROP DATABASE {settings.db_base};"))
|
63
platform/reworkd_platform/logging.py
Normal file
63
platform/reworkd_platform/logging.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from reworkd_platform.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class InterceptHandler(logging.Handler):
|
||||||
|
"""
|
||||||
|
Default handler from examples in loguru documentation.
|
||||||
|
|
||||||
|
This handler intercepts all log requests and
|
||||||
|
passes them to loguru.
|
||||||
|
|
||||||
|
For more info see:
|
||||||
|
https://loguru.readthedocs.io/en/stable/overview.html#entirely-compatible-with-standard-logging
|
||||||
|
"""
|
||||||
|
|
||||||
|
def emit(self, record: logging.LogRecord) -> None: # pragma: no cover
|
||||||
|
"""
|
||||||
|
Propagates logs to loguru.
|
||||||
|
|
||||||
|
:param record: record to log.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
level: Union[str, int] = logger.level(record.levelname).name
|
||||||
|
except ValueError:
|
||||||
|
level = record.levelno
|
||||||
|
|
||||||
|
# Find caller from where originated the logged message
|
||||||
|
frame, depth = logging.currentframe(), 2
|
||||||
|
while frame.f_code.co_filename == logging.__file__:
|
||||||
|
frame = frame.f_back # type: ignore
|
||||||
|
depth += 1
|
||||||
|
|
||||||
|
logger.opt(depth=depth, exception=record.exc_info).log(
|
||||||
|
level,
|
||||||
|
record.getMessage(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def configure_logging() -> None: # pragma: no cover
|
||||||
|
"""Configures logging."""
|
||||||
|
intercept_handler = InterceptHandler()
|
||||||
|
|
||||||
|
logging.basicConfig(handlers=[intercept_handler], level=logging.NOTSET)
|
||||||
|
|
||||||
|
for logger_name in logging.root.manager.loggerDict:
|
||||||
|
if logger_name.startswith("uvicorn."):
|
||||||
|
logging.getLogger(logger_name).handlers = []
|
||||||
|
|
||||||
|
# change handler for default uvicorn logger
|
||||||
|
logging.getLogger("uvicorn").handlers = [intercept_handler]
|
||||||
|
logging.getLogger("uvicorn.access").handlers = [intercept_handler]
|
||||||
|
|
||||||
|
# set logs output, level and format
|
||||||
|
logger.remove()
|
||||||
|
logger.add(
|
||||||
|
sys.stdout,
|
||||||
|
level=settings.log_level,
|
||||||
|
)
|
2
platform/reworkd_platform/schemas/__init__.py
Normal file
2
platform/reworkd_platform/schemas/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from .agent import ModelSettings
|
||||||
|
from .user import UserBase
|
87
platform/reworkd_platform/schemas/agent.py
Normal file
87
platform/reworkd_platform/schemas/agent.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
|
||||||
|
from reworkd_platform.web.api.agent.analysis import Analysis
|
||||||
|
|
||||||
|
LLM_Model = Literal[
|
||||||
|
"gpt-3.5-turbo",
|
||||||
|
"gpt-3.5-turbo-16k",
|
||||||
|
"gpt-4o",
|
||||||
|
]
|
||||||
|
Loop_Step = Literal[
|
||||||
|
"start",
|
||||||
|
"analyze",
|
||||||
|
"execute",
|
||||||
|
"create",
|
||||||
|
"summarize",
|
||||||
|
"chat",
|
||||||
|
]
|
||||||
|
LLM_MODEL_MAX_TOKENS: Dict[LLM_Model, int] = {
|
||||||
|
"gpt-3.5-turbo": 4000,
|
||||||
|
"gpt-3.5-turbo-16k": 16000,
|
||||||
|
"gpt-4o": 8000,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSettings(BaseModel):
|
||||||
|
model: LLM_Model = Field(default="gpt-4o")
|
||||||
|
custom_api_key: Optional[str] = Field(default=None)
|
||||||
|
temperature: float = Field(default=0.9, ge=0.0, le=1.0)
|
||||||
|
max_tokens: int = Field(default=500, ge=0)
|
||||||
|
language: str = Field(default="English")
|
||||||
|
|
||||||
|
@validator("max_tokens")
|
||||||
|
def validate_max_tokens(cls, v: float, values: Dict[str, Any]) -> float:
|
||||||
|
model = values["model"]
|
||||||
|
if v > (max_tokens := LLM_MODEL_MAX_TOKENS[model]):
|
||||||
|
raise ValueError(f"Model {model} only supports {max_tokens} tokens")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRunCreate(BaseModel):
|
||||||
|
goal: str
|
||||||
|
model_settings: ModelSettings = Field(default=ModelSettings())
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRun(AgentRunCreate):
|
||||||
|
run_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTaskAnalyze(AgentRun):
|
||||||
|
task: str
|
||||||
|
tool_names: List[str] = Field(default=[])
|
||||||
|
model_settings: ModelSettings = Field(default=ModelSettings())
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTaskExecute(AgentRun):
|
||||||
|
task: str
|
||||||
|
analysis: Analysis
|
||||||
|
|
||||||
|
|
||||||
|
class AgentTaskCreate(AgentRun):
|
||||||
|
tasks: List[str] = Field(default=[])
|
||||||
|
last_task: Optional[str] = Field(default=None)
|
||||||
|
result: Optional[str] = Field(default=None)
|
||||||
|
completed_tasks: List[str] = Field(default=[])
|
||||||
|
|
||||||
|
|
||||||
|
class AgentSummarize(AgentRun):
|
||||||
|
results: List[str] = Field(default=[])
|
||||||
|
|
||||||
|
|
||||||
|
class AgentChat(AgentRun):
|
||||||
|
message: str
|
||||||
|
results: List[str] = Field(default=[])
|
||||||
|
|
||||||
|
|
||||||
|
class NewTasksResponse(BaseModel):
|
||||||
|
run_id: str
|
||||||
|
new_tasks: List[str] = Field(alias="newTasks")
|
||||||
|
|
||||||
|
|
||||||
|
class RunCount(BaseModel):
|
||||||
|
count: int
|
||||||
|
first_run: Optional[datetime]
|
||||||
|
last_run: Optional[datetime]
|
21
platform/reworkd_platform/schemas/user.py
Normal file
21
platform/reworkd_platform/schemas/user.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class OrganizationRole(BaseModel):
|
||||||
|
id: str
|
||||||
|
role: str
|
||||||
|
organization_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserBase(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: Optional[str]
|
||||||
|
email: Optional[str]
|
||||||
|
image: Optional[str] = Field(default=None)
|
||||||
|
organization: Optional[OrganizationRole] = Field(default=None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def organization_id(self) -> Optional[str]:
|
||||||
|
return self.organization.organization_id if self.organization else None
|
1
platform/reworkd_platform/services/__init__.py
Normal file
1
platform/reworkd_platform/services/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""Services for reworkd_platform."""
|
42
platform/reworkd_platform/services/anthropic.py
Normal file
42
platform/reworkd_platform/services/anthropic.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from anthropic import AsyncAnthropic
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractPrompt(BaseModel):
|
||||||
|
def to_string(self) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class HumanAssistantPrompt(AbstractPrompt):
|
||||||
|
assistant_prompt: str
|
||||||
|
human_prompt: str
|
||||||
|
|
||||||
|
def to_string(self) -> str:
|
||||||
|
return (
|
||||||
|
f"""\n\nHuman: {self.human_prompt}\n\nAssistant: {self.assistant_prompt}"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeService:
|
||||||
|
def __init__(self, api_key: Optional[str], model: str = "claude-2"):
|
||||||
|
self.claude = AsyncAnthropic(api_key=api_key)
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
async def completion(
|
||||||
|
self,
|
||||||
|
prompt: AbstractPrompt,
|
||||||
|
max_tokens_to_sample: int,
|
||||||
|
temperature: int = 0,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
return (
|
||||||
|
await self.claude.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
prompt=prompt.to_string(),
|
||||||
|
max_tokens_to_sample=max_tokens_to_sample,
|
||||||
|
temperature=temperature,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
).completion.strip()
|
7
platform/reworkd_platform/services/aws/__init__.py
Normal file
7
platform/reworkd_platform/services/aws/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
import boto3
|
||||||
|
from botocore.exceptions import ProfileNotFound
|
||||||
|
|
||||||
|
try:
|
||||||
|
boto3.setup_default_session(profile_name="dev")
|
||||||
|
except ProfileNotFound:
|
||||||
|
pass
|
85
platform/reworkd_platform/services/aws/s3.py
Normal file
85
platform/reworkd_platform/services/aws/s3.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
import io
|
||||||
|
import os
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from aiohttp import ClientError
|
||||||
|
from boto3 import client as boto3_client
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
REGION = "us-east-1"
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection SpellCheckingInspection
|
||||||
|
class PresignedPost(BaseModel):
|
||||||
|
url: str
|
||||||
|
fields: Dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleStorageService:
|
||||||
|
# TODO: would be great if with could make this async
|
||||||
|
|
||||||
|
def __init__(self, bucket: Optional[str]) -> None:
|
||||||
|
if not bucket:
|
||||||
|
raise ValueError("Bucket name must be provided")
|
||||||
|
|
||||||
|
self._client = boto3_client("s3", region_name=REGION)
|
||||||
|
self._bucket = bucket
|
||||||
|
|
||||||
|
def create_presigned_upload_url(
|
||||||
|
self,
|
||||||
|
object_name: str,
|
||||||
|
) -> PresignedPost:
|
||||||
|
return PresignedPost(
|
||||||
|
**self._client.generate_presigned_post(
|
||||||
|
Bucket=self._bucket,
|
||||||
|
Key=object_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_presigned_download_url(self, object_name: str) -> str:
|
||||||
|
return self._client.generate_presigned_url(
|
||||||
|
"get_object",
|
||||||
|
Params={"Bucket": self._bucket, "Key": object_name},
|
||||||
|
)
|
||||||
|
|
||||||
|
def upload_to_bucket(
|
||||||
|
self,
|
||||||
|
object_name: str,
|
||||||
|
file: io.BytesIO,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
self._client.put_object(
|
||||||
|
Bucket=self._bucket, Key=object_name, Body=file.getvalue()
|
||||||
|
)
|
||||||
|
except ClientError as e:
|
||||||
|
logger.error(e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def download_file(self, object_name: str, local_filename: str) -> None:
|
||||||
|
self._client.download_file(
|
||||||
|
Bucket=self._bucket, Key=object_name, Filename=local_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_keys(self, prefix: str) -> List[str]:
|
||||||
|
files = self._client.list_objects_v2(Bucket=self._bucket, Prefix=prefix)
|
||||||
|
if "Contents" not in files:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [file["Key"] for file in files["Contents"]]
|
||||||
|
|
||||||
|
def download_folder(self, prefix: str, path: str) -> List[str]:
|
||||||
|
local_files = []
|
||||||
|
for key in self.list_keys(prefix):
|
||||||
|
local_filename = os.path.join(path, key.split("/")[-1])
|
||||||
|
self.download_file(key, local_filename)
|
||||||
|
local_files.append(local_filename)
|
||||||
|
|
||||||
|
return local_files
|
||||||
|
|
||||||
|
def delete_folder(self, prefix: str) -> None:
|
||||||
|
keys = self.list_keys(prefix)
|
||||||
|
self._client.delete_objects(
|
||||||
|
Bucket=self._bucket,
|
||||||
|
Delete={"Objects": [{"Key": key} for key in keys]},
|
||||||
|
)
|
147
platform/reworkd_platform/services/oauth_installers.py
Normal file
147
platform/reworkd_platform/services/oauth_installers.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
import json
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from fastapi import Depends, Path
|
||||||
|
|
||||||
|
from reworkd_platform.db.crud.oauth import OAuthCrud
|
||||||
|
from reworkd_platform.db.models.auth import OauthCredentials
|
||||||
|
from reworkd_platform.schemas import UserBase
|
||||||
|
from reworkd_platform.services.security import encryption_service
|
||||||
|
from reworkd_platform.settings import Settings
|
||||||
|
from reworkd_platform.settings import settings as platform_settings
|
||||||
|
from reworkd_platform.web.api.http_responses import forbidden
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthInstaller(ABC):
|
||||||
|
def __init__(self, crud: OAuthCrud, settings: Settings):
|
||||||
|
self.crud = crud
|
||||||
|
self.settings = settings
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def install(self, user: UserBase, redirect_uri: str) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def install_callback(self, code: str, state: str) -> OauthCredentials:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def uninstall(self, user: UserBase) -> bool:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def store_access_token(creds: OauthCredentials, access_token: str) -> None:
|
||||||
|
creds.access_token_enc = encryption_service.encrypt(access_token)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def store_refresh_token(creds: OauthCredentials, refresh_token: str) -> None:
|
||||||
|
creds.refresh_token_enc = encryption_service.encrypt(refresh_token)
|
||||||
|
|
||||||
|
|
||||||
|
class SIDInstaller(OAuthInstaller):
|
||||||
|
PROVIDER = "sid"
|
||||||
|
|
||||||
|
async def install(self, user: UserBase, redirect_uri: str) -> str:
|
||||||
|
# gracefully handle the case where the installation already exists
|
||||||
|
# this can happen if the user starts the process from multiple tabs
|
||||||
|
installation = await self.crud.get_installation_by_user_id(
|
||||||
|
user.id, self.PROVIDER
|
||||||
|
)
|
||||||
|
if not installation:
|
||||||
|
installation = await self.crud.create_installation(
|
||||||
|
user,
|
||||||
|
self.PROVIDER,
|
||||||
|
redirect_uri,
|
||||||
|
)
|
||||||
|
scopes = ["data:query", "offline_access"]
|
||||||
|
params = {
|
||||||
|
"client_id": self.settings.sid_client_id,
|
||||||
|
"redirect_uri": self.settings.sid_redirect_uri,
|
||||||
|
"response_type": "code",
|
||||||
|
"scope": " ".join(scopes),
|
||||||
|
"state": installation.state,
|
||||||
|
"audience": "https://api.sid.ai/api/v1/",
|
||||||
|
}
|
||||||
|
auth_url = "https://me.sid.ai/api/oauth/authorize"
|
||||||
|
auth_url += "?" + urlencode(params)
|
||||||
|
return auth_url
|
||||||
|
|
||||||
|
async def install_callback(self, code: str, state: str) -> OauthCredentials:
|
||||||
|
creds = await self.crud.get_installation_by_state(state)
|
||||||
|
if not creds:
|
||||||
|
raise forbidden()
|
||||||
|
req = {
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"client_id": self.settings.sid_client_id,
|
||||||
|
"client_secret": self.settings.sid_client_secret,
|
||||||
|
"redirect_uri": self.settings.sid_redirect_uri,
|
||||||
|
"code": code,
|
||||||
|
}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
"https://auth.sid.ai/oauth/token",
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json",
|
||||||
|
},
|
||||||
|
data=json.dumps(req),
|
||||||
|
) as response:
|
||||||
|
res_data = await response.json()
|
||||||
|
|
||||||
|
OAuthInstaller.store_access_token(creds, res_data["access_token"])
|
||||||
|
OAuthInstaller.store_refresh_token(creds, res_data["refresh_token"])
|
||||||
|
creds.access_token_expiration = datetime.now() + timedelta(
|
||||||
|
seconds=res_data["expires_in"]
|
||||||
|
)
|
||||||
|
return await creds.save(self.crud.session)
|
||||||
|
|
||||||
|
async def uninstall(self, user: UserBase) -> bool:
|
||||||
|
creds = await self.crud.get_installation_by_user_id(user.id, self.PROVIDER)
|
||||||
|
# check if credentials exist and contain a refresh token
|
||||||
|
if not creds:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# use refresh token to revoke access
|
||||||
|
delete_token = encryption_service.decrypt(creds.refresh_token_enc)
|
||||||
|
# delete credentials from database
|
||||||
|
await self.crud.session.delete(creds)
|
||||||
|
|
||||||
|
# revoke refresh token
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
await session.post(
|
||||||
|
"https://auth.sid.ai/oauth/revoke",
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
data=json.dumps(
|
||||||
|
{
|
||||||
|
"client_id": self.settings.sid_client_id,
|
||||||
|
"client_secret": self.settings.sid_client_secret,
|
||||||
|
"token": delete_token,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
integrations = {
|
||||||
|
SIDInstaller.PROVIDER: SIDInstaller,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def installer_factory(
|
||||||
|
provider: str = Path(description="OAuth Provider"),
|
||||||
|
crud: OAuthCrud = Depends(OAuthCrud.inject),
|
||||||
|
) -> OAuthInstaller:
|
||||||
|
"""Factory for OAuth installers
|
||||||
|
Args:
|
||||||
|
provider (str): OAuth Provider (can be slack, github, etc.) (injected)
|
||||||
|
crud (OAuthCrud): OAuth Crud (injected)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if provider in integrations:
|
||||||
|
return integrations[provider](crud, platform_settings)
|
||||||
|
raise NotImplementedError()
|
11
platform/reworkd_platform/services/pinecone/lifetime.py
Normal file
11
platform/reworkd_platform/services/pinecone/lifetime.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
import pinecone
|
||||||
|
|
||||||
|
from reworkd_platform.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
def init_pinecone() -> None:
|
||||||
|
if settings.pinecone_api_key and settings.pinecone_environment:
|
||||||
|
pinecone.init(
|
||||||
|
api_key=settings.pinecone_api_key,
|
||||||
|
environment=settings.pinecone_environment,
|
||||||
|
)
|
98
platform/reworkd_platform/services/pinecone/pinecone.py
Normal file
98
platform/reworkd_platform/services/pinecone/pinecone.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from pinecone import Index # import doesnt work on plane wifi
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from reworkd_platform.settings import settings
|
||||||
|
from reworkd_platform.timer import timed_function
|
||||||
|
from reworkd_platform.web.api.memory.memory import AgentMemory
|
||||||
|
|
||||||
|
OPENAI_EMBEDDING_DIM = 1536
|
||||||
|
|
||||||
|
|
||||||
|
class Row(BaseModel):
|
||||||
|
id: str
|
||||||
|
values: List[float]
|
||||||
|
metadata: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class QueryResult(BaseModel):
|
||||||
|
id: str
|
||||||
|
score: float
|
||||||
|
metadata: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class PineconeMemory(AgentMemory):
|
||||||
|
"""
|
||||||
|
Wrapper around pinecone
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, index_name: str, namespace: str = ""):
|
||||||
|
self.index = Index(settings.pinecone_index_name)
|
||||||
|
self.namespace = namespace or index_name
|
||||||
|
|
||||||
|
@timed_function(level="DEBUG")
|
||||||
|
def __enter__(self) -> AgentMemory:
|
||||||
|
self.embeddings: Embeddings = OpenAIEmbeddings(
|
||||||
|
client=None, # Meta private value but mypy will complain its missing
|
||||||
|
openai_api_key=settings.openai_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@timed_function(level="DEBUG")
|
||||||
|
def reset_class(self) -> None:
|
||||||
|
self.index.delete(delete_all=True, namespace=self.namespace)
|
||||||
|
|
||||||
|
@timed_function(level="DEBUG")
|
||||||
|
def add_tasks(self, tasks: List[str]) -> List[str]:
|
||||||
|
if len(tasks) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
embeds = self.embeddings.embed_documents(tasks)
|
||||||
|
|
||||||
|
if len(tasks) != len(embeds):
|
||||||
|
raise ValueError("Embeddings and tasks are not the same length")
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
Row(values=vector, metadata={"text": tasks[i]}, id=str(uuid.uuid4()))
|
||||||
|
for i, vector in enumerate(embeds)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.index.upsert(
|
||||||
|
vectors=[row.dict() for row in rows], namespace=self.namespace
|
||||||
|
)
|
||||||
|
|
||||||
|
return [row.id for row in rows]
|
||||||
|
|
||||||
|
@timed_function(level="DEBUG")
|
||||||
|
def get_similar_tasks(
|
||||||
|
self, text: str, score_threshold: float = 0.95
|
||||||
|
) -> List[QueryResult]:
|
||||||
|
# Get similar tasks
|
||||||
|
vector = self.embeddings.embed_query(text)
|
||||||
|
results = self.index.query(
|
||||||
|
vector=vector,
|
||||||
|
top_k=5,
|
||||||
|
include_metadata=True,
|
||||||
|
include_values=True,
|
||||||
|
namespace=self.namespace,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
QueryResult(id=row.id, score=row.score, metadata=row.metadata)
|
||||||
|
for row in getattr(results, "matches", [])
|
||||||
|
if row.score > score_threshold
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def should_use() -> bool:
|
||||||
|
return False
|
23
platform/reworkd_platform/services/security.py
Normal file
23
platform/reworkd_platform/services/security.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
|
|
||||||
|
from reworkd_platform.settings import settings
|
||||||
|
from reworkd_platform.web.api.http_responses import forbidden
|
||||||
|
|
||||||
|
|
||||||
|
class EncryptionService:
|
||||||
|
def __init__(self, secret: bytes):
|
||||||
|
self.fernet = Fernet(secret)
|
||||||
|
|
||||||
|
def encrypt(self, text: str) -> bytes:
|
||||||
|
return self.fernet.encrypt(text.encode("utf-8"))
|
||||||
|
|
||||||
|
def decrypt(self, encoded_bytes: Union[bytes, str]) -> str:
|
||||||
|
try:
|
||||||
|
return self.fernet.decrypt(encoded_bytes).decode("utf-8")
|
||||||
|
except InvalidToken:
|
||||||
|
raise forbidden()
|
||||||
|
|
||||||
|
|
||||||
|
encryption_service = EncryptionService(settings.secret_signing_key.encode())
|
25
platform/reworkd_platform/services/ssl.py
Normal file
25
platform/reworkd_platform/services/ssl.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from ssl import SSLContext, create_default_context
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from reworkd_platform.settings import Settings
|
||||||
|
|
||||||
|
MACOS_CERT_PATH = "/etc/ssl/cert.pem"
|
||||||
|
DOCKER_CERT_PATH = "/etc/ssl/certs/ca-certificates.crt"
|
||||||
|
|
||||||
|
|
||||||
|
def get_ssl_context(
|
||||||
|
settings: Settings, paths: Optional[List[str]] = None
|
||||||
|
) -> SSLContext:
|
||||||
|
if settings.db_ca_path:
|
||||||
|
return create_default_context(cafile=settings.db_ca_path)
|
||||||
|
|
||||||
|
for path in paths or [MACOS_CERT_PATH, DOCKER_CERT_PATH]:
|
||||||
|
try:
|
||||||
|
return create_default_context(cafile=path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
"No CA certificates found for your OS. To fix this, please run change "
|
||||||
|
"db_ca_path in your settings.py to point to a valid CA certificate file."
|
||||||
|
)
|
1
platform/reworkd_platform/services/tokenizer/__init__.py
Normal file
1
platform/reworkd_platform/services/tokenizer/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""Token Service"""
|
@ -0,0 +1,7 @@
|
|||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from reworkd_platform.services.tokenizer.token_service import TokenService
|
||||||
|
|
||||||
|
|
||||||
|
def get_token_service(request: Request) -> TokenService:
|
||||||
|
return TokenService(request.app.state.token_encoding)
|
16
platform/reworkd_platform/services/tokenizer/lifetime.py
Normal file
16
platform/reworkd_platform/services/tokenizer/lifetime.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import tiktoken
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
ENCODING_NAME = "cl100k_base" # gpt-4, gpt-3.5-turbo, text-embedding-ada-002
|
||||||
|
|
||||||
|
|
||||||
|
def init_tokenizer(app: FastAPI) -> None: # pragma: no cover
|
||||||
|
"""
|
||||||
|
Initialize tokenizer.
|
||||||
|
|
||||||
|
TikToken downloads the encoding on start. It is then
|
||||||
|
stored in the state of the application.
|
||||||
|
|
||||||
|
:param app: current application.
|
||||||
|
"""
|
||||||
|
app.state.token_encoding = tiktoken.get_encoding(ENCODING_NAME)
|
@ -0,0 +1,33 @@
|
|||||||
|
from tiktoken import Encoding, get_encoding
|
||||||
|
|
||||||
|
from reworkd_platform.schemas.agent import LLM_MODEL_MAX_TOKENS, LLM_Model
|
||||||
|
from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
|
||||||
|
|
||||||
|
|
||||||
|
class TokenService:
|
||||||
|
def __init__(self, encoding: Encoding):
|
||||||
|
self.encoding = encoding
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, encoding: str = "cl100k_base") -> "TokenService":
|
||||||
|
return cls(get_encoding(encoding))
|
||||||
|
|
||||||
|
def tokenize(self, text: str) -> list[int]:
|
||||||
|
return self.encoding.encode(text)
|
||||||
|
|
||||||
|
def detokenize(self, tokens: list[int]) -> str:
|
||||||
|
return self.encoding.decode(tokens)
|
||||||
|
|
||||||
|
def count(self, text: str) -> int:
|
||||||
|
return len(self.tokenize(text))
|
||||||
|
|
||||||
|
def get_completion_space(self, model: LLM_Model, *prompts: str) -> int:
|
||||||
|
max_allowed_tokens = LLM_MODEL_MAX_TOKENS.get(model, 4000)
|
||||||
|
prompt_tokens = sum([self.count(p) for p in prompts])
|
||||||
|
return max_allowed_tokens - prompt_tokens
|
||||||
|
|
||||||
|
def calculate_max_tokens(self, model: WrappedChatOpenAI, *prompts: str) -> None:
|
||||||
|
requested_tokens = self.get_completion_space(model.model_name, *prompts)
|
||||||
|
|
||||||
|
model.max_tokens = min(model.max_tokens, requested_tokens)
|
||||||
|
model.max_tokens = max(model.max_tokens, 1)
|
178
platform/reworkd_platform/settings.py
Normal file
178
platform/reworkd_platform/settings.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
import platform
|
||||||
|
from pathlib import Path
|
||||||
|
from tempfile import gettempdir
|
||||||
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseSettings
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
|
from reworkd_platform.constants import ENV_PREFIX
|
||||||
|
|
||||||
|
TEMP_DIR = Path(gettempdir())
|
||||||
|
|
||||||
|
LOG_LEVEL = Literal[
|
||||||
|
"NOTSET",
|
||||||
|
"DEBUG",
|
||||||
|
"INFO",
|
||||||
|
"WARNING",
|
||||||
|
"ERROR",
|
||||||
|
"FATAL",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
SASL_MECHANISM = Literal[
|
||||||
|
"PLAIN",
|
||||||
|
"SCRAM-SHA-256",
|
||||||
|
]
|
||||||
|
|
||||||
|
ENVIRONMENT = Literal[
|
||||||
|
"development",
|
||||||
|
"production",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""
|
||||||
|
Application settings.
|
||||||
|
|
||||||
|
These parameters can be configured
|
||||||
|
with environment variables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Application settings
|
||||||
|
host: str = "0.0.0.0"
|
||||||
|
port: int = 8000
|
||||||
|
workers_count: int = 1
|
||||||
|
reload: bool = True
|
||||||
|
environment: ENVIRONMENT = "development"
|
||||||
|
log_level: LOG_LEVEL = "DEBUG"
|
||||||
|
|
||||||
|
# Make sure you update this with your own secret key
|
||||||
|
# Must be 32 url-safe base64-encoded bytes
|
||||||
|
secret_signing_key: str = "JF52S66x6WMoifP5gZreiguYs9LYMn0lkXqgPYoNMD0="
|
||||||
|
|
||||||
|
# OpenAI
|
||||||
|
openai_api_base: str = "https://api.openai.com/v1"
|
||||||
|
openai_api_key: str = "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"
|
||||||
|
openai_api_version: str = "2023-08-01-preview"
|
||||||
|
azure_openai_deployment_name: str = "<Should be updated via env if using azure>"
|
||||||
|
|
||||||
|
# Helicone
|
||||||
|
helicone_api_base: str = "https://oai.hconeai.com/v1"
|
||||||
|
helicone_api_key: Optional[str] = None
|
||||||
|
|
||||||
|
replicate_api_key: Optional[str] = None
|
||||||
|
serp_api_key: Optional[str] = None
|
||||||
|
|
||||||
|
# Frontend URL for CORS
|
||||||
|
frontend_url: str = "https://allinix.ai"
|
||||||
|
allowed_origins_regex: Optional[str] = None
|
||||||
|
|
||||||
|
# Variables for the database
|
||||||
|
db_host: str = "65.19.178.35"
|
||||||
|
db_port: int = 3307
|
||||||
|
db_user: str = "reworkd_platform"
|
||||||
|
db_pass: str = "reworkd_platform"
|
||||||
|
db_base: str = "default"
|
||||||
|
db_echo: bool = False
|
||||||
|
db_ca_path: Optional[str] = None
|
||||||
|
|
||||||
|
# Variables for Pinecone DB
|
||||||
|
pinecone_api_key: Optional[str] = None
|
||||||
|
pinecone_index_name: Optional[str] = None
|
||||||
|
pinecone_environment: Optional[str] = None
|
||||||
|
|
||||||
|
# Sentry's configuration.
|
||||||
|
sentry_dsn: Optional[str] = None
|
||||||
|
sentry_sample_rate: float = 1.0
|
||||||
|
|
||||||
|
kafka_bootstrap_servers: Union[str, List[str]] = []
|
||||||
|
kafka_username: Optional[str] = None
|
||||||
|
kafka_password: Optional[str] = None
|
||||||
|
kafka_ssal_mechanism: SASL_MECHANISM = "PLAIN"
|
||||||
|
|
||||||
|
# Websocket settings
|
||||||
|
pusher_app_id: Optional[str] = None
|
||||||
|
pusher_key: Optional[str] = None
|
||||||
|
pusher_secret: Optional[str] = None
|
||||||
|
pusher_cluster: Optional[str] = None
|
||||||
|
|
||||||
|
# Application Settings
|
||||||
|
ff_mock_mode_enabled: bool = False # Controls whether calls are mocked
|
||||||
|
max_loops: int = 25 # Maximum number of loops to run
|
||||||
|
|
||||||
|
# Settings for sid
|
||||||
|
sid_client_id: Optional[str] = None
|
||||||
|
sid_client_secret: Optional[str] = None
|
||||||
|
sid_redirect_uri: Optional[str] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kafka_consumer_group(self) -> str:
|
||||||
|
"""
|
||||||
|
Kafka consumer group will be the name of the host in development
|
||||||
|
mode, making it easier to share a dev cluster.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.environment == "development":
|
||||||
|
return platform.node()
|
||||||
|
|
||||||
|
return "platform"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def db_url(self) -> URL:
|
||||||
|
return URL.build(
|
||||||
|
scheme="mysql+aiomysql",
|
||||||
|
host=self.db_host,
|
||||||
|
port=self.db_port,
|
||||||
|
user=self.db_user,
|
||||||
|
password=self.db_pass,
|
||||||
|
path=f"/{self.db_base}",
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pusher_enabled(self) -> bool:
|
||||||
|
return all(
|
||||||
|
[
|
||||||
|
self.pusher_app_id,
|
||||||
|
self.pusher_key,
|
||||||
|
self.pusher_secret,
|
||||||
|
self.pusher_cluster,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kafka_enabled(self) -> bool:
|
||||||
|
return all(
|
||||||
|
[
|
||||||
|
self.kafka_bootstrap_servers,
|
||||||
|
self.kafka_username,
|
||||||
|
self.kafka_password,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def helicone_enabled(self) -> bool:
|
||||||
|
return all(
|
||||||
|
[
|
||||||
|
self.helicone_api_base,
|
||||||
|
self.helicone_api_key,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sid_enabled(self) -> bool:
|
||||||
|
return all(
|
||||||
|
[
|
||||||
|
self.sid_client_id,
|
||||||
|
self.sid_client_secret,
|
||||||
|
self.sid_redirect_uri,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
env_prefix = ENV_PREFIX
|
||||||
|
env_file_encoding = "utf-8"
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
1
platform/reworkd_platform/tests/__init__.py
Normal file
1
platform/reworkd_platform/tests/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""Tests for reworkd_platform."""
|
31
platform/reworkd_platform/tests/agent/test_analysis.py
Normal file
31
platform/reworkd_platform/tests/agent/test_analysis.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from reworkd_platform.web.api.agent.analysis import Analysis
|
||||||
|
from reworkd_platform.web.api.agent.tools.tools import get_default_tool, get_tool_name
|
||||||
|
|
||||||
|
|
||||||
|
def test_analysis_model() -> None:
|
||||||
|
valid_tool_name = get_tool_name(get_default_tool())
|
||||||
|
analysis = Analysis(action=valid_tool_name, arg="arg", reasoning="reasoning")
|
||||||
|
|
||||||
|
assert analysis.action == valid_tool_name
|
||||||
|
assert analysis.arg == "arg"
|
||||||
|
assert analysis.reasoning == "reasoning"
|
||||||
|
|
||||||
|
|
||||||
|
def test_analysis_model_search_empty_arg() -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Analysis(action="search", arg="", reasoning="reasoning")
|
||||||
|
|
||||||
|
|
||||||
|
def test_analysis_model_search_non_empty_arg() -> None:
|
||||||
|
analysis = Analysis(action="search", arg="non-empty arg", reasoning="reasoning")
|
||||||
|
assert analysis.action == "search"
|
||||||
|
assert analysis.arg == "non-empty arg"
|
||||||
|
assert analysis.reasoning == "reasoning"
|
||||||
|
|
||||||
|
|
||||||
|
def test_analysis_model_invalid_tool() -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Analysis(action="invalid tool name", arg="test argument", reasoning="reasoning")
|
63
platform/reworkd_platform/tests/agent/test_crud.py
Normal file
63
platform/reworkd_platform/tests/agent/test_crud.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
from reworkd_platform.db.crud.agent import AgentCRUD
|
||||||
|
from reworkd_platform.settings import settings
|
||||||
|
from reworkd_platform.web.api.errors import MaxLoopsError, MultipleSummaryError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_task_count_no_error(mocker) -> None:
|
||||||
|
mock_agent_run_exists(mocker, True)
|
||||||
|
session = mock_session_with_run_count(mocker, 0)
|
||||||
|
agent_crud: AgentCRUD = AgentCRUD(session, mocker.MagicMock())
|
||||||
|
|
||||||
|
# Doesn't throw an exception
|
||||||
|
await agent_crud.validate_task_count("test", "summarize")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_task_count_when_run_not_found(mocker: MockerFixture) -> None:
|
||||||
|
mock_agent_run_exists(mocker, False)
|
||||||
|
agent_crud: AgentCRUD = AgentCRUD(mocker.AsyncMock(), mocker.MagicMock())
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException):
|
||||||
|
await agent_crud.validate_task_count("test", "test")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_task_count_max_loops_error(mocker: MockerFixture) -> None:
|
||||||
|
mock_agent_run_exists(mocker, True)
|
||||||
|
session = mock_session_with_run_count(mocker, settings.max_loops)
|
||||||
|
agent_crud: AgentCRUD = AgentCRUD(session, mocker.AsyncMock())
|
||||||
|
|
||||||
|
with pytest.raises(MaxLoopsError):
|
||||||
|
await agent_crud.validate_task_count("test", "test")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_task_count_multiple_summary_error(
|
||||||
|
mocker: MockerFixture,
|
||||||
|
) -> None:
|
||||||
|
mock_agent_run_exists(mocker, True)
|
||||||
|
session = mock_session_with_run_count(mocker, 2)
|
||||||
|
agent_crud: AgentCRUD = AgentCRUD(session, mocker.MagicMock())
|
||||||
|
|
||||||
|
with pytest.raises(MultipleSummaryError):
|
||||||
|
await agent_crud.validate_task_count("test", "summarize")
|
||||||
|
|
||||||
|
|
||||||
|
def mock_agent_run_exists(mocker: MockerFixture, exists: bool) -> None:
|
||||||
|
mocker.patch("reworkd_platform.db.models.agent.AgentRun.get", return_value=exists)
|
||||||
|
|
||||||
|
|
||||||
|
def mock_session_with_run_count(mocker: MockerFixture, run_count: int) -> AsyncMock:
|
||||||
|
session = mocker.AsyncMock()
|
||||||
|
scalar_mock = mocker.MagicMock()
|
||||||
|
|
||||||
|
session.execute.return_value = scalar_mock
|
||||||
|
scalar_mock.scalar_one.return_value = run_count
|
||||||
|
return session
|
138
platform/reworkd_platform/tests/agent/test_model_factory.py
Normal file
138
platform/reworkd_platform/tests/agent/test_model_factory.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
import itertools
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
||||||
|
|
||||||
|
from reworkd_platform.schemas import ModelSettings, UserBase
|
||||||
|
from reworkd_platform.settings import Settings
|
||||||
|
from reworkd_platform.web.api.agent.model_factory import (
|
||||||
|
WrappedAzureChatOpenAI,
|
||||||
|
WrappedChatOpenAI,
|
||||||
|
create_model,
|
||||||
|
get_base_and_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_helicone_enabled_without_custom_api_key():
|
||||||
|
model_settings = ModelSettings()
|
||||||
|
user = UserBase(id="user_id")
|
||||||
|
settings = Settings(
|
||||||
|
helicone_api_key="some_key",
|
||||||
|
helicone_api_base="helicone_base",
|
||||||
|
openai_api_base="openai_base",
|
||||||
|
)
|
||||||
|
|
||||||
|
base, headers, use_helicone = get_base_and_headers(settings, model_settings, user)
|
||||||
|
|
||||||
|
assert use_helicone is True
|
||||||
|
assert base == "helicone_base"
|
||||||
|
assert headers == {
|
||||||
|
"Helicone-Auth": "Bearer some_key",
|
||||||
|
"Helicone-Cache-Enabled": "true",
|
||||||
|
"Helicone-User-Id": "user_id",
|
||||||
|
"Helicone-OpenAI-Api-Base": "openai_base",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_helicone_disabled():
|
||||||
|
model_settings = ModelSettings()
|
||||||
|
user = UserBase(id="user_id")
|
||||||
|
settings = Settings()
|
||||||
|
|
||||||
|
base, headers, use_helicone = get_base_and_headers(settings, model_settings, user)
|
||||||
|
assert base == "https://api.openai.com/v1"
|
||||||
|
assert headers is None
|
||||||
|
assert use_helicone is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_helicone_enabled_with_custom_api_key():
|
||||||
|
model_settings = ModelSettings(
|
||||||
|
custom_api_key="custom_key",
|
||||||
|
)
|
||||||
|
user = UserBase(id="user_id")
|
||||||
|
settings = Settings(
|
||||||
|
openai_api_base="openai_base",
|
||||||
|
helicone_api_key="some_key",
|
||||||
|
helicone_api_base="helicone_base",
|
||||||
|
)
|
||||||
|
|
||||||
|
base, headers, use_helicone = get_base_and_headers(settings, model_settings, user)
|
||||||
|
|
||||||
|
assert base == "https://api.openai.com/v1"
|
||||||
|
assert headers is None
|
||||||
|
assert use_helicone is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"streaming, use_azure",
|
||||||
|
list(
|
||||||
|
itertools.product(
|
||||||
|
[True, False],
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_create_model(streaming, use_azure):
|
||||||
|
user = UserBase(id="user_id")
|
||||||
|
settings = Settings()
|
||||||
|
model_settings = ModelSettings(
|
||||||
|
temperature=0.7,
|
||||||
|
model="gpt-4o",
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
settings.openai_api_base = (
|
||||||
|
"https://api.openai.com" if not use_azure else "https://oai.azure.com"
|
||||||
|
)
|
||||||
|
settings.openai_api_key = "key"
|
||||||
|
settings.openai_api_version = "version"
|
||||||
|
|
||||||
|
result = create_model(settings, model_settings, user, streaming)
|
||||||
|
assert issubclass(result.__class__, WrappedChatOpenAI)
|
||||||
|
assert issubclass(result.__class__, ChatOpenAI)
|
||||||
|
|
||||||
|
# Check if the required keys are set properly
|
||||||
|
assert result.openai_api_base == settings.openai_api_base
|
||||||
|
assert result.openai_api_key == settings.openai_api_key
|
||||||
|
assert result.temperature == model_settings.temperature
|
||||||
|
assert result.max_tokens == model_settings.max_tokens
|
||||||
|
assert result.streaming == streaming
|
||||||
|
assert result.max_retries == 5
|
||||||
|
|
||||||
|
# For Azure specific checks
|
||||||
|
if use_azure:
|
||||||
|
assert isinstance(result, WrappedAzureChatOpenAI)
|
||||||
|
assert issubclass(result.__class__, AzureChatOpenAI)
|
||||||
|
assert result.openai_api_version == settings.openai_api_version
|
||||||
|
assert result.deployment_name == "gpt-35-turbo"
|
||||||
|
assert result.openai_api_type == "azure"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_settings, streaming",
|
||||||
|
list(
|
||||||
|
itertools.product(
|
||||||
|
[
|
||||||
|
ModelSettings(
|
||||||
|
customTemperature=0.222,
|
||||||
|
customModelName="gpt-4",
|
||||||
|
maxTokens=1234,
|
||||||
|
),
|
||||||
|
ModelSettings(),
|
||||||
|
],
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_custom_model_settings(model_settings: ModelSettings, streaming: bool):
|
||||||
|
model = create_model(
|
||||||
|
Settings(),
|
||||||
|
model_settings,
|
||||||
|
UserBase(id="", email="test@example.com"),
|
||||||
|
streaming=streaming,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert model.temperature == model_settings.temperature
|
||||||
|
assert model.model_name.startswith(model_settings.model)
|
||||||
|
assert model.max_tokens == model_settings.max_tokens
|
||||||
|
assert model.streaming == streaming
|
203
platform/reworkd_platform/tests/agent/test_task_output_parser.py
Normal file
203
platform/reworkd_platform/tests/agent/test_task_output_parser.py
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
from typing import List, Type
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain.schema import OutputParserException
|
||||||
|
|
||||||
|
from reworkd_platform.web.api.agent.task_output_parser import (
|
||||||
|
TaskOutputParser,
|
||||||
|
extract_array,
|
||||||
|
real_tasks_filter,
|
||||||
|
remove_prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_text,expected_output",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
'["Task 1: Do something", "Task 2: Do something else", "Task 3: Do '
|
||||||
|
'another thing"]',
|
||||||
|
["Do something", "Do something else", "Do another thing"],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
'Some random stuff ["1: Hello"]',
|
||||||
|
["Hello"],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"[]",
|
||||||
|
[],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_parse_success(input_text: str, expected_output: List[str]) -> None:
|
||||||
|
parser = TaskOutputParser(completed_tasks=[])
|
||||||
|
result = parser.parse(input_text)
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_with_completed_tasks() -> None:
|
||||||
|
input_text = '["One", "Two", "Three"]'
|
||||||
|
completed = ["One"]
|
||||||
|
expected = ["Two", "Three"]
|
||||||
|
|
||||||
|
parser = TaskOutputParser(completed_tasks=completed)
|
||||||
|
|
||||||
|
result = parser.parse(input_text)
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_text, exception",
|
||||||
|
[
|
||||||
|
# Test cases for non-array and non-multiline string inputs
|
||||||
|
("This is not an array", OutputParserException),
|
||||||
|
("123456", OutputParserException),
|
||||||
|
("Some random text", OutputParserException),
|
||||||
|
("[abc]", OutputParserException),
|
||||||
|
# Test cases for malformed arrays
|
||||||
|
("[1, 2, 3", OutputParserException),
|
||||||
|
("'item1', 'item2']", OutputParserException),
|
||||||
|
("['item1', 'item2", OutputParserException),
|
||||||
|
# Test case for invalid multiline strings
|
||||||
|
("This is not\na valid\nmultiline string.", OutputParserException),
|
||||||
|
# Test case for multiline strings that don't start with digit + period
|
||||||
|
("Some text\nMore text\nAnd more text.", OutputParserException),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_parse_failure(input_text: str, exception: Type[Exception]) -> None:
|
||||||
|
parser = TaskOutputParser(completed_tasks=[])
|
||||||
|
with pytest.raises(exception):
|
||||||
|
parser.parse(input_text)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_str, expected",
|
||||||
|
[
|
||||||
|
# Test cases for empty array
|
||||||
|
("[]", []),
|
||||||
|
# Test cases for arrays with one element
|
||||||
|
('["One"]', ["One"]),
|
||||||
|
("['Single quote']", ["Single quote"]),
|
||||||
|
# Test cases for arrays with multiple elements
|
||||||
|
('["Research", "Develop", "Integrate"]', ["Research", "Develop", "Integrate"]),
|
||||||
|
('["Search", "Identify"]', ["Search", "Identify"]),
|
||||||
|
('["Item 1","Item 2","Item 3"]', ["Item 1", "Item 2", "Item 3"]),
|
||||||
|
# Test cases for arrays with special characters in elements
|
||||||
|
("['Single with \"quote\"']", ['Single with "quote"']),
|
||||||
|
('["Escape \\" within"]', ['Escape " within']),
|
||||||
|
# Test case for array embedded in other text
|
||||||
|
("Random stuff ['Search', 'Identify']", ["Search", "Identify"]),
|
||||||
|
# Test case for array within JSON
|
||||||
|
('{"array": ["123", "456"]}', ["123", "456"]),
|
||||||
|
# Multiline string cases
|
||||||
|
(
|
||||||
|
"1. Identify the target\n2. Conduct research\n3. Implement the methods",
|
||||||
|
[
|
||||||
|
"1. Identify the target",
|
||||||
|
"2. Conduct research",
|
||||||
|
"3. Implement the methods",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
("1. Step one.\n2. Step two.", ["1. Step one.", "2. Step two."]),
|
||||||
|
(
|
||||||
|
"""1. Review and understand the code to be debugged
|
||||||
|
2. Identify and address any errors or issues found during the review process
|
||||||
|
3. Print out debug information and setup initial variables
|
||||||
|
4. Start necessary threads and execute program logic.""",
|
||||||
|
[
|
||||||
|
"1. Review and understand the code to be debugged",
|
||||||
|
"2. Identify and address any errors or issues found during the review "
|
||||||
|
"process",
|
||||||
|
"3. Print out debug information and setup initial variables",
|
||||||
|
"4. Start necessary threads and execute program logic.",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
# Test cases with sentences before the digit + period pattern
|
||||||
|
(
|
||||||
|
"Any text before 1. Identify the task to be repeated\nUnrelated info 2. "
|
||||||
|
"Determine the frequency of the repetition\nAnother sentence 3. Create a "
|
||||||
|
"schedule or system to ensure completion of the task at the designated "
|
||||||
|
"frequency\nMore text 4. Execute the task according to the established "
|
||||||
|
"schedule or system",
|
||||||
|
[
|
||||||
|
"1. Identify the task to be repeated",
|
||||||
|
"2. Determine the frequency of the repetition",
|
||||||
|
"3. Create a schedule or system to ensure completion of the task at "
|
||||||
|
"the designated frequency",
|
||||||
|
"4. Execute the task according to the established schedule or system",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_extract_array_success(input_str: str, expected: List[str]) -> None:
|
||||||
|
print(extract_array(input_str), expected)
|
||||||
|
assert extract_array(input_str) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_str, exception",
|
||||||
|
[
|
||||||
|
(None, TypeError),
|
||||||
|
("123", RuntimeError),
|
||||||
|
("Some random text", RuntimeError),
|
||||||
|
('"single_string"', RuntimeError),
|
||||||
|
('{"test": 123}', RuntimeError),
|
||||||
|
('["Unclosed array", "other"', RuntimeError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_extract_array_exception(input_str: str, exception: Type[Exception]) -> None:
|
||||||
|
with pytest.raises(exception):
|
||||||
|
extract_array(input_str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"task_input, expected_output",
|
||||||
|
[
|
||||||
|
("Task: This is a sample task", "This is a sample task"),
|
||||||
|
(
|
||||||
|
"Task 1: Perform a comprehensive analysis of system performance.",
|
||||||
|
"Perform a comprehensive analysis of system performance.",
|
||||||
|
),
|
||||||
|
("Task 2. Create a python script", "Create a python script"),
|
||||||
|
("5 - This is a sample task", "This is a sample task"),
|
||||||
|
("2: This is a sample task", "This is a sample task"),
|
||||||
|
(
|
||||||
|
"This is a sample task without a prefix",
|
||||||
|
"This is a sample task without a prefix",
|
||||||
|
),
|
||||||
|
("Step: This is a sample task", "This is a sample task"),
|
||||||
|
(
|
||||||
|
"Step 1: Perform a comprehensive analysis of system performance.",
|
||||||
|
"Perform a comprehensive analysis of system performance.",
|
||||||
|
),
|
||||||
|
("Step 2:Create a python script", "Create a python script"),
|
||||||
|
("Step:This is a sample task", "This is a sample task"),
|
||||||
|
(
|
||||||
|
". Conduct research on the history of Nike",
|
||||||
|
"Conduct research on the history of Nike",
|
||||||
|
),
|
||||||
|
(".This is a sample task", "This is a sample task"),
|
||||||
|
(
|
||||||
|
"1. Research the history and background of Nike company.",
|
||||||
|
"Research the history and background of Nike company.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_remove_task_prefix(task_input: str, expected_output: str) -> None:
|
||||||
|
output = remove_prefix(task_input)
|
||||||
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_text, expected_result",
|
||||||
|
[
|
||||||
|
("Write the report", True),
|
||||||
|
("No new task needed", False),
|
||||||
|
("Task completed", False),
|
||||||
|
("Do nothing", False),
|
||||||
|
("", False), # empty_string
|
||||||
|
("no new task needed", False), # case_insensitive
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_real_tasks_filter_no_task(input_text: str, expected_result: bool) -> None:
|
||||||
|
assert real_tasks_filter(input_text) == expected_result
|
56
platform/reworkd_platform/tests/agent/test_tools.py
Normal file
56
platform/reworkd_platform/tests/agent/test_tools.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from typing import List, Type
|
||||||
|
|
||||||
|
from reworkd_platform.web.api.agent.tools.conclude import Conclude
|
||||||
|
from reworkd_platform.web.api.agent.tools.image import Image
|
||||||
|
from reworkd_platform.web.api.agent.tools.reason import Reason
|
||||||
|
from reworkd_platform.web.api.agent.tools.search import Search
|
||||||
|
from reworkd_platform.web.api.agent.tools.sidsearch import SID
|
||||||
|
from reworkd_platform.web.api.agent.tools.tools import (
|
||||||
|
Tool,
|
||||||
|
format_tool_name,
|
||||||
|
get_default_tool,
|
||||||
|
get_tool_from_name,
|
||||||
|
get_tool_name,
|
||||||
|
get_tools_overview,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_tool_name() -> None:
|
||||||
|
assert get_tool_name(Image) == "image"
|
||||||
|
assert get_tool_name(Search) == "search"
|
||||||
|
assert get_tool_name(Reason) == "reason"
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_tool_name() -> None:
|
||||||
|
assert format_tool_name("Search") == "search"
|
||||||
|
assert format_tool_name("reason") == "reason"
|
||||||
|
assert format_tool_name("Conclude") == "conclude"
|
||||||
|
assert format_tool_name("CoNcLuDe") == "conclude"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_tools_overview_no_duplicates() -> None:
|
||||||
|
"""Test to assert that the tools overview doesn't include duplicates."""
|
||||||
|
tools: List[Type[Tool]] = [Image, Search, Reason, Conclude, Image, Search]
|
||||||
|
overview = get_tools_overview(tools)
|
||||||
|
|
||||||
|
# Check if each unique tool description is included in the overview
|
||||||
|
for tool in set(tools):
|
||||||
|
expected_description = f"'{get_tool_name(tool)}': {tool.description}"
|
||||||
|
assert expected_description in overview
|
||||||
|
|
||||||
|
# Check for duplicates in the overview
|
||||||
|
overview_list = overview.split("\n")
|
||||||
|
assert len(overview_list) == len(
|
||||||
|
set(overview_list)
|
||||||
|
), "Overview includes duplicate entries"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_default_tool() -> None:
|
||||||
|
assert get_default_tool() == Search
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_tool_from_name() -> None:
|
||||||
|
assert get_tool_from_name("Search") == Search
|
||||||
|
assert get_tool_from_name("CoNcLuDe") == Conclude
|
||||||
|
assert get_tool_from_name("NonExistingTool") == Search
|
||||||
|
assert get_tool_from_name("SID") == SID
|
@ -0,0 +1,43 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from reworkd_platform.web.api.memory.memory_with_fallback import MemoryWithFallback
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"method_name, args",
|
||||||
|
[
|
||||||
|
("add_tasks", (["task1", "task2"],)),
|
||||||
|
("get_similar_tasks", ("task1",)),
|
||||||
|
("reset_class", ()),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_memory_primary(mocker, method_name: str, args) -> None:
|
||||||
|
primary = mocker.Mock()
|
||||||
|
secondary = mocker.Mock()
|
||||||
|
memory_with_fallback = MemoryWithFallback(primary, secondary)
|
||||||
|
|
||||||
|
# Use getattr() to call the method on the object with args
|
||||||
|
getattr(memory_with_fallback, method_name)(*args)
|
||||||
|
getattr(primary, method_name).assert_called_once_with(*args)
|
||||||
|
getattr(secondary, method_name).assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"method_name, args",
|
||||||
|
[
|
||||||
|
("add_tasks", (["task1", "task2"],)),
|
||||||
|
("get_similar_tasks", ("task1",)),
|
||||||
|
("reset_class", ()),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_memory_fallback(mocker, method_name: str, args) -> None:
|
||||||
|
primary = mocker.Mock()
|
||||||
|
secondary = mocker.Mock()
|
||||||
|
memory_with_fallback = MemoryWithFallback(primary, secondary)
|
||||||
|
|
||||||
|
getattr(primary, method_name).side_effect = Exception("Primary Failed")
|
||||||
|
|
||||||
|
# Call the method again, this time it should fall back to secondary
|
||||||
|
getattr(memory_with_fallback, method_name)(*args)
|
||||||
|
getattr(primary, method_name).assert_called_once_with(*args)
|
||||||
|
getattr(secondary, method_name).assert_called_once_with(*args)
|
26
platform/reworkd_platform/tests/test_dependancies.py
Normal file
26
platform/reworkd_platform/tests/test_dependancies.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from reworkd_platform.web.api.agent import dependancies
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"validator, step",
|
||||||
|
[
|
||||||
|
(dependancies.agent_summarize_validator, "summarize"),
|
||||||
|
(dependancies.agent_chat_validator, "chat"),
|
||||||
|
(dependancies.agent_analyze_validator, "analyze"),
|
||||||
|
(dependancies.agent_create_validator, "create"),
|
||||||
|
(dependancies.agent_execute_validator, "execute"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_agent_validate(mocker, validator, step):
|
||||||
|
run_id = "asim"
|
||||||
|
crud = mocker.Mock()
|
||||||
|
body = mocker.Mock()
|
||||||
|
body.run_id = run_id
|
||||||
|
|
||||||
|
crud.create_task = mocker.AsyncMock()
|
||||||
|
|
||||||
|
await validator(body, crud)
|
||||||
|
crud.create_task.assert_called_once_with(run_id, step)
|
45
platform/reworkd_platform/tests/test_helpers.py
Normal file
45
platform/reworkd_platform/tests/test_helpers.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import pytest
|
||||||
|
from openai.error import InvalidRequestError, ServiceUnavailableError
|
||||||
|
|
||||||
|
from reworkd_platform.schemas.agent import ModelSettings
|
||||||
|
from reworkd_platform.web.api.agent.helpers import openai_error_handler
|
||||||
|
from reworkd_platform.web.api.errors import OpenAIError
|
||||||
|
|
||||||
|
|
||||||
|
async def act(*args, settings: ModelSettings = ModelSettings(), **kwargs):
|
||||||
|
return await openai_error_handler(*args, settings=settings, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_service_unavailable_error():
|
||||||
|
async def mock_service_unavailable_error():
|
||||||
|
raise ServiceUnavailableError("Service Unavailable")
|
||||||
|
|
||||||
|
with pytest.raises(OpenAIError):
|
||||||
|
await act(mock_service_unavailable_error)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"settings,should_log",
|
||||||
|
[
|
||||||
|
(ModelSettings(custom_api_key="xyz"), False),
|
||||||
|
(ModelSettings(custom_api_key=None), True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_should_log(settings, should_log):
|
||||||
|
async def mock_invalid_request_error_model_access():
|
||||||
|
raise InvalidRequestError(
|
||||||
|
"The model: xyz does not exist or you do not have access to it.",
|
||||||
|
param="model",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await openai_error_handler(
|
||||||
|
mock_invalid_request_error_model_access, settings=settings
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(exc_info.value, OpenAIError)
|
||||||
|
error: OpenAIError = exc_info.value
|
||||||
|
|
||||||
|
assert error.should_log == should_log
|
15
platform/reworkd_platform/tests/test_oauth_installers.py
Normal file
15
platform/reworkd_platform/tests/test_oauth_installers.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from reworkd_platform.services.oauth_installers import installer_factory
|
||||||
|
|
||||||
|
|
||||||
|
def test_installer_factory(mocker):
|
||||||
|
crud = mocker.Mock()
|
||||||
|
installer_factory("sid", crud)
|
||||||
|
|
||||||
|
|
||||||
|
def test_integration_dne(mocker):
|
||||||
|
crud = mocker.Mock()
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
installer_factory("asim", crud)
|
18
platform/reworkd_platform/tests/test_reworkd_platform.py
Normal file
18
platform/reworkd_platform/tests/test_reworkd_platform.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from httpx import AsyncClient
|
||||||
|
from starlette import status
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Mysql needs to be mocked")
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_health(client: AsyncClient, fastapi_app: FastAPI) -> None:
|
||||||
|
"""
|
||||||
|
Checks the health endpoint.
|
||||||
|
|
||||||
|
:param client: client for the app.
|
||||||
|
:param fastapi_app: current FastAPI application.
|
||||||
|
"""
|
||||||
|
url = fastapi_app.url_path_for("health_check")
|
||||||
|
response = await client.get(url)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
21
platform/reworkd_platform/tests/test_s3.py
Normal file
21
platform/reworkd_platform/tests/test_s3.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from reworkd_platform.services.aws.s3 import SimpleStorageService
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_signed_post(mocker):
|
||||||
|
post_url = {
|
||||||
|
"url": "https://my_bucket.s3.amazonaws.com/my_object",
|
||||||
|
"fields": {"key": "value"},
|
||||||
|
}
|
||||||
|
|
||||||
|
boto3_mock = mocker.Mock()
|
||||||
|
boto3_mock.generate_presigned_post.return_value = post_url
|
||||||
|
mocker.patch(
|
||||||
|
"reworkd_platform.services.aws.s3.boto3_client", return_value=boto3_mock
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
SimpleStorageService(bucket="my_bucket").create_presigned_upload_url(
|
||||||
|
object_name="json"
|
||||||
|
)
|
||||||
|
== post_url
|
||||||
|
)
|
61
platform/reworkd_platform/tests/test_schemas.py
Normal file
61
platform/reworkd_platform/tests/test_schemas.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from reworkd_platform.schemas.agent import ModelSettings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"settings",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"max_tokens": 7000,
|
||||||
|
"temperature": 0.5,
|
||||||
|
"language": "french",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"max_tokens": 3000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model": "gpt-3.5-turbo-16k",
|
||||||
|
"max_tokens": 16000,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_model_settings_valid(settings):
|
||||||
|
result = ModelSettings(**settings)
|
||||||
|
assert result.model == settings.get("model", "gpt-4o")
|
||||||
|
assert result.max_tokens == settings.get("max_tokens", 500)
|
||||||
|
assert result.temperature == settings.get("temperature", 0.9)
|
||||||
|
assert result.language == settings.get("language", "English")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"settings",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"model": "gpt-4-32k",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"temperature": -1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"max_tokens": 8000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model": "gpt-4",
|
||||||
|
"max_tokens": 32000,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_model_settings_invalid(settings):
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
ModelSettings(**settings)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_settings_default():
|
||||||
|
settings = ModelSettings(**{})
|
||||||
|
assert settings.model == "gpt-3.5-turbo"
|
||||||
|
assert settings.temperature == 0.9
|
||||||
|
assert settings.max_tokens == 500
|
||||||
|
assert settings.language == "English"
|
29
platform/reworkd_platform/tests/test_security.py
Normal file
29
platform/reworkd_platform/tests/test_security.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import pytest
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from reworkd_platform.services.security import EncryptionService
|
||||||
|
|
||||||
|
|
||||||
|
def test_encrypt_decrypt():
|
||||||
|
key = Fernet.generate_key()
|
||||||
|
service = EncryptionService(key)
|
||||||
|
|
||||||
|
original_text = "Hello, world!"
|
||||||
|
encrypted = service.encrypt(original_text)
|
||||||
|
decrypted = service.decrypt(encrypted)
|
||||||
|
|
||||||
|
assert original_text == decrypted
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_key():
|
||||||
|
key = Fernet.generate_key()
|
||||||
|
|
||||||
|
different_key = Fernet.generate_key()
|
||||||
|
different_service = EncryptionService(different_key)
|
||||||
|
|
||||||
|
original_text = "Hello, world!"
|
||||||
|
encrypted = Fernet(key).encrypt(original_text.encode())
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException):
|
||||||
|
different_service.decrypt(encrypted)
|
5
platform/reworkd_platform/tests/test_settings.py
Normal file
5
platform/reworkd_platform/tests/test_settings.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from reworkd_platform.settings import Settings
|
||||||
|
|
||||||
|
|
||||||
|
def test_settings_create():
|
||||||
|
assert Settings() is not None
|
105
platform/reworkd_platform/tests/test_token_service.py
Normal file
105
platform/reworkd_platform/tests/test_token_service.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
from reworkd_platform.schemas.agent import LLM_MODEL_MAX_TOKENS
|
||||||
|
from reworkd_platform.services.tokenizer.token_service import TokenService
|
||||||
|
|
||||||
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
|
||||||
|
def test_happy_path() -> None:
|
||||||
|
service = TokenService(encoding)
|
||||||
|
text = "Hello world!"
|
||||||
|
validate_tokenize_and_detokenize(service, text, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nothing() -> None:
|
||||||
|
service = TokenService(encoding)
|
||||||
|
text = ""
|
||||||
|
validate_tokenize_and_detokenize(service, text, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_tokenize_and_detokenize(
|
||||||
|
service: TokenService, text: str, expected_token_count: int
|
||||||
|
) -> None:
|
||||||
|
tokens = service.tokenize(text)
|
||||||
|
assert text == service.detokenize(tokens)
|
||||||
|
assert len(tokens) == service.count(text)
|
||||||
|
assert len(tokens) == expected_token_count
|
||||||
|
|
||||||
|
|
||||||
|
def test_calculate_max_tokens_with_small_max_tokens() -> None:
|
||||||
|
initial_max_tokens = 3000
|
||||||
|
service = TokenService(encoding)
|
||||||
|
model = Mock(spec=["model_name", "max_tokens"])
|
||||||
|
model.model_name = "gpt-3.5-turbo"
|
||||||
|
model.max_tokens = initial_max_tokens
|
||||||
|
|
||||||
|
service.calculate_max_tokens(model, "Hello")
|
||||||
|
|
||||||
|
assert model.max_tokens == initial_max_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def test_calculate_max_tokens_with_high_completion_tokens() -> None:
|
||||||
|
service = TokenService(encoding)
|
||||||
|
prompt_tokens = service.count(LONG_TEXT)
|
||||||
|
model = Mock(spec=["model_name", "max_tokens"])
|
||||||
|
model.model_name = "gpt-3.5-turbo"
|
||||||
|
model.max_tokens = 8000
|
||||||
|
|
||||||
|
service.calculate_max_tokens(model, LONG_TEXT)
|
||||||
|
|
||||||
|
assert model.max_tokens == (
|
||||||
|
LLM_MODEL_MAX_TOKENS.get("gpt-3.5-turbo") - prompt_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_calculate_max_tokens_with_negative_result() -> None:
|
||||||
|
service = TokenService(encoding)
|
||||||
|
model = Mock(spec=["model_name", "max_tokens"])
|
||||||
|
model.model_name = "gpt-3.5-turbo"
|
||||||
|
model.max_tokens = 8000
|
||||||
|
|
||||||
|
service.calculate_max_tokens(model, *([LONG_TEXT] * 100))
|
||||||
|
|
||||||
|
# We use the minimum length of 1
|
||||||
|
assert model.max_tokens == 1
|
||||||
|
|
||||||
|
|
||||||
|
LONG_TEXT = """
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
This is some long text. This is some long text. This is some long text.
|
||||||
|
"""
|
34
platform/reworkd_platform/timer.py
Normal file
34
platform/reworkd_platform/timer.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from functools import wraps
|
||||||
|
from time import time
|
||||||
|
from typing import Any, Callable, Literal
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
Log_Level = Literal[
|
||||||
|
"TRACE",
|
||||||
|
"DEBUG",
|
||||||
|
"INFO",
|
||||||
|
"SUCCESS",
|
||||||
|
"WARNING",
|
||||||
|
"ERROR",
|
||||||
|
"CRITICAL",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def timed_function(level: Log_Level = "INFO") -> Callable[..., Any]:
|
||||||
|
def decorator(func: Any) -> Callable[..., Any]:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
start_time = time()
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
execution_time = time() - start_time
|
||||||
|
logger.log(
|
||||||
|
level,
|
||||||
|
f"Function '{func.__qualname__}' executed in {execution_time:.4f} seconds",
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
1
platform/reworkd_platform/web/__init__.py
Normal file
1
platform/reworkd_platform/web/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""WEB API for reworkd_platform."""
|
1
platform/reworkd_platform/web/api/__init__.py
Normal file
1
platform/reworkd_platform/web/api/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""reworkd_platform API package."""
|
4
platform/reworkd_platform/web/api/agent/__init__.py
Normal file
4
platform/reworkd_platform/web/api/agent/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
"""API for checking running agents"""
|
||||||
|
from reworkd_platform.web.api.agent.views import router
|
||||||
|
|
||||||
|
__all__ = ["router"]
|
@ -0,0 +1,51 @@
|
|||||||
|
from typing import List, Optional, Protocol
|
||||||
|
|
||||||
|
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
|
||||||
|
|
||||||
|
from reworkd_platform.web.api.agent.analysis import Analysis
|
||||||
|
|
||||||
|
|
||||||
|
class AgentService(Protocol):
|
||||||
|
async def start_goal_agent(self, *, goal: str) -> List[str]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def analyze_task_agent(
|
||||||
|
self, *, goal: str, task: str, tool_names: List[str]
|
||||||
|
) -> Analysis:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def execute_task_agent(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
goal: str,
|
||||||
|
task: str,
|
||||||
|
analysis: Analysis,
|
||||||
|
) -> FastAPIStreamingResponse:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def create_tasks_agent(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
goal: str,
|
||||||
|
tasks: List[str],
|
||||||
|
last_task: str,
|
||||||
|
result: str,
|
||||||
|
completed_tasks: Optional[List[str]] = None,
|
||||||
|
) -> List[str]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def summarize_task_agent(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
goal: str,
|
||||||
|
results: List[str],
|
||||||
|
) -> FastAPIStreamingResponse:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
message: str,
|
||||||
|
results: List[str],
|
||||||
|
) -> FastAPIStreamingResponse:
|
||||||
|
pass
|
@ -0,0 +1,53 @@
|
|||||||
|
from typing import Any, Callable, Coroutine, Optional
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
|
||||||
|
from reworkd_platform.db.crud.oauth import OAuthCrud
|
||||||
|
from reworkd_platform.schemas.agent import AgentRun, LLM_Model
|
||||||
|
from reworkd_platform.schemas.user import UserBase
|
||||||
|
from reworkd_platform.services.tokenizer.dependencies import get_token_service
|
||||||
|
from reworkd_platform.services.tokenizer.token_service import TokenService
|
||||||
|
from reworkd_platform.settings import settings
|
||||||
|
from reworkd_platform.web.api.agent.agent_service.agent_service import AgentService
|
||||||
|
from reworkd_platform.web.api.agent.agent_service.mock_agent_service import (
|
||||||
|
MockAgentService,
|
||||||
|
)
|
||||||
|
from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import (
|
||||||
|
OpenAIAgentService,
|
||||||
|
)
|
||||||
|
from reworkd_platform.web.api.agent.model_factory import create_model
|
||||||
|
from reworkd_platform.web.api.dependencies import get_current_user
|
||||||
|
|
||||||
|
|
||||||
|
def get_agent_service(
|
||||||
|
validator: Callable[..., Coroutine[Any, Any, AgentRun]],
|
||||||
|
streaming: bool = False,
|
||||||
|
llm_model: Optional[LLM_Model] = None,
|
||||||
|
) -> Callable[..., AgentService]:
|
||||||
|
def func(
|
||||||
|
run: AgentRun = Depends(validator),
|
||||||
|
user: UserBase = Depends(get_current_user),
|
||||||
|
token_service: TokenService = Depends(get_token_service),
|
||||||
|
oauth_crud: OAuthCrud = Depends(OAuthCrud.inject),
|
||||||
|
) -> AgentService:
|
||||||
|
if settings.ff_mock_mode_enabled:
|
||||||
|
return MockAgentService()
|
||||||
|
|
||||||
|
model = create_model(
|
||||||
|
settings,
|
||||||
|
run.model_settings,
|
||||||
|
user,
|
||||||
|
streaming=streaming,
|
||||||
|
force_model=llm_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
return OpenAIAgentService(
|
||||||
|
model,
|
||||||
|
run.model_settings,
|
||||||
|
token_service,
|
||||||
|
callbacks=None,
|
||||||
|
user=user,
|
||||||
|
oauth_crud=oauth_crud,
|
||||||
|
)
|
||||||
|
|
||||||
|
return func
|
@ -0,0 +1,74 @@
|
|||||||
|
import time
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
|
||||||
|
|
||||||
|
from reworkd_platform.web.api.agent.agent_service.agent_service import (
|
||||||
|
AgentService,
|
||||||
|
Analysis,
|
||||||
|
)
|
||||||
|
from reworkd_platform.web.api.agent.stream_mock import stream_string
|
||||||
|
|
||||||
|
|
||||||
|
class MockAgentService(AgentService):
|
||||||
|
async def start_goal_agent(self, **kwargs: Any) -> List[str]:
|
||||||
|
time.sleep(1)
|
||||||
|
return ["Task X", "Task Y", "Task Z"]
|
||||||
|
|
||||||
|
async def create_tasks_agent(self, **kwargs: Any) -> List[str]:
|
||||||
|
time.sleep(1)
|
||||||
|
return ["Some random task that doesn't exist"]
|
||||||
|
|
||||||
|
async def analyze_task_agent(self, **kwargs: Any) -> Analysis:
|
||||||
|
time.sleep(1.5)
|
||||||
|
return Analysis(
|
||||||
|
action="reason",
|
||||||
|
arg="Mock analysis",
|
||||||
|
reasoning="Mock to avoid wasting money calling the OpenAI API.",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute_task_agent(self, **kwargs: Any) -> FastAPIStreamingResponse:
|
||||||
|
time.sleep(0.5)
|
||||||
|
return stream_string(
|
||||||
|
""" This is going to be a longer task result such that
|
||||||
|
We make the stream of this string take time and feel long. The reality is... this is a mock!
|
||||||
|
|
||||||
|
Lorem Ipsum is simply dummy text of the printing and typesetting industry.
|
||||||
|
Lorem Ipsum has been the industry's standard dummy text ever since the 1500s,
|
||||||
|
when an unknown printer took a galley of type and scrambled it to make a type specimen book.
|
||||||
|
It has survived not only five centuries, but also the leap into electronic typesetting, remaining unchanged.
|
||||||
|
"""
|
||||||
|
+ kwargs.get("task", "task"),
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def summarize_task_agent(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
goal: str,
|
||||||
|
results: List[str],
|
||||||
|
) -> FastAPIStreamingResponse:
|
||||||
|
time.sleep(0.5)
|
||||||
|
return stream_string(
|
||||||
|
""" This is going to be a longer task result such that
|
||||||
|
We make the stream of this string take time and feel long. The reality is... this is a mock!
|
||||||
|
|
||||||
|
Lorem Ipsum is simply dummy text of the printing and typesetting industry.
|
||||||
|
Lorem Ipsum has been the industry's standard dummy text ever since the 1500s,
|
||||||
|
when an unknown printer took a galley of type and scrambled it to make a type specimen book.
|
||||||
|
It has survived not only five centuries, but also the leap into electronic typesetting, remaining unchanged.
|
||||||
|
""",
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
message: str,
|
||||||
|
results: List[str],
|
||||||
|
) -> FastAPIStreamingResponse:
|
||||||
|
time.sleep(0.5)
|
||||||
|
return stream_string(
|
||||||
|
"What do you want dude?",
|
||||||
|
True,
|
||||||
|
)
|
@ -0,0 +1,227 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
|
||||||
|
from lanarky.responses import StreamingResponse
|
||||||
|
from langchain import LLMChain
|
||||||
|
from langchain.callbacks.base import AsyncCallbackHandler
|
||||||
|
from langchain.output_parsers import PydanticOutputParser
|
||||||
|
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
|
||||||
|
from langchain.schema import HumanMessage
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from reworkd_platform.db.crud.oauth import OAuthCrud
|
||||||
|
from reworkd_platform.schemas.agent import ModelSettings
|
||||||
|
from reworkd_platform.schemas.user import UserBase
|
||||||
|
from reworkd_platform.services.tokenizer.token_service import TokenService
|
||||||
|
from reworkd_platform.web.api.agent.agent_service.agent_service import AgentService
|
||||||
|
from reworkd_platform.web.api.agent.analysis import Analysis, AnalysisArguments
|
||||||
|
from reworkd_platform.web.api.agent.helpers import (
|
||||||
|
call_model_with_handling,
|
||||||
|
openai_error_handler,
|
||||||
|
parse_with_handling,
|
||||||
|
)
|
||||||
|
from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
|
||||||
|
from reworkd_platform.web.api.agent.prompts import (
|
||||||
|
analyze_task_prompt,
|
||||||
|
chat_prompt,
|
||||||
|
create_tasks_prompt,
|
||||||
|
start_goal_prompt,
|
||||||
|
)
|
||||||
|
from reworkd_platform.web.api.agent.task_output_parser import TaskOutputParser
|
||||||
|
from reworkd_platform.web.api.agent.tools.open_ai_function import get_tool_function
|
||||||
|
from reworkd_platform.web.api.agent.tools.tools import (
|
||||||
|
get_default_tool,
|
||||||
|
get_tool_from_name,
|
||||||
|
get_tool_name,
|
||||||
|
get_user_tools,
|
||||||
|
)
|
||||||
|
from reworkd_platform.web.api.agent.tools.utils import summarize
|
||||||
|
from reworkd_platform.web.api.errors import OpenAIError
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIAgentService(AgentService):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: WrappedChatOpenAI,
|
||||||
|
settings: ModelSettings,
|
||||||
|
token_service: TokenService,
|
||||||
|
callbacks: Optional[List[AsyncCallbackHandler]],
|
||||||
|
user: UserBase,
|
||||||
|
oauth_crud: OAuthCrud,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.settings = settings
|
||||||
|
self.token_service = token_service
|
||||||
|
self.callbacks = callbacks
|
||||||
|
self.user = user
|
||||||
|
self.oauth_crud = oauth_crud
|
||||||
|
|
||||||
|
async def start_goal_agent(self, *, goal: str) -> List[str]:
|
||||||
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[SystemMessagePromptTemplate(prompt=start_goal_prompt)]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.token_service.calculate_max_tokens(
|
||||||
|
self.model,
|
||||||
|
prompt.format_prompt(
|
||||||
|
goal=goal,
|
||||||
|
language=self.settings.language,
|
||||||
|
).to_string(),
|
||||||
|
)
|
||||||
|
|
||||||
|
completion = await call_model_with_handling(
|
||||||
|
self.model,
|
||||||
|
ChatPromptTemplate.from_messages(
|
||||||
|
[SystemMessagePromptTemplate(prompt=start_goal_prompt)]
|
||||||
|
),
|
||||||
|
{"goal": goal, "language": self.settings.language},
|
||||||
|
settings=self.settings,
|
||||||
|
callbacks=self.callbacks,
|
||||||
|
)
|
||||||
|
|
||||||
|
task_output_parser = TaskOutputParser(completed_tasks=[])
|
||||||
|
tasks = parse_with_handling(task_output_parser, completion)
|
||||||
|
|
||||||
|
return tasks
|
||||||
|
|
||||||
|
async def analyze_task_agent(
|
||||||
|
self, *, goal: str, task: str, tool_names: List[str]
|
||||||
|
) -> Analysis:
|
||||||
|
user_tools = await get_user_tools(tool_names, self.user, self.oauth_crud)
|
||||||
|
functions = list(map(get_tool_function, user_tools))
|
||||||
|
prompt = analyze_task_prompt.format_prompt(
|
||||||
|
goal=goal,
|
||||||
|
task=task,
|
||||||
|
language=self.settings.language,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.token_service.calculate_max_tokens(
|
||||||
|
self.model,
|
||||||
|
prompt.to_string(),
|
||||||
|
str(functions),
|
||||||
|
)
|
||||||
|
|
||||||
|
message = await openai_error_handler(
|
||||||
|
func=self.model.apredict_messages,
|
||||||
|
messages=prompt.to_messages(),
|
||||||
|
functions=functions,
|
||||||
|
settings=self.settings,
|
||||||
|
callbacks=self.callbacks,
|
||||||
|
)
|
||||||
|
|
||||||
|
function_call = message.additional_kwargs.get("function_call", {})
|
||||||
|
completion = function_call.get("arguments", "")
|
||||||
|
|
||||||
|
try:
|
||||||
|
pydantic_parser = PydanticOutputParser(pydantic_object=AnalysisArguments)
|
||||||
|
analysis_arguments = parse_with_handling(pydantic_parser, completion)
|
||||||
|
return Analysis(
|
||||||
|
action=function_call.get("name", get_tool_name(get_default_tool())),
|
||||||
|
**analysis_arguments.dict(),
|
||||||
|
)
|
||||||
|
except (OpenAIError, ValidationError):
|
||||||
|
return Analysis.get_default_analysis(task)
|
||||||
|
|
||||||
|
async def execute_task_agent(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
goal: str,
|
||||||
|
task: str,
|
||||||
|
analysis: Analysis,
|
||||||
|
) -> StreamingResponse:
|
||||||
|
# TODO: More mature way of calculating max_tokens
|
||||||
|
if self.model.max_tokens > 3000:
|
||||||
|
self.model.max_tokens = max(self.model.max_tokens - 1000, 3000)
|
||||||
|
|
||||||
|
tool_class = get_tool_from_name(analysis.action)
|
||||||
|
return await tool_class(self.model, self.settings.language).call(
|
||||||
|
goal,
|
||||||
|
task,
|
||||||
|
analysis.arg,
|
||||||
|
self.user,
|
||||||
|
self.oauth_crud,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_tasks_agent(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
goal: str,
|
||||||
|
tasks: List[str],
|
||||||
|
last_task: str,
|
||||||
|
result: str,
|
||||||
|
completed_tasks: Optional[List[str]] = None,
|
||||||
|
) -> List[str]:
|
||||||
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[SystemMessagePromptTemplate(prompt=create_tasks_prompt)]
|
||||||
|
)
|
||||||
|
|
||||||
|
args = {
|
||||||
|
"goal": goal,
|
||||||
|
"language": self.settings.language,
|
||||||
|
"tasks": "\n".join(tasks),
|
||||||
|
"lastTask": last_task,
|
||||||
|
"result": result,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.token_service.calculate_max_tokens(
|
||||||
|
self.model, prompt.format_prompt(**args).to_string()
|
||||||
|
)
|
||||||
|
|
||||||
|
completion = await call_model_with_handling(
|
||||||
|
self.model, prompt, args, settings=self.settings, callbacks=self.callbacks
|
||||||
|
)
|
||||||
|
|
||||||
|
previous_tasks = (completed_tasks or []) + tasks
|
||||||
|
return [completion] if completion not in previous_tasks else []
|
||||||
|
|
||||||
|
async def summarize_task_agent(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
goal: str,
|
||||||
|
results: List[str],
|
||||||
|
) -> FastAPIStreamingResponse:
|
||||||
|
self.model.model_name = "gpt-4o"
|
||||||
|
self.model.max_tokens = 8000 # Total tokens = prompt tokens + completion tokens
|
||||||
|
|
||||||
|
snippet_max_tokens = 7000 # Leave room for the rest of the prompt
|
||||||
|
text_tokens = self.token_service.tokenize("".join(results))
|
||||||
|
text = self.token_service.detokenize(text_tokens[0:snippet_max_tokens])
|
||||||
|
logger.info(f"Summarizing text: {text}")
|
||||||
|
|
||||||
|
return summarize(
|
||||||
|
model=self.model,
|
||||||
|
language=self.settings.language,
|
||||||
|
goal=goal,
|
||||||
|
text=text,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
message: str,
|
||||||
|
results: List[str],
|
||||||
|
) -> FastAPIStreamingResponse:
|
||||||
|
self.model.model_name = "gpt-4o"
|
||||||
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
SystemMessagePromptTemplate(prompt=chat_prompt),
|
||||||
|
*[HumanMessage(content=result) for result in results],
|
||||||
|
HumanMessage(content=message),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.token_service.calculate_max_tokens(
|
||||||
|
self.model,
|
||||||
|
prompt.format_prompt(
|
||||||
|
language=self.settings.language,
|
||||||
|
).to_string(),
|
||||||
|
)
|
||||||
|
|
||||||
|
chain = LLMChain(llm=self.model, prompt=prompt)
|
||||||
|
|
||||||
|
return StreamingResponse.from_chain(
|
||||||
|
chain,
|
||||||
|
{"language": self.settings.language},
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
45
platform/reworkd_platform/web/api/agent/analysis.py
Normal file
45
platform/reworkd_platform/web/api/agent/analysis.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from pydantic import BaseModel, validator
|
||||||
|
|
||||||
|
|
||||||
|
class AnalysisArguments(BaseModel):
|
||||||
|
"""
|
||||||
|
Arguments for the analysis function of a tool. OpenAI functions will resolve these values but leave out the action.
|
||||||
|
"""
|
||||||
|
|
||||||
|
reasoning: str
|
||||||
|
arg: str
|
||||||
|
|
||||||
|
|
||||||
|
class Analysis(AnalysisArguments):
|
||||||
|
action: str
|
||||||
|
|
||||||
|
@validator("action")
|
||||||
|
def action_must_be_valid_tool(cls, v: str) -> str:
|
||||||
|
# TODO: Remove circular import
|
||||||
|
from reworkd_platform.web.api.agent.tools.tools import get_available_tools_names
|
||||||
|
|
||||||
|
if v not in get_available_tools_names():
|
||||||
|
raise ValueError(f"Analysis action '{v}' is not a valid tool")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@validator("action")
|
||||||
|
def search_action_must_have_arg(cls, v: str, values: Dict[str, str]) -> str:
|
||||||
|
from reworkd_platform.web.api.agent.tools.search import Search
|
||||||
|
from reworkd_platform.web.api.agent.tools.tools import get_tool_name
|
||||||
|
|
||||||
|
if v == get_tool_name(Search) and not values["arg"]:
|
||||||
|
raise ValueError("Analysis arg cannot be empty if action is 'search'")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_default_analysis(cls, task: str) -> "Analysis":
|
||||||
|
# TODO: Remove circular import
|
||||||
|
from reworkd_platform.web.api.agent.tools.tools import get_default_tool_name
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
reasoning="Hmm... I'll try searching it up",
|
||||||
|
action=get_default_tool_name(),
|
||||||
|
arg=task,
|
||||||
|
)
|
95
platform/reworkd_platform/web/api/agent/dependancies.py
Normal file
95
platform/reworkd_platform/web/api/agent/dependancies.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from fastapi import Body, Depends
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from reworkd_platform.db.crud.agent import AgentCRUD
|
||||||
|
from reworkd_platform.db.dependencies import get_db_session
|
||||||
|
from reworkd_platform.schemas.agent import (
|
||||||
|
AgentChat,
|
||||||
|
AgentRun,
|
||||||
|
AgentRunCreate,
|
||||||
|
AgentSummarize,
|
||||||
|
AgentTaskAnalyze,
|
||||||
|
AgentTaskCreate,
|
||||||
|
AgentTaskExecute,
|
||||||
|
Loop_Step,
|
||||||
|
)
|
||||||
|
from reworkd_platform.schemas.user import UserBase
|
||||||
|
from reworkd_platform.web.api.dependencies import get_current_user
|
||||||
|
|
||||||
|
T = TypeVar(
|
||||||
|
"T", AgentTaskAnalyze, AgentTaskExecute, AgentTaskCreate, AgentSummarize, AgentChat
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def agent_crud(
|
||||||
|
user: UserBase = Depends(get_current_user),
|
||||||
|
session: AsyncSession = Depends(get_db_session),
|
||||||
|
) -> AgentCRUD:
|
||||||
|
return AgentCRUD(session, user)
|
||||||
|
|
||||||
|
|
||||||
|
async def agent_start_validator(
|
||||||
|
body: AgentRunCreate = Body(
|
||||||
|
example={
|
||||||
|
"goal": "Create business plan for a bagel company",
|
||||||
|
"modelSettings": {
|
||||||
|
"customModelName": "gpt-4o",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
crud: AgentCRUD = Depends(agent_crud),
|
||||||
|
) -> AgentRun:
|
||||||
|
id_ = (await crud.create_run(body.goal)).id
|
||||||
|
return AgentRun(**body.dict(), run_id=str(id_))
|
||||||
|
|
||||||
|
|
||||||
|
async def validate(body: T, crud: AgentCRUD, type_: Loop_Step) -> T:
|
||||||
|
body.run_id = (await crud.create_task(body.run_id, type_)).id
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
async def agent_analyze_validator(
|
||||||
|
body: AgentTaskAnalyze = Body(),
|
||||||
|
crud: AgentCRUD = Depends(agent_crud),
|
||||||
|
) -> AgentTaskAnalyze:
|
||||||
|
return await validate(body, crud, "analyze")
|
||||||
|
|
||||||
|
|
||||||
|
async def agent_execute_validator(
|
||||||
|
body: AgentTaskExecute = Body(
|
||||||
|
example={
|
||||||
|
"goal": "Perform tasks accurately",
|
||||||
|
"task": "Write code to make a platformer",
|
||||||
|
"analysis": {
|
||||||
|
"reasoning": "I like to write code.",
|
||||||
|
"action": "code",
|
||||||
|
"arg": "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
crud: AgentCRUD = Depends(agent_crud),
|
||||||
|
) -> AgentTaskExecute:
|
||||||
|
return await validate(body, crud, "execute")
|
||||||
|
|
||||||
|
|
||||||
|
async def agent_create_validator(
|
||||||
|
body: AgentTaskCreate = Body(),
|
||||||
|
crud: AgentCRUD = Depends(agent_crud),
|
||||||
|
) -> AgentTaskCreate:
|
||||||
|
return await validate(body, crud, "create")
|
||||||
|
|
||||||
|
|
||||||
|
async def agent_summarize_validator(
|
||||||
|
body: AgentSummarize = Body(),
|
||||||
|
crud: AgentCRUD = Depends(agent_crud),
|
||||||
|
) -> AgentSummarize:
|
||||||
|
return await validate(body, crud, "summarize")
|
||||||
|
|
||||||
|
|
||||||
|
async def agent_chat_validator(
|
||||||
|
body: AgentChat = Body(),
|
||||||
|
crud: AgentCRUD = Depends(agent_crud),
|
||||||
|
) -> AgentChat:
|
||||||
|
return await validate(body, crud, "chat")
|
76
platform/reworkd_platform/web/api/agent/helpers.py
Normal file
76
platform/reworkd_platform/web/api/agent/helpers.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
from typing import Any, Callable, Dict, TypeVar
|
||||||
|
|
||||||
|
from langchain import BasePromptTemplate, LLMChain
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from langchain.schema import BaseOutputParser, OutputParserException
|
||||||
|
from openai.error import (
|
||||||
|
AuthenticationError,
|
||||||
|
InvalidRequestError,
|
||||||
|
RateLimitError,
|
||||||
|
ServiceUnavailableError,
|
||||||
|
)
|
||||||
|
|
||||||
|
from reworkd_platform.schemas.agent import ModelSettings
|
||||||
|
from reworkd_platform.web.api.errors import OpenAIError
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_with_handling(parser: BaseOutputParser[T], completion: str) -> T:
|
||||||
|
try:
|
||||||
|
return parser.parse(completion)
|
||||||
|
except OutputParserException as e:
|
||||||
|
raise OpenAIError(
|
||||||
|
e, "There was an issue parsing the response from the AI model."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def openai_error_handler(
|
||||||
|
func: Callable[..., Any], *args: Any, settings: ModelSettings, **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
try:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
except ServiceUnavailableError as e:
|
||||||
|
raise OpenAIError(
|
||||||
|
e,
|
||||||
|
"OpenAI is experiencing issues. Visit "
|
||||||
|
"https://status.openai.com/ for more info.",
|
||||||
|
should_log=not settings.custom_api_key,
|
||||||
|
)
|
||||||
|
except InvalidRequestError as e:
|
||||||
|
if e.user_message.startswith("The model:"):
|
||||||
|
raise OpenAIError(
|
||||||
|
e,
|
||||||
|
f"Your API key does not have access to your current model. Please use a different model.",
|
||||||
|
should_log=not settings.custom_api_key,
|
||||||
|
)
|
||||||
|
raise OpenAIError(e, e.user_message)
|
||||||
|
except AuthenticationError as e:
|
||||||
|
raise OpenAIError(
|
||||||
|
e,
|
||||||
|
"Authentication error: Ensure a valid API key is being used.",
|
||||||
|
should_log=not settings.custom_api_key,
|
||||||
|
)
|
||||||
|
except RateLimitError as e:
|
||||||
|
if e.user_message.startswith("You exceeded your current quota"):
|
||||||
|
raise OpenAIError(
|
||||||
|
e,
|
||||||
|
f"Your API key exceeded your current quota, please check your plan and billing details.",
|
||||||
|
should_log=not settings.custom_api_key,
|
||||||
|
)
|
||||||
|
raise OpenAIError(e, e.user_message)
|
||||||
|
except Exception as e:
|
||||||
|
raise OpenAIError(
|
||||||
|
e, "There was an unexpected issue getting a response from the AI model."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def call_model_with_handling(
|
||||||
|
model: BaseChatModel,
|
||||||
|
prompt: BasePromptTemplate,
|
||||||
|
args: Dict[str, str],
|
||||||
|
settings: ModelSettings,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
chain = LLMChain(llm=model, prompt=prompt)
|
||||||
|
return await openai_error_handler(chain.arun, args, settings=settings, **kwargs)
|
97
platform/reworkd_platform/web/api/agent/model_factory.py
Normal file
97
platform/reworkd_platform/web/api/agent/model_factory.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
from typing import Any, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from reworkd_platform.schemas.agent import LLM_Model, ModelSettings
|
||||||
|
from reworkd_platform.schemas.user import UserBase
|
||||||
|
from reworkd_platform.settings import Settings
|
||||||
|
|
||||||
|
|
||||||
|
class WrappedChatOpenAI(ChatOpenAI):
|
||||||
|
client: Any = Field(
|
||||||
|
default=None,
|
||||||
|
description="Meta private value but mypy will complain its missing",
|
||||||
|
)
|
||||||
|
max_tokens: int
|
||||||
|
model_name: LLM_Model = Field(alias="model")
|
||||||
|
|
||||||
|
|
||||||
|
class WrappedAzureChatOpenAI(AzureChatOpenAI, WrappedChatOpenAI):
|
||||||
|
openai_api_base: str
|
||||||
|
openai_api_version: str
|
||||||
|
deployment_name: str
|
||||||
|
|
||||||
|
|
||||||
|
WrappedChat = Union[WrappedAzureChatOpenAI, WrappedChatOpenAI]
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(
|
||||||
|
settings: Settings,
|
||||||
|
model_settings: ModelSettings,
|
||||||
|
user: UserBase,
|
||||||
|
streaming: bool = False,
|
||||||
|
force_model: Optional[LLM_Model] = None,
|
||||||
|
) -> WrappedChat:
|
||||||
|
use_azure = (
|
||||||
|
not model_settings.custom_api_key and "azure" in settings.openai_api_base
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_model = force_model or model_settings.model
|
||||||
|
model: Type[WrappedChat] = WrappedChatOpenAI
|
||||||
|
base, headers, use_helicone = get_base_and_headers(settings, model_settings, user)
|
||||||
|
kwargs = {
|
||||||
|
"openai_api_base": base,
|
||||||
|
"openai_api_key": model_settings.custom_api_key or settings.openai_api_key,
|
||||||
|
"temperature": model_settings.temperature,
|
||||||
|
"model": llm_model,
|
||||||
|
"max_tokens": model_settings.max_tokens,
|
||||||
|
"streaming": streaming,
|
||||||
|
"max_retries": 5,
|
||||||
|
"model_kwargs": {"user": user.email, "headers": headers},
|
||||||
|
}
|
||||||
|
|
||||||
|
if use_azure:
|
||||||
|
model = WrappedAzureChatOpenAI
|
||||||
|
deployment_name = llm_model.replace(".", "")
|
||||||
|
kwargs.update(
|
||||||
|
{
|
||||||
|
"openai_api_version": settings.openai_api_version,
|
||||||
|
"deployment_name": deployment_name,
|
||||||
|
"openai_api_type": "azure",
|
||||||
|
"openai_api_base": base.rstrip("v1"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_helicone:
|
||||||
|
kwargs["model"] = deployment_name
|
||||||
|
|
||||||
|
return model(**kwargs) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def get_base_and_headers(
|
||||||
|
settings_: Settings, model_settings: ModelSettings, user: UserBase
|
||||||
|
) -> Tuple[str, Optional[Dict[str, str]], bool]:
|
||||||
|
use_helicone = settings_.helicone_enabled and not model_settings.custom_api_key
|
||||||
|
base = (
|
||||||
|
settings_.helicone_api_base
|
||||||
|
if use_helicone
|
||||||
|
else (
|
||||||
|
"https://api.openai.com/v1"
|
||||||
|
if model_settings.custom_api_key
|
||||||
|
else settings_.openai_api_base
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = (
|
||||||
|
{
|
||||||
|
"Helicone-Auth": f"Bearer {settings_.helicone_api_key}",
|
||||||
|
"Helicone-Cache-Enabled": "true",
|
||||||
|
"Helicone-User-Id": user.id,
|
||||||
|
"Helicone-OpenAI-Api-Base": settings_.openai_api_base,
|
||||||
|
}
|
||||||
|
if use_helicone
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return base, headers, use_helicone
|
185
platform/reworkd_platform/web/api/agent/prompts.py
Normal file
185
platform/reworkd_platform/web/api/agent/prompts.py
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
from langchain import PromptTemplate
|
||||||
|
|
||||||
|
# Create initial tasks using plan and solve prompting
|
||||||
|
# https://github.com/AGI-Edgerunners/Plan-and-Solve-Prompting
|
||||||
|
start_goal_prompt = PromptTemplate(
|
||||||
|
template="""You are a task creation AI called Allinix.
|
||||||
|
You answer in the "{language}" language. You have the following objective "{goal}".
|
||||||
|
Return a list of search queries that would be required to answer the entirety of the objective.
|
||||||
|
Limit the list to a maximum of 5 queries. Ensure the queries are as succinct as possible.
|
||||||
|
For simple questions use a single query.
|
||||||
|
|
||||||
|
Return the response as a JSON array of strings. Examples:
|
||||||
|
|
||||||
|
query: "Who is considered the best NBA player in the current season?", answer: ["current NBA MVP candidates"]
|
||||||
|
query: "How does the Olympicpayroll brand currently stand in the market, and what are its prospects and strategies for expansion in NJ, NY, and PA?", answer: ["Olympicpayroll brand comprehensive analysis 2023", "customer reviews of Olympicpayroll.com", "Olympicpayroll market position analysis", "payroll industry trends forecast 2023-2025", "payroll services expansion strategies in NJ, NY, PA"]
|
||||||
|
query: "How can I create a function to add weight to edges in a digraph using {language}?", answer: ["algorithm to add weight to digraph edge in {language}"]
|
||||||
|
query: "What is the current weather in New York?", answer: ["current weather in New York"]
|
||||||
|
query: "5 + 5?", answer: ["Sum of 5 and 5"]
|
||||||
|
query: "What is a good homemade recipe for KFC-style chicken?", answer: ["KFC style chicken recipe at home"]
|
||||||
|
query: "What are the nutritional values of almond milk and soy milk?", answer: ["nutritional information of almond milk", "nutritional information of soy milk"]""",
|
||||||
|
input_variables=["goal", "language"],
|
||||||
|
)
|
||||||
|
|
||||||
|
analyze_task_prompt = PromptTemplate(
|
||||||
|
template="""
|
||||||
|
High level objective: "{goal}"
|
||||||
|
Current task: "{task}"
|
||||||
|
|
||||||
|
Based on this information, use the best function to make progress or accomplish the task entirely.
|
||||||
|
Select the correct function by being smart and efficient. Ensure "reasoning" and only "reasoning" is in the
|
||||||
|
{language} language.
|
||||||
|
|
||||||
|
Note you MUST select a function.
|
||||||
|
""",
|
||||||
|
input_variables=["goal", "task", "language"],
|
||||||
|
)
|
||||||
|
|
||||||
|
code_prompt = PromptTemplate(
|
||||||
|
template="""
|
||||||
|
You are a world-class software engineer and an expert in all programing languages,
|
||||||
|
software systems, and architecture.
|
||||||
|
|
||||||
|
For reference, your high level goal is {goal}
|
||||||
|
|
||||||
|
Write code in English but explanations/comments in the "{language}" language.
|
||||||
|
|
||||||
|
Provide no information about who you are and focus on writing code.
|
||||||
|
Ensure code is bug and error free and explain complex concepts through comments
|
||||||
|
Respond in well-formatted markdown. Ensure code blocks are used for code sections.
|
||||||
|
Approach problems step by step and file by file, for each section, use a heading to describe the section.
|
||||||
|
|
||||||
|
Write code to accomplish the following:
|
||||||
|
{task}
|
||||||
|
""",
|
||||||
|
input_variables=["goal", "language", "task"],
|
||||||
|
)
|
||||||
|
|
||||||
|
execute_task_prompt = PromptTemplate(
|
||||||
|
template="""Answer in the "{language}" language. Given
|
||||||
|
the following overall objective `{goal}` and the following sub-task, `{task}`.
|
||||||
|
|
||||||
|
Perform the task by understanding the problem, extracting variables, and being smart
|
||||||
|
and efficient. Write a detailed response that address the task.
|
||||||
|
When confronted with choices, make a decision yourself with reasoning.
|
||||||
|
""",
|
||||||
|
input_variables=["goal", "language", "task"],
|
||||||
|
)
|
||||||
|
|
||||||
|
create_tasks_prompt = PromptTemplate(
|
||||||
|
template="""You are an AI task creation agent. You must answer in the "{language}"
|
||||||
|
language. You have the following objective `{goal}`.
|
||||||
|
|
||||||
|
You have the following incomplete tasks:
|
||||||
|
`{tasks}`
|
||||||
|
|
||||||
|
You just completed the following task:
|
||||||
|
`{lastTask}`
|
||||||
|
|
||||||
|
And received the following result:
|
||||||
|
`{result}`.
|
||||||
|
|
||||||
|
Based on this, create a single new task to be completed by your AI system such that your goal is closer reached.
|
||||||
|
If there are no more tasks to be done, return nothing. Do not add quotes to the task.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Search the web for NBA news
|
||||||
|
Create a function to add a new vertex with a specified weight to the digraph.
|
||||||
|
Search for any additional information on Bertie W.
|
||||||
|
""
|
||||||
|
""",
|
||||||
|
input_variables=["goal", "language", "tasks", "lastTask", "result"],
|
||||||
|
)
|
||||||
|
|
||||||
|
summarize_prompt = PromptTemplate(
|
||||||
|
template="""You must answer in the "{language}" language.
|
||||||
|
|
||||||
|
Combine the following text into a cohesive document:
|
||||||
|
|
||||||
|
"{text}"
|
||||||
|
|
||||||
|
Write using clear markdown formatting in a style expected of the goal "{goal}".
|
||||||
|
Be as clear, informative, and descriptive as necessary.
|
||||||
|
You will not make up information or add any information outside of the above text.
|
||||||
|
Only use the given information and nothing more.
|
||||||
|
|
||||||
|
If there is no information provided, say "There is nothing to summarize".
|
||||||
|
""",
|
||||||
|
input_variables=["goal", "language", "text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
company_context_prompt = PromptTemplate(
|
||||||
|
template="""You must answer in the "{language}" language.
|
||||||
|
|
||||||
|
Create a short description on "{company_name}".
|
||||||
|
Find out what sector it is in and what are their primary products.
|
||||||
|
|
||||||
|
Be as clear, informative, and descriptive as necessary.
|
||||||
|
You will not make up information or add any information outside of the above text.
|
||||||
|
Only use the given information and nothing more.
|
||||||
|
|
||||||
|
If there is no information provided, say "There is nothing to summarize".
|
||||||
|
""",
|
||||||
|
input_variables=["company_name", "language"],
|
||||||
|
)
|
||||||
|
|
||||||
|
summarize_pdf_prompt = PromptTemplate(
|
||||||
|
template="""You must answer in the "{language}" language.
|
||||||
|
|
||||||
|
For the given text: "{text}", you have the following objective "{query}".
|
||||||
|
|
||||||
|
Be as clear, informative, and descriptive as necessary.
|
||||||
|
You will not make up information or add any information outside of the above text.
|
||||||
|
Only use the given information and nothing more.
|
||||||
|
""",
|
||||||
|
input_variables=["query", "language", "text"],
|
||||||
|
)
|
||||||
|
|
||||||
|
summarize_with_sources_prompt = PromptTemplate(
|
||||||
|
template="""You must answer in the "{language}" language.
|
||||||
|
|
||||||
|
Answer the following query: "{query}" using the following information: "{snippets}".
|
||||||
|
Write using clear markdown formatting and use markdown lists where possible.
|
||||||
|
|
||||||
|
Cite sources for sentences via markdown links using the source link as the link and the index as the text.
|
||||||
|
Use in-line sources. Do not separately list sources at the end of the writing.
|
||||||
|
|
||||||
|
If the query cannot be answered with the provided information, mention this and provide a reason why along with what it does mention.
|
||||||
|
Also cite the sources of what is actually mentioned.
|
||||||
|
|
||||||
|
Example sentences of the paragraph:
|
||||||
|
"So this is a cited sentence at the end of a paragraph[1](https://test.com). This is another sentence."
|
||||||
|
"Stephen curry is an american basketball player that plays for the warriors[1](https://www.britannica.com/biography/Stephen-Curry)."
|
||||||
|
"The economic growth forecast for the region has been adjusted from 2.5% to 3.1% due to improved trade relations[1](https://economictimes.com), while inflation rates are expected to remain steady at around 1.7% according to financial analysts[2](https://financeworld.com)."
|
||||||
|
""",
|
||||||
|
input_variables=["language", "query", "snippets"],
|
||||||
|
)
|
||||||
|
|
||||||
|
summarize_sid_prompt = PromptTemplate(
|
||||||
|
template="""You must answer in the "{language}" language.
|
||||||
|
|
||||||
|
Parse and summarize the following text snippets "{snippets}".
|
||||||
|
Write using clear markdown formatting in a style expected of the goal "{goal}".
|
||||||
|
Be as clear, informative, and descriptive as necessary and attempt to
|
||||||
|
answer the query: "{query}" as best as possible.
|
||||||
|
If any of the snippets are not relevant to the query,
|
||||||
|
ignore them, and do not include them in the summary.
|
||||||
|
Do not mention that you are ignoring them.
|
||||||
|
|
||||||
|
If there is no information provided, say "There is nothing to summarize".
|
||||||
|
""",
|
||||||
|
input_variables=["goal", "language", "query", "snippets"],
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_prompt = PromptTemplate(
|
||||||
|
template="""You must answer in the "{language}" language.
|
||||||
|
|
||||||
|
You are a helpful AI Assistant that will provide responses based on the current conversation history.
|
||||||
|
|
||||||
|
The human will provide previous messages as context. Use ONLY this information for your responses.
|
||||||
|
Do not make anything up and do not add any additional information.
|
||||||
|
If you have no information for a given question in the conversation history,
|
||||||
|
say "I do not have any information on this".
|
||||||
|
""",
|
||||||
|
input_variables=["language"],
|
||||||
|
)
|
26
platform/reworkd_platform/web/api/agent/stream_mock.py
Normal file
26
platform/reworkd_platform/web/api/agent/stream_mock.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
def stream_string(data: str, delayed: bool = False) -> FastAPIStreamingResponse:
|
||||||
|
return FastAPIStreamingResponse(
|
||||||
|
stream_generator(data, delayed),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_generator(data: str, delayed: bool) -> AsyncGenerator[bytes, None]:
|
||||||
|
if delayed:
|
||||||
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
token_data = encoding.encode(data)
|
||||||
|
|
||||||
|
for token in token_data:
|
||||||
|
yield encoding.decode([token]).encode("utf-8")
|
||||||
|
await asyncio.sleep(0.025) # simulate slow processing
|
||||||
|
else:
|
||||||
|
yield data.encode()
|
@ -0,0 +1,88 @@
|
|||||||
|
import ast
|
||||||
|
import re
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.schema import BaseOutputParser, OutputParserException
|
||||||
|
|
||||||
|
|
||||||
|
class TaskOutputParser(BaseOutputParser[List[str]]):
|
||||||
|
"""
|
||||||
|
Extension of LangChain's BaseOutputParser
|
||||||
|
Responsible for parsing task creation output into a list of task strings
|
||||||
|
"""
|
||||||
|
|
||||||
|
completed_tasks: List[str] = []
|
||||||
|
|
||||||
|
def __init__(self, *, completed_tasks: List[str]):
|
||||||
|
super().__init__()
|
||||||
|
self.completed_tasks = completed_tasks
|
||||||
|
|
||||||
|
def parse(self, text: str) -> List[str]:
|
||||||
|
try:
|
||||||
|
array_str = extract_array(text)
|
||||||
|
all_tasks = [
|
||||||
|
remove_prefix(task) for task in array_str if real_tasks_filter(task)
|
||||||
|
]
|
||||||
|
return [task for task in all_tasks if task not in self.completed_tasks]
|
||||||
|
except Exception as e:
|
||||||
|
msg = f"Failed to parse tasks from completion '{text}'. Exception: {e}"
|
||||||
|
raise OutputParserException(msg)
|
||||||
|
|
||||||
|
def get_format_instructions(self) -> str:
|
||||||
|
return """
|
||||||
|
The response should be a JSON array of strings. Example:
|
||||||
|
|
||||||
|
["Search the web for NBA news", "Write some code to build a web scraper"]
|
||||||
|
|
||||||
|
This should be parsable by json.loads()
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def extract_array(input_str: str) -> List[str]:
|
||||||
|
regex = (
|
||||||
|
r"\[\s*\]|" # Empty array check
|
||||||
|
r"(\[(?:\s*(?:\"(?:[^\"\\]|\\.)*\"|\'(?:[^\'\\]|\\.)*\')\s*,?)*\s*\])"
|
||||||
|
)
|
||||||
|
match = re.search(regex, input_str)
|
||||||
|
if match is not None:
|
||||||
|
return ast.literal_eval(match[0])
|
||||||
|
else:
|
||||||
|
return handle_multiline_string(input_str)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_multiline_string(input_str: str) -> List[str]:
|
||||||
|
# Handle multiline string as a list
|
||||||
|
processed_lines = [
|
||||||
|
re.sub(r".*?(\d+\..+)", r"\1", line).strip()
|
||||||
|
for line in input_str.split("\n")
|
||||||
|
if line.strip() != ""
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check if there is at least one line that starts with a digit and a period
|
||||||
|
if any(re.match(r"\d+\..+", line) for line in processed_lines):
|
||||||
|
return processed_lines
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Failed to extract array from {input_str}")
|
||||||
|
|
||||||
|
|
||||||
|
def remove_prefix(input_str: str) -> str:
|
||||||
|
prefix_pattern = (
|
||||||
|
r"^(Task\s*\d*\.\s*|Task\s*\d*[-:]?\s*|Step\s*\d*["
|
||||||
|
r"-:]?\s*|Step\s*[-:]?\s*|\d+\.\s*|\d+\s*[-:]?\s*|^\.\s*|^\.*)"
|
||||||
|
)
|
||||||
|
return re.sub(prefix_pattern, "", input_str, flags=re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
def real_tasks_filter(input_str: str) -> bool:
|
||||||
|
no_task_regex = (
|
||||||
|
r"^No( (new|further|additional|extra|other))? tasks? (is )?("
|
||||||
|
r"required|needed|added|created|inputted).*"
|
||||||
|
)
|
||||||
|
task_complete_regex = r"^Task (complete|completed|finished|done|over|success).*"
|
||||||
|
do_nothing_regex = r"^(\s*|Do nothing(\s.*)?)$"
|
||||||
|
|
||||||
|
return (
|
||||||
|
not re.search(no_task_regex, input_str, re.IGNORECASE)
|
||||||
|
and not re.search(task_complete_regex, input_str, re.IGNORECASE)
|
||||||
|
and not re.search(do_nothing_regex, input_str, re.IGNORECASE)
|
||||||
|
)
|
25
platform/reworkd_platform/web/api/agent/tools/code.py
Normal file
25
platform/reworkd_platform/web/api/agent/tools/code.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
|
||||||
|
from lanarky.responses import StreamingResponse
|
||||||
|
from langchain import LLMChain
|
||||||
|
|
||||||
|
from reworkd_platform.web.api.agent.tools.tool import Tool
|
||||||
|
|
||||||
|
|
||||||
|
class Code(Tool):
|
||||||
|
description = "Should only be used to write code, refactor code, fix code bugs, and explain programming concepts."
|
||||||
|
public_description = "Write and review code."
|
||||||
|
|
||||||
|
async def call(
|
||||||
|
self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any
|
||||||
|
) -> FastAPIStreamingResponse:
|
||||||
|
from reworkd_platform.web.api.agent.prompts import code_prompt
|
||||||
|
|
||||||
|
chain = LLMChain(llm=self.model, prompt=code_prompt)
|
||||||
|
|
||||||
|
return StreamingResponse.from_chain(
|
||||||
|
chain,
|
||||||
|
{"goal": goal, "language": self.language, "task": task},
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
15
platform/reworkd_platform/web/api/agent/tools/conclude.py
Normal file
15
platform/reworkd_platform/web/api/agent/tools/conclude.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
|
||||||
|
|
||||||
|
from reworkd_platform.web.api.agent.stream_mock import stream_string
|
||||||
|
from reworkd_platform.web.api.agent.tools.tool import Tool
|
||||||
|
|
||||||
|
|
||||||
|
class Conclude(Tool):
|
||||||
|
description = "Use when there is nothing else to do. The task has been concluded."
|
||||||
|
|
||||||
|
async def call(
|
||||||
|
self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any
|
||||||
|
) -> FastAPIStreamingResponse:
|
||||||
|
return stream_string("Task execution concluded.", delayed=True)
|
66
platform/reworkd_platform/web/api/agent/tools/image.py
Normal file
66
platform/reworkd_platform/web/api/agent/tools/image.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import replicate
|
||||||
|
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
|
||||||
|
from replicate.exceptions import ModelError
|
||||||
|
from replicate.exceptions import ReplicateError as ReplicateAPIError
|
||||||
|
|
||||||
|
from reworkd_platform.settings import settings
|
||||||
|
from reworkd_platform.web.api.agent.stream_mock import stream_string
|
||||||
|
from reworkd_platform.web.api.agent.tools.tool import Tool
|
||||||
|
from reworkd_platform.web.api.errors import ReplicateError
|
||||||
|
|
||||||
|
|
||||||
|
async def get_replicate_image(input_str: str) -> str:
|
||||||
|
if settings.replicate_api_key is None or settings.replicate_api_key == "":
|
||||||
|
raise RuntimeError("Replicate API key not set")
|
||||||
|
|
||||||
|
client = replicate.Client(settings.replicate_api_key)
|
||||||
|
try:
|
||||||
|
output = client.run(
|
||||||
|
"stability-ai/stable-diffusion"
|
||||||
|
":db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf",
|
||||||
|
input={"prompt": input_str},
|
||||||
|
image_dimensions="512x512",
|
||||||
|
)
|
||||||
|
except ModelError as e:
|
||||||
|
raise ReplicateError(e, "Image generation failed due to NSFW image.")
|
||||||
|
except ReplicateAPIError as e:
|
||||||
|
raise ReplicateError(e, "Failed to generate an image.")
|
||||||
|
|
||||||
|
return output[0]
|
||||||
|
|
||||||
|
|
||||||
|
# Use AI to generate an Image based on a prompt
|
||||||
|
async def get_open_ai_image(input_str: str) -> str:
|
||||||
|
response = openai.Image.create(
|
||||||
|
api_key=settings.openai_api_key,
|
||||||
|
prompt=input_str,
|
||||||
|
n=1,
|
||||||
|
size="256x256",
|
||||||
|
)
|
||||||
|
|
||||||
|
return response["data"][0]["url"]
|
||||||
|
|
||||||
|
|
||||||
|
class Image(Tool):
|
||||||
|
description = "Used to sketch, draw, or generate an image."
|
||||||
|
public_description = "Generate AI images."
|
||||||
|
arg_description = (
|
||||||
|
"The input prompt to the image generator. "
|
||||||
|
"This should be a detailed description of the image touching on image "
|
||||||
|
"style, image focus, color, etc."
|
||||||
|
)
|
||||||
|
image_url = "/tools/replicate.png"
|
||||||
|
|
||||||
|
async def call(
|
||||||
|
self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any
|
||||||
|
) -> FastAPIStreamingResponse:
|
||||||
|
# Use the replicate API if its available, otherwise use DALL-E
|
||||||
|
try:
|
||||||
|
url = await get_replicate_image(input_str)
|
||||||
|
except RuntimeError:
|
||||||
|
url = await get_open_ai_image(input_str)
|
||||||
|
|
||||||
|
return stream_string(f"")
|
@ -0,0 +1,43 @@
|
|||||||
|
from typing import Type, TypedDict
|
||||||
|
|
||||||
|
from reworkd_platform.web.api.agent.tools.tool import Tool
|
||||||
|
from reworkd_platform.web.api.agent.tools.tools import get_tool_name
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionDescription(TypedDict):
|
||||||
|
"""Representation of a callable function to the OpenAI API."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""The name of the function."""
|
||||||
|
description: str
|
||||||
|
"""A description of the function."""
|
||||||
|
parameters: dict[str, object]
|
||||||
|
"""The parameters of the function."""
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_function(tool: Type[Tool]) -> FunctionDescription:
|
||||||
|
"""A function that will return the tool's function specification"""
|
||||||
|
name = get_tool_name(tool)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"name": name,
|
||||||
|
"description": tool.description,
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"reasoning": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
f"Reasoning is how the task will be accomplished with the current function. "
|
||||||
|
"Detail your overall plan along with any concerns you have."
|
||||||
|
"Ensure this reasoning value is in the user defined langauge "
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"arg": {
|
||||||
|
"type": "string",
|
||||||
|
"description": tool.arg_description,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["reasoning", "arg"],
|
||||||
|
},
|
||||||
|
}
|
27
platform/reworkd_platform/web/api/agent/tools/reason.py
Normal file
27
platform/reworkd_platform/web/api/agent/tools/reason.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
|
||||||
|
from lanarky.responses import StreamingResponse
|
||||||
|
from langchain import LLMChain
|
||||||
|
|
||||||
|
from reworkd_platform.web.api.agent.tools.tool import Tool
|
||||||
|
|
||||||
|
|
||||||
|
class Reason(Tool):
|
||||||
|
description = (
|
||||||
|
"Reason about task via existing information or understanding. "
|
||||||
|
"Make decisions / selections from options."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def call(
|
||||||
|
self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any
|
||||||
|
) -> FastAPIStreamingResponse:
|
||||||
|
from reworkd_platform.web.api.agent.prompts import execute_task_prompt
|
||||||
|
|
||||||
|
chain = LLMChain(llm=self.model, prompt=execute_task_prompt)
|
||||||
|
|
||||||
|
return StreamingResponse.from_chain(
|
||||||
|
chain,
|
||||||
|
{"goal": goal, "language": self.language, "task": task},
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
109
platform/reworkd_platform/web/api/agent/tools/search.py
Normal file
109
platform/reworkd_platform/web/api/agent/tools/search.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
from typing import Any, List
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from aiohttp import ClientResponseError
|
||||||
|
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from reworkd_platform.settings import settings
|
||||||
|
from reworkd_platform.web.api.agent.stream_mock import stream_string
|
||||||
|
from reworkd_platform.web.api.agent.tools.reason import Reason
|
||||||
|
from reworkd_platform.web.api.agent.tools.tool import Tool
|
||||||
|
from reworkd_platform.web.api.agent.tools.utils import (
|
||||||
|
CitedSnippet,
|
||||||
|
summarize_with_sources,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search google via serper.dev. Adapted from LangChain
|
||||||
|
# https://github.com/hwchase17/langchain/blob/master/langchain/utilities
|
||||||
|
|
||||||
|
|
||||||
|
async def _google_serper_search_results(
|
||||||
|
search_term: str, search_type: str = "search"
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
headers = {
|
||||||
|
"X-API-KEY": settings.serp_api_key or "",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
params = {
|
||||||
|
"q": search_term,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"https://google.serper.dev/{search_type}", headers=headers, params=params
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
search_results = await response.json()
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
|
||||||
|
class Search(Tool):
|
||||||
|
description = (
|
||||||
|
"Search Google for short up to date searches for simple questions about public information "
|
||||||
|
"news and people.\n"
|
||||||
|
)
|
||||||
|
public_description = "Search google for information about current events."
|
||||||
|
arg_description = "The query argument to search for. This value is always populated and cannot be an empty string."
|
||||||
|
image_url = "/tools/google.png"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def available() -> bool:
|
||||||
|
return settings.serp_api_key is not None and settings.serp_api_key != ""
|
||||||
|
|
||||||
|
async def call(
|
||||||
|
self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any
|
||||||
|
) -> FastAPIStreamingResponse:
|
||||||
|
try:
|
||||||
|
return await self._call(goal, task, input_str, *args, **kwargs)
|
||||||
|
except ClientResponseError:
|
||||||
|
logger.exception("Error calling Serper API, falling back to reasoning")
|
||||||
|
return await Reason(self.model, self.language).call(
|
||||||
|
goal, task, input_str, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _call(
|
||||||
|
self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any
|
||||||
|
) -> FastAPIStreamingResponse:
|
||||||
|
results = await _google_serper_search_results(
|
||||||
|
input_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
k = 5 # Number of results to return
|
||||||
|
snippets: List[CitedSnippet] = []
|
||||||
|
|
||||||
|
if results.get("answerBox"):
|
||||||
|
answer_values = []
|
||||||
|
answer_box = results.get("answerBox", {})
|
||||||
|
if answer_box.get("answer"):
|
||||||
|
answer_values.append(answer_box.get("answer"))
|
||||||
|
elif answer_box.get("snippet"):
|
||||||
|
answer_values.append(answer_box.get("snippet").replace("\n", " "))
|
||||||
|
elif answer_box.get("snippetHighlighted"):
|
||||||
|
answer_values.append(", ".join(answer_box.get("snippetHighlighted")))
|
||||||
|
|
||||||
|
if len(answer_values) > 0:
|
||||||
|
snippets.append(
|
||||||
|
CitedSnippet(
|
||||||
|
len(snippets) + 1,
|
||||||
|
"\n".join(answer_values),
|
||||||
|
f"https://www.google.com/search?q={quote(input_str)}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, result in enumerate(results["organic"][:k]):
|
||||||
|
texts = []
|
||||||
|
link = ""
|
||||||
|
if "snippet" in result:
|
||||||
|
texts.append(result["snippet"])
|
||||||
|
if "link" in result:
|
||||||
|
link = result["link"]
|
||||||
|
for attribute, value in result.get("attributes", {}).items():
|
||||||
|
texts.append(f"{attribute}: {value}.")
|
||||||
|
snippets.append(CitedSnippet(len(snippets) + 1, "\n".join(texts), link))
|
||||||
|
|
||||||
|
if len(snippets) == 0:
|
||||||
|
return stream_string("No good Google Search Result was found", True)
|
||||||
|
|
||||||
|
return summarize_with_sources(self.model, self.language, goal, task, snippets)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user