mirror of
https://github.com/trushildhokiya/allininx-2.git
synced 2025-03-14 18:28:41 +00:00
commit -2
This commit is contained in:
parent
48da88668c
commit
bdfe11f039
24
cli/.gitignore
vendored
Normal file
24
cli/.gitignore
vendored
Normal file
@ -0,0 +1,24 @@
|
||||
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
|
||||
|
||||
# dependencies
|
||||
/node_modules
|
||||
/.pnp
|
||||
.pnp.js
|
||||
|
||||
# testing
|
||||
/coverage
|
||||
|
||||
# production
|
||||
/build
|
||||
|
||||
# misc
|
||||
.DS_Store
|
||||
*.pem
|
||||
|
||||
# debug
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
.pnpm-debug.log*
|
||||
|
||||
.eslintcache
|
27
cli/README.md
Normal file
27
cli/README.md
Normal file
@ -0,0 +1,27 @@
|
||||
## AgentGPT CLI
|
||||
|
||||
AgentGPT CLI is a utility designed to streamline the setup process of your AgentGPT environment.
|
||||
It uses Inquirer to interactively build up ENV values while also validating they are correct.
|
||||
|
||||
This was first created by @JPDucky on GitHub.
|
||||
|
||||
### Running the tool
|
||||
|
||||
```
|
||||
// Running from the root of the project
|
||||
./setup.sh
|
||||
```
|
||||
|
||||
```
|
||||
// Running from the cli directory
|
||||
cd cli/
|
||||
npm run start
|
||||
```
|
||||
|
||||
### Updating ENV values
|
||||
|
||||
To update ENV values:
|
||||
|
||||
- Add a question to the list of questions in `index.js` for the ENV value
|
||||
- Add a value in the `envDefinition` for the ENV value
|
||||
- Add the ENV value to the `.env.example` in the root of the project
|
2865
cli/package-lock.json
generated
Normal file
2865
cli/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
32
cli/package.json
Normal file
32
cli/package.json
Normal file
@ -0,0 +1,32 @@
|
||||
{
|
||||
"name": "agentgpt-cli",
|
||||
"version": "1.0.0",
|
||||
"description": "A CLI to create your AgentGPT environment",
|
||||
"private": true,
|
||||
"engines": {
|
||||
"node": ">=18.0.0 <19.0.0"
|
||||
},
|
||||
"type": "module",
|
||||
"main": "index.js",
|
||||
"scripts": {
|
||||
"start": "node src/index.js",
|
||||
"dev": "node src/index.js"
|
||||
},
|
||||
"author": "reworkd",
|
||||
"dependencies": {
|
||||
"@octokit/auth-basic": "^1.4.8",
|
||||
"@octokit/rest": "^20.0.2",
|
||||
"chalk": "^5.3.0",
|
||||
"clear": "^0.1.0",
|
||||
"clui": "^0.3.6",
|
||||
"configstore": "^6.0.0",
|
||||
"dotenv": "^16.3.1",
|
||||
"figlet": "^1.7.0",
|
||||
"inquirer": "^9.2.12",
|
||||
"lodash": "^4.17.21",
|
||||
"minimist": "^1.2.8",
|
||||
"node-fetch": "^3.3.2",
|
||||
"simple-git": "^3.20.0",
|
||||
"touch": "^3.1.0"
|
||||
}
|
||||
}
|
142
cli/src/envGenerator.js
Normal file
142
cli/src/envGenerator.js
Normal file
@ -0,0 +1,142 @@
|
||||
import crypto from "crypto";
|
||||
import fs from "fs";
|
||||
import chalk from "chalk";
|
||||
|
||||
export const generateEnv = (envValues) => {
|
||||
let isDockerCompose = envValues.runOption === "docker-compose";
|
||||
let dbPort = isDockerCompose ? 3307 : 3306;
|
||||
let platformUrl = isDockerCompose
|
||||
? "http://host.docker.internal:8000"
|
||||
: "http://localhost:8000";
|
||||
|
||||
const envDefinition = getEnvDefinition(
|
||||
envValues,
|
||||
isDockerCompose,
|
||||
dbPort,
|
||||
platformUrl
|
||||
);
|
||||
|
||||
const envFileContent = generateEnvFileContent(envDefinition);
|
||||
saveEnvFile(envFileContent);
|
||||
};
|
||||
|
||||
const getEnvDefinition = (envValues, isDockerCompose, dbPort, platformUrl) => {
|
||||
return {
|
||||
"Deployment Environment": {
|
||||
NODE_ENV: "development",
|
||||
NEXT_PUBLIC_VERCEL_ENV: "${NODE_ENV}",
|
||||
},
|
||||
NextJS: {
|
||||
NEXT_PUBLIC_BACKEND_URL: "http://localhost:8000",
|
||||
NEXT_PUBLIC_MAX_LOOPS: 100,
|
||||
},
|
||||
"Next Auth config": {
|
||||
NEXTAUTH_SECRET: generateAuthSecret(),
|
||||
NEXTAUTH_URL: "http://localhost:3000",
|
||||
},
|
||||
"Auth providers (Use if you want to get out of development mode sign-in)": {
|
||||
GOOGLE_CLIENT_ID: "***",
|
||||
GOOGLE_CLIENT_SECRET: "***",
|
||||
GITHUB_CLIENT_ID: "***",
|
||||
GITHUB_CLIENT_SECRET: "***",
|
||||
DISCORD_CLIENT_SECRET: "***",
|
||||
DISCORD_CLIENT_ID: "***",
|
||||
},
|
||||
Backend: {
|
||||
REWORKD_PLATFORM_ENVIRONMENT: "${NODE_ENV}",
|
||||
REWORKD_PLATFORM_FF_MOCK_MODE_ENABLED: false,
|
||||
REWORKD_PLATFORM_MAX_LOOPS: "${NEXT_PUBLIC_MAX_LOOPS}",
|
||||
REWORKD_PLATFORM_OPENAI_API_KEY:
|
||||
envValues.OpenAIApiKey || '"<change me>"',
|
||||
REWORKD_PLATFORM_FRONTEND_URL: "http://localhost:3000",
|
||||
REWORKD_PLATFORM_RELOAD: true,
|
||||
REWORKD_PLATFORM_OPENAI_API_BASE: "https://api.openai.com/v1",
|
||||
REWORKD_PLATFORM_SERP_API_KEY: envValues.serpApiKey || '""',
|
||||
REWORKD_PLATFORM_REPLICATE_API_KEY: envValues.replicateApiKey || '""',
|
||||
},
|
||||
"Database (Backend)": {
|
||||
REWORKD_PLATFORM_DATABASE_USER: "reworkd_platform",
|
||||
REWORKD_PLATFORM_DATABASE_PASSWORD: "reworkd_platform",
|
||||
REWORKD_PLATFORM_DATABASE_HOST: "agentgpt_db",
|
||||
REWORKD_PLATFORM_DATABASE_PORT: dbPort,
|
||||
REWORKD_PLATFORM_DATABASE_NAME: "reworkd_platform",
|
||||
REWORKD_PLATFORM_DATABASE_URL:
|
||||
"mysql://${REWORKD_PLATFORM_DATABASE_USER}:${REWORKD_PLATFORM_DATABASE_PASSWORD}@${REWORKD_PLATFORM_DATABASE_HOST}:${REWORKD_PLATFORM_DATABASE_PORT}/${REWORKD_PLATFORM_DATABASE_NAME}",
|
||||
},
|
||||
"Database (Frontend)": {
|
||||
DATABASE_USER: "reworkd_platform",
|
||||
DATABASE_PASSWORD: "reworkd_platform",
|
||||
DATABASE_HOST: "agentgpt_db",
|
||||
DATABASE_PORT: dbPort,
|
||||
DATABASE_NAME: "reworkd_platform",
|
||||
DATABASE_URL:
|
||||
"mysql://${DATABASE_USER}:${DATABASE_PASSWORD}@${DATABASE_HOST}:${DATABASE_PORT}/${DATABASE_NAME}",
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
const generateEnvFileContent = (config) => {
|
||||
let configFile = "";
|
||||
|
||||
Object.entries(config).forEach(([section, variables]) => {
|
||||
configFile += `# ${section}:\n`;
|
||||
Object.entries(variables).forEach(([key, value]) => {
|
||||
configFile += `${key}=${value}\n`;
|
||||
});
|
||||
configFile += "\n";
|
||||
});
|
||||
|
||||
return configFile.trim();
|
||||
};
|
||||
|
||||
const generateAuthSecret = () => {
|
||||
const length = 32;
|
||||
const buffer = crypto.randomBytes(length);
|
||||
return buffer.toString("base64");
|
||||
};
|
||||
|
||||
const ENV_PATH = "../next/.env";
|
||||
const BACKEND_ENV_PATH = "../platform/.env";
|
||||
|
||||
export const doesEnvFileExist = () => {
|
||||
return fs.existsSync(ENV_PATH);
|
||||
};
|
||||
|
||||
// Read the existing env file, test if it is missing any keys or contains any extra keys
|
||||
export const testEnvFile = () => {
|
||||
const data = fs.readFileSync(ENV_PATH, "utf8");
|
||||
|
||||
// Make a fake definition to compare the keys of
|
||||
const envDefinition = getEnvDefinition({}, "", "", "", "");
|
||||
|
||||
const lines = data
|
||||
.split("\n")
|
||||
.filter((line) => !line.startsWith("#") && line.trim() !== "");
|
||||
const envKeysFromFile = lines.map((line) => line.split("=")[0]);
|
||||
|
||||
const envKeysFromDef = Object.entries(envDefinition).flatMap(
|
||||
([section, entries]) => Object.keys(entries)
|
||||
);
|
||||
|
||||
const missingFromFile = envKeysFromDef.filter(
|
||||
(key) => !envKeysFromFile.includes(key)
|
||||
);
|
||||
|
||||
if (missingFromFile.length > 0) {
|
||||
let errorMessage = "\nYour ./next/.env is missing the following keys:\n";
|
||||
missingFromFile.forEach((key) => {
|
||||
errorMessage += chalk.whiteBright(`- ❌ ${key}\n`);
|
||||
});
|
||||
errorMessage += "\n";
|
||||
|
||||
errorMessage += chalk.red(
|
||||
"We recommend deleting your .env file(s) and restarting this script."
|
||||
);
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
};
|
||||
|
||||
export const saveEnvFile = (envFileContent) => {
|
||||
fs.writeFileSync(ENV_PATH, envFileContent);
|
||||
fs.writeFileSync(BACKEND_ENV_PATH, envFileContent);
|
||||
};
|
26
cli/src/helpers.js
Normal file
26
cli/src/helpers.js
Normal file
@ -0,0 +1,26 @@
|
||||
import chalk from "chalk";
|
||||
import figlet from "figlet";
|
||||
|
||||
export const printTitle = () => {
|
||||
console.log(
|
||||
chalk.red(
|
||||
figlet.textSync("AgentGPT", {
|
||||
horizontalLayout: "full",
|
||||
font: "ANSI Shadow",
|
||||
})
|
||||
)
|
||||
);
|
||||
console.log(
|
||||
"Welcome to the AgentGPT CLI! This CLI will generate the required .env files."
|
||||
);
|
||||
console.log(
|
||||
"Copies of the generated envs will be created in `./next/.env` and `./platform/.env`.\n"
|
||||
);
|
||||
};
|
||||
|
||||
// Function to check if entered api key is in the correct format or empty
|
||||
export const isValidKey = (apikey, pattern) => {
|
||||
return (apikey === "" || pattern.test(apikey))
|
||||
};
|
||||
|
||||
export const validKeyErrorMessage = "\nInvalid api key. Please try again."
|
60
cli/src/index.js
Normal file
60
cli/src/index.js
Normal file
@ -0,0 +1,60 @@
|
||||
import inquirer from "inquirer";
|
||||
import dotenv from "dotenv";
|
||||
import { printTitle } from "./helpers.js";
|
||||
import { doesEnvFileExist, generateEnv, testEnvFile } from "./envGenerator.js";
|
||||
import { newEnvQuestions } from "./questions/newEnvQuestions.js";
|
||||
import { existingEnvQuestions } from "./questions/existingEnvQuestions.js";
|
||||
import { spawn } from "child_process";
|
||||
import chalk from "chalk";
|
||||
|
||||
const handleExistingEnv = () => {
|
||||
console.log(chalk.yellow("Existing ./next/env file found. Validating..."));
|
||||
|
||||
try {
|
||||
testEnvFile();
|
||||
} catch (e) {
|
||||
console.log(e.message);
|
||||
return;
|
||||
}
|
||||
|
||||
inquirer.prompt(existingEnvQuestions).then((answers) => {
|
||||
handleRunOption(answers.runOption);
|
||||
});
|
||||
};
|
||||
|
||||
const handleNewEnv = () => {
|
||||
inquirer.prompt(newEnvQuestions).then((answers) => {
|
||||
dotenv.config({ path: "./.env" });
|
||||
generateEnv(answers);
|
||||
console.log("\nEnv files successfully created!");
|
||||
handleRunOption(answers.runOption);
|
||||
});
|
||||
};
|
||||
|
||||
const handleRunOption = (runOption) => {
|
||||
if (runOption === "docker-compose") {
|
||||
const dockerComposeUp = spawn("docker-compose", ["up", "--build"], {
|
||||
stdio: "inherit",
|
||||
});
|
||||
}
|
||||
|
||||
if (runOption === "manual") {
|
||||
console.log(
|
||||
"Please go into the ./next folder and run `npm install && npm run dev`."
|
||||
);
|
||||
console.log(
|
||||
"Please also go into the ./platform folder and run `poetry install && poetry run python -m reworkd_platform`."
|
||||
);
|
||||
console.log(
|
||||
"Please use or update the MySQL database configuration in the env file(s)."
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
printTitle();
|
||||
|
||||
if (doesEnvFileExist()) {
|
||||
handleExistingEnv();
|
||||
} else {
|
||||
handleNewEnv();
|
||||
}
|
5
cli/src/questions/existingEnvQuestions.js
Normal file
5
cli/src/questions/existingEnvQuestions.js
Normal file
@ -0,0 +1,5 @@
|
||||
import { RUN_OPTION_QUESTION } from "./sharedQuestions.js";
|
||||
|
||||
export const existingEnvQuestions = [
|
||||
RUN_OPTION_QUESTION
|
||||
];
|
87
cli/src/questions/newEnvQuestions.js
Normal file
87
cli/src/questions/newEnvQuestions.js
Normal file
@ -0,0 +1,87 @@
|
||||
import { isValidKey, validKeyErrorMessage } from "../helpers.js";
|
||||
import { RUN_OPTION_QUESTION } from "./sharedQuestions.js";
|
||||
import fetch from "node-fetch";
|
||||
|
||||
export const newEnvQuestions = [
|
||||
RUN_OPTION_QUESTION,
|
||||
{
|
||||
type: "input",
|
||||
name: "OpenAIApiKey",
|
||||
message:
|
||||
"Enter your openai key (eg: sk...) or press enter to continue with no key:",
|
||||
validate: async(apikey) => {
|
||||
if(apikey === "") return true;
|
||||
|
||||
if(!isValidKey(apikey, /^sk(-proj)?-[a-zA-Z0-9\-\_]+$/)) {
|
||||
return validKeyErrorMessage
|
||||
}
|
||||
|
||||
const endpoint = "https://api.openai.com/v1/models"
|
||||
const response = await fetch(endpoint, {
|
||||
headers: {
|
||||
"Authorization": `Bearer ${apikey}`,
|
||||
},
|
||||
});
|
||||
if(!response.ok) {
|
||||
return validKeyErrorMessage
|
||||
}
|
||||
|
||||
return true
|
||||
},
|
||||
},
|
||||
{
|
||||
type: "input",
|
||||
name: "serpApiKey",
|
||||
message:
|
||||
"What is your SERP API key (https://serper.dev/)? Leave empty to disable web search.",
|
||||
validate: async(apikey) => {
|
||||
if(apikey === "") return true;
|
||||
|
||||
if(!isValidKey(apikey, /^[a-zA-Z0-9]{40}$/)) {
|
||||
return validKeyErrorMessage
|
||||
}
|
||||
|
||||
const endpoint = "https://google.serper.dev/search"
|
||||
const response = await fetch(endpoint, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
"X-API-KEY": apikey,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
"q": "apple inc"
|
||||
}),
|
||||
});
|
||||
if(!response.ok) {
|
||||
return validKeyErrorMessage
|
||||
}
|
||||
|
||||
return true
|
||||
},
|
||||
},
|
||||
{
|
||||
type: "input",
|
||||
name: "replicateApiKey",
|
||||
message:
|
||||
"What is your Replicate API key (https://replicate.com/)? Leave empty to just use DALL-E for image generation.",
|
||||
validate: async(apikey) => {
|
||||
if(apikey === "") return true;
|
||||
|
||||
if(!isValidKey(apikey, /^r8_[a-zA-Z0-9]{37}$/)) {
|
||||
return validKeyErrorMessage
|
||||
}
|
||||
|
||||
const endpoint = "https://api.replicate.com/v1/models/replicate/hello-world"
|
||||
const response = await fetch(endpoint, {
|
||||
headers: {
|
||||
"Authorization": `Token ${apikey}`,
|
||||
},
|
||||
});
|
||||
if(!response.ok) {
|
||||
return validKeyErrorMessage
|
||||
}
|
||||
|
||||
return true
|
||||
},
|
||||
},
|
||||
];
|
10
cli/src/questions/sharedQuestions.js
Normal file
10
cli/src/questions/sharedQuestions.js
Normal file
@ -0,0 +1,10 @@
|
||||
export const RUN_OPTION_QUESTION = {
|
||||
type: 'list',
|
||||
name: 'runOption',
|
||||
choices: [
|
||||
{ value: "docker-compose", name: "🐋 Docker-compose (Recommended)" },
|
||||
{ value: "manual", name: "💪 Manual (Not recommended)" },
|
||||
],
|
||||
message: 'How will you be running AgentGPT?',
|
||||
default: "docker-compose",
|
||||
}
|
115
cli/tsconfig.json
Normal file
115
cli/tsconfig.json
Normal file
@ -0,0 +1,115 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
/* Visit https://aka.ms/tsconfig to read more about this file */
|
||||
|
||||
/* Projects */
|
||||
// "incremental": true, /* Save .tsbuildinfo files to allow for incremental compilation of projects. */
|
||||
// "composite": true, /* Enable constraints that allow a TypeScript project to be used with project references. */
|
||||
// "tsBuildInfoFile": "./.tsbuildinfo", /* Specify the path to .tsbuildinfo incremental compilation file. */
|
||||
// "disableSourceOfProjectReferenceRedirect": true, /* Disable preferring source files instead of declaration files when referencing composite projects. */
|
||||
// "disableSolutionSearching": true, /* Opt a project out of multi-project reference checking when editing. */
|
||||
// "disableReferencedProjectLoad": true, /* Reduce the number of projects loaded automatically by TypeScript. */
|
||||
|
||||
/* Language and Environment */
|
||||
"target": "es2016",
|
||||
/* Set the JavaScript language version for emitted JavaScript and include compatible library declarations. */
|
||||
// "src": [], /* Specify a set of bundled library declaration files that describe the target runtime environment. */
|
||||
// "jsx": "preserve", /* Specify what JSX code is generated. */
|
||||
// "experimentalDecorators": true, /* Enable experimental support for legacy experimental decorators. */
|
||||
// "emitDecoratorMetadata": true, /* Emit design-type metadata for decorated declarations in source files. */
|
||||
// "jsxFactory": "", /* Specify the JSX factory function used when targeting React JSX emit, e.g. 'React.createElement' or 'h'. */
|
||||
// "jsxFragmentFactory": "", /* Specify the JSX Fragment reference used for fragments when targeting React JSX emit e.g. 'React.Fragment' or 'Fragment'. */
|
||||
// "jsxImportSource": "", /* Specify module specifier used to import the JSX factory functions when using 'jsx: react-jsx*'. */
|
||||
// "reactNamespace": "", /* Specify the object invoked for 'createElement'. This only applies when targeting 'react' JSX emit. */
|
||||
// "noLib": true, /* Disable including any library files, including the default src.d.ts. */
|
||||
// "useDefineForClassFields": true, /* Emit ECMAScript-standard-compliant class fields. */
|
||||
// "moduleDetection": "auto", /* Control what method is used to detect module-format JS files. */
|
||||
|
||||
/* Modules */
|
||||
"module": "commonjs",
|
||||
/* Specify what module code is generated. */
|
||||
// "rootDir": "./", /* Specify the root folder within your source files. */
|
||||
// "moduleResolution": "node10", /* Specify how TypeScript looks up a file from a given module specifier. */
|
||||
// "baseUrl": "./", /* Specify the base directory to resolve non-relative module names. */
|
||||
// "paths": {}, /* Specify a set of entries that re-map imports to additional lookup locations. */
|
||||
// "rootDirs": [], /* Allow multiple folders to be treated as one when resolving modules. */
|
||||
// "typeRoots": [], /* Specify multiple folders that act like './node_modules/@types'. */
|
||||
// "types": [], /* Specify type package names to be included without being referenced in a source file. */
|
||||
// "allowUmdGlobalAccess": true, /* Allow accessing UMD globals from modules. */
|
||||
// "moduleSuffixes": [], /* List of file name suffixes to search when resolving a module. */
|
||||
// "allowImportingTsExtensions": true, /* Allow imports to include TypeScript file extensions. Requires '--moduleResolution bundler' and either '--noEmit' or '--emitDeclarationOnly' to be set. */
|
||||
// "resolvePackageJsonExports": true, /* Use the package.json 'exports' field when resolving package imports. */
|
||||
// "resolvePackageJsonImports": true, /* Use the package.json 'imports' field when resolving imports. */
|
||||
// "customConditions": [], /* Conditions to set in addition to the resolver-specific defaults when resolving imports. */
|
||||
// "resolveJsonModule": true, /* Enable importing .json files. */
|
||||
// "allowArbitraryExtensions": true, /* Enable importing files with any extension, provided a declaration file is present. */
|
||||
// "noResolve": true, /* Disallow 'import's, 'require's or '<reference>'s from expanding the number of files TypeScript should add to a project. */
|
||||
|
||||
/* JavaScript Support */
|
||||
// "allowJs": true, /* Allow JavaScript files to be a part of your program. Use the 'checkJS' option to get errors from these files. */
|
||||
// "checkJs": true, /* Enable error reporting in type-checked JavaScript files. */
|
||||
// "maxNodeModuleJsDepth": 1, /* Specify the maximum folder depth used for checking JavaScript files from 'node_modules'. Only applicable with 'allowJs'. */
|
||||
|
||||
/* Emit */
|
||||
// "declaration": true, /* Generate .d.ts files from TypeScript and JavaScript files in your project. */
|
||||
// "declarationMap": true, /* Create sourcemaps for d.ts files. */
|
||||
// "emitDeclarationOnly": true, /* Only output d.ts files and not JavaScript files. */
|
||||
// "sourceMap": true, /* Create source map files for emitted JavaScript files. */
|
||||
// "inlineSourceMap": true, /* Include sourcemap files inside the emitted JavaScript. */
|
||||
// "outFile": "./", /* Specify a file that bundles all outputs into one JavaScript file. If 'declaration' is true, also designates a file that bundles all .d.ts output. */
|
||||
// "outDir": "./", /* Specify an output folder for all emitted files. */
|
||||
// "removeComments": true, /* Disable emitting comments. */
|
||||
// "noEmit": true, /* Disable emitting files from a compilation. */
|
||||
// "importHelpers": true, /* Allow importing helper functions from tslib once per project, instead of including them per-file. */
|
||||
// "importsNotUsedAsValues": "remove", /* Specify emit/checking behavior for imports that are only used for types. */
|
||||
// "downlevelIteration": true, /* Emit more compliant, but verbose and less performant JavaScript for iteration. */
|
||||
// "sourceRoot": "", /* Specify the root path for debuggers to find the reference source code. */
|
||||
// "mapRoot": "", /* Specify the location where debugger should locate map files instead of generated locations. */
|
||||
// "inlineSources": true, /* Include source code in the sourcemaps inside the emitted JavaScript. */
|
||||
// "emitBOM": true, /* Emit a UTF-8 Byte Order Mark (BOM) in the beginning of output files. */
|
||||
// "newLine": "crlf", /* Set the newline character for emitting files. */
|
||||
// "stripInternal": true, /* Disable emitting declarations that have '@internal' in their JSDoc comments. */
|
||||
// "noEmitHelpers": true, /* Disable generating custom helper functions like '__extends' in compiled output. */
|
||||
// "noEmitOnError": true, /* Disable emitting files if any type checking errors are reported. */
|
||||
// "preserveConstEnums": true, /* Disable erasing 'const enum' declarations in generated code. */
|
||||
// "declarationDir": "./", /* Specify the output directory for generated declaration files. */
|
||||
// "preserveValueImports": true, /* Preserve unused imported values in the JavaScript output that would otherwise be removed. */
|
||||
|
||||
/* Interop Constraints */
|
||||
// "isolatedModules": true, /* Ensure that each file can be safely transpiled without relying on other imports. */
|
||||
// "verbatimModuleSyntax": true, /* Do not transform or elide any imports or exports not marked as type-only, ensuring they are written in the output file's format based on the 'module' setting. */
|
||||
// "allowSyntheticDefaultImports": true, /* Allow 'import x from y' when a module doesn't have a default export. */
|
||||
"esModuleInterop": true,
|
||||
/* Emit additional JavaScript to ease support for importing CommonJS modules. This enables 'allowSyntheticDefaultImports' for type compatibility. */
|
||||
// "preserveSymlinks": true, /* Disable resolving symlinks to their realpath. This correlates to the same flag in node. */
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
/* Ensure that casing is correct in imports. */
|
||||
|
||||
/* Type Checking */
|
||||
"strict": true,
|
||||
/* Enable all strict type-checking options. */
|
||||
// "noImplicitAny": true, /* Enable error reporting for expressions and declarations with an implied 'any' type. */
|
||||
// "strictNullChecks": true, /* When type checking, take into account 'null' and 'undefined'. */
|
||||
// "strictFunctionTypes": true, /* When assigning functions, check to ensure parameters and the return values are subtype-compatible. */
|
||||
// "strictBindCallApply": true, /* Check that the arguments for 'bind', 'call', and 'apply' methods match the original function. */
|
||||
// "strictPropertyInitialization": true, /* Check for class properties that are declared but not set in the constructor. */
|
||||
// "noImplicitThis": true, /* Enable error reporting when 'this' is given the type 'any'. */
|
||||
// "useUnknownInCatchVariables": true, /* Default catch clause variables as 'unknown' instead of 'any'. */
|
||||
// "alwaysStrict": true, /* Ensure 'use strict' is always emitted. */
|
||||
// "noUnusedLocals": true, /* Enable error reporting when local variables aren't read. */
|
||||
// "noUnusedParameters": true, /* Raise an error when a function parameter isn't read. */
|
||||
// "exactOptionalPropertyTypes": true, /* Interpret optional property types as written, rather than adding 'undefined'. */
|
||||
// "noImplicitReturns": true, /* Enable error reporting for codepaths that do not explicitly return in a function. */
|
||||
// "noFallthroughCasesInSwitch": true, /* Enable error reporting for fallthrough cases in switch statements. */
|
||||
// "noUncheckedIndexedAccess": true, /* Add 'undefined' to a type when accessed using an index. */
|
||||
// "noImplicitOverride": true, /* Ensure overriding members in derived classes are marked with an override modifier. */
|
||||
// "noPropertyAccessFromIndexSignature": true, /* Enforces using indexed accessors for keys declared using an indexed type. */
|
||||
// "allowUnusedLabels": true, /* Disable error reporting for unused labels. */
|
||||
// "allowUnreachableCode": true, /* Disable error reporting for unreachable code. */
|
||||
|
||||
/* Completeness */
|
||||
// "skipDefaultLibCheck": true, /* Skip type checking .d.ts files that are included with TypeScript. */
|
||||
"skipLibCheck": true
|
||||
/* Skip type checking all .d.ts files. */
|
||||
}
|
||||
}
|
3
db/Dockerfile
Normal file
3
db/Dockerfile
Normal file
@ -0,0 +1,3 @@
|
||||
FROM mysql:8.0
|
||||
|
||||
ADD setup.sql /docker-entrypoint-initdb.d
|
11
db/setup.sql
Normal file
11
db/setup.sql
Normal file
@ -0,0 +1,11 @@
|
||||
-- Prisma requires DB creation privileges to create a shadow database (https://pris.ly/d/migrate-shadow)
|
||||
-- This is not available to our user by default, so we must manually add this
|
||||
|
||||
-- Create the user
|
||||
CREATE USER IF NOT EXISTS 'reworkd_platform'@'%' IDENTIFIED BY 'reworkd_platform';
|
||||
|
||||
-- Grant the necessary permissions
|
||||
GRANT CREATE, ALTER, DROP, INSERT, UPDATE, DELETE, SELECT ON *.* TO 'reworkd_platform'@'%';
|
||||
|
||||
-- Apply the changes
|
||||
FLUSH PRIVILEGES;
|
145
platform/.dockerignore
Normal file
145
platform/.dockerignore
Normal file
@ -0,0 +1,145 @@
|
||||
### Python template
|
||||
|
||||
deploy/
|
||||
.idea/
|
||||
.vscode/
|
||||
.git/
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
poetry.lock
|
31
platform/.editorconfig
Normal file
31
platform/.editorconfig
Normal file
@ -0,0 +1,31 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
tab_width = 4
|
||||
end_of_line = lf
|
||||
max_line_length = 88
|
||||
ij_visual_guides = 88
|
||||
insert_final_newline = true
|
||||
trim_trailing_whitespace = true
|
||||
|
||||
[*.{js,py,html}]
|
||||
charset = utf-8
|
||||
|
||||
[*.md]
|
||||
trim_trailing_whitespace = false
|
||||
|
||||
[*.{yml,yaml}]
|
||||
indent_style = space
|
||||
indent_size = 2
|
||||
|
||||
[Makefile]
|
||||
indent_style = tab
|
||||
|
||||
[.flake8]
|
||||
indent_style = space
|
||||
indent_size = 2
|
||||
|
||||
[*.py]
|
||||
indent_style = space
|
||||
indent_size = 4
|
||||
ij_python_from_import_parentheses_force_if_multiline = true
|
115
platform/.flake8
Normal file
115
platform/.flake8
Normal file
@ -0,0 +1,115 @@
|
||||
[flake8]
|
||||
max-complexity = 6
|
||||
inline-quotes = double
|
||||
max-line-length = 88
|
||||
extend-ignore = E203
|
||||
docstring_style=sphinx
|
||||
|
||||
ignore =
|
||||
; Found `f` string
|
||||
WPS305,
|
||||
; Missing docstring in public module
|
||||
D100,
|
||||
; Missing docstring in magic method
|
||||
D105,
|
||||
; Missing docstring in __init__
|
||||
D107,
|
||||
; Found `__init__.py` module with logic
|
||||
WPS412,
|
||||
; Found class without a base class
|
||||
WPS306,
|
||||
; Missing docstring in public nested class
|
||||
D106,
|
||||
; First line should be in imperative mood
|
||||
D401,
|
||||
; Found wrong variable name
|
||||
WPS110,
|
||||
; Found `__init__.py` module with logic
|
||||
WPS326,
|
||||
; Found string constant over-use
|
||||
WPS226,
|
||||
; Found upper-case constant in a class
|
||||
WPS115,
|
||||
; Found nested function
|
||||
WPS602,
|
||||
; Found method without arguments
|
||||
WPS605,
|
||||
; Found overused expression
|
||||
WPS204,
|
||||
; Found too many module members
|
||||
WPS202,
|
||||
; Found too high module cognitive complexity
|
||||
WPS232,
|
||||
; line break before binary operator
|
||||
W503,
|
||||
; Found module with too many imports
|
||||
WPS201,
|
||||
; Inline strong start-string without end-string.
|
||||
RST210,
|
||||
; Found nested class
|
||||
WPS431,
|
||||
; Found wrong module name
|
||||
WPS100,
|
||||
; Found too many methods
|
||||
WPS214,
|
||||
; Found too long ``try`` body
|
||||
WPS229,
|
||||
; Found unpythonic getter or setter
|
||||
WPS615,
|
||||
; Found a line that starts with a dot
|
||||
WPS348,
|
||||
; Found complex default value (for dependency injection)
|
||||
WPS404,
|
||||
; not perform function calls in argument defaults (for dependency injection)
|
||||
B008,
|
||||
; Model should define verbose_name in its Meta inner class
|
||||
DJ10,
|
||||
; Model should define verbose_name_plural in its Meta inner class
|
||||
DJ11,
|
||||
; Found mutable module constant.
|
||||
WPS407,
|
||||
; Found too many empty lines in `def`
|
||||
WPS473,
|
||||
; Found missing trailing comma
|
||||
C812,
|
||||
|
||||
per-file-ignores =
|
||||
; all tests
|
||||
test_*.py,tests.py,tests_*.py,*/tests/*,conftest.py:
|
||||
; Use of assert detected
|
||||
S101,
|
||||
; Found outer scope names shadowing
|
||||
WPS442,
|
||||
; Found too many local variables
|
||||
WPS210,
|
||||
; Found magic number
|
||||
WPS432,
|
||||
; Missing parameter(s) in Docstring
|
||||
DAR101,
|
||||
; Found too many arguments
|
||||
WPS211,
|
||||
|
||||
; all init files
|
||||
__init__.py:
|
||||
; ignore not used imports
|
||||
F401,
|
||||
; ignore import with wildcard
|
||||
F403,
|
||||
; Found wrong metadata variable
|
||||
WPS410,
|
||||
|
||||
exclude =
|
||||
./.cache,
|
||||
./.git,
|
||||
./.idea,
|
||||
./.mypy_cache,
|
||||
./.pytest_cache,
|
||||
./.venv,
|
||||
./venv,
|
||||
./env,
|
||||
./cached_venv,
|
||||
./docs,
|
||||
./deploy,
|
||||
./var,
|
||||
./.vscode,
|
||||
*migrations*,
|
144
platform/.gitignore
vendored
Normal file
144
platform/.gitignore
vendored
Normal file
@ -0,0 +1,144 @@
|
||||
### Python template
|
||||
|
||||
.idea/
|
||||
.vscode/
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.pem
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
ssl/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
*.sqlite3
|
||||
*.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
.env
|
63
platform/.pre-commit-config.yaml
Normal file
63
platform/.pre-commit-config.yaml
Normal file
@ -0,0 +1,63 @@
|
||||
---
|
||||
# See https://pre-commit.com for more information
|
||||
# See https://pre-commit.com/hooks.html for more hooks
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v2.4.0
|
||||
hooks:
|
||||
- id: check-ast
|
||||
- id: trailing-whitespace
|
||||
- id: check-toml
|
||||
- id: end-of-file-fixer
|
||||
|
||||
- repo: https://github.com/asottile/add-trailing-comma
|
||||
rev: v2.1.0
|
||||
hooks:
|
||||
- id: add-trailing-comma
|
||||
|
||||
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
|
||||
rev: v2.1.0
|
||||
hooks:
|
||||
- id: pretty-format-yaml
|
||||
args:
|
||||
- --autofix
|
||||
- --preserve-quotes
|
||||
- --indent=2
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: black
|
||||
name: Format with Black
|
||||
entry: poetry run black
|
||||
language: system
|
||||
types: [python]
|
||||
|
||||
- id: autoflake
|
||||
name: autoflake
|
||||
entry: poetry run autoflake
|
||||
language: system
|
||||
types: [python]
|
||||
args: [--in-place, --remove-all-unused-imports, --remove-duplicate-keys]
|
||||
|
||||
- id: isort
|
||||
name: isort
|
||||
entry: poetry run isort
|
||||
language: system
|
||||
types: [python]
|
||||
|
||||
- id: flake8
|
||||
name: Check with Flake8
|
||||
entry: poetry run flake8
|
||||
language: system
|
||||
pass_filenames: false
|
||||
types: [python]
|
||||
args: [--count, .]
|
||||
|
||||
- id: mypy
|
||||
name: Validate types with MyPy
|
||||
entry: poetry run mypy
|
||||
language: system
|
||||
types: [python]
|
||||
pass_filenames: false
|
||||
args:
|
||||
- "reworkd_platform"
|
37
platform/Dockerfile
Normal file
37
platform/Dockerfile
Normal file
@ -0,0 +1,37 @@
|
||||
FROM python:3.11-slim-buster as prod
|
||||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
default-libmysqlclient-dev \
|
||||
gcc \
|
||||
pkg-config \
|
||||
openjdk-11-jdk \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip install poetry==1.4.2
|
||||
|
||||
# Configuring poetry
|
||||
RUN poetry config virtualenvs.create false
|
||||
|
||||
# Copying requirements of a project
|
||||
COPY pyproject.toml /app/src/
|
||||
WORKDIR /app/src
|
||||
|
||||
# Installing requirements
|
||||
RUN poetry install --only main
|
||||
# Removing gcc
|
||||
RUN apt-get purge -y \
|
||||
g++ \
|
||||
gcc \
|
||||
pkg-config \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copying actual application
|
||||
COPY . /app/src/
|
||||
RUN poetry install --only main
|
||||
|
||||
CMD ["/usr/local/bin/python", "-m", "reworkd_platform"]
|
||||
|
||||
FROM prod as dev
|
||||
|
||||
RUN poetry install
|
151
platform/README.md
Normal file
151
platform/README.md
Normal file
@ -0,0 +1,151 @@
|
||||
# reworkd_platform
|
||||
|
||||
This project was generated using fastapi_template.
|
||||
|
||||
## Poetry
|
||||
|
||||
This project uses poetry. It's a modern dependency management
|
||||
tool.
|
||||
|
||||
To run the project use this set of commands:
|
||||
|
||||
```bash
|
||||
poetry install
|
||||
poetry run python -m reworkd_platform
|
||||
```
|
||||
|
||||
This will start the server on the configured host.
|
||||
|
||||
You can find swagger documentation at `/api/docs`.
|
||||
|
||||
You can read more about poetry here: https://python-poetry.org/
|
||||
|
||||
## Docker
|
||||
|
||||
You can start the project with docker using this command:
|
||||
|
||||
```bash
|
||||
docker-compose -f deploy/docker-compose.yml --project-directory . up --build
|
||||
```
|
||||
|
||||
If you want to develop in docker with autoreload add `-f deploy/docker-compose.dev.yml` to your docker command.
|
||||
Like this:
|
||||
|
||||
```bash
|
||||
docker-compose -f deploy/docker-compose.yml -f deploy/docker-compose.dev.yml --project-directory . up --build
|
||||
```
|
||||
|
||||
This command exposes the web application on port 8000, mounts current directory and enables autoreload.
|
||||
|
||||
But you have to rebuild image every time you modify `poetry.lock` or `pyproject.toml` with this command:
|
||||
|
||||
```bash
|
||||
docker-compose -f deploy/docker-compose.yml --project-directory . build
|
||||
```
|
||||
|
||||
## Project structure
|
||||
|
||||
```bash
|
||||
$ tree "reworkd_platform"
|
||||
reworkd_platform
|
||||
├── conftest.py # Fixtures for all tests.
|
||||
├── db # module contains db configurations
|
||||
│ ├── dao # Data Access Objects. Contains different classes to interact with database.
|
||||
│ └── models # Package contains different models for ORMs.
|
||||
├── __main__.py # Startup script. Starts uvicorn.
|
||||
├── services # Package for different external services such as rabbit or redis etc.
|
||||
├── settings.py # Main configuration settings for project.
|
||||
├── static # Static content.
|
||||
├── tests # Tests for project.
|
||||
└── web # Package contains web server. Handlers, startup config.
|
||||
├── api # Package with all handlers.
|
||||
│ └── router.py # Main router.
|
||||
├── application.py # FastAPI application configuration.
|
||||
└── lifetime.py # Contains actions to perform on startup and shutdown.
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
This application can be configured with environment variables.
|
||||
|
||||
You can create `.env` file in the root directory and place all
|
||||
environment variables here.
|
||||
|
||||
All environment variables should start with "REWORKD_PLATFORM_" prefix.
|
||||
|
||||
For example if you see in your "reworkd_platform/settings.py" a variable named like
|
||||
`random_parameter`, you should provide the "REWORKD_PLATFORM_RANDOM_PARAMETER"
|
||||
variable to configure the value. This behaviour can be changed by overriding `env_prefix` property
|
||||
in `reworkd_platform.settings.Settings.Config`.
|
||||
|
||||
An example of .env file:
|
||||
|
||||
```bash
|
||||
REWORKD_PLATFORM_RELOAD="True"
|
||||
REWORKD_PLATFORM_PORT="8000"
|
||||
REWORKD_PLATFORM_ENVIRONMENT="development"
|
||||
```
|
||||
|
||||
You can read more about BaseSettings class here: https://pydantic-docs.helpmanual.io/usage/settings/
|
||||
|
||||
## Pre-commit
|
||||
|
||||
To install pre-commit simply run inside the shell:
|
||||
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
pre-commit is very useful to check your code before publishing it.
|
||||
It's configured using .pre-commit-config.yaml file.
|
||||
|
||||
By default it runs:
|
||||
|
||||
* black (formats your code);
|
||||
* mypy (validates types);
|
||||
* isort (sorts imports in all files);
|
||||
* flake8 (spots possibe bugs);
|
||||
|
||||
You can read more about pre-commit here: https://pre-commit.com/
|
||||
|
||||
## Running tests
|
||||
|
||||
If you want to run it in docker, simply run:
|
||||
|
||||
```bash
|
||||
docker-compose -f deploy/docker-compose.yml -f deploy/docker-compose.dev.yml --project-directory . run --build --rm api pytest -vv .
|
||||
docker-compose -f deploy/docker-compose.yml -f deploy/docker-compose.dev.yml --project-directory . down
|
||||
```
|
||||
|
||||
For running tests on your local machine.
|
||||
|
||||
1. you need to start a database.
|
||||
|
||||
I prefer doing it with docker:
|
||||
|
||||
```
|
||||
docker run -p "3306:3306" -e "MYSQL_PASSWORD=reworkd_platform" -e "MYSQL_USER=reworkd_platform" -e "MYSQL_DATABASE=reworkd_platform" -e ALLOW_EMPTY_PASSWORD=yes bitnami/mysql:8.0.30
|
||||
```
|
||||
|
||||
2. Run the pytest.
|
||||
|
||||
```bash
|
||||
pytest -vv .
|
||||
```
|
||||
|
||||
## Running linters
|
||||
|
||||
```bash
|
||||
# Flake
|
||||
poetry run black .
|
||||
poetry run autoflake --in-place --remove-duplicate-keys --remove-all-unused-imports -r .
|
||||
poetry run flake8
|
||||
poetry run mypy .
|
||||
|
||||
# Pytest
|
||||
poetry run pytest -vv --cov="reworkd_platform" .
|
||||
|
||||
# Bump packages
|
||||
poetry self add poetry-plugin-up
|
||||
poetry up --latest
|
||||
```
|
14
platform/entrypoint.sh
Normal file
14
platform/entrypoint.sh
Normal file
@ -0,0 +1,14 @@
|
||||
#!/usr/bin/env sh
|
||||
|
||||
host=agentgpt_db
|
||||
port=3306
|
||||
|
||||
until echo "SELECT 1;" | nc "$host" "$port" > /dev/null 2>&1; do
|
||||
>&2 echo "Database is unavailable - Sleeping..."
|
||||
sleep 2
|
||||
done
|
||||
|
||||
>&2 echo "Database is available! Continuing..."
|
||||
|
||||
# Run cmd
|
||||
exec "$@"
|
3413
platform/poetry.lock
generated
Normal file
3413
platform/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
100
platform/pyproject.toml
Normal file
100
platform/pyproject.toml
Normal file
@ -0,0 +1,100 @@
|
||||
[tool.poetry]
|
||||
name = "reworkd_platform"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = [
|
||||
"awtkns",
|
||||
"asim-shrestha"
|
||||
]
|
||||
|
||||
maintainers = [
|
||||
"reworkd"
|
||||
]
|
||||
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
fastapi = "^0.98.0"
|
||||
boto3 = "^1.28.51"
|
||||
uvicorn = { version = "^0.22.0", extras = ["standard"] }
|
||||
pydantic = { version = "<2.0", extras = ["dotenv"] }
|
||||
ujson = "^5.8.0"
|
||||
sqlalchemy = { version = "^2.0.21", extras = ["mypy", "asyncio"] }
|
||||
aiomysql = "^0.1.1"
|
||||
mysqlclient = "^2.2.0"
|
||||
sentry-sdk = "^1.31.0"
|
||||
loguru = "^0.7.2"
|
||||
aiokafka = "^0.8.1"
|
||||
requests = "^2.31.0"
|
||||
langchain = "^0.0.295"
|
||||
openai = "^0.28.0"
|
||||
wikipedia = "^1.4.0"
|
||||
replicate = "^0.8.4"
|
||||
lanarky = "0.7.15"
|
||||
tiktoken = "^0.5.1"
|
||||
grpcio = "^1.58.0"
|
||||
pinecone-client = "^2.2.4"
|
||||
python-multipart = "^0.0.6"
|
||||
aws-secretsmanager-caching = "^1.1.1.5"
|
||||
botocore = "^1.31.51"
|
||||
stripe = "^5.5.0"
|
||||
cryptography = "^41.0.4"
|
||||
httpx = "^0.25.0"
|
||||
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
autopep8 = "^2.0.4"
|
||||
pytest = "^7.4.2"
|
||||
flake8 = "~6.0.0"
|
||||
mypy = "^1.5.1"
|
||||
isort = "^5.12.0"
|
||||
pre-commit = "^3.4.0"
|
||||
wemake-python-styleguide = "^0.18.0"
|
||||
black = "^23.9.1"
|
||||
autoflake = "^2.2.1"
|
||||
pytest-cov = "^4.1.0"
|
||||
anyio = "^3.7.1"
|
||||
pytest-env = "^0.8.2"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
dotmap = "^1.3.30"
|
||||
pytest-mock = "^3.10.0"
|
||||
pytest-asyncio = "^0.21.0"
|
||||
mypy = "^1.4.1"
|
||||
types-requests = "^2.31.0.1"
|
||||
types-pytz = "^2023.3.0.0"
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
multi_line_output = 3
|
||||
src_paths = ["reworkd_platform"]
|
||||
|
||||
[tool.mypy]
|
||||
strict = true
|
||||
ignore_missing_imports = true
|
||||
allow_subclassing_any = true
|
||||
allow_untyped_calls = true
|
||||
pretty = true
|
||||
show_error_codes = true
|
||||
implicit_reexport = true
|
||||
allow_untyped_decorators = true
|
||||
warn_unused_ignores = false
|
||||
warn_return_any = false
|
||||
namespace_packages = true
|
||||
exclude = "tests"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
filterwarnings = [
|
||||
"error",
|
||||
"ignore::DeprecationWarning",
|
||||
"ignore:.*unclosed.*:ResourceWarning",
|
||||
"ignore::ImportWarning",
|
||||
]
|
||||
env = [
|
||||
"REWORKD_PLATFORM_DB_BASE=reworkd_platform_test",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
1
platform/reworkd_platform/__init__.py
Normal file
1
platform/reworkd_platform/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""reworkd_platform package."""
|
20
platform/reworkd_platform/__main__.py
Normal file
20
platform/reworkd_platform/__main__.py
Normal file
@ -0,0 +1,20 @@
|
||||
import uvicorn
|
||||
|
||||
from reworkd_platform.settings import settings
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Entrypoint of the application."""
|
||||
uvicorn.run(
|
||||
"reworkd_platform.web.application:get_app",
|
||||
workers=settings.workers_count,
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
reload=settings.reload,
|
||||
log_level="debug",
|
||||
factory=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
107
platform/reworkd_platform/conftest.py
Normal file
107
platform/reworkd_platform/conftest.py
Normal file
@ -0,0 +1,107 @@
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import AsyncClient
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from reworkd_platform.db.dependencies import get_db_session
|
||||
from reworkd_platform.db.utils import create_database, drop_database
|
||||
from reworkd_platform.settings import settings
|
||||
from reworkd_platform.web.application import get_app
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def anyio_backend() -> str:
|
||||
"""
|
||||
Backend for anyio pytest plugin.
|
||||
|
||||
:return: backend name.
|
||||
"""
|
||||
return "asyncio"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def _engine() -> AsyncGenerator[AsyncEngine, None]:
|
||||
"""
|
||||
Create engine and databases.
|
||||
|
||||
:yield: new engine.
|
||||
"""
|
||||
from reworkd_platform.db.meta import meta # noqa: WPS433
|
||||
from reworkd_platform.db.models import load_all_models # noqa: WPS433
|
||||
|
||||
load_all_models()
|
||||
|
||||
await create_database()
|
||||
|
||||
engine = create_async_engine(str(settings.db_url))
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(meta.create_all)
|
||||
|
||||
try:
|
||||
yield engine
|
||||
finally:
|
||||
await engine.dispose()
|
||||
await drop_database()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dbsession(
|
||||
_engine: AsyncEngine,
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Get session to database.
|
||||
|
||||
Fixture that returns a SQLAlchemy session with a SAVEPOINT, and the rollback to it
|
||||
after the test completes.
|
||||
|
||||
:param _engine: current engine.
|
||||
:yields: async session.
|
||||
"""
|
||||
connection = await _engine.connect()
|
||||
trans = await connection.begin()
|
||||
|
||||
session_maker = async_sessionmaker(
|
||||
connection,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
session = session_maker()
|
||||
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
await trans.rollback()
|
||||
await connection.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fastapi_app(dbsession: AsyncSession) -> FastAPI:
|
||||
"""
|
||||
Fixture for creating FastAPI app.
|
||||
|
||||
:return: fastapi app with mocked dependencies.
|
||||
"""
|
||||
application = get_app()
|
||||
application.dependency_overrides[get_db_session] = lambda: dbsession
|
||||
return application # noqa: WPS331
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(
|
||||
fastapi_app: FastAPI, anyio_backend: Any
|
||||
) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""
|
||||
Fixture that creates client for requesting server.
|
||||
|
||||
:param fastapi_app: the application.
|
||||
:yield: client for the app.
|
||||
"""
|
||||
async with AsyncClient(app=fastapi_app, base_url="http://test") as ac:
|
||||
yield ac
|
1
platform/reworkd_platform/constants.py
Normal file
1
platform/reworkd_platform/constants.py
Normal file
@ -0,0 +1 @@
|
||||
ENV_PREFIX = "REWORKD_PLATFORM_"
|
0
platform/reworkd_platform/db/__init__.py
Normal file
0
platform/reworkd_platform/db/__init__.py
Normal file
68
platform/reworkd_platform/db/base.py
Normal file
68
platform/reworkd_platform/db/base.py
Normal file
@ -0,0 +1,68 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional, Type, TypeVar
|
||||
|
||||
from sqlalchemy import DateTime, String, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
from reworkd_platform.db.meta import meta
|
||||
from reworkd_platform.web.api.http_responses import not_found
|
||||
|
||||
T = TypeVar("T", bound="Base")
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base for all models."""
|
||||
|
||||
metadata = meta
|
||||
id: Mapped[str] = mapped_column(
|
||||
String,
|
||||
primary_key=True,
|
||||
default=lambda _: str(uuid.uuid4()),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get(cls: Type[T], session: AsyncSession, id_: str) -> Optional[T]:
|
||||
return await session.get(cls, id_)
|
||||
|
||||
@classmethod
|
||||
async def get_or_404(cls: Type[T], session: AsyncSession, id_: str) -> T:
|
||||
if model := await cls.get(session, id_):
|
||||
return model
|
||||
|
||||
raise not_found(detail=f"{cls.__name__}[{id_}] not found")
|
||||
|
||||
async def save(self: T, session: AsyncSession) -> T:
|
||||
session.add(self)
|
||||
await session.flush()
|
||||
return self
|
||||
|
||||
async def delete(self: T, session: AsyncSession) -> None:
|
||||
await session.delete(self)
|
||||
|
||||
|
||||
class TrackedModel(Base):
|
||||
"""Base for all tracked models."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
create_date = mapped_column(
|
||||
DateTime, name="create_date", server_default=func.now(), nullable=False
|
||||
)
|
||||
update_date = mapped_column(
|
||||
DateTime, name="update_date", onupdate=func.now(), nullable=True
|
||||
)
|
||||
delete_date = mapped_column(DateTime, name="delete_date", nullable=True)
|
||||
|
||||
async def delete(self, session: AsyncSession) -> None:
|
||||
"""Marks the model as deleted."""
|
||||
self.delete_date = datetime.now()
|
||||
await self.save(session)
|
||||
|
||||
|
||||
class UserMixin:
|
||||
user_id = mapped_column(String, name="user_id", nullable=False)
|
||||
organization_id = mapped_column(String, name="organization_id", nullable=True)
|
0
platform/reworkd_platform/db/crud/__init__.py
Normal file
0
platform/reworkd_platform/db/crud/__init__.py
Normal file
58
platform/reworkd_platform/db/crud/agent.py
Normal file
58
platform/reworkd_platform/db/crud/agent.py
Normal file
@ -0,0 +1,58 @@
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import and_, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reworkd_platform.db.crud.base import BaseCrud
|
||||
from reworkd_platform.db.models.agent import AgentRun, AgentTask
|
||||
from reworkd_platform.schemas.agent import Loop_Step
|
||||
from reworkd_platform.schemas.user import UserBase
|
||||
from reworkd_platform.settings import settings
|
||||
from reworkd_platform.web.api.errors import MaxLoopsError, MultipleSummaryError
|
||||
|
||||
|
||||
class AgentCRUD(BaseCrud):
|
||||
def __init__(self, session: AsyncSession, user: UserBase):
|
||||
super().__init__(session)
|
||||
self.user = user
|
||||
|
||||
async def create_run(self, goal: str) -> AgentRun:
|
||||
return await AgentRun(
|
||||
user_id=self.user.id,
|
||||
goal=goal,
|
||||
).save(self.session)
|
||||
|
||||
async def create_task(self, run_id: str, type_: Loop_Step) -> AgentTask:
|
||||
await self.validate_task_count(run_id, type_)
|
||||
return await AgentTask(
|
||||
run_id=run_id,
|
||||
type_=type_,
|
||||
).save(self.session)
|
||||
|
||||
async def validate_task_count(self, run_id: str, type_: str) -> None:
|
||||
if not await AgentRun.get(self.session, run_id):
|
||||
raise HTTPException(404, f"Run {run_id} not found")
|
||||
|
||||
query = select(func.count(AgentTask.id)).where(
|
||||
and_(
|
||||
AgentTask.run_id == run_id,
|
||||
AgentTask.type_ == type_,
|
||||
)
|
||||
)
|
||||
|
||||
task_count = (await self.session.execute(query)).scalar_one()
|
||||
max_ = settings.max_loops
|
||||
|
||||
if task_count >= max_:
|
||||
raise MaxLoopsError(
|
||||
StopIteration(),
|
||||
f"Max loops of {max_} exceeded, shutting down.",
|
||||
429,
|
||||
should_log=False,
|
||||
)
|
||||
|
||||
if type_ == "summarize" and task_count > 1:
|
||||
raise MultipleSummaryError(
|
||||
StopIteration(),
|
||||
"Multiple summary tasks are not allowed",
|
||||
429,
|
||||
)
|
10
platform/reworkd_platform/db/crud/base.py
Normal file
10
platform/reworkd_platform/db/crud/base.py
Normal file
@ -0,0 +1,10 @@
|
||||
from typing import TypeVar
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
T = TypeVar("T", bound="BaseCrud")
|
||||
|
||||
|
||||
class BaseCrud:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
77
platform/reworkd_platform/db/crud/oauth.py
Normal file
77
platform/reworkd_platform/db/crud/oauth.py
Normal file
@ -0,0 +1,77 @@
|
||||
import secrets
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reworkd_platform.db.crud.base import BaseCrud
|
||||
from reworkd_platform.db.dependencies import get_db_session
|
||||
from reworkd_platform.db.models.auth import OauthCredentials
|
||||
from reworkd_platform.schemas import UserBase
|
||||
|
||||
|
||||
class OAuthCrud(BaseCrud):
|
||||
@classmethod
|
||||
async def inject(
|
||||
cls,
|
||||
session: AsyncSession = Depends(get_db_session),
|
||||
) -> "OAuthCrud":
|
||||
return cls(session)
|
||||
|
||||
async def create_installation(
|
||||
self, user: UserBase, provider: str, redirect_uri: Optional[str]
|
||||
) -> OauthCredentials:
|
||||
return await OauthCredentials(
|
||||
user_id=user.id,
|
||||
organization_id=user.organization_id,
|
||||
provider=provider,
|
||||
state=secrets.token_hex(16),
|
||||
redirect_uri=redirect_uri,
|
||||
).save(self.session)
|
||||
|
||||
async def get_installation_by_state(self, state: str) -> Optional[OauthCredentials]:
|
||||
query = select(OauthCredentials).filter(OauthCredentials.state == state)
|
||||
|
||||
return (await self.session.execute(query)).scalar_one_or_none()
|
||||
|
||||
async def get_installation_by_user_id(
|
||||
self, user_id: str, provider: str
|
||||
) -> Optional[OauthCredentials]:
|
||||
query = select(OauthCredentials).filter(
|
||||
OauthCredentials.user_id == user_id,
|
||||
OauthCredentials.provider == provider,
|
||||
OauthCredentials.access_token_enc.isnot(None),
|
||||
)
|
||||
|
||||
return (await self.session.execute(query)).scalars().first()
|
||||
|
||||
async def get_installation_by_organization_id(
|
||||
self, organization_id: str, provider: str
|
||||
) -> Optional[OauthCredentials]:
|
||||
query = select(OauthCredentials).filter(
|
||||
OauthCredentials.organization_id == organization_id,
|
||||
OauthCredentials.provider == provider,
|
||||
OauthCredentials.access_token_enc.isnot(None),
|
||||
OauthCredentials.organization_id.isnot(None),
|
||||
)
|
||||
|
||||
return (await self.session.execute(query)).scalars().first()
|
||||
|
||||
async def get_all(self, user: UserBase) -> Dict[str, str]:
|
||||
query = (
|
||||
select(
|
||||
OauthCredentials.provider,
|
||||
func.any_value(OauthCredentials.access_token_enc),
|
||||
)
|
||||
.filter(
|
||||
OauthCredentials.access_token_enc.isnot(None),
|
||||
OauthCredentials.organization_id == user.organization_id,
|
||||
)
|
||||
.group_by(OauthCredentials.provider)
|
||||
)
|
||||
|
||||
return {
|
||||
provider: token
|
||||
for provider, token in (await self.session.execute(query)).all()
|
||||
}
|
98
platform/reworkd_platform/db/crud/organization.py
Normal file
98
platform/reworkd_platform/db/crud/organization.py
Normal file
@ -0,0 +1,98 @@
|
||||
# from
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import Depends
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import aliased
|
||||
|
||||
from reworkd_platform.db.crud.base import BaseCrud
|
||||
from reworkd_platform.db.dependencies import get_db_session
|
||||
from reworkd_platform.db.models.auth import Organization, OrganizationUser
|
||||
from reworkd_platform.db.models.user import User
|
||||
from reworkd_platform.schemas import UserBase
|
||||
from reworkd_platform.web.api.dependencies import get_current_user
|
||||
|
||||
|
||||
class OrgUser(BaseModel):
|
||||
id: str
|
||||
role: str
|
||||
user: UserBase
|
||||
|
||||
|
||||
class OrganizationUsers(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
users: List[OrgUser]
|
||||
|
||||
|
||||
class OrganizationCrud(BaseCrud):
|
||||
def __init__(self, session: AsyncSession, user: UserBase):
|
||||
super().__init__(session)
|
||||
self.user = user
|
||||
|
||||
@classmethod
|
||||
def inject(
|
||||
cls,
|
||||
session: AsyncSession = Depends(get_db_session),
|
||||
user: UserBase = Depends(get_current_user),
|
||||
) -> "OrganizationCrud":
|
||||
return cls(session, user)
|
||||
|
||||
async def create_organization(self, name: str) -> Organization:
|
||||
return await Organization(
|
||||
created_by=self.user.id,
|
||||
name=name,
|
||||
).save(self.session)
|
||||
|
||||
async def get_by_name(self, name: str) -> Optional[OrganizationUsers]:
|
||||
owner = aliased(OrganizationUser, name="owner")
|
||||
|
||||
query = (
|
||||
select(
|
||||
Organization,
|
||||
User,
|
||||
OrganizationUser,
|
||||
)
|
||||
.join(
|
||||
OrganizationUser,
|
||||
and_(
|
||||
Organization.id == OrganizationUser.organization_id,
|
||||
),
|
||||
)
|
||||
.join(
|
||||
User,
|
||||
User.id == OrganizationUser.user_id,
|
||||
)
|
||||
.join( # Owner
|
||||
owner,
|
||||
and_(
|
||||
OrganizationUser.organization_id == Organization.id,
|
||||
OrganizationUser.user_id == self.user.id,
|
||||
),
|
||||
)
|
||||
.filter(Organization.name == name)
|
||||
)
|
||||
|
||||
rows = (await self.session.execute(query)).all()
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
org: Organization = rows[0][0]
|
||||
return OrganizationUsers(
|
||||
id=org.id,
|
||||
name=org.name,
|
||||
users=[
|
||||
OrgUser(
|
||||
id=org_user.user_id,
|
||||
role=org_user.role,
|
||||
user=UserBase(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
name=user.name,
|
||||
),
|
||||
)
|
||||
for [_, user, org_user] in rows
|
||||
],
|
||||
)
|
31
platform/reworkd_platform/db/crud/user.py
Normal file
31
platform/reworkd_platform/db/crud/user.py
Normal file
@ -0,0 +1,31 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from reworkd_platform.db.crud.base import BaseCrud
|
||||
from reworkd_platform.db.models.auth import OrganizationUser
|
||||
from reworkd_platform.db.models.user import UserSession
|
||||
|
||||
|
||||
class UserCrud(BaseCrud):
|
||||
async def get_user_session(self, token: str) -> UserSession:
|
||||
query = (
|
||||
select(UserSession)
|
||||
.filter(UserSession.session_token == token)
|
||||
.options(selectinload(UserSession.user))
|
||||
)
|
||||
return (await self.session.execute(query)).scalar_one()
|
||||
|
||||
async def get_user_organization(
|
||||
self, user_id: str, organization_id: str
|
||||
) -> Optional[OrganizationUser]:
|
||||
query = select(OrganizationUser).filter(
|
||||
and_(
|
||||
OrganizationUser.user_id == user_id,
|
||||
OrganizationUser.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: Only returns the first organization
|
||||
return (await self.session.execute(query)).scalar()
|
20
platform/reworkd_platform/db/dependencies.py
Normal file
20
platform/reworkd_platform/db/dependencies.py
Normal file
@ -0,0 +1,20 @@
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from starlette.requests import Request
|
||||
|
||||
|
||||
async def get_db_session(request: Request) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Create and get database session.
|
||||
|
||||
:param request: current request.
|
||||
:yield: database session.
|
||||
"""
|
||||
session: AsyncSession = request.app.state.db_session_factory()
|
||||
|
||||
try: # noqa: WPS501
|
||||
yield session
|
||||
await session.commit()
|
||||
finally:
|
||||
await session.close()
|
3
platform/reworkd_platform/db/meta.py
Normal file
3
platform/reworkd_platform/db/meta.py
Normal file
@ -0,0 +1,3 @@
|
||||
import sqlalchemy as sa
|
||||
|
||||
meta = sa.MetaData()
|
14
platform/reworkd_platform/db/models/__init__.py
Normal file
14
platform/reworkd_platform/db/models/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""reworkd_platform models."""
|
||||
import pkgutil
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def load_all_models() -> None:
|
||||
"""Load all models from this folder."""
|
||||
package_dir = Path(__file__).resolve().parent
|
||||
modules = pkgutil.walk_packages(
|
||||
path=[str(package_dir)],
|
||||
prefix="reworkd_platform.db.models.",
|
||||
)
|
||||
for module in modules:
|
||||
__import__(module.name) # noqa: WPS421
|
24
platform/reworkd_platform/db/models/agent.py
Normal file
24
platform/reworkd_platform/db/models/agent.py
Normal file
@ -0,0 +1,24 @@
|
||||
from sqlalchemy import DateTime, String, Text, func
|
||||
from sqlalchemy.orm import mapped_column
|
||||
|
||||
from reworkd_platform.db.base import Base
|
||||
|
||||
|
||||
class AgentRun(Base):
|
||||
__tablename__ = "agent_run"
|
||||
|
||||
user_id = mapped_column(String, nullable=False)
|
||||
goal = mapped_column(Text, nullable=False)
|
||||
create_date = mapped_column(
|
||||
DateTime, name="create_date", server_default=func.now(), nullable=False
|
||||
)
|
||||
|
||||
|
||||
class AgentTask(Base):
|
||||
__tablename__ = "agent_task"
|
||||
|
||||
run_id = mapped_column(String, nullable=False)
|
||||
type_ = mapped_column(String, nullable=False, name="type")
|
||||
create_date = mapped_column(
|
||||
DateTime, name="create_date", server_default=func.now(), nullable=False
|
||||
)
|
36
platform/reworkd_platform/db/models/auth.py
Normal file
36
platform/reworkd_platform/db/models/auth.py
Normal file
@ -0,0 +1,36 @@
|
||||
from sqlalchemy import DateTime, String
|
||||
from sqlalchemy.orm import mapped_column
|
||||
|
||||
from reworkd_platform.db.base import TrackedModel
|
||||
|
||||
|
||||
class Organization(TrackedModel):
|
||||
__tablename__ = "organization"
|
||||
|
||||
name = mapped_column(String, nullable=False)
|
||||
created_by = mapped_column(String, nullable=False)
|
||||
|
||||
|
||||
class OrganizationUser(TrackedModel):
|
||||
__tablename__ = "organization_user"
|
||||
|
||||
user_id = mapped_column(String, nullable=False)
|
||||
organization_id = mapped_column(String, nullable=False)
|
||||
role = mapped_column(String, nullable=False, default="member")
|
||||
|
||||
|
||||
class OauthCredentials(TrackedModel):
|
||||
__tablename__ = "oauth_credentials"
|
||||
|
||||
user_id = mapped_column(String, nullable=False)
|
||||
organization_id = mapped_column(String, nullable=True)
|
||||
provider = mapped_column(String, nullable=False)
|
||||
state = mapped_column(String, nullable=False)
|
||||
redirect_uri = mapped_column(String, nullable=False)
|
||||
|
||||
# Post-installation
|
||||
token_type = mapped_column(String, nullable=True)
|
||||
access_token_enc = mapped_column(String, nullable=True)
|
||||
access_token_expiration = mapped_column(DateTime, nullable=True)
|
||||
refresh_token_enc = mapped_column(String, nullable=True)
|
||||
scope = mapped_column(String, nullable=True)
|
38
platform/reworkd_platform/db/models/user.py
Normal file
38
platform/reworkd_platform/db/models/user.py
Normal file
@ -0,0 +1,38 @@
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, Index, String, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from reworkd_platform.db.base import Base
|
||||
|
||||
|
||||
class UserSession(Base):
|
||||
__tablename__ = "Session"
|
||||
|
||||
session_token = mapped_column(String, unique=True, name="sessionToken")
|
||||
user_id = mapped_column(
|
||||
String, ForeignKey("User.id", ondelete="CASCADE"), name="userId"
|
||||
)
|
||||
expires = mapped_column(DateTime)
|
||||
|
||||
user = relationship("User")
|
||||
|
||||
__table_args__ = (Index("user_id"),)
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "User"
|
||||
|
||||
name = mapped_column(String, nullable=True)
|
||||
email = mapped_column(String, nullable=True, unique=True)
|
||||
email_verified = mapped_column(DateTime, nullable=True, name="emailVerified")
|
||||
image = mapped_column(String, nullable=True)
|
||||
create_date = mapped_column(
|
||||
DateTime, server_default=text("(now())"), name="createDate"
|
||||
)
|
||||
|
||||
sessions: Mapped[List["UserSession"]] = relationship(
|
||||
"UserSession", back_populates="user"
|
||||
)
|
||||
|
||||
__table_args__ = (Index("email"),)
|
57
platform/reworkd_platform/db/utils.py
Normal file
57
platform/reworkd_platform/db/utils.py
Normal file
@ -0,0 +1,57 @@
|
||||
from ssl import CERT_REQUIRED
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||
|
||||
from reworkd_platform.services.ssl import get_ssl_context
|
||||
from reworkd_platform.settings import settings
|
||||
|
||||
|
||||
def create_engine() -> AsyncEngine:
|
||||
"""
|
||||
Creates SQLAlchemy engine instance.
|
||||
|
||||
:return: SQLAlchemy engine instance.
|
||||
"""
|
||||
if settings.environment == "development":
|
||||
return create_async_engine(
|
||||
str(settings.db_url),
|
||||
echo=settings.db_echo,
|
||||
)
|
||||
|
||||
ssl_context = get_ssl_context(settings)
|
||||
ssl_context.verify_mode = CERT_REQUIRED
|
||||
connect_args = {"ssl": ssl_context}
|
||||
|
||||
return create_async_engine(
|
||||
str(settings.db_url),
|
||||
echo=settings.db_echo,
|
||||
connect_args=connect_args,
|
||||
)
|
||||
|
||||
|
||||
async def create_database() -> None:
|
||||
"""Create a database."""
|
||||
engine = create_async_engine(str(settings.db_url.with_path("/mysql")))
|
||||
|
||||
async with engine.connect() as conn:
|
||||
database_existance = await conn.execute(
|
||||
text(
|
||||
"SELECT 1 FROM INFORMATION_SCHEMA.SCHEMATA" # noqa: S608
|
||||
f" WHERE SCHEMA_NAME='{settings.db_base}';",
|
||||
)
|
||||
)
|
||||
database_exists = database_existance.scalar() == 1
|
||||
|
||||
if database_exists:
|
||||
await drop_database()
|
||||
|
||||
async with engine.connect() as conn: # noqa: WPS440
|
||||
await conn.execute(text(f"CREATE DATABASE {settings.db_base};"))
|
||||
|
||||
|
||||
async def drop_database() -> None:
|
||||
"""Drop current database."""
|
||||
engine = create_async_engine(str(settings.db_url.with_path("/mysql")))
|
||||
async with engine.connect() as conn:
|
||||
await conn.execute(text(f"DROP DATABASE {settings.db_base};"))
|
63
platform/reworkd_platform/logging.py
Normal file
63
platform/reworkd_platform/logging.py
Normal file
@ -0,0 +1,63 @@
|
||||
import logging
|
||||
import sys
|
||||
from typing import Union
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from reworkd_platform.settings import settings
|
||||
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
"""
|
||||
Default handler from examples in loguru documentation.
|
||||
|
||||
This handler intercepts all log requests and
|
||||
passes them to loguru.
|
||||
|
||||
For more info see:
|
||||
https://loguru.readthedocs.io/en/stable/overview.html#entirely-compatible-with-standard-logging
|
||||
"""
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None: # pragma: no cover
|
||||
"""
|
||||
Propagates logs to loguru.
|
||||
|
||||
:param record: record to log.
|
||||
"""
|
||||
try:
|
||||
level: Union[str, int] = logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
|
||||
# Find caller from where originated the logged message
|
||||
frame, depth = logging.currentframe(), 2
|
||||
while frame.f_code.co_filename == logging.__file__:
|
||||
frame = frame.f_back # type: ignore
|
||||
depth += 1
|
||||
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(
|
||||
level,
|
||||
record.getMessage(),
|
||||
)
|
||||
|
||||
|
||||
def configure_logging() -> None: # pragma: no cover
|
||||
"""Configures logging."""
|
||||
intercept_handler = InterceptHandler()
|
||||
|
||||
logging.basicConfig(handlers=[intercept_handler], level=logging.NOTSET)
|
||||
|
||||
for logger_name in logging.root.manager.loggerDict:
|
||||
if logger_name.startswith("uvicorn."):
|
||||
logging.getLogger(logger_name).handlers = []
|
||||
|
||||
# change handler for default uvicorn logger
|
||||
logging.getLogger("uvicorn").handlers = [intercept_handler]
|
||||
logging.getLogger("uvicorn.access").handlers = [intercept_handler]
|
||||
|
||||
# set logs output, level and format
|
||||
logger.remove()
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
level=settings.log_level,
|
||||
)
|
2
platform/reworkd_platform/schemas/__init__.py
Normal file
2
platform/reworkd_platform/schemas/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .agent import ModelSettings
|
||||
from .user import UserBase
|
87
platform/reworkd_platform/schemas/agent.py
Normal file
87
platform/reworkd_platform/schemas/agent.py
Normal file
@ -0,0 +1,87 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from reworkd_platform.web.api.agent.analysis import Analysis
|
||||
|
||||
LLM_Model = Literal[
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-4o",
|
||||
]
|
||||
Loop_Step = Literal[
|
||||
"start",
|
||||
"analyze",
|
||||
"execute",
|
||||
"create",
|
||||
"summarize",
|
||||
"chat",
|
||||
]
|
||||
LLM_MODEL_MAX_TOKENS: Dict[LLM_Model, int] = {
|
||||
"gpt-3.5-turbo": 4000,
|
||||
"gpt-3.5-turbo-16k": 16000,
|
||||
"gpt-4o": 8000,
|
||||
}
|
||||
|
||||
|
||||
class ModelSettings(BaseModel):
|
||||
model: LLM_Model = Field(default="gpt-4o")
|
||||
custom_api_key: Optional[str] = Field(default=None)
|
||||
temperature: float = Field(default=0.9, ge=0.0, le=1.0)
|
||||
max_tokens: int = Field(default=500, ge=0)
|
||||
language: str = Field(default="English")
|
||||
|
||||
@validator("max_tokens")
|
||||
def validate_max_tokens(cls, v: float, values: Dict[str, Any]) -> float:
|
||||
model = values["model"]
|
||||
if v > (max_tokens := LLM_MODEL_MAX_TOKENS[model]):
|
||||
raise ValueError(f"Model {model} only supports {max_tokens} tokens")
|
||||
return v
|
||||
|
||||
|
||||
class AgentRunCreate(BaseModel):
|
||||
goal: str
|
||||
model_settings: ModelSettings = Field(default=ModelSettings())
|
||||
|
||||
|
||||
class AgentRun(AgentRunCreate):
|
||||
run_id: str
|
||||
|
||||
|
||||
class AgentTaskAnalyze(AgentRun):
|
||||
task: str
|
||||
tool_names: List[str] = Field(default=[])
|
||||
model_settings: ModelSettings = Field(default=ModelSettings())
|
||||
|
||||
|
||||
class AgentTaskExecute(AgentRun):
|
||||
task: str
|
||||
analysis: Analysis
|
||||
|
||||
|
||||
class AgentTaskCreate(AgentRun):
|
||||
tasks: List[str] = Field(default=[])
|
||||
last_task: Optional[str] = Field(default=None)
|
||||
result: Optional[str] = Field(default=None)
|
||||
completed_tasks: List[str] = Field(default=[])
|
||||
|
||||
|
||||
class AgentSummarize(AgentRun):
|
||||
results: List[str] = Field(default=[])
|
||||
|
||||
|
||||
class AgentChat(AgentRun):
|
||||
message: str
|
||||
results: List[str] = Field(default=[])
|
||||
|
||||
|
||||
class NewTasksResponse(BaseModel):
|
||||
run_id: str
|
||||
new_tasks: List[str] = Field(alias="newTasks")
|
||||
|
||||
|
||||
class RunCount(BaseModel):
|
||||
count: int
|
||||
first_run: Optional[datetime]
|
||||
last_run: Optional[datetime]
|
21
platform/reworkd_platform/schemas/user.py
Normal file
21
platform/reworkd_platform/schemas/user.py
Normal file
@ -0,0 +1,21 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class OrganizationRole(BaseModel):
|
||||
id: str
|
||||
role: str
|
||||
organization_id: str
|
||||
|
||||
|
||||
class UserBase(BaseModel):
|
||||
id: str
|
||||
name: Optional[str]
|
||||
email: Optional[str]
|
||||
image: Optional[str] = Field(default=None)
|
||||
organization: Optional[OrganizationRole] = Field(default=None)
|
||||
|
||||
@property
|
||||
def organization_id(self) -> Optional[str]:
|
||||
return self.organization.organization_id if self.organization else None
|
1
platform/reworkd_platform/services/__init__.py
Normal file
1
platform/reworkd_platform/services/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Services for reworkd_platform."""
|
42
platform/reworkd_platform/services/anthropic.py
Normal file
42
platform/reworkd_platform/services/anthropic.py
Normal file
@ -0,0 +1,42 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AbstractPrompt(BaseModel):
|
||||
def to_string(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class HumanAssistantPrompt(AbstractPrompt):
|
||||
assistant_prompt: str
|
||||
human_prompt: str
|
||||
|
||||
def to_string(self) -> str:
|
||||
return (
|
||||
f"""\n\nHuman: {self.human_prompt}\n\nAssistant: {self.assistant_prompt}"""
|
||||
)
|
||||
|
||||
|
||||
class ClaudeService:
|
||||
def __init__(self, api_key: Optional[str], model: str = "claude-2"):
|
||||
self.claude = AsyncAnthropic(api_key=api_key)
|
||||
self.model = model
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
prompt: AbstractPrompt,
|
||||
max_tokens_to_sample: int,
|
||||
temperature: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
return (
|
||||
await self.claude.completions.create(
|
||||
model=self.model,
|
||||
prompt=prompt.to_string(),
|
||||
max_tokens_to_sample=max_tokens_to_sample,
|
||||
temperature=temperature,
|
||||
**kwargs,
|
||||
)
|
||||
).completion.strip()
|
7
platform/reworkd_platform/services/aws/__init__.py
Normal file
7
platform/reworkd_platform/services/aws/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
import boto3
|
||||
from botocore.exceptions import ProfileNotFound
|
||||
|
||||
try:
|
||||
boto3.setup_default_session(profile_name="dev")
|
||||
except ProfileNotFound:
|
||||
pass
|
85
platform/reworkd_platform/services/aws/s3.py
Normal file
85
platform/reworkd_platform/services/aws/s3.py
Normal file
@ -0,0 +1,85 @@
|
||||
import io
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from aiohttp import ClientError
|
||||
from boto3 import client as boto3_client
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
REGION = "us-east-1"
|
||||
|
||||
|
||||
# noinspection SpellCheckingInspection
|
||||
class PresignedPost(BaseModel):
|
||||
url: str
|
||||
fields: Dict[str, str]
|
||||
|
||||
|
||||
class SimpleStorageService:
|
||||
# TODO: would be great if with could make this async
|
||||
|
||||
def __init__(self, bucket: Optional[str]) -> None:
|
||||
if not bucket:
|
||||
raise ValueError("Bucket name must be provided")
|
||||
|
||||
self._client = boto3_client("s3", region_name=REGION)
|
||||
self._bucket = bucket
|
||||
|
||||
def create_presigned_upload_url(
|
||||
self,
|
||||
object_name: str,
|
||||
) -> PresignedPost:
|
||||
return PresignedPost(
|
||||
**self._client.generate_presigned_post(
|
||||
Bucket=self._bucket,
|
||||
Key=object_name,
|
||||
)
|
||||
)
|
||||
|
||||
def create_presigned_download_url(self, object_name: str) -> str:
|
||||
return self._client.generate_presigned_url(
|
||||
"get_object",
|
||||
Params={"Bucket": self._bucket, "Key": object_name},
|
||||
)
|
||||
|
||||
def upload_to_bucket(
|
||||
self,
|
||||
object_name: str,
|
||||
file: io.BytesIO,
|
||||
) -> None:
|
||||
try:
|
||||
self._client.put_object(
|
||||
Bucket=self._bucket, Key=object_name, Body=file.getvalue()
|
||||
)
|
||||
except ClientError as e:
|
||||
logger.error(e)
|
||||
raise e
|
||||
|
||||
def download_file(self, object_name: str, local_filename: str) -> None:
|
||||
self._client.download_file(
|
||||
Bucket=self._bucket, Key=object_name, Filename=local_filename
|
||||
)
|
||||
|
||||
def list_keys(self, prefix: str) -> List[str]:
|
||||
files = self._client.list_objects_v2(Bucket=self._bucket, Prefix=prefix)
|
||||
if "Contents" not in files:
|
||||
return []
|
||||
|
||||
return [file["Key"] for file in files["Contents"]]
|
||||
|
||||
def download_folder(self, prefix: str, path: str) -> List[str]:
|
||||
local_files = []
|
||||
for key in self.list_keys(prefix):
|
||||
local_filename = os.path.join(path, key.split("/")[-1])
|
||||
self.download_file(key, local_filename)
|
||||
local_files.append(local_filename)
|
||||
|
||||
return local_files
|
||||
|
||||
def delete_folder(self, prefix: str) -> None:
|
||||
keys = self.list_keys(prefix)
|
||||
self._client.delete_objects(
|
||||
Bucket=self._bucket,
|
||||
Delete={"Objects": [{"Key": key} for key in keys]},
|
||||
)
|
147
platform/reworkd_platform/services/oauth_installers.py
Normal file
147
platform/reworkd_platform/services/oauth_installers.py
Normal file
@ -0,0 +1,147 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timedelta
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import aiohttp
|
||||
from fastapi import Depends, Path
|
||||
|
||||
from reworkd_platform.db.crud.oauth import OAuthCrud
|
||||
from reworkd_platform.db.models.auth import OauthCredentials
|
||||
from reworkd_platform.schemas import UserBase
|
||||
from reworkd_platform.services.security import encryption_service
|
||||
from reworkd_platform.settings import Settings
|
||||
from reworkd_platform.settings import settings as platform_settings
|
||||
from reworkd_platform.web.api.http_responses import forbidden
|
||||
|
||||
|
||||
class OAuthInstaller(ABC):
|
||||
def __init__(self, crud: OAuthCrud, settings: Settings):
|
||||
self.crud = crud
|
||||
self.settings = settings
|
||||
|
||||
@abstractmethod
|
||||
async def install(self, user: UserBase, redirect_uri: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def install_callback(self, code: str, state: str) -> OauthCredentials:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def uninstall(self, user: UserBase) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def store_access_token(creds: OauthCredentials, access_token: str) -> None:
|
||||
creds.access_token_enc = encryption_service.encrypt(access_token)
|
||||
|
||||
@staticmethod
|
||||
def store_refresh_token(creds: OauthCredentials, refresh_token: str) -> None:
|
||||
creds.refresh_token_enc = encryption_service.encrypt(refresh_token)
|
||||
|
||||
|
||||
class SIDInstaller(OAuthInstaller):
|
||||
PROVIDER = "sid"
|
||||
|
||||
async def install(self, user: UserBase, redirect_uri: str) -> str:
|
||||
# gracefully handle the case where the installation already exists
|
||||
# this can happen if the user starts the process from multiple tabs
|
||||
installation = await self.crud.get_installation_by_user_id(
|
||||
user.id, self.PROVIDER
|
||||
)
|
||||
if not installation:
|
||||
installation = await self.crud.create_installation(
|
||||
user,
|
||||
self.PROVIDER,
|
||||
redirect_uri,
|
||||
)
|
||||
scopes = ["data:query", "offline_access"]
|
||||
params = {
|
||||
"client_id": self.settings.sid_client_id,
|
||||
"redirect_uri": self.settings.sid_redirect_uri,
|
||||
"response_type": "code",
|
||||
"scope": " ".join(scopes),
|
||||
"state": installation.state,
|
||||
"audience": "https://api.sid.ai/api/v1/",
|
||||
}
|
||||
auth_url = "https://me.sid.ai/api/oauth/authorize"
|
||||
auth_url += "?" + urlencode(params)
|
||||
return auth_url
|
||||
|
||||
async def install_callback(self, code: str, state: str) -> OauthCredentials:
|
||||
creds = await self.crud.get_installation_by_state(state)
|
||||
if not creds:
|
||||
raise forbidden()
|
||||
req = {
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": self.settings.sid_client_id,
|
||||
"client_secret": self.settings.sid_client_secret,
|
||||
"redirect_uri": self.settings.sid_redirect_uri,
|
||||
"code": code,
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
"https://auth.sid.ai/oauth/token",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
data=json.dumps(req),
|
||||
) as response:
|
||||
res_data = await response.json()
|
||||
|
||||
OAuthInstaller.store_access_token(creds, res_data["access_token"])
|
||||
OAuthInstaller.store_refresh_token(creds, res_data["refresh_token"])
|
||||
creds.access_token_expiration = datetime.now() + timedelta(
|
||||
seconds=res_data["expires_in"]
|
||||
)
|
||||
return await creds.save(self.crud.session)
|
||||
|
||||
async def uninstall(self, user: UserBase) -> bool:
|
||||
creds = await self.crud.get_installation_by_user_id(user.id, self.PROVIDER)
|
||||
# check if credentials exist and contain a refresh token
|
||||
if not creds:
|
||||
return False
|
||||
|
||||
# use refresh token to revoke access
|
||||
delete_token = encryption_service.decrypt(creds.refresh_token_enc)
|
||||
# delete credentials from database
|
||||
await self.crud.session.delete(creds)
|
||||
|
||||
# revoke refresh token
|
||||
async with aiohttp.ClientSession() as session:
|
||||
await session.post(
|
||||
"https://auth.sid.ai/oauth/revoke",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data=json.dumps(
|
||||
{
|
||||
"client_id": self.settings.sid_client_id,
|
||||
"client_secret": self.settings.sid_client_secret,
|
||||
"token": delete_token,
|
||||
}
|
||||
),
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
integrations = {
|
||||
SIDInstaller.PROVIDER: SIDInstaller,
|
||||
}
|
||||
|
||||
|
||||
def installer_factory(
|
||||
provider: str = Path(description="OAuth Provider"),
|
||||
crud: OAuthCrud = Depends(OAuthCrud.inject),
|
||||
) -> OAuthInstaller:
|
||||
"""Factory for OAuth installers
|
||||
Args:
|
||||
provider (str): OAuth Provider (can be slack, github, etc.) (injected)
|
||||
crud (OAuthCrud): OAuth Crud (injected)
|
||||
"""
|
||||
|
||||
if provider in integrations:
|
||||
return integrations[provider](crud, platform_settings)
|
||||
raise NotImplementedError()
|
11
platform/reworkd_platform/services/pinecone/lifetime.py
Normal file
11
platform/reworkd_platform/services/pinecone/lifetime.py
Normal file
@ -0,0 +1,11 @@
|
||||
import pinecone
|
||||
|
||||
from reworkd_platform.settings import settings
|
||||
|
||||
|
||||
def init_pinecone() -> None:
|
||||
if settings.pinecone_api_key and settings.pinecone_environment:
|
||||
pinecone.init(
|
||||
api_key=settings.pinecone_api_key,
|
||||
environment=settings.pinecone_environment,
|
||||
)
|
98
platform/reworkd_platform/services/pinecone/pinecone.py
Normal file
98
platform/reworkd_platform/services/pinecone/pinecone.py
Normal file
@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from pinecone import Index # import doesnt work on plane wifi
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reworkd_platform.settings import settings
|
||||
from reworkd_platform.timer import timed_function
|
||||
from reworkd_platform.web.api.memory.memory import AgentMemory
|
||||
|
||||
OPENAI_EMBEDDING_DIM = 1536
|
||||
|
||||
|
||||
class Row(BaseModel):
|
||||
id: str
|
||||
values: List[float]
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class QueryResult(BaseModel):
|
||||
id: str
|
||||
score: float
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class PineconeMemory(AgentMemory):
|
||||
"""
|
||||
Wrapper around pinecone
|
||||
"""
|
||||
|
||||
def __init__(self, index_name: str, namespace: str = ""):
|
||||
self.index = Index(settings.pinecone_index_name)
|
||||
self.namespace = namespace or index_name
|
||||
|
||||
@timed_function(level="DEBUG")
|
||||
def __enter__(self) -> AgentMemory:
|
||||
self.embeddings: Embeddings = OpenAIEmbeddings(
|
||||
client=None, # Meta private value but mypy will complain its missing
|
||||
openai_api_key=settings.openai_api_key,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
@timed_function(level="DEBUG")
|
||||
def reset_class(self) -> None:
|
||||
self.index.delete(delete_all=True, namespace=self.namespace)
|
||||
|
||||
@timed_function(level="DEBUG")
|
||||
def add_tasks(self, tasks: List[str]) -> List[str]:
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
|
||||
embeds = self.embeddings.embed_documents(tasks)
|
||||
|
||||
if len(tasks) != len(embeds):
|
||||
raise ValueError("Embeddings and tasks are not the same length")
|
||||
|
||||
rows = [
|
||||
Row(values=vector, metadata={"text": tasks[i]}, id=str(uuid.uuid4()))
|
||||
for i, vector in enumerate(embeds)
|
||||
]
|
||||
|
||||
self.index.upsert(
|
||||
vectors=[row.dict() for row in rows], namespace=self.namespace
|
||||
)
|
||||
|
||||
return [row.id for row in rows]
|
||||
|
||||
@timed_function(level="DEBUG")
|
||||
def get_similar_tasks(
|
||||
self, text: str, score_threshold: float = 0.95
|
||||
) -> List[QueryResult]:
|
||||
# Get similar tasks
|
||||
vector = self.embeddings.embed_query(text)
|
||||
results = self.index.query(
|
||||
vector=vector,
|
||||
top_k=5,
|
||||
include_metadata=True,
|
||||
include_values=True,
|
||||
namespace=self.namespace,
|
||||
)
|
||||
|
||||
return [
|
||||
QueryResult(id=row.id, score=row.score, metadata=row.metadata)
|
||||
for row in getattr(results, "matches", [])
|
||||
if row.score > score_threshold
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def should_use() -> bool:
|
||||
return False
|
23
platform/reworkd_platform/services/security.py
Normal file
23
platform/reworkd_platform/services/security.py
Normal file
@ -0,0 +1,23 @@
|
||||
from typing import Union
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
|
||||
from reworkd_platform.settings import settings
|
||||
from reworkd_platform.web.api.http_responses import forbidden
|
||||
|
||||
|
||||
class EncryptionService:
|
||||
def __init__(self, secret: bytes):
|
||||
self.fernet = Fernet(secret)
|
||||
|
||||
def encrypt(self, text: str) -> bytes:
|
||||
return self.fernet.encrypt(text.encode("utf-8"))
|
||||
|
||||
def decrypt(self, encoded_bytes: Union[bytes, str]) -> str:
|
||||
try:
|
||||
return self.fernet.decrypt(encoded_bytes).decode("utf-8")
|
||||
except InvalidToken:
|
||||
raise forbidden()
|
||||
|
||||
|
||||
encryption_service = EncryptionService(settings.secret_signing_key.encode())
|
25
platform/reworkd_platform/services/ssl.py
Normal file
25
platform/reworkd_platform/services/ssl.py
Normal file
@ -0,0 +1,25 @@
|
||||
from ssl import SSLContext, create_default_context
|
||||
from typing import List, Optional
|
||||
|
||||
from reworkd_platform.settings import Settings
|
||||
|
||||
MACOS_CERT_PATH = "/etc/ssl/cert.pem"
|
||||
DOCKER_CERT_PATH = "/etc/ssl/certs/ca-certificates.crt"
|
||||
|
||||
|
||||
def get_ssl_context(
|
||||
settings: Settings, paths: Optional[List[str]] = None
|
||||
) -> SSLContext:
|
||||
if settings.db_ca_path:
|
||||
return create_default_context(cafile=settings.db_ca_path)
|
||||
|
||||
for path in paths or [MACOS_CERT_PATH, DOCKER_CERT_PATH]:
|
||||
try:
|
||||
return create_default_context(cafile=path)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
raise ValueError(
|
||||
"No CA certificates found for your OS. To fix this, please run change "
|
||||
"db_ca_path in your settings.py to point to a valid CA certificate file."
|
||||
)
|
1
platform/reworkd_platform/services/tokenizer/__init__.py
Normal file
1
platform/reworkd_platform/services/tokenizer/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Token Service"""
|
@ -0,0 +1,7 @@
|
||||
from fastapi import Request
|
||||
|
||||
from reworkd_platform.services.tokenizer.token_service import TokenService
|
||||
|
||||
|
||||
def get_token_service(request: Request) -> TokenService:
|
||||
return TokenService(request.app.state.token_encoding)
|
16
platform/reworkd_platform/services/tokenizer/lifetime.py
Normal file
16
platform/reworkd_platform/services/tokenizer/lifetime.py
Normal file
@ -0,0 +1,16 @@
|
||||
import tiktoken
|
||||
from fastapi import FastAPI
|
||||
|
||||
ENCODING_NAME = "cl100k_base" # gpt-4, gpt-3.5-turbo, text-embedding-ada-002
|
||||
|
||||
|
||||
def init_tokenizer(app: FastAPI) -> None: # pragma: no cover
|
||||
"""
|
||||
Initialize tokenizer.
|
||||
|
||||
TikToken downloads the encoding on start. It is then
|
||||
stored in the state of the application.
|
||||
|
||||
:param app: current application.
|
||||
"""
|
||||
app.state.token_encoding = tiktoken.get_encoding(ENCODING_NAME)
|
@ -0,0 +1,33 @@
|
||||
from tiktoken import Encoding, get_encoding
|
||||
|
||||
from reworkd_platform.schemas.agent import LLM_MODEL_MAX_TOKENS, LLM_Model
|
||||
from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
|
||||
|
||||
|
||||
class TokenService:
|
||||
def __init__(self, encoding: Encoding):
|
||||
self.encoding = encoding
|
||||
|
||||
@classmethod
|
||||
def create(cls, encoding: str = "cl100k_base") -> "TokenService":
|
||||
return cls(get_encoding(encoding))
|
||||
|
||||
def tokenize(self, text: str) -> list[int]:
|
||||
return self.encoding.encode(text)
|
||||
|
||||
def detokenize(self, tokens: list[int]) -> str:
|
||||
return self.encoding.decode(tokens)
|
||||
|
||||
def count(self, text: str) -> int:
|
||||
return len(self.tokenize(text))
|
||||
|
||||
def get_completion_space(self, model: LLM_Model, *prompts: str) -> int:
|
||||
max_allowed_tokens = LLM_MODEL_MAX_TOKENS.get(model, 4000)
|
||||
prompt_tokens = sum([self.count(p) for p in prompts])
|
||||
return max_allowed_tokens - prompt_tokens
|
||||
|
||||
def calculate_max_tokens(self, model: WrappedChatOpenAI, *prompts: str) -> None:
|
||||
requested_tokens = self.get_completion_space(model.model_name, *prompts)
|
||||
|
||||
model.max_tokens = min(model.max_tokens, requested_tokens)
|
||||
model.max_tokens = max(model.max_tokens, 1)
|
178
platform/reworkd_platform/settings.py
Normal file
178
platform/reworkd_platform/settings.py
Normal file
@ -0,0 +1,178 @@
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from tempfile import gettempdir
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseSettings
|
||||
from yarl import URL
|
||||
|
||||
from reworkd_platform.constants import ENV_PREFIX
|
||||
|
||||
TEMP_DIR = Path(gettempdir())
|
||||
|
||||
LOG_LEVEL = Literal[
|
||||
"NOTSET",
|
||||
"DEBUG",
|
||||
"INFO",
|
||||
"WARNING",
|
||||
"ERROR",
|
||||
"FATAL",
|
||||
]
|
||||
|
||||
|
||||
SASL_MECHANISM = Literal[
|
||||
"PLAIN",
|
||||
"SCRAM-SHA-256",
|
||||
]
|
||||
|
||||
ENVIRONMENT = Literal[
|
||||
"development",
|
||||
"production",
|
||||
]
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""
|
||||
Application settings.
|
||||
|
||||
These parameters can be configured
|
||||
with environment variables.
|
||||
"""
|
||||
|
||||
# Application settings
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
workers_count: int = 1
|
||||
reload: bool = True
|
||||
environment: ENVIRONMENT = "development"
|
||||
log_level: LOG_LEVEL = "DEBUG"
|
||||
|
||||
# Make sure you update this with your own secret key
|
||||
# Must be 32 url-safe base64-encoded bytes
|
||||
secret_signing_key: str = "JF52S66x6WMoifP5gZreiguYs9LYMn0lkXqgPYoNMD0="
|
||||
|
||||
# OpenAI
|
||||
openai_api_base: str = "https://api.openai.com/v1"
|
||||
openai_api_key: str = "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"
|
||||
openai_api_version: str = "2023-08-01-preview"
|
||||
azure_openai_deployment_name: str = "<Should be updated via env if using azure>"
|
||||
|
||||
# Helicone
|
||||
helicone_api_base: str = "https://oai.hconeai.com/v1"
|
||||
helicone_api_key: Optional[str] = None
|
||||
|
||||
replicate_api_key: Optional[str] = None
|
||||
serp_api_key: Optional[str] = None
|
||||
|
||||
# Frontend URL for CORS
|
||||
frontend_url: str = "https://allinix.ai"
|
||||
allowed_origins_regex: Optional[str] = None
|
||||
|
||||
# Variables for the database
|
||||
db_host: str = "65.19.178.35"
|
||||
db_port: int = 3307
|
||||
db_user: str = "reworkd_platform"
|
||||
db_pass: str = "reworkd_platform"
|
||||
db_base: str = "default"
|
||||
db_echo: bool = False
|
||||
db_ca_path: Optional[str] = None
|
||||
|
||||
# Variables for Pinecone DB
|
||||
pinecone_api_key: Optional[str] = None
|
||||
pinecone_index_name: Optional[str] = None
|
||||
pinecone_environment: Optional[str] = None
|
||||
|
||||
# Sentry's configuration.
|
||||
sentry_dsn: Optional[str] = None
|
||||
sentry_sample_rate: float = 1.0
|
||||
|
||||
kafka_bootstrap_servers: Union[str, List[str]] = []
|
||||
kafka_username: Optional[str] = None
|
||||
kafka_password: Optional[str] = None
|
||||
kafka_ssal_mechanism: SASL_MECHANISM = "PLAIN"
|
||||
|
||||
# Websocket settings
|
||||
pusher_app_id: Optional[str] = None
|
||||
pusher_key: Optional[str] = None
|
||||
pusher_secret: Optional[str] = None
|
||||
pusher_cluster: Optional[str] = None
|
||||
|
||||
# Application Settings
|
||||
ff_mock_mode_enabled: bool = False # Controls whether calls are mocked
|
||||
max_loops: int = 25 # Maximum number of loops to run
|
||||
|
||||
# Settings for sid
|
||||
sid_client_id: Optional[str] = None
|
||||
sid_client_secret: Optional[str] = None
|
||||
sid_redirect_uri: Optional[str] = None
|
||||
|
||||
@property
|
||||
def kafka_consumer_group(self) -> str:
|
||||
"""
|
||||
Kafka consumer group will be the name of the host in development
|
||||
mode, making it easier to share a dev cluster.
|
||||
"""
|
||||
|
||||
if self.environment == "development":
|
||||
return platform.node()
|
||||
|
||||
return "platform"
|
||||
|
||||
@property
|
||||
def db_url(self) -> URL:
|
||||
return URL.build(
|
||||
scheme="mysql+aiomysql",
|
||||
host=self.db_host,
|
||||
port=self.db_port,
|
||||
user=self.db_user,
|
||||
password=self.db_pass,
|
||||
path=f"/{self.db_base}",
|
||||
)
|
||||
|
||||
@property
|
||||
def pusher_enabled(self) -> bool:
|
||||
return all(
|
||||
[
|
||||
self.pusher_app_id,
|
||||
self.pusher_key,
|
||||
self.pusher_secret,
|
||||
self.pusher_cluster,
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def kafka_enabled(self) -> bool:
|
||||
return all(
|
||||
[
|
||||
self.kafka_bootstrap_servers,
|
||||
self.kafka_username,
|
||||
self.kafka_password,
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def helicone_enabled(self) -> bool:
|
||||
return all(
|
||||
[
|
||||
self.helicone_api_base,
|
||||
self.helicone_api_key,
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def sid_enabled(self) -> bool:
|
||||
return all(
|
||||
[
|
||||
self.sid_client_id,
|
||||
self.sid_client_secret,
|
||||
self.sid_redirect_uri,
|
||||
]
|
||||
)
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_prefix = ENV_PREFIX
|
||||
env_file_encoding = "utf-8"
|
||||
|
||||
|
||||
settings = Settings()
|
1
platform/reworkd_platform/tests/__init__.py
Normal file
1
platform/reworkd_platform/tests/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Tests for reworkd_platform."""
|
31
platform/reworkd_platform/tests/agent/test_analysis.py
Normal file
31
platform/reworkd_platform/tests/agent/test_analysis.py
Normal file
@ -0,0 +1,31 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from reworkd_platform.web.api.agent.analysis import Analysis
|
||||
from reworkd_platform.web.api.agent.tools.tools import get_default_tool, get_tool_name
|
||||
|
||||
|
||||
def test_analysis_model() -> None:
|
||||
valid_tool_name = get_tool_name(get_default_tool())
|
||||
analysis = Analysis(action=valid_tool_name, arg="arg", reasoning="reasoning")
|
||||
|
||||
assert analysis.action == valid_tool_name
|
||||
assert analysis.arg == "arg"
|
||||
assert analysis.reasoning == "reasoning"
|
||||
|
||||
|
||||
def test_analysis_model_search_empty_arg() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
Analysis(action="search", arg="", reasoning="reasoning")
|
||||
|
||||
|
||||
def test_analysis_model_search_non_empty_arg() -> None:
|
||||
analysis = Analysis(action="search", arg="non-empty arg", reasoning="reasoning")
|
||||
assert analysis.action == "search"
|
||||
assert analysis.arg == "non-empty arg"
|
||||
assert analysis.reasoning == "reasoning"
|
||||
|
||||
|
||||
def test_analysis_model_invalid_tool() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
Analysis(action="invalid tool name", arg="test argument", reasoning="reasoning")
|
63
platform/reworkd_platform/tests/agent/test_crud.py
Normal file
63
platform/reworkd_platform/tests/agent/test_crud.py
Normal file
@ -0,0 +1,63 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from reworkd_platform.db.crud.agent import AgentCRUD
|
||||
from reworkd_platform.settings import settings
|
||||
from reworkd_platform.web.api.errors import MaxLoopsError, MultipleSummaryError
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_task_count_no_error(mocker) -> None:
|
||||
mock_agent_run_exists(mocker, True)
|
||||
session = mock_session_with_run_count(mocker, 0)
|
||||
agent_crud: AgentCRUD = AgentCRUD(session, mocker.MagicMock())
|
||||
|
||||
# Doesn't throw an exception
|
||||
await agent_crud.validate_task_count("test", "summarize")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_task_count_when_run_not_found(mocker: MockerFixture) -> None:
|
||||
mock_agent_run_exists(mocker, False)
|
||||
agent_crud: AgentCRUD = AgentCRUD(mocker.AsyncMock(), mocker.MagicMock())
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
await agent_crud.validate_task_count("test", "test")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_task_count_max_loops_error(mocker: MockerFixture) -> None:
|
||||
mock_agent_run_exists(mocker, True)
|
||||
session = mock_session_with_run_count(mocker, settings.max_loops)
|
||||
agent_crud: AgentCRUD = AgentCRUD(session, mocker.AsyncMock())
|
||||
|
||||
with pytest.raises(MaxLoopsError):
|
||||
await agent_crud.validate_task_count("test", "test")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_task_count_multiple_summary_error(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mock_agent_run_exists(mocker, True)
|
||||
session = mock_session_with_run_count(mocker, 2)
|
||||
agent_crud: AgentCRUD = AgentCRUD(session, mocker.MagicMock())
|
||||
|
||||
with pytest.raises(MultipleSummaryError):
|
||||
await agent_crud.validate_task_count("test", "summarize")
|
||||
|
||||
|
||||
def mock_agent_run_exists(mocker: MockerFixture, exists: bool) -> None:
|
||||
mocker.patch("reworkd_platform.db.models.agent.AgentRun.get", return_value=exists)
|
||||
|
||||
|
||||
def mock_session_with_run_count(mocker: MockerFixture, run_count: int) -> AsyncMock:
|
||||
session = mocker.AsyncMock()
|
||||
scalar_mock = mocker.MagicMock()
|
||||
|
||||
session.execute.return_value = scalar_mock
|
||||
scalar_mock.scalar_one.return_value = run_count
|
||||
return session
|
138
platform/reworkd_platform/tests/agent/test_model_factory.py
Normal file
138
platform/reworkd_platform/tests/agent/test_model_factory.py
Normal file
@ -0,0 +1,138 @@
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
||||
|
||||
from reworkd_platform.schemas import ModelSettings, UserBase
|
||||
from reworkd_platform.settings import Settings
|
||||
from reworkd_platform.web.api.agent.model_factory import (
|
||||
WrappedAzureChatOpenAI,
|
||||
WrappedChatOpenAI,
|
||||
create_model,
|
||||
get_base_and_headers,
|
||||
)
|
||||
|
||||
|
||||
def test_helicone_enabled_without_custom_api_key():
|
||||
model_settings = ModelSettings()
|
||||
user = UserBase(id="user_id")
|
||||
settings = Settings(
|
||||
helicone_api_key="some_key",
|
||||
helicone_api_base="helicone_base",
|
||||
openai_api_base="openai_base",
|
||||
)
|
||||
|
||||
base, headers, use_helicone = get_base_and_headers(settings, model_settings, user)
|
||||
|
||||
assert use_helicone is True
|
||||
assert base == "helicone_base"
|
||||
assert headers == {
|
||||
"Helicone-Auth": "Bearer some_key",
|
||||
"Helicone-Cache-Enabled": "true",
|
||||
"Helicone-User-Id": "user_id",
|
||||
"Helicone-OpenAI-Api-Base": "openai_base",
|
||||
}
|
||||
|
||||
|
||||
def test_helicone_disabled():
|
||||
model_settings = ModelSettings()
|
||||
user = UserBase(id="user_id")
|
||||
settings = Settings()
|
||||
|
||||
base, headers, use_helicone = get_base_and_headers(settings, model_settings, user)
|
||||
assert base == "https://api.openai.com/v1"
|
||||
assert headers is None
|
||||
assert use_helicone is False
|
||||
|
||||
|
||||
def test_helicone_enabled_with_custom_api_key():
|
||||
model_settings = ModelSettings(
|
||||
custom_api_key="custom_key",
|
||||
)
|
||||
user = UserBase(id="user_id")
|
||||
settings = Settings(
|
||||
openai_api_base="openai_base",
|
||||
helicone_api_key="some_key",
|
||||
helicone_api_base="helicone_base",
|
||||
)
|
||||
|
||||
base, headers, use_helicone = get_base_and_headers(settings, model_settings, user)
|
||||
|
||||
assert base == "https://api.openai.com/v1"
|
||||
assert headers is None
|
||||
assert use_helicone is False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"streaming, use_azure",
|
||||
list(
|
||||
itertools.product(
|
||||
[True, False],
|
||||
[True, False],
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_create_model(streaming, use_azure):
|
||||
user = UserBase(id="user_id")
|
||||
settings = Settings()
|
||||
model_settings = ModelSettings(
|
||||
temperature=0.7,
|
||||
model="gpt-4o",
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
settings.openai_api_base = (
|
||||
"https://api.openai.com" if not use_azure else "https://oai.azure.com"
|
||||
)
|
||||
settings.openai_api_key = "key"
|
||||
settings.openai_api_version = "version"
|
||||
|
||||
result = create_model(settings, model_settings, user, streaming)
|
||||
assert issubclass(result.__class__, WrappedChatOpenAI)
|
||||
assert issubclass(result.__class__, ChatOpenAI)
|
||||
|
||||
# Check if the required keys are set properly
|
||||
assert result.openai_api_base == settings.openai_api_base
|
||||
assert result.openai_api_key == settings.openai_api_key
|
||||
assert result.temperature == model_settings.temperature
|
||||
assert result.max_tokens == model_settings.max_tokens
|
||||
assert result.streaming == streaming
|
||||
assert result.max_retries == 5
|
||||
|
||||
# For Azure specific checks
|
||||
if use_azure:
|
||||
assert isinstance(result, WrappedAzureChatOpenAI)
|
||||
assert issubclass(result.__class__, AzureChatOpenAI)
|
||||
assert result.openai_api_version == settings.openai_api_version
|
||||
assert result.deployment_name == "gpt-35-turbo"
|
||||
assert result.openai_api_type == "azure"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_settings, streaming",
|
||||
list(
|
||||
itertools.product(
|
||||
[
|
||||
ModelSettings(
|
||||
customTemperature=0.222,
|
||||
customModelName="gpt-4",
|
||||
maxTokens=1234,
|
||||
),
|
||||
ModelSettings(),
|
||||
],
|
||||
[True, False],
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_custom_model_settings(model_settings: ModelSettings, streaming: bool):
|
||||
model = create_model(
|
||||
Settings(),
|
||||
model_settings,
|
||||
UserBase(id="", email="test@example.com"),
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
assert model.temperature == model_settings.temperature
|
||||
assert model.model_name.startswith(model_settings.model)
|
||||
assert model.max_tokens == model_settings.max_tokens
|
||||
assert model.streaming == streaming
|
203
platform/reworkd_platform/tests/agent/test_task_output_parser.py
Normal file
203
platform/reworkd_platform/tests/agent/test_task_output_parser.py
Normal file
@ -0,0 +1,203 @@
|
||||
from typing import List, Type
|
||||
|
||||
import pytest
|
||||
from langchain.schema import OutputParserException
|
||||
|
||||
from reworkd_platform.web.api.agent.task_output_parser import (
|
||||
TaskOutputParser,
|
||||
extract_array,
|
||||
real_tasks_filter,
|
||||
remove_prefix,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_text,expected_output",
|
||||
[
|
||||
(
|
||||
'["Task 1: Do something", "Task 2: Do something else", "Task 3: Do '
|
||||
'another thing"]',
|
||||
["Do something", "Do something else", "Do another thing"],
|
||||
),
|
||||
(
|
||||
'Some random stuff ["1: Hello"]',
|
||||
["Hello"],
|
||||
),
|
||||
(
|
||||
"[]",
|
||||
[],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_parse_success(input_text: str, expected_output: List[str]) -> None:
|
||||
parser = TaskOutputParser(completed_tasks=[])
|
||||
result = parser.parse(input_text)
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test_parse_with_completed_tasks() -> None:
|
||||
input_text = '["One", "Two", "Three"]'
|
||||
completed = ["One"]
|
||||
expected = ["Two", "Three"]
|
||||
|
||||
parser = TaskOutputParser(completed_tasks=completed)
|
||||
|
||||
result = parser.parse(input_text)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_text, exception",
|
||||
[
|
||||
# Test cases for non-array and non-multiline string inputs
|
||||
("This is not an array", OutputParserException),
|
||||
("123456", OutputParserException),
|
||||
("Some random text", OutputParserException),
|
||||
("[abc]", OutputParserException),
|
||||
# Test cases for malformed arrays
|
||||
("[1, 2, 3", OutputParserException),
|
||||
("'item1', 'item2']", OutputParserException),
|
||||
("['item1', 'item2", OutputParserException),
|
||||
# Test case for invalid multiline strings
|
||||
("This is not\na valid\nmultiline string.", OutputParserException),
|
||||
# Test case for multiline strings that don't start with digit + period
|
||||
("Some text\nMore text\nAnd more text.", OutputParserException),
|
||||
],
|
||||
)
|
||||
def test_parse_failure(input_text: str, exception: Type[Exception]) -> None:
|
||||
parser = TaskOutputParser(completed_tasks=[])
|
||||
with pytest.raises(exception):
|
||||
parser.parse(input_text)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_str, expected",
|
||||
[
|
||||
# Test cases for empty array
|
||||
("[]", []),
|
||||
# Test cases for arrays with one element
|
||||
('["One"]', ["One"]),
|
||||
("['Single quote']", ["Single quote"]),
|
||||
# Test cases for arrays with multiple elements
|
||||
('["Research", "Develop", "Integrate"]', ["Research", "Develop", "Integrate"]),
|
||||
('["Search", "Identify"]', ["Search", "Identify"]),
|
||||
('["Item 1","Item 2","Item 3"]', ["Item 1", "Item 2", "Item 3"]),
|
||||
# Test cases for arrays with special characters in elements
|
||||
("['Single with \"quote\"']", ['Single with "quote"']),
|
||||
('["Escape \\" within"]', ['Escape " within']),
|
||||
# Test case for array embedded in other text
|
||||
("Random stuff ['Search', 'Identify']", ["Search", "Identify"]),
|
||||
# Test case for array within JSON
|
||||
('{"array": ["123", "456"]}', ["123", "456"]),
|
||||
# Multiline string cases
|
||||
(
|
||||
"1. Identify the target\n2. Conduct research\n3. Implement the methods",
|
||||
[
|
||||
"1. Identify the target",
|
||||
"2. Conduct research",
|
||||
"3. Implement the methods",
|
||||
],
|
||||
),
|
||||
("1. Step one.\n2. Step two.", ["1. Step one.", "2. Step two."]),
|
||||
(
|
||||
"""1. Review and understand the code to be debugged
|
||||
2. Identify and address any errors or issues found during the review process
|
||||
3. Print out debug information and setup initial variables
|
||||
4. Start necessary threads and execute program logic.""",
|
||||
[
|
||||
"1. Review and understand the code to be debugged",
|
||||
"2. Identify and address any errors or issues found during the review "
|
||||
"process",
|
||||
"3. Print out debug information and setup initial variables",
|
||||
"4. Start necessary threads and execute program logic.",
|
||||
],
|
||||
),
|
||||
# Test cases with sentences before the digit + period pattern
|
||||
(
|
||||
"Any text before 1. Identify the task to be repeated\nUnrelated info 2. "
|
||||
"Determine the frequency of the repetition\nAnother sentence 3. Create a "
|
||||
"schedule or system to ensure completion of the task at the designated "
|
||||
"frequency\nMore text 4. Execute the task according to the established "
|
||||
"schedule or system",
|
||||
[
|
||||
"1. Identify the task to be repeated",
|
||||
"2. Determine the frequency of the repetition",
|
||||
"3. Create a schedule or system to ensure completion of the task at "
|
||||
"the designated frequency",
|
||||
"4. Execute the task according to the established schedule or system",
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_array_success(input_str: str, expected: List[str]) -> None:
|
||||
print(extract_array(input_str), expected)
|
||||
assert extract_array(input_str) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_str, exception",
|
||||
[
|
||||
(None, TypeError),
|
||||
("123", RuntimeError),
|
||||
("Some random text", RuntimeError),
|
||||
('"single_string"', RuntimeError),
|
||||
('{"test": 123}', RuntimeError),
|
||||
('["Unclosed array", "other"', RuntimeError),
|
||||
],
|
||||
)
|
||||
def test_extract_array_exception(input_str: str, exception: Type[Exception]) -> None:
|
||||
with pytest.raises(exception):
|
||||
extract_array(input_str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"task_input, expected_output",
|
||||
[
|
||||
("Task: This is a sample task", "This is a sample task"),
|
||||
(
|
||||
"Task 1: Perform a comprehensive analysis of system performance.",
|
||||
"Perform a comprehensive analysis of system performance.",
|
||||
),
|
||||
("Task 2. Create a python script", "Create a python script"),
|
||||
("5 - This is a sample task", "This is a sample task"),
|
||||
("2: This is a sample task", "This is a sample task"),
|
||||
(
|
||||
"This is a sample task without a prefix",
|
||||
"This is a sample task without a prefix",
|
||||
),
|
||||
("Step: This is a sample task", "This is a sample task"),
|
||||
(
|
||||
"Step 1: Perform a comprehensive analysis of system performance.",
|
||||
"Perform a comprehensive analysis of system performance.",
|
||||
),
|
||||
("Step 2:Create a python script", "Create a python script"),
|
||||
("Step:This is a sample task", "This is a sample task"),
|
||||
(
|
||||
". Conduct research on the history of Nike",
|
||||
"Conduct research on the history of Nike",
|
||||
),
|
||||
(".This is a sample task", "This is a sample task"),
|
||||
(
|
||||
"1. Research the history and background of Nike company.",
|
||||
"Research the history and background of Nike company.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_remove_task_prefix(task_input: str, expected_output: str) -> None:
|
||||
output = remove_prefix(task_input)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_text, expected_result",
|
||||
[
|
||||
("Write the report", True),
|
||||
("No new task needed", False),
|
||||
("Task completed", False),
|
||||
("Do nothing", False),
|
||||
("", False), # empty_string
|
||||
("no new task needed", False), # case_insensitive
|
||||
],
|
||||
)
|
||||
def test_real_tasks_filter_no_task(input_text: str, expected_result: bool) -> None:
|
||||
assert real_tasks_filter(input_text) == expected_result
|
56
platform/reworkd_platform/tests/agent/test_tools.py
Normal file
56
platform/reworkd_platform/tests/agent/test_tools.py
Normal file
@ -0,0 +1,56 @@
|
||||
from typing import List, Type
|
||||
|
||||
from reworkd_platform.web.api.agent.tools.conclude import Conclude
|
||||
from reworkd_platform.web.api.agent.tools.image import Image
|
||||
from reworkd_platform.web.api.agent.tools.reason import Reason
|
||||
from reworkd_platform.web.api.agent.tools.search import Search
|
||||
from reworkd_platform.web.api.agent.tools.sidsearch import SID
|
||||
from reworkd_platform.web.api.agent.tools.tools import (
|
||||
Tool,
|
||||
format_tool_name,
|
||||
get_default_tool,
|
||||
get_tool_from_name,
|
||||
get_tool_name,
|
||||
get_tools_overview,
|
||||
)
|
||||
|
||||
|
||||
def test_get_tool_name() -> None:
|
||||
assert get_tool_name(Image) == "image"
|
||||
assert get_tool_name(Search) == "search"
|
||||
assert get_tool_name(Reason) == "reason"
|
||||
|
||||
|
||||
def test_format_tool_name() -> None:
|
||||
assert format_tool_name("Search") == "search"
|
||||
assert format_tool_name("reason") == "reason"
|
||||
assert format_tool_name("Conclude") == "conclude"
|
||||
assert format_tool_name("CoNcLuDe") == "conclude"
|
||||
|
||||
|
||||
def test_get_tools_overview_no_duplicates() -> None:
|
||||
"""Test to assert that the tools overview doesn't include duplicates."""
|
||||
tools: List[Type[Tool]] = [Image, Search, Reason, Conclude, Image, Search]
|
||||
overview = get_tools_overview(tools)
|
||||
|
||||
# Check if each unique tool description is included in the overview
|
||||
for tool in set(tools):
|
||||
expected_description = f"'{get_tool_name(tool)}': {tool.description}"
|
||||
assert expected_description in overview
|
||||
|
||||
# Check for duplicates in the overview
|
||||
overview_list = overview.split("\n")
|
||||
assert len(overview_list) == len(
|
||||
set(overview_list)
|
||||
), "Overview includes duplicate entries"
|
||||
|
||||
|
||||
def test_get_default_tool() -> None:
|
||||
assert get_default_tool() == Search
|
||||
|
||||
|
||||
def test_get_tool_from_name() -> None:
|
||||
assert get_tool_from_name("Search") == Search
|
||||
assert get_tool_from_name("CoNcLuDe") == Conclude
|
||||
assert get_tool_from_name("NonExistingTool") == Search
|
||||
assert get_tool_from_name("SID") == SID
|
@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
|
||||
from reworkd_platform.web.api.memory.memory_with_fallback import MemoryWithFallback
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"method_name, args",
|
||||
[
|
||||
("add_tasks", (["task1", "task2"],)),
|
||||
("get_similar_tasks", ("task1",)),
|
||||
("reset_class", ()),
|
||||
],
|
||||
)
|
||||
def test_memory_primary(mocker, method_name: str, args) -> None:
|
||||
primary = mocker.Mock()
|
||||
secondary = mocker.Mock()
|
||||
memory_with_fallback = MemoryWithFallback(primary, secondary)
|
||||
|
||||
# Use getattr() to call the method on the object with args
|
||||
getattr(memory_with_fallback, method_name)(*args)
|
||||
getattr(primary, method_name).assert_called_once_with(*args)
|
||||
getattr(secondary, method_name).assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"method_name, args",
|
||||
[
|
||||
("add_tasks", (["task1", "task2"],)),
|
||||
("get_similar_tasks", ("task1",)),
|
||||
("reset_class", ()),
|
||||
],
|
||||
)
|
||||
def test_memory_fallback(mocker, method_name: str, args) -> None:
|
||||
primary = mocker.Mock()
|
||||
secondary = mocker.Mock()
|
||||
memory_with_fallback = MemoryWithFallback(primary, secondary)
|
||||
|
||||
getattr(primary, method_name).side_effect = Exception("Primary Failed")
|
||||
|
||||
# Call the method again, this time it should fall back to secondary
|
||||
getattr(memory_with_fallback, method_name)(*args)
|
||||
getattr(primary, method_name).assert_called_once_with(*args)
|
||||
getattr(secondary, method_name).assert_called_once_with(*args)
|
26
platform/reworkd_platform/tests/test_dependancies.py
Normal file
26
platform/reworkd_platform/tests/test_dependancies.py
Normal file
@ -0,0 +1,26 @@
|
||||
import pytest
|
||||
|
||||
from reworkd_platform.web.api.agent import dependancies
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize(
|
||||
"validator, step",
|
||||
[
|
||||
(dependancies.agent_summarize_validator, "summarize"),
|
||||
(dependancies.agent_chat_validator, "chat"),
|
||||
(dependancies.agent_analyze_validator, "analyze"),
|
||||
(dependancies.agent_create_validator, "create"),
|
||||
(dependancies.agent_execute_validator, "execute"),
|
||||
],
|
||||
)
|
||||
async def test_agent_validate(mocker, validator, step):
|
||||
run_id = "asim"
|
||||
crud = mocker.Mock()
|
||||
body = mocker.Mock()
|
||||
body.run_id = run_id
|
||||
|
||||
crud.create_task = mocker.AsyncMock()
|
||||
|
||||
await validator(body, crud)
|
||||
crud.create_task.assert_called_once_with(run_id, step)
|
45
platform/reworkd_platform/tests/test_helpers.py
Normal file
45
platform/reworkd_platform/tests/test_helpers.py
Normal file
@ -0,0 +1,45 @@
|
||||
import pytest
|
||||
from openai.error import InvalidRequestError, ServiceUnavailableError
|
||||
|
||||
from reworkd_platform.schemas.agent import ModelSettings
|
||||
from reworkd_platform.web.api.agent.helpers import openai_error_handler
|
||||
from reworkd_platform.web.api.errors import OpenAIError
|
||||
|
||||
|
||||
async def act(*args, settings: ModelSettings = ModelSettings(), **kwargs):
|
||||
return await openai_error_handler(*args, settings=settings, **kwargs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_unavailable_error():
|
||||
async def mock_service_unavailable_error():
|
||||
raise ServiceUnavailableError("Service Unavailable")
|
||||
|
||||
with pytest.raises(OpenAIError):
|
||||
await act(mock_service_unavailable_error)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"settings,should_log",
|
||||
[
|
||||
(ModelSettings(custom_api_key="xyz"), False),
|
||||
(ModelSettings(custom_api_key=None), True),
|
||||
],
|
||||
)
|
||||
async def test_should_log(settings, should_log):
|
||||
async def mock_invalid_request_error_model_access():
|
||||
raise InvalidRequestError(
|
||||
"The model: xyz does not exist or you do not have access to it.",
|
||||
param="model",
|
||||
)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await openai_error_handler(
|
||||
mock_invalid_request_error_model_access, settings=settings
|
||||
)
|
||||
|
||||
assert isinstance(exc_info.value, OpenAIError)
|
||||
error: OpenAIError = exc_info.value
|
||||
|
||||
assert error.should_log == should_log
|
15
platform/reworkd_platform/tests/test_oauth_installers.py
Normal file
15
platform/reworkd_platform/tests/test_oauth_installers.py
Normal file
@ -0,0 +1,15 @@
|
||||
import pytest
|
||||
|
||||
from reworkd_platform.services.oauth_installers import installer_factory
|
||||
|
||||
|
||||
def test_installer_factory(mocker):
|
||||
crud = mocker.Mock()
|
||||
installer_factory("sid", crud)
|
||||
|
||||
|
||||
def test_integration_dne(mocker):
|
||||
crud = mocker.Mock()
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
installer_factory("asim", crud)
|
18
platform/reworkd_platform/tests/test_reworkd_platform.py
Normal file
18
platform/reworkd_platform/tests/test_reworkd_platform.py
Normal file
@ -0,0 +1,18 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import AsyncClient
|
||||
from starlette import status
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Mysql needs to be mocked")
|
||||
@pytest.mark.anyio
|
||||
async def test_health(client: AsyncClient, fastapi_app: FastAPI) -> None:
|
||||
"""
|
||||
Checks the health endpoint.
|
||||
|
||||
:param client: client for the app.
|
||||
:param fastapi_app: current FastAPI application.
|
||||
"""
|
||||
url = fastapi_app.url_path_for("health_check")
|
||||
response = await client.get(url)
|
||||
assert response.status_code == status.HTTP_200_OK
|
21
platform/reworkd_platform/tests/test_s3.py
Normal file
21
platform/reworkd_platform/tests/test_s3.py
Normal file
@ -0,0 +1,21 @@
|
||||
from reworkd_platform.services.aws.s3 import SimpleStorageService
|
||||
|
||||
|
||||
def test_create_signed_post(mocker):
|
||||
post_url = {
|
||||
"url": "https://my_bucket.s3.amazonaws.com/my_object",
|
||||
"fields": {"key": "value"},
|
||||
}
|
||||
|
||||
boto3_mock = mocker.Mock()
|
||||
boto3_mock.generate_presigned_post.return_value = post_url
|
||||
mocker.patch(
|
||||
"reworkd_platform.services.aws.s3.boto3_client", return_value=boto3_mock
|
||||
)
|
||||
|
||||
assert (
|
||||
SimpleStorageService(bucket="my_bucket").create_presigned_upload_url(
|
||||
object_name="json"
|
||||
)
|
||||
== post_url
|
||||
)
|
61
platform/reworkd_platform/tests/test_schemas.py
Normal file
61
platform/reworkd_platform/tests/test_schemas.py
Normal file
@ -0,0 +1,61 @@
|
||||
import pytest
|
||||
|
||||
from reworkd_platform.schemas.agent import ModelSettings
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"settings",
|
||||
[
|
||||
{
|
||||
"model": "gpt-4o",
|
||||
"max_tokens": 7000,
|
||||
"temperature": 0.5,
|
||||
"language": "french",
|
||||
},
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"max_tokens": 3000,
|
||||
},
|
||||
{
|
||||
"model": "gpt-3.5-turbo-16k",
|
||||
"max_tokens": 16000,
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_model_settings_valid(settings):
|
||||
result = ModelSettings(**settings)
|
||||
assert result.model == settings.get("model", "gpt-4o")
|
||||
assert result.max_tokens == settings.get("max_tokens", 500)
|
||||
assert result.temperature == settings.get("temperature", 0.9)
|
||||
assert result.language == settings.get("language", "English")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"settings",
|
||||
[
|
||||
{
|
||||
"model": "gpt-4-32k",
|
||||
},
|
||||
{
|
||||
"temperature": -1,
|
||||
},
|
||||
{
|
||||
"max_tokens": 8000,
|
||||
},
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"max_tokens": 32000,
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_model_settings_invalid(settings):
|
||||
with pytest.raises(Exception):
|
||||
ModelSettings(**settings)
|
||||
|
||||
|
||||
def test_model_settings_default():
|
||||
settings = ModelSettings(**{})
|
||||
assert settings.model == "gpt-3.5-turbo"
|
||||
assert settings.temperature == 0.9
|
||||
assert settings.max_tokens == 500
|
||||
assert settings.language == "English"
|
29
platform/reworkd_platform/tests/test_security.py
Normal file
29
platform/reworkd_platform/tests/test_security.py
Normal file
@ -0,0 +1,29 @@
|
||||
import pytest
|
||||
from cryptography.fernet import Fernet
|
||||
from fastapi import HTTPException
|
||||
|
||||
from reworkd_platform.services.security import EncryptionService
|
||||
|
||||
|
||||
def test_encrypt_decrypt():
|
||||
key = Fernet.generate_key()
|
||||
service = EncryptionService(key)
|
||||
|
||||
original_text = "Hello, world!"
|
||||
encrypted = service.encrypt(original_text)
|
||||
decrypted = service.decrypt(encrypted)
|
||||
|
||||
assert original_text == decrypted
|
||||
|
||||
|
||||
def test_invalid_key():
|
||||
key = Fernet.generate_key()
|
||||
|
||||
different_key = Fernet.generate_key()
|
||||
different_service = EncryptionService(different_key)
|
||||
|
||||
original_text = "Hello, world!"
|
||||
encrypted = Fernet(key).encrypt(original_text.encode())
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
different_service.decrypt(encrypted)
|
5
platform/reworkd_platform/tests/test_settings.py
Normal file
5
platform/reworkd_platform/tests/test_settings.py
Normal file
@ -0,0 +1,5 @@
|
||||
from reworkd_platform.settings import Settings
|
||||
|
||||
|
||||
def test_settings_create():
|
||||
assert Settings() is not None
|
105
platform/reworkd_platform/tests/test_token_service.py
Normal file
105
platform/reworkd_platform/tests/test_token_service.py
Normal file
@ -0,0 +1,105 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import tiktoken
|
||||
|
||||
from reworkd_platform.schemas.agent import LLM_MODEL_MAX_TOKENS
|
||||
from reworkd_platform.services.tokenizer.token_service import TokenService
|
||||
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
def test_happy_path() -> None:
|
||||
service = TokenService(encoding)
|
||||
text = "Hello world!"
|
||||
validate_tokenize_and_detokenize(service, text, 3)
|
||||
|
||||
|
||||
def test_nothing() -> None:
|
||||
service = TokenService(encoding)
|
||||
text = ""
|
||||
validate_tokenize_and_detokenize(service, text, 0)
|
||||
|
||||
|
||||
def validate_tokenize_and_detokenize(
|
||||
service: TokenService, text: str, expected_token_count: int
|
||||
) -> None:
|
||||
tokens = service.tokenize(text)
|
||||
assert text == service.detokenize(tokens)
|
||||
assert len(tokens) == service.count(text)
|
||||
assert len(tokens) == expected_token_count
|
||||
|
||||
|
||||
def test_calculate_max_tokens_with_small_max_tokens() -> None:
|
||||
initial_max_tokens = 3000
|
||||
service = TokenService(encoding)
|
||||
model = Mock(spec=["model_name", "max_tokens"])
|
||||
model.model_name = "gpt-3.5-turbo"
|
||||
model.max_tokens = initial_max_tokens
|
||||
|
||||
service.calculate_max_tokens(model, "Hello")
|
||||
|
||||
assert model.max_tokens == initial_max_tokens
|
||||
|
||||
|
||||
def test_calculate_max_tokens_with_high_completion_tokens() -> None:
|
||||
service = TokenService(encoding)
|
||||
prompt_tokens = service.count(LONG_TEXT)
|
||||
model = Mock(spec=["model_name", "max_tokens"])
|
||||
model.model_name = "gpt-3.5-turbo"
|
||||
model.max_tokens = 8000
|
||||
|
||||
service.calculate_max_tokens(model, LONG_TEXT)
|
||||
|
||||
assert model.max_tokens == (
|
||||
LLM_MODEL_MAX_TOKENS.get("gpt-3.5-turbo") - prompt_tokens
|
||||
)
|
||||
|
||||
|
||||
def test_calculate_max_tokens_with_negative_result() -> None:
|
||||
service = TokenService(encoding)
|
||||
model = Mock(spec=["model_name", "max_tokens"])
|
||||
model.model_name = "gpt-3.5-turbo"
|
||||
model.max_tokens = 8000
|
||||
|
||||
service.calculate_max_tokens(model, *([LONG_TEXT] * 100))
|
||||
|
||||
# We use the minimum length of 1
|
||||
assert model.max_tokens == 1
|
||||
|
||||
|
||||
LONG_TEXT = """
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
This is some long text. This is some long text. This is some long text.
|
||||
"""
|
34
platform/reworkd_platform/timer.py
Normal file
34
platform/reworkd_platform/timer.py
Normal file
@ -0,0 +1,34 @@
|
||||
from functools import wraps
|
||||
from time import time
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
from loguru import logger
|
||||
|
||||
Log_Level = Literal[
|
||||
"TRACE",
|
||||
"DEBUG",
|
||||
"INFO",
|
||||
"SUCCESS",
|
||||
"WARNING",
|
||||
"ERROR",
|
||||
"CRITICAL",
|
||||
]
|
||||
|
||||
|
||||
def timed_function(level: Log_Level = "INFO") -> Callable[..., Any]:
|
||||
def decorator(func: Any) -> Callable[..., Any]:
|
||||
@wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
start_time = time()
|
||||
result = func(*args, **kwargs)
|
||||
execution_time = time() - start_time
|
||||
logger.log(
|
||||
level,
|
||||
f"Function '{func.__qualname__}' executed in {execution_time:.4f} seconds",
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
1
platform/reworkd_platform/web/__init__.py
Normal file
1
platform/reworkd_platform/web/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""WEB API for reworkd_platform."""
|
1
platform/reworkd_platform/web/api/__init__.py
Normal file
1
platform/reworkd_platform/web/api/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""reworkd_platform API package."""
|
4
platform/reworkd_platform/web/api/agent/__init__.py
Normal file
4
platform/reworkd_platform/web/api/agent/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
"""API for checking running agents"""
|
||||
from reworkd_platform.web.api.agent.views import router
|
||||
|
||||
__all__ = ["router"]
|
@ -0,0 +1,51 @@
|
||||
from typing import List, Optional, Protocol
|
||||
|
||||
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
|
||||
|
||||
from reworkd_platform.web.api.agent.analysis import Analysis
|
||||
|
||||
|
||||
class AgentService(Protocol):
|
||||
async def start_goal_agent(self, *, goal: str) -> List[str]:
|
||||
pass
|
||||
|
||||
async def analyze_task_agent(
|
||||
self, *, goal: str, task: str, tool_names: List[str]
|
||||
) -> Analysis:
|
||||
pass
|
||||
|
||||
async def execute_task_agent(
|
||||
self,
|
||||
*,
|
||||
goal: str,
|
||||
task: str,
|
||||
analysis: Analysis,
|
||||
) -> FastAPIStreamingResponse:
|
||||
pass
|
||||
|
||||
async def create_tasks_agent(
|
||||
self,
|
||||
*,
|
||||
goal: str,
|
||||
tasks: List[str],
|
||||
last_task: str,
|
||||
result: str,
|
||||
completed_tasks: Optional[List[str]] = None,
|
||||
) -> List[str]:
|
||||
pass
|
||||
|
||||
async def summarize_task_agent(
|
||||
self,
|
||||
*,
|
||||
goal: str,
|
||||
results: List[str],
|
||||
) -> FastAPIStreamingResponse:
|
||||
pass
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
*,
|
||||
message: str,
|
||||
results: List[str],
|
||||
) -> FastAPIStreamingResponse:
|
||||
pass
|
@ -0,0 +1,53 @@
|
||||
from typing import Any, Callable, Coroutine, Optional
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
from reworkd_platform.db.crud.oauth import OAuthCrud
|
||||
from reworkd_platform.schemas.agent import AgentRun, LLM_Model
|
||||
from reworkd_platform.schemas.user import UserBase
|
||||
from reworkd_platform.services.tokenizer.dependencies import get_token_service
|
||||
from reworkd_platform.services.tokenizer.token_service import TokenService
|
||||
from reworkd_platform.settings import settings
|
||||
from reworkd_platform.web.api.agent.agent_service.agent_service import AgentService
|
||||
from reworkd_platform.web.api.agent.agent_service.mock_agent_service import (
|
||||
MockAgentService,
|
||||
)
|
||||
from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import (
|
||||
OpenAIAgentService,
|
||||
)
|
||||
from reworkd_platform.web.api.agent.model_factory import create_model
|
||||
from reworkd_platform.web.api.dependencies import get_current_user
|
||||
|
||||
|
||||
def get_agent_service(
|
||||
validator: Callable[..., Coroutine[Any, Any, AgentRun]],
|
||||
streaming: bool = False,
|
||||
llm_model: Optional[LLM_Model] = None,
|
||||
) -> Callable[..., AgentService]:
|
||||
def func(
|
||||
run: AgentRun = Depends(validator),
|
||||
user: UserBase = Depends(get_current_user),
|
||||
token_service: TokenService = Depends(get_token_service),
|
||||
oauth_crud: OAuthCrud = Depends(OAuthCrud.inject),
|
||||
) -> AgentService:
|
||||
if settings.ff_mock_mode_enabled:
|
||||
return MockAgentService()
|
||||
|
||||
model = create_model(
|
||||
settings,
|
||||
run.model_settings,
|
||||
user,
|
||||
streaming=streaming,
|
||||
force_model=llm_model,
|
||||
)
|
||||
|
||||
return OpenAIAgentService(
|
||||
model,
|
||||
run.model_settings,
|
||||
token_service,
|
||||
callbacks=None,
|
||||
user=user,
|
||||
oauth_crud=oauth_crud,
|
||||
)
|
||||
|
||||
return func
|
@ -0,0 +1,74 @@
|
||||
import time
|
||||
from typing import Any, List
|
||||
|
||||
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
|
||||
|
||||
from reworkd_platform.web.api.agent.agent_service.agent_service import (
|
||||
AgentService,
|
||||
Analysis,
|
||||
)
|
||||
from reworkd_platform.web.api.agent.stream_mock import stream_string
|
||||
|
||||
|
||||
class MockAgentService(AgentService):
|
||||
async def start_goal_agent(self, **kwargs: Any) -> List[str]:
|
||||
time.sleep(1)
|
||||
return ["Task X", "Task Y", "Task Z"]
|
||||
|
||||
async def create_tasks_agent(self, **kwargs: Any) -> List[str]:
|
||||
time.sleep(1)
|
||||
return ["Some random task that doesn't exist"]
|
||||
|
||||
async def analyze_task_agent(self, **kwargs: Any) -> Analysis:
|
||||
time.sleep(1.5)
|
||||
return Analysis(
|
||||
action="reason",
|
||||
arg="Mock analysis",
|
||||
reasoning="Mock to avoid wasting money calling the OpenAI API.",
|
||||
)
|
||||
|
||||
async def execute_task_agent(self, **kwargs: Any) -> FastAPIStreamingResponse:
|
||||
time.sleep(0.5)
|
||||
return stream_string(
|
||||
""" This is going to be a longer task result such that
|
||||
We make the stream of this string take time and feel long. The reality is... this is a mock!
|
||||
|
||||
Lorem Ipsum is simply dummy text of the printing and typesetting industry.
|
||||
Lorem Ipsum has been the industry's standard dummy text ever since the 1500s,
|
||||
when an unknown printer took a galley of type and scrambled it to make a type specimen book.
|
||||
It has survived not only five centuries, but also the leap into electronic typesetting, remaining unchanged.
|
||||
"""
|
||||
+ kwargs.get("task", "task"),
|
||||
True,
|
||||
)
|
||||
|
||||
async def summarize_task_agent(
|
||||
self,
|
||||
*,
|
||||
goal: str,
|
||||
results: List[str],
|
||||
) -> FastAPIStreamingResponse:
|
||||
time.sleep(0.5)
|
||||
return stream_string(
|
||||
""" This is going to be a longer task result such that
|
||||
We make the stream of this string take time and feel long. The reality is... this is a mock!
|
||||
|
||||
Lorem Ipsum is simply dummy text of the printing and typesetting industry.
|
||||
Lorem Ipsum has been the industry's standard dummy text ever since the 1500s,
|
||||
when an unknown printer took a galley of type and scrambled it to make a type specimen book.
|
||||
It has survived not only five centuries, but also the leap into electronic typesetting, remaining unchanged.
|
||||
""",
|
||||
True,
|
||||
)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
*,
|
||||
message: str,
|
||||
results: List[str],
|
||||
) -> FastAPIStreamingResponse:
|
||||
time.sleep(0.5)
|
||||
return stream_string(
|
||||
"What do you want dude?",
|
||||
True,
|
||||
)
|
@ -0,0 +1,227 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
|
||||
from lanarky.responses import StreamingResponse
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks.base import AsyncCallbackHandler
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
|
||||
from langchain.schema import HumanMessage
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
|
||||
from reworkd_platform.db.crud.oauth import OAuthCrud
|
||||
from reworkd_platform.schemas.agent import ModelSettings
|
||||
from reworkd_platform.schemas.user import UserBase
|
||||
from reworkd_platform.services.tokenizer.token_service import TokenService
|
||||
from reworkd_platform.web.api.agent.agent_service.agent_service import AgentService
|
||||
from reworkd_platform.web.api.agent.analysis import Analysis, AnalysisArguments
|
||||
from reworkd_platform.web.api.agent.helpers import (
|
||||
call_model_with_handling,
|
||||
openai_error_handler,
|
||||
parse_with_handling,
|
||||
)
|
||||
from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
|
||||
from reworkd_platform.web.api.agent.prompts import (
|
||||
analyze_task_prompt,
|
||||
chat_prompt,
|
||||
create_tasks_prompt,
|
||||
start_goal_prompt,
|
||||
)
|
||||
from reworkd_platform.web.api.agent.task_output_parser import TaskOutputParser
|
||||
from reworkd_platform.web.api.agent.tools.open_ai_function import get_tool_function
|
||||
from reworkd_platform.web.api.agent.tools.tools import (
|
||||
get_default_tool,
|
||||
get_tool_from_name,
|
||||
get_tool_name,
|
||||
get_user_tools,
|
||||
)
|
||||
from reworkd_platform.web.api.agent.tools.utils import summarize
|
||||
from reworkd_platform.web.api.errors import OpenAIError
|
||||
|
||||
|
||||
class OpenAIAgentService(AgentService):
|
||||
def __init__(
|
||||
self,
|
||||
model: WrappedChatOpenAI,
|
||||
settings: ModelSettings,
|
||||
token_service: TokenService,
|
||||
callbacks: Optional[List[AsyncCallbackHandler]],
|
||||
user: UserBase,
|
||||
oauth_crud: OAuthCrud,
|
||||
):
|
||||
self.model = model
|
||||
self.settings = settings
|
||||
self.token_service = token_service
|
||||
self.callbacks = callbacks
|
||||
self.user = user
|
||||
self.oauth_crud = oauth_crud
|
||||
|
||||
async def start_goal_agent(self, *, goal: str) -> List[str]:
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[SystemMessagePromptTemplate(prompt=start_goal_prompt)]
|
||||
)
|
||||
|
||||
self.token_service.calculate_max_tokens(
|
||||
self.model,
|
||||
prompt.format_prompt(
|
||||
goal=goal,
|
||||
language=self.settings.language,
|
||||
).to_string(),
|
||||
)
|
||||
|
||||
completion = await call_model_with_handling(
|
||||
self.model,
|
||||
ChatPromptTemplate.from_messages(
|
||||
[SystemMessagePromptTemplate(prompt=start_goal_prompt)]
|
||||
),
|
||||
{"goal": goal, "language": self.settings.language},
|
||||
settings=self.settings,
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
|
||||
task_output_parser = TaskOutputParser(completed_tasks=[])
|
||||
tasks = parse_with_handling(task_output_parser, completion)
|
||||
|
||||
return tasks
|
||||
|
||||
async def analyze_task_agent(
|
||||
self, *, goal: str, task: str, tool_names: List[str]
|
||||
) -> Analysis:
|
||||
user_tools = await get_user_tools(tool_names, self.user, self.oauth_crud)
|
||||
functions = list(map(get_tool_function, user_tools))
|
||||
prompt = analyze_task_prompt.format_prompt(
|
||||
goal=goal,
|
||||
task=task,
|
||||
language=self.settings.language,
|
||||
)
|
||||
|
||||
self.token_service.calculate_max_tokens(
|
||||
self.model,
|
||||
prompt.to_string(),
|
||||
str(functions),
|
||||
)
|
||||
|
||||
message = await openai_error_handler(
|
||||
func=self.model.apredict_messages,
|
||||
messages=prompt.to_messages(),
|
||||
functions=functions,
|
||||
settings=self.settings,
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
|
||||
function_call = message.additional_kwargs.get("function_call", {})
|
||||
completion = function_call.get("arguments", "")
|
||||
|
||||
try:
|
||||
pydantic_parser = PydanticOutputParser(pydantic_object=AnalysisArguments)
|
||||
analysis_arguments = parse_with_handling(pydantic_parser, completion)
|
||||
return Analysis(
|
||||
action=function_call.get("name", get_tool_name(get_default_tool())),
|
||||
**analysis_arguments.dict(),
|
||||
)
|
||||
except (OpenAIError, ValidationError):
|
||||
return Analysis.get_default_analysis(task)
|
||||
|
||||
async def execute_task_agent(
|
||||
self,
|
||||
*,
|
||||
goal: str,
|
||||
task: str,
|
||||
analysis: Analysis,
|
||||
) -> StreamingResponse:
|
||||
# TODO: More mature way of calculating max_tokens
|
||||
if self.model.max_tokens > 3000:
|
||||
self.model.max_tokens = max(self.model.max_tokens - 1000, 3000)
|
||||
|
||||
tool_class = get_tool_from_name(analysis.action)
|
||||
return await tool_class(self.model, self.settings.language).call(
|
||||
goal,
|
||||
task,
|
||||
analysis.arg,
|
||||
self.user,
|
||||
self.oauth_crud,
|
||||
)
|
||||
|
||||
async def create_tasks_agent(
|
||||
self,
|
||||
*,
|
||||
goal: str,
|
||||
tasks: List[str],
|
||||
last_task: str,
|
||||
result: str,
|
||||
completed_tasks: Optional[List[str]] = None,
|
||||
) -> List[str]:
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[SystemMessagePromptTemplate(prompt=create_tasks_prompt)]
|
||||
)
|
||||
|
||||
args = {
|
||||
"goal": goal,
|
||||
"language": self.settings.language,
|
||||
"tasks": "\n".join(tasks),
|
||||
"lastTask": last_task,
|
||||
"result": result,
|
||||
}
|
||||
|
||||
self.token_service.calculate_max_tokens(
|
||||
self.model, prompt.format_prompt(**args).to_string()
|
||||
)
|
||||
|
||||
completion = await call_model_with_handling(
|
||||
self.model, prompt, args, settings=self.settings, callbacks=self.callbacks
|
||||
)
|
||||
|
||||
previous_tasks = (completed_tasks or []) + tasks
|
||||
return [completion] if completion not in previous_tasks else []
|
||||
|
||||
async def summarize_task_agent(
|
||||
self,
|
||||
*,
|
||||
goal: str,
|
||||
results: List[str],
|
||||
) -> FastAPIStreamingResponse:
|
||||
self.model.model_name = "gpt-4o"
|
||||
self.model.max_tokens = 8000 # Total tokens = prompt tokens + completion tokens
|
||||
|
||||
snippet_max_tokens = 7000 # Leave room for the rest of the prompt
|
||||
text_tokens = self.token_service.tokenize("".join(results))
|
||||
text = self.token_service.detokenize(text_tokens[0:snippet_max_tokens])
|
||||
logger.info(f"Summarizing text: {text}")
|
||||
|
||||
return summarize(
|
||||
model=self.model,
|
||||
language=self.settings.language,
|
||||
goal=goal,
|
||||
text=text,
|
||||
)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
*,
|
||||
message: str,
|
||||
results: List[str],
|
||||
) -> FastAPIStreamingResponse:
|
||||
self.model.model_name = "gpt-4o"
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
SystemMessagePromptTemplate(prompt=chat_prompt),
|
||||
*[HumanMessage(content=result) for result in results],
|
||||
HumanMessage(content=message),
|
||||
]
|
||||
)
|
||||
|
||||
self.token_service.calculate_max_tokens(
|
||||
self.model,
|
||||
prompt.format_prompt(
|
||||
language=self.settings.language,
|
||||
).to_string(),
|
||||
)
|
||||
|
||||
chain = LLMChain(llm=self.model, prompt=prompt)
|
||||
|
||||
return StreamingResponse.from_chain(
|
||||
chain,
|
||||
{"language": self.settings.language},
|
||||
media_type="text/event-stream",
|
||||
)
|
45
platform/reworkd_platform/web/api/agent/analysis.py
Normal file
45
platform/reworkd_platform/web/api/agent/analysis.py
Normal file
@ -0,0 +1,45 @@
|
||||
from typing import Dict
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
|
||||
class AnalysisArguments(BaseModel):
|
||||
"""
|
||||
Arguments for the analysis function of a tool. OpenAI functions will resolve these values but leave out the action.
|
||||
"""
|
||||
|
||||
reasoning: str
|
||||
arg: str
|
||||
|
||||
|
||||
class Analysis(AnalysisArguments):
|
||||
action: str
|
||||
|
||||
@validator("action")
|
||||
def action_must_be_valid_tool(cls, v: str) -> str:
|
||||
# TODO: Remove circular import
|
||||
from reworkd_platform.web.api.agent.tools.tools import get_available_tools_names
|
||||
|
||||
if v not in get_available_tools_names():
|
||||
raise ValueError(f"Analysis action '{v}' is not a valid tool")
|
||||
return v
|
||||
|
||||
@validator("action")
|
||||
def search_action_must_have_arg(cls, v: str, values: Dict[str, str]) -> str:
|
||||
from reworkd_platform.web.api.agent.tools.search import Search
|
||||
from reworkd_platform.web.api.agent.tools.tools import get_tool_name
|
||||
|
||||
if v == get_tool_name(Search) and not values["arg"]:
|
||||
raise ValueError("Analysis arg cannot be empty if action is 'search'")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def get_default_analysis(cls, task: str) -> "Analysis":
|
||||
# TODO: Remove circular import
|
||||
from reworkd_platform.web.api.agent.tools.tools import get_default_tool_name
|
||||
|
||||
return cls(
|
||||
reasoning="Hmm... I'll try searching it up",
|
||||
action=get_default_tool_name(),
|
||||
arg=task,
|
||||
)
|
95
platform/reworkd_platform/web/api/agent/dependancies.py
Normal file
95
platform/reworkd_platform/web/api/agent/dependancies.py
Normal file
@ -0,0 +1,95 @@
|
||||
from typing import TypeVar
|
||||
|
||||
from fastapi import Body, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from reworkd_platform.db.crud.agent import AgentCRUD
|
||||
from reworkd_platform.db.dependencies import get_db_session
|
||||
from reworkd_platform.schemas.agent import (
|
||||
AgentChat,
|
||||
AgentRun,
|
||||
AgentRunCreate,
|
||||
AgentSummarize,
|
||||
AgentTaskAnalyze,
|
||||
AgentTaskCreate,
|
||||
AgentTaskExecute,
|
||||
Loop_Step,
|
||||
)
|
||||
from reworkd_platform.schemas.user import UserBase
|
||||
from reworkd_platform.web.api.dependencies import get_current_user
|
||||
|
||||
T = TypeVar(
|
||||
"T", AgentTaskAnalyze, AgentTaskExecute, AgentTaskCreate, AgentSummarize, AgentChat
|
||||
)
|
||||
|
||||
|
||||
def agent_crud(
|
||||
user: UserBase = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_db_session),
|
||||
) -> AgentCRUD:
|
||||
return AgentCRUD(session, user)
|
||||
|
||||
|
||||
async def agent_start_validator(
|
||||
body: AgentRunCreate = Body(
|
||||
example={
|
||||
"goal": "Create business plan for a bagel company",
|
||||
"modelSettings": {
|
||||
"customModelName": "gpt-4o",
|
||||
},
|
||||
},
|
||||
),
|
||||
crud: AgentCRUD = Depends(agent_crud),
|
||||
) -> AgentRun:
|
||||
id_ = (await crud.create_run(body.goal)).id
|
||||
return AgentRun(**body.dict(), run_id=str(id_))
|
||||
|
||||
|
||||
async def validate(body: T, crud: AgentCRUD, type_: Loop_Step) -> T:
|
||||
body.run_id = (await crud.create_task(body.run_id, type_)).id
|
||||
return body
|
||||
|
||||
|
||||
async def agent_analyze_validator(
|
||||
body: AgentTaskAnalyze = Body(),
|
||||
crud: AgentCRUD = Depends(agent_crud),
|
||||
) -> AgentTaskAnalyze:
|
||||
return await validate(body, crud, "analyze")
|
||||
|
||||
|
||||
async def agent_execute_validator(
|
||||
body: AgentTaskExecute = Body(
|
||||
example={
|
||||
"goal": "Perform tasks accurately",
|
||||
"task": "Write code to make a platformer",
|
||||
"analysis": {
|
||||
"reasoning": "I like to write code.",
|
||||
"action": "code",
|
||||
"arg": "",
|
||||
},
|
||||
},
|
||||
),
|
||||
crud: AgentCRUD = Depends(agent_crud),
|
||||
) -> AgentTaskExecute:
|
||||
return await validate(body, crud, "execute")
|
||||
|
||||
|
||||
async def agent_create_validator(
|
||||
body: AgentTaskCreate = Body(),
|
||||
crud: AgentCRUD = Depends(agent_crud),
|
||||
) -> AgentTaskCreate:
|
||||
return await validate(body, crud, "create")
|
||||
|
||||
|
||||
async def agent_summarize_validator(
|
||||
body: AgentSummarize = Body(),
|
||||
crud: AgentCRUD = Depends(agent_crud),
|
||||
) -> AgentSummarize:
|
||||
return await validate(body, crud, "summarize")
|
||||
|
||||
|
||||
async def agent_chat_validator(
|
||||
body: AgentChat = Body(),
|
||||
crud: AgentCRUD = Depends(agent_crud),
|
||||
) -> AgentChat:
|
||||
return await validate(body, crud, "chat")
|
76
platform/reworkd_platform/web/api/agent/helpers.py
Normal file
76
platform/reworkd_platform/web/api/agent/helpers.py
Normal file
@ -0,0 +1,76 @@
|
||||
from typing import Any, Callable, Dict, TypeVar
|
||||
|
||||
from langchain import BasePromptTemplate, LLMChain
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
from openai.error import (
|
||||
AuthenticationError,
|
||||
InvalidRequestError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
)
|
||||
|
||||
from reworkd_platform.schemas.agent import ModelSettings
|
||||
from reworkd_platform.web.api.errors import OpenAIError
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def parse_with_handling(parser: BaseOutputParser[T], completion: str) -> T:
|
||||
try:
|
||||
return parser.parse(completion)
|
||||
except OutputParserException as e:
|
||||
raise OpenAIError(
|
||||
e, "There was an issue parsing the response from the AI model."
|
||||
)
|
||||
|
||||
|
||||
async def openai_error_handler(
|
||||
func: Callable[..., Any], *args: Any, settings: ModelSettings, **kwargs: Any
|
||||
) -> Any:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except ServiceUnavailableError as e:
|
||||
raise OpenAIError(
|
||||
e,
|
||||
"OpenAI is experiencing issues. Visit "
|
||||
"https://status.openai.com/ for more info.",
|
||||
should_log=not settings.custom_api_key,
|
||||
)
|
||||
except InvalidRequestError as e:
|
||||
if e.user_message.startswith("The model:"):
|
||||
raise OpenAIError(
|
||||
e,
|
||||
f"Your API key does not have access to your current model. Please use a different model.",
|
||||
should_log=not settings.custom_api_key,
|
||||
)
|
||||
raise OpenAIError(e, e.user_message)
|
||||
except AuthenticationError as e:
|
||||
raise OpenAIError(
|
||||
e,
|
||||
"Authentication error: Ensure a valid API key is being used.",
|
||||
should_log=not settings.custom_api_key,
|
||||
)
|
||||
except RateLimitError as e:
|
||||
if e.user_message.startswith("You exceeded your current quota"):
|
||||
raise OpenAIError(
|
||||
e,
|
||||
f"Your API key exceeded your current quota, please check your plan and billing details.",
|
||||
should_log=not settings.custom_api_key,
|
||||
)
|
||||
raise OpenAIError(e, e.user_message)
|
||||
except Exception as e:
|
||||
raise OpenAIError(
|
||||
e, "There was an unexpected issue getting a response from the AI model."
|
||||
)
|
||||
|
||||
|
||||
async def call_model_with_handling(
|
||||
model: BaseChatModel,
|
||||
prompt: BasePromptTemplate,
|
||||
args: Dict[str, str],
|
||||
settings: ModelSettings,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
chain = LLMChain(llm=model, prompt=prompt)
|
||||
return await openai_error_handler(chain.arun, args, settings=settings, **kwargs)
|
97
platform/reworkd_platform/web/api/agent/model_factory.py
Normal file
97
platform/reworkd_platform/web/api/agent/model_factory.py
Normal file
@ -0,0 +1,97 @@
|
||||
from typing import Any, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
||||
from pydantic import Field
|
||||
|
||||
from reworkd_platform.schemas.agent import LLM_Model, ModelSettings
|
||||
from reworkd_platform.schemas.user import UserBase
|
||||
from reworkd_platform.settings import Settings
|
||||
|
||||
|
||||
class WrappedChatOpenAI(ChatOpenAI):
|
||||
client: Any = Field(
|
||||
default=None,
|
||||
description="Meta private value but mypy will complain its missing",
|
||||
)
|
||||
max_tokens: int
|
||||
model_name: LLM_Model = Field(alias="model")
|
||||
|
||||
|
||||
class WrappedAzureChatOpenAI(AzureChatOpenAI, WrappedChatOpenAI):
|
||||
openai_api_base: str
|
||||
openai_api_version: str
|
||||
deployment_name: str
|
||||
|
||||
|
||||
WrappedChat = Union[WrappedAzureChatOpenAI, WrappedChatOpenAI]
|
||||
|
||||
|
||||
def create_model(
|
||||
settings: Settings,
|
||||
model_settings: ModelSettings,
|
||||
user: UserBase,
|
||||
streaming: bool = False,
|
||||
force_model: Optional[LLM_Model] = None,
|
||||
) -> WrappedChat:
|
||||
use_azure = (
|
||||
not model_settings.custom_api_key and "azure" in settings.openai_api_base
|
||||
)
|
||||
|
||||
llm_model = force_model or model_settings.model
|
||||
model: Type[WrappedChat] = WrappedChatOpenAI
|
||||
base, headers, use_helicone = get_base_and_headers(settings, model_settings, user)
|
||||
kwargs = {
|
||||
"openai_api_base": base,
|
||||
"openai_api_key": model_settings.custom_api_key or settings.openai_api_key,
|
||||
"temperature": model_settings.temperature,
|
||||
"model": llm_model,
|
||||
"max_tokens": model_settings.max_tokens,
|
||||
"streaming": streaming,
|
||||
"max_retries": 5,
|
||||
"model_kwargs": {"user": user.email, "headers": headers},
|
||||
}
|
||||
|
||||
if use_azure:
|
||||
model = WrappedAzureChatOpenAI
|
||||
deployment_name = llm_model.replace(".", "")
|
||||
kwargs.update(
|
||||
{
|
||||
"openai_api_version": settings.openai_api_version,
|
||||
"deployment_name": deployment_name,
|
||||
"openai_api_type": "azure",
|
||||
"openai_api_base": base.rstrip("v1"),
|
||||
}
|
||||
)
|
||||
|
||||
if use_helicone:
|
||||
kwargs["model"] = deployment_name
|
||||
|
||||
return model(**kwargs) # type: ignore
|
||||
|
||||
|
||||
def get_base_and_headers(
|
||||
settings_: Settings, model_settings: ModelSettings, user: UserBase
|
||||
) -> Tuple[str, Optional[Dict[str, str]], bool]:
|
||||
use_helicone = settings_.helicone_enabled and not model_settings.custom_api_key
|
||||
base = (
|
||||
settings_.helicone_api_base
|
||||
if use_helicone
|
||||
else (
|
||||
"https://api.openai.com/v1"
|
||||
if model_settings.custom_api_key
|
||||
else settings_.openai_api_base
|
||||
)
|
||||
)
|
||||
|
||||
headers = (
|
||||
{
|
||||
"Helicone-Auth": f"Bearer {settings_.helicone_api_key}",
|
||||
"Helicone-Cache-Enabled": "true",
|
||||
"Helicone-User-Id": user.id,
|
||||
"Helicone-OpenAI-Api-Base": settings_.openai_api_base,
|
||||
}
|
||||
if use_helicone
|
||||
else None
|
||||
)
|
||||
|
||||
return base, headers, use_helicone
|
185
platform/reworkd_platform/web/api/agent/prompts.py
Normal file
185
platform/reworkd_platform/web/api/agent/prompts.py
Normal file
@ -0,0 +1,185 @@
|
||||
from langchain import PromptTemplate
|
||||
|
||||
# Create initial tasks using plan and solve prompting
|
||||
# https://github.com/AGI-Edgerunners/Plan-and-Solve-Prompting
|
||||
start_goal_prompt = PromptTemplate(
|
||||
template="""You are a task creation AI called Allinix.
|
||||
You answer in the "{language}" language. You have the following objective "{goal}".
|
||||
Return a list of search queries that would be required to answer the entirety of the objective.
|
||||
Limit the list to a maximum of 5 queries. Ensure the queries are as succinct as possible.
|
||||
For simple questions use a single query.
|
||||
|
||||
Return the response as a JSON array of strings. Examples:
|
||||
|
||||
query: "Who is considered the best NBA player in the current season?", answer: ["current NBA MVP candidates"]
|
||||
query: "How does the Olympicpayroll brand currently stand in the market, and what are its prospects and strategies for expansion in NJ, NY, and PA?", answer: ["Olympicpayroll brand comprehensive analysis 2023", "customer reviews of Olympicpayroll.com", "Olympicpayroll market position analysis", "payroll industry trends forecast 2023-2025", "payroll services expansion strategies in NJ, NY, PA"]
|
||||
query: "How can I create a function to add weight to edges in a digraph using {language}?", answer: ["algorithm to add weight to digraph edge in {language}"]
|
||||
query: "What is the current weather in New York?", answer: ["current weather in New York"]
|
||||
query: "5 + 5?", answer: ["Sum of 5 and 5"]
|
||||
query: "What is a good homemade recipe for KFC-style chicken?", answer: ["KFC style chicken recipe at home"]
|
||||
query: "What are the nutritional values of almond milk and soy milk?", answer: ["nutritional information of almond milk", "nutritional information of soy milk"]""",
|
||||
input_variables=["goal", "language"],
|
||||
)
|
||||
|
||||
analyze_task_prompt = PromptTemplate(
|
||||
template="""
|
||||
High level objective: "{goal}"
|
||||
Current task: "{task}"
|
||||
|
||||
Based on this information, use the best function to make progress or accomplish the task entirely.
|
||||
Select the correct function by being smart and efficient. Ensure "reasoning" and only "reasoning" is in the
|
||||
{language} language.
|
||||
|
||||
Note you MUST select a function.
|
||||
""",
|
||||
input_variables=["goal", "task", "language"],
|
||||
)
|
||||
|
||||
code_prompt = PromptTemplate(
|
||||
template="""
|
||||
You are a world-class software engineer and an expert in all programing languages,
|
||||
software systems, and architecture.
|
||||
|
||||
For reference, your high level goal is {goal}
|
||||
|
||||
Write code in English but explanations/comments in the "{language}" language.
|
||||
|
||||
Provide no information about who you are and focus on writing code.
|
||||
Ensure code is bug and error free and explain complex concepts through comments
|
||||
Respond in well-formatted markdown. Ensure code blocks are used for code sections.
|
||||
Approach problems step by step and file by file, for each section, use a heading to describe the section.
|
||||
|
||||
Write code to accomplish the following:
|
||||
{task}
|
||||
""",
|
||||
input_variables=["goal", "language", "task"],
|
||||
)
|
||||
|
||||
execute_task_prompt = PromptTemplate(
|
||||
template="""Answer in the "{language}" language. Given
|
||||
the following overall objective `{goal}` and the following sub-task, `{task}`.
|
||||
|
||||
Perform the task by understanding the problem, extracting variables, and being smart
|
||||
and efficient. Write a detailed response that address the task.
|
||||
When confronted with choices, make a decision yourself with reasoning.
|
||||
""",
|
||||
input_variables=["goal", "language", "task"],
|
||||
)
|
||||
|
||||
create_tasks_prompt = PromptTemplate(
|
||||
template="""You are an AI task creation agent. You must answer in the "{language}"
|
||||
language. You have the following objective `{goal}`.
|
||||
|
||||
You have the following incomplete tasks:
|
||||
`{tasks}`
|
||||
|
||||
You just completed the following task:
|
||||
`{lastTask}`
|
||||
|
||||
And received the following result:
|
||||
`{result}`.
|
||||
|
||||
Based on this, create a single new task to be completed by your AI system such that your goal is closer reached.
|
||||
If there are no more tasks to be done, return nothing. Do not add quotes to the task.
|
||||
|
||||
Examples:
|
||||
Search the web for NBA news
|
||||
Create a function to add a new vertex with a specified weight to the digraph.
|
||||
Search for any additional information on Bertie W.
|
||||
""
|
||||
""",
|
||||
input_variables=["goal", "language", "tasks", "lastTask", "result"],
|
||||
)
|
||||
|
||||
summarize_prompt = PromptTemplate(
|
||||
template="""You must answer in the "{language}" language.
|
||||
|
||||
Combine the following text into a cohesive document:
|
||||
|
||||
"{text}"
|
||||
|
||||
Write using clear markdown formatting in a style expected of the goal "{goal}".
|
||||
Be as clear, informative, and descriptive as necessary.
|
||||
You will not make up information or add any information outside of the above text.
|
||||
Only use the given information and nothing more.
|
||||
|
||||
If there is no information provided, say "There is nothing to summarize".
|
||||
""",
|
||||
input_variables=["goal", "language", "text"],
|
||||
)
|
||||
|
||||
company_context_prompt = PromptTemplate(
|
||||
template="""You must answer in the "{language}" language.
|
||||
|
||||
Create a short description on "{company_name}".
|
||||
Find out what sector it is in and what are their primary products.
|
||||
|
||||
Be as clear, informative, and descriptive as necessary.
|
||||
You will not make up information or add any information outside of the above text.
|
||||
Only use the given information and nothing more.
|
||||
|
||||
If there is no information provided, say "There is nothing to summarize".
|
||||
""",
|
||||
input_variables=["company_name", "language"],
|
||||
)
|
||||
|
||||
summarize_pdf_prompt = PromptTemplate(
|
||||
template="""You must answer in the "{language}" language.
|
||||
|
||||
For the given text: "{text}", you have the following objective "{query}".
|
||||
|
||||
Be as clear, informative, and descriptive as necessary.
|
||||
You will not make up information or add any information outside of the above text.
|
||||
Only use the given information and nothing more.
|
||||
""",
|
||||
input_variables=["query", "language", "text"],
|
||||
)
|
||||
|
||||
summarize_with_sources_prompt = PromptTemplate(
|
||||
template="""You must answer in the "{language}" language.
|
||||
|
||||
Answer the following query: "{query}" using the following information: "{snippets}".
|
||||
Write using clear markdown formatting and use markdown lists where possible.
|
||||
|
||||
Cite sources for sentences via markdown links using the source link as the link and the index as the text.
|
||||
Use in-line sources. Do not separately list sources at the end of the writing.
|
||||
|
||||
If the query cannot be answered with the provided information, mention this and provide a reason why along with what it does mention.
|
||||
Also cite the sources of what is actually mentioned.
|
||||
|
||||
Example sentences of the paragraph:
|
||||
"So this is a cited sentence at the end of a paragraph[1](https://test.com). This is another sentence."
|
||||
"Stephen curry is an american basketball player that plays for the warriors[1](https://www.britannica.com/biography/Stephen-Curry)."
|
||||
"The economic growth forecast for the region has been adjusted from 2.5% to 3.1% due to improved trade relations[1](https://economictimes.com), while inflation rates are expected to remain steady at around 1.7% according to financial analysts[2](https://financeworld.com)."
|
||||
""",
|
||||
input_variables=["language", "query", "snippets"],
|
||||
)
|
||||
|
||||
summarize_sid_prompt = PromptTemplate(
|
||||
template="""You must answer in the "{language}" language.
|
||||
|
||||
Parse and summarize the following text snippets "{snippets}".
|
||||
Write using clear markdown formatting in a style expected of the goal "{goal}".
|
||||
Be as clear, informative, and descriptive as necessary and attempt to
|
||||
answer the query: "{query}" as best as possible.
|
||||
If any of the snippets are not relevant to the query,
|
||||
ignore them, and do not include them in the summary.
|
||||
Do not mention that you are ignoring them.
|
||||
|
||||
If there is no information provided, say "There is nothing to summarize".
|
||||
""",
|
||||
input_variables=["goal", "language", "query", "snippets"],
|
||||
)
|
||||
|
||||
chat_prompt = PromptTemplate(
|
||||
template="""You must answer in the "{language}" language.
|
||||
|
||||
You are a helpful AI Assistant that will provide responses based on the current conversation history.
|
||||
|
||||
The human will provide previous messages as context. Use ONLY this information for your responses.
|
||||
Do not make anything up and do not add any additional information.
|
||||
If you have no information for a given question in the conversation history,
|
||||
say "I do not have any information on this".
|
||||
""",
|
||||
input_variables=["language"],
|
||||
)
|
26
platform/reworkd_platform/web/api/agent/stream_mock.py
Normal file
26
platform/reworkd_platform/web/api/agent/stream_mock.py
Normal file
@ -0,0 +1,26 @@
|
||||
import asyncio
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import tiktoken
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def stream_string(data: str, delayed: bool = False) -> FastAPIStreamingResponse:
|
||||
return FastAPIStreamingResponse(
|
||||
stream_generator(data, delayed),
|
||||
)
|
||||
|
||||
|
||||
async def stream_generator(data: str, delayed: bool) -> AsyncGenerator[bytes, None]:
|
||||
if delayed:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
token_data = encoding.encode(data)
|
||||
|
||||
for token in token_data:
|
||||
yield encoding.decode([token]).encode("utf-8")
|
||||
await asyncio.sleep(0.025) # simulate slow processing
|
||||
else:
|
||||
yield data.encode()
|
@ -0,0 +1,88 @@
|
||||
import ast
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from langchain.schema import BaseOutputParser, OutputParserException
|
||||
|
||||
|
||||
class TaskOutputParser(BaseOutputParser[List[str]]):
|
||||
"""
|
||||
Extension of LangChain's BaseOutputParser
|
||||
Responsible for parsing task creation output into a list of task strings
|
||||
"""
|
||||
|
||||
completed_tasks: List[str] = []
|
||||
|
||||
def __init__(self, *, completed_tasks: List[str]):
|
||||
super().__init__()
|
||||
self.completed_tasks = completed_tasks
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
try:
|
||||
array_str = extract_array(text)
|
||||
all_tasks = [
|
||||
remove_prefix(task) for task in array_str if real_tasks_filter(task)
|
||||
]
|
||||
return [task for task in all_tasks if task not in self.completed_tasks]
|
||||
except Exception as e:
|
||||
msg = f"Failed to parse tasks from completion '{text}'. Exception: {e}"
|
||||
raise OutputParserException(msg)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return """
|
||||
The response should be a JSON array of strings. Example:
|
||||
|
||||
["Search the web for NBA news", "Write some code to build a web scraper"]
|
||||
|
||||
This should be parsable by json.loads()
|
||||
"""
|
||||
|
||||
|
||||
def extract_array(input_str: str) -> List[str]:
|
||||
regex = (
|
||||
r"\[\s*\]|" # Empty array check
|
||||
r"(\[(?:\s*(?:\"(?:[^\"\\]|\\.)*\"|\'(?:[^\'\\]|\\.)*\')\s*,?)*\s*\])"
|
||||
)
|
||||
match = re.search(regex, input_str)
|
||||
if match is not None:
|
||||
return ast.literal_eval(match[0])
|
||||
else:
|
||||
return handle_multiline_string(input_str)
|
||||
|
||||
|
||||
def handle_multiline_string(input_str: str) -> List[str]:
|
||||
# Handle multiline string as a list
|
||||
processed_lines = [
|
||||
re.sub(r".*?(\d+\..+)", r"\1", line).strip()
|
||||
for line in input_str.split("\n")
|
||||
if line.strip() != ""
|
||||
]
|
||||
|
||||
# Check if there is at least one line that starts with a digit and a period
|
||||
if any(re.match(r"\d+\..+", line) for line in processed_lines):
|
||||
return processed_lines
|
||||
else:
|
||||
raise RuntimeError(f"Failed to extract array from {input_str}")
|
||||
|
||||
|
||||
def remove_prefix(input_str: str) -> str:
|
||||
prefix_pattern = (
|
||||
r"^(Task\s*\d*\.\s*|Task\s*\d*[-:]?\s*|Step\s*\d*["
|
||||
r"-:]?\s*|Step\s*[-:]?\s*|\d+\.\s*|\d+\s*[-:]?\s*|^\.\s*|^\.*)"
|
||||
)
|
||||
return re.sub(prefix_pattern, "", input_str, flags=re.IGNORECASE)
|
||||
|
||||
|
||||
def real_tasks_filter(input_str: str) -> bool:
|
||||
no_task_regex = (
|
||||
r"^No( (new|further|additional|extra|other))? tasks? (is )?("
|
||||
r"required|needed|added|created|inputted).*"
|
||||
)
|
||||
task_complete_regex = r"^Task (complete|completed|finished|done|over|success).*"
|
||||
do_nothing_regex = r"^(\s*|Do nothing(\s.*)?)$"
|
||||
|
||||
return (
|
||||
not re.search(no_task_regex, input_str, re.IGNORECASE)
|
||||
and not re.search(task_complete_regex, input_str, re.IGNORECASE)
|
||||
and not re.search(do_nothing_regex, input_str, re.IGNORECASE)
|
||||
)
|
25
platform/reworkd_platform/web/api/agent/tools/code.py
Normal file
25
platform/reworkd_platform/web/api/agent/tools/code.py
Normal file
@ -0,0 +1,25 @@
|
||||
from typing import Any
|
||||
|
||||
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
|
||||
from lanarky.responses import StreamingResponse
|
||||
from langchain import LLMChain
|
||||
|
||||
from reworkd_platform.web.api.agent.tools.tool import Tool
|
||||
|
||||
|
||||
class Code(Tool):
|
||||
description = "Should only be used to write code, refactor code, fix code bugs, and explain programming concepts."
|
||||
public_description = "Write and review code."
|
||||
|
||||
async def call(
|
||||
self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any
|
||||
) -> FastAPIStreamingResponse:
|
||||
from reworkd_platform.web.api.agent.prompts import code_prompt
|
||||
|
||||
chain = LLMChain(llm=self.model, prompt=code_prompt)
|
||||
|
||||
return StreamingResponse.from_chain(
|
||||
chain,
|
||||
{"goal": goal, "language": self.language, "task": task},
|
||||
media_type="text/event-stream",
|
||||
)
|
15
platform/reworkd_platform/web/api/agent/tools/conclude.py
Normal file
15
platform/reworkd_platform/web/api/agent/tools/conclude.py
Normal file
@ -0,0 +1,15 @@
|
||||
from typing import Any
|
||||
|
||||
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
|
||||
|
||||
from reworkd_platform.web.api.agent.stream_mock import stream_string
|
||||
from reworkd_platform.web.api.agent.tools.tool import Tool
|
||||
|
||||
|
||||
class Conclude(Tool):
|
||||
description = "Use when there is nothing else to do. The task has been concluded."
|
||||
|
||||
async def call(
|
||||
self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any
|
||||
) -> FastAPIStreamingResponse:
|
||||
return stream_string("Task execution concluded.", delayed=True)
|
66
platform/reworkd_platform/web/api/agent/tools/image.py
Normal file
66
platform/reworkd_platform/web/api/agent/tools/image.py
Normal file
@ -0,0 +1,66 @@
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
import replicate
|
||||
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
|
||||
from replicate.exceptions import ModelError
|
||||
from replicate.exceptions import ReplicateError as ReplicateAPIError
|
||||
|
||||
from reworkd_platform.settings import settings
|
||||
from reworkd_platform.web.api.agent.stream_mock import stream_string
|
||||
from reworkd_platform.web.api.agent.tools.tool import Tool
|
||||
from reworkd_platform.web.api.errors import ReplicateError
|
||||
|
||||
|
||||
async def get_replicate_image(input_str: str) -> str:
|
||||
if settings.replicate_api_key is None or settings.replicate_api_key == "":
|
||||
raise RuntimeError("Replicate API key not set")
|
||||
|
||||
client = replicate.Client(settings.replicate_api_key)
|
||||
try:
|
||||
output = client.run(
|
||||
"stability-ai/stable-diffusion"
|
||||
":db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf",
|
||||
input={"prompt": input_str},
|
||||
image_dimensions="512x512",
|
||||
)
|
||||
except ModelError as e:
|
||||
raise ReplicateError(e, "Image generation failed due to NSFW image.")
|
||||
except ReplicateAPIError as e:
|
||||
raise ReplicateError(e, "Failed to generate an image.")
|
||||
|
||||
return output[0]
|
||||
|
||||
|
||||
# Use AI to generate an Image based on a prompt
|
||||
async def get_open_ai_image(input_str: str) -> str:
|
||||
response = openai.Image.create(
|
||||
api_key=settings.openai_api_key,
|
||||
prompt=input_str,
|
||||
n=1,
|
||||
size="256x256",
|
||||
)
|
||||
|
||||
return response["data"][0]["url"]
|
||||
|
||||
|
||||
class Image(Tool):
|
||||
description = "Used to sketch, draw, or generate an image."
|
||||
public_description = "Generate AI images."
|
||||
arg_description = (
|
||||
"The input prompt to the image generator. "
|
||||
"This should be a detailed description of the image touching on image "
|
||||
"style, image focus, color, etc."
|
||||
)
|
||||
image_url = "/tools/replicate.png"
|
||||
|
||||
async def call(
|
||||
self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any
|
||||
) -> FastAPIStreamingResponse:
|
||||
# Use the replicate API if its available, otherwise use DALL-E
|
||||
try:
|
||||
url = await get_replicate_image(input_str)
|
||||
except RuntimeError:
|
||||
url = await get_open_ai_image(input_str)
|
||||
|
||||
return stream_string(f"")
|
@ -0,0 +1,43 @@
|
||||
from typing import Type, TypedDict
|
||||
|
||||
from reworkd_platform.web.api.agent.tools.tool import Tool
|
||||
from reworkd_platform.web.api.agent.tools.tools import get_tool_name
|
||||
|
||||
|
||||
class FunctionDescription(TypedDict):
|
||||
"""Representation of a callable function to the OpenAI API."""
|
||||
|
||||
name: str
|
||||
"""The name of the function."""
|
||||
description: str
|
||||
"""A description of the function."""
|
||||
parameters: dict[str, object]
|
||||
"""The parameters of the function."""
|
||||
|
||||
|
||||
def get_tool_function(tool: Type[Tool]) -> FunctionDescription:
|
||||
"""A function that will return the tool's function specification"""
|
||||
name = get_tool_name(tool)
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"description": tool.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
f"Reasoning is how the task will be accomplished with the current function. "
|
||||
"Detail your overall plan along with any concerns you have."
|
||||
"Ensure this reasoning value is in the user defined langauge "
|
||||
),
|
||||
},
|
||||
"arg": {
|
||||
"type": "string",
|
||||
"description": tool.arg_description,
|
||||
},
|
||||
},
|
||||
"required": ["reasoning", "arg"],
|
||||
},
|
||||
}
|
27
platform/reworkd_platform/web/api/agent/tools/reason.py
Normal file
27
platform/reworkd_platform/web/api/agent/tools/reason.py
Normal file
@ -0,0 +1,27 @@
|
||||
from typing import Any
|
||||
|
||||
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
|
||||
from lanarky.responses import StreamingResponse
|
||||
from langchain import LLMChain
|
||||
|
||||
from reworkd_platform.web.api.agent.tools.tool import Tool
|
||||
|
||||
|
||||
class Reason(Tool):
|
||||
description = (
|
||||
"Reason about task via existing information or understanding. "
|
||||
"Make decisions / selections from options."
|
||||
)
|
||||
|
||||
async def call(
|
||||
self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any
|
||||
) -> FastAPIStreamingResponse:
|
||||
from reworkd_platform.web.api.agent.prompts import execute_task_prompt
|
||||
|
||||
chain = LLMChain(llm=self.model, prompt=execute_task_prompt)
|
||||
|
||||
return StreamingResponse.from_chain(
|
||||
chain,
|
||||
{"goal": goal, "language": self.language, "task": task},
|
||||
media_type="text/event-stream",
|
||||
)
|
109
platform/reworkd_platform/web/api/agent/tools/search.py
Normal file
109
platform/reworkd_platform/web/api/agent/tools/search.py
Normal file
@ -0,0 +1,109 @@
|
||||
from typing import Any, List
|
||||
from urllib.parse import quote
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import ClientResponseError
|
||||
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
|
||||
from loguru import logger
|
||||
|
||||
from reworkd_platform.settings import settings
|
||||
from reworkd_platform.web.api.agent.stream_mock import stream_string
|
||||
from reworkd_platform.web.api.agent.tools.reason import Reason
|
||||
from reworkd_platform.web.api.agent.tools.tool import Tool
|
||||
from reworkd_platform.web.api.agent.tools.utils import (
|
||||
CitedSnippet,
|
||||
summarize_with_sources,
|
||||
)
|
||||
|
||||
# Search google via serper.dev. Adapted from LangChain
|
||||
# https://github.com/hwchase17/langchain/blob/master/langchain/utilities
|
||||
|
||||
|
||||
async def _google_serper_search_results(
|
||||
search_term: str, search_type: str = "search"
|
||||
) -> dict[str, Any]:
|
||||
headers = {
|
||||
"X-API-KEY": settings.serp_api_key or "",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
params = {
|
||||
"q": search_term,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"https://google.serper.dev/{search_type}", headers=headers, params=params
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
search_results = await response.json()
|
||||
return search_results
|
||||
|
||||
|
||||
class Search(Tool):
|
||||
description = (
|
||||
"Search Google for short up to date searches for simple questions about public information "
|
||||
"news and people.\n"
|
||||
)
|
||||
public_description = "Search google for information about current events."
|
||||
arg_description = "The query argument to search for. This value is always populated and cannot be an empty string."
|
||||
image_url = "/tools/google.png"
|
||||
|
||||
@staticmethod
|
||||
def available() -> bool:
|
||||
return settings.serp_api_key is not None and settings.serp_api_key != ""
|
||||
|
||||
async def call(
|
||||
self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any
|
||||
) -> FastAPIStreamingResponse:
|
||||
try:
|
||||
return await self._call(goal, task, input_str, *args, **kwargs)
|
||||
except ClientResponseError:
|
||||
logger.exception("Error calling Serper API, falling back to reasoning")
|
||||
return await Reason(self.model, self.language).call(
|
||||
goal, task, input_str, *args, **kwargs
|
||||
)
|
||||
|
||||
async def _call(
|
||||
self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any
|
||||
) -> FastAPIStreamingResponse:
|
||||
results = await _google_serper_search_results(
|
||||
input_str,
|
||||
)
|
||||
|
||||
k = 5 # Number of results to return
|
||||
snippets: List[CitedSnippet] = []
|
||||
|
||||
if results.get("answerBox"):
|
||||
answer_values = []
|
||||
answer_box = results.get("answerBox", {})
|
||||
if answer_box.get("answer"):
|
||||
answer_values.append(answer_box.get("answer"))
|
||||
elif answer_box.get("snippet"):
|
||||
answer_values.append(answer_box.get("snippet").replace("\n", " "))
|
||||
elif answer_box.get("snippetHighlighted"):
|
||||
answer_values.append(", ".join(answer_box.get("snippetHighlighted")))
|
||||
|
||||
if len(answer_values) > 0:
|
||||
snippets.append(
|
||||
CitedSnippet(
|
||||
len(snippets) + 1,
|
||||
"\n".join(answer_values),
|
||||
f"https://www.google.com/search?q={quote(input_str)}",
|
||||
)
|
||||
)
|
||||
|
||||
for i, result in enumerate(results["organic"][:k]):
|
||||
texts = []
|
||||
link = ""
|
||||
if "snippet" in result:
|
||||
texts.append(result["snippet"])
|
||||
if "link" in result:
|
||||
link = result["link"]
|
||||
for attribute, value in result.get("attributes", {}).items():
|
||||
texts.append(f"{attribute}: {value}.")
|
||||
snippets.append(CitedSnippet(len(snippets) + 1, "\n".join(texts), link))
|
||||
|
||||
if len(snippets) == 0:
|
||||
return stream_string("No good Google Search Result was found", True)
|
||||
|
||||
return summarize_with_sources(self.model, self.language, goal, task, snippets)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user