Format code
This commit is contained in:
@@ -5,6 +5,7 @@ import asyncpg.exceptions
|
||||
|
||||
from pluralkit.bot import logger
|
||||
|
||||
|
||||
async def connect():
|
||||
while True:
|
||||
try:
|
||||
@@ -12,32 +13,40 @@ async def connect():
|
||||
except (ConnectionError, asyncpg.exceptions.CannotConnectNowError):
|
||||
pass
|
||||
|
||||
|
||||
def db_wrap(func):
|
||||
async def inner(*args, **kwargs):
|
||||
before = time.perf_counter()
|
||||
res = await func(*args, **kwargs)
|
||||
after = time.perf_counter()
|
||||
|
||||
|
||||
logger.debug(" - DB took {:.2f} ms".format((after - before) * 1000))
|
||||
return res
|
||||
return inner
|
||||
|
||||
|
||||
webhook_cache = {}
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def create_system(conn, system_name: str, system_hid: str):
|
||||
logger.debug("Creating system (name={}, hid={})".format(system_name, system_hid))
|
||||
logger.debug("Creating system (name={}, hid={})".format(
|
||||
system_name, system_hid))
|
||||
return await conn.fetchrow("insert into systems (name, hid) values ($1, $2) returning *", system_name, system_hid)
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def remove_system(conn, system_id: int):
|
||||
logger.debug("Deleting system (id={})".format(system_id))
|
||||
await conn.execute("delete from systems where id = $1", system_id)
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def create_member(conn, system_id: int, member_name: str, member_hid: str):
|
||||
logger.debug("Creating member (system={}, name={}, hid={})".format(system_id, member_name, member_hid))
|
||||
return await conn.fetchrow("insert into members (name, system, hid) values ($1, $2, $3) returning *", member_name, system_id, member_hid)
|
||||
logger.debug("Creating member (system={}, name={}, hid={})".format(
|
||||
system_id, member_name, member_hid))
|
||||
return await conn.fetchrow("insert into members (name, system, hid) values ($1, $2, $3) returning *", member_name, system_id, member_hid)
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def delete_member(conn, member_id: int):
|
||||
@@ -45,70 +54,90 @@ async def delete_member(conn, member_id: int):
|
||||
await conn.execute("update switches set member = null, member_del = true where member = $1", member_id)
|
||||
await conn.execute("delete from members where id = $1", member_id)
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def link_account(conn, system_id: int, account_id: str):
|
||||
logger.debug("Linking account (account_id={}, system_id={})".format(account_id, system_id))
|
||||
logger.debug("Linking account (account_id={}, system_id={})".format(
|
||||
account_id, system_id))
|
||||
await conn.execute("insert into accounts (uid, system) values ($1, $2)", int(account_id), system_id)
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def unlink_account(conn, system_id: int, account_id: str):
|
||||
logger.debug("Unlinking account (account_id={}, system_id={})".format(account_id, system_id))
|
||||
logger.debug("Unlinking account (account_id={}, system_id={})".format(
|
||||
account_id, system_id))
|
||||
await conn.execute("delete from accounts where uid = $1 and system = $2", int(account_id), system_id)
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_linked_accounts(conn, system_id: int):
|
||||
return [row["uid"] for row in await conn.fetch("select uid from accounts where system = $1", system_id)]
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_system_by_account(conn, account_id: str):
|
||||
return await conn.fetchrow("select systems.* from systems, accounts where accounts.uid = $1 and accounts.system = systems.id", int(account_id))
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_system_by_hid(conn, system_hid: str):
|
||||
return await conn.fetchrow("select * from systems where hid = $1", system_hid)
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_system(conn, system_id: int):
|
||||
return await conn.fetchrow("select * from systems where id = $1", system_id)
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_member_by_name(conn, system_id: int, member_name: str):
|
||||
return await conn.fetchrow("select * from members where system = $1 and name = $2", system_id, member_name)
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_member_by_hid_in_system(conn, system_id: int, member_hid: str):
|
||||
return await conn.fetchrow("select * from members where system = $1 and hid = $2", system_id, member_hid)
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_member_by_hid(conn, member_hid: str):
|
||||
return await conn.fetchrow("select * from members where hid = $1", member_hid)
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_member(conn, member_id: int):
|
||||
return await conn.fetchrow("select * from members where id = $1", member_id)
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_message(conn, message_id: str):
|
||||
return await conn.fetchrow("select * from messages where mid = $1", message_id)
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def update_system_field(conn, system_id: int, field: str, value):
|
||||
logger.debug("Updating system field (id={}, {}={})".format(system_id, field, value))
|
||||
logger.debug("Updating system field (id={}, {}={})".format(
|
||||
system_id, field, value))
|
||||
await conn.execute("update systems set {} = $1 where id = $2".format(field), value, system_id)
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def update_member_field(conn, member_id: int, field: str, value):
|
||||
logger.debug("Updating member field (id={}, {}={})".format(member_id, field, value))
|
||||
logger.debug("Updating member field (id={}, {}={})".format(
|
||||
member_id, field, value))
|
||||
await conn.execute("update members set {} = $1 where id = $2".format(field), value, member_id)
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_all_members(conn, system_id: int):
|
||||
return await conn.fetch("select * from members where system = $1", system_id)
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_members_exceeding(conn, system_id: int, length: int):
|
||||
return await conn.fetch("select * from members where system = $1 and length(name) >= $2", system_id, length)
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_webhook(conn, channel_id: str):
|
||||
if channel_id in webhook_cache:
|
||||
@@ -117,30 +146,38 @@ async def get_webhook(conn, channel_id: str):
|
||||
webhook_cache[channel_id] = res
|
||||
return res
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def add_webhook(conn, channel_id: str, webhook_id: str, webhook_token: str):
|
||||
logger.debug("Adding new webhook (channel={}, webhook={}, token={})".format(channel_id, webhook_id, webhook_token))
|
||||
logger.debug("Adding new webhook (channel={}, webhook={}, token={})".format(
|
||||
channel_id, webhook_id, webhook_token))
|
||||
await conn.execute("insert into webhooks (channel, webhook, token) values ($1, $2, $3)", int(channel_id), int(webhook_id), webhook_token)
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def add_message(conn, message_id: str, channel_id: str, member_id: int, sender_id: str):
|
||||
logger.debug("Adding new message (id={}, channel={}, member={}, sender={})".format(message_id, channel_id, member_id, sender_id))
|
||||
logger.debug("Adding new message (id={}, channel={}, member={}, sender={})".format(
|
||||
message_id, channel_id, member_id, sender_id))
|
||||
await conn.execute("insert into messages (mid, channel, member, sender) values ($1, $2, $3, $4)", int(message_id), int(channel_id), member_id, int(sender_id))
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_members_by_account(conn, account_id: str):
|
||||
# Returns a "chimera" object
|
||||
return await conn.fetch("select members.id, members.hid, members.prefix, members.suffix, members.name, members.avatar_url, systems.tag from systems, members, accounts where accounts.uid = $1 and systems.id = accounts.system and members.system = systems.id", int(account_id))
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def get_message_by_sender_and_id(conn, message_id: str, sender_id: str):
|
||||
await conn.fetchrow("select * from messages where mid = $1 and sender = $2", int(message_id), int(sender_id))
|
||||
|
||||
|
||||
@db_wrap
|
||||
async def delete_message(conn, message_id: str):
|
||||
logger.debug("Deleting message (id={})".format(message_id))
|
||||
await conn.execute("delete from messages where mid = $1", int(message_id))
|
||||
|
||||
|
||||
async def create_tables(conn):
|
||||
await conn.execute("""create table if not exists systems (
|
||||
id serial primary key,
|
||||
@@ -190,4 +227,4 @@ async def create_tables(conn):
|
||||
id bigint primary key,
|
||||
cmd_chans bigint[],
|
||||
proxy_chans bigint[]
|
||||
)""")
|
||||
)""")
|
||||
|
||||
Reference in New Issue
Block a user