Option to pass in args to handle shard manually

This commit is contained in:
Jeremy Zhang 2020-02-17 18:00:19 -08:00
parent 575d066f9f
commit 51ff9b48b6
3 changed files with 76 additions and 25 deletions

View File

@ -1,9 +1,56 @@
from titanembeds import Titan from titanembeds import Titan
from config import config
import argparse
import gc import gc
import requests
def print_shards():
token = config["bot-token"]
url = "https://discordapp.com/api/v6/gateway/bot"
headers = {"Authorization": "Bot {}".format(token)}
r = requests.get(url, headers=headers)
if r.status_code >= 200 and r.status_code < 300:
print("Suggested number of shards: {}".format(r.json().get("shards", 0)))
else:
print("Status Code: " + r.status_code)
print(r.text)
def main(): def main():
parser = argparse.ArgumentParser(
description="Embed Discord like a True Titan (Discord Bot portion)"
)
parser.add_argument(
"-sid",
"--shard_id",
help="ID of the shard",
type=int,
default=None
)
parser.add_argument(
"-sc",
"--shard_count",
help="Number of total shards",
type=int,
default=None
)
parser.add_argument(
"-s",
"--shards",
help="Prints the reccomended number of shards to spawn",
action="store_true"
)
args = parser.parse_args()
if args.shards:
print_shards()
return
print("Starting...") print("Starting...")
te = Titan() te = Titan(
shard_ids = [args.shard_id] if args.shard_id is not None else None,
shard_count = args.shard_count
)
te.run() te.run()
gc.collect() gc.collect()

View File

