refactor: add SqlKata for SQL generation, move connection handling into ModelRepository
This commit is contained in:
@@ -2,6 +2,7 @@ using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Data;
|
||||
using System.IO;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
using App.Metrics;
|
||||
@@ -14,6 +15,9 @@ using Npgsql;
|
||||
|
||||
using Serilog;
|
||||
|
||||
using SqlKata;
|
||||
using SqlKata.Compilers;
|
||||
|
||||
namespace PluralKit.Core
|
||||
{
|
||||
internal class Database: IDatabase
|
||||
@@ -46,6 +50,8 @@ namespace PluralKit.Core
|
||||
}.ConnectionString;
|
||||
}
|
||||
|
||||
private static readonly PostgresCompiler _compiler = new();
|
||||
|
||||
public static void InitStatic()
|
||||
{
|
||||
DefaultTypeMap.MatchNamesWithUnderscores = true;
|
||||
@@ -151,28 +157,83 @@ namespace PluralKit.Core
|
||||
|
||||
public override T[] Parse(object value) => Array.ConvertAll((TInner[])value, v => _factory(v));
|
||||
}
|
||||
}
|
||||
|
||||
public static class DatabaseExt
|
||||
{
|
||||
public static async Task Execute(this IDatabase db, Func<IPKConnection, Task> func)
|
||||
public async Task Execute(Func<IPKConnection, Task> func)
|
||||
{
|
||||
await using var conn = await db.Obtain();
|
||||
await using var conn = await Obtain();
|
||||
await func(conn);
|
||||
}
|
||||
|
||||
public static async Task<T> Execute<T>(this IDatabase db, Func<IPKConnection, Task<T>> func)
|
||||
public async Task<T> Execute<T>(Func<IPKConnection, Task<T>> func)
|
||||
{
|
||||
await using var conn = await db.Obtain();
|
||||
await using var conn = await Obtain();
|
||||
return await func(conn);
|
||||
}
|
||||
|
||||
public static async IAsyncEnumerable<T> Execute<T>(this IDatabase db, Func<IPKConnection, IAsyncEnumerable<T>> func)
|
||||
public async IAsyncEnumerable<T> Execute<T>(Func<IPKConnection, IAsyncEnumerable<T>> func)
|
||||
{
|
||||
await using var conn = await db.Obtain();
|
||||
await using var conn = await Obtain();
|
||||
|
||||
await foreach (var val in func(conn))
|
||||
yield return val;
|
||||
}
|
||||
|
||||
public async Task<int> ExecuteQuery(Query q, string extraSql = "", [CallerMemberName] string queryName = "")
|
||||
{
|
||||
var query = _compiler.Compile(q);
|
||||
using var conn = await Obtain();
|
||||
using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName)))
|
||||
return await conn.ExecuteAsync(query.Sql + $" {extraSql}", query.NamedBindings);
|
||||
}
|
||||
|
||||
public async Task<T> QueryFirst<T>(Query q, string extraSql = "", [CallerMemberName] string queryName = "")
|
||||
{
|
||||
var query = _compiler.Compile(q);
|
||||
using var conn = await Obtain();
|
||||
using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName)))
|
||||
return await conn.QueryFirstOrDefaultAsync<T>(query.Sql + $" {extraSql}", query.NamedBindings);
|
||||
}
|
||||
|
||||
public async Task<T> QueryFirst<T>(IPKConnection? conn, Query q, string extraSql = "", [CallerMemberName] string queryName = "")
|
||||
{
|
||||
if (conn == null)
|
||||
return await QueryFirst<T>(q, extraSql, queryName);
|
||||
|
||||
var query = _compiler.Compile(q);
|
||||
using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName)))
|
||||
return await conn.QueryFirstOrDefaultAsync<T>(query.Sql + $" {extraSql}", query.NamedBindings);
|
||||
}
|
||||
|
||||
public async Task<IEnumerable<T>> Query<T>(Query q, [CallerMemberName] string queryName = "")
|
||||
{
|
||||
var query = _compiler.Compile(q);
|
||||
using var conn = await Obtain();
|
||||
using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName)))
|
||||
return await conn.QueryAsync<T>(query.Sql, query.NamedBindings);
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<T> QueryStream<T>(Query q, [CallerMemberName] string queryName = "")
|
||||
{
|
||||
var query = _compiler.Compile(q);
|
||||
using var conn = await Obtain();
|
||||
using (_metrics.Measure.Timer.Time(CoreMetrics.DatabaseQuery, new MetricTags("Query", queryName)))
|
||||
await foreach (var val in conn.QueryStreamAsync<T>(query.Sql, query.NamedBindings))
|
||||
yield return val;
|
||||
}
|
||||
|
||||
// the procedures (message_context and proxy_members, as of writing) have their own metrics tracking elsewhere
|
||||
// still, including them here for consistency
|
||||
|
||||
public async Task<T> QuerySingleProcedure<T>(string queryName, object param)
|
||||
{
|
||||
using var conn = await Obtain();
|
||||
return await conn.QueryFirstAsync<T>(queryName, param, commandType: CommandType.StoredProcedure);
|
||||
}
|
||||
|
||||
public async Task<IEnumerable<T>> QueryProcedure<T>(string queryName, object param)
|
||||
{
|
||||
using var conn = await Obtain();
|
||||
return await conn.QueryAsync<T>(queryName, param, commandType: CommandType.StoredProcedure);
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user