Refactored config file loading

This commit is contained in:
Ske
2019-03-08 17:22:05 +01:00
parent abda846ca3
commit 560b79c2ae
2 changed files with 41 additions and 17 deletions

View File

@@ -2,8 +2,10 @@ import asyncio
import sys
import asyncpg
from collections import namedtuple
import discord
import logging
import json
import os
import traceback
@@ -12,13 +14,45 @@ from pluralkit.bot import commands, proxy, channel_logger, embeds
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
class Config(namedtuple("Config", ["database_uri", "token", "log_channel"])):
required_fields = ["database_uri", "token"]
database_uri: str
token: str
log_channel: str
@staticmethod
def from_file_and_env(filename: str) -> "Config":
try:
with open(filename, "r") as f:
config = json.load(f)
except IOError as e:
# If all the required fields are specified as environment variables, it's OK to
# not raise the IOError, we can just construct the dict from these
if all([rf.upper() in os.environ for rf in Config.required_fields]):
config = {}
else:
# If they aren't, though, then rethrow
raise e
# Override with environment variables
for f in Config._fields:
if f.upper() in os.environ:
config[f] = os.environ[f.upper()]
# If we currently don't have all the required fields, then raise
if not all([rf in config for rf in Config.required_fields]):
raise RuntimeError("Some required config fields were missing: " + ", ".join(filter(lambda rf: rf not in config, Config.required_fields)))
return Config(**config)
def connect_to_database(uri: str) -> asyncpg.pool.Pool:
return asyncio.get_event_loop().run_until_complete(db.connect(uri))
def run(token: str, db_uri: str, log_channel_id: int):
pool = connect_to_database(db_uri)
def run(config: Config):
pool = connect_to_database(config.database_uri)
async def create_tables():
async with pool.acquire() as conn:
@@ -78,9 +112,9 @@ def run(token: str, db_uri: str, log_channel_id: int):
# Then log it to the given log channel
# TODO: replace this with Sentry or something
if not log_channel_id:
if not config.log_channel:
return
log_channel = client.get_channel(log_channel_id)
log_channel = client.get_channel(int(config.log_channel))
# If this is a message event, we can attach additional information in an event
# ie. username, channel, content, etc
@@ -102,4 +136,4 @@ def run(token: str, db_uri: str, log_channel_id: int):
if len(traceback.format_exc()) >= (2000 - len("```python\n```")):
traceback_str = "```python\n...{}```".format(traceback.format_exc()[- (2000 - len("```python\n...```")):])
await log_channel.send(content=traceback_str, embed=embed)
client.run(token)
client.run(config.token)