From ac225c84f5687a5585f8b748378983f7958722b2 Mon Sep 17 00:00:00 2001 From: Jeremy Zhang Date: Tue, 10 Jul 2018 07:04:19 +0000 Subject: [PATCH 1/2] Initial switch to gino for discordbot db --- discordbot/titanembeds/bot.py | 2 +- discordbot/titanembeds/database/__init__.py | 359 +++++++----------- .../titanembeds/database/guild_members.py | 20 +- discordbot/titanembeds/database/guilds.py | 27 +- discordbot/titanembeds/database/messages.py | 21 +- .../database/unauthenticated_bans.py | 19 +- .../database/unauthenticated_users.py | 13 +- requirements.txt | 3 +- 8 files changed, 153 insertions(+), 311 deletions(-) diff --git a/discordbot/titanembeds/bot.py b/discordbot/titanembeds/bot.py index e5a22d0..b26349d 100644 --- a/discordbot/titanembeds/bot.py +++ b/discordbot/titanembeds/bot.py @@ -68,7 +68,7 @@ class Titan(discord.AutoShardedClient): await self.change_presence(status=discord.Status.online, activity=game) try: - self.database.connect(config["database-uri"]) + await self.database.connect(config["database-uri"]) except Exception: self.logger.error("Unable to connect to specified database!") traceback.print_exc() diff --git a/discordbot/titanembeds/database/__init__.py b/discordbot/titanembeds/database/__init__.py index 76a6e31..6b05953 100644 --- a/discordbot/titanembeds/database/__init__.py +++ b/discordbot/titanembeds/database/__init__.py @@ -1,13 +1,8 @@ -from contextlib import contextmanager -import sqlalchemy as db -from sqlalchemy.engine import Engine, create_engine -from sqlalchemy.orm import sessionmaker, Session -from sqlalchemy.ext.declarative import declarative_base - +from gino import Gino import json import discord -Base = declarative_base() +db = Gino() from titanembeds.database.guilds import Guilds from titanembeds.database.messages import Messages @@ -18,87 +13,47 @@ from titanembeds.database.unauthenticated_bans import UnauthenticatedBans from titanembeds.utils import get_message_author, get_message_mentions, get_webhooks_list, get_emojis_list, get_roles_list, get_channels_list, list_role_ids, get_attachments_list, get_embeds_list class DatabaseInterface(object): - # Courtesy of https://github.com/SunDwarf/Jokusoramame def __init__(self, bot): self.bot = bot - self.engine = None # type: Engine - self._sessionmaker = None # type: sessionmaker - - def connect(self, dburi): - self.engine = create_engine(dburi, pool_recycle=10) - - @contextmanager - def get_session(self): - Session = sessionmaker(bind=self.engine) - session = Session() - try: - yield session - except: - session.rollback() - finally: - session.commit() + async def connect(self, dburi): + await db.set_bind(dburi) async def push_message(self, message): - self.bot.loop.run_in_executor(None, self._push_message, message) - - def _push_message(self, message): if message.guild: - with self.get_session() as session: - edit_ts = message.edited_at - if not edit_ts: - edit_ts = None - else: - edit_ts = str(edit_ts) - - msg = Messages( - int(message.guild.id), - int(message.channel.id), - int(message.id), - message.content, - json.dumps(get_message_author(message)), - str(message.created_at), - edit_ts, - json.dumps(get_message_mentions(message.mentions)), - json.dumps(get_attachments_list(message.attachments)), - json.dumps(get_embeds_list(message.embeds)) - ) - session.add(msg) - session.commit() + edit_ts = message.edited_at + if not edit_ts: + edit_ts = None + else: + edit_ts = str(edit_ts) + await Messages.create( + message_id = int(message.id), + guild_id = int(message.guild.id), + channel_id = int(message.channel.id), + content = message.content, + author = json.dumps(get_message_author(message)), + timestamp = str(message.created_at), + edited_timestamp = edit_ts, + mentions = json.dumps(get_message_mentions(message.mentions)), + attachments = json.dumps(get_attachments_list(message.attachments)), + embeds = json.dumps(get_embeds_list(message.embeds)) + ) async def update_message(self, message): - self.bot.loop.run_in_executor(None, self._update_message, message) - - def _update_message(self, message): if message.guild: - with self.get_session() as session: - msg = session.query(Messages) \ - .filter(Messages.guild_id == message.guild.id) \ - .filter(Messages.channel_id == message.channel.id) \ - .filter(Messages.message_id == message.id).first() - if msg: - msg.content = message.content - msg.timestamp = message.created_at - msg.edited_timestamp = message.edited_at - msg.mentions = json.dumps(get_message_mentions(message.mentions)) - msg.attachments = json.dumps(get_attachments_list(message.attachments)) - msg.embeds = json.dumps(get_embeds_list(message.embeds)) - msg.author = json.dumps(get_message_author(message)) - session.commit() + await Messages.get(int(message.id)).update( + content = message.content, + timestamp = message.created_at, + edited_timestamp = message.edited_at, + mentions = json.dumps(get_message_mentions(message.mentions)), + attachments = json.dumps(get_attachments_list(message.attachments)), + embeds = json.dumps(get_embeds_list(message.embeds)), + author = json.dumps(get_message_author(message)) + ).apply() async def delete_message(self, message): - self.bot.loop.run_in_executor(None, self._delete_message, message) - - def _delete_message(self, message): if message.guild: - with self.get_session() as session: - msg = session.query(Messages) \ - .filter(Messages.guild_id == int(message.guild.id)) \ - .filter(Messages.channel_id == int(message.channel.id)) \ - .filter(Messages.message_id == int(message.id)).first() - if msg: - session.delete(msg) - session.commit() + await Messages.get(int(message.id)).delete() async def update_guild(self, guild): if guild.me.guild_permissions.manage_webhooks: @@ -108,163 +63,124 @@ class DatabaseInterface(object): server_webhooks = [] else: server_webhooks = [] - self.bot.loop.run_in_executor(None, self._update_guild, guild, server_webhooks) - - def _update_guild(self, guild, server_webhooks): - with self.get_session() as session: - gui = session.query(Guilds).filter(Guilds.guild_id == guild.id).first() - if not gui: - gui = Guilds( - int(guild.id), - guild.name, - json.dumps(get_roles_list(guild.roles)), - json.dumps(get_channels_list(guild.channels)), - json.dumps(get_webhooks_list(server_webhooks)), - json.dumps(get_emojis_list(guild.emojis)), - int(guild.owner_id), - guild.icon - ) - session.add(gui) - else: - gui.name = guild.name - gui.roles = json.dumps(get_roles_list(guild.roles)) - gui.channels = json.dumps(get_channels_list(guild.channels)) - gui.webhooks = json.dumps(get_webhooks_list(server_webhooks)) - gui.emojis = json.dumps(get_emojis_list(guild.emojis)) - gui.owner_id = int(guild.owner_id) - gui.icon = guild.icon - session.commit() + gui = await Guilds.get(guild.id) + if not gui: + await Guilds.create( + guild_id = int(guild.id), + name = guild.name, + unauth_users = True, + visitor_view = False, + webhook_messages = False, + guest_icon = None, + chat_links = True, + bracket_links = True, + unauth_captcha = True, + mentions_limit = -1, + roles = json.dumps(get_roles_list(guild.roles)), + channels = json.dumps(get_channels_list(guild.channels)), + webhooks = json.dumps(get_webhooks_list(server_webhooks)), + emojis = json.dumps(get_emojis_list(guild.emojis)), + owner_id = int(guild.owner_id), + icon = guild.icon + ) + else: + await gui.update( + name = guild.name, + roles = json.dumps(get_roles_list(guild.roles)), + channels = json.dumps(get_channels_list(guild.channels)), + webhooks = json.dumps(get_webhooks_list(server_webhooks)), + emojis = json.dumps(get_emojis_list(guild.emojis)), + owner_id = int(guild.owner_id), + icon = guild.icon + ).apply() async def remove_unused_guilds(self, guilds): - self.bot.loop.run_in_executor(None, self._remove_unused_guilds, guilds) - - def _remove_unused_guilds(self, guilds): - with self.get_session() as session: - dbguilds = session.query(Guilds).all() - changed = False - for guild in dbguilds: - disguild = discord.utils.get(guilds, id=guild.guild_id) - if not disguild: - changed = True - dbmsgs = session.query(Messages).filter(Messages.guild_id == int(guild.guild_id)).all() - for msg in dbmsgs: - session.delete(msg) - session.delete(guild) - if changed: - session.commit() + dbguilds = await Guilds.query.gino.all() + for guild in dbguilds: + disguild = discord.utils.get(guilds, id=guild.guild_id) + if not disguild: + await Messages.delete.where(Messages.guild_id == int(guild.guild_id)).gino.status() async def remove_guild(self, guild): - self.bot.loop.run_in_executor(None, self._remove_guild, guild) - - def _remove_guild(self, guild): - with self.get_session() as session: - gui = session.query(Guilds).filter(Guilds.guild_id == int(guild.id)).first() - if gui: - dbmsgs = session.query(Messages).filter(Messages.guild_id == int(guild.id)).delete() - session.delete(gui) - session.commit() + gui = await Guilds.get(int(guild.id)) + if gui: + await Messages.delete.where(Messages.guild_id == int(guild.id)).gino.status() + await gui.delete() async def update_guild_member(self, member, active=True, banned=False, guild=None): - self.bot.loop.run_in_executor(None, self._update_guild_member, member, active, banned, guild) - - def _update_guild_member(self, member, active=True, banned=False, guild=None): - with self.get_session() as session: - if guild: - dbmember = session.query(GuildMembers) \ - .filter(GuildMembers.guild_id == int(guild.id)) \ - .filter(GuildMembers.user_id == int(member.id)) \ - .order_by(GuildMembers.id).all() - else: - dbmember = session.query(GuildMembers) \ - .filter(GuildMembers.guild_id == int(member.guild.id)) \ - .filter(GuildMembers.user_id == int(member.id)) \ - .order_by(GuildMembers.id).all() - if not dbmember: - dbmember = GuildMembers( - int(member.guild.id), - int(member.id), - member.name, - member.discriminator, - member.nick, - member.avatar, - active, - banned, - json.dumps(list_role_ids(member.roles)) - ) - session.add(dbmember) - else: - if len(dbmember) > 1: - for mem in dbmember[1:]: - session.delete(mem) + if guild: + dbmember = await GuildMembers.query \ + .where(GuildMembers.guild_id == int(guild.id)) \ + .where(GuildMembers.user_id == int(member.id)) \ + .order_by(GuildMembers.id).gino.all() + else: + dbmember = await GuildMembers.query \ + .where(GuildMembers.guild_id == int(member.guild.id)) \ + .where(GuildMembers.user_id == int(member.id)) \ + .order_by(GuildMembers.id).gino.all() + if not dbmember: + await GuildMembers.create( + guild_id = int(member.guild.id), + user_id = int(member.id), + username = member.name, + discriminator = member.discriminator, + nickname = member.nick, + avatar = member.avatar, + active = active, + banned = banned, + roles = json.dumps(list_role_ids(member.roles)) + ) + else: + if len(dbmember) > 1: + for mem in dbmember[1:]: + await mem.delete() dbmember = dbmember[0] - if dbmember.banned != banned or dbmember.active != active or dbmember.username != member.name or dbmember.discriminator != int(member.discriminator) or dbmember.nickname != member.nick or dbmember.avatar != member.avatar or set(json.loads(dbmember.roles)) != set(list_role_ids(member.roles)): - dbmember.banned = banned - dbmember.active = active - dbmember.username = member.name - dbmember.discriminator = member.discriminator - dbmember.nickname = member.nick - dbmember.avatar = member.avatar - dbmember.roles = json.dumps(list_role_ids(member.roles)) - session.commit() + if dbmember.banned != banned or dbmember.active != active or dbmember.username != member.name or dbmember.discriminator != int(member.discriminator) or dbmember.nickname != member.nick or dbmember.avatar != member.avatar or set(json.loads(dbmember.roles)) != set(list_role_ids(member.roles)): + await dbmember.update( + banned = banned, + active = active, + username = member.name, + discriminator = member.discriminator, + nickname = member.nick, + avatar = member.avatar, + roles = json.dumps(list_role_ids(member.roles)) + ).apply() async def unban_server_user(self, user, server): - self.bot.loop.run_in_executor(None, self._unban_server_user, user, server) - - def _unban_server_user(self, user, server): - with self.get_session() as session: - dbmember = session.query(GuildMembers) \ - .filter(GuildMembers.guild_id == int(server.id)) \ - .filter(GuildMembers.user_id == int(user.id)).first() - if dbmember: - dbmember.banned = False - session.commit() + await GuildMembers.query \ + .where(GuildMembers.guild_id == int(server.id)) \ + .where(GuildMembers.user_id == int(user.id)) \ + .update(banned = False).apply() async def flag_unactive_guild_members(self, guild_id, guild_members): - self.bot.loop.run_in_executor(None, self._flag_unactive_guild_members, guild_id, guild_members) - - def _flag_unactive_guild_members(self, guild_id, guild_members): - with self.get_session() as session: - changed = False - dbmembers = session.query(GuildMembers) \ - .filter(GuildMembers.guild_id == int(guild_id)) \ - .filter(GuildMembers.active == True).all() - for member in dbmembers: + async with db.transaction(): + async for member in GuildMembers.query \ + .where(GuildMembers.guild_id == int(guild_id)) \ + .where(GuildMembers.active == True).gino.iterate(): dismember = discord.utils.get(guild_members, id=member.user_id) if not dismember: - changed = True - member.active = False - if changed: - session.commit() + await member.update(active = False).apply() async def flag_unactive_bans(self, guild_id, guildbans): - self.bot.loop.run_in_executor(None, self._flag_unactive_bans, guild_id, guildbans) - - def _flag_unactive_bans(self, guild_id, guildbans): - with self.get_session() as session: - changed = False - for usr in guildbans: - dbusr = session.query(GuildMembers) \ - .filter(GuildMembers.guild_id == int(guild_id)) \ - .filter(GuildMembers.user_id == int(usr.id)) \ - .filter(GuildMembers.active == False).first() - changed = True - if dbusr: - dbusr.banned = True - else: - dbusr = GuildMembers( - int(guild_id), - int(usr.id), - usr.name, - usr.discriminator, - None, - usr.avatar, - False, - True, - "[]" - ) - session.add(dbusr) - if changed: - session.commit() + for usr in guildbans: + dbusr = await GuildMembers.query \ + .where(GuildMembers.guild_id == int(guild_id)) \ + .where(GuildMembers.user_id == int(usr.id)) \ + .where(GuildMembers.active == False).gino.first() + if dbusr: + dbusr.update(banned=True).apply() + else: + await GuildMembers.create( + guild_id = int(guild_id), + user_id = int(usr.id), + username = usr.name, + discriminator = usr.discriminator, + nickname = None, + avatar = usr.avatar, + active = False, + banned = True, + roles = "[]" + ) async def ban_unauth_user_by_query(self, guild_id, placer_id, username, discriminator): self.bot.loop.run_in_executor(None, self._ban_unauth_user_by_query, guild_id, placer_id, username, discriminator) @@ -324,9 +240,4 @@ class DatabaseInterface(object): return "Successfully kicked **{}#{}**!".format(dbuser.username, dbuser.discriminator) async def delete_all_messages_from_channel(self, channel_id): - self.bot.loop.run_in_executor(None, self._delete_all_messages_from_channel, channel_id) - - def _delete_all_messages_from_channel(self, channel_id): - with self.get_session() as session: - session.query(Messages).filter(Messages.channel_id == int(channel_id)).delete() - session.commit() \ No newline at end of file + await Messages.delete.where(Messages.channel_id == int(channel_id)).gino.status() \ No newline at end of file diff --git a/discordbot/titanembeds/database/guild_members.py b/discordbot/titanembeds/database/guild_members.py index 2b4494e..5197039 100644 --- a/discordbot/titanembeds/database/guild_members.py +++ b/discordbot/titanembeds/database/guild_members.py @@ -1,6 +1,6 @@ -from titanembeds.database import db, Base +from titanembeds.database import db -class GuildMembers(Base): +class GuildMembers(db.Model): __tablename__ = "guild_members" id = db.Column(db.Integer, primary_key=True) # Auto incremented id guild_id = db.Column(db.BigInteger) # Discord guild id @@ -11,18 +11,4 @@ class GuildMembers(Base): avatar = db.Column(db.String(255)) # The avatar str of the user active = db.Column(db.Boolean()) # If the user is a member of the guild banned = db.Column(db.Boolean()) # If the user is banned in the guild - roles = db.Column(db.Text()) # Member roles - - def __init__(self, guild_id, user_id, username, discriminator, nickname, avatar, active, banned, roles): - self.guild_id = guild_id - self.user_id = user_id - self.username = username - self.discriminator = discriminator - self.nickname = nickname - self.avatar = avatar - self.active = active - self.banned = banned - self.roles = roles - - def __repr__(self): - return ''.format(self.id, self.guild_id, self.user_id, self.username, self.discriminator) + roles = db.Column(db.Text()) # Member roles \ No newline at end of file diff --git a/discordbot/titanembeds/database/guilds.py b/discordbot/titanembeds/database/guilds.py index 298eeb6..df77de0 100644 --- a/discordbot/titanembeds/database/guilds.py +++ b/discordbot/titanembeds/database/guilds.py @@ -1,6 +1,6 @@ -from titanembeds.database import db, Base +from titanembeds.database import db -class Guilds(Base): +class Guilds(db.Model): __tablename__ = "guilds" guild_id = db.Column(db.BigInteger, primary_key=True) # Discord guild id name = db.Column(db.String(255)) # Name @@ -23,25 +23,4 @@ class Guilds(Base): max_message_length = db.Column(db.Integer, nullable=False, server_default="300") # Chars length the message should be before being rejected by the server banned_words_enabled = db.Column(db.Boolean(), nullable=False, server_default="0") # If banned words are enforced banned_words_global_included = db.Column(db.Boolean(), nullable=False, server_default="0") # Add global banned words to the list - banned_words = db.Column(db.Text(), nullable=False, server_default="[]") # JSON list of strings to block from sending - - def __init__(self, guild_id, name, roles, channels, webhooks, emojis, owner_id, icon): - self.guild_id = guild_id - self.name = name - self.unauth_users = True # defaults to true - self.visitor_view = False - self.webhook_messages = False - self.guest_icon = None - self.chat_links = True - self.bracket_links = True - self.unauth_captcha = True - self.mentions_limit = -1 # -1 = unlimited mentions - self.roles = roles - self.channels = channels - self.webhooks = webhooks - self.emojis = emojis - self.owner_id = owner_id - self.icon = icon - - def __repr__(self): - return ''.format(self.id, self.guild_id) + banned_words = db.Column(db.Text(), nullable=False, server_default="[]") # JSON list of strings to block from sending \ No newline at end of file diff --git a/discordbot/titanembeds/database/messages.py b/discordbot/titanembeds/database/messages.py index 91b469a..f8ba868 100644 --- a/discordbot/titanembeds/database/messages.py +++ b/discordbot/titanembeds/database/messages.py @@ -1,6 +1,6 @@ -from titanembeds.database import db, Base +from titanembeds.database import db -class Messages(Base): +class Messages(db.Model): __tablename__ = "messages" message_id = db.Column(db.BigInteger, primary_key=True) # Message snowflake guild_id = db.Column(db.BigInteger) # Discord guild id @@ -11,19 +11,4 @@ class Messages(Base): edited_timestamp = db.Column(db.TIMESTAMP) # Timestamp of when content is edited mentions = db.Column(db.Text()) # Mentions serialized attachments = db.Column(db.Text()) # serialized attachments - embeds = db.Column(db.Text().with_variant(db.Text(length=4294967295), 'mysql')) # message embeds - - def __init__(self, guild_id, channel_id, message_id, content, author, timestamp, edited_timestamp, mentions, attachments, embeds): - self.guild_id = guild_id - self.channel_id = channel_id - self.message_id = message_id - self.content = content - self.author = author - self.timestamp = timestamp - self.edited_timestamp = edited_timestamp - self.mentions = mentions - self.attachments = attachments - self.embeds = embeds - - def __repr__(self): - return ''.format(self.id, self.guild_id, self.guild_id, self.channel_id, self.message_id) + embeds = db.Column(db.Text().with_variant(db.Text(length=4294967295), 'mysql')) # message embeds \ No newline at end of file diff --git a/discordbot/titanembeds/database/unauthenticated_bans.py b/discordbot/titanembeds/database/unauthenticated_bans.py index 5acd51d..adfad16 100644 --- a/discordbot/titanembeds/database/unauthenticated_bans.py +++ b/discordbot/titanembeds/database/unauthenticated_bans.py @@ -1,8 +1,8 @@ -from titanembeds.database import db, Base +from titanembeds.database import db import datetime import time -class UnauthenticatedBans(Base): +class UnauthenticatedBans(db.Model): __tablename__ = "unauthenticated_bans" id = db.Column(db.Integer, primary_key=True) # Auto increment id guild_id = db.Column(db.String(255)) # Guild pretaining to the unauthenticated user @@ -12,17 +12,4 @@ class UnauthenticatedBans(Base): timestamp = db.Column(db.TIMESTAMP) # The timestamp of when the user got banned reason = db.Column(db.Text()) # The reason of the ban set by the guild moderators lifter_id = db.Column(db.BigInteger) # Discord Client ID of the user who lifted the ban - placer_id = db.Column(db.BigInteger) # The id of who placed the ban - - def __init__(self, guild_id, ip_address, last_username, last_discriminator, reason, placer_id): - self.guild_id = guild_id - self.ip_address = ip_address - self.last_username = last_username - self.last_discriminator = last_discriminator - self.timestamp = datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H:%M:%S') - self.reason = reason - self.lifter_id = None - self.placer_id = placer_id - - def __repr__(self): - return ''.format(self.id, self.guild_id, self.username, self.discriminator, self.user_key, self.ip_address, self.revoked) + revoked = db.Column(db.Boolean()) # If the user's key has been revoked and a new one is required to be generated \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index a0f111d..65fc425 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ redis aioredis Flask-Babel patreon -flask-redis \ No newline at end of file +flask-redis +gino \ No newline at end of file From de94f87c456ec4a713e24911243a881e72c9f030 Mon Sep 17 00:00:00 2001 From: Jeremy Zhang Date: Tue, 10 Jul 2018 19:06:54 +0000 Subject: [PATCH 2/2] Fix gino bugs, reimplement ban and revoke cmds, enforce postgresql --- README.md | 1 + cloud9_install.sh | 4 +- discordbot/titanembeds/bot.py | 14 +-- discordbot/titanembeds/database/__init__.py | 133 ++++++++++---------- 4 files changed, 72 insertions(+), 80 deletions(-) diff --git a/README.md b/README.md index 384362e..06788ac 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ If you happen to have a copy of Ubuntu on your server, you may head onto our [An # Database installation To set up the database for it to work with the webapp and the discordbot, one must use **alembic** to *migrate* their databases to the current database state. To do so, please follow these instructions. +**PostgreSQL supports proper indexing and suitable for Titan needs. For this reason, Titan only supports using a PostgreSQL database.** 1. Install alembic with **Python 3.5's pip** `pip install alembic` 2. Change your directory to the webapp where the alembic files are located `cd webapp` 3. Clone `alembic.example.ini` into your own `alembic.ini` file to find and edit the following line `sqlalchemy.url` to equal your database uri. [See here](http://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls) if you need help understanding how database uri works in SQLalchemy. diff --git a/cloud9_install.sh b/cloud9_install.sh index 40336ce..a562143 100644 --- a/cloud9_install.sh +++ b/cloud9_install.sh @@ -23,8 +23,8 @@ sed -i '32s/.*/sqlalchemy.url = postgresql:\/\/\/titan/' ~/workspace/webapp/ale alembic upgrade head echo "[C9Setup] Setting database uri for discordbot/config.py" -#'database-uri': "mysql+psycopg2:///titan?client_encoding=utf8", -sed -i "4s/.*/\'database-uri\': \"postgresql+psycopg2:\/\/\/titan?client_encoding=utf8\",/" ~/workspace/discordbot/config.py +#'database-uri': "postgresql:///titan", +sed -i "4s/.*/\'database-uri\': \"postgresql:\/\/\/titan\",/" ~/workspace/discordbot/config.py echo "[C9Setup] Setting database uri and app location for webapp/config.py" sed -i "19s/.*/\'database-uri\': \"postgresql+psycopg2:\/\/\/titan?client_encoding=utf8\",/" ~/workspace/webapp/config.py diff --git a/discordbot/titanembeds/bot.py b/discordbot/titanembeds/bot.py index b26349d..76ba965 100644 --- a/discordbot/titanembeds/bot.py +++ b/discordbot/titanembeds/bot.py @@ -45,7 +45,7 @@ class Titan(discord.AutoShardedClient): def run(self): try: - self.loop.run_until_complete(self.start(config["bot-token"])) + self.loop.run_until_complete(self.start()) except discord.errors.LoginFailure: print("Invalid bot token in config!") finally: @@ -54,6 +54,10 @@ class Titan(discord.AutoShardedClient): except Exception as e: print("Error in cleanup:", e) self.loop.close() + + async def start(self): + await self.database.connect(config["database-uri"]) + await super().start(config["bot-token"]) async def on_ready(self): print('Titan [DiscordBot]') @@ -66,14 +70,6 @@ class Titan(discord.AutoShardedClient): game = discord.Game(name="Embed your Discord server! Visit https://TitanEmbeds.com/") await self.change_presence(status=discord.Status.online, activity=game) - - try: - await self.database.connect(config["database-uri"]) - except Exception: - self.logger.error("Unable to connect to specified database!") - traceback.print_exc() - await self.logout() - return self.discordBotsOrg = DiscordBotsOrg(self.user.id, config.get("discord-bots-org-token", None)) self.botsDiscordPw = BotsDiscordPw(self.user.id, config.get("bots-discord-pw-token", None)) diff --git a/discordbot/titanembeds/database/__init__.py b/discordbot/titanembeds/database/__init__.py index 6b05953..da62dce 100644 --- a/discordbot/titanembeds/database/__init__.py +++ b/discordbot/titanembeds/database/__init__.py @@ -1,6 +1,7 @@ from gino import Gino import json import discord +import datetime db = Gino() @@ -21,19 +22,14 @@ class DatabaseInterface(object): async def push_message(self, message): if message.guild: - edit_ts = message.edited_at - if not edit_ts: - edit_ts = None - else: - edit_ts = str(edit_ts) await Messages.create( message_id = int(message.id), guild_id = int(message.guild.id), channel_id = int(message.channel.id), content = message.content, author = json.dumps(get_message_author(message)), - timestamp = str(message.created_at), - edited_timestamp = edit_ts, + timestamp = message.created_at, + edited_timestamp = message.edited_at, mentions = json.dumps(get_message_mentions(message.mentions)), attachments = json.dumps(get_attachments_list(message.attachments)), embeds = json.dumps(get_embeds_list(message.embeds)) @@ -41,7 +37,7 @@ class DatabaseInterface(object): async def update_message(self, message): if message.guild: - await Messages.get(int(message.id)).update( + await Messages.update.values( content = message.content, timestamp = message.created_at, edited_timestamp = message.edited_at, @@ -49,11 +45,11 @@ class DatabaseInterface(object): attachments = json.dumps(get_attachments_list(message.attachments)), embeds = json.dumps(get_embeds_list(message.embeds)), author = json.dumps(get_message_author(message)) - ).apply() + ).where(Messages.message_id == int(message.id)).gino.status() async def delete_message(self, message): if message.guild: - await Messages.get(int(message.id)).delete() + await Messages.delete.where(Messages.message_id == int(message.id)).gino.status() async def update_guild(self, guild): if guild.me.guild_permissions.manage_webhooks: @@ -123,7 +119,7 @@ class DatabaseInterface(object): guild_id = int(member.guild.id), user_id = int(member.id), username = member.name, - discriminator = member.discriminator, + discriminator = int(member.discriminator), nickname = member.nick, avatar = member.avatar, active = active, @@ -134,23 +130,24 @@ class DatabaseInterface(object): if len(dbmember) > 1: for mem in dbmember[1:]: await mem.delete() - dbmember = dbmember[0] + dbmember = dbmember[0] if dbmember.banned != banned or dbmember.active != active or dbmember.username != member.name or dbmember.discriminator != int(member.discriminator) or dbmember.nickname != member.nick or dbmember.avatar != member.avatar or set(json.loads(dbmember.roles)) != set(list_role_ids(member.roles)): await dbmember.update( banned = banned, active = active, username = member.name, - discriminator = member.discriminator, + discriminator = int(member.discriminator), nickname = member.nick, avatar = member.avatar, roles = json.dumps(list_role_ids(member.roles)) ).apply() async def unban_server_user(self, user, server): - await GuildMembers.query \ + await GuildMembers.update.values(banned = False) \ .where(GuildMembers.guild_id == int(server.id)) \ .where(GuildMembers.user_id == int(user.id)) \ - .update(banned = False).apply() + .gino.status() + async def flag_unactive_guild_members(self, guild_id, guild_members): async with db.transaction(): @@ -174,7 +171,7 @@ class DatabaseInterface(object): guild_id = int(guild_id), user_id = int(usr.id), username = usr.name, - discriminator = usr.discriminator, + discriminator = int(usr.discriminator), nickname = None, avatar = usr.avatar, active = False, @@ -183,61 +180,59 @@ class DatabaseInterface(object): ) async def ban_unauth_user_by_query(self, guild_id, placer_id, username, discriminator): - self.bot.loop.run_in_executor(None, self._ban_unauth_user_by_query, guild_id, placer_id, username, discriminator) - - def _ban_unauth_user_by_query(self, guild_id, placer_id, username, discriminator): - with self.get_session() as session: - dbuser = None - if discriminator: - dbuser = session.query(UnauthenticatedUsers) \ - .filter(UnauthenticatedUsers.guild_id == int(guild_id)) \ - .filter(UnauthenticatedUsers.username.ilike("%" + username + "%")) \ - .filter(UnauthenticatedUsers.discriminator == discriminator) \ - .order_by(UnauthenticatedUsers.id.desc()).first() - else: - dbuser = session.query(UnauthenticatedUsers) \ - .filter(UnauthenticatedUsers.guild_id == int(guild_id)) \ - .filter(UnauthenticatedUsers.username.ilike("%" + username + "%")) \ - .order_by(UnauthenticatedUsers.id.desc()).first() - if not dbuser: - return "Ban error! Guest user cannot be found." - dbban = session.query(UnauthenticatedBans) \ - .filter(UnauthenticatedBans.guild_id == int(guild_id)) \ - .filter(UnauthenticatedBans.last_username == dbuser.username) \ - .filter(UnauthenticatedBans.last_discriminator == dbuser.discriminator).first() - if dbban is not None: - if dbban.lifter_id is None: - return "Ban error! Guest user, **{}#{}**, has already been banned.".format(dbban.last_username, dbban.last_discriminator) - session.delete(dbban) - dbban = UnauthenticatedBans(int(guild_id), dbuser.ip_address, dbuser.username, dbuser.discriminator, "", int(placer_id)) - session.add(dbban) - session.commit() - return "Guest user, **{}#{}**, has successfully been added to the ban list!".format(dbban.last_username, dbban.last_discriminator) + dbuser = None + if discriminator: + dbuser = await UnauthenticatedUsers.query \ + .where(UnauthenticatedUsers.guild_id == int(guild_id)) \ + .where(UnauthenticatedUsers.username.ilike("%" + username + "%")) \ + .where(UnauthenticatedUsers.discriminator == discriminator) \ + .order_by(UnauthenticatedUsers.id.desc()).gino.first() + else: + dbuser = await UnauthenticatedUsers.query \ + .where(UnauthenticatedUsers.guild_id == int(guild_id)) \ + .where(UnauthenticatedUsers.username.ilike("%" + username + "%")) \ + .order_by(UnauthenticatedUsers.id.desc()).gino.first() + if not dbuser: + return "Ban error! Guest user cannot be found." + dbban = await UnauthenticatedBans.query \ + .where(UnauthenticatedBans.guild_id == int(guild_id)) \ + .where(UnauthenticatedBans.last_username == dbuser.username) \ + .where(UnauthenticatedBans.last_discriminator == dbuser.discriminator).gino.first() + if dbban is not None: + if dbban.lifter_id is None: + return "Ban error! Guest user, **{}#{}**, has already been banned.".format(dbban.last_username, dbban.last_discriminator) + await dbban.delete() + dbban = await UnauthenticatedBans.create( + guild_id = int(guild_id), + ip_address = dbuser.ip_address, + last_username = dbuser.username, + last_discriminator = dbuser.discriminator, + timestamp = datetime.datetime.now(), + reason = "", + lifter_id = None, + placer_id = int(placer_id) + ) + return "Guest user, **{}#{}**, has successfully been added to the ban list!".format(dbban.last_username, dbban.last_discriminator) async def revoke_unauth_user_by_query(self, guild_id, username, discriminator): - self.bot.loop.run_in_executor(None, self._revoke_unauth_user_by_query, guild_id, username, discriminator) + dbuser = None + if discriminator: + dbuser = await UnauthenticatedUsers.query \ + .where(UnauthenticatedUsers.guild_id == int(guild_id)) \ + .where(UnauthenticatedUsers.username.ilike("%" + username + "%")) \ + .where(UnauthenticatedUsers.discriminator == discriminator) \ + .order_by(UnauthenticatedUsers.id.desc()).gino.first() + else: + dbuser = await UnauthenticatedUsers.query \ + .where(UnauthenticatedUsers.guild_id == int(guild_id)) \ + .where(UnauthenticatedUsers.username.ilike("%" + username + "%")) \ + .order_by(UnauthenticatedUsers.id.desc()).gino.first() + if not dbuser: + return "Kick error! Guest user cannot be found." + elif dbuser.revoked: + return "Kick error! Guest user **{}#{}** has already been kicked!".format(dbuser.username, dbuser.discriminator) + await dbuser.update(revoked = True).apply() + return "Successfully kicked **{}#{}**!".format(dbuser.username, dbuser.discriminator) - def _revoke_unauth_user_by_query(self, guild_id, username, discriminator): - with self.get_session() as session: - dbuser = None - if discriminator: - dbuser = session.query(UnauthenticatedUsers) \ - .filter(UnauthenticatedUsers.guild_id == int(guild_id)) \ - .filter(UnauthenticatedUsers.username.ilike("%" + username + "%")) \ - .filter(UnauthenticatedUsers.discriminator == discriminator) \ - .order_by(UnauthenticatedUsers.id.desc()).first() - else: - dbuser = session.query(UnauthenticatedUsers) \ - .filter(UnauthenticatedUsers.guild_id == int(guild_id)) \ - .filter(UnauthenticatedUsers.username.ilike("%" + username + "%")) \ - .order_by(UnauthenticatedUsers.id.desc()).first() - if not dbuser: - return "Kick error! Guest user cannot be found." - elif dbuser.revoked: - return "Kick error! Guest user **{}#{}** has already been kicked!".format(dbuser.username, dbuser.discriminator) - dbuser.revoked = True - session.commit() - return "Successfully kicked **{}#{}**!".format(dbuser.username, dbuser.discriminator) - async def delete_all_messages_from_channel(self, channel_id): await Messages.delete.where(Messages.channel_id == int(channel_id)).gino.status() \ No newline at end of file