commit -2

This commit is contained in:
allinix 2025-02-17 19:44:17 +05:30
parent 48da88668c
commit bdfe11f039
124 changed files with 12452 additions and 0 deletions

24
cli/.gitignore vendored Normal file
View File

@ -0,0 +1,24 @@
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
# dependencies
/node_modules
/.pnp
.pnp.js
# testing
/coverage
# production
/build
# misc
.DS_Store
*.pem
# debug
npm-debug.log*
yarn-debug.log*
yarn-error.log*
.pnpm-debug.log*
.eslintcache

27
cli/README.md Normal file
View File

@ -0,0 +1,27 @@
## AgentGPT CLI
AgentGPT CLI is a utility designed to streamline the setup process of your AgentGPT environment.
It uses Inquirer to interactively build up ENV values while also validating they are correct.
This was first created by @JPDucky on GitHub.
### Running the tool
```
// Running from the root of the project
./setup.sh
```
```
// Running from the cli directory
cd cli/
npm run start
```
### Updating ENV values
To update ENV values:
- Add a question to the list of questions in `index.js` for the ENV value
- Add a value in the `envDefinition` for the ENV value
- Add the ENV value to the `.env.example` in the root of the project

2865
cli/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

32
cli/package.json Normal file
View File

@ -0,0 +1,32 @@
{
"name": "agentgpt-cli",
"version": "1.0.0",
"description": "A CLI to create your AgentGPT environment",
"private": true,
"engines": {
"node": ">=18.0.0 <19.0.0"
},
"type": "module",
"main": "index.js",
"scripts": {
"start": "node src/index.js",
"dev": "node src/index.js"
},
"author": "reworkd",
"dependencies": {
"@octokit/auth-basic": "^1.4.8",
"@octokit/rest": "^20.0.2",
"chalk": "^5.3.0",
"clear": "^0.1.0",
"clui": "^0.3.6",
"configstore": "^6.0.0",
"dotenv": "^16.3.1",
"figlet": "^1.7.0",
"inquirer": "^9.2.12",
"lodash": "^4.17.21",
"minimist": "^1.2.8",
"node-fetch": "^3.3.2",
"simple-git": "^3.20.0",
"touch": "^3.1.0"
}
}

142
cli/src/envGenerator.js Normal file
View File

@ -0,0 +1,142 @@
import crypto from "crypto";
import fs from "fs";
import chalk from "chalk";
export const generateEnv = (envValues) => {
let isDockerCompose = envValues.runOption === "docker-compose";
let dbPort = isDockerCompose ? 3307 : 3306;
let platformUrl = isDockerCompose
? "http://host.docker.internal:8000"
: "http://localhost:8000";
const envDefinition = getEnvDefinition(
envValues,
isDockerCompose,
dbPort,
platformUrl
);
const envFileContent = generateEnvFileContent(envDefinition);
saveEnvFile(envFileContent);
};
const getEnvDefinition = (envValues, isDockerCompose, dbPort, platformUrl) => {
return {
"Deployment Environment": {
NODE_ENV: "development",
NEXT_PUBLIC_VERCEL_ENV: "${NODE_ENV}",
},
NextJS: {
NEXT_PUBLIC_BACKEND_URL: "http://localhost:8000",
NEXT_PUBLIC_MAX_LOOPS: 100,
},
"Next Auth config": {
NEXTAUTH_SECRET: generateAuthSecret(),
NEXTAUTH_URL: "http://localhost:3000",
},
"Auth providers (Use if you want to get out of development mode sign-in)": {
GOOGLE_CLIENT_ID: "***",
GOOGLE_CLIENT_SECRET: "***",
GITHUB_CLIENT_ID: "***",
GITHUB_CLIENT_SECRET: "***",
DISCORD_CLIENT_SECRET: "***",
DISCORD_CLIENT_ID: "***",
},
Backend: {
REWORKD_PLATFORM_ENVIRONMENT: "${NODE_ENV}",
REWORKD_PLATFORM_FF_MOCK_MODE_ENABLED: false,
REWORKD_PLATFORM_MAX_LOOPS: "${NEXT_PUBLIC_MAX_LOOPS}",
REWORKD_PLATFORM_OPENAI_API_KEY:
envValues.OpenAIApiKey || '"<change me>"',
REWORKD_PLATFORM_FRONTEND_URL: "http://localhost:3000",
REWORKD_PLATFORM_RELOAD: true,
REWORKD_PLATFORM_OPENAI_API_BASE: "https://api.openai.com/v1",
REWORKD_PLATFORM_SERP_API_KEY: envValues.serpApiKey || '""',
REWORKD_PLATFORM_REPLICATE_API_KEY: envValues.replicateApiKey || '""',
},
"Database (Backend)": {
REWORKD_PLATFORM_DATABASE_USER: "reworkd_platform",
REWORKD_PLATFORM_DATABASE_PASSWORD: "reworkd_platform",
REWORKD_PLATFORM_DATABASE_HOST: "agentgpt_db",
REWORKD_PLATFORM_DATABASE_PORT: dbPort,
REWORKD_PLATFORM_DATABASE_NAME: "reworkd_platform",
REWORKD_PLATFORM_DATABASE_URL:
"mysql://${REWORKD_PLATFORM_DATABASE_USER}:${REWORKD_PLATFORM_DATABASE_PASSWORD}@${REWORKD_PLATFORM_DATABASE_HOST}:${REWORKD_PLATFORM_DATABASE_PORT}/${REWORKD_PLATFORM_DATABASE_NAME}",
},
"Database (Frontend)": {
DATABASE_USER: "reworkd_platform",
DATABASE_PASSWORD: "reworkd_platform",
DATABASE_HOST: "agentgpt_db",
DATABASE_PORT: dbPort,
DATABASE_NAME: "reworkd_platform",
DATABASE_URL:
"mysql://${DATABASE_USER}:${DATABASE_PASSWORD}@${DATABASE_HOST}:${DATABASE_PORT}/${DATABASE_NAME}",
},
};
};
const generateEnvFileContent = (config) => {
let configFile = "";
Object.entries(config).forEach(([section, variables]) => {
configFile += `# ${section}:\n`;
Object.entries(variables).forEach(([key, value]) => {
configFile += `${key}=${value}\n`;
});
configFile += "\n";
});
return configFile.trim();
};
const generateAuthSecret = () => {
const length = 32;
const buffer = crypto.randomBytes(length);
return buffer.toString("base64");
};
const ENV_PATH = "../next/.env";
const BACKEND_ENV_PATH = "../platform/.env";
export const doesEnvFileExist = () => {
return fs.existsSync(ENV_PATH);
};
// Read the existing env file, test if it is missing any keys or contains any extra keys
export const testEnvFile = () => {
const data = fs.readFileSync(ENV_PATH, "utf8");
// Make a fake definition to compare the keys of
const envDefinition = getEnvDefinition({}, "", "", "", "");
const lines = data
.split("\n")
.filter((line) => !line.startsWith("#") && line.trim() !== "");
const envKeysFromFile = lines.map((line) => line.split("=")[0]);
const envKeysFromDef = Object.entries(envDefinition).flatMap(
([section, entries]) => Object.keys(entries)
);
const missingFromFile = envKeysFromDef.filter(
(key) => !envKeysFromFile.includes(key)
);
if (missingFromFile.length > 0) {
let errorMessage = "\nYour ./next/.env is missing the following keys:\n";
missingFromFile.forEach((key) => {
errorMessage += chalk.whiteBright(`- ❌ ${key}\n`);
});
errorMessage += "\n";
errorMessage += chalk.red(
"We recommend deleting your .env file(s) and restarting this script."
);
throw new Error(errorMessage);
}
};
export const saveEnvFile = (envFileContent) => {
fs.writeFileSync(ENV_PATH, envFileContent);
fs.writeFileSync(BACKEND_ENV_PATH, envFileContent);
};

26
cli/src/helpers.js Normal file
View File

@ -0,0 +1,26 @@
import chalk from "chalk";
import figlet from "figlet";
export const printTitle = () => {
console.log(
chalk.red(
figlet.textSync("AgentGPT", {
horizontalLayout: "full",
font: "ANSI Shadow",
})
)
);
console.log(
"Welcome to the AgentGPT CLI! This CLI will generate the required .env files."
);
console.log(
"Copies of the generated envs will be created in `./next/.env` and `./platform/.env`.\n"
);
};
// Function to check if entered api key is in the correct format or empty
export const isValidKey = (apikey, pattern) => {
return (apikey === "" || pattern.test(apikey))
};
export const validKeyErrorMessage = "\nInvalid api key. Please try again."

60
cli/src/index.js Normal file
View File

@ -0,0 +1,60 @@
import inquirer from "inquirer";
import dotenv from "dotenv";
import { printTitle } from "./helpers.js";
import { doesEnvFileExist, generateEnv, testEnvFile } from "./envGenerator.js";
import { newEnvQuestions } from "./questions/newEnvQuestions.js";
import { existingEnvQuestions } from "./questions/existingEnvQuestions.js";
import { spawn } from "child_process";
import chalk from "chalk";
const handleExistingEnv = () => {
console.log(chalk.yellow("Existing ./next/env file found. Validating..."));
try {
testEnvFile();
} catch (e) {
console.log(e.message);
return;
}
inquirer.prompt(existingEnvQuestions).then((answers) => {
handleRunOption(answers.runOption);
});
};
const handleNewEnv = () => {
inquirer.prompt(newEnvQuestions).then((answers) => {
dotenv.config({ path: "./.env" });
generateEnv(answers);
console.log("\nEnv files successfully created!");
handleRunOption(answers.runOption);
});
};
const handleRunOption = (runOption) => {
if (runOption === "docker-compose") {
const dockerComposeUp = spawn("docker-compose", ["up", "--build"], {
stdio: "inherit",
});
}
if (runOption === "manual") {
console.log(
"Please go into the ./next folder and run `npm install && npm run dev`."
);
console.log(
"Please also go into the ./platform folder and run `poetry install && poetry run python -m reworkd_platform`."
);
console.log(
"Please use or update the MySQL database configuration in the env file(s)."
);
}
};
printTitle();
if (doesEnvFileExist()) {
handleExistingEnv();
} else {
handleNewEnv();
}

View File

@ -0,0 +1,5 @@
import { RUN_OPTION_QUESTION } from "./sharedQuestions.js";
export const existingEnvQuestions = [
RUN_OPTION_QUESTION
];

View File

@ -0,0 +1,87 @@
import { isValidKey, validKeyErrorMessage } from "../helpers.js";
import { RUN_OPTION_QUESTION } from "./sharedQuestions.js";
import fetch from "node-fetch";
export const newEnvQuestions = [
RUN_OPTION_QUESTION,
{
type: "input",
name: "OpenAIApiKey",
message:
"Enter your openai key (eg: sk...) or press enter to continue with no key:",
validate: async(apikey) => {
if(apikey === "") return true;
if(!isValidKey(apikey, /^sk(-proj)?-[a-zA-Z0-9\-\_]+$/)) {
return validKeyErrorMessage
}
const endpoint = "https://api.openai.com/v1/models"
const response = await fetch(endpoint, {
headers: {
"Authorization": `Bearer ${apikey}`,
},
});
if(!response.ok) {
return validKeyErrorMessage
}
return true
},
},
{
type: "input",
name: "serpApiKey",
message:
"What is your SERP API key (https://serper.dev/)? Leave empty to disable web search.",
validate: async(apikey) => {
if(apikey === "") return true;
if(!isValidKey(apikey, /^[a-zA-Z0-9]{40}$/)) {
return validKeyErrorMessage
}
const endpoint = "https://google.serper.dev/search"
const response = await fetch(endpoint, {
method: 'POST',
headers: {
"X-API-KEY": apikey,
"Content-Type": "application/json",
},
body: JSON.stringify({
"q": "apple inc"
}),
});
if(!response.ok) {
return validKeyErrorMessage
}
return true
},
},
{
type: "input",
name: "replicateApiKey",
message:
"What is your Replicate API key (https://replicate.com/)? Leave empty to just use DALL-E for image generation.",
validate: async(apikey) => {
if(apikey === "") return true;
if(!isValidKey(apikey, /^r8_[a-zA-Z0-9]{37}$/)) {
return validKeyErrorMessage
}
const endpoint = "https://api.replicate.com/v1/models/replicate/hello-world"
const response = await fetch(endpoint, {
headers: {
"Authorization": `Token ${apikey}`,
},
});
if(!response.ok) {
return validKeyErrorMessage
}
return true
},
},
];

View File

@ -0,0 +1,10 @@
export const RUN_OPTION_QUESTION = {
type: 'list',
name: 'runOption',
choices: [
{ value: "docker-compose", name: "🐋 Docker-compose (Recommended)" },
{ value: "manual", name: "💪 Manual (Not recommended)" },
],
message: 'How will you be running AgentGPT?',
default: "docker-compose",
}

115
cli/tsconfig.json Normal file
View File

