Files
PluralKit/src/api_main.py
2019-03-07 16:29:46 +01:00

203 lines
5.4 KiB
Python

import json
import logging
import os
from aiohttp import web
from pluralkit import db, utils
from pluralkit.errors import PluralKitError
from pluralkit.member import Member
from pluralkit.system import System
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
logger = logging.getLogger("pluralkit.api")
def db_handler(f):
async def inner(request, *args, **kwargs):
async with request.app["pool"].acquire() as conn:
return await f(request, conn=conn, *args, **kwargs)
return inner
def system_auth(f):
async def inner(request: web.Request, conn, *args, **kwargs):
token = request.headers.get("X-Token")
if not token:
token = request.query.get("token")
if not token:
raise web.HTTPUnauthorized()
system = await System.get_by_token(conn, token)
if not system:
raise web.HTTPUnauthorized()
return await f(request, conn=conn, system=system, *args, **kwargs)
return inner
@db_handler
async def get_system(request: web.Request, conn):
system = await db.get_system_by_hid(conn, request.match_info["id"])
if not system:
raise web.HTTPNotFound()
members = await db.get_all_members(conn, system.id)
system_json = system.to_json()
system_json["members"] = [member.to_json() for member in members]
return web.json_response(system_json)
@db_handler
async def get_member(request: web.Request, conn):
member = await db.get_member_by_hid(conn, request.match_info["id"])
if not member:
raise web.HTTPNotFound()
return web.json_response(member.to_json())
@db_handler
async def get_switches(request: web.Request, conn):
system = await db.get_system_by_hid(conn, request.match_info["id"])
if not system:
raise web.HTTPNotFound()
switches = await utils.get_front_history(conn, system.id, 99999)
data = [{
"timestamp": stamp.isoformat(),
"members": [member.hid for member in members]
} for stamp, members in switches]
return web.json_response(data)
@db_handler
async def get_message(request: web.Request, conn):
message = await db.get_message(conn, request.match_info["id"])
if not message:
raise web.HTTPNotFound()
return web.json_response(message.to_json())
@db_handler
async def get_switch(request: web.Request, conn):
system = await db.get_system_by_hid(conn, request.match_info["id"])
if not system:
raise web.HTTPNotFound()
members, stamp = await utils.get_fronters(conn, system.id)
if not stamp:
# No switch has been registered at all
raise web.HTTPNotFound()
data = {
"timestamp": stamp.isoformat(),
"members": [member.to_json() for member in members]
}
return web.json_response(data)
@db_handler
async def get_switch_name(request: web.Request, conn):
system = await db.get_system_by_hid(conn, request.match_info["id"])
if not system:
raise web.HTTPNotFound()
members, stamp = await utils.get_fronters(conn, system.id)
return web.Response(text=members[0].name if members else "(nobody)")
@db_handler
async def get_switch_color(request: web.Request, conn):
system = await db.get_system_by_hid(conn, request.match_info["id"])
if not system:
raise web.HTTPNotFound()
members, stamp = await utils.get_fronters(conn, system.id)
return web.Response(text=members[0].color if members else "#ffffff")
@db_handler
@system_auth
async def put_switch(request: web.Request, system: System, conn):
try:
req = await request.json()
except json.JSONDecodeError:
raise web.HTTPBadRequest(body="Invalid JSON")
if isinstance(req, str):
req = [req]
elif not isinstance(req, list):
raise web.HTTPBadRequest(body="Body must be JSON string or list")
members = []
for member_name in req:
if not isinstance(member_name, str):
raise web.HTTPBadRequest(body="List value must be string")
member = await Member.get_member_fuzzy(conn, system.id, member_name)
if not member:
raise web.HTTPBadRequest(body="Member '{}' not found".format(member_name))
members.append(member)
switch = await system.add_switch(conn, members)
return web.json_response(await switch.to_json(conn))
@db_handler
async def get_stats(request: web.Request, conn):
system_count = await db.system_count(conn)
member_count = await db.member_count(conn)
message_count = await db.message_count(conn)
return web.json_response({
"systems": system_count,
"members": member_count,
"messages": message_count
})
@web.middleware
async def render_pk_errors(request, handler):
try:
return await handler(request)
except PluralKitError as e:
raise web.HTTPBadRequest(body=e.message)
app = web.Application(middlewares=[render_pk_errors])
app.add_routes([
web.get("/systems/{id}", get_system),
web.get("/systems/{id}/switches", get_switches),
web.get("/systems/{id}/switch", get_switch),
web.put("/systems/{id}/switch", put_switch),
web.get("/systems/{id}/switch/name", get_switch_name),
web.get("/systems/{id}/switch/color", get_switch_color),
web.get("/members/{id}", get_member),
web.get("/messages/{id}", get_message),
web.get("/stats", get_stats)
])
async def run():
app["pool"] = await db.connect(
os.environ["DATABASE_URI"]
)
return app
web.run_app(run())