From 9e096fe9301860fd5d3a3264d6e69adbf99da91e Mon Sep 17 00:00:00 2001 From: Jeremy Zhang Date: Mon, 20 Mar 2017 00:37:37 -0700 Subject: [PATCH] Implemented rate limit handling and seperated discord rest api --- titanembeds/blueprints/api/api.py | 49 ++----------- titanembeds/discordrest.py | 114 ++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+), 42 deletions(-) create mode 100644 titanembeds/discordrest.py diff --git a/titanembeds/blueprints/api/api.py b/titanembeds/blueprints/api/api.py index a235a62..23d806e 100644 --- a/titanembeds/blueprints/api/api.py +++ b/titanembeds/blueprints/api/api.py @@ -1,5 +1,6 @@ from titanembeds.database import db, Guilds, UnauthenticatedUsers, UnauthenticatedBans from titanembeds.decorators import valid_session_required +from titanembeds.discordrest import DiscordREST from flask import Blueprint, abort, jsonify, session, request from sqlalchemy import and_ from werkzeug.contrib.cache import SimpleCache @@ -10,10 +11,9 @@ import time from config import config api = Blueprint("api", __name__) +discord_api = DiscordREST(config['bot-token']) cache = SimpleCache() -_DISCORD_API_BASE = "https://discordapp.com/api/v6" - def user_unauthenticated(): if 'unauthenticated' in session: return session['unauthenticated'] @@ -48,32 +48,13 @@ def get_client_ipaddr(): else: # general return request.remote_addr -def get_all_guilds(): - _endpoint = _DISCORD_API_BASE + "/users/@me/guilds" - payload = {} - guilds = [] - headers = {'Authorization': 'Bot ' + config['bot-token']} - count = 1 #priming the loop - last_guild = "" - while count > 0: - r = requests.get(_endpoint, params=payload, headers=headers) - js = r.json() - if r.status_code == 200: - count = len(js) - guilds.extend(js) - if count > 0: - payload['after'] = js[-1]['id'] - else: - time.sleep(js['retry_after'] / float(1000)) - return guilds - def check_guild_existance(guild_id): - dbGuild = Guild.query.filter_by(guild_id=guild_id).first() + dbGuild = Guilds.query.filter_by(guild_id=guild_id).first() if not dbGuild: return False guilds = cache.get('bot_guilds') if guilds is None: - guilds = get_all_guilds() + guilds = discord_api.get_all_guilds() cache.set('bot_guilds', guilds) for guild in guilds: if guild_id == guild['id']: @@ -103,22 +84,6 @@ def update_user_status(guild_id, username, user_key=None): pass #authenticated user todo return status -def get_channel_messages(channel_id, after_snowflake=None): - _endpoint = _DISCORD_API_BASE + "/channels/{channel_id}/messages".format(channel_id=channel_id) - payload = {} - if after_snowflake is not None: - payload = {'after': after_snowflake} - headers = {'Authorization': 'Bot ' + config['bot-token']} - r = requests.get(_endpoint, params=payload, headers=headers) - return json.loads(r.content) - -def post_create_message(channel_id, content): - _endpoint = _DISCORD_API_BASE + "/channels/{channel_id}/messages".format(channel_id=channel_id) - payload = {'content': session['username'] + ": " + content} - headers = {'Authorization': 'Bot ' + config['bot-token'], 'Content-Type': 'application/json'} - r = requests.post(_endpoint, headers=headers, data=json.dumps(payload)) - return json.loads(r.content) - @api.route("/fetch", methods=["GET"]) @valid_session_required(api=True) def fetch(): @@ -132,7 +97,7 @@ def fetch(): if status['banned'] or status['revoked']: messages = {} else: - messages = get_channel_messages(channel_id, after_snowflake) + messages = discord_api.get_channel_messages(channel_id, after_snowflake) return jsonify(messages=messages, status=status) @api.route("/post", methods=["POST"]) @@ -147,7 +112,7 @@ def post(): status = update_user_status(channel_id, session['username'], key) if status['banned'] or status['revoked']: return jsonify(status=status) - message = post_create_message(channel_id, content) + message = discord_api.create_message(channel_id, content) return jsonify(message=message, status=status) @api.route("/create_unauthenticated_user", methods=["POST"]) @@ -157,7 +122,7 @@ def create_unauthenticated_user(): guild_id = request.form['guild_id'] ip_address = get_client_ipaddr() if not check_guild_existance(guild_id): - abort(404) + abort(400) if not checkUserBanned(guild_id, ip_address): session['username'] = username if 'user_id' not in session: diff --git a/titanembeds/discordrest.py b/titanembeds/discordrest.py new file mode 100644 index 0000000..e966e62 --- /dev/null +++ b/titanembeds/discordrest.py @@ -0,0 +1,114 @@ +import requests +import sys +import time +import json + +_DISCORD_API_BASE = "https://discordapp.com/api/v6" + +def json_or_text(response): + text = response.text + if response.headers['content-type'] == 'application/json': + return response.json() + return text + +class DiscordREST: + def __init__(self, bot_token): + self.bot_token = bot_token + self.user_agent = "TitanEmbeds (https://github.com/EndenDragon/Titan) Python/{} requests/{}".format(sys.version_info, requests.__version__) + self.rate_limit_bucket = {} + self.global_limited = False + self.global_limit_expire = 0 + + def request(self, verb, url, **kwargs): + headers = { + 'User-Agent': self.user_agent, + 'Authorization': 'Bot {}'.format(self.bot_token), + } + params = None + if 'params' in kwargs: + params = kwargs['params'] + data = None + if 'data' in kwargs: + data = kwargs['data'] + if 'json' in kwargs: + headers['Content-Type'] = 'application/json' + data = json.dumps(data) + + for tries in range(5): + curepoch = time.time() + if self.global_limited: + time.sleep(self.global_limit_expire - curepoch) + curepoch = time.time() + + if url in self.rate_limit_bucket and self.rate_limit_bucket[url] > curepoch: + time.sleep(self.rate_limit_bucket[url] - curepoch) + + url_formatted = _DISCORD_API_BASE + url + req = requests.request(verb, url_formatted, params=params, data=data, headers=headers) + + remaining = None + if 'X-RateLimit-Remaining' in req.headers: + remaining = req.headers['X-RateLimit-Remaining'] + if remaining == '0' and req.status_code != 429: + self.rate_limit_bucket[url] = int(req.headers['X-RateLimit-Reset']) + + if 300 > req.status_code >= 200: + self.global_limited = False + return { + 'success': True, + 'content': json_or_text(req), + 'code': req.status_code, + } + + if req.status_code == 429: + if 'X-RateLimit-Global' not in req.headers: + self.rate_limit_bucket[url] = int(req.headers['X-RateLimit-Reset']) + else: + self.global_limit_expire = time.time() + int(req.headers['Retry-After']) + + if req.status_code == 502 and tries <= 5: + time.sleep(1 + tries * 2) + continue + + if req.status_code == 403 or req.status_code == 404: + return { + 'success': False, + 'code': req.status_code, + } + return { + 'success': False, + 'code': req.status_code, + 'content': json_or_text(req), + } + + def get_all_guilds(self): + _endpoint = "/users/@me/guilds" + params = {} + guilds = [] + count = 1 #priming the loop + last_guild = "" + while count > 0: + r = self.request("GET", _endpoint, params=params) + if r['success'] == True: + content = r['content'] + count = len(content) + guilds.extend(content) + if count > 0: + params['after'] = content[-1]['id'] + else: + count = 0 + return guilds + + def get_channel_messages(self, channel_id, after_snowflake=None): + _endpoint = "/channels/{channel_id}/messages".format(channel_id=channel_id) + params = {} + if after_snowflake is not None: + params = {'after': after_snowflake} + r = self.request("GET", _endpoint, params=params) + return r + + def create_message(self, channel_id, content): + _endpoint = "/channels/{channel_id}/messages".format(channel_id=channel_id) + payload = {'content': content} + r = self.request("POST", _endpoint, data=payload) + return r