2025-02-17 19:44:17 +05:30

148 lines
5.2 KiB
Python

import json
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from urllib.parse import urlencode
import aiohttp
from fastapi import Depends, Path
from reworkd_platform.db.crud.oauth import OAuthCrud
from reworkd_platform.db.models.auth import OauthCredentials
from reworkd_platform.schemas import UserBase
from reworkd_platform.services.security import encryption_service
from reworkd_platform.settings import Settings
from reworkd_platform.settings import settings as platform_settings
from reworkd_platform.web.api.http_responses import forbidden
class OAuthInstaller(ABC):
def __init__(self, crud: OAuthCrud, settings: Settings):
self.crud = crud
self.settings = settings
@abstractmethod
async def install(self, user: UserBase, redirect_uri: str) -> str:
raise NotImplementedError()
@abstractmethod
async def install_callback(self, code: str, state: str) -> OauthCredentials:
raise NotImplementedError()
@abstractmethod
async def uninstall(self, user: UserBase) -> bool:
raise NotImplementedError()
@staticmethod
def store_access_token(creds: OauthCredentials, access_token: str) -> None:
creds.access_token_enc = encryption_service.encrypt(access_token)
@staticmethod
def store_refresh_token(creds: OauthCredentials, refresh_token: str) -> None:
creds.refresh_token_enc = encryption_service.encrypt(refresh_token)
class SIDInstaller(OAuthInstaller):
PROVIDER = "sid"
async def install(self, user: UserBase, redirect_uri: str) -> str:
# gracefully handle the case where the installation already exists
# this can happen if the user starts the process from multiple tabs
installation = await self.crud.get_installation_by_user_id(
user.id, self.PROVIDER
)
if not installation:
installation = await self.crud.create_installation(
user,
self.PROVIDER,
redirect_uri,
)
scopes = ["data:query", "offline_access"]
params = {
"client_id": self.settings.sid_client_id,
"redirect_uri": self.settings.sid_redirect_uri,
"response_type": "code",
"scope": " ".join(scopes),
"state": installation.state,
"audience": "https://api.sid.ai/api/v1/",
}
auth_url = "https://me.sid.ai/api/oauth/authorize"
auth_url += "?" + urlencode(params)
return auth_url
async def install_callback(self, code: str, state: str) -> OauthCredentials:
creds = await self.crud.get_installation_by_state(state)
if not creds:
raise forbidden()
req = {
"grant_type": "authorization_code",
"client_id": self.settings.sid_client_id,
"client_secret": self.settings.sid_client_secret,
"redirect_uri": self.settings.sid_redirect_uri,
"code": code,
}
async with aiohttp.ClientSession() as session:
async with session.post(
"https://auth.sid.ai/oauth/token",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
},
data=json.dumps(req),
) as response:
res_data = await response.json()
OAuthInstaller.store_access_token(creds, res_data["access_token"])
OAuthInstaller.store_refresh_token(creds, res_data["refresh_token"])
creds.access_token_expiration = datetime.now() + timedelta(
seconds=res_data["expires_in"]
)
return await creds.save(self.crud.session)
async def uninstall(self, user: UserBase) -> bool:
creds = await self.crud.get_installation_by_user_id(user.id, self.PROVIDER)
# check if credentials exist and contain a refresh token
if not creds:
return False
# use refresh token to revoke access
delete_token = encryption_service.decrypt(creds.refresh_token_enc)
# delete credentials from database
await self.crud.session.delete(creds)
# revoke refresh token
async with aiohttp.ClientSession() as session:
await session.post(
"https://auth.sid.ai/oauth/revoke",
headers={
"Content-Type": "application/json",
},
data=json.dumps(
{
"client_id": self.settings.sid_client_id,
"client_secret": self.settings.sid_client_secret,
"token": delete_token,
}
),
)
return True
integrations = {
SIDInstaller.PROVIDER: SIDInstaller,
}
def installer_factory(
provider: str = Path(description="OAuth Provider"),
crud: OAuthCrud = Depends(OAuthCrud.inject),
) -> OAuthInstaller:
"""Factory for OAuth installers
Args:
provider (str): OAuth Provider (can be slack, github, etc.) (injected)
crud (OAuthCrud): OAuth Crud (injected)
"""
if provider in integrations:
return integrations[provider](crud, platform_settings)
raise NotImplementedError()