Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial PR to add support for stackoverflow teams #78

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Coming Soon...
- [X] Bookstack - by [@flifloo](https://github.com/flifloo) :pray:
- [X] Mattermost - by [@itaykal](https://github.com/Itaykal) :pray:
- [X] RocketChat - by [@flifloo](https://github.com/flifloo) :pray:
- [X] Stackoverflow Teams - by [@allen-munsch](https://github.com/allen-munsch) :pray:
- [ ] Gitlab Issues (In PR :pray:)
- [ ] Zendesk (In PR :pray:)
- [ ] Azure DevOps (In PR :pray:)
Expand Down
2 changes: 1 addition & 1 deletion app/api/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async def list_locations(request: Request, data_source_name: str, config: dict)
@router.post("")
async def connect_data_source(request: Request, dto: AddDataSourceDto, background_tasks: BackgroundTasks) -> int:
logger.info(f"Adding data source {dto.name} with config {json.dumps(dto.config)}")
data_source = DataSourceContext.create_data_source(name=dto.name, config=dto.config)
data_source = await DataSourceContext.create_data_source(name=dto.name, config=dto.config)
Posthog.added_data_source(uuid=request.headers.get('uuid'), name=dto.name)
# in main.py we have a background task that runs every 5 minutes and indexes the data source
# but here we want to index the data source immediately
Expand Down
1 change: 1 addition & 0 deletions app/clear_ack_queue.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sqlite3 ~/.gerev/storage/tasks.sqlite3/data.db 'delete from ack_queue_task where _id in (select _id from ack_queue_task);'
1 change: 1 addition & 0 deletions app/clear_data_sources.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sqlite3 ~/.gerev/storage/db.sqlite3 'delete from data_source where id in (select id from data_source);'
2 changes: 1 addition & 1 deletion app/data_source/api/base_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_config_fields() -> List[ConfigField]:

@staticmethod
@abstractmethod
def validate_config(config: Dict) -> None:
async def validate_config(config: Dict) -> None:
"""
Validates the config and raises an exception if it's invalid.
"""
Expand Down
36 changes: 25 additions & 11 deletions app/data_source/api/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from data_source.api.base_data_source import BaseDataSource
from data_source.api.dynamic_loader import DynamicLoader, ClassInfo
from data_source.api.exception import KnownException
from db_engine import Session
from db_engine import Session, async_session
from pydantic.error_wrappers import ValidationError
from schemas import DataSourceType, DataSource

from sqlalchemy import select

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -48,22 +49,31 @@ def get_data_source_classes(cls) -> Dict[str, BaseDataSource]:
return cls._data_source_classes

@classmethod
def create_data_source(cls, name: str, config: dict) -> BaseDataSource:
with Session() as session:
data_source_type = session.query(DataSourceType).filter_by(name=name).first()
async def create_data_source(cls, name: str, config: dict) -> BaseDataSource:
async with async_session() as session:
data_source_type = await session.execute(
select(DataSourceType).filter_by(name=name)
)
data_source_type = data_source_type.scalar_one_or_none()
if data_source_type is None:
raise KnownException(message=f"Data source type {name} does not exist")

data_source_class = DynamicLoader.get_data_source_class(name)
logger.info(f"validating config for data source {name}")
data_source_class.validate_config(config)
await data_source_class.validate_config(config)
config_str = json.dumps(config)

data_source_row = DataSource(type_id=data_source_type.id, config=config_str, created_at=datetime.now())
data_source_row = DataSource(
type_id=data_source_type.id,
config=config_str,
created_at=datetime.now(),
)
session.add(data_source_row)
session.commit()
await session.commit()

data_source = data_source_class(config=config, data_source_id=data_source_row.id)
data_source = data_source_class(
config=config, data_source_id=data_source_row.id
)
cls._data_source_instances[data_source_row.id] = data_source

return data_source
Expand Down Expand Up @@ -95,8 +105,12 @@ def _load_connected_sources_from_db(cls):
for data_source in data_sources:
data_source_cls = DynamicLoader.get_data_source_class(data_source.type.name)
config = json.loads(data_source.config)
data_source_instance = data_source_cls(config=config, data_source_id=data_source.id,
last_index_time=data_source.last_indexed_at)
try:
data_source_instance = data_source_cls(config=config, data_source_id=data_source.id,
last_index_time=data_source.last_indexed_at)
except ValidationError as e:
logger.error(f"Error loading data source {data_source.id}: {e}")
return
cls._data_source_instances[data_source.id] = data_source_instance

cls._initialized = True
Expand Down
25 changes: 25 additions & 0 deletions app/data_source/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,36 @@
from functools import lru_cache
from io import BytesIO
from typing import Optional
import time
import threading
from functools import wraps

import requests


logger = logging.getLogger(__name__)


def rate_limit(*, allowed_per_second: int):
max_period = 1.0 / allowed_per_second
last_call = [time.perf_counter()]
lock = threading.Lock()

def decorate(func):
@wraps(func)
def limit(*args, **kwargs):
with lock:
elapsed = time.perf_counter() - last_call[0]
hold = max_period - elapsed
if hold > 0:
time.sleep(hold)
result = func(*args, **kwargs)
last_call[0] = time.perf_counter()
return result
return limit
return decorate


def snake_case_to_pascal_case(snake_case_string: str):
"""Converts a snake case string to a PascalCase string"""
components = snake_case_string.split('_')
Expand Down Expand Up @@ -55,3 +79,4 @@ def get_confluence_user_image(image_url: str, token: str) -> Optional[str]:
return f"data:image/jpeg;base64,{base64.b64encode(image_bytes.getvalue()).decode()}"
except:
logger.warning(f"Failed to get confluence user image {image_url}")

2 changes: 1 addition & 1 deletion app/data_source/sources/bookstack/bookstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def list_books(book_stack: BookStack) -> List[Dict]:
raise e

@staticmethod
def validate_config(config: Dict) -> None:
async def validate_config(config: Dict) -> None:
try:
parsed_config = BookStackConfig(**config)
book_stack = BookStack(url=parsed_config.url, token_id=parsed_config.token_id,
Expand Down
2 changes: 1 addition & 1 deletion app/data_source/sources/confluence/confluence.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def list_all_spaces(confluence: Confluence) -> List[Location]:
return spaces

@staticmethod
def validate_config(config: Dict) -> None:
async def validate_config(config: Dict) -> None:
try:
client = ConfluenceDataSource.confluence_client_from_config(config)
ConfluenceDataSource.list_spaces(confluence=client)
Expand Down
2 changes: 1 addition & 1 deletion app/data_source/sources/confluence/confluence_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_config_fields() -> List[ConfigField]:
]

@staticmethod
def validate_config(config: Dict) -> None:
async def validate_config(config: Dict) -> None:
try:
client = ConfluenceCloudDataSource.confluence_client_from_config(config)
ConfluenceCloudDataSource.list_spaces(confluence=client)
Expand Down
2 changes: 1 addition & 1 deletion app/data_source/sources/google_drive/google_drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_config_fields() -> List[ConfigField]:
]

@staticmethod
def validate_config(config: Dict) -> None:
async def validate_config(config: Dict) -> None:
try:
scopes = ['https://www.googleapis.com/auth/drive.readonly']
parsed_config = GoogleDriveConfig(**config)
Expand Down
2 changes: 1 addition & 1 deletion app/data_source/sources/mattermost/mattermost.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_config_fields() -> List[ConfigField]:
]

@staticmethod
def validate_config(config: Dict) -> None:
async def validate_config(config: Dict) -> None:
try:
parsed_config = MattermostConfig(**config)
maattermost = Driver(options=asdict(parsed_config))
Expand Down
2 changes: 1 addition & 1 deletion app/data_source/sources/rocketchat/rocketchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_display_name(cls) -> str:
return "Rocket.Chat"

@staticmethod
def validate_config(config: Dict) -> None:
async def validate_config(config: Dict) -> None:
rocket_chat_config = RocketchatConfig(**config)
should_verify_ssl = os.environ.get('ROCKETCHAT_VERIFY_SSL') is not None
rocket_chat = RocketChat(user_id=rocket_chat_config.token_id, auth_token=rocket_chat_config.token_secret,
Expand Down
2 changes: 1 addition & 1 deletion app/data_source/sources/slack/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_config_fields() -> List[ConfigField]:
]

@staticmethod
def validate_config(config: Dict) -> None:
async def validate_config(config: Dict) -> None:
slack_config = SlackConfig(**config)
slack = WebClient(token=slack_config.token)
slack.auth_test()
Expand Down
Empty file.
136 changes: 136 additions & 0 deletions app/data_source/sources/stackoverflow/stackoverflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import logging
import time
from dataclasses import dataclass
from datetime import datetime
from typing import Dict, List, Optional
import requests

from data_source.api.base_data_source import BaseDataSource, ConfigField, HTMLInputType, BaseDataSourceConfig
from data_source.api.basic_document import DocumentType, BasicDocument
from queues.index_queue import IndexQueue

from data_source.api.utils import rate_limit

logger = logging.getLogger(__name__)


@dataclass
class StackOverflowPost:
link: str
score: int
last_activity_date: int
creation_date: int
post_id: Optional[int] = None
post_type: Optional[str] = None
body_markdown: Optional[str] = None
owner_account_id: Optional[int] = None
owner_reputation: Optional[int] = None
owner_user_id: Optional[int] = None
owner_user_type: Optional[str] = None
owner_profile_image: Optional[str] = None
owner_display_name: Optional[str] = None
owner_link: Optional[str] = None
title: Optional[str] = None
last_edit_date: Optional[str] = None
tags: Optional[List[str]] = None
view_count: Optional[int] = None
article_id: Optional[int] = None
article_type: Optional[str] = None

class StackOverflowConfig(BaseDataSourceConfig):
api_key: str
team_name: str


@rate_limit(allowed_per_second=15)
def rate_limited_get(url, headers):
'''
https://api.stackoverflowteams.com/docs/throttle
https://api.stackexchange.com/docs/throttle
Every application is subject to an IP based concurrent request throttle.
If a single IP is making more than 30 requests a second, new requests will be dropped.
The exact ban period is subject to change, but will be on the order of 30 seconds to a few minutes typically.
Note that exactly what response an application gets (in terms of HTTP code, text, and so on)
is undefined when subject to this ban; we consider > 30 request/sec per IP to be very abusive and thus cut the requests off very harshly.
'''
resp = requests.get(url, headers=headers)
if resp.status_code == 429:
logger.warning('Rate limited, sleeping for 5 minutes')
time.sleep(300)
return rate_limited_get(url, headers)
return resp


class StackOverflowDataSource(BaseDataSource):

@staticmethod
def get_config_fields() -> List[ConfigField]:
return [
ConfigField(label="PAT API Key", name="api_key", type=HTMLInputType.TEXT),
ConfigField(label="Team Name", name="team_name", type=HTMLInputType.TEXT),
]

@staticmethod
async def validate_config(config: Dict) -> None:
so_config = StackOverflowConfig(**config)
url = f'https://api.stackoverflowteams.com/2.3/questions?&team={so_config.team_name}'
response = rate_limited_get(url, headers={'X-API-Access-Token': so_config.api_key})
response.raise_for_status()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
so_config = StackOverflowConfig(**self._raw_config)
self._api_key = so_config.api_key
self._team_name = so_config.team_name

def _fetch_posts(self, *, api_key: str, team_name: str, page: int, doc_type: str) -> None:
team_fragment = f'&team={team_name}'
# this is a filter for "body markdown" inclusion, all filters are unique and static
# i am not entirely sure if this is per account, or usable by everyone
filter_fragment = '&filter=!nOedRLbqzB'
page_fragment = f'&page={page}'
# it looked like the timestamp was 10 digits, lets only look at stuff that is newer than the last index time
from_date_fragment = f'&fromdate={self._last_index_time.timestamp():.10n}'
url = f'https://api.stackoverflowteams.com/2.3/{doc_type}?{team_fragment}{filter_fragment}{page_fragment}{from_date_fragment}'
response = rate_limited_get(url, headers={'X-API-Access-Token': api_key})
response.raise_for_status()
response = response.json()
has_more = response['has_more']
items = response['items']
logger.info(f'Fetched {len(items)} {doc_type} from Stack Overflow')
for item_dict in items:
owner_fields = {}
if 'owner' in item_dict:
owner_fields = {f"owner_{k}": v for k, v in item_dict.pop('owner').items()}
if 'title' not in item_dict:
item_dict['title'] = item_dict['link']
post = StackOverflowPost(**item_dict, **owner_fields)
last_modified = datetime.fromtimestamp(post.last_edit_date or post.last_activity_date)
if last_modified < self._last_index_time:
return
logger.info(f'Feeding {doc_type} {post.title}')
post_document = BasicDocument(title=post.title, content=post.body_markdown, author=post.owner_display_name,
timestamp=datetime.fromtimestamp(post.creation_date), id=post.post_id,
data_source_id=self._data_source_id, location=post.link,
url=post.link, author_image_url=post.owner_profile_image,
type=DocumentType.MESSAGE)
IndexQueue.get_instance().put_single(doc=post_document)
if has_more:
# paginate onto the queue
self.add_task_to_queue(self._fetch_posts, api_key=self._api_key, team_name=self._team_name, page=page + 1, doc_type=doc_type)

def _feed_new_documents(self) -> None:
self.add_task_to_queue(self._fetch_posts, api_key=self._api_key, team_name=self._team_name, page=1, doc_type='posts')
# TODO: figure out how to get articles
allen-munsch marked this conversation as resolved.
Show resolved Hide resolved
# self.add_task_to_queue(self._fetch_posts, api_key=self._api_key, team_name=self._team_name, page=1, doc_type='articles')


# def test():
# import os
# config = {"api_key": os.environ['SO_API_KEY'], "team_name": os.environ['SO_TEAM_NAME']}
# so = StackOverflowDataSource(config=config, data_source_id=1)
# so._feed_new_documents()
#
#
# if __name__ == '__main__':
# test()
10 changes: 9 additions & 1 deletion app/db_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,19 @@

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
# import base document and then register all classes
from schemas.base import Base

from paths import SQLITE_DB_PATH

engine = create_engine(f'sqlite:///{SQLITE_DB_PATH}')
db_url = f'sqlite:///{SQLITE_DB_PATH}'
print('DB engine path:', db_url)
engine = create_engine(db_url)
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)

async_db_url = db_url.replace('sqlite', 'sqlite+aiosqlite', 1)
print('ASYNC DB engine path:', async_db_url)
async_engine = create_async_engine(async_db_url)
async_session = sessionmaker(async_engine, expire_on_commit=False, class_=AsyncSession)
Loading