Add API endpoint for logging new switches (+ refactor)
This commit is contained in:
@@ -1,19 +1,40 @@
|
||||
import os
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from pluralkit import db, utils
|
||||
from pluralkit.errors import PluralKitError
|
||||
from pluralkit.member import Member
|
||||
from pluralkit.system import System
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
|
||||
logger = logging.getLogger("pluralkit.api")
|
||||
|
||||
|
||||
def db_handler(f):
|
||||
async def inner(request):
|
||||
async def inner(request, *args, **kwargs):
|
||||
async with request.app["pool"].acquire() as conn:
|
||||
return await f(request, conn)
|
||||
return await f(request, conn=conn, *args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def system_auth(f):
|
||||
async def inner(request: web.Request, conn, *args, **kwargs):
|
||||
token = request.headers.get("X-Token")
|
||||
if not token:
|
||||
token = request.query.get("token")
|
||||
|
||||
if not token:
|
||||
raise web.HTTPUnauthorized()
|
||||
|
||||
system = await System.get_by_token(conn, token)
|
||||
if not system:
|
||||
raise web.HTTPUnauthorized()
|
||||
|
||||
return await f(request, conn=conn, system=system, *args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
@@ -58,6 +79,7 @@ async def get_switches(request: web.Request, conn):
|
||||
|
||||
return web.json_response(data)
|
||||
|
||||
|
||||
@db_handler
|
||||
async def get_message(request: web.Request, conn):
|
||||
message = await db.get_message(conn, request.match_info["id"])
|
||||
@@ -66,6 +88,7 @@ async def get_message(request: web.Request, conn):
|
||||
|
||||
return web.json_response(message.to_json())
|
||||
|
||||
|
||||
@db_handler
|
||||
async def get_switch(request: web.Request, conn):
|
||||
system = await db.get_system_by_hid(conn, request.match_info["id"])
|
||||
@@ -84,6 +107,7 @@ async def get_switch(request: web.Request, conn):
|
||||
}
|
||||
return web.json_response(data)
|
||||
|
||||
|
||||
@db_handler
|
||||
async def get_switch_name(request: web.Request, conn):
|
||||
system = await db.get_system_by_hid(conn, request.match_info["id"])
|
||||
@@ -94,6 +118,7 @@ async def get_switch_name(request: web.Request, conn):
|
||||
members, stamp = await utils.get_fronters(conn, system.id)
|
||||
return web.Response(text=members[0].name if members else "(nobody)")
|
||||
|
||||
|
||||
@db_handler
|
||||
async def get_switch_color(request: web.Request, conn):
|
||||
system = await db.get_system_by_hid(conn, request.match_info["id"])
|
||||
@@ -104,6 +129,34 @@ async def get_switch_color(request: web.Request, conn):
|
||||
members, stamp = await utils.get_fronters(conn, system.id)
|
||||
return web.Response(text=members[0].color if members else "#ffffff")
|
||||
|
||||
|
||||
@db_handler
|
||||
@system_auth
|
||||
async def put_switch(request: web.Request, system: System, conn):
|
||||
try:
|
||||
req = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
raise web.HTTPBadRequest(body="Invalid JSON")
|
||||
|
||||
if isinstance(req, str):
|
||||
req = [req]
|
||||
elif not isinstance(req, list):
|
||||
raise web.HTTPBadRequest(body="Body must be JSON string or list")
|
||||
|
||||
members = []
|
||||
for member_name in req:
|
||||
if not isinstance(member_name, str):
|
||||
raise web.HTTPBadRequest(body="List value must be string")
|
||||
|
||||
member = await Member.get_member_fuzzy(conn, system.id, member_name)
|
||||
if not member:
|
||||
raise web.HTTPBadRequest(body="Member '{}' not found".format(member_name))
|
||||
members.append(member)
|
||||
|
||||
switch = await system.add_switch(conn, members)
|
||||
return web.json_response(await switch.to_json(conn))
|
||||
|
||||
|
||||
@db_handler
|
||||
async def get_stats(request: web.Request, conn):
|
||||
system_count = await db.system_count(conn)
|
||||
@@ -116,11 +169,21 @@ async def get_stats(request: web.Request, conn):
|
||||
"messages": message_count
|
||||
})
|
||||
|
||||
app = web.Application()
|
||||
|
||||
@web.middleware
|
||||
async def render_pk_errors(request, handler):
|
||||
try:
|
||||
return await handler(request)
|
||||
except PluralKitError as e:
|
||||
raise web.HTTPBadRequest(body=e.message)
|
||||
|
||||
|
||||
app = web.Application(middlewares=[render_pk_errors])
|
||||
app.add_routes([
|
||||
web.get("/systems/{id}", get_system),
|
||||
web.get("/systems/{id}/switches", get_switches),
|
||||
web.get("/systems/{id}/switch", get_switch),
|
||||
web.put("/systems/{id}/switch", put_switch),
|
||||
web.get("/systems/{id}/switch/name", get_switch_name),
|
||||
web.get("/systems/{id}/switch/color", get_switch_color),
|
||||
web.get("/members/{id}", get_member),
|
||||
|
||||
Reference in New Issue
Block a user