@ -0,0 +1,115 @@
{
"compilerOptions": {
/* Visit https://aka.ms/tsconfig to read more about this file */
/* Projects */
// "incremental": true, /* Save .tsbuildinfo files to allow for incremental compilation of projects. */
// "composite": true, /* Enable constraints that allow a TypeScript project to be used with project references. */
// "tsBuildInfoFile": "./.tsbuildinfo", /* Specify the path to .tsbuildinfo incremental compilation file. */
// "disableSourceOfProjectReferenceRedirect": true, /* Disable preferring source files instead of declaration files when referencing composite projects. */
// "disableSolutionSearching": true, /* Opt a project out of multi-project reference checking when editing. */
// "disableReferencedProjectLoad": true, /* Reduce the number of projects loaded automatically by TypeScript. */
/* Language and Environment */
"target": "es2016",
/* Set the JavaScript language version for emitted JavaScript and include compatible library declarations. */
// "src": [], /* Specify a set of bundled library declaration files that describe the target runtime environment. */
// "jsx": "preserve", /* Specify what JSX code is generated. */
// "experimentalDecorators": true, /* Enable experimental support for legacy experimental decorators. */
// "emitDecoratorMetadata": true, /* Emit design-type metadata for decorated declarations in source files. */
// "jsxFactory": "", /* Specify the JSX factory function used when targeting React JSX emit, e.g. 'React.createElement' or 'h'. */
// "jsxFragmentFactory": "", /* Specify the JSX Fragment reference used for fragments when targeting React JSX emit e.g. 'React.Fragment' or 'Fragment'. */
// "jsxImportSource": "", /* Specify module specifier used to import the JSX factory functions when using 'jsx: react-jsx*'. */
// "reactNamespace": "", /* Specify the object invoked for 'createElement'. This only applies when targeting 'react' JSX emit. */
// "noLib": true, /* Disable including any library files, including the default src.d.ts. */
// "useDefineForClassFields": true, /* Emit ECMAScript-standard-compliant class fields. */
// "moduleDetection": "auto", /* Control what method is used to detect module-format JS files. */
/* Modules */
"module": "commonjs",
/* Specify what module code is generated. */
// "rootDir": "./", /* Specify the root folder within your source files. */
// "moduleResolution": "node10", /* Specify how TypeScript looks up a file from a given module specifier. */
// "baseUrl": "./", /* Specify the base directory to resolve non-relative module names. */
// "paths": {}, /* Specify a set of entries that re-map imports to additional lookup locations. */
// "rootDirs": [], /* Allow multiple folders to be treated as one when resolving modules. */
// "typeRoots": [], /* Specify multiple folders that act like './node_modules/@types'. */
// "types": [], /* Specify type package names to be included without being referenced in a source file. */
// "allowUmdGlobalAccess": true, /* Allow accessing UMD globals from modules. */
// "moduleSuffixes": [], /* List of file name suffixes to search when resolving a module. */
// "allowImportingTsExtensions": true, /* Allow imports to include TypeScript file extensions. Requires '--moduleResolution bundler' and either '--noEmit' or '--emitDeclarationOnly' to be set. */
// "resolvePackageJsonExports": true, /* Use the package.json 'exports' field when resolving package imports. */
// "resolvePackageJsonImports": true, /* Use the package.json 'imports' field when resolving imports. */
// "customConditions": [], /* Conditions to set in addition to the resolver-specific defaults when resolving imports. */
// "resolveJsonModule": true, /* Enable importing .json files. */
// "allowArbitraryExtensions": true, /* Enable importing files with any extension, provided a declaration file is present. */
// "noResolve": true, /* Disallow 'import's, 'require's or '<reference>'s from expanding the number of files TypeScript should add to a project. */
/* JavaScript Support */
// "allowJs": true, /* Allow JavaScript files to be a part of your program. Use the 'checkJS' option to get errors from these files. */
// "checkJs": true, /* Enable error reporting in type-checked JavaScript files. */
// "maxNodeModuleJsDepth": 1, /* Specify the maximum folder depth used for checking JavaScript files from 'node_modules'. Only applicable with 'allowJs'. */
/* Emit */
// "declaration": true, /* Generate .d.ts files from TypeScript and JavaScript files in your project. */
// "declarationMap": true, /* Create sourcemaps for d.ts files. */
// "emitDeclarationOnly": true, /* Only output d.ts files and not JavaScript files. */
// "sourceMap": true, /* Create source map files for emitted JavaScript files. */
// "inlineSourceMap": true, /* Include sourcemap files inside the emitted JavaScript. */
// "outFile": "./", /* Specify a file that bundles all outputs into one JavaScript file. If 'declaration' is true, also designates a file that bundles all .d.ts output. */
// "outDir": "./", /* Specify an output folder for all emitted files. */
// "removeComments": true, /* Disable emitting comments. */
// "noEmit": true, /* Disable emitting files from a compilation. */
// "importHelpers": true, /* Allow importing helper functions from tslib once per project, instead of including them per-file. */
// "importsNotUsedAsValues": "remove", /* Specify emit/checking behavior for imports that are only used for types. */
// "downlevelIteration": true, /* Emit more compliant, but verbose and less performant JavaScript for iteration. */
// "sourceRoot": "", /* Specify the root path for debuggers to find the reference source code. */
// "mapRoot": "", /* Specify the location where debugger should locate map files instead of generated locations. */
// "inlineSources": true, /* Include source code in the sourcemaps inside the emitted JavaScript. */
// "emitBOM": true, /* Emit a UTF-8 Byte Order Mark (BOM) in the beginning of output files. */
// "newLine": "crlf", /* Set the newline character for emitting files. */
// "stripInternal": true, /* Disable emitting declarations that have '@internal' in their JSDoc comments. */
// "noEmitHelpers": true, /* Disable generating custom helper functions like '__extends' in compiled output. */
// "noEmitOnError": true, /* Disable emitting files if any type checking errors are reported. */
// "preserveConstEnums": true, /* Disable erasing 'const enum' declarations in generated code. */
// "declarationDir": "./", /* Specify the output directory for generated declaration files. */
// "preserveValueImports": true, /* Preserve unused imported values in the JavaScript output that would otherwise be removed. */
/* Interop Constraints */
// "isolatedModules": true, /* Ensure that each file can be safely transpiled without relying on other imports. */
// "verbatimModuleSyntax": true, /* Do not transform or elide any imports or exports not marked as type-only, ensuring they are written in the output file's format based on the 'module' setting. */
// "allowSyntheticDefaultImports": true, /* Allow 'import x from y' when a module doesn't have a default export. */
"esModuleInterop": true,
/* Emit additional JavaScript to ease support for importing CommonJS modules. This enables 'allowSyntheticDefaultImports' for type compatibility. */
// "preserveSymlinks": true, /* Disable resolving symlinks to their realpath. This correlates to the same flag in node. */
"forceConsistentCasingInFileNames": true,
/* Ensure that casing is correct in imports. */
/* Type Checking */
"strict": true,
/* Enable all strict type-checking options. */
// "noImplicitAny": true, /* Enable error reporting for expressions and declarations with an implied 'any' type. */
// "strictNullChecks": true, /* When type checking, take into account 'null' and 'undefined'. */
// "strictFunctionTypes": true, /* When assigning functions, check to ensure parameters and the return values are subtype-compatible. */
// "strictBindCallApply": true, /* Check that the arguments for 'bind', 'call', and 'apply' methods match the original function. */
// "strictPropertyInitialization": true, /* Check for class properties that are declared but not set in the constructor. */
// "noImplicitThis": true, /* Enable error reporting when 'this' is given the type 'any'. */
// "useUnknownInCatchVariables": true, /* Default catch clause variables as 'unknown' instead of 'any'. */
// "alwaysStrict": true, /* Ensure 'use strict' is always emitted. */
// "noUnusedLocals": true, /* Enable error reporting when local variables aren't read. */
// "noUnusedParameters": true, /* Raise an error when a function parameter isn't read. */
// "exactOptionalPropertyTypes": true, /* Interpret optional property types as written, rather than adding 'undefined'. */
// "noImplicitReturns": true, /* Enable error reporting for codepaths that do not explicitly return in a function. */
// "noFallthroughCasesInSwitch": true, /* Enable error reporting for fallthrough cases in switch statements. */
// "noUncheckedIndexedAccess": true, /* Add 'undefined' to a type when accessed using an index. */
// "noImplicitOverride": true, /* Ensure overriding members in derived classes are marked with an override modifier. */
// "noPropertyAccessFromIndexSignature": true, /* Enforces using indexed accessors for keys declared using an indexed type. */
// "allowUnusedLabels": true, /* Disable error reporting for unused labels. */
// "allowUnreachableCode": true, /* Disable error reporting for unreachable code. */
/* Completeness */
// "skipDefaultLibCheck": true, /* Skip type checking .d.ts files that are included with TypeScript. */
"skipLibCheck": true
/* Skip type checking all .d.ts files. */
}
}

3
db/Dockerfile Normal file
View File

@ -0,0 +1,3 @@
FROM mysql:8.0
ADD setup.sql /docker-entrypoint-initdb.d

11
db/setup.sql Normal file
View File