@ -12,21 +12,20 @@ import asyncio
import sys import sys
import logging import logging
import json import json
logging.basicConfig(filename='titanbot.log',level=logging.INFO,format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p')
handler = logging.FileHandler(config.get("logging-location", "titanbot.log"))
logging.getLogger('TitanBot')
logging.getLogger('sqlalchemy')
# try: # try:
# raven_client = RavenClient(config["sentry-dsn"]) # raven_client = RavenClient(config["sentry-dsn"])
# except raven.exceptions.InvalidDsn: # except raven.exceptions.InvalidDsn:
# pass # pass
class Titan(discord.AutoShardedClient): class Titan(discord.AutoShardedClient):
def __init__(self): def __init__(self, shard_ids=None, shard_count=None):
super().__init__( super().__init__(
shard_ids=shard_ids,
shard_count=shard_count,
max_messages=10000, max_messages=10000,
activity=discord.Game(name="Embed your Discord server! Visit https://TitanEmbeds.com/") activity=discord.Game(name="Embed your Discord server! Visit https://TitanEmbeds.com/")
) )
self.setup_logger(shard_ids)
self.aiosession = aiohttp.ClientSession(loop=self.loop) self.aiosession = aiohttp.ClientSession(loop=self.loop)
self.http.user_agent += ' TitanEmbeds-Bot' self.http.user_agent += ' TitanEmbeds-Bot'
self.redisqueue = RedisQueue(self, config["redis-uri"]) self.redisqueue = RedisQueue(self, config["redis-uri"])
@ -38,6 +37,16 @@ class Titan(discord.AutoShardedClient):
self.discordBotsOrg = None self.discordBotsOrg = None
self.botsDiscordPw = None self.botsDiscordPw = None
def setup_logger(self, shard_ids=None):
shard_ids = '-'.join(str(x) for x in shard_ids) if shard_ids is not None else ''
logging.basicConfig(
filename='titanbot{}.log'.format(shard_ids),
level=logging.INFO,
format='%(asctime)s %(message)s',
datefmt='%m/%d/%Y %I:%M:%S %p'
)
logging.getLogger('TitanBot')
def _cleanup(self): def _cleanup(self):
try: try:
self.loop.run_until_complete(self.logout()) self.loop.run_until_complete(self.logout())
@ -54,18 +63,18 @@ class Titan(discord.AutoShardedClient):
async def start(self): async def start(self):
await self.redisqueue.connect() await self.redisqueue.connect()
self.loop.create_task(self.redisqueue.subscribe())
await super().start(config["bot-token"]) await super().start(config["bot-token"])
async def on_shard_ready(self, shard_id): async def on_shard_ready(self, shard_id):
print('Titan [DiscordBot]') logging.info('Titan [DiscordBot]')
print('Logged in as the following user:') logging.info('Logged in as the following user:')
print(self.user.name) logging.info(self.user.name)
print(self.user.id) logging.info(self.user.id)
print('------') logging.info('------')
print("Shard count: " + str(self.shard_count)) logging.info("Shard count: " + str(self.shard_count))
print("Shard id: "+ str(shard_id)) logging.info("Shard id: "+ str(shard_id))
print("------") logging.info("------")
self.loop.create_task(self.redisqueue.subscribe())
self.discordBotsOrg = DiscordBotsOrg(self.user.id, config.get("discord-bots-org-token", None)) self.discordBotsOrg = DiscordBotsOrg(self.user.id, config.get("discord-bots-org-token", None))
self.botsDiscordPw = BotsDiscordPw(self.user.id, config.get("bots-discord-pw-token", None)) self.botsDiscordPw = BotsDiscordPw(self.user.id, config.get("bots-discord-pw-token", None))

View File

@ -129,8 +129,6 @@ class RedisQueue:
async def on_get_guild_member(self, key, params): async def on_get_guild_member(self, key, params):
guild = self.bot.get_guild(int(params["guild_id"])) guild = self.bot.get_guild(int(params["guild_id"]))
if not guild: if not guild:
await self.connection.set(key, "")
await self.enforce_expiring_key(key)
return return
member = guild.get_member(int(params["user_id"])) member = guild.get_member(int(params["user_id"]))
if not member: if not member:
@ -143,12 +141,11 @@ class RedisQueue:
async def on_get_guild_member_named(self, key, params): async def on_get_guild_member_named(self, key, params):
guild = self.bot.get_guild(int(params["guild_id"])) guild = self.bot.get_guild(int(params["guild_id"]))
if not guild:
return
query = params["query"] query = params["query"]
result = None result = None
if guild: members = guild.members
members = guild.members
else:
members = None
if members and len(query) > 5 and query[-5] == '#': if members and len(query) > 5 and query[-5] == '#':
potential_discriminator = query[-4:] potential_discriminator = query[-4:]
result = discord.utils.get(members, name=query[:-5], discriminator=potential_discriminator) result = discord.utils.get(members, name=query[:-5], discriminator=potential_discriminator)
@ -167,6 +164,8 @@ class RedisQueue:
async def on_list_guild_members(self, key, params): async def on_list_guild_members(self, key, params):
guild = self.bot.get_guild(int(params["guild_id"])) guild = self.bot.get_guild(int(params["guild_id"]))
if not guild:
return
members = guild.members members = guild.members
member_ids = [] member_ids = []
for member in members: for member in members:
@ -200,8 +199,6 @@ class RedisQueue:
async def on_get_guild(self, key, params): async def on_get_guild(self, key, params):
guild = self.bot.get_guild(int(params["guild_id"])) guild = self.bot.get_guild(int(params["guild_id"]))
if not guild: if not guild:
await self.connection.set(key, "")
await self.enforce_expiring_key(key)
return return
if guild.me and guild.me.guild_permissions.manage_webhooks: if guild.me and guild.me.guild_permissions.manage_webhooks:
try: try:
@ -229,8 +226,6 @@ class RedisQueue:
async def on_get_user(self, key, params): async def on_get_user(self, key, params):
user = self.bot.get_user(int(params["user_id"])) user = self.bot.get_user(int(params["user_id"]))
if not user: if not user:
await self.connection.set(key, "")
await self.enforce_expiring_key(key)
return return
user_formatted = { user_formatted = {
"id": user.id, "id": user.id,