From ac225c84f5687a5585f8b748378983f7958722b2 Mon Sep 17 00:00:00 2001 From: Jeremy Zhang Date: Tue, 10 Jul 2018 07:04:19 +0000 Subject: [PATCH] Initial switch to gino for discordbot db --- discordbot/titanembeds/bot.py | 2 +- discordbot/titanembeds/database/__init__.py | 359 +++++++----------- .../titanembeds/database/guild_members.py | 20 +- discordbot/titanembeds/database/guilds.py | 27 +- discordbot/titanembeds/database/messages.py | 21 +- .../database/unauthenticated_bans.py | 19 +- .../database/unauthenticated_users.py | 13 +- requirements.txt | 3 +- 8 files changed, 153 insertions(+), 311 deletions(-) diff --git a/discordbot/titanembeds/bot.py b/discordbot/titanembeds/bot.py index e5a22d0..b26349d 100644 --- a/discordbot/titanembeds/bot.py +++ b/discordbot/titanembeds/bot.py @@ -68,7 +68,7 @@ class Titan(discord.AutoShardedClient): await self.change_presence(status=discord.Status.online, activity=game) try: - self.database.connect(config["database-uri"]) + await self.database.connect(config["database-uri"]) except Exception: self.logger.error("Unable to connect to specified database!") traceback.print_exc() diff --git a/discordbot/titanembeds/database/__init__.py b/discordbot/titanembeds/database/__init__.py index 76a6e31..6b05953 100644 --- a/discordbot/titanembeds/database/__init__.py +++ b/discordbot/titanembeds/database/__init__.py @@ -1,13 +1,8 @@ -from contextlib import contextmanager -import sqlalchemy as db -from sqlalchemy.engine import Engine, create_engine -from sqlalchemy.orm import sessionmaker, Session -from sqlalchemy.ext.declarative import declarative_base - +from gino import Gino import json import discord -Base = declarative_base() +db = Gino() from titanembeds.database.guilds import Guilds from titanembeds.database.messages import Messages @@ -18,87 +13,47 @@ from titanembeds.database.unauthenticated_bans import UnauthenticatedBans from titanembeds.utils import get_message_author, get_message_mentions, get_webhooks_list, get_emojis_list, get_roles_list, get_channels_list, list_role_ids, get_attachments_list, get_embeds_list class DatabaseInterface(object): - # Courtesy of https://github.com/SunDwarf/Jokusoramame def __init__(self, bot): self.bot = bot - self.engine = None # type: Engine - self._sessionmaker = None # type: sessionmaker - - def connect(self, dburi): - self.engine = create_engine(dburi, pool_recycle=10) - - @contextmanager - def get_session(self): - Session = sessionmaker(bind=self.engine) - session = Session() - try: - yield session - except: - session.rollback() - finally: - session.commit() + async def connect(self, dburi): + await db.set_bind(dburi) async def push_message(self, message): - self.bot.loop.run_in_executor(None, self._push_message, message) - - def _push_message(self, message): if message.guild: - with self.get_session() as session: - edit_ts = message.edited_at - if not edit_ts: - edit_ts = None - else: - edit_ts = str(edit_ts) - - msg = Messages( - int(message.guild.id), - int(message.channel.id), - int(message.id), - message.content, - json.dumps(get_message_author(message)), - str(message.created_at), - edit_ts, - json.dumps(get_message_mentions(message.mentions)), - json.dumps(get_attachments_list(message.attachments)), - json.dumps(get_embeds_list(message.embeds)) - ) - session.add(msg) - session.commit() + edit_ts = message.edited_at + if not edit_ts: + edit_ts = None + else: + edit_ts = str(edit_ts) + await Messages.create( + message_id = int(message.id), + guild_id = int(message.guild.id), + channel_id = int(message.channel.id), + content = message.content, + author = json.dumps(get_message_author(message)), + timestamp = str(message.created_at), + edited_timestamp = edit_ts, + mentions = json.dumps(get_message_mentions(message.mentions)), + attachments = json.dumps(get_attachments_list(message.attachments)), + embeds = json.dumps(get_embeds_list(message.embeds)) + ) async def update_message(self, message): - self.bot.loop.run_in_executor(None, self._update_message, message) - - def _update_message(self, message): if message.guild: - with self.get_session() as session: - msg = session.query(Messages) \ - .filter(Messages.guild_id == message.guild.id) \ - .filter(Messages.channel_id == message.channel.id) \ - .filter(Messages.message_id == message.id).first() - if msg: - msg.content = message.content - msg.timestamp = message.created_at - msg.edited_timestamp = message.edited_at - msg.mentions = json.dumps(get_message_mentions(message.mentions)) - msg.attachments = json.dumps(get_attachments_list(message.attachments)) - msg.embeds = json.dumps(get_embeds_list(message.embeds)) - msg.author = json.dumps(get_message_author(message)) - session.commit() + await Messages.get(int(message.id)).update( + content = message.content, + timestamp = message.created_at, + edited_timestamp = message.edited_at, + mentions = json.dumps(get_message_mentions(message.mentions)), + attachments = json.dumps(get_attachments_list(message.attachments)), + embeds = json.dumps(get_embeds_list(message.embeds)), + author = json.dumps(get_message_author(message)) + ).apply() async def delete_message(self, message): - self.bot.loop.run_in_executor(None, self._delete_message, message) - - def _delete_message(self, message): if message.guild: - with self.get_session() as session: - msg = session.query(Messages) \ - .filter(Messages.guild_id == int(message.guild.id)) \ - .filter(Messages.channel_id == int(message.channel.id)) \ - .filter(Messages.message_id == int(message.id)).first() - if msg: - session.delete(msg) - session.commit() + await Messages.get(int(message.id)).delete() async def update_guild(self, guild): if guild.me.guild_permissions.manage_webhooks: @@ -108,163 +63,124 @@ class DatabaseInterface(object): server_webhooks = [] else: server_webhooks = [] - self.bot.loop.run_in_executor(None, self._update_guild, guild, server_webhooks) - - def _update_guild(self, guild, server_webhooks): - with self.get_session() as session: - gui = session.query(Guilds).filter(Guilds.guild_id == guild.id).first() - if not gui: - gui = Guilds( - int(guild.id), - guild.name, - json.dumps(get_roles_list(guild.roles)), - json.dumps(get_channels_list(guild.channels)), - json.dumps(get_webhooks_list(server_webhooks)), - json.dumps(get_emojis_list(guild.emojis)), - int(guild.owner_id), - guild.icon - ) - session.add(gui) - else: - gui.name = guild.name - gui.roles = json.dumps(get_roles_list(guild.roles)) - gui.channels = json.dumps(get_channels_list(guild.channels)) - gui.webhooks = json.dumps(get_webhooks_list(server_webhooks)) - gui.emojis = json.dumps(get_emojis_list(guild.emojis)) - gui.owner_id = int(guild.owner_id) - gui.icon = guild.icon - session.commit() + gui = await Guilds.get(guild.id) + if not gui: + await Guilds.create( + guild_id = int(guild.id), + name = guild.name, + unauth_users = True, + visitor_view = False, + webhook_messages = False, + guest_icon = None, + chat_links = True, + bracket_links = True, + unauth_captcha = True, + mentions_limit = -1, + roles = json.dumps(get_roles_list(guild.roles)), + channels = json.dumps(get_channels_list(guild.channels)), + webhooks = json.dumps(get_webhooks_list(server_webhooks)), + emojis = json.dumps(get_emojis_list(guild.emojis)), + owner_id = int(guild.owner_id), + icon = guild.icon + ) + else: + await gui.update( + name = guild.name, + roles = json.dumps(get_roles_list(guild.roles)), + channels = json.dumps(get_channels_list(guild.channels)), + webhooks = json.dumps(get_webhooks_list(server_webhooks)), + emojis = json.dumps(get_emojis_list(guild.emojis)), + owner_id = int(guild.owner_id), + icon = guild.icon + ).apply() async def remove_unused_guilds(self, guilds): - self.bot.loop.run_in_executor(None, self._remove_unused_guilds, guilds) - - def _remove_unused_guilds(self, guilds): - with self.get_session() as session: - dbguilds = session.query(Guilds).all() - changed = False - for guild in dbguilds: - disguild = discord.utils.get(guilds, id=guild.guild_id) - if not disguild: - changed = True - dbmsgs = session.query(Messages).filter(Messages.guild_id == int(guild.guild_id)).all() - for msg in dbmsgs: - session.delete(msg) - session.delete(guild) - if changed: - session.commit() + dbguilds = await Guilds.query.gino.all() + for guild in dbguilds: + disguild = discord.utils.get(guilds, id=guild.guild_id) + if not disguild: + await Messages.delete.where(Messages.guild_id == int(guild.guild_id)).gino.status() async def remove_guild(self, guild): - self.bot.loop.run_in_executor(None, self._remove_guild, guild) - - def _remove_guild(self, guild): - with self.get_session() as session: - gui = session.query(Guilds).filter(Guilds.guild_id == int(guild.id)).first() - if gui: - dbmsgs = session.query(Messages).filter(Messages.guild_id == int(guild.id)).delete() - session.delete(gui) - session.commit() + gui = await Guilds.get(int(guild.id)) + if gui: + await Messages.delete.where(Messages.guild_id == int(guild.id)).gino.status() + await gui.delete() async def update_guild_member(self, member, active=True, banned=False, guild=None): - self.bot.loop.run_in_executor(None, self._update_guild_member, member, active, banned, guild) - - def _update_guild_member(self, member, active=True, banned=False, guild=None): - with self.get_session() as session: - if guild: - dbmember = session.query(GuildMembers) \ - .filter(GuildMembers.guild_id == int(guild.id)) \ - .filter(GuildMembers.user_id == int(member.id)) \ - .order_by(GuildMembers.id).all() - else: - dbmember = session.query(GuildMembers) \ - .filter(GuildMembers.guild_id == int(member.guild.id)) \ - .filter(GuildMembers.user_id == int(member.id)) \ - .order_by(GuildMembers.id).all() - if not dbmember: - dbmember = GuildMembers( - int(member.guild.id), - int(member.id), - member.name, - member.discriminator, - member.nick, - member.avatar, - active, - banned, - json.dumps(list_role_ids(member.roles)) - ) - session.add(dbmember) - else: - if len(dbmember) > 1: - for mem in dbmember[1:]: - session.delete(mem) + if guild: + dbmember = await GuildMembers.query \ + .where(GuildMembers.guild_id == int(guild.id)) \ + .where(GuildMembers.user_id == int(member.id)) \ + .order_by(GuildMembers.id).gino.all() + else: + dbmember = await GuildMembers.query \ + .where(GuildMembers.guild_id == int(member.guild.id)) \ + .where(GuildMembers.user_id == int(member.id)) \ + .order_by(GuildMembers.id).gino.all() + if not dbmember: + await GuildMembers.create( + guild_id = int(member.guild.id), + user_id = int(member.id), + username = member.name, + discriminator = member.discriminator, + nickname = member.nick, + avatar = member.avatar, + active = active, + banned = banned, + roles = json.dumps(list_role_ids(member.roles)) + ) + else: + if len(dbmember) > 1: + for mem in dbmember[1:]: + await mem.delete() dbmember = dbmember[0] - if dbmember.banned != banned or dbmember.active != active or dbmember.username != member.name or dbmember.discriminator != int(member.discriminator) or dbmember.nickname != member.nick or dbmember.avatar != member.avatar or set(json.loads(dbmember.roles)) != set(list_role_ids(member.roles)): - dbmember.banned = banned - dbmember.active = active - dbmember.username = member.name - dbmember.discriminator = member.discriminator - dbmember.nickname = member.nick - dbmember.avatar = member.avatar - dbmember.roles = json.dumps(list_role_ids(member.roles)) - session.commit() + if dbmember.banned != banned or dbmember.active != active or dbmember.username != member.name or dbmember.discriminator != int(member.discriminator) or dbmember.nickname != member.nick or dbmember.avatar != member.avatar or set(json.loads(dbmember.roles)) != set(list_role_ids(member.roles)): + await dbmember.update( + banned = banned, + active = active, + username = member.name, + discriminator = member.discriminator, + nickname = member.nick, + avatar = member.avatar, + roles = json.dumps(list_role_ids(member.roles)) + ).apply() async def unban_server_user(self, user, server): - self.bot.loop.run_in_executor(None, self._unban_server_user, user, server) - - def _unban_server_user(self, user, server): - with self.get_session() as session: - dbmember = session.query(GuildMembers) \ - .filter(GuildMembers.guild_id == int(server.id)) \ - .filter(GuildMembers.user_id == int(user.id)).first() - if dbmember: - dbmember.banned = False - session.commit() + await GuildMembers.query \ + .where(GuildMembers.guild_id == int(server.id)) \ + .where(GuildMembers.user_id == int(user.id)) \ + .update(banned = False).apply() async def flag_unactive_guild_members(self, guild_id, guild_members): - self.bot.loop.run_in_executor(None, self._flag_unactive_guild_members, guild_id, guild_members) - - def _flag_unactive_guild_members(self, guild_id, guild_members): - with self.get_session() as session: - changed = False - dbmembers = session.query(GuildMembers) \ - .filter(GuildMembers.guild_id == int(guild_id)) \ - .filter(GuildMembers.active == True).all() - for member in dbmembers: + async with db.transaction(): + async for member in GuildMembers.query \ + .where(GuildMembers.guild_id == int(guild_id)) \ + .where(GuildMembers.active == True).gino.iterate(): dismember = discord.utils.get(guild_members, id=member.user_id) if not dismember: - changed = True - member.active = False - if changed: - session.commit() + await member.update(active = False).apply() async def flag_unactive_bans(self, guild_id, guildbans): - self.bot.loop.run_in_executor(None, self._flag_unactive_bans, guild_id, guildbans) - - def _flag_unactive_bans(self, guild_id, guildbans): - with self.get_session() as session: - changed = False - for usr in guildbans: - dbusr = session.query(GuildMembers) \ - .filter(GuildMembers.guild_id == int(guild_id)) \ - .filter(GuildMembers.user_id == int(usr.id)) \ - .filter(GuildMembers.active == False).first() - changed = True - if dbusr: - dbusr.banned = True - else: - dbusr = GuildMembers( - int(guild_id), - int(usr.id), - usr.name, - usr.discriminator, - None, - usr.avatar, - False, - True, - "[]" - ) - session.add(dbusr) - if changed: - session.commit() + for usr in guildbans: + dbusr = await GuildMembers.query \ + .where(GuildMembers.guild_id == int(guild_id)) \ + .where(GuildMembers.user_id == int(usr.id)) \ + .where(GuildMembers.active == False).gino.first() + if dbusr: + dbusr.update(banned=True).apply() + else: + await GuildMembers.create( + guild_id = int(guild_id), + user_id = int(usr.id), + username = usr.name, + discriminator = usr.discriminator, + nickname = None, + avatar = usr.avatar, + active = False, + banned = True, + roles = "[]" + ) async def ban_unauth_user_by_query(self, guild_id, placer_id, username, discriminator): self.bot.loop.run_in_executor(None, self._ban_unauth_user_by_query, guild_id, placer_id, username, discriminator) @@ -324,9 +240,4 @@ class DatabaseInterface(object): return "Successfully kicked **{}#{}**!".format(dbuser.username, dbuser.discriminator) async def delete_all_messages_from_channel(self, channel_id): - self.bot.loop.run_in_executor(None, self._delete_all_messages_from_channel, channel_id) - - def _delete_all_messages_from_channel(self, channel_id): - with self.get_session() as session: - session.query(Messages).filter(Messages.channel_id == int(channel_id)).delete() - session.commit() \ No newline at end of file + await Messages.delete.where(Messages.channel_id == int(channel_id)).gino.status() \ No newline at end of file diff --git a/discordbot/titanembeds/database/guild_members.py b/discordbot/titanembeds/database/guild_members.py index 2b4494e..5197039 100644 --- a/discordbot/titanembeds/database/guild_members.py +++ b/discordbot/titanembeds/database/guild_members.py @@ -1,6 +1,6 @@ -from titanembeds.database import db, Base +from titanembeds.database import db -class GuildMembers(Base): +class GuildMembers(db.Model): __tablename__ = "guild_members" id = db.Column(db.Integer, primary_key=True) # Auto incremented id guild_id = db.Column(db.BigInteger) # Discord guild id @@ -11,18 +11,4 @@ class GuildMembers(Base): avatar = db.Column(db.String(255)) # The avatar str of the user active = db.Column(db.Boolean()) # If the user is a member of the guild banned = db.Column(db.Boolean()) # If the user is banned in the guild - roles = db.Column(db.Text()) # Member roles - - def __init__(self, guild_id, user_id, username, discriminator, nickname, avatar, active, banned, roles): - self.guild_id = guild_id - self.user_id = user_id - self.username = username - self.discriminator = discriminator - self.nickname = nickname - self.avatar = avatar - self.active = active - self.banned = banned - self.roles = roles - - def __repr__(self): - return ''.format(self.id, self.guild_id, self.user_id, self.username, self.discriminator) + roles = db.Column(db.Text()) # Member roles \ No newline at end of file diff --git a/discordbot/titanembeds/database/guilds.py b/discordbot/titanembeds/database/guilds.py index 298eeb6..df77de0 100644 --- a/discordbot/titanembeds/database/guilds.py +++ b/discordbot/titanembeds/database/guilds.py @@ -1,6 +1,6 @@ -from titanembeds.database import db, Base +from titanembeds.database import db -class Guilds(Base): +class Guilds(db.Model): __tablename__ = "guilds" guild_id = db.Column(db.BigInteger, primary_key=True) # Discord guild id name = db.Column(db.String(255)) # Name @@ -23,25 +23,4 @@ class Guilds(Base): max_message_length = db.Column(db.Integer, nullable=False, server_default="300") # Chars length the message should be before being rejected by the server banned_words_enabled = db.Column(db.Boolean(), nullable=False, server_default="0") # If banned words are enforced banned_words_global_included = db.Column(db.Boolean(), nullable=False, server_default="0") # Add global banned words to the list - banned_words = db.Column(db.Text(), nullable=False, server_default="[]") # JSON list of strings to block from sending - - def __init__(self, guild_id, name, roles, channels, webhooks, emojis, owner_id, icon): - self.guild_id = guild_id - self.name = name - self.unauth_users = True # defaults to true - self.visitor_view = False - self.webhook_messages = False - self.guest_icon = None - self.chat_links = True - self.bracket_links = True - self.unauth_captcha = True - self.mentions_limit = -1 # -1 = unlimited mentions - self.roles = roles - self.channels = channels - self.webhooks = webhooks - self.emojis = emojis - self.owner_id = owner_id - self.icon = icon - - def __repr__(self): - return ''.format(self.id, self.guild_id) + banned_words = db.Column(db.Text(), nullable=False, server_default="[]") # JSON list of strings to block from sending \ No newline at end of file diff --git a/discordbot/titanembeds/database/messages.py b/discordbot/titanembeds/database/messages.py index 91b469a..f8ba868 100644 --- a/discordbot/titanembeds/database/messages.py +++ b/discordbot/titanembeds/database/messages.py @@ -1,6 +1,6 @@ -from titanembeds.database import db, Base +from titanembeds.database import db -class Messages(Base): +class Messages(db.Model): __tablename__ = "messages" message_id = db.Column(db.BigInteger, primary_key=True) # Message snowflake guild_id = db.Column(db.BigInteger) # Discord guild id @@ -11,19 +11,4 @@ class Messages(Base): edited_timestamp = db.Column(db.TIMESTAMP) # Timestamp of when content is edited mentions = db.Column(db.Text()) # Mentions serialized attachments = db.Column(db.Text()) # serialized attachments - embeds = db.Column(db.Text().with_variant(db.Text(length=4294967295), 'mysql')) # message embeds - - def __init__(self, guild_id, channel_id, message_id, content, author, timestamp, edited_timestamp, mentions, attachments, embeds): - self.guild_id = guild_id - self.channel_id = channel_id - self.message_id = message_id - self.content = content - self.author = author - self.timestamp = timestamp - self.edited_timestamp = edited_timestamp - self.mentions = mentions - self.attachments = attachments - self.embeds = embeds - - def __repr__(self): - return ''.format(self.id, self.guild_id, self.guild_id, self.channel_id, self.message_id) + embeds = db.Column(db.Text().with_variant(db.Text(length=4294967295), 'mysql')) # message embeds \ No newline at end of file diff --git a/discordbot/titanembeds/database/unauthenticated_bans.py b/discordbot/titanembeds/database/unauthenticated_bans.py index 5acd51d..adfad16 100644 --- a/discordbot/titanembeds/database/unauthenticated_bans.py +++ b/discordbot/titanembeds/database/unauthenticated_bans.py @@ -1,8 +1,8 @@ -from titanembeds.database import db, Base +from titanembeds.database import db import datetime import time -class UnauthenticatedBans(Base): +class UnauthenticatedBans(db.Model): __tablename__ = "unauthenticated_bans" id = db.Column(db.Integer, primary_key=True) # Auto increment id guild_id = db.Column(db.String(255)) # Guild pretaining to the unauthenticated user @@ -12,17 +12,4 @@ class UnauthenticatedBans(Base): timestamp = db.Column(db.TIMESTAMP) # The timestamp of when the user got banned reason = db.Column(db.Text()) # The reason of the ban set by the guild moderators lifter_id = db.Column(db.BigInteger) # Discord Client ID of the user who lifted the ban - placer_id = db.Column(db.BigInteger) # The id of who placed the ban - - def __init__(self, guild_id, ip_address, last_username, last_discriminator, reason, placer_id): - self.guild_id = guild_id - self.ip_address = ip_address - self.last_username = last_username - self.last_discriminator = last_discriminator - self.timestamp = datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S') - self.reason = reason - self.lifter_id = None - self.placer_id = placer_id - - def __repr__(self): - return ''.format(self.id, self.guild_id, self.username, self.discriminator, self.user_key, self.ip_address, self.revoked) + revoked = db.Column(db.Boolean()) # If the user's key has been revoked and a new one is required to be generated \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index a0f111d..65fc425 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ redis aioredis Flask-Babel patreon -flask-redis \ No newline at end of file +flask-redis +gino \ No newline at end of file