2018-07-21 01:10:24 +02:00
|
|
|
from titanembeds.utils import get_formatted_message, get_formatted_user, get_formatted_guild
|
2018-07-16 05:50:31 +02:00
|
|
|
from urllib.parse import urlparse
|
|
|
|
import asyncio_redis
|
|
|
|
import json
|
|
|
|
import discord
|
|
|
|
import asyncio
|
|
|
|
import traceback
|
|
|
|
import sys
|
|
|
|
import re
|
|
|
|
|
|
|
|
class RedisQueue:
|
|
|
|
def __init__(self, bot, redis_uri):
|
|
|
|
self.bot = bot
|
|
|
|
self.redis_uri = redis_uri
|
|
|
|
|
|
|
|
async def connect(self):
|
|
|
|
url_parsed = urlparse(self.redis_uri)
|
|
|
|
url_path = 0
|
|
|
|
if url_parsed.path and len(url_parsed.path) > 2:
|
|
|
|
url_path = int(url_parsed.path[1:])
|
|
|
|
self.sub_connection = await asyncio_redis.Connection.create(
|
|
|
|
host = url_parsed.hostname or "localhost",
|
|
|
|
port = url_parsed.port or 6379,
|
|
|
|
password = url_parsed.password,
|
|
|
|
db = url_path
|
|
|
|
)
|
|
|
|
self.connection = await asyncio_redis.Pool.create(
|
|
|
|
host = url_parsed.hostname or "localhost",
|
|
|
|
port = url_parsed.port or 6379,
|
|
|
|
password = url_parsed.password,
|
|
|
|
db = url_path,
|
|
|
|
poolsize = 10
|
|
|
|
)
|
|
|
|
|
|
|
|
async def subscribe(self):
|
|
|
|
await self.bot.wait_until_ready()
|
|
|
|
subscriber = await self.sub_connection.start_subscribe()
|
|
|
|
await subscriber.subscribe(["discord-api-req"])
|
2019-03-05 22:02:33 +01:00
|
|
|
while True:
|
|
|
|
if not self.bot.is_ready() or self.bot.is_closed():
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
continue
|
2018-07-16 05:50:31 +02:00
|
|
|
reply = await subscriber.next_published()
|
|
|
|
request = json.loads(reply.value)
|
|
|
|
resource = request["resource"]
|
|
|
|
self.dispatch(resource, request["key"], request["params"])
|
2019-03-06 04:11:26 +01:00
|
|
|
await asyncio.sleep(0)
|
2018-07-16 05:50:31 +02:00
|
|
|
|
|
|
|
def dispatch(self, event, key, params):
|
|
|
|
method = "on_" + event
|
|
|
|
if hasattr(self, method):
|
2019-03-05 22:02:33 +01:00
|
|
|
self.bot.loop.create_task(self._run_event(method, key, params))
|
2018-07-16 05:50:31 +02:00
|
|
|
|
|
|
|
async def _run_event(self, event, key, params):
|
|
|
|
try:
|
|
|
|
await getattr(self, event)(key, params)
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
pass
|
|
|
|
except Exception:
|
|
|
|
try:
|
|
|
|
await self.on_error(event)
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
async def on_error(self, event_method):
|
|
|
|
print('Ignoring exception in {}'.format(event_method), file=sys.stderr)
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
async def set_scan_json(self, key, dict_key, dict_value_pattern):
|
|
|
|
unformatted_item = None
|
|
|
|
formatted_item = None
|
|
|
|
exists = await self.connection.exists(key)
|
|
|
|
if exists:
|
|
|
|
members = await self.connection.smembers(key)
|
|
|
|
for member in members:
|
|
|
|
the_member = await member
|
2018-07-21 01:10:24 +02:00
|
|
|
if not the_member:
|
|
|
|
continue
|
2018-07-16 05:50:31 +02:00
|
|
|
parsed = json.loads(the_member)
|
|
|
|
if re.match(str(dict_value_pattern), str(parsed[dict_key])):
|
|
|
|
unformatted_item = the_member
|
|
|
|
formatted_item = parsed
|
|
|
|
break
|
|
|
|
return (unformatted_item, formatted_item)
|
|
|
|
|
2018-07-18 01:47:30 +02:00
|
|
|
async def enforce_expiring_key(self, key):
|
|
|
|
ttl = await self.connection.ttl(key)
|
|
|
|
newttl = 0
|
|
|
|
if ttl == -1:
|
|
|
|
newttl = 60 * 5 # 5 minutes
|
|
|
|
if ttl >= 0:
|
|
|
|
newttl = ttl
|
|
|
|
await self.connection.expire(key, newttl)
|
|
|
|
|
2018-07-16 05:50:31 +02:00
|
|
|
async def on_get_channel_messages(self, key, params):
|
|
|
|
channel = self.bot.get_channel(int(params["channel_id"]))
|
|
|
|
if not channel or not isinstance(channel, discord.channel.TextChannel):
|
|
|
|
return
|
|
|
|
await self.connection.delete([key])
|
|
|
|
messages = []
|
2019-01-19 21:07:57 +01:00
|
|
|
me = channel.guild.get_member(self.bot.user.id)
|
|
|
|
if channel.permissions_for(me).read_messages:
|
|
|
|
async for message in channel.history(limit=50):
|
|
|
|
formatted = get_formatted_message(message)
|
|
|
|
messages.append(json.dumps(formatted, separators=(',', ':')))
|
2018-07-18 01:47:30 +02:00
|
|
|
await self.connection.sadd(key, [""] + messages)
|
2018-07-16 05:50:31 +02:00
|
|
|
|
|
|
|
async def push_message(self, message):
|
|
|
|
if message.guild:
|
|
|
|
key = "Queue/channels/{}/messages".format(message.channel.id)
|
|
|
|
exists = await self.connection.exists(key)
|
|
|
|
if exists:
|
|
|
|
message = get_formatted_message(message)
|
2018-07-21 01:10:24 +02:00
|
|
|
await self.connection.sadd(key, [json.dumps(message, separators=(',', ':'))])
|
2018-07-16 05:50:31 +02:00
|
|
|
|
|
|
|
async def delete_message(self, message):
|
|
|
|
if message.guild:
|
|
|
|
key = "Queue/channels/{}/messages".format(message.channel.id)
|
|
|
|
exists = await self.connection.exists(key)
|
|
|
|
if exists:
|
|
|
|
unformatted_item, formatted_item = await self.set_scan_json(key, "id", message.id)
|
|
|
|
if formatted_item:
|
|
|
|
await self.connection.srem(key, [unformatted_item])
|
|
|
|
|
|
|
|
async def update_message(self, message):
|
|
|
|
await self.delete_message(message)
|
2018-07-18 01:47:30 +02:00
|
|
|
await self.push_message(message)
|
|
|
|
|
|
|
|
async def on_get_guild_member(self, key, params):
|
2018-12-30 21:22:58 +01:00
|
|
|
guild = self.bot.get_guild(int(params["guild_id"]))
|
|
|
|
if not guild:
|
|
|
|
return
|
|
|
|
member = guild.get_member(int(params["user_id"]))
|
2018-07-18 01:47:30 +02:00
|
|
|
if not member:
|
2021-03-13 02:15:58 +01:00
|
|
|
members = await guild.query_members(user_ids=[int(params["user_id"])], cache=True)
|
|
|
|
if not len(members):
|
2021-03-13 02:30:39 +01:00
|
|
|
self.remove_member(member)
|
2021-03-13 02:15:58 +01:00
|
|
|
return
|
|
|
|
else:
|
|
|
|
member = members[0]
|
2018-07-18 01:47:30 +02:00
|
|
|
user = get_formatted_user(member)
|
2018-07-21 01:10:24 +02:00
|
|
|
await self.connection.set(key, json.dumps(user, separators=(',', ':')))
|
2018-07-31 05:42:21 +02:00
|
|
|
await self.enforce_expiring_key(key)
|
2018-07-18 01:47:30 +02:00
|
|
|
|
|
|
|
async def on_get_guild_member_named(self, key, params):
|
|
|
|
guild = self.bot.get_guild(int(params["guild_id"]))
|
2020-02-18 03:00:19 +01:00
|
|
|
if not guild:
|
|
|
|
return
|
2018-07-18 01:47:30 +02:00
|
|
|
query = params["query"]
|
|
|
|
result = None
|
2020-02-18 03:00:19 +01:00
|
|
|
members = guild.members
|
2019-01-05 21:54:50 +01:00
|
|
|
if members and len(query) > 5 and query[-5] == '#':
|
2018-07-18 01:47:30 +02:00
|
|
|
potential_discriminator = query[-4:]
|
|
|
|
result = discord.utils.get(members, name=query[:-5], discriminator=potential_discriminator)
|
|
|
|
if not result:
|
|
|
|
result = discord.utils.get(members, nick=query[:-5], discriminator=potential_discriminator)
|
|
|
|
if not result:
|
|
|
|
result = ""
|
|
|
|
else:
|
|
|
|
result_id = result.id
|
2018-07-21 01:10:24 +02:00
|
|
|
result = json.dumps({"user_id": result_id}, separators=(',', ':'))
|
2018-07-18 01:47:30 +02:00
|
|
|
get_guild_member_key = "Queue/guilds/{}/members/{}".format(guild.id, result_id)
|
|
|
|
get_guild_member_param = {"guild_id": guild.id, "user_id": result_id}
|
|
|
|
await self.on_get_guild_member(get_guild_member_key, get_guild_member_param)
|
|
|
|
await self.connection.set(key, result)
|
2018-07-31 05:42:21 +02:00
|
|
|
await self.enforce_expiring_key(key)
|
2018-07-18 01:47:30 +02:00
|
|
|
|
|
|
|
async def on_list_guild_members(self, key, params):
|
|
|
|
guild = self.bot.get_guild(int(params["guild_id"]))
|
2020-02-18 03:00:19 +01:00
|
|
|
if not guild:
|
|
|
|
return
|
2018-07-18 01:47:30 +02:00
|
|
|
members = guild.members
|
|
|
|
member_ids = []
|
|
|
|
for member in members:
|
2018-07-21 01:10:24 +02:00
|
|
|
member_ids.append(json.dumps({"user_id": member.id}, separators=(',', ':')))
|
2018-07-18 01:47:30 +02:00
|
|
|
get_guild_member_key = "Queue/guilds/{}/members/{}".format(guild.id, member.id)
|
|
|
|
get_guild_member_param = {"guild_id": guild.id, "user_id": member.id}
|
|
|
|
await self.on_get_guild_member(get_guild_member_key, get_guild_member_param)
|
|
|
|
await self.connection.sadd(key, member_ids)
|
|
|
|
|
|
|
|
async def add_member(self, member):
|
|
|
|
key = "Queue/guilds/{}/members".format(member.guild.id)
|
|
|
|
exists = await self.connection.exists(key)
|
|
|
|
if exists:
|
2018-07-21 01:10:24 +02:00
|
|
|
await self.connection.sadd(key, [json.dumps({"user_id": member.id}, separators=(',', ':'))])
|
2018-07-18 01:47:30 +02:00
|
|
|
|
|
|
|
async def remove_member(self, member, guild=None):
|
|
|
|
if not guild:
|
|
|
|
guild = member.guild
|
|
|
|
guild_member_key = "Queue/guilds/{}/members/{}".format(guild.id, member.id)
|
|
|
|
list_members_key = "Queue/guilds/{}/members".format(guild.id)
|
2018-07-21 01:10:24 +02:00
|
|
|
await self.connection.srem(list_members_key, [json.dumps({"user_id": member.id}, separators=(',', ':'))])
|
2021-03-13 01:49:37 +01:00
|
|
|
await self.connection.delete([guild_member_key])
|
2018-07-18 01:47:30 +02:00
|
|
|
|
|
|
|
async def update_member(self, member):
|
|
|
|
await self.remove_member(member)
|
|
|
|
await self.add_member(member)
|
|
|
|
|
|
|
|
async def ban_member(self, guild, user):
|
2018-07-21 01:10:24 +02:00
|
|
|
await self.remove_member(user, guild)
|
|
|
|
|
|
|
|
async def on_get_guild(self, key, params):
|
|
|
|
guild = self.bot.get_guild(int(params["guild_id"]))
|
|
|
|
if not guild:
|
|
|
|
return
|
2018-10-28 02:07:41 +01:00
|
|
|
if guild.me and guild.me.guild_permissions.manage_webhooks:
|
2018-07-21 01:10:24 +02:00
|
|
|
try:
|
|
|
|
server_webhooks = await guild.webhooks()
|
|
|
|
except:
|
|
|
|
server_webhooks = []
|
|
|
|
else:
|
|
|
|
server_webhooks = []
|
|
|
|
guild_fmtted = get_formatted_guild(guild, server_webhooks)
|
|
|
|
await self.connection.set(key, json.dumps(guild_fmtted, separators=(',', ':')))
|
|
|
|
await self.enforce_expiring_key(key)
|
|
|
|
|
|
|
|
async def delete_guild(self, guild):
|
|
|
|
key = "Queue/guilds/{}".format(guild.id)
|
|
|
|
await self.connection.delete([key])
|
|
|
|
|
|
|
|
async def update_guild(self, guild):
|
|
|
|
key = "Queue/guilds/{}".format(guild.id)
|
|
|
|
exists = await self.connection.exists(key)
|
|
|
|
if exists:
|
|
|
|
await self.delete_guild(guild)
|
|
|
|
await self.on_get_guild(key, {"guild_id": guild.id})
|
2018-07-23 07:02:26 +02:00
|
|
|
await self.enforce_expiring_key(key)
|
|
|
|
|
|
|
|
async def on_get_user(self, key, params):
|
|
|
|
user = self.bot.get_user(int(params["user_id"]))
|
|
|
|
if not user:
|
|
|
|
return
|
|
|
|
user_formatted = {
|
|
|
|
"id": user.id,
|
|
|
|
"username": user.name,
|
|
|
|
"discriminator": user.discriminator,
|
|
|
|
"avatar": user.avatar,
|
|
|
|
"bot": user.bot
|
|
|
|
}
|
2018-07-31 05:42:21 +02:00
|
|
|
await self.connection.set(key, json.dumps(user_formatted, separators=(',', ':')))
|
|
|
|
await self.enforce_expiring_key(key)
|