diff --git a/PluralKit.API/Controllers/SystemController.cs b/PluralKit.API/Controllers/SystemController.cs index 4108e352..f5fadfff 100644 --- a/PluralKit.API/Controllers/SystemController.cs +++ b/PluralKit.API/Controllers/SystemController.cs @@ -74,7 +74,7 @@ namespace PluralKit.API.Controllers var system = await _systems.GetByHid(hid); if (system == null) return NotFound("System not found."); - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) { var res = await conn.QueryAsync( @"select *, array( @@ -146,7 +146,7 @@ namespace PluralKit.API.Controllers // Resolve member objects for all given IDs IEnumerable membersList; - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) membersList = (await conn.QueryAsync("select * from members where hid = any(@Hids)", new {Hids = param.Members})).ToList(); foreach (var member in membersList) diff --git a/PluralKit.Bot/Bot.cs b/PluralKit.Bot/Bot.cs index f9f76270..a51d77ea 100644 --- a/PluralKit.Bot/Bot.cs +++ b/PluralKit.Bot/Bot.cs @@ -34,7 +34,7 @@ namespace PluralKit.Bot using (var services = BuildServiceProvider()) { Console.WriteLine("- Connecting to database..."); - using (var conn = services.GetRequiredService().Obtain()) + using (var conn = await services.GetRequiredService().Obtain()) await Schema.CreateTables(conn); Console.WriteLine("- Connecting to Discord..."); @@ -179,7 +179,7 @@ namespace PluralKit.Bot // and start command execution // Note system may be null if user has no system, hence `OrDefault` PKSystem system; - using (var conn = serviceScope.ServiceProvider.GetService().Obtain()) + using (var conn = await serviceScope.ServiceProvider.GetService().Obtain()) system = await conn.QueryFirstOrDefaultAsync("select systems.* from systems, accounts where accounts.uid = @Id and systems.id = accounts.system", new { Id = arg.Author.Id }); await _commands.ExecuteAsync(new PKCommandContext(_client, arg, system), argPos, serviceScope.ServiceProvider); } diff --git a/PluralKit.Bot/Services/LogChannelService.cs b/PluralKit.Bot/Services/LogChannelService.cs index 6915c526..8d2ca896 100644 --- a/PluralKit.Bot/Services/LogChannelService.cs +++ b/PluralKit.Bot/Services/LogChannelService.cs @@ -30,7 +30,7 @@ namespace PluralKit.Bot { } public async Task GetLogChannel(IGuild guild) { - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) { var server = await conn.QueryFirstOrDefaultAsync("select * from servers where id = @Id", @@ -46,7 +46,7 @@ namespace PluralKit.Bot { LogChannel = newLogChannel?.Id }; - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) { await conn.QueryAsync( "insert into servers (id, log_channel) values (@Id, @LogChannel) on conflict (id) do update set log_channel = @LogChannel", diff --git a/PluralKit.Bot/Services/ProxyService.cs b/PluralKit.Bot/Services/ProxyService.cs index 17a61c1d..f62dc06e 100644 --- a/PluralKit.Bot/Services/ProxyService.cs +++ b/PluralKit.Bot/Services/ProxyService.cs @@ -79,7 +79,7 @@ namespace PluralKit.Bot public async Task HandleMessageAsync(IMessage message) { IEnumerable results; - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) { results = await conn.QueryAsync( "select members.*, systems.* from members, systems, accounts where members.system = systems.id and accounts.system = systems.id and accounts.uid = @Uid", @@ -168,21 +168,29 @@ namespace PluralKit.Bot return HandleMessageDeletionByReaction(message, reaction.UserId); case "\u2753": // Red question mark case "\u2754": // White question mark - return HandleMessageQueryByReaction(message, reaction.UserId); + return HandleMessageQueryByReaction(message, reaction.UserId, reaction.Emote); default: return Task.CompletedTask; } } - private async Task HandleMessageQueryByReaction(Cacheable message, ulong userWhoReacted) + private async Task HandleMessageQueryByReaction(Cacheable message, ulong userWhoReacted, IEmote reactedEmote) { + // Find the user who sent the reaction, so we can DM them var user = await _client.GetUserAsync(userWhoReacted); if (user == null) return; + // Find the message in the DB var msg = await _messageStorage.Get(message.Id); if (msg == null) return; + // DM them the message card await user.SendMessageAsync(embed: await _embeds.CreateMessageInfoEmbed(msg)); + + // And finally remove the original reaction (if we can) + var msgObj = await message.GetOrDownloadAsync(); + if (await msgObj.Channel.HasPermission(ChannelPermission.ManageMessages)) + await msgObj.RemoveReactionAsync(reactedEmote, user); } public async Task HandleMessageDeletionByReaction(Cacheable message, ulong userWhoReacted) diff --git a/PluralKit.Bot/Utils.cs b/PluralKit.Bot/Utils.cs index c2ef2def..62c1dffd 100644 --- a/PluralKit.Bot/Utils.cs +++ b/PluralKit.Bot/Utils.cs @@ -86,6 +86,25 @@ namespace PluralKit.Bot public static string Sanitize(this string input) => Regex.Replace(Regex.Replace(input, "<@[!&]?(\\d{17,19})>", "<\\@$1>"), "@(everyone|here)", "@\u200B$1"); + + public static async Task PermissionsIn(this IChannel channel) + { + switch (channel) + { + case IDMChannel _: + return ChannelPermissions.DM; + case IGroupChannel _: + return ChannelPermissions.Group; + case IGuildChannel gc: + var currentUser = await gc.Guild.GetCurrentUserAsync(); + return currentUser.GetPermissions(gc); + default: + return ChannelPermissions.None; + } + } + + public static async Task HasPermission(this IChannel channel, ChannelPermission permission) => + (await PermissionsIn(channel)).Has(permission); } class PKSystemTypeReader : TypeReader diff --git a/PluralKit.Core/Stores.cs b/PluralKit.Core/Stores.cs index 6ed9f466..a3fab10a 100644 --- a/PluralKit.Core/Stores.cs +++ b/PluralKit.Core/Stores.cs @@ -20,56 +20,56 @@ namespace PluralKit { // TODO: handle HID collision case var hid = Utils.GenerateHid(); - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) return await conn.QuerySingleAsync("insert into systems (hid, name) values (@Hid, @Name) returning *", new { Hid = hid, Name = systemName }); } public async Task Link(PKSystem system, ulong accountId) { // We have "on conflict do nothing" since linking an account when it's already linked to the same system is idempotent // This is used in import/export, although the pk;link command checks for this case beforehand - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) await conn.ExecuteAsync("insert into accounts (uid, system) values (@Id, @SystemId) on conflict do nothing", new { Id = accountId, SystemId = system.Id }); } public async Task Unlink(PKSystem system, ulong accountId) { - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) await conn.ExecuteAsync("delete from accounts where uid = @Id and system = @SystemId", new { Id = accountId, SystemId = system.Id }); } public async Task GetByAccount(ulong accountId) { - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) return await conn.QuerySingleOrDefaultAsync("select systems.* from systems, accounts where accounts.system = systems.id and accounts.uid = @Id", new { Id = accountId }); } public async Task GetByHid(string hid) { - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) return await conn.QuerySingleOrDefaultAsync("select * from systems where systems.hid = @Hid", new { Hid = hid.ToLower() }); } public async Task GetByToken(string token) { - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) return await conn.QuerySingleOrDefaultAsync("select * from systems where token = @Token", new { Token = token }); } public async Task GetById(int id) { - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) return await conn.QuerySingleOrDefaultAsync("select * from systems where id = @Id", new { Id = id }); } public async Task Save(PKSystem system) { - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) await conn.ExecuteAsync("update systems set name = @Name, description = @Description, tag = @Tag, avatar_url = @AvatarUrl, token = @Token, ui_tz = @UiTz where id = @Id", system); } public async Task Delete(PKSystem system) { - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) await conn.ExecuteAsync("delete from systems where id = @Id", system); } public async Task> GetLinkedAccountIds(PKSystem system) { - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) return await conn.QueryAsync("select uid from accounts where system = @Id", new { Id = system.Id }); } } @@ -85,7 +85,7 @@ namespace PluralKit { // TODO: handle collision var hid = Utils.GenerateHid(); - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) return await conn.QuerySingleAsync("insert into members (hid, system, name) values (@Hid, @SystemId, @Name) returning *", new { Hid = hid, SystemID = system.Id, @@ -94,13 +94,13 @@ namespace PluralKit { } public async Task GetByHid(string hid) { - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) return await conn.QuerySingleOrDefaultAsync("select * from members where hid = @Hid", new { Hid = hid.ToLower() }); } public async Task GetByName(PKSystem system, string name) { // QueryFirst, since members can (in rare cases) share names - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) return await conn.QueryFirstOrDefaultAsync("select * from members where lower(name) = lower(@Name) and system = @SystemID", new { Name = name, SystemID = system.Id }); } @@ -113,23 +113,23 @@ namespace PluralKit { } public async Task> GetBySystem(PKSystem system) { - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) return await conn.QueryAsync("select * from members where system = @SystemID", new { SystemID = system.Id }); } public async Task Save(PKMember member) { - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) await conn.ExecuteAsync("update members set name = @Name, description = @Description, color = @Color, avatar_url = @AvatarUrl, birthday = @Birthday, pronouns = @Pronouns, prefix = @Prefix, suffix = @Suffix where id = @Id", member); } public async Task Delete(PKMember member) { - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) await conn.ExecuteAsync("delete from members where id = @Id", member); } public async Task MessageCount(PKMember member) { - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) return await conn.QuerySingleAsync("select count(*) from messages where member = @Id", member); } } @@ -155,7 +155,7 @@ namespace PluralKit { } public async Task Store(ulong senderId, ulong messageId, ulong channelId, PKMember member) { - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) await conn.ExecuteAsync("insert into messages(mid, channel, member, sender) values(@MessageId, @ChannelId, @MemberId, @SenderId)", new { MessageId = messageId, ChannelId = channelId, @@ -166,7 +166,7 @@ namespace PluralKit { public async Task Get(ulong id) { - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) return (await conn.QueryAsync("select messages.*, members.*, systems.* from messages, members, systems where mid = @Id and messages.member = members.id and systems.id = members.system", (msg, member, system) => new StoredMessage { Message = msg, @@ -176,7 +176,7 @@ namespace PluralKit { } public async Task Delete(ulong id) { - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) await conn.ExecuteAsync("delete from messages where mid = @Id", new { Id = id }); } } @@ -193,7 +193,7 @@ namespace PluralKit { public async Task RegisterSwitch(PKSystem system, IEnumerable members) { // Use a transaction here since we're doing multiple executed commands in one - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) using (var tx = conn.BeginTransaction()) { // First, we insert the switch itself @@ -218,20 +218,20 @@ namespace PluralKit { { // TODO: refactor the PKSwitch data structure to somehow include a hydrated member list // (maybe when we get caching in?) - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) return await conn.QueryAsync("select * from switches where system = @System order by timestamp desc limit @Count", new {System = system.Id, Count = count}); } public async Task> GetSwitchMemberIds(PKSwitch sw) { - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) return await conn.QueryAsync("select member from switch_members where switch = @Switch", new {Switch = sw.Id}); } public async Task> GetSwitchMembers(PKSwitch sw) { - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) return await conn.QueryAsync( "select * from switch_members, members where switch_members.member = members.id and switch_members.switch = @Switch", new {Switch = sw.Id}); @@ -241,14 +241,14 @@ namespace PluralKit { public async Task MoveSwitch(PKSwitch sw, Instant time) { - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) await conn.ExecuteAsync("update switches set timestamp = @Time where id = @Id", new {Time = time, Id = sw.Id}); } public async Task DeleteSwitch(PKSwitch sw) { - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) await conn.ExecuteAsync("delete from switches where id = @Id", new {Id = sw.Id}); } @@ -273,7 +273,7 @@ namespace PluralKit { // this makes sure the return list has the same instances of PKMember throughout, which is important for the dictionary // key used in GetPerMemberSwitchDuration below Dictionary memberObjects; - using (var conn = _conn.Obtain()) + using (var conn = await _conn.Obtain()) { memberObjects = (await conn.QueryAsync( "select distinct members.* from members, switch_members where switch_members.switch = any(@Switches) and switch_members.member = members.id", // lol postgres specific `= any()` syntax diff --git a/PluralKit.Core/Utils.cs b/PluralKit.Core/Utils.cs index f50cf248..b0c89c22 100644 --- a/PluralKit.Core/Utils.cs +++ b/PluralKit.Core/Utils.cs @@ -5,6 +5,7 @@ using System.IO; using System.Linq; using System.Security.Cryptography; using System.Text.RegularExpressions; +using System.Threading.Tasks; using Dapper; using Microsoft.Extensions.Configuration; using Newtonsoft.Json; @@ -343,9 +344,11 @@ namespace PluralKit _connectionString = connectionString; } - public IDbConnection Obtain() + public async Task Obtain() { - return new NpgsqlConnection(_connectionString); + var conn = new NpgsqlConnection(_connectionString); + await conn.OpenAsync(); + return conn; } } }