From b025816abfe2b7473821d6b1c82c7ae3f84fb1dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Luthi?= Date: Wed, 27 Nov 2024 11:32:15 +0100 Subject: [PATCH] Infer the DbAdapter from the DbConnection --- README.md | 2 +- Respawn.DatabaseTests/InformixTests.cs | 8 ---- Respawn.DatabaseTests/MySqlTests.cs | 10 ----- Respawn.DatabaseTests/OracleTests.cs | 13 ------ Respawn.DatabaseTests/PostgresTests.cs | 28 ++----------- Respawn/Respawner.cs | 57 +++++++++++++++++++------- Respawn/RespawnerOptions.cs | 3 +- 7 files changed, 49 insertions(+), 72 deletions(-) diff --git a/README.md b/README.md index 071fb5d..879b62d 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ var respawner = await Respawner.CreateAsync(connection, new RespawnerOptions { "public" }, - DbAdapter = DbAdapter.Postgres + DbAdapter = DbAdapter.Postgres // 👈 optional, inferred from the connection for SQL Server, PostgreSQL, MySQL, Oracle and Informix }); ``` diff --git a/Respawn.DatabaseTests/InformixTests.cs b/Respawn.DatabaseTests/InformixTests.cs index 70127e1..a37a482 100644 --- a/Respawn.DatabaseTests/InformixTests.cs +++ b/Respawn.DatabaseTests/InformixTests.cs @@ -73,7 +73,6 @@ public async Task ShouldDeleteData() var checkPoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Informix, SchemasToInclude = new[] { "informix" } }); await checkPoint.ResetAsync(_connection); @@ -99,7 +98,6 @@ public async Task ShouldIgnoreTables() } var checkPoint = await Respawner.CreateAsync(_connection, new RespawnerOptions() { - DbAdapter = DbAdapter.Informix, SchemasToInclude = new[] { "informix" }, TablesToIgnore = new Table[] { "foo" } }); @@ -140,7 +138,6 @@ FOREIGN KEY (FooValue) REFERENCES Foo(Value) var checkPoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Informix, SchemasToInclude = new[] { "informix" } }); try @@ -203,7 +200,6 @@ ParentId INT NULL var checkPoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Informix, SchemasToInclude = new[] { "informix" } }); try @@ -251,7 +247,6 @@ ParentId INT NULL var checkPoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Informix, SchemasToInclude = new[] { "informix" } }); try @@ -331,7 +326,6 @@ public async Task ShouldHandleComplexCycles() var checkPoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Informix, SchemasToInclude = new[] { "informix" } }); try @@ -383,7 +377,6 @@ public async Task ShouldExcludeSchemas() var checkPoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Informix, SchemasToExclude = new[] { user_1 } }); try @@ -427,7 +420,6 @@ public async Task ShouldIncludeSchemas() var checkPoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Informix, SchemasToInclude = new[] { user_2 } }); try diff --git a/Respawn.DatabaseTests/MySqlTests.cs b/Respawn.DatabaseTests/MySqlTests.cs index 7407d38..e0fbc8b 100644 --- a/Respawn.DatabaseTests/MySqlTests.cs +++ b/Respawn.DatabaseTests/MySqlTests.cs @@ -58,7 +58,6 @@ public async Task ShouldDeleteData() var checkpoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.MySql, SchemasToInclude = new[] { "MySqlTests" } }); await checkpoint.ResetAsync(_connection); @@ -116,7 +115,6 @@ PRIMARY KEY (`BarValue`), var checkpoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.MySql, SchemasToInclude = new[] { "MySqlTests" } }); await checkpoint.ResetAsync(_connection); @@ -142,7 +140,6 @@ public async Task ShouldHandleSelfRelationships() var checkpoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.MySql, SchemasToInclude = new[] { "MySqlTests" } }); try @@ -181,7 +178,6 @@ public async Task ShouldHandleCircularRelationships() var checkpoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.MySql, SchemasToInclude = new[] { "MySqlTests" } }); await checkpoint.ResetAsync(_connection); @@ -226,7 +222,6 @@ public async Task ShouldHandleComplexCycles() var checkpoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.MySql, SchemasToInclude = new[] { "MySqlTests" } }); try @@ -260,7 +255,6 @@ public async Task ShouldIgnoreTables() var checkpoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.MySql, TablesToIgnore = new Table[] { "Foo" }, SchemasToInclude = new[] { "MySqlTests" } }); @@ -283,7 +277,6 @@ public async Task ShouldIncludeTables() var checkpoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.MySql, TablesToInclude = new Table[] { "Foo" }, SchemasToInclude = new[] { "MySqlTests" } }); @@ -313,7 +306,6 @@ public async Task ShouldExcludeSchemas() var checkpoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.MySql, SchemasToExclude = new[] { "A", "MySqlTests" } }); await checkpoint.ResetAsync(_connection); @@ -342,7 +334,6 @@ public async Task ShouldIncludeSchemas() var checkpoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.MySql, SchemasToInclude = new[] { "B" } }); await checkpoint.ResetAsync(_connection); @@ -361,7 +352,6 @@ public async Task ShouldResetSequencesAndIdentities() var checkpoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.MySql, WithReseed = true }); diff --git a/Respawn.DatabaseTests/OracleTests.cs b/Respawn.DatabaseTests/OracleTests.cs index 30e081a..eabde91 100644 --- a/Respawn.DatabaseTests/OracleTests.cs +++ b/Respawn.DatabaseTests/OracleTests.cs @@ -57,7 +57,6 @@ public async Task ShouldDeleteData() var respawner = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Oracle, SchemasToInclude = new[] { _createdUser } }); try @@ -87,7 +86,6 @@ public async Task ShouldDeleteMultipleTables() var respawner = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Oracle, SchemasToInclude = new[] { _createdUser }, }); await respawner.ResetAsync(_connection); @@ -113,7 +111,6 @@ public async Task ShouldHandleRelationships() var respawner = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Oracle, SchemasToInclude = new[] { _createdUser }, }); try @@ -147,7 +144,6 @@ public async Task ShouldHandleRelationshipsWithTableNames() var respawner = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Oracle, SchemasToInclude = new[] { _createdUser }, TablesToInclude = new[] { new Table(_createdUser, "foo"), new Table(_createdUser, "baz") }, TablesToIgnore = new[] { new Table(_createdUser, "bar") } @@ -193,7 +189,6 @@ public async Task ShouldHandleRelationshipsWithNamedPrimaryKeyConstraint() var respawner = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Oracle, SchemasToInclude = new[] { userA }, }); try @@ -251,7 +246,6 @@ public async Task ShouldHandleComplexCycles() var respawner = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Oracle, SchemasToInclude = new[] { _createdUser }, }); try @@ -294,7 +288,6 @@ public async Task ShouldHandleCircularRelationships() var respawner = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Oracle, SchemasToInclude = new[] { _createdUser }, }); try @@ -325,7 +318,6 @@ public async Task ShouldIgnoreTables() var respawner = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Oracle, SchemasToInclude = new[] { _createdUser }, TablesToIgnore = new[] { new Table("foo") } }); @@ -349,7 +341,6 @@ public async Task ShouldIgnoreTablesWithSchema() var respawner = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Oracle, SchemasToInclude = new[] { _createdUser }, TablesToIgnore = new[] { new Table(_createdUser, "foo") } }); @@ -373,7 +364,6 @@ public async Task ShouldIncludeTables() var respawner = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Oracle, SchemasToInclude = new[] { _createdUser }, TablesToInclude = new[] { new Table("foo") } }); @@ -397,7 +387,6 @@ public async Task ShouldIncludeTablesWithSchema() var respawner = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Oracle, TablesToInclude = new[] { new Table(_createdUser, "foo") } }); await respawner.ResetAsync(_connection); @@ -425,7 +414,6 @@ public async Task ShouldExcludeSchemas() var respawner = await Respawner.CreateAsync(_connection, new RespawnerOptions { // We must make sure we don't delete all these users that are used by Oracle - DbAdapter = DbAdapter.Oracle, SchemasToExclude = new[] { userA, "ANONYMOUS", "APEX_040000", "APEX_PUBLIC_USER", "APPQOSSYS", @@ -463,7 +451,6 @@ public async Task ShouldIncludeSchemas() var respawner = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Oracle, SchemasToInclude = new[] { userB } }); await respawner.ResetAsync(_connection); diff --git a/Respawn.DatabaseTests/PostgresTests.cs b/Respawn.DatabaseTests/PostgresTests.cs index be8e26d..b2859ae 100644 --- a/Respawn.DatabaseTests/PostgresTests.cs +++ b/Respawn.DatabaseTests/PostgresTests.cs @@ -65,10 +65,7 @@ public async Task ShouldDeleteData() _database.ExecuteScalar("SELECT COUNT(1) FROM \"foo\"").ShouldBe(100); - var checkpoint = await Respawner.CreateAsync(_connection, new RespawnerOptions - { - DbAdapter = DbAdapter.Postgres - }); + var checkpoint = await Respawner.CreateAsync(_connection); await checkpoint.ResetAsync(_connection); _database.ExecuteScalar("SELECT COUNT(1) FROM \"foo\"").ShouldBe(0); @@ -88,7 +85,6 @@ public async Task ShouldIgnoreTables() var checkpoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Postgres, TablesToIgnore = new Table[] { "foo" } }); await checkpoint.ResetAsync(_connection); @@ -112,7 +108,6 @@ public async Task ShouldIgnoreTablesIfSchemaSpecified() var checkpoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Postgres, TablesToIgnore = new Table[] { new Table("eggs", "foo") } }); await checkpoint.ResetAsync(_connection); @@ -135,7 +130,6 @@ public async Task ShouldIncludeTables() var checkpoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Postgres, TablesToInclude = new Table[] { "foo" } }); await checkpoint.ResetAsync(_connection); @@ -159,7 +153,6 @@ public async Task ShouldIncludeTablesIfSchemaSpecified() var checkpoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Postgres, TablesToInclude = new Table[] { new Table("eggs", "foo") } }); await checkpoint.ResetAsync(_connection); @@ -185,7 +178,6 @@ public async Task ShouldHandleRelationships() var checkpoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Postgres, SchemasToInclude = new [] { "public" } }); try @@ -222,10 +214,7 @@ public async Task ShouldHandleCircularRelationships() _database.ExecuteScalar("SELECT COUNT(1) FROM parent").ShouldBe(100); _database.ExecuteScalar("SELECT COUNT(1) FROM child").ShouldBe(100); - var checkpoint = await Respawner.CreateAsync(_connection, new RespawnerOptions - { - DbAdapter = DbAdapter.Postgres - }); + var checkpoint = await Respawner.CreateAsync(_connection); try { await checkpoint.ResetAsync(_connection); @@ -254,10 +243,7 @@ public async Task ShouldHandleSelfRelationships() _database.ExecuteScalar("SELECT COUNT(1) FROM foo").ShouldBe(100); - var checkpoint = await Respawner.CreateAsync(_connection, new RespawnerOptions - { - DbAdapter = DbAdapter.Postgres - }); + var checkpoint = await Respawner.CreateAsync(_connection); try { await checkpoint.ResetAsync(_connection); @@ -305,10 +291,7 @@ public async Task ShouldHandleComplexCycles() _database.ExecuteScalar("SELECT COUNT(1) FROM e").ShouldBe(1); _database.ExecuteScalar("SELECT COUNT(1) FROM f").ShouldBe(1); - var checkpoint = await Respawner.CreateAsync(_connection, new RespawnerOptions - { - DbAdapter = DbAdapter.Postgres - }); + var checkpoint = await Respawner.CreateAsync(_connection); try { await checkpoint.ResetAsync(_connection); @@ -344,7 +327,6 @@ public async Task ShouldExcludeSchemas() var checkpoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Postgres, SchemasToExclude = new [] { "a" } }); await checkpoint.ResetAsync(_connection); @@ -369,7 +351,6 @@ public async Task ShouldIncludeSchemas() var checkpoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Postgres, SchemasToInclude = new [] { "b" } }); await checkpoint.ResetAsync(_connection); @@ -388,7 +369,6 @@ public async Task ShouldResetSequencesAndIdentities() var checkpoint = await Respawner.CreateAsync(_connection, new RespawnerOptions { - DbAdapter = DbAdapter.Postgres, WithReseed = true }); diff --git a/Respawn/Respawner.cs b/Respawn/Respawner.cs index 8c35fd3..5fac5f5 100644 --- a/Respawn/Respawner.cs +++ b/Respawn/Respawner.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Collections.Generic; using System.Data.Common; using Microsoft.Data.SqlClient; @@ -10,14 +10,33 @@ namespace Respawn { public class Respawner { + private readonly IDbAdapter _dbAdapter; private IList _temporalTables = new List(); public RespawnerOptions Options { get; } public string? DeleteSql { get; private set; } public string? ReseedSql { get; private set; } - private Respawner(RespawnerOptions options) + private Respawner(RespawnerOptions options, IDbAdapter dbAdapter) { - Options = options; + if (options.DbAdapter != null) + { + Options = options; + } + else + { + Options = new RespawnerOptions + { + TablesToIgnore = options.TablesToIgnore, + TablesToInclude = options.TablesToInclude, + SchemasToInclude = options.SchemasToInclude, + SchemasToExclude = options.SchemasToExclude, + CheckTemporalTables = options.CheckTemporalTables, + WithReseed = options.WithReseed, + CommandTimeout = options.CommandTimeout, + DbAdapter = dbAdapter, + }; + } + _dbAdapter = dbAdapter; } /// @@ -31,7 +50,7 @@ public static async Task CreateAsync(string nameOrConnectionString, R { options ??= new RespawnerOptions(); - if (options.DbAdapter is not SqlServerDbAdapter) + if (options.DbAdapter is not null && options.DbAdapter is not SqlServerDbAdapter) { throw new ArgumentException("This overload only supports the SqlDataAdapter. To use an alternative adapter, use the overload that supplies a DbConnection.", nameof(options.DbAdapter)); } @@ -40,7 +59,7 @@ public static async Task CreateAsync(string nameOrConnectionString, R await connection.OpenAsync(); - var respawner = new Respawner(options); + var respawner = new Respawner(options, DbAdapter.SqlServer); await respawner.BuildDeleteTables(connection); @@ -57,7 +76,17 @@ public static async Task CreateAsync(DbConnection connection, Respawn { options ??= new RespawnerOptions(); - var respawner = new Respawner(options); + var dbAdapter = options.DbAdapter ?? connection.GetType().Name switch + { + "SqlConnection" => DbAdapter.SqlServer, + "NpgsqlConnection" => DbAdapter.Postgres, + "MySqlConnection" => DbAdapter.MySql, + "OracleConnection" => DbAdapter.Oracle, + "DB2Connection" or "IfxConnection" => DbAdapter.Informix, + _ => throw new ArgumentException("The database adapter could not be inferred from the DbConnection. Please pass an explicit database adapter in the options.", nameof(options)) + }; + + var respawner = new Respawner(options, dbAdapter); await respawner.BuildDeleteTables(connection); @@ -78,7 +107,7 @@ public virtual async Task ResetAsync(DbConnection connection) { if (_temporalTables.Any()) { - var turnOffVersioningCommandText = Options.DbAdapter.BuildTurnOffSystemVersioningCommandText(_temporalTables); + var turnOffVersioningCommandText = _dbAdapter.BuildTurnOffSystemVersioningCommandText(_temporalTables); await ExecuteAlterSystemVersioningAsync(connection, turnOffVersioningCommandText); } @@ -90,7 +119,7 @@ public virtual async Task ResetAsync(DbConnection connection) { if (_temporalTables.Any()) { - var turnOnVersioningCommandText = Options.DbAdapter.BuildTurnOnSystemVersioningCommandText(_temporalTables); + var turnOnVersioningCommandText = _dbAdapter.BuildTurnOnSystemVersioningCommandText(_temporalTables); await ExecuteAlterSystemVersioningAsync(connection, turnOnVersioningCommandText); } } @@ -140,7 +169,7 @@ private async Task BuildDeleteTables(DbConnection connection) "No tables found. Ensure your target database has at least one non-ignored table to reset. Consider initializing the database and/or running migrations."); } - if (Options.CheckTemporalTables && await Options.DbAdapter.CheckSupportsTemporalTables(connection)) + if (Options.CheckTemporalTables && await _dbAdapter.CheckSupportsTemporalTables(connection)) { _temporalTables = await GetAllTemporalTables(connection); } @@ -149,14 +178,14 @@ private async Task BuildDeleteTables(DbConnection connection) var graphBuilder = new GraphBuilder(allTables, allRelationships); - DeleteSql = Options.DbAdapter.BuildDeleteCommandText(graphBuilder); - ReseedSql = Options.WithReseed ? Options.DbAdapter.BuildReseedSql(graphBuilder.ToDelete) : null; + DeleteSql = _dbAdapter.BuildDeleteCommandText(graphBuilder); + ReseedSql = Options.WithReseed ? _dbAdapter.BuildReseedSql(graphBuilder.ToDelete) : null; } private async Task> GetRelationships(DbConnection connection) { var relationships = new HashSet(); - var commandText = Options.DbAdapter.BuildRelationshipCommandText(Options); + var commandText = _dbAdapter.BuildRelationshipCommandText(Options); await using var cmd = connection.CreateCommand(); @@ -179,7 +208,7 @@ private async Task> GetAllTables(DbConnection connection) { var tables = new HashSet(); - var commandText = Options.DbAdapter.BuildTableCommandText(Options); + var commandText = _dbAdapter.BuildTableCommandText(Options); await using var cmd = connection.CreateCommand(); @@ -199,7 +228,7 @@ private async Task> GetAllTemporalTables(DbConnection conne { var tables = new List(); - var commandText = Options.DbAdapter.BuildTemporalTableCommandText(Options); + var commandText = _dbAdapter.BuildTemporalTableCommandText(Options); await using var cmd = connection.CreateCommand(); diff --git a/Respawn/RespawnerOptions.cs b/Respawn/RespawnerOptions.cs index 7eb762f..ca64882 100644 --- a/Respawn/RespawnerOptions.cs +++ b/Respawn/RespawnerOptions.cs @@ -12,6 +12,5 @@ public class RespawnerOptions public bool CheckTemporalTables { get; init; } public bool WithReseed { get; init; } public int? CommandTimeout { get; init; } - public IDbAdapter DbAdapter { get; init; } = Respawn.DbAdapter.SqlServer; - + public IDbAdapter? DbAdapter { get; init; } } \ No newline at end of file