From 6da7436aed78e21cfe36a8f7d160c3b59c15400b Mon Sep 17 00:00:00 2001 From: Ske Date: Tue, 13 Nov 2018 12:34:19 +0100 Subject: [PATCH] Add commands for API token retrieval/refreshing --- src/pluralkit/bot/commands/__init__.py | 6 ++++- src/pluralkit/bot/commands/api_commands.py | 31 ++++++++++++++++++++++ src/pluralkit/db.py | 8 +++++- src/pluralkit/system.py | 22 +++++++++++---- 4 files changed, 60 insertions(+), 7 deletions(-) create mode 100644 src/pluralkit/bot/commands/api_commands.py diff --git a/src/pluralkit/bot/commands/__init__.py b/src/pluralkit/bot/commands/__init__.py index b5940189..e0b10cae 100644 --- a/src/pluralkit/bot/commands/__init__.py +++ b/src/pluralkit/bot/commands/__init__.py @@ -129,6 +129,7 @@ class CommandContext: raise CommandError("Timed out - try again.") +import pluralkit.bot.commands.api_commands import pluralkit.bot.commands.import_commands import pluralkit.bot.commands.member_commands import pluralkit.bot.commands.message_commands @@ -179,7 +180,10 @@ async def command_dispatch(client: discord.Client, message: discord.Message, con (r"switch move", switch_commands.switch_move), (r"switch out", switch_commands.switch_out), - (r"switch", switch_commands.switch_member) + (r"switch", switch_commands.switch_member), + + (r"token (refresh|expire|update)", api_commands.refresh_token), + (r"token", api_commands.get_token) ] for pattern, func in commands: diff --git a/src/pluralkit/bot/commands/api_commands.py b/src/pluralkit/bot/commands/api_commands.py new file mode 100644 index 00000000..5670acd3 --- /dev/null +++ b/src/pluralkit/bot/commands/api_commands.py @@ -0,0 +1,31 @@ +import logging +from discord import DMChannel + +from pluralkit.bot.commands import CommandContext, CommandSuccess + +logger = logging.getLogger("pluralkit.commands") +disclaimer = "Please note that this grants access to modify (and delete!) all your system data, so keep it safe and secure. If it leaks or you need a new one, you can invalidate this one with `pk;token refresh`." + +async def reply_dm(ctx: CommandContext, message: str): + await ctx.message.author.send(message) + + if not isinstance(ctx.message.channel, DMChannel): + return CommandSuccess("DM'd!") + +async def get_token(ctx: CommandContext): + system = await ctx.ensure_system() + + if system.token: + token = system.token + else: + token = await system.refresh_token(ctx.conn) + + token_message = "Here's your API token: \n**`{}`**\n{}".format(token, disclaimer) + return await reply_dm(ctx, token_message) + +async def refresh_token(ctx: CommandContext): + system = await ctx.ensure_system() + + token = await system.refresh_token(ctx.conn) + token_message = "Your previous API token has been invalidated. You will need to change it anywhere it's currently used.\nHere's your new API token:\n**`{}`**\n{}".format(token, disclaimer) + return await reply_dm(ctx, token_message) \ No newline at end of file diff --git a/src/pluralkit/db.py b/src/pluralkit/db.py index 5b4faf73..dd640f05 100644 --- a/src/pluralkit/db.py +++ b/src/pluralkit/db.py @@ -1,7 +1,7 @@ from collections import namedtuple from datetime import datetime import logging -from typing import List +from typing import List, Optional import time import asyncpg @@ -85,6 +85,11 @@ async def get_system_by_account(conn, account_id: int) -> System: row = await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", account_id) return System(**row) if row else None +@db_wrap +async def get_system_by_token(conn, token: str) -> Optional[System]: + row = await conn.fetchrow("select * from systems where token = $1", token) + return System(**row) if row else None + @db_wrap async def get_system_by_hid(conn, system_hid: str) -> System: row = await conn.fetchrow("select * from systems where hid = $1", system_hid) @@ -323,6 +328,7 @@ async def create_tables(conn): description text, tag text, avatar_url text, + token text, created timestamp not null default (current_timestamp at time zone 'utc') )""") await conn.execute("""create table if not exists members ( diff --git a/src/pluralkit/system.py b/src/pluralkit/system.py index 70e47f69..ba98f082 100644 --- a/src/pluralkit/system.py +++ b/src/pluralkit/system.py @@ -1,3 +1,5 @@ +import random +import string from datetime import datetime from collections.__init__ import namedtuple @@ -9,21 +11,26 @@ from pluralkit.switch import Switch from pluralkit.utils import generate_hid, contains_custom_emoji, validate_avatar_url_or_raise -class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "avatar_url", "created"])): +class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "avatar_url", "token", "created"])): id: int hid: str name: str description: str tag: str avatar_url: str + token: str created: datetime @staticmethod - async def get_by_account(conn, account_id: str) -> "System": + async def get_by_account(conn, account_id: int) -> Optional["System"]: return await db.get_system_by_account(conn, account_id) @staticmethod - async def create_system(conn, account_id: str, system_name: Optional[str] = None) -> "System": + async def get_by_token(conn, token: str) -> Optional["System"]: + return await db.get_system_by_token(conn, token) + + @staticmethod + async def create_system(conn, account_id: int, system_name: Optional[str] = None) -> "System": async with conn.transaction(): existing_system = await System.get_by_account(conn, account_id) if existing_system: @@ -66,7 +73,7 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a await db.update_system_field(conn, self.id, "avatar_url", new_avatar_url) - async def link_account(self, conn, new_account_id: str): + async def link_account(self, conn, new_account_id: int): async with conn.transaction(): existing_system = await System.get_by_account(conn, new_account_id) @@ -78,7 +85,7 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a await db.link_account(conn, self.id, new_account_id) - async def unlink_account(self, conn, account_id: str): + async def unlink_account(self, conn, account_id: int): async with conn.transaction(): linked_accounts = await db.get_linked_accounts(conn, self.id) if len(linked_accounts) == 1: @@ -92,6 +99,11 @@ class System(namedtuple("System", ["id", "hid", "name", "description", "tag", "a async def delete(self, conn): await db.remove_system(conn, self.id) + async def refresh_token(self, conn) -> str: + new_token = "".join(random.choices(string.ascii_letters + string.digits, k=64)) + await db.update_system_field(conn, self.id, "token", new_token) + return new_token + async def create_member(self, conn, member_name: str) -> Member: # TODO: figure out what to do if this errors out on collision on generate_hid new_hid = generate_hid()