Too many refactors in one:
- Allowed adding ephemeral(ish) views and functions - Moved message_count to a concrete database field - Moved most proxy logic to a stored procedure - Moved database files around and refactored schema manager
This commit is contained in:
0
PluralKit.Core/Migrations/6.sql → PluralKit.Core/Database/Migrations/6.sql
Executable file → Normal file
0
PluralKit.Core/Migrations/6.sql → PluralKit.Core/Database/Migrations/6.sql
Executable file → Normal file
33
PluralKit.Core/Database/Migrations/7.sql
Normal file
33
PluralKit.Core/Database/Migrations/7.sql
Normal file
@@ -0,0 +1,33 @@
|
||||
-- SCHEMA VERSION 7: 2020-06-12
|
||||
-- (in-db message count row)
|
||||
|
||||
-- Add message count row to members table, initialize it with the correct data
|
||||
alter table members add column message_count int not null default 0;
|
||||
update members set message_count = (select count(*) from messages where messages.member = members.id);
|
||||
|
||||
|
||||
-- Create a trigger function to increment the message count on inserting to the messages table
|
||||
create function trg_msgcount_increment() returns trigger as $$
|
||||
begin
|
||||
update members set message_count = message_count + 1 where id = NEW.member;
|
||||
return NEW;
|
||||
end;
|
||||
$$ language plpgsql;
|
||||
|
||||
create trigger increment_member_message_count before insert on messages for each row execute procedure trg_msgcount_increment();
|
||||
|
||||
|
||||
-- Create a trigger function to decrement the message count on deleting from the messages table
|
||||
create function trg_msgcount_decrement() returns trigger as $$
|
||||
begin
|
||||
-- Don't decrement if count <= zero (shouldn't happen, but we don't want negative message counts)
|
||||
update members set message_count = message_count - 1 where id = OLD.member and message_count > 0;
|
||||
return OLD;
|
||||
end;
|
||||
$$ language plpgsql;
|
||||
|
||||
create trigger decrement_member_message_count before delete on messages for each row execute procedure trg_msgcount_decrement();
|
||||
|
||||
|
||||
-- (update schema ver)
|
||||
update info set schema_version = 7;
|
||||
26
PluralKit.Core/Database/ProxyMember.cs
Normal file
26
PluralKit.Core/Database/ProxyMember.cs
Normal file
@@ -0,0 +1,26 @@
|
||||
#nullable enable
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace PluralKit.Core
|
||||
{
|
||||
/// <summary>
|
||||
/// Model for the `proxy_info` PL/pgSQL function in `functions.sql`
|
||||
/// </summary>
|
||||
public class ProxyMember
|
||||
{
|
||||
public int SystemId { get; set; }
|
||||
public int MemberId { get; set; }
|
||||
public bool ProxyEnabled { get; set; }
|
||||
public AutoproxyMode AutoproxyMode { get; set; }
|
||||
public bool IsAutoproxyMember { get; set; }
|
||||
public ulong? LatchMessage { get; set; }
|
||||
public string ProxyName { get; set; } = "";
|
||||
public string? ProxyAvatar { get; set; }
|
||||
public IReadOnlyCollection<ProxyTag> ProxyTags { get; set; } = new ProxyTag[0];
|
||||
public bool KeepProxy { get; set; }
|
||||
|
||||
public IReadOnlyCollection<ulong> ChannelBlacklist { get; set; } = new ulong[0];
|
||||
public IReadOnlyCollection<ulong> LogBlacklist { get; set; } = new ulong[0];
|
||||
public ulong? LogChannel { get; set; }
|
||||
}
|
||||
}
|
||||
99
PluralKit.Core/Database/Schemas.cs
Normal file
99
PluralKit.Core/Database/Schemas.cs
Normal file
@@ -0,0 +1,99 @@
|
||||
using System;
|
||||
using System.Data;
|
||||
using System.IO;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using Dapper;
|
||||
|
||||
using Npgsql;
|
||||
|
||||
using Serilog;
|
||||
|
||||
namespace PluralKit.Core
|
||||
{
|
||||
public class Schemas
|
||||
{
|
||||
private const string RootPath = "PluralKit.Core.Database"; // "resource path" root for SQL files
|
||||
private const int TargetSchemaVersion = 7;
|
||||
|
||||
private DbConnectionFactory _conn;
|
||||
private ILogger _logger;
|
||||
|
||||
public Schemas(DbConnectionFactory conn, ILogger logger)
|
||||
{
|
||||
_conn = conn;
|
||||
_logger = logger.ForContext<Schemas>();
|
||||
}
|
||||
|
||||
public static void Initialize()
|
||||
{
|
||||
// Without these it'll still *work* but break at the first launch + probably cause other small issues
|
||||
NpgsqlConnection.GlobalTypeMapper.MapComposite<ProxyTag>("proxy_tag");
|
||||
NpgsqlConnection.GlobalTypeMapper.MapEnum<PrivacyLevel>("privacy_level");
|
||||
}
|
||||
|
||||
public async Task InitializeDatabase()
|
||||
{
|
||||
// Run everything in a transaction
|
||||
await using var conn = await _conn.Obtain();
|
||||
using var tx = conn.BeginTransaction();
|
||||
|
||||
// Before applying migrations, clean out views/functions to prevent type errors
|
||||
await ExecuteSqlFile($"{RootPath}.clean.sql", conn, tx);
|
||||
|
||||
// Apply all migrations between the current database version and the target version
|
||||
await ApplyMigrations(conn, tx);
|
||||
|
||||
// Now, reapply views/functions (we deleted them above, no need to worry about conflicts)
|
||||
await ExecuteSqlFile($"{RootPath}.views.sql", conn, tx);
|
||||
await ExecuteSqlFile($"{RootPath}.functions.sql", conn, tx);
|
||||
|
||||
// Finally, commit tx
|
||||
tx.Commit();
|
||||
}
|
||||
|
||||
private async Task ApplyMigrations(IAsyncDbConnection conn, IDbTransaction tx)
|
||||
{
|
||||
var currentVersion = await GetCurrentDatabaseVersion(conn);
|
||||
_logger.Information("Current schema version: {CurrentVersion}", currentVersion);
|
||||
for (var migration = currentVersion + 1; migration <= TargetSchemaVersion; migration++)
|
||||
{
|
||||
_logger.Information("Applying schema migration {MigrationId}", migration);
|
||||
await ExecuteSqlFile($"{RootPath}.Migrations.{migration}.sql", conn, tx);
|
||||
}
|
||||
}
|
||||
|
||||
private async Task ExecuteSqlFile(string resourceName, IDbConnection conn, IDbTransaction tx = null)
|
||||
{
|
||||
await using var stream = typeof(Schemas).Assembly.GetManifestResourceStream(resourceName);
|
||||
if (stream == null) throw new ArgumentException($"Invalid resource name '{resourceName}'");
|
||||
|
||||
using var reader = new StreamReader(stream);
|
||||
var query = await reader.ReadToEndAsync();
|
||||
|
||||
await conn.ExecuteAsync(query, transaction: tx);
|
||||
|
||||
// If the above creates new enum/composite types, we must tell Npgsql to reload the internal type caches
|
||||
// This will propagate to every other connection as well, since it marks the global type mapper collection dirty.
|
||||
// TODO: find a way to get around the cast to our internal tracker wrapper... this could break if that ever changes
|
||||
((PerformanceTrackingConnection) conn)._impl.ReloadTypes();
|
||||
}
|
||||
|
||||
private async Task<int> GetCurrentDatabaseVersion(IDbConnection conn)
|
||||
{
|
||||
// First, check if the "info" table exists (it may not, if this is a *really* old database)
|
||||
var hasInfoTable =
|
||||
await conn.QuerySingleOrDefaultAsync<int>(
|
||||
"select count(*) from information_schema.tables where table_name = 'info'") == 1;
|
||||
|
||||
// If we have the table, read the schema version
|
||||
if (hasInfoTable)
|
||||
return await conn.QuerySingleOrDefaultAsync<int>("select schema_version from info");
|
||||
|
||||
// If not, we return version "-1"
|
||||
// This means migration 0 will get executed, getting us into a consistent state
|
||||
// Then, migration 1 gets executed, which creates the info table and sets version to 1
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
3
PluralKit.Core/Database/clean.sql
Normal file
3
PluralKit.Core/Database/clean.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
drop view if exists system_last_switch;
|
||||
drop view if exists member_list;
|
||||
drop function if exists proxy_info;
|
||||
85
PluralKit.Core/Database/functions.sql
Normal file
85
PluralKit.Core/Database/functions.sql
Normal file
@@ -0,0 +1,85 @@
|
||||
-- Giant "mega-function" to find all information relevant for message proxying
|
||||
-- Returns one row per member, computes several properties from others
|
||||
create function proxy_info(account_id bigint, guild_id bigint)
|
||||
returns table
|
||||
(
|
||||
-- Note: table type gets matched *by index*, not *by name* (make sure order here and in `select` match)
|
||||
system_id int, -- from: systems.id
|
||||
member_id int, -- from: members.id
|
||||
proxy_tags proxy_tag[], -- from: members.proxy_tags
|
||||
keep_proxy bool, -- from: members.keep_proxy
|
||||
proxy_enabled bool, -- from: system_guild.proxy_enabled
|
||||
proxy_name text, -- calculated: name we should proxy under
|
||||
proxy_avatar text, -- calculated: avatar we should proxy with
|
||||
autoproxy_mode int, -- from: system_guild.autoproxy_mode
|
||||
is_autoproxy_member bool, -- calculated: should this member be used for AP?
|
||||
latch_message bigint, -- calculated: last message from this account in this guild
|
||||
channel_blacklist bigint[], -- from: servers.blacklist
|
||||
log_blacklist bigint[], -- from: servers.log_blacklist
|
||||
log_channel bigint -- from: servers.log_channel
|
||||
)
|
||||
as
|
||||
$$
|
||||
select
|
||||
-- Basic data
|
||||
systems.id as system_id,
|
||||
members.id as member_id,
|
||||
members.proxy_tags as proxy_tags,
|
||||
members.keep_proxy as keep_proxy,
|
||||
|
||||
-- Proxy info
|
||||
coalesce(system_guild.proxy_enabled, true) as proxy_enabled,
|
||||
case
|
||||
when systems.tag is not null then (coalesce(member_guild.display_name, members.display_name, members.name) || ' ' || systems.tag)
|
||||
else coalesce(member_guild.display_name, members.display_name, members.name)
|
||||
end as proxy_name,
|
||||
coalesce(member_guild.avatar_url, members.avatar_url, systems.avatar_url) as proxy_avatar,
|
||||
|
||||
-- Autoproxy data
|
||||
coalesce(system_guild.autoproxy_mode, 1) as autoproxy_mode,
|
||||
|
||||
-- Autoproxy logic is essentially: "is this member the one we should autoproxy?"
|
||||
case
|
||||
-- Front mode: check if this is the first fronter
|
||||
when system_guild.autoproxy_mode = 2 then members.id = (select sls.members[1]
|
||||
from system_last_switch as sls
|
||||
where sls.system = systems.id)
|
||||
|
||||
-- Latch mode: check if this is the last proxier
|
||||
when system_guild.autoproxy_mode = 3 then members.id = last_message_in_guild.member
|
||||
|
||||
-- Member mode: check if this is the selected memebr
|
||||
when system_guild.autoproxy_mode = 4 then members.id = system_guild.autoproxy_member
|
||||
|
||||
-- no autoproxy: then this member definitely shouldn't be autoproxied :)
|
||||
else false end as is_autoproxy_member,
|
||||
|
||||
last_message_in_guild.mid as latch_message,
|
||||
|
||||
-- Guild info
|
||||
coalesce(servers.blacklist, array[]::bigint[]) as channel_blacklist,
|
||||
coalesce(servers.log_blacklist, array[]::bigint[]) as log_blacklist,
|
||||
servers.log_channel as log_channel
|
||||
from accounts
|
||||
-- Fetch guild info
|
||||
left join servers on servers.id = guild_id
|
||||
|
||||
-- Fetch the system for this account (w/ guild config)
|
||||
inner join systems on systems.id = accounts.system
|
||||
left join system_guild on system_guild.system = accounts.system and system_guild.guild = guild_id
|
||||
|
||||
-- Fetch all members from this system (w/ guild config)
|
||||
inner join members on members.system = systems.id
|
||||
left join member_guild on member_guild.member = members.id and member_guild.guild = guild_id
|
||||
|
||||
-- Find ID and member for the last message sent in this guild
|
||||
left join lateral (select mid, member
|
||||
from messages
|
||||
where messages.guild = guild_id
|
||||
and messages.sender = account_id
|
||||
and system_guild.autoproxy_mode = 3
|
||||
order by mid desc
|
||||
limit 1) as last_message_in_guild on true
|
||||
where accounts.uid = account_id;
|
||||
$$ language sql stable
|
||||
rows 10;
|
||||
29
PluralKit.Core/Database/views.sql
Normal file
29
PluralKit.Core/Database/views.sql
Normal file
@@ -0,0 +1,29 @@
|
||||
create view system_last_switch as
|
||||
select systems.id as system,
|
||||
last_switch.id as switch,
|
||||
last_switch.timestamp as timestamp,
|
||||
array(select member from switch_members where switch_members.switch = last_switch.id) as members
|
||||
from systems
|
||||
inner join lateral (select * from switches where switches.system = systems.id order by timestamp desc limit 1) as last_switch on true;
|
||||
|
||||
create view member_list as
|
||||
select members.*,
|
||||
-- Find last message ID
|
||||
(select max(messages.mid) from messages where messages.member = members.id) as last_message,
|
||||
|
||||
-- Find last switch timestamp
|
||||
(
|
||||
select max(switches.timestamp)
|
||||
from switch_members
|
||||
inner join switches on switches.id = switch_members.switch
|
||||
where switch_members.member = members.id
|
||||
) as last_switch_time,
|
||||
|
||||
-- Extract month/day from birthday and "force" the year identical (just using 4) -> month/day only sorting!
|
||||
case when members.birthday is not null then
|
||||
make_date(
|
||||
4,
|
||||
extract(month from members.birthday)::integer,
|
||||
extract(day from members.birthday)::integer
|
||||
) end as birthday_md
|
||||
from members;
|
||||
@@ -22,6 +22,7 @@ namespace PluralKit.Core {
|
||||
[JsonProperty("proxy_tags")] public ICollection<ProxyTag> ProxyTags { get; set; }
|
||||
[JsonProperty("keep_proxy")] public bool KeepProxy { get; set; }
|
||||
[JsonProperty("created")] public Instant Created { get; set; }
|
||||
[JsonProperty("message_count")] public int MessageCount { get; set; }
|
||||
|
||||
public PrivacyLevel MemberPrivacy { get; set; }
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ namespace PluralKit.Core
|
||||
builder.RegisterType<DbConnectionCountHolder>().SingleInstance();
|
||||
builder.RegisterType<DbConnectionFactory>().AsSelf().SingleInstance();
|
||||
builder.RegisterType<PostgresDataStore>().AsSelf().As<IDataStore>();
|
||||
builder.RegisterType<SchemaService>().AsSelf();
|
||||
builder.RegisterType<Schemas>().AsSelf();
|
||||
|
||||
builder.Populate(new ServiceCollection().AddMemoryCache());
|
||||
builder.RegisterType<ProxyCache>().AsSelf().SingleInstance();
|
||||
|
||||
@@ -33,7 +33,6 @@
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<EmbeddedResource Include="Migrations\*.sql" />
|
||||
<EmbeddedResource Include="Database/**/*.sql" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
<wpf:ResourceDictionary xml:space="preserve" xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml" xmlns:s="clr-namespace:System;assembly=mscorlib" xmlns:ss="urn:shemas-jetbrains-com:settings-storage-xaml" xmlns:wpf="http://schemas.microsoft.com/winfx/2006/xaml/presentation">
|
||||
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=migrations/@EntryIndexedValue">True</s:Boolean>
|
||||
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=models/@EntryIndexedValue">True</s:Boolean>
|
||||
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=schema/@EntryIndexedValue">True</s:Boolean>
|
||||
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=services/@EntryIndexedValue">True</s:Boolean></wpf:ResourceDictionary>
|
||||
@@ -30,7 +30,6 @@ namespace PluralKit.Core
|
||||
// Export members
|
||||
var members = new List<DataFileMember>();
|
||||
var pkMembers = _data.GetSystemMembers(system); // Read all members in the system
|
||||
var messageCounts = await _data.GetMemberMessageCountBulk(system); // Count messages proxied by all members in the system
|
||||
|
||||
await foreach (var member in pkMembers.Select(m => new DataFileMember
|
||||
{
|
||||
@@ -45,7 +44,7 @@ namespace PluralKit.Core
|
||||
ProxyTags = m.ProxyTags,
|
||||
KeepProxy = m.KeepProxy,
|
||||
Created = DateTimeFormats.TimestampExportFormat.Format(m.Created),
|
||||
MessageCount = messageCounts.Where(x => x.Member == m.Id).Select(x => x.MessageCount).FirstOrDefault()
|
||||
MessageCount = m.MessageCount
|
||||
})) members.Add(member);
|
||||
|
||||
// Export switches
|
||||
|
||||
@@ -41,12 +41,6 @@ namespace PluralKit.Core {
|
||||
public Instant TimespanEnd;
|
||||
}
|
||||
|
||||
public struct MemberMessageCount
|
||||
{
|
||||
public int Member;
|
||||
public int MessageCount;
|
||||
}
|
||||
|
||||
public struct FrontBreakdown
|
||||
{
|
||||
public Dictionary<PKMember, Duration> MemberSwitchDurations;
|
||||
@@ -208,18 +202,7 @@ namespace PluralKit.Core {
|
||||
/// </summary>
|
||||
/// <returns>An enumerable of <see cref="PKMember"/> structs representing each member in the system, in no particular order.</returns>
|
||||
IAsyncEnumerable<PKMember> GetSystemMembers(PKSystem system, bool orderByName = false);
|
||||
/// <summary>
|
||||
/// Gets the amount of messages proxied by a given member.
|
||||
/// </summary>
|
||||
/// <returns>The message count of the given member.</returns>
|
||||
Task<ulong> GetMemberMessageCount(PKMember member);
|
||||
|
||||
/// <summary>
|
||||
/// Collects a breakdown of each member in a system's message count.
|
||||
/// </summary>
|
||||
/// <returns>An enumerable of members along with their message counts.</returns>
|
||||
Task<IEnumerable<MemberMessageCount>> GetMemberMessageCountBulk(PKSystem system);
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// Creates a member, auto-generating its corresponding IDs.
|
||||
/// </summary>
|
||||
@@ -267,9 +250,9 @@ namespace PluralKit.Core {
|
||||
/// <param name="channelId">The ID of the channel the message was posted to.</param>
|
||||
/// <param name="postedMessageId">The ID of the message posted by the webhook.</param>
|
||||
/// <param name="triggerMessageId">The ID of the original trigger message containing the proxy tags.</param>
|
||||
/// <param name="proxiedMember">The member (and by extension system) that was proxied.</param>
|
||||
/// <param name="proxiedMemberId">The member (and by extension system) that was proxied.</param>
|
||||
/// <returns></returns>
|
||||
Task AddMessage(ulong senderAccount, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, PKMember proxiedMember);
|
||||
Task AddMessage(ulong senderAccount, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, int proxiedMemberId);
|
||||
|
||||
/// <summary>
|
||||
/// Deletes a message from the data store.
|
||||
|
||||
@@ -231,25 +231,6 @@ namespace PluralKit.Core {
|
||||
await _cache.InvalidateSystem(member.System);
|
||||
}
|
||||
|
||||
public async Task<ulong> GetMemberMessageCount(PKMember member)
|
||||
{
|
||||
using (var conn = await _conn.Obtain())
|
||||
return await conn.QuerySingleAsync<ulong>("select count(*) from messages where member = @Id", member);
|
||||
}
|
||||
|
||||
public async Task<IEnumerable<MemberMessageCount>> GetMemberMessageCountBulk(PKSystem system)
|
||||
{
|
||||
using (var conn = await _conn.Obtain())
|
||||
return await conn.QueryAsync<MemberMessageCount>(
|
||||
@"SELECT messages.member, COUNT(messages.member) messagecount
|
||||
FROM members
|
||||
JOIN messages
|
||||
ON members.id = messages.member
|
||||
WHERE members.system = @System
|
||||
GROUP BY messages.member",
|
||||
new { System = system.Id });
|
||||
}
|
||||
|
||||
public async Task<int> GetSystemMemberCount(PKSystem system, bool includePrivate)
|
||||
{
|
||||
var query = "select count(*) from members where system = @Id";
|
||||
@@ -264,19 +245,19 @@ namespace PluralKit.Core {
|
||||
using (var conn = await _conn.Obtain())
|
||||
return await conn.ExecuteScalarAsync<ulong>("select count(id) from members");
|
||||
}
|
||||
public async Task AddMessage(ulong senderId, ulong messageId, ulong guildId, ulong channelId, ulong originalMessage, PKMember member) {
|
||||
public async Task AddMessage(ulong senderId, ulong guildId, ulong channelId, ulong postedMessageId, ulong triggerMessageId, int proxiedMemberId) {
|
||||
using (var conn = await _conn.Obtain())
|
||||
// "on conflict do nothing" in the (pretty rare) case of duplicate events coming in from Discord, which would lead to a DB error before
|
||||
await conn.ExecuteAsync("insert into messages(mid, guild, channel, member, sender, original_mid) values(@MessageId, @GuildId, @ChannelId, @MemberId, @SenderId, @OriginalMid) on conflict do nothing", new {
|
||||
MessageId = messageId,
|
||||
MessageId = postedMessageId,
|
||||
GuildId = guildId,
|
||||
ChannelId = channelId,
|
||||
MemberId = member.Id,
|
||||
MemberId = proxiedMemberId,
|
||||
SenderId = senderId,
|
||||
OriginalMid = originalMessage
|
||||
OriginalMid = triggerMessageId
|
||||
});
|
||||
|
||||
_logger.Information("Stored message {Message} in channel {Channel}", messageId, channelId);
|
||||
_logger.Debug("Stored message {Message} in channel {Channel}", postedMessageId, channelId);
|
||||
}
|
||||
|
||||
public async Task<FullMessage> GetMessage(ulong id)
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
using System;
|
||||
using System.IO;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using Dapper;
|
||||
|
||||
using Npgsql;
|
||||
|
||||
using Serilog;
|
||||
|
||||
namespace PluralKit.Core {
|
||||
public class SchemaService
|
||||
{
|
||||
private const int TargetSchemaVersion = 6;
|
||||
|
||||
private DbConnectionFactory _conn;
|
||||
private ILogger _logger;
|
||||
|
||||
public SchemaService(DbConnectionFactory conn, ILogger logger)
|
||||
{
|
||||
_conn = conn;
|
||||
_logger = logger.ForContext<SchemaService>();
|
||||
}
|
||||
|
||||
public static void Initialize()
|
||||
{
|
||||
// Without these it'll still *work* but break at the first launch + probably cause other small issues
|
||||
NpgsqlConnection.GlobalTypeMapper.MapComposite<ProxyTag>("proxy_tag");
|
||||
NpgsqlConnection.GlobalTypeMapper.MapEnum<PrivacyLevel>("privacy_level");
|
||||
}
|
||||
|
||||
public async Task ApplyMigrations()
|
||||
{
|
||||
for (var version = 0; version <= TargetSchemaVersion; version++)
|
||||
await ApplyMigration(version);
|
||||
}
|
||||
|
||||
private async Task ApplyMigration(int migrationId)
|
||||
{
|
||||
// migrationId is the *target* version
|
||||
using var conn = await _conn.Obtain();
|
||||
using var tx = conn.BeginTransaction();
|
||||
|
||||
// See if we even have the info table... if not, we implicitly define the version as -1
|
||||
// This means migration 0 will get executed, which ensures we're at a consistent state.
|
||||
// *Technically* this also means schema version 0 will be identified as -1, but since we're only doing these
|
||||
// checks in the above for loop, this doesn't matter.
|
||||
var hasInfoTable = await conn.QuerySingleOrDefaultAsync<int>("select count(*) from information_schema.tables where table_name = 'info'") == 1;
|
||||
|
||||
int currentVersion;
|
||||
if (hasInfoTable)
|
||||
currentVersion = await conn.QuerySingleOrDefaultAsync<int>("select schema_version from info");
|
||||
else currentVersion = -1;
|
||||
|
||||
if (currentVersion >= migrationId)
|
||||
return; // Don't execute the migration if we're already at the target version.
|
||||
|
||||
using var stream = typeof(SchemaService).Assembly.GetManifestResourceStream($"PluralKit.Core.Migrations.{migrationId}.sql");
|
||||
if (stream == null) throw new ArgumentException("Invalid migration ID");
|
||||
|
||||
using var reader = new StreamReader(stream);
|
||||
var migrationQuery = await reader.ReadToEndAsync();
|
||||
|
||||
_logger.Information("Current schema version is {CurrentVersion}, applying migration {MigrationId}", currentVersion, migrationId);
|
||||
await conn.ExecuteAsync(migrationQuery, transaction: tx);
|
||||
tx.Commit();
|
||||
|
||||
// If the above migration creates new enum/composite types, we must tell Npgsql to reload the internal type caches
|
||||
// This will propagate to every other connection as well, since it marks the global type mapper collection dirty.
|
||||
// TODO: find a way to get around the cast to our internal tracker wrapper... this could break if that ever changes
|
||||
((PerformanceTrackingConnection) conn)._impl.ReloadTypes();
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user