@ -0,0 +1,11 @@
-- Prisma requires DB creation privileges to create a shadow database (https://pris.ly/d/migrate-shadow)
-- This is not available to our user by default, so we must manually add this
-- Create the user
CREATE USER IF NOT EXISTS 'reworkd_platform'@'%' IDENTIFIED BY 'reworkd_platform';
-- Grant the necessary permissions
GRANT CREATE, ALTER, DROP, INSERT, UPDATE, DELETE, SELECT ON *.* TO 'reworkd_platform'@'%';
-- Apply the changes
FLUSH PRIVILEGES;

145
platform/.dockerignore Normal file
View File

@ -0,0 +1,145 @@
### Python template
deploy/
.idea/
.vscode/
.git/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
poetry.lock

31
platform/.editorconfig Normal file
View File

@ -0,0 +1,31 @@
root = true
[*]
tab_width = 4
end_of_line = lf
max_line_length = 88
ij_visual_guides = 88
insert_final_newline = true
trim_trailing_whitespace = true
[*.{js,py,html}]
charset = utf-8
[*.md]
trim_trailing_whitespace = false
[*.{yml,yaml}]
indent_style = space
indent_size = 2
[Makefile]
indent_style = tab
[.flake8]
indent_style = space
indent_size = 2
[*.py]
indent_style = space
indent_size = 4
ij_python_from_import_parentheses_force_if_multiline = true

115
platform/.flake8 Normal file
View File

@ -0,0 +1,115 @@
[flake8]
max-complexity = 6
inline-quotes = double
max-line-length = 88
extend-ignore = E203
docstring_style=sphinx
ignore =
; Found `f` string
WPS305,
; Missing docstring in public module
D100,
; Missing docstring in magic method
D105,
; Missing docstring in __init__
D107,
; Found `__init__.py` module with logic
WPS412,
; Found class without a base class
WPS306,
; Missing docstring in public nested class
D106,
; First line should be in imperative mood
D401,
; Found wrong variable name
WPS110,
; Found `__init__.py` module with logic
WPS326,
; Found string constant over-use
WPS226,
; Found upper-case constant in a class
WPS115,
; Found nested function
WPS602,
; Found method without arguments
WPS605,
; Found overused expression
WPS204,
; Found too many module members
WPS202,
; Found too high module cognitive complexity
WPS232,
; line break before binary operator
W503,
; Found module with too many imports
WPS201,
; Inline strong start-string without end-string.
RST210,
; Found nested class
WPS431,
; Found wrong module name
WPS100,
; Found too many methods
WPS214,
; Found too long ``try`` body
WPS229,
; Found unpythonic getter or setter
WPS615,
; Found a line that starts with a dot
WPS348,
; Found complex default value (for dependency injection)
WPS404,
; not perform function calls in argument defaults (for dependency injection)
B008,
; Model should define verbose_name in its Meta inner class
DJ10,
; Model should define verbose_name_plural in its Meta inner class
DJ11,
; Found mutable module constant.
WPS407,
; Found too many empty lines in `def`
WPS473,
; Found missing trailing comma
C812,
per-file-ignores =
; all tests
test_*.py,tests.py,tests_*.py,*/tests/*,conftest.py:
; Use of assert detected
S101,
; Found outer scope names shadowing
WPS442,
; Found too many local variables
WPS210,
; Found magic number
WPS432,
; Missing parameter(s) in Docstring
DAR101,
; Found too many arguments
WPS211,
; all init files
__init__.py:
; ignore not used imports
F401,
; ignore import with wildcard
F403,
; Found wrong metadata variable
WPS410,
exclude =
./.cache,
./.git,
./.idea,
./.mypy_cache,
./.pytest_cache,
./.venv,
./venv,
./env,
./cached_venv,
./docs,
./deploy,
./var,
./.vscode,
*migrations*,

144
platform/.gitignore vendored Normal file
View File

@ -0,0 +1,144 @@
### Python template
.idea/
.vscode/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
*.pem
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
ssl/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
*.sqlite3
*.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
.env

View File

@ -0,0 +1,63 @@
---
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.4.0
hooks:
- id: check-ast
- id: trailing-whitespace
- id: check-toml
- id: end-of-file-fixer
- repo: https://github.com/asottile/add-trailing-comma
rev: v2.1.0
hooks:
- id: add-trailing-comma
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
rev: v2.1.0
hooks:
- id: pretty-format-yaml
args:
- --autofix
- --preserve-quotes
- --indent=2
- repo: local
hooks:
- id: black
name: Format with Black
entry: poetry run black
language: system
types: [python]
- id: autoflake
name: autoflake
entry: poetry run autoflake
language: system
types: [python]
args: [--in-place, --remove-all-unused-imports, --remove-duplicate-keys]
- id: isort
name: isort
entry: poetry run isort
language: system
types: [python]
- id: flake8
name: Check with Flake8
entry: poetry run flake8
language: system
pass_filenames: false
types: [python]
args: [--count, .]
- id: mypy
name: Validate types with MyPy
entry: poetry run mypy
language: system
types: [python]
pass_filenames: false
args:
- "reworkd_platform"

37
platform/Dockerfile Normal file
View File

@ -0,0 +1,37 @@
FROM python:3.11-slim-buster as prod
RUN apt-get update && apt-get install -y \
default-libmysqlclient-dev \
gcc \
pkg-config \
openjdk-11-jdk \
build-essential \
&& rm -rf /var/lib/apt/lists/*
RUN pip install poetry==1.4.2
# Configuring poetry
RUN poetry config virtualenvs.create false
# Copying requirements of a project
COPY pyproject.toml /app/src/
WORKDIR /app/src
# Installing requirements
RUN poetry install --only main
# Removing gcc
RUN apt-get purge -y \
g++ \
gcc \
pkg-config \
&& rm -rf /var/lib/apt/lists/*
# Copying actual application
COPY . /app/src/
RUN poetry install --only main
CMD ["/usr/local/bin/python", "-m", "reworkd_platform"]
FROM prod as dev
RUN poetry install

151
platform/README.md Normal file
View File

@ -0,0 +1,151 @@
# reworkd_platform
This project was generated using fastapi_template.
## Poetry
This project uses poetry. It's a modern dependency management
tool.
To run the project use this set of commands:
```bash
poetry install
poetry run python -m reworkd_platform
```
This will start the server on the configured host.
You can find swagger documentation at `/api/docs`.
You can read more about poetry here: https://python-poetry.org/
## Docker
You can start the project with docker using this command:
```bash
docker-compose -f deploy/docker-compose.yml --project-directory . up --build
```
If you want to develop in docker with autoreload add `-f deploy/docker-compose.dev.yml` to your docker command.
Like this:
```bash
docker-compose -f deploy/docker-compose.yml -f deploy/docker-compose.dev.yml --project-directory . up --build
```
This command exposes the web application on port 8000, mounts current directory and enables autoreload.
But you have to rebuild image every time you modify `poetry.lock` or `pyproject.toml` with this command:
```bash
docker-compose -f deploy/docker-compose.yml --project-directory . build
```
## Project structure
```bash
$ tree "reworkd_platform"
reworkd_platform
├── conftest.py # Fixtures for all tests.
├── db # module contains db configurations
│   ├── dao # Data Access Objects. Contains different classes to interact with database.
│   └── models # Package contains different models for ORMs.
├── __main__.py # Startup script. Starts uvicorn.
├── services # Package for different external services such as rabbit or redis etc.
├── settings.py # Main configuration settings for project.
├── static # Static content.
├── tests # Tests for project.
└── web # Package contains web server. Handlers, startup config.
├── api # Package with all handlers.
│   └── router.py # Main router.
├── application.py # FastAPI application configuration.
└── lifetime.py # Contains actions to perform on startup and shutdown.
```
## Configuration
This application can be configured with environment variables.
You can create `.env` file in the root directory and place all
environment variables here.
All environment variables should start with "REWORKD_PLATFORM_" prefix.
For example if you see in your "reworkd_platform/settings.py" a variable named like
`random_parameter`, you should provide the "REWORKD_PLATFORM_RANDOM_PARAMETER"
variable to configure the value. This behaviour can be changed by overriding `env_prefix` property
in `reworkd_platform.settings.Settings.Config`.
An example of .env file:
```bash
REWORKD_PLATFORM_RELOAD="True"
REWORKD_PLATFORM_PORT="8000"
REWORKD_PLATFORM_ENVIRONMENT="development"
```
You can read more about BaseSettings class here: https://pydantic-docs.helpmanual.io/usage/settings/
## Pre-commit
To install pre-commit simply run inside the shell:
```bash
pre-commit install
```
pre-commit is very useful to check your code before publishing it.
It's configured using .pre-commit-config.yaml file.
By default it runs:
* black (formats your code);
* mypy (validates types);
* isort (sorts imports in all files);
* flake8 (spots possibe bugs);
You can read more about pre-commit here: https://pre-commit.com/
## Running tests
If you want to run it in docker, simply run:
```bash
docker-compose -f deploy/docker-compose.yml -f deploy/docker-compose.dev.yml --project-directory . run --build --rm api pytest -vv .
docker-compose -f deploy/docker-compose.yml -f deploy/docker-compose.dev.yml --project-directory . down
```
For running tests on your local machine.
1. you need to start a database.
I prefer doing it with docker:
```
docker run -p "3306:3306" -e "MYSQL_PASSWORD=reworkd_platform" -e "MYSQL_USER=reworkd_platform" -e "MYSQL_DATABASE=reworkd_platform" -e ALLOW_EMPTY_PASSWORD=yes bitnami/mysql:8.0.30
```
2. Run the pytest.
```bash
pytest -vv .
```
## Running linters
```bash
# Flake
poetry run black .
poetry run autoflake --in-place --remove-duplicate-keys --remove-all-unused-imports -r .
poetry run flake8
poetry run mypy .
# Pytest
poetry run pytest -vv --cov="reworkd_platform" .
# Bump packages
poetry self add poetry-plugin-up
poetry up --latest
```

14
platform/entrypoint.sh Normal file
View File

@ -0,0 +1,14 @@
#!/usr/bin/env sh
host=agentgpt_db
port=3306
until echo "SELECT 1;" | nc "$host" "$port" > /dev/null 2>&1; do
>&2 echo "Database is unavailable - Sleeping..."
sleep 2
done
>&2 echo "Database is available! Continuing..."
# Run cmd
exec "$@"

3413
platform/poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

100
platform/pyproject.toml Normal file
View File

@ -0,0 +1,100 @@
[tool.poetry]
name = "reworkd_platform"
version = "0.1.0"
description = ""
authors = [
"awtkns",
"asim-shrestha"
]
maintainers = [
"reworkd"
]
readme = "README.md"
[tool.poetry.dependencies]
python = "^3.11"
fastapi = "^0.98.0"
boto3 = "^1.28.51"
uvicorn = { version = "^0.22.0", extras = ["standard"] }
pydantic = { version = "<2.0", extras = ["dotenv"] }
ujson = "^5.8.0"
sqlalchemy = { version = "^2.0.21", extras = ["mypy", "asyncio"] }
aiomysql = "^0.1.1"
mysqlclient = "^2.2.0"
sentry-sdk = "^1.31.0"
loguru = "^0.7.2"
aiokafka = "^0.8.1"
requests = "^2.31.0"
langchain = "^0.0.295"
openai = "^0.28.0"
wikipedia = "^1.4.0"
replicate = "^0.8.4"
lanarky = "0.7.15"
tiktoken = "^0.5.1"
grpcio = "^1.58.0"
pinecone-client = "^2.2.4"
python-multipart = "^0.0.6"
aws-secretsmanager-caching = "^1.1.1.5"
botocore = "^1.31.51"
stripe = "^5.5.0"
cryptography = "^41.0.4"
httpx = "^0.25.0"
[tool.poetry.dev-dependencies]
autopep8 = "^2.0.4"
pytest = "^7.4.2"
flake8 = "~6.0.0"
mypy = "^1.5.1"
isort = "^5.12.0"
pre-commit = "^3.4.0"
wemake-python-styleguide = "^0.18.0"
black = "^23.9.1"
autoflake = "^2.2.1"
pytest-cov = "^4.1.0"
anyio = "^3.7.1"
pytest-env = "^0.8.2"
[tool.poetry.group.dev.dependencies]
dotmap = "^1.3.30"
pytest-mock = "^3.10.0"
pytest-asyncio = "^0.21.0"
mypy = "^1.4.1"
types-requests = "^2.31.0.1"
types-pytz = "^2023.3.0.0"
[tool.isort]
profile = "black"
multi_line_output = 3
src_paths = ["reworkd_platform"]
[tool.mypy]
strict = true
ignore_missing_imports = true
allow_subclassing_any = true
allow_untyped_calls = true
pretty = true
show_error_codes = true
implicit_reexport = true
allow_untyped_decorators = true
warn_unused_ignores = false
warn_return_any = false
namespace_packages = true
exclude = "tests"
[tool.pytest.ini_options]
filterwarnings = [
"error",
"ignore::DeprecationWarning",
"ignore:.*unclosed.*:ResourceWarning",
"ignore::ImportWarning",
]
env = [
"REWORKD_PLATFORM_DB_BASE=reworkd_platform_test",
]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

View File

@ -0,0 +1 @@
"""reworkd_platform package."""

View File

@ -0,0 +1,20 @@
import uvicorn
from reworkd_platform.settings import settings
def main() -> None:
"""Entrypoint of the application."""
uvicorn.run(
"reworkd_platform.web.application:get_app",
workers=settings.workers_count,
host=settings.host,
port=settings.port,
reload=settings.reload,
log_level="debug",
factory=True,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,107 @@
from typing import Any, AsyncGenerator
import pytest
from fastapi import FastAPI
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from reworkd_platform.db.dependencies import get_db_session
from reworkd_platform.db.utils import create_database, drop_database
from reworkd_platform.settings import settings
from reworkd_platform.web.application import get_app
@pytest.fixture(scope="session")
def anyio_backend() -> str:
"""
Backend for anyio pytest plugin.
:return: backend name.
"""
return "asyncio"
@pytest.fixture(scope="session")
async def _engine() -> AsyncGenerator[AsyncEngine, None]:
"""
Create engine and databases.
:yield: new engine.
"""
from reworkd_platform.db.meta import meta # noqa: WPS433
from reworkd_platform.db.models import load_all_models # noqa: WPS433
load_all_models()
await create_database()
engine = create_async_engine(str(settings.db_url))
async with engine.begin() as conn:
await conn.run_sync(meta.create_all)
try:
yield engine
finally:
await engine.dispose()
await drop_database()
@pytest.fixture
async def dbsession(
_engine: AsyncEngine,
) -> AsyncGenerator[AsyncSession, None]:
"""
Get session to database.
Fixture that returns a SQLAlchemy session with a SAVEPOINT, and the rollback to it
after the test completes.
:param _engine: current engine.
:yields: async session.
"""
connection = await _engine.connect()
trans = await connection.begin()
session_maker = async_sessionmaker(
connection,
expire_on_commit=False,
)
session = session_maker()
try:
yield session
finally:
await session.close()
await trans.rollback()
await connection.close()
@pytest.fixture
def fastapi_app(dbsession: AsyncSession) -> FastAPI:
"""
Fixture for creating FastAPI app.
:return: fastapi app with mocked dependencies.
"""
application = get_app()
application.dependency_overrides[get_db_session] = lambda: dbsession
return application # noqa: WPS331
@pytest.fixture
async def client(
fastapi_app: FastAPI, anyio_backend: Any
) -> AsyncGenerator[AsyncClient, None]:
"""
Fixture that creates client for requesting server.
:param fastapi_app: the application.
:yield: client for the app.
"""
async with AsyncClient(app=fastapi_app, base_url="http://test") as ac:
yield ac

View File

@ -0,0 +1 @@
ENV_PREFIX = "REWORKD_PLATFORM_"

View File

View File

@ -0,0 +1,68 @@
import uuid
from datetime import datetime
from typing import Optional, Type, TypeVar
from sqlalchemy import DateTime, String, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from reworkd_platform.db.meta import meta
from reworkd_platform.web.api.http_responses import not_found
T = TypeVar("T", bound="Base")
class Base(DeclarativeBase):
"""Base for all models."""
metadata = meta
id: Mapped[str] = mapped_column(
String,
primary_key=True,
default=lambda _: str(uuid.uuid4()),
unique=True,
nullable=False,
)
@classmethod
async def get(cls: Type[T], session: AsyncSession, id_: str) -> Optional[T]:
return await session.get(cls, id_)
@classmethod
async def get_or_404(cls: Type[T], session: AsyncSession, id_: str) -> T:
if model := await cls.get(session, id_):
return model
raise not_found(detail=f"{cls.__name__}[{id_}] not found")
async def save(self: T, session: AsyncSession) -> T:
session.add(self)
await session.flush()
return self
async def delete(self: T, session: AsyncSession) -> None:
await session.delete(self)
class TrackedModel(Base):
"""Base for all tracked models."""
__abstract__ = True
create_date = mapped_column(
DateTime, name="create_date", server_default=func.now(), nullable=False
)
update_date = mapped_column(
DateTime, name="update_date", onupdate=func.now(), nullable=True
)
delete_date = mapped_column(DateTime, name="delete_date", nullable=True)
async def delete(self, session: AsyncSession) -> None:
"""Marks the model as deleted."""
self.delete_date = datetime.now()
await self.save(session)
class UserMixin:
user_id = mapped_column(String, name="user_id", nullable=False)
organization_id = mapped_column(String, name="organization_id", nullable=True)

View File

@ -0,0 +1,58 @@
from fastapi import HTTPException
from sqlalchemy import and_, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from reworkd_platform.db.crud.base import BaseCrud
from reworkd_platform.db.models.agent import AgentRun, AgentTask
from reworkd_platform.schemas.agent import Loop_Step
from reworkd_platform.schemas.user import UserBase
from reworkd_platform.settings import settings
from reworkd_platform.web.api.errors import MaxLoopsError, MultipleSummaryError
class AgentCRUD(BaseCrud):
def __init__(self, session: AsyncSession, user: UserBase):
super().__init__(session)
self.user = user
async def create_run(self, goal: str) -> AgentRun:
return await AgentRun(
user_id=self.user.id,
goal=goal,
).save(self.session)
async def create_task(self, run_id: str, type_: Loop_Step) -> AgentTask:
await self.validate_task_count(run_id, type_)
return await AgentTask(
run_id=run_id,
type_=type_,
).save(self.session)
async def validate_task_count(self, run_id: str, type_: str) -> None:
if not await AgentRun.get(self.session, run_id):
raise HTTPException(404, f"Run {run_id} not found")
query = select(func.count(AgentTask.id)).where(
and_(
AgentTask.run_id == run_id,
AgentTask.type_ == type_,
)
)
task_count = (await self.session.execute(query)).scalar_one()
max_ = settings.max_loops
if task_count >= max_:
raise MaxLoopsError(
StopIteration(),
f"Max loops of {max_} exceeded, shutting down.",
429,
should_log=False,
)
if type_ == "summarize" and task_count > 1:
raise MultipleSummaryError(
StopIteration(),
"Multiple summary tasks are not allowed",
429,
)

View File

@ -0,0 +1,10 @@
from typing import TypeVar
from sqlalchemy.ext.asyncio import AsyncSession
T = TypeVar("T", bound="BaseCrud")
class BaseCrud:
def __init__(self, session: AsyncSession):
self.session = session

View File

@ -0,0 +1,77 @@
import secrets
from typing import Dict, Optional
from fastapi import Depends
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from reworkd_platform.db.crud.base import BaseCrud
from reworkd_platform.db.dependencies import get_db_session
from reworkd_platform.db.models.auth import OauthCredentials
from reworkd_platform.schemas import UserBase
class OAuthCrud(BaseCrud):
@classmethod
async def inject(
cls,
session: AsyncSession = Depends(get_db_session),
) -> "OAuthCrud":
return cls(session)
async def create_installation(
self, user: UserBase, provider: str, redirect_uri: Optional[str]
) -> OauthCredentials:
return await OauthCredentials(
user_id=user.id,
organization_id=user.organization_id,
provider=provider,
state=secrets.token_hex(16),
redirect_uri=redirect_uri,
).save(self.session)
async def get_installation_by_state(self, state: str) -> Optional[OauthCredentials]:
query = select(OauthCredentials).filter(OauthCredentials.state == state)
return (await self.session.execute(query)).scalar_one_or_none()
async def get_installation_by_user_id(
self, user_id: str, provider: str
) -> Optional[OauthCredentials]:
query = select(OauthCredentials).filter(
OauthCredentials.user_id == user_id,
OauthCredentials.provider == provider,
OauthCredentials.access_token_enc.isnot(None),
)
return (await self.session.execute(query)).scalars().first()
async def get_installation_by_organization_id(
self, organization_id: str, provider: str
) -> Optional[OauthCredentials]:
query = select(OauthCredentials).filter(
OauthCredentials.organization_id == organization_id,
OauthCredentials.provider == provider,
OauthCredentials.access_token_enc.isnot(None),
OauthCredentials.organization_id.isnot(None),
)
return (await self.session.execute(query)).scalars().first()
async def get_all(self, user: UserBase) -> Dict[str, str]:
query = (
select(
OauthCredentials.provider,
func.any_value(OauthCredentials.access_token_enc),
)
.filter(
OauthCredentials.access_token_enc.isnot(None),
OauthCredentials.organization_id == user.organization_id,
)
.group_by(OauthCredentials.provider)
)
return {
provider: token
for provider, token in (await self.session.execute(query)).all()
}

View File

@ -0,0 +1,98 @@
# from
from typing import List, Optional
from fastapi import Depends
from pydantic import BaseModel
from sqlalchemy import and_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import aliased
from reworkd_platform.db.crud.base import BaseCrud
from reworkd_platform.db.dependencies import get_db_session
from reworkd_platform.db.models.auth import Organization, OrganizationUser
from reworkd_platform.db.models.user import User
from reworkd_platform.schemas import UserBase
from reworkd_platform.web.api.dependencies import get_current_user
class OrgUser(BaseModel):
id: str
role: str
user: UserBase
class OrganizationUsers(BaseModel):
id: str
name: str
users: List[OrgUser]
class OrganizationCrud(BaseCrud):
def __init__(self, session: AsyncSession, user: UserBase):
super().__init__(session)
self.user = user
@classmethod
def inject(
cls,
session: AsyncSession = Depends(get_db_session),
user: UserBase = Depends(get_current_user),
) -> "OrganizationCrud":
return cls(session, user)
async def create_organization(self, name: str) -> Organization:
return await Organization(
created_by=self.user.id,
name=name,
).save(self.session)
async def get_by_name(self, name: str) -> Optional[OrganizationUsers]:
owner = aliased(OrganizationUser, name="owner")
query = (
select(
Organization,
User,
OrganizationUser,
)
.join(
OrganizationUser,
and_(
Organization.id == OrganizationUser.organization_id,
),
)
.join(
User,
User.id == OrganizationUser.user_id,
)
.join( # Owner
owner,
and_(
OrganizationUser.organization_id == Organization.id,
OrganizationUser.user_id == self.user.id,
),
)
.filter(Organization.name == name)
)
rows = (await self.session.execute(query)).all()
if not rows:
return None
org: Organization = rows[0][0]
return OrganizationUsers(
id=org.id,
name=org.name,
users=[
OrgUser(
id=org_user.user_id,
role=org_user.role,
user=UserBase(
id=user.id,
email=user.email,
name=user.name,
),
)
for [_, user, org_user] in rows
],
)

View File

@ -0,0 +1,31 @@
from typing import Optional
from sqlalchemy import and_, select
from sqlalchemy.orm import selectinload
from reworkd_platform.db.crud.base import BaseCrud
from reworkd_platform.db.models.auth import OrganizationUser
from reworkd_platform.db.models.user import UserSession
class UserCrud(BaseCrud):
async def get_user_session(self, token: str) -> UserSession:
query = (
select(UserSession)
.filter(UserSession.session_token == token)
.options(selectinload(UserSession.user))
)
return (await self.session.execute(query)).scalar_one()
async def get_user_organization(
self, user_id: str, organization_id: str
) -> Optional[OrganizationUser]:
query = select(OrganizationUser).filter(
and_(
OrganizationUser.user_id == user_id,
OrganizationUser.organization_id == organization_id,
)
)
# TODO: Only returns the first organization
return (await self.session.execute(query)).scalar()

View File

@ -0,0 +1,20 @@
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.requests import Request
async def get_db_session(request: Request) -> AsyncGenerator[AsyncSession, None]:
"""
Create and get database session.
:param request: current request.
:yield: database session.
"""
session: AsyncSession = request.app.state.db_session_factory()
try: # noqa: WPS501
yield session
await session.commit()
finally:
await session.close()

View File

@ -0,0 +1,3 @@
import sqlalchemy as sa
meta = sa.MetaData()

View File

@ -0,0 +1,14 @@
"""reworkd_platform models."""
import pkgutil
from pathlib import Path
def load_all_models() -> None:
"""Load all models from this folder."""
package_dir = Path(__file__).resolve().parent
modules = pkgutil.walk_packages(
path=[str(package_dir)],
prefix="reworkd_platform.db.models.",
)
for module in modules:
__import__(module.name) # noqa: WPS421

View File

@ -0,0 +1,24 @@
from sqlalchemy import DateTime, String, Text, func
from sqlalchemy.orm import mapped_column
from reworkd_platform.db.base import Base
class AgentRun(Base):
__tablename__ = "agent_run"
user_id = mapped_column(String, nullable=False)
goal = mapped_column(Text, nullable=False)
create_date = mapped_column(
DateTime, name="create_date", server_default=func.now(), nullable=False
)
class AgentTask(Base):
__tablename__ = "agent_task"
run_id = mapped_column(String, nullable=False)
type_ = mapped_column(String, nullable=False, name="type")
create_date = mapped_column(
DateTime, name="create_date", server_default=func.now(), nullable=False
)

View File

@ -0,0 +1,36 @@
from sqlalchemy import DateTime, String
from sqlalchemy.orm import mapped_column
from reworkd_platform.db.base import TrackedModel
class Organization(TrackedModel):
__tablename__ = "organization"
name = mapped_column(String, nullable=False)
created_by = mapped_column(String, nullable=False)
class OrganizationUser(TrackedModel):
__tablename__ = "organization_user"
user_id = mapped_column(String, nullable=False)
organization_id = mapped_column(String, nullable=False)
role = mapped_column(String, nullable=False, default="member")
class OauthCredentials(TrackedModel):
__tablename__ = "oauth_credentials"
user_id = mapped_column(String, nullable=False)
organization_id = mapped_column(String, nullable=True)
provider = mapped_column(String, nullable=False)
state = mapped_column(String, nullable=False)
redirect_uri = mapped_column(String, nullable=False)
# Post-installation
token_type = mapped_column(String, nullable=True)
access_token_enc = mapped_column(String, nullable=True)
access_token_expiration = mapped_column(DateTime, nullable=True)
refresh_token_enc = mapped_column(String, nullable=True)
scope = mapped_column(String, nullable=True)

View File

@ -0,0 +1,38 @@
from typing import List
from sqlalchemy import DateTime, ForeignKey, Index, String, text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from reworkd_platform.db.base import Base
class UserSession(Base):
__tablename__ = "Session"
session_token = mapped_column(String, unique=True, name="sessionToken")
user_id = mapped_column(
String, ForeignKey("User.id", ondelete="CASCADE"), name="userId"
)
expires = mapped_column(DateTime)
user = relationship("User")
__table_args__ = (Index("user_id"),)
class User(Base):
__tablename__ = "User"
name = mapped_column(String, nullable=True)
email = mapped_column(String, nullable=True, unique=True)
email_verified = mapped_column(DateTime, nullable=True, name="emailVerified")
image = mapped_column(String, nullable=True)
create_date = mapped_column(
DateTime, server_default=text("(now())"), name="createDate"
)
sessions: Mapped[List["UserSession"]] = relationship(
"UserSession", back_populates="user"
)
__table_args__ = (Index("email"),)

View File

@ -0,0 +1,57 @@
from ssl import CERT_REQUIRED
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from reworkd_platform.services.ssl import get_ssl_context
from reworkd_platform.settings import settings
def create_engine() -> AsyncEngine:
"""
Creates SQLAlchemy engine instance.
:return: SQLAlchemy engine instance.
"""
if settings.environment == "development":
return create_async_engine(
str(settings.db_url),
echo=settings.db_echo,
)
ssl_context = get_ssl_context(settings)
ssl_context.verify_mode = CERT_REQUIRED
connect_args = {"ssl": ssl_context}
return create_async_engine(
str(settings.db_url),
echo=settings.db_echo,
connect_args=connect_args,
)
async def create_database() -> None:
"""Create a database."""
engine = create_async_engine(str(settings.db_url.with_path("/mysql")))
async with engine.connect() as conn:
database_existance = await conn.execute(
text(
"SELECT 1 FROM INFORMATION_SCHEMA.SCHEMATA" # noqa: S608
f" WHERE SCHEMA_NAME='{settings.db_base}';",
)
)
database_exists = database_existance.scalar() == 1
if database_exists:
await drop_database()
async with engine.connect() as conn: # noqa: WPS440
await conn.execute(text(f"CREATE DATABASE {settings.db_base};"))
async def drop_database() -> None:
"""Drop current database."""
engine = create_async_engine(str(settings.db_url.with_path("/mysql")))
async with engine.connect() as conn:
await conn.execute(text(f"DROP DATABASE {settings.db_base};"))

View File

@ -0,0 +1,63 @@
import logging
import sys
from typing import Union
from loguru import logger
from reworkd_platform.settings import settings
class InterceptHandler(logging.Handler):
"""
Default handler from examples in loguru documentation.
This handler intercepts all log requests and
passes them to loguru.
For more info see:
https://loguru.readthedocs.io/en/stable/overview.html#entirely-compatible-with-standard-logging
"""
def emit(self, record: logging.LogRecord) -> None: # pragma: no cover
"""
Propagates logs to loguru.
:param record: record to log.
"""
try:
level: Union[str, int] = logger.level(record.levelname).name
except ValueError:
level = record.levelno
# Find caller from where originated the logged message
frame, depth = logging.currentframe(), 2
while frame.f_code.co_filename == logging.__file__:
frame = frame.f_back # type: ignore
depth += 1
logger.opt(depth=depth, exception=record.exc_info).log(
level,
record.getMessage(),
)
def configure_logging() -> None: # pragma: no cover
"""Configures logging."""
intercept_handler = InterceptHandler()
logging.basicConfig(handlers=[intercept_handler], level=logging.NOTSET)
for logger_name in logging.root.manager.loggerDict:
if logger_name.startswith("uvicorn."):
logging.getLogger(logger_name).handlers = []
# change handler for default uvicorn logger
logging.getLogger("uvicorn").handlers = [intercept_handler]
logging.getLogger("uvicorn.access").handlers = [intercept_handler]
# set logs output, level and format
logger.remove()
logger.add(
sys.stdout,
level=settings.log_level,
)

View File

@ -0,0 +1,2 @@
from .agent import ModelSettings
from .user import UserBase

View File

@ -0,0 +1,87 @@
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field, validator
from reworkd_platform.web.api.agent.analysis import Analysis
LLM_Model = Literal[
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4o",
]
Loop_Step = Literal[
"start",
"analyze",
"execute",
"create",
"summarize",
"chat",
]
LLM_MODEL_MAX_TOKENS: Dict[LLM_Model, int] = {
"gpt-3.5-turbo": 4000,
"gpt-3.5-turbo-16k": 16000,
"gpt-4o": 8000,
}
class ModelSettings(BaseModel):
model: LLM_Model = Field(default="gpt-4o")
custom_api_key: Optional[str] = Field(default=None)
temperature: float = Field(default=0.9, ge=0.0, le=1.0)
max_tokens: int = Field(default=500, ge=0)
language: str = Field(default="English")
@validator("max_tokens")
def validate_max_tokens(cls, v: float, values: Dict[str, Any]) -> float:
model = values["model"]
if v > (max_tokens := LLM_MODEL_MAX_TOKENS[model]):
raise ValueError(f"Model {model} only supports {max_tokens} tokens")
return v
class AgentRunCreate(BaseModel):
goal: str
model_settings: ModelSettings = Field(default=ModelSettings())
class AgentRun(AgentRunCreate):
run_id: str
class AgentTaskAnalyze(AgentRun):
task: str
tool_names: List[str] = Field(default=[])
model_settings: ModelSettings = Field(default=ModelSettings())
class AgentTaskExecute(AgentRun):
task: str
analysis: Analysis
class AgentTaskCreate(AgentRun):
tasks: List[str] = Field(default=[])
last_task: Optional[str] = Field(default=None)
result: Optional[str] = Field(default=None)
completed_tasks: List[str] = Field(default=[])
class AgentSummarize(AgentRun):
results: List[str] = Field(default=[])
class AgentChat(AgentRun):
message: str
results: List[str] = Field(default=[])
class NewTasksResponse(BaseModel):
run_id: str
new_tasks: List[str] = Field(alias="newTasks")
class RunCount(BaseModel):
count: int
first_run: Optional[datetime]
last_run: Optional[datetime]

View File

@ -0,0 +1,21 @@
from typing import Optional
from pydantic import BaseModel, Field
class OrganizationRole(BaseModel):
id: str
role: str
organization_id: str
class UserBase(BaseModel):
id: str
name: Optional[str]
email: Optional[str]
image: Optional[str] = Field(default=None)
organization: Optional[OrganizationRole] = Field(default=None)
@property
def organization_id(self) -> Optional[str]:
return self.organization.organization_id if self.organization else None

View File

@ -0,0 +1 @@
"""Services for reworkd_platform."""

View File

@ -0,0 +1,42 @@
from typing import Any, Optional
from anthropic import AsyncAnthropic
from pydantic import BaseModel
class AbstractPrompt(BaseModel):
def to_string(self) -> str:
raise NotImplementedError
class HumanAssistantPrompt(AbstractPrompt):
assistant_prompt: str
human_prompt: str
def to_string(self) -> str:
return (
f"""\n\nHuman: {self.human_prompt}\n\nAssistant: {self.assistant_prompt}"""
)
class ClaudeService:
def __init__(self, api_key: Optional[str], model: str = "claude-2"):
self.claude = AsyncAnthropic(api_key=api_key)
self.model = model
async def completion(
self,
prompt: AbstractPrompt,
max_tokens_to_sample: int,
temperature: int = 0,
**kwargs: Any,
) -> str:
return (
await self.claude.completions.create(
model=self.model,
prompt=prompt.to_string(),
max_tokens_to_sample=max_tokens_to_sample,
temperature=temperature,
**kwargs,
)
).completion.strip()

View File

@ -0,0 +1,7 @@
import boto3
from botocore.exceptions import ProfileNotFound
try:
boto3.setup_default_session(profile_name="dev")
except ProfileNotFound:
pass

View File

@ -0,0 +1,85 @@
import io
import os
from typing import Dict, List, Optional
from aiohttp import ClientError
from boto3 import client as boto3_client
from loguru import logger
from pydantic import BaseModel
REGION = "us-east-1"
# noinspection SpellCheckingInspection
class PresignedPost(BaseModel):
url: str
fields: Dict[str, str]
class SimpleStorageService:
# TODO: would be great if with could make this async
def __init__(self, bucket: Optional[str]) -> None:
if not bucket:
raise ValueError("Bucket name must be provided")
self._client = boto3_client("s3", region_name=REGION)
self._bucket = bucket
def create_presigned_upload_url(
self,
object_name: str,
) -> PresignedPost:
return PresignedPost(
**self._client.generate_presigned_post(
Bucket=self._bucket,
Key=object_name,
)
)
def create_presigned_download_url(self, object_name: str) -> str:
return self._client.generate_presigned_url(
"get_object",
Params={"Bucket": self._bucket, "Key": object_name},
)
def upload_to_bucket(
self,
object_name: str,
file: io.BytesIO,
) -> None:
try:
self._client.put_object(
Bucket=self._bucket, Key=object_name, Body=file.getvalue()
)
except ClientError as e:
logger.error(e)
raise e
def download_file(self, object_name: str, local_filename: str) -> None:
self._client.download_file(
Bucket=self._bucket, Key=object_name, Filename=local_filename
)
def list_keys(self, prefix: str) -> List[str]:
files = self._client.list_objects_v2(Bucket=self._bucket, Prefix=prefix)
if "Contents" not in files:
return []
return [file["Key"] for file in files["Contents"]]
def download_folder(self, prefix: str, path: str) -> List[str]:
local_files = []
for key in self.list_keys(prefix):
local_filename = os.path.join(path, key.split("/")[-1])
self.download_file(key, local_filename)
local_files.append(local_filename)
return local_files
def delete_folder(self, prefix: str) -> None:
keys = self.list_keys(prefix)
self._client.delete_objects(
Bucket=self._bucket,
Delete={"Objects": [{"Key": key} for key in keys]},
)

View File

@ -0,0 +1,147 @@
import json
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from urllib.parse import urlencode
import aiohttp
from fastapi import Depends, Path
from reworkd_platform.db.crud.oauth import OAuthCrud
from reworkd_platform.db.models.auth import OauthCredentials
from reworkd_platform.schemas import UserBase
from reworkd_platform.services.security import encryption_service
from reworkd_platform.settings import Settings
from reworkd_platform.settings import settings as platform_settings
from reworkd_platform.web.api.http_responses import forbidden
class OAuthInstaller(ABC):
def __init__(self, crud: OAuthCrud, settings: Settings):
self.crud = crud
self.settings = settings
@abstractmethod
async def install(self, user: UserBase, redirect_uri: str) -> str:
raise NotImplementedError()
@abstractmethod
async def install_callback(self, code: str, state: str) -> OauthCredentials:
raise NotImplementedError()
@abstractmethod
async def uninstall(self, user: UserBase) -> bool:
raise NotImplementedError()
@staticmethod
def store_access_token(creds: OauthCredentials, access_token: str) -> None:
creds.access_token_enc = encryption_service.encrypt(access_token)
@staticmethod
def store_refresh_token(creds: OauthCredentials, refresh_token: str) -> None:
creds.refresh_token_enc = encryption_service.encrypt(refresh_token)
class SIDInstaller(OAuthInstaller):
PROVIDER = "sid"
async def install(self, user: UserBase, redirect_uri: str) -> str:
# gracefully handle the case where the installation already exists
# this can happen if the user starts the process from multiple tabs
installation = await self.crud.get_installation_by_user_id(
user.id, self.PROVIDER
)
if not installation:
installation = await self.crud.create_installation(
user,
self.PROVIDER,
redirect_uri,
)
scopes = ["data:query", "offline_access"]
params = {
"client_id": self.settings.sid_client_id,
"redirect_uri": self.settings.sid_redirect_uri,
"response_type": "code",
"scope": " ".join(scopes),
"state": installation.state,
"audience": "https://api.sid.ai/api/v1/",
}
auth_url = "https://me.sid.ai/api/oauth/authorize"
auth_url += "?" + urlencode(params)
return auth_url
async def install_callback(self, code: str, state: str) -> OauthCredentials:
creds = await self.crud.get_installation_by_state(state)
if not creds:
raise forbidden()
req = {
"grant_type": "authorization_code",
"client_id": self.settings.sid_client_id,
"client_secret": self.settings.sid_client_secret,
"redirect_uri": self.settings.sid_redirect_uri,
"code": code,
}
async with aiohttp.ClientSession() as session:
async with session.post(
"https://auth.sid.ai/oauth/token",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
},
data=json.dumps(req),
) as response:
res_data = await response.json()
OAuthInstaller.store_access_token(creds, res_data["access_token"])
OAuthInstaller.store_refresh_token(creds, res_data["refresh_token"])
creds.access_token_expiration = datetime.now() + timedelta(
seconds=res_data["expires_in"]
)
return await creds.save(self.crud.session)
async def uninstall(self, user: UserBase) -> bool:
creds = await self.crud.get_installation_by_user_id(user.id, self.PROVIDER)
# check if credentials exist and contain a refresh token
if not creds:
return False
# use refresh token to revoke access
delete_token = encryption_service.decrypt(creds.refresh_token_enc)
# delete credentials from database
await self.crud.session.delete(creds)
# revoke refresh token
async with aiohttp.ClientSession() as session:
await session.post(
"https://auth.sid.ai/oauth/revoke",
headers={
"Content-Type": "application/json",
},
data=json.dumps(
{
"client_id": self.settings.sid_client_id,
"client_secret": self.settings.sid_client_secret,
"token": delete_token,
}
),
)
return True
integrations = {
SIDInstaller.PROVIDER: SIDInstaller,
}
def installer_factory(
provider: str = Path(description="OAuth Provider"),
crud: OAuthCrud = Depends(OAuthCrud.inject),
) -> OAuthInstaller:
"""Factory for OAuth installers
Args:
provider (str): OAuth Provider (can be slack, github, etc.) (injected)
crud (OAuthCrud): OAuth Crud (injected)
"""
if provider in integrations:
return integrations[provider](crud, platform_settings)
raise NotImplementedError()

View File

@ -0,0 +1,11 @@
import pinecone
from reworkd_platform.settings import settings
def init_pinecone() -> None:
if settings.pinecone_api_key and settings.pinecone_environment:
pinecone.init(
api_key=settings.pinecone_api_key,
environment=settings.pinecone_environment,
)

View File

@ -0,0 +1,98 @@
from __future__ import annotations
import uuid
from typing import Any, Dict, List
from langchain.embeddings import OpenAIEmbeddings
from langchain.embeddings.base import Embeddings
from pinecone import Index # import doesnt work on plane wifi
from pydantic import BaseModel
from reworkd_platform.settings import settings
from reworkd_platform.timer import timed_function
from reworkd_platform.web.api.memory.memory import AgentMemory
OPENAI_EMBEDDING_DIM = 1536
class Row(BaseModel):
id: str
values: List[float]
metadata: Dict[str, Any] = {}
class QueryResult(BaseModel):
id: str
score: float
metadata: Dict[str, Any] = {}
class PineconeMemory(AgentMemory):
"""
Wrapper around pinecone
"""
def __init__(self, index_name: str, namespace: str = ""):
self.index = Index(settings.pinecone_index_name)
self.namespace = namespace or index_name
@timed_function(level="DEBUG")
def __enter__(self) -> AgentMemory:
self.embeddings: Embeddings = OpenAIEmbeddings(
client=None, # Meta private value but mypy will complain its missing
openai_api_key=settings.openai_api_key,
)
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
pass
@timed_function(level="DEBUG")
def reset_class(self) -> None:
self.index.delete(delete_all=True, namespace=self.namespace)
@timed_function(level="DEBUG")
def add_tasks(self, tasks: List[str]) -> List[str]:
if len(tasks) == 0:
return []
embeds = self.embeddings.embed_documents(tasks)
if len(tasks) != len(embeds):
raise ValueError("Embeddings and tasks are not the same length")
rows = [
Row(values=vector, metadata={"text": tasks[i]}, id=str(uuid.uuid4()))
for i, vector in enumerate(embeds)
]
self.index.upsert(
vectors=[row.dict() for row in rows], namespace=self.namespace
)
return [row.id for row in rows]
@timed_function(level="DEBUG")
def get_similar_tasks(
self, text: str, score_threshold: float = 0.95
) -> List[QueryResult]:
# Get similar tasks
vector = self.embeddings.embed_query(text)
results = self.index.query(
vector=vector,
top_k=5,
include_metadata=True,
include_values=True,
namespace=self.namespace,
)
return [
QueryResult(id=row.id, score=row.score, metadata=row.metadata)
for row in getattr(results, "matches", [])
if row.score > score_threshold
]
@staticmethod
def should_use() -> bool:
return False

View File

@ -0,0 +1,23 @@
from typing import Union
from cryptography.fernet import Fernet, InvalidToken
from reworkd_platform.settings import settings
from reworkd_platform.web.api.http_responses import forbidden
class EncryptionService:
def __init__(self, secret: bytes):
self.fernet = Fernet(secret)
def encrypt(self, text: str) -> bytes:
return self.fernet.encrypt(text.encode("utf-8"))
def decrypt(self, encoded_bytes: Union[bytes, str]) -> str:
try:
return self.fernet.decrypt(encoded_bytes).decode("utf-8")
except InvalidToken:
raise forbidden()
encryption_service = EncryptionService(settings.secret_signing_key.encode())

View File

@ -0,0 +1,25 @@
from ssl import SSLContext, create_default_context
from typing import List, Optional
from reworkd_platform.settings import Settings
MACOS_CERT_PATH = "/etc/ssl/cert.pem"
DOCKER_CERT_PATH = "/etc/ssl/certs/ca-certificates.crt"
def get_ssl_context(
settings: Settings, paths: Optional[List[str]] = None
) -> SSLContext:
if settings.db_ca_path:
return create_default_context(cafile=settings.db_ca_path)
for path in paths or [MACOS_CERT_PATH, DOCKER_CERT_PATH]:
try:
return create_default_context(cafile=path)
except FileNotFoundError:
continue
raise ValueError(
"No CA certificates found for your OS. To fix this, please run change "
"db_ca_path in your settings.py to point to a valid CA certificate file."
)

View File

@ -0,0 +1 @@
"""Token Service"""

View File

@ -0,0 +1,7 @@
from fastapi import Request
from reworkd_platform.services.tokenizer.token_service import TokenService
def get_token_service(request: Request) -> TokenService:
return TokenService(request.app.state.token_encoding)

View File

@ -0,0 +1,16 @@
import tiktoken
from fastapi import FastAPI
ENCODING_NAME = "cl100k_base" # gpt-4, gpt-3.5-turbo, text-embedding-ada-002
def init_tokenizer(app: FastAPI) -> None: # pragma: no cover
"""
Initialize tokenizer.
TikToken downloads the encoding on start. It is then
stored in the state of the application.
:param app: current application.
"""
app.state.token_encoding = tiktoken.get_encoding(ENCODING_NAME)

View File

@ -0,0 +1,33 @@
from tiktoken import Encoding, get_encoding
from reworkd_platform.schemas.agent import LLM_MODEL_MAX_TOKENS, LLM_Model
from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
class TokenService:
def __init__(self, encoding: Encoding):
self.encoding = encoding
@classmethod
def create(cls, encoding: str = "cl100k_base") -> "TokenService":
return cls(get_encoding(encoding))
def tokenize(self, text: str) -> list[int]:
return self.encoding.encode(text)
def detokenize(self, tokens: list[int]) -> str:
return self.encoding.decode(tokens)
def count(self, text: str) -> int:
return len(self.tokenize(text))
def get_completion_space(self, model: LLM_Model, *prompts: str) -> int:
max_allowed_tokens = LLM_MODEL_MAX_TOKENS.get(model, 4000)
prompt_tokens = sum([self.count(p) for p in prompts])
return max_allowed_tokens - prompt_tokens
def calculate_max_tokens(self, model: WrappedChatOpenAI, *prompts: str) -> None:
requested_tokens = self.get_completion_space(model.model_name, *prompts)
model.max_tokens = min(model.max_tokens, requested_tokens)
model.max_tokens = max(model.max_tokens, 1)

View File

@ -0,0 +1,178 @@
import platform
from pathlib import Path
from tempfile import gettempdir
from typing import List, Literal, Optional, Union
from pydantic import BaseSettings
from yarl import URL
from reworkd_platform.constants import ENV_PREFIX
TEMP_DIR = Path(gettempdir())
LOG_LEVEL = Literal[
"NOTSET",
"DEBUG",
"INFO",
"WARNING",
"ERROR",
"FATAL",
]
SASL_MECHANISM = Literal[
"PLAIN",
"SCRAM-SHA-256",
]
ENVIRONMENT = Literal[
"development",
"production",
]
class Settings(BaseSettings):
"""
Application settings.
These parameters can be configured
with environment variables.
"""
# Application settings
host: str = "0.0.0.0"
port: int = 8000
workers_count: int = 1
reload: bool = True
environment: ENVIRONMENT = "development"
log_level: LOG_LEVEL = "DEBUG"
# Make sure you update this with your own secret key
# Must be 32 url-safe base64-encoded bytes
secret_signing_key: str = "JF52S66x6WMoifP5gZreiguYs9LYMn0lkXqgPYoNMD0="
# OpenAI
openai_api_base: str = "https://api.openai.com/v1"
openai_api_key: str = "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"
openai_api_version: str = "2023-08-01-preview"
azure_openai_deployment_name: str = "<Should be updated via env if using azure>"
# Helicone
helicone_api_base: str = "https://oai.hconeai.com/v1"
helicone_api_key: Optional[str] = None
replicate_api_key: Optional[str] = None
serp_api_key: Optional[str] = None
# Frontend URL for CORS
frontend_url: str = "https://allinix.ai"
allowed_origins_regex: Optional[str] = None
# Variables for the database
db_host: str = "65.19.178.35"
db_port: int = 3307
db_user: str = "reworkd_platform"
db_pass: str = "reworkd_platform"
db_base: str = "default"
db_echo: bool = False
db_ca_path: Optional[str] = None
# Variables for Pinecone DB
pinecone_api_key: Optional[str] = None
pinecone_index_name: Optional[str] = None
pinecone_environment: Optional[str] = None
# Sentry's configuration.
sentry_dsn: Optional[str] = None
sentry_sample_rate: float = 1.0
kafka_bootstrap_servers: Union[str, List[str]] = []
kafka_username: Optional[str] = None
kafka_password: Optional[str] = None
kafka_ssal_mechanism: SASL_MECHANISM = "PLAIN"
# Websocket settings
pusher_app_id: Optional[str] = None
pusher_key: Optional[str] = None
pusher_secret: Optional[str] = None
pusher_cluster: Optional[str] = None
# Application Settings
ff_mock_mode_enabled: bool = False # Controls whether calls are mocked
max_loops: int = 25 # Maximum number of loops to run
# Settings for sid
sid_client_id: Optional[str] = None
sid_client_secret: Optional[str] = None
sid_redirect_uri: Optional[str] = None
@property
def kafka_consumer_group(self) -> str:
"""
Kafka consumer group will be the name of the host in development
mode, making it easier to share a dev cluster.
"""
if self.environment == "development":
return platform.node()
return "platform"
@property
def db_url(self) -> URL:
return URL.build(
scheme="mysql+aiomysql",
host=self.db_host,
port=self.db_port,
user=self.db_user,
password=self.db_pass,
path=f"/{self.db_base}",
)
@property
def pusher_enabled(self) -> bool:
return all(
[
self.pusher_app_id,
self.pusher_key,
self.pusher_secret,
self.pusher_cluster,
]
)
@property
def kafka_enabled(self) -> bool:
return all(
[
self.kafka_bootstrap_servers,
self.kafka_username,
self.kafka_password,
]
)
@property
def helicone_enabled(self) -> bool:
return all(
[
self.helicone_api_base,
self.helicone_api_key,
]
)
@property
def sid_enabled(self) -> bool:
return all(
[
self.sid_client_id,
self.sid_client_secret,
self.sid_redirect_uri,
]
)
class Config:
env_file = ".env"
env_prefix = ENV_PREFIX
env_file_encoding = "utf-8"
settings = Settings()

View File

@ -0,0 +1 @@
"""Tests for reworkd_platform."""

View File

@ -0,0 +1,31 @@
import pytest
from pydantic import ValidationError
from reworkd_platform.web.api.agent.analysis import Analysis
from reworkd_platform.web.api.agent.tools.tools import get_default_tool, get_tool_name
def test_analysis_model() -> None:
valid_tool_name = get_tool_name(get_default_tool())
analysis = Analysis(action=valid_tool_name, arg="arg", reasoning="reasoning")
assert analysis.action == valid_tool_name
assert analysis.arg == "arg"
assert analysis.reasoning == "reasoning"
def test_analysis_model_search_empty_arg() -> None:
with pytest.raises(ValidationError):
Analysis(action="search", arg="", reasoning="reasoning")
def test_analysis_model_search_non_empty_arg() -> None:
analysis = Analysis(action="search", arg="non-empty arg", reasoning="reasoning")
assert analysis.action == "search"
assert analysis.arg == "non-empty arg"
assert analysis.reasoning == "reasoning"
def test_analysis_model_invalid_tool() -> None:
with pytest.raises(ValidationError):
Analysis(action="invalid tool name", arg="test argument", reasoning="reasoning")

View File

@ -0,0 +1,63 @@
from unittest.mock import AsyncMock
import pytest
from fastapi import HTTPException
from pytest_mock import MockerFixture
from reworkd_platform.db.crud.agent import AgentCRUD
from reworkd_platform.settings import settings
from reworkd_platform.web.api.errors import MaxLoopsError, MultipleSummaryError
@pytest.mark.asyncio
async def test_validate_task_count_no_error(mocker) -> None:
mock_agent_run_exists(mocker, True)
session = mock_session_with_run_count(mocker, 0)
agent_crud: AgentCRUD = AgentCRUD(session, mocker.MagicMock())
# Doesn't throw an exception
await agent_crud.validate_task_count("test", "summarize")
@pytest.mark.asyncio
async def test_validate_task_count_when_run_not_found(mocker: MockerFixture) -> None:
mock_agent_run_exists(mocker, False)
agent_crud: AgentCRUD = AgentCRUD(mocker.AsyncMock(), mocker.MagicMock())
with pytest.raises(HTTPException):
await agent_crud.validate_task_count("test", "test")
@pytest.mark.asyncio
async def test_validate_task_count_max_loops_error(mocker: MockerFixture) -> None:
mock_agent_run_exists(mocker, True)
session = mock_session_with_run_count(mocker, settings.max_loops)
agent_crud: AgentCRUD = AgentCRUD(session, mocker.AsyncMock())
with pytest.raises(MaxLoopsError):
await agent_crud.validate_task_count("test", "test")
@pytest.mark.asyncio
async def test_validate_task_count_multiple_summary_error(
mocker: MockerFixture,
) -> None:
mock_agent_run_exists(mocker, True)
session = mock_session_with_run_count(mocker, 2)
agent_crud: AgentCRUD = AgentCRUD(session, mocker.MagicMock())
with pytest.raises(MultipleSummaryError):
await agent_crud.validate_task_count("test", "summarize")
def mock_agent_run_exists(mocker: MockerFixture, exists: bool) -> None:
mocker.patch("reworkd_platform.db.models.agent.AgentRun.get", return_value=exists)
def mock_session_with_run_count(mocker: MockerFixture, run_count: int) -> AsyncMock:
session = mocker.AsyncMock()
scalar_mock = mocker.MagicMock()
session.execute.return_value = scalar_mock
scalar_mock.scalar_one.return_value = run_count
return session

View File

@ -0,0 +1,138 @@
import itertools
import pytest
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from reworkd_platform.schemas import ModelSettings, UserBase
from reworkd_platform.settings import Settings
from reworkd_platform.web.api.agent.model_factory import (
WrappedAzureChatOpenAI,
WrappedChatOpenAI,
create_model,
get_base_and_headers,
)
def test_helicone_enabled_without_custom_api_key():
model_settings = ModelSettings()
user = UserBase(id="user_id")
settings = Settings(
helicone_api_key="some_key",
helicone_api_base="helicone_base",
openai_api_base="openai_base",
)
base, headers, use_helicone = get_base_and_headers(settings, model_settings, user)
assert use_helicone is True
assert base == "helicone_base"
assert headers == {
"Helicone-Auth": "Bearer some_key",
"Helicone-Cache-Enabled": "true",
"Helicone-User-Id": "user_id",
"Helicone-OpenAI-Api-Base": "openai_base",
}
def test_helicone_disabled():
model_settings = ModelSettings()
user = UserBase(id="user_id")
settings = Settings()
base, headers, use_helicone = get_base_and_headers(settings, model_settings, user)
assert base == "https://api.openai.com/v1"
assert headers is None
assert use_helicone is False
def test_helicone_enabled_with_custom_api_key():
model_settings = ModelSettings(
custom_api_key="custom_key",
)
user = UserBase(id="user_id")
settings = Settings(
openai_api_base="openai_base",
helicone_api_key="some_key",
helicone_api_base="helicone_base",
)
base, headers, use_helicone = get_base_and_headers(settings, model_settings, user)
assert base == "https://api.openai.com/v1"
assert headers is None
assert use_helicone is False
@pytest.mark.parametrize(
"streaming, use_azure",
list(
itertools.product(
[True, False],
[True, False],
)
),
)
def test_create_model(streaming, use_azure):
user = UserBase(id="user_id")
settings = Settings()
model_settings = ModelSettings(
temperature=0.7,
model="gpt-4o",
max_tokens=100,
)
settings.openai_api_base = (
"https://api.openai.com" if not use_azure else "https://oai.azure.com"
)
settings.openai_api_key = "key"
settings.openai_api_version = "version"
result = create_model(settings, model_settings, user, streaming)
assert issubclass(result.__class__, WrappedChatOpenAI)
assert issubclass(result.__class__, ChatOpenAI)
# Check if the required keys are set properly
assert result.openai_api_base == settings.openai_api_base
assert result.openai_api_key == settings.openai_api_key
assert result.temperature == model_settings.temperature
assert result.max_tokens == model_settings.max_tokens
assert result.streaming == streaming
assert result.max_retries == 5
# For Azure specific checks
if use_azure:
assert isinstance(result, WrappedAzureChatOpenAI)
assert issubclass(result.__class__, AzureChatOpenAI)
assert result.openai_api_version == settings.openai_api_version
assert result.deployment_name == "gpt-35-turbo"
assert result.openai_api_type == "azure"
@pytest.mark.parametrize(
"model_settings, streaming",
list(
itertools.product(
[
ModelSettings(
customTemperature=0.222,
customModelName="gpt-4",
maxTokens=1234,
),
ModelSettings(),
],
[True, False],
)
),
)
def test_custom_model_settings(model_settings: ModelSettings, streaming: bool):
model = create_model(
Settings(),
model_settings,
UserBase(id="", email="test@example.com"),
streaming=streaming,
)
assert model.temperature == model_settings.temperature
assert model.model_name.startswith(model_settings.model)
assert model.max_tokens == model_settings.max_tokens
assert model.streaming == streaming

View File

@ -0,0 +1,203 @@
from typing import List, Type
import pytest
from langchain.schema import OutputParserException
from reworkd_platform.web.api.agent.task_output_parser import (
TaskOutputParser,
extract_array,
real_tasks_filter,
remove_prefix,
)
@pytest.mark.parametrize(
"input_text,expected_output",
[
(
'["Task 1: Do something", "Task 2: Do something else", "Task 3: Do '
'another thing"]',
["Do something", "Do something else", "Do another thing"],
),
(
'Some random stuff ["1: Hello"]',
["Hello"],
),
(
"[]",
[],
),
],
)
def test_parse_success(input_text: str, expected_output: List[str]) -> None:
parser = TaskOutputParser(completed_tasks=[])
result = parser.parse(input_text)
assert result == expected_output
def test_parse_with_completed_tasks() -> None:
input_text = '["One", "Two", "Three"]'
completed = ["One"]
expected = ["Two", "Three"]
parser = TaskOutputParser(completed_tasks=completed)
result = parser.parse(input_text)
assert result == expected
@pytest.mark.parametrize(
"input_text, exception",
[
# Test cases for non-array and non-multiline string inputs
("This is not an array", OutputParserException),
("123456", OutputParserException),
("Some random text", OutputParserException),
("[abc]", OutputParserException),
# Test cases for malformed arrays
("[1, 2, 3", OutputParserException),
("'item1', 'item2']", OutputParserException),
("['item1', 'item2", OutputParserException),
# Test case for invalid multiline strings
("This is not\na valid\nmultiline string.", OutputParserException),
# Test case for multiline strings that don't start with digit + period
("Some text\nMore text\nAnd more text.", OutputParserException),
],
)
def test_parse_failure(input_text: str, exception: Type[Exception]) -> None:
parser = TaskOutputParser(completed_tasks=[])
with pytest.raises(exception):
parser.parse(input_text)
@pytest.mark.parametrize(
"input_str, expected",
[
# Test cases for empty array
("[]", []),
# Test cases for arrays with one element
('["One"]', ["One"]),
("['Single quote']", ["Single quote"]),
# Test cases for arrays with multiple elements
('["Research", "Develop", "Integrate"]', ["Research", "Develop", "Integrate"]),
('["Search", "Identify"]', ["Search", "Identify"]),
('["Item 1","Item 2","Item 3"]', ["Item 1", "Item 2", "Item 3"]),
# Test cases for arrays with special characters in elements
("['Single with \"quote\"']", ['Single with "quote"']),
('["Escape \\" within"]', ['Escape " within']),
# Test case for array embedded in other text
("Random stuff ['Search', 'Identify']", ["Search", "Identify"]),
# Test case for array within JSON
('{"array": ["123", "456"]}', ["123", "456"]),
# Multiline string cases
(
"1. Identify the target\n2. Conduct research\n3. Implement the methods",
[
"1. Identify the target",
"2. Conduct research",
"3. Implement the methods",
],
),
("1. Step one.\n2. Step two.", ["1. Step one.", "2. Step two."]),
(
"""1. Review and understand the code to be debugged
2. Identify and address any errors or issues found during the review process
3. Print out debug information and setup initial variables
4. Start necessary threads and execute program logic.""",
[
"1. Review and understand the code to be debugged",
"2. Identify and address any errors or issues found during the review "
"process",
"3. Print out debug information and setup initial variables",
"4. Start necessary threads and execute program logic.",
],
),
# Test cases with sentences before the digit + period pattern
(
"Any text before 1. Identify the task to be repeated\nUnrelated info 2. "
"Determine the frequency of the repetition\nAnother sentence 3. Create a "
"schedule or system to ensure completion of the task at the designated "
"frequency\nMore text 4. Execute the task according to the established "
"schedule or system",
[
"1. Identify the task to be repeated",
"2. Determine the frequency of the repetition",
"3. Create a schedule or system to ensure completion of the task at "
"the designated frequency",
"4. Execute the task according to the established schedule or system",
],
),
],
)
def test_extract_array_success(input_str: str, expected: List[str]) -> None:
print(extract_array(input_str), expected)
assert extract_array(input_str) == expected
@pytest.mark.parametrize(
"input_str, exception",
[
(None, TypeError),
("123", RuntimeError),
("Some random text", RuntimeError),
('"single_string"', RuntimeError),
('{"test": 123}', RuntimeError),
('["Unclosed array", "other"', RuntimeError),
],
)
def test_extract_array_exception(input_str: str, exception: Type[Exception]) -> None:
with pytest.raises(exception):
extract_array(input_str)
@pytest.mark.parametrize(
"task_input, expected_output",
[
("Task: This is a sample task", "This is a sample task"),
(
"Task 1: Perform a comprehensive analysis of system performance.",
"Perform a comprehensive analysis of system performance.",
),
("Task 2. Create a python script", "Create a python script"),
("5 - This is a sample task", "This is a sample task"),
("2: This is a sample task", "This is a sample task"),
(
"This is a sample task without a prefix",
"This is a sample task without a prefix",
),
("Step: This is a sample task", "This is a sample task"),
(
"Step 1: Perform a comprehensive analysis of system performance.",
"Perform a comprehensive analysis of system performance.",
),
("Step 2:Create a python script", "Create a python script"),
("Step:This is a sample task", "This is a sample task"),
(
". Conduct research on the history of Nike",
"Conduct research on the history of Nike",
),
(".This is a sample task", "This is a sample task"),
(
"1. Research the history and background of Nike company.",
"Research the history and background of Nike company.",
),
],
)
def test_remove_task_prefix(task_input: str, expected_output: str) -> None:
output = remove_prefix(task_input)
assert output == expected_output
@pytest.mark.parametrize(
"input_text, expected_result",
[
("Write the report", True),
("No new task needed", False),
("Task completed", False),
("Do nothing", False),
("", False), # empty_string
("no new task needed", False), # case_insensitive
],
)
def test_real_tasks_filter_no_task(input_text: str, expected_result: bool) -> None:
assert real_tasks_filter(input_text) == expected_result

View File

@ -0,0 +1,56 @@
from typing import List, Type
from reworkd_platform.web.api.agent.tools.conclude import Conclude
from reworkd_platform.web.api.agent.tools.image import Image
from reworkd_platform.web.api.agent.tools.reason import Reason
from reworkd_platform.web.api.agent.tools.search import Search
from reworkd_platform.web.api.agent.tools.sidsearch import SID
from reworkd_platform.web.api.agent.tools.tools import (
Tool,
format_tool_name,
get_default_tool,
get_tool_from_name,
get_tool_name,
get_tools_overview,
)
def test_get_tool_name() -> None:
assert get_tool_name(Image) == "image"
assert get_tool_name(Search) == "search"
assert get_tool_name(Reason) == "reason"
def test_format_tool_name() -> None:
assert format_tool_name("Search") == "search"
assert format_tool_name("reason") == "reason"
assert format_tool_name("Conclude") == "conclude"
assert format_tool_name("CoNcLuDe") == "conclude"
def test_get_tools_overview_no_duplicates() -> None:
"""Test to assert that the tools overview doesn't include duplicates."""
tools: List[Type[Tool]] = [Image, Search, Reason, Conclude, Image, Search]
overview = get_tools_overview(tools)
# Check if each unique tool description is included in the overview
for tool in set(tools):
expected_description = f"'{get_tool_name(tool)}': {tool.description}"
assert expected_description in overview
# Check for duplicates in the overview
overview_list = overview.split("\n")
assert len(overview_list) == len(
set(overview_list)
), "Overview includes duplicate entries"
def test_get_default_tool() -> None:
assert get_default_tool() == Search
def test_get_tool_from_name() -> None:
assert get_tool_from_name("Search") == Search
assert get_tool_from_name("CoNcLuDe") == Conclude
assert get_tool_from_name("NonExistingTool") == Search
assert get_tool_from_name("SID") == SID

View File

@ -0,0 +1,43 @@
import pytest
from reworkd_platform.web.api.memory.memory_with_fallback import MemoryWithFallback
@pytest.mark.parametrize(
"method_name, args",
[
("add_tasks", (["task1", "task2"],)),
("get_similar_tasks", ("task1",)),
("reset_class", ()),
],
)
def test_memory_primary(mocker, method_name: str, args) -> None:
primary = mocker.Mock()
secondary = mocker.Mock()
memory_with_fallback = MemoryWithFallback(primary, secondary)
# Use getattr() to call the method on the object with args
getattr(memory_with_fallback, method_name)(*args)
getattr(primary, method_name).assert_called_once_with(*args)
getattr(secondary, method_name).assert_not_called()
@pytest.mark.parametrize(
"method_name, args",
[
("add_tasks", (["task1", "task2"],)),
("get_similar_tasks", ("task1",)),
("reset_class", ()),
],
)
def test_memory_fallback(mocker, method_name: str, args) -> None:
primary = mocker.Mock()
secondary = mocker.Mock()
memory_with_fallback = MemoryWithFallback(primary, secondary)
getattr(primary, method_name).side_effect = Exception("Primary Failed")
# Call the method again, this time it should fall back to secondary
getattr(memory_with_fallback, method_name)(*args)
getattr(primary, method_name).assert_called_once_with(*args)
getattr(secondary, method_name).assert_called_once_with(*args)

View File

@ -0,0 +1,26 @@
import pytest
from reworkd_platform.web.api.agent import dependancies
@pytest.mark.anyio
@pytest.mark.parametrize(
"validator, step",
[
(dependancies.agent_summarize_validator, "summarize"),
(dependancies.agent_chat_validator, "chat"),
(dependancies.agent_analyze_validator, "analyze"),
(dependancies.agent_create_validator, "create"),
(dependancies.agent_execute_validator, "execute"),
],
)
async def test_agent_validate(mocker, validator, step):
run_id = "asim"
crud = mocker.Mock()
body = mocker.Mock()
body.run_id = run_id
crud.create_task = mocker.AsyncMock()
await validator(body, crud)
crud.create_task.assert_called_once_with(run_id, step)

View File

@ -0,0 +1,45 @@
import pytest
from openai.error import InvalidRequestError, ServiceUnavailableError
from reworkd_platform.schemas.agent import ModelSettings
from reworkd_platform.web.api.agent.helpers import openai_error_handler
from reworkd_platform.web.api.errors import OpenAIError
async def act(*args, settings: ModelSettings = ModelSettings(), **kwargs):
return await openai_error_handler(*args, settings=settings, **kwargs)
@pytest.mark.asyncio
async def test_service_unavailable_error():
async def mock_service_unavailable_error():
raise ServiceUnavailableError("Service Unavailable")
with pytest.raises(OpenAIError):
await act(mock_service_unavailable_error)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"settings,should_log",
[
(ModelSettings(custom_api_key="xyz"), False),
(ModelSettings(custom_api_key=None), True),
],
)
async def test_should_log(settings, should_log):
async def mock_invalid_request_error_model_access():
raise InvalidRequestError(
"The model: xyz does not exist or you do not have access to it.",
param="model",
)
with pytest.raises(Exception) as exc_info:
await openai_error_handler(
mock_invalid_request_error_model_access, settings=settings
)
assert isinstance(exc_info.value, OpenAIError)
error: OpenAIError = exc_info.value
assert error.should_log == should_log

View File

@ -0,0 +1,15 @@
import pytest
from reworkd_platform.services.oauth_installers import installer_factory
def test_installer_factory(mocker):
crud = mocker.Mock()
installer_factory("sid", crud)
def test_integration_dne(mocker):
crud = mocker.Mock()
with pytest.raises(NotImplementedError):
installer_factory("asim", crud)

View File

@ -0,0 +1,18 @@
import pytest
from fastapi import FastAPI
from httpx import AsyncClient
from starlette import status
@pytest.mark.skip(reason="Mysql needs to be mocked")
@pytest.mark.anyio
async def test_health(client: AsyncClient, fastapi_app: FastAPI) -> None:
"""
Checks the health endpoint.
:param client: client for the app.
:param fastapi_app: current FastAPI application.
"""
url = fastapi_app.url_path_for("health_check")
response = await client.get(url)
assert response.status_code == status.HTTP_200_OK

View File

@ -0,0 +1,21 @@
from reworkd_platform.services.aws.s3 import SimpleStorageService
def test_create_signed_post(mocker):
post_url = {
"url": "https://my_bucket.s3.amazonaws.com/my_object",
"fields": {"key": "value"},
}
boto3_mock = mocker.Mock()
boto3_mock.generate_presigned_post.return_value = post_url
mocker.patch(
"reworkd_platform.services.aws.s3.boto3_client", return_value=boto3_mock
)
assert (
SimpleStorageService(bucket="my_bucket").create_presigned_upload_url(
object_name="json"
)
== post_url
)

View File

@ -0,0 +1,61 @@
import pytest
from reworkd_platform.schemas.agent import ModelSettings
@pytest.mark.parametrize(
"settings",
[
{
"model": "gpt-4o",
"max_tokens": 7000,
"temperature": 0.5,
"language": "french",
},
{
"model": "gpt-3.5-turbo",
"max_tokens": 3000,
},
{
"model": "gpt-3.5-turbo-16k",
"max_tokens": 16000,
},
],
)
def test_model_settings_valid(settings):
result = ModelSettings(**settings)
assert result.model == settings.get("model", "gpt-4o")
assert result.max_tokens == settings.get("max_tokens", 500)
assert result.temperature == settings.get("temperature", 0.9)
assert result.language == settings.get("language", "English")
@pytest.mark.parametrize(
"settings",
[
{
"model": "gpt-4-32k",
},
{
"temperature": -1,
},
{
"max_tokens": 8000,
},
{
"model": "gpt-4",
"max_tokens": 32000,
},
],
)
def test_model_settings_invalid(settings):
with pytest.raises(Exception):
ModelSettings(**settings)
def test_model_settings_default():
settings = ModelSettings(**{})
assert settings.model == "gpt-3.5-turbo"
assert settings.temperature == 0.9
assert settings.max_tokens == 500
assert settings.language == "English"

View File

@ -0,0 +1,29 @@
import pytest
from cryptography.fernet import Fernet
from fastapi import HTTPException
from reworkd_platform.services.security import EncryptionService
def test_encrypt_decrypt():
key = Fernet.generate_key()
service = EncryptionService(key)
original_text = "Hello, world!"
encrypted = service.encrypt(original_text)
decrypted = service.decrypt(encrypted)
assert original_text == decrypted
def test_invalid_key():
key = Fernet.generate_key()
different_key = Fernet.generate_key()
different_service = EncryptionService(different_key)
original_text = "Hello, world!"
encrypted = Fernet(key).encrypt(original_text.encode())
with pytest.raises(HTTPException):
different_service.decrypt(encrypted)

View File

@ -0,0 +1,5 @@
from reworkd_platform.settings import Settings
def test_settings_create():
assert Settings() is not None

View File

@ -0,0 +1,105 @@
from unittest.mock import Mock
import tiktoken
from reworkd_platform.schemas.agent import LLM_MODEL_MAX_TOKENS
from reworkd_platform.services.tokenizer.token_service import TokenService
encoding = tiktoken.get_encoding("cl100k_base")
def test_happy_path() -> None:
service = TokenService(encoding)
text = "Hello world!"
validate_tokenize_and_detokenize(service, text, 3)
def test_nothing() -> None:
service = TokenService(encoding)
text = ""
validate_tokenize_and_detokenize(service, text, 0)
def validate_tokenize_and_detokenize(
service: TokenService, text: str, expected_token_count: int
) -> None:
tokens = service.tokenize(text)
assert text == service.detokenize(tokens)
assert len(tokens) == service.count(text)
assert len(tokens) == expected_token_count
def test_calculate_max_tokens_with_small_max_tokens() -> None:
initial_max_tokens = 3000
service = TokenService(encoding)
model = Mock(spec=["model_name", "max_tokens"])
model.model_name = "gpt-3.5-turbo"
model.max_tokens = initial_max_tokens
service.calculate_max_tokens(model, "Hello")
assert model.max_tokens == initial_max_tokens
def test_calculate_max_tokens_with_high_completion_tokens() -> None:
service = TokenService(encoding)
prompt_tokens = service.count(LONG_TEXT)
model = Mock(spec=["model_name", "max_tokens"])
model.model_name = "gpt-3.5-turbo"
model.max_tokens = 8000
service.calculate_max_tokens(model, LONG_TEXT)
assert model.max_tokens == (
LLM_MODEL_MAX_TOKENS.get("gpt-3.5-turbo") - prompt_tokens
)
def test_calculate_max_tokens_with_negative_result() -> None:
service = TokenService(encoding)
model = Mock(spec=["model_name", "max_tokens"])
model.model_name = "gpt-3.5-turbo"
model.max_tokens = 8000
service.calculate_max_tokens(model, *([LONG_TEXT] * 100))
# We use the minimum length of 1
assert model.max_tokens == 1
LONG_TEXT = """
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
This is some long text. This is some long text. This is some long text.
"""

View File

@ -0,0 +1,34 @@
from functools import wraps
from time import time
from typing import Any, Callable, Literal
from loguru import logger
Log_Level = Literal[
"TRACE",
"DEBUG",
"INFO",
"SUCCESS",
"WARNING",
"ERROR",
"CRITICAL",
]
def timed_function(level: Log_Level = "INFO") -> Callable[..., Any]:
def decorator(func: Any) -> Callable[..., Any]:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
start_time = time()
result = func(*args, **kwargs)
execution_time = time() - start_time
logger.log(
level,
f"Function '{func.__qualname__}' executed in {execution_time:.4f} seconds",
)
return result
return wrapper
return decorator

View File

@ -0,0 +1 @@
"""WEB API for reworkd_platform."""

View File

@ -0,0 +1 @@
"""reworkd_platform API package."""

View File

@ -0,0 +1,4 @@
"""API for checking running agents"""
from reworkd_platform.web.api.agent.views import router
__all__ = ["router"]

View File

@ -0,0 +1,51 @@
from typing import List, Optional, Protocol
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
from reworkd_platform.web.api.agent.analysis import Analysis
class AgentService(Protocol):
async def start_goal_agent(self, *, goal: str) -> List[str]:
pass
async def analyze_task_agent(
self, *, goal: str, task: str, tool_names: List[str]
) -> Analysis:
pass
async def execute_task_agent(
self,
*,
goal: str,
task: str,
analysis: Analysis,
) -> FastAPIStreamingResponse:
pass
async def create_tasks_agent(
self,
*,
goal: str,
tasks: List[str],
last_task: str,
result: str,
completed_tasks: Optional[List[str]] = None,
) -> List[str]:
pass
async def summarize_task_agent(
self,
*,
goal: str,
results: List[str],
) -> FastAPIStreamingResponse:
pass
async def chat(
self,
*,
message: str,
results: List[str],
) -> FastAPIStreamingResponse:
pass

View File

@ -0,0 +1,53 @@
from typing import Any, Callable, Coroutine, Optional
from fastapi import Depends
from reworkd_platform.db.crud.oauth import OAuthCrud
from reworkd_platform.schemas.agent import AgentRun, LLM_Model
from reworkd_platform.schemas.user import UserBase
from reworkd_platform.services.tokenizer.dependencies import get_token_service
from reworkd_platform.services.tokenizer.token_service import TokenService
from reworkd_platform.settings import settings
from reworkd_platform.web.api.agent.agent_service.agent_service import AgentService
from reworkd_platform.web.api.agent.agent_service.mock_agent_service import (
MockAgentService,
)
from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import (
OpenAIAgentService,
)
from reworkd_platform.web.api.agent.model_factory import create_model
from reworkd_platform.web.api.dependencies import get_current_user
def get_agent_service(
validator: Callable[..., Coroutine[Any, Any, AgentRun]],
streaming: bool = False,
llm_model: Optional[LLM_Model] = None,
) -> Callable[..., AgentService]:
def func(
run: AgentRun = Depends(validator),
user: UserBase = Depends(get_current_user),
token_service: TokenService = Depends(get_token_service),
oauth_crud: OAuthCrud = Depends(OAuthCrud.inject),
) -> AgentService:
if settings.ff_mock_mode_enabled:
return MockAgentService()
model = create_model(
settings,
run.model_settings,
user,
streaming=streaming,
force_model=llm_model,
)
return OpenAIAgentService(
model,
run.model_settings,
token_service,
callbacks=None,
user=user,
oauth_crud=oauth_crud,
)
return func

View File

@ -0,0 +1,74 @@
import time
from typing import Any, List
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
from reworkd_platform.web.api.agent.agent_service.agent_service import (
AgentService,
Analysis,
)
from reworkd_platform.web.api.agent.stream_mock import stream_string
class MockAgentService(AgentService):
async def start_goal_agent(self, **kwargs: Any) -> List[str]:
time.sleep(1)
return ["Task X", "Task Y", "Task Z"]
async def create_tasks_agent(self, **kwargs: Any) -> List[str]:
time.sleep(1)
return ["Some random task that doesn't exist"]
async def analyze_task_agent(self, **kwargs: Any) -> Analysis:
time.sleep(1.5)
return Analysis(
action="reason",
arg="Mock analysis",
reasoning="Mock to avoid wasting money calling the OpenAI API.",
)
async def execute_task_agent(self, **kwargs: Any) -> FastAPIStreamingResponse:
time.sleep(0.5)
return stream_string(
""" This is going to be a longer task result such that
We make the stream of this string take time and feel long. The reality is... this is a mock!
Lorem Ipsum is simply dummy text of the printing and typesetting industry.
Lorem Ipsum has been the industry's standard dummy text ever since the 1500s,
when an unknown printer took a galley of type and scrambled it to make a type specimen book.
It has survived not only five centuries, but also the leap into electronic typesetting, remaining unchanged.
"""
+ kwargs.get("task", "task"),
True,
)
async def summarize_task_agent(
self,
*,
goal: str,
results: List[str],
) -> FastAPIStreamingResponse:
time.sleep(0.5)
return stream_string(
""" This is going to be a longer task result such that
We make the stream of this string take time and feel long. The reality is... this is a mock!
Lorem Ipsum is simply dummy text of the printing and typesetting industry.
Lorem Ipsum has been the industry's standard dummy text ever since the 1500s,
when an unknown printer took a galley of type and scrambled it to make a type specimen book.
It has survived not only five centuries, but also the leap into electronic typesetting, remaining unchanged.
""",
True,
)
async def chat(
self,
*,
message: str,
results: List[str],
) -> FastAPIStreamingResponse:
time.sleep(0.5)
return stream_string(
"What do you want dude?",
True,
)

View File

@ -0,0 +1,227 @@
from typing import List, Optional
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
from lanarky.responses import StreamingResponse
from langchain import LLMChain
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
from langchain.schema import HumanMessage
from loguru import logger
from pydantic import ValidationError
from reworkd_platform.db.crud.oauth import OAuthCrud
from reworkd_platform.schemas.agent import ModelSettings
from reworkd_platform.schemas.user import UserBase
from reworkd_platform.services.tokenizer.token_service import TokenService
from reworkd_platform.web.api.agent.agent_service.agent_service import AgentService
from reworkd_platform.web.api.agent.analysis import Analysis, AnalysisArguments
from reworkd_platform.web.api.agent.helpers import (
call_model_with_handling,
openai_error_handler,
parse_with_handling,
)
from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
from reworkd_platform.web.api.agent.prompts import (
analyze_task_prompt,
chat_prompt,
create_tasks_prompt,
start_goal_prompt,
)
from reworkd_platform.web.api.agent.task_output_parser import TaskOutputParser
from reworkd_platform.web.api.agent.tools.open_ai_function import get_tool_function
from reworkd_platform.web.api.agent.tools.tools import (
get_default_tool,
get_tool_from_name,
get_tool_name,
get_user_tools,
)
from reworkd_platform.web.api.agent.tools.utils import summarize
from reworkd_platform.web.api.errors import OpenAIError
class OpenAIAgentService(AgentService):
def __init__(
self,
model: WrappedChatOpenAI,
settings: ModelSettings,
token_service: TokenService,
callbacks: Optional[List[AsyncCallbackHandler]],
user: UserBase,
oauth_crud: OAuthCrud,
):
self.model = model
self.settings = settings
self.token_service = token_service
self.callbacks = callbacks
self.user = user
self.oauth_crud = oauth_crud
async def start_goal_agent(self, *, goal: str) -> List[str]:
prompt = ChatPromptTemplate.from_messages(
[SystemMessagePromptTemplate(prompt=start_goal_prompt)]
)
self.token_service.calculate_max_tokens(
self.model,
prompt.format_prompt(
goal=goal,
language=self.settings.language,
).to_string(),
)
completion = await call_model_with_handling(
self.model,
ChatPromptTemplate.from_messages(
[SystemMessagePromptTemplate(prompt=start_goal_prompt)]
),
{"goal": goal, "language": self.settings.language},
settings=self.settings,
callbacks=self.callbacks,
)
task_output_parser = TaskOutputParser(completed_tasks=[])
tasks = parse_with_handling(task_output_parser, completion)
return tasks
async def analyze_task_agent(
self, *, goal: str, task: str, tool_names: List[str]
) -> Analysis:
user_tools = await get_user_tools(tool_names, self.user, self.oauth_crud)
functions = list(map(get_tool_function, user_tools))
prompt = analyze_task_prompt.format_prompt(
goal=goal,
task=task,
language=self.settings.language,
)
self.token_service.calculate_max_tokens(
self.model,
prompt.to_string(),
str(functions),
)
message = await openai_error_handler(
func=self.model.apredict_messages,
messages=prompt.to_messages(),
functions=functions,
settings=self.settings,
callbacks=self.callbacks,
)
function_call = message.additional_kwargs.get("function_call", {})
completion = function_call.get("arguments", "")
try:
pydantic_parser = PydanticOutputParser(pydantic_object=AnalysisArguments)
analysis_arguments = parse_with_handling(pydantic_parser, completion)
return Analysis(
action=function_call.get("name", get_tool_name(get_default_tool())),
**analysis_arguments.dict(),
)
except (OpenAIError, ValidationError):
return Analysis.get_default_analysis(task)
async def execute_task_agent(
self,
*,
goal: str,
task: str,
analysis: Analysis,
) -> StreamingResponse:
# TODO: More mature way of calculating max_tokens
if self.model.max_tokens > 3000:
self.model.max_tokens = max(self.model.max_tokens - 1000, 3000)
tool_class = get_tool_from_name(analysis.action)
return await tool_class(self.model, self.settings.language).call(
goal,
task,
analysis.arg,
self.user,
self.oauth_crud,
)
async def create_tasks_agent(
self,
*,
goal: str,
tasks: List[str],
last_task: str,
result: str,
completed_tasks: Optional[List[str]] = None,
) -> List[str]:
prompt = ChatPromptTemplate.from_messages(
[SystemMessagePromptTemplate(prompt=create_tasks_prompt)]
)
args = {
"goal": goal,
"language": self.settings.language,
"tasks": "\n".join(tasks),
"lastTask": last_task,
"result": result,
}
self.token_service.calculate_max_tokens(
self.model, prompt.format_prompt(**args).to_string()
)
completion = await call_model_with_handling(
self.model, prompt, args, settings=self.settings, callbacks=self.callbacks
)
previous_tasks = (completed_tasks or []) + tasks
return [completion] if completion not in previous_tasks else []
async def summarize_task_agent(
self,
*,
goal: str,
results: List[str],
) -> FastAPIStreamingResponse:
self.model.model_name = "gpt-4o"
self.model.max_tokens = 8000 # Total tokens = prompt tokens + completion tokens
snippet_max_tokens = 7000 # Leave room for the rest of the prompt
text_tokens = self.token_service.tokenize("".join(results))
text = self.token_service.detokenize(text_tokens[0:snippet_max_tokens])
logger.info(f"Summarizing text: {text}")
return summarize(
model=self.model,
language=self.settings.language,
goal=goal,
text=text,
)
async def chat(
self,
*,
message: str,
results: List[str],
) -> FastAPIStreamingResponse:
self.model.model_name = "gpt-4o"
prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate(prompt=chat_prompt),
*[HumanMessage(content=result) for result in results],
HumanMessage(content=message),
]
)
self.token_service.calculate_max_tokens(
self.model,
prompt.format_prompt(
language=self.settings.language,
).to_string(),
)
chain = LLMChain(llm=self.model, prompt=prompt)
return StreamingResponse.from_chain(
chain,
{"language": self.settings.language},
media_type="text/event-stream",
)

View File

@ -0,0 +1,45 @@
from typing import Dict
from pydantic import BaseModel, validator
class AnalysisArguments(BaseModel):
"""
Arguments for the analysis function of a tool. OpenAI functions will resolve these values but leave out the action.
"""
reasoning: str
arg: str
class Analysis(AnalysisArguments):
action: str
@validator("action")
def action_must_be_valid_tool(cls, v: str) -> str:
# TODO: Remove circular import
from reworkd_platform.web.api.agent.tools.tools import get_available_tools_names
if v not in get_available_tools_names():
raise ValueError(f"Analysis action '{v}' is not a valid tool")
return v
@validator("action")
def search_action_must_have_arg(cls, v: str, values: Dict[str, str]) -> str:
from reworkd_platform.web.api.agent.tools.search import Search
from reworkd_platform.web.api.agent.tools.tools import get_tool_name
if v == get_tool_name(Search) and not values["arg"]:
raise ValueError("Analysis arg cannot be empty if action is 'search'")
return v
@classmethod
def get_default_analysis(cls, task: str) -> "Analysis":
# TODO: Remove circular import
from reworkd_platform.web.api.agent.tools.tools import get_default_tool_name
return cls(
reasoning="Hmm... I'll try searching it up",
action=get_default_tool_name(),
arg=task,
)

View File

@ -0,0 +1,95 @@
from typing import TypeVar
from fastapi import Body, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from reworkd_platform.db.crud.agent import AgentCRUD
from reworkd_platform.db.dependencies import get_db_session
from reworkd_platform.schemas.agent import (
AgentChat,
AgentRun,
AgentRunCreate,
AgentSummarize,
AgentTaskAnalyze,
AgentTaskCreate,
AgentTaskExecute,
Loop_Step,
)
from reworkd_platform.schemas.user import UserBase
from reworkd_platform.web.api.dependencies import get_current_user
T = TypeVar(
"T", AgentTaskAnalyze, AgentTaskExecute, AgentTaskCreate, AgentSummarize, AgentChat
)
def agent_crud(
user: UserBase = Depends(get_current_user),
session: AsyncSession = Depends(get_db_session),
) -> AgentCRUD:
return AgentCRUD(session, user)
async def agent_start_validator(
body: AgentRunCreate = Body(
example={
"goal": "Create business plan for a bagel company",
"modelSettings": {
"customModelName": "gpt-4o",
},
},
),
crud: AgentCRUD = Depends(agent_crud),
) -> AgentRun:
id_ = (await crud.create_run(body.goal)).id
return AgentRun(**body.dict(), run_id=str(id_))
async def validate(body: T, crud: AgentCRUD, type_: Loop_Step) -> T:
body.run_id = (await crud.create_task(body.run_id, type_)).id
return body
async def agent_analyze_validator(
body: AgentTaskAnalyze = Body(),
crud: AgentCRUD = Depends(agent_crud),
) -> AgentTaskAnalyze:
return await validate(body, crud, "analyze")
async def agent_execute_validator(
body: AgentTaskExecute = Body(
example={
"goal": "Perform tasks accurately",
"task": "Write code to make a platformer",
"analysis": {
"reasoning": "I like to write code.",
"action": "code",
"arg": "",
},
},
),
crud: AgentCRUD = Depends(agent_crud),
) -> AgentTaskExecute:
return await validate(body, crud, "execute")
async def agent_create_validator(
body: AgentTaskCreate = Body(),
crud: AgentCRUD = Depends(agent_crud),
) -> AgentTaskCreate:
return await validate(body, crud, "create")
async def agent_summarize_validator(
body: AgentSummarize = Body(),
crud: AgentCRUD = Depends(agent_crud),
) -> AgentSummarize:
return await validate(body, crud, "summarize")
async def agent_chat_validator(
body: AgentChat = Body(),
crud: AgentCRUD = Depends(agent_crud),
) -> AgentChat:
return await validate(body, crud, "chat")

View File

@ -0,0 +1,76 @@
from typing import Any, Callable, Dict, TypeVar
from langchain import BasePromptTemplate, LLMChain
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseOutputParser, OutputParserException
from openai.error import (
AuthenticationError,
InvalidRequestError,
RateLimitError,
ServiceUnavailableError,
)
from reworkd_platform.schemas.agent import ModelSettings
from reworkd_platform.web.api.errors import OpenAIError
T = TypeVar("T")
def parse_with_handling(parser: BaseOutputParser[T], completion: str) -> T:
try:
return parser.parse(completion)
except OutputParserException as e:
raise OpenAIError(
e, "There was an issue parsing the response from the AI model."
)
async def openai_error_handler(
func: Callable[..., Any], *args: Any, settings: ModelSettings, **kwargs: Any
) -> Any:
try:
return await func(*args, **kwargs)
except ServiceUnavailableError as e:
raise OpenAIError(
e,
"OpenAI is experiencing issues. Visit "
"https://status.openai.com/ for more info.",
should_log=not settings.custom_api_key,
)
except InvalidRequestError as e:
if e.user_message.startswith("The model:"):
raise OpenAIError(
e,
f"Your API key does not have access to your current model. Please use a different model.",
should_log=not settings.custom_api_key,
)
raise OpenAIError(e, e.user_message)
except AuthenticationError as e:
raise OpenAIError(
e,
"Authentication error: Ensure a valid API key is being used.",
should_log=not settings.custom_api_key,
)
except RateLimitError as e:
if e.user_message.startswith("You exceeded your current quota"):
raise OpenAIError(
e,
f"Your API key exceeded your current quota, please check your plan and billing details.",
should_log=not settings.custom_api_key,
)
raise OpenAIError(e, e.user_message)
except Exception as e:
raise OpenAIError(
e, "There was an unexpected issue getting a response from the AI model."
)
async def call_model_with_handling(
model: BaseChatModel,
prompt: BasePromptTemplate,
args: Dict[str, str],
settings: ModelSettings,
**kwargs: Any,
) -> str:
chain = LLMChain(llm=model, prompt=prompt)
return await openai_error_handler(chain.arun, args, settings=settings, **kwargs)

View File

@ -0,0 +1,97 @@
from typing import Any, Dict, Optional, Tuple, Type, Union
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from pydantic import Field
from reworkd_platform.schemas.agent import LLM_Model, ModelSettings
from reworkd_platform.schemas.user import UserBase
from reworkd_platform.settings import Settings
class WrappedChatOpenAI(ChatOpenAI):
client: Any = Field(
default=None,
description="Meta private value but mypy will complain its missing",
)
max_tokens: int
model_name: LLM_Model = Field(alias="model")
class WrappedAzureChatOpenAI(AzureChatOpenAI, WrappedChatOpenAI):
openai_api_base: str
openai_api_version: str
deployment_name: str
WrappedChat = Union[WrappedAzureChatOpenAI, WrappedChatOpenAI]
def create_model(
settings: Settings,
model_settings: ModelSettings,
user: UserBase,
streaming: bool = False,
force_model: Optional[LLM_Model] = None,
) -> WrappedChat:
use_azure = (
not model_settings.custom_api_key and "azure" in settings.openai_api_base
)
llm_model = force_model or model_settings.model
model: Type[WrappedChat] = WrappedChatOpenAI
base, headers, use_helicone = get_base_and_headers(settings, model_settings, user)
kwargs = {
"openai_api_base": base,
"openai_api_key": model_settings.custom_api_key or settings.openai_api_key,
"temperature": model_settings.temperature,
"model": llm_model,
"max_tokens": model_settings.max_tokens,
"streaming": streaming,
"max_retries": 5,
"model_kwargs": {"user": user.email, "headers": headers},
}
if use_azure:
model = WrappedAzureChatOpenAI
deployment_name = llm_model.replace(".", "")
kwargs.update(
{
"openai_api_version": settings.openai_api_version,
"deployment_name": deployment_name,
"openai_api_type": "azure",
"openai_api_base": base.rstrip("v1"),
}
)
if use_helicone:
kwargs["model"] = deployment_name
return model(**kwargs) # type: ignore
def get_base_and_headers(
settings_: Settings, model_settings: ModelSettings, user: UserBase
) -> Tuple[str, Optional[Dict[str, str]], bool]:
use_helicone = settings_.helicone_enabled and not model_settings.custom_api_key
base = (
settings_.helicone_api_base
if use_helicone
else (
"https://api.openai.com/v1"
if model_settings.custom_api_key
else settings_.openai_api_base
)
)
headers = (
{
"Helicone-Auth": f"Bearer {settings_.helicone_api_key}",
"Helicone-Cache-Enabled": "true",
"Helicone-User-Id": user.id,
"Helicone-OpenAI-Api-Base": settings_.openai_api_base,
}
if use_helicone
else None
)
return base, headers, use_helicone

View File

@ -0,0 +1,185 @@
from langchain import PromptTemplate
# Create initial tasks using plan and solve prompting
# https://github.com/AGI-Edgerunners/Plan-and-Solve-Prompting
start_goal_prompt = PromptTemplate(
template="""You are a task creation AI called Allinix.
You answer in the "{language}" language. You have the following objective "{goal}".
Return a list of search queries that would be required to answer the entirety of the objective.
Limit the list to a maximum of 5 queries. Ensure the queries are as succinct as possible.
For simple questions use a single query.
Return the response as a JSON array of strings. Examples:
query: "Who is considered the best NBA player in the current season?", answer: ["current NBA MVP candidates"]
query: "How does the Olympicpayroll brand currently stand in the market, and what are its prospects and strategies for expansion in NJ, NY, and PA?", answer: ["Olympicpayroll brand comprehensive analysis 2023", "customer reviews of Olympicpayroll.com", "Olympicpayroll market position analysis", "payroll industry trends forecast 2023-2025", "payroll services expansion strategies in NJ, NY, PA"]
query: "How can I create a function to add weight to edges in a digraph using {language}?", answer: ["algorithm to add weight to digraph edge in {language}"]
query: "What is the current weather in New York?", answer: ["current weather in New York"]
query: "5 + 5?", answer: ["Sum of 5 and 5"]
query: "What is a good homemade recipe for KFC-style chicken?", answer: ["KFC style chicken recipe at home"]
query: "What are the nutritional values of almond milk and soy milk?", answer: ["nutritional information of almond milk", "nutritional information of soy milk"]""",
input_variables=["goal", "language"],
)
analyze_task_prompt = PromptTemplate(
template="""
High level objective: "{goal}"
Current task: "{task}"
Based on this information, use the best function to make progress or accomplish the task entirely.
Select the correct function by being smart and efficient. Ensure "reasoning" and only "reasoning" is in the
{language} language.
Note you MUST select a function.
""",
input_variables=["goal", "task", "language"],
)
code_prompt = PromptTemplate(
template="""
You are a world-class software engineer and an expert in all programing languages,
software systems, and architecture.
For reference, your high level goal is {goal}
Write code in English but explanations/comments in the "{language}" language.
Provide no information about who you are and focus on writing code.
Ensure code is bug and error free and explain complex concepts through comments
Respond in well-formatted markdown. Ensure code blocks are used for code sections.
Approach problems step by step and file by file, for each section, use a heading to describe the section.
Write code to accomplish the following:
{task}
""",
input_variables=["goal", "language", "task"],
)
execute_task_prompt = PromptTemplate(
template="""Answer in the "{language}" language. Given
the following overall objective `{goal}` and the following sub-task, `{task}`.
Perform the task by understanding the problem, extracting variables, and being smart
and efficient. Write a detailed response that address the task.
When confronted with choices, make a decision yourself with reasoning.
""",
input_variables=["goal", "language", "task"],
)
create_tasks_prompt = PromptTemplate(
template="""You are an AI task creation agent. You must answer in the "{language}"
language. You have the following objective `{goal}`.
You have the following incomplete tasks:
`{tasks}`
You just completed the following task:
`{lastTask}`
And received the following result:
`{result}`.
Based on this, create a single new task to be completed by your AI system such that your goal is closer reached.
If there are no more tasks to be done, return nothing. Do not add quotes to the task.
Examples:
Search the web for NBA news
Create a function to add a new vertex with a specified weight to the digraph.
Search for any additional information on Bertie W.
""
""",
input_variables=["goal", "language", "tasks", "lastTask", "result"],
)
summarize_prompt = PromptTemplate(
template="""You must answer in the "{language}" language.
Combine the following text into a cohesive document:
"{text}"
Write using clear markdown formatting in a style expected of the goal "{goal}".
Be as clear, informative, and descriptive as necessary.
You will not make up information or add any information outside of the above text.
Only use the given information and nothing more.
If there is no information provided, say "There is nothing to summarize".
""",
input_variables=["goal", "language", "text"],
)
company_context_prompt = PromptTemplate(
template="""You must answer in the "{language}" language.
Create a short description on "{company_name}".
Find out what sector it is in and what are their primary products.
Be as clear, informative, and descriptive as necessary.
You will not make up information or add any information outside of the above text.
Only use the given information and nothing more.
If there is no information provided, say "There is nothing to summarize".
""",
input_variables=["company_name", "language"],
)
summarize_pdf_prompt = PromptTemplate(
template="""You must answer in the "{language}" language.
For the given text: "{text}", you have the following objective "{query}".
Be as clear, informative, and descriptive as necessary.
You will not make up information or add any information outside of the above text.
Only use the given information and nothing more.
""",
input_variables=["query", "language", "text"],
)
summarize_with_sources_prompt = PromptTemplate(
template="""You must answer in the "{language}" language.
Answer the following query: "{query}" using the following information: "{snippets}".
Write using clear markdown formatting and use markdown lists where possible.
Cite sources for sentences via markdown links using the source link as the link and the index as the text.
Use in-line sources. Do not separately list sources at the end of the writing.
If the query cannot be answered with the provided information, mention this and provide a reason why along with what it does mention.
Also cite the sources of what is actually mentioned.
Example sentences of the paragraph:
"So this is a cited sentence at the end of a paragraph[1](https://test.com). This is another sentence."
"Stephen curry is an american basketball player that plays for the warriors[1](https://www.britannica.com/biography/Stephen-Curry)."
"The economic growth forecast for the region has been adjusted from 2.5% to 3.1% due to improved trade relations[1](https://economictimes.com), while inflation rates are expected to remain steady at around 1.7% according to financial analysts[2](https://financeworld.com)."
""",
input_variables=["language", "query", "snippets"],
)
summarize_sid_prompt = PromptTemplate(
template="""You must answer in the "{language}" language.
Parse and summarize the following text snippets "{snippets}".
Write using clear markdown formatting in a style expected of the goal "{goal}".
Be as clear, informative, and descriptive as necessary and attempt to
answer the query: "{query}" as best as possible.
If any of the snippets are not relevant to the query,
ignore them, and do not include them in the summary.
Do not mention that you are ignoring them.
If there is no information provided, say "There is nothing to summarize".
""",
input_variables=["goal", "language", "query", "snippets"],
)
chat_prompt = PromptTemplate(
template="""You must answer in the "{language}" language.
You are a helpful AI Assistant that will provide responses based on the current conversation history.
The human will provide previous messages as context. Use ONLY this information for your responses.
Do not make anything up and do not add any additional information.
If you have no information for a given question in the conversation history,
say "I do not have any information on this".
""",
input_variables=["language"],
)

View File

@ -0,0 +1,26 @@
import asyncio
from typing import AsyncGenerator
import tiktoken
from fastapi import FastAPI
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
app = FastAPI()
def stream_string(data: str, delayed: bool = False) -> FastAPIStreamingResponse:
return FastAPIStreamingResponse(
stream_generator(data, delayed),
)
async def stream_generator(data: str, delayed: bool) -> AsyncGenerator[bytes, None]:
if delayed:
encoding = tiktoken.get_encoding("cl100k_base")
token_data = encoding.encode(data)
for token in token_data:
yield encoding.decode([token]).encode("utf-8")
await asyncio.sleep(0.025) # simulate slow processing
else:
yield data.encode()

View File

@ -0,0 +1,88 @@
import ast
import re
from typing import List
from langchain.schema import BaseOutputParser, OutputParserException
class TaskOutputParser(BaseOutputParser[List[str]]):
"""
Extension of LangChain's BaseOutputParser
Responsible for parsing task creation output into a list of task strings
"""
completed_tasks: List[str] = []
def __init__(self, *, completed_tasks: List[str]):
super().__init__()
self.completed_tasks = completed_tasks
def parse(self, text: str) -> List[str]:
try:
array_str = extract_array(text)
all_tasks = [
remove_prefix(task) for task in array_str if real_tasks_filter(task)
]
return [task for task in all_tasks if task not in self.completed_tasks]
except Exception as e:
msg = f"Failed to parse tasks from completion '{text}'. Exception: {e}"
raise OutputParserException(msg)
def get_format_instructions(self) -> str:
return """
The response should be a JSON array of strings. Example:
["Search the web for NBA news", "Write some code to build a web scraper"]
This should be parsable by json.loads()
"""
def extract_array(input_str: str) -> List[str]:
regex = (
r"\[\s*\]|" # Empty array check
r"(\[(?:\s*(?:\"(?:[^\"\\]|\\.)*\"|\'(?:[^\'\\]|\\.)*\')\s*,?)*\s*\])"
)
match = re.search(regex, input_str)
if match is not None:
return ast.literal_eval(match[0])
else:
return handle_multiline_string(input_str)
def handle_multiline_string(input_str: str) -> List[str]:
# Handle multiline string as a list
processed_lines = [
re.sub(r".*?(\d+\..+)", r"\1", line).strip()
for line in input_str.split("\n")
if line.strip() != ""
]
# Check if there is at least one line that starts with a digit and a period
if any(re.match(r"\d+\..+", line) for line in processed_lines):
return processed_lines
else:
raise RuntimeError(f"Failed to extract array from {input_str}")
def remove_prefix(input_str: str) -> str:
prefix_pattern = (
r"^(Task\s*\d*\.\s*|Task\s*\d*[-:]?\s*|Step\s*\d*["
r"-:]?\s*|Step\s*[-:]?\s*|\d+\.\s*|\d+\s*[-:]?\s*|^\.\s*|^\.*)"
)
return re.sub(prefix_pattern, "", input_str, flags=re.IGNORECASE)
def real_tasks_filter(input_str: str) -> bool:
no_task_regex = (
r"^No( (new|further|additional|extra|other))? tasks? (is )?("
r"required|needed|added|created|inputted).*"
)
task_complete_regex = r"^Task (complete|completed|finished|done|over|success).*"
do_nothing_regex = r"^(\s*|Do nothing(\s.*)?)$"
return (
not re.search(no_task_regex, input_str, re.IGNORECASE)
and not re.search(task_complete_regex, input_str, re.IGNORECASE)
and not re.search(do_nothing_regex, input_str, re.IGNORECASE)
)

View File

@ -0,0 +1,25 @@
from typing import Any
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
from lanarky.responses import StreamingResponse
from langchain import LLMChain
from reworkd_platform.web.api.agent.tools.tool import Tool
class Code(Tool):
description = "Should only be used to write code, refactor code, fix code bugs, and explain programming concepts."
public_description = "Write and review code."
async def call(
self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any
) -> FastAPIStreamingResponse:
from reworkd_platform.web.api.agent.prompts import code_prompt
chain = LLMChain(llm=self.model, prompt=code_prompt)
return StreamingResponse.from_chain(
chain,
{"goal": goal, "language": self.language, "task": task},
media_type="text/event-stream",
)

View File

@ -0,0 +1,15 @@
from typing import Any
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
from reworkd_platform.web.api.agent.stream_mock import stream_string
from reworkd_platform.web.api.agent.tools.tool import Tool
class Conclude(Tool):
description = "Use when there is nothing else to do. The task has been concluded."
async def call(
self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any
) -> FastAPIStreamingResponse:
return stream_string("Task execution concluded.", delayed=True)

View File

@ -0,0 +1,66 @@
from typing import Any
import openai
import replicate
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
from replicate.exceptions import ModelError
from replicate.exceptions import ReplicateError as ReplicateAPIError
from reworkd_platform.settings import settings
from reworkd_platform.web.api.agent.stream_mock import stream_string
from reworkd_platform.web.api.agent.tools.tool import Tool
from reworkd_platform.web.api.errors import ReplicateError
async def get_replicate_image(input_str: str) -> str:
if settings.replicate_api_key is None or settings.replicate_api_key == "":
raise RuntimeError("Replicate API key not set")
client = replicate.Client(settings.replicate_api_key)
try:
output = client.run(
"stability-ai/stable-diffusion"
":db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf",
input={"prompt": input_str},
image_dimensions="512x512",
)
except ModelError as e:
raise ReplicateError(e, "Image generation failed due to NSFW image.")
except ReplicateAPIError as e:
raise ReplicateError(e, "Failed to generate an image.")
return output[0]
# Use AI to generate an Image based on a prompt
async def get_open_ai_image(input_str: str) -> str:
response = openai.Image.create(
api_key=settings.openai_api_key,
prompt=input_str,
n=1,
size="256x256",
)
return response["data"][0]["url"]
class Image(Tool):
description = "Used to sketch, draw, or generate an image."
public_description = "Generate AI images."
arg_description = (
"The input prompt to the image generator. "
"This should be a detailed description of the image touching on image "
"style, image focus, color, etc."
)
image_url = "/tools/replicate.png"
async def call(
self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any
) -> FastAPIStreamingResponse:
# Use the replicate API if its available, otherwise use DALL-E
try:
url = await get_replicate_image(input_str)
except RuntimeError:
url = await get_open_ai_image(input_str)
return stream_string(f"![{input_str}]({url})")

View File

@ -0,0 +1,43 @@
from typing import Type, TypedDict
from reworkd_platform.web.api.agent.tools.tool import Tool
from reworkd_platform.web.api.agent.tools.tools import get_tool_name
class FunctionDescription(TypedDict):
"""Representation of a callable function to the OpenAI API."""
name: str
"""The name of the function."""
description: str
"""A description of the function."""
parameters: dict[str, object]
"""The parameters of the function."""
def get_tool_function(tool: Type[Tool]) -> FunctionDescription:
"""A function that will return the tool's function specification"""
name = get_tool_name(tool)
return {
"name": name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": (
f"Reasoning is how the task will be accomplished with the current function. "
"Detail your overall plan along with any concerns you have."
"Ensure this reasoning value is in the user defined langauge "
),
},
"arg": {
"type": "string",
"description": tool.arg_description,
},
},
"required": ["reasoning", "arg"],
},
}

View File

@ -0,0 +1,27 @@
from typing import Any
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
from lanarky.responses import StreamingResponse
from langchain import LLMChain
from reworkd_platform.web.api.agent.tools.tool import Tool
class Reason(Tool):
description = (
"Reason about task via existing information or understanding. "
"Make decisions / selections from options."
)
async def call(
self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any
) -> FastAPIStreamingResponse:
from reworkd_platform.web.api.agent.prompts import execute_task_prompt
chain = LLMChain(llm=self.model, prompt=execute_task_prompt)
return StreamingResponse.from_chain(
chain,
{"goal": goal, "language": self.language, "task": task},
media_type="text/event-stream",
)

View File

@ -0,0 +1,109 @@
from typing import Any, List
from urllib.parse import quote
import aiohttp
from aiohttp import ClientResponseError
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
from loguru import logger
from reworkd_platform.settings import settings
from reworkd_platform.web.api.agent.stream_mock import stream_string
from reworkd_platform.web.api.agent.tools.reason import Reason
from reworkd_platform.web.api.agent.tools.tool import Tool
from reworkd_platform.web.api.agent.tools.utils import (
CitedSnippet,
summarize_with_sources,
)
# Search google via serper.dev. Adapted from LangChain
# https://github.com/hwchase17/langchain/blob/master/langchain/utilities
async def _google_serper_search_results(
search_term: str, search_type: str = "search"
) -> dict[str, Any]:
headers = {
"X-API-KEY": settings.serp_api_key or "",
"Content-Type": "application/json",
}
params = {
"q": search_term,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"https://google.serper.dev/{search_type}", headers=headers, params=params
) as response:
response.raise_for_status()
search_results = await response.json()
return search_results
class Search(Tool):
description = (
"Search Google for short up to date searches for simple questions about public information "
"news and people.\n"
)
public_description = "Search google for information about current events."
arg_description = "The query argument to search for. This value is always populated and cannot be an empty string."
image_url = "/tools/google.png"
@staticmethod
def available() -> bool:
return settings.serp_api_key is not None and settings.serp_api_key != ""
async def call(
self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any
) -> FastAPIStreamingResponse:
try:
return await self._call(goal, task, input_str, *args, **kwargs)
except ClientResponseError:
logger.exception("Error calling Serper API, falling back to reasoning")
return await Reason(self.model, self.language).call(
goal, task, input_str, *args, **kwargs
)
async def _call(
self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any
) -> FastAPIStreamingResponse:
results = await _google_serper_search_results(
input_str,
)
k = 5 # Number of results to return
snippets: List[CitedSnippet] = []
if results.get("answerBox"):
answer_values = []
answer_box = results.get("answerBox", {})
if answer_box.get("answer"):
answer_values.append(answer_box.get("answer"))
elif answer_box.get("snippet"):
answer_values.append(answer_box.get("snippet").replace("\n", " "))
elif answer_box.get("snippetHighlighted"):
answer_values.append(", ".join(answer_box.get("snippetHighlighted")))
if len(answer_values) > 0:
snippets.append(
CitedSnippet(
len(snippets) + 1,
"\n".join(answer_values),
f"https://www.google.com/search?q={quote(input_str)}",
)
)
for i, result in enumerate(results["organic"][:k]):
texts = []
link = ""
if "snippet" in result:
texts.append(result["snippet"])
if "link" in result:
link = result["link"]
for attribute, value in result.get("attributes", {}).items():
texts.append(f"{attribute}: {value}.")
snippets.append(CitedSnippet(len(snippets) + 1, "\n".join(texts), link))
if len(snippets) == 0:
return stream_string("No good Google Search Result was found", True)
return summarize_with_sources(self.model, self.language, goal, task, snippets)

Some files were not shown because too many files have changed in this diff Show More