From 26105073d9c484a4a228cec47911891e5f5f4508 Mon Sep 17 00:00:00 2001 From: John Campion Date: Thu, 17 Dec 2020 21:23:36 -0500 Subject: [PATCH 1/4] Added NoTracking support Closes #113 --- src/MongoFramework/IMongoDbSet.cs | 3 + src/MongoFramework/MongoDbSet.cs | 62 +++++- src/MongoFramework/MongoDbTenantSet.cs | 27 ++- tests/MongoFramework.Tests/MongoDbSetTests.cs | 119 +++++++++++ .../MongoDbTenantSetTests.cs | 187 ++++++++++++++++++ 5 files changed, 383 insertions(+), 15 deletions(-) diff --git a/src/MongoFramework/IMongoDbSet.cs b/src/MongoFramework/IMongoDbSet.cs index a3f04f5b..acd5637c 100644 --- a/src/MongoFramework/IMongoDbSet.cs +++ b/src/MongoFramework/IMongoDbSet.cs @@ -23,11 +23,14 @@ public interface IMongoDbSet : IMongoDbSet, IQueryable where T TEntity Create(); void Add(TEntity entity); void AddRange(IEnumerable entities); + void Attach(TEntity entity); + void AttachRange(IEnumerable entities); void Update(TEntity entity); void UpdateRange(IEnumerable entities); void Remove(TEntity entity); void RemoveRange(IEnumerable entities); void RemoveRange(Expression> predicate); void RemoveById(object entityId); + IQueryable AsNoTracking(); } } diff --git a/src/MongoFramework/MongoDbSet.cs b/src/MongoFramework/MongoDbSet.cs index b2019853..cb9c9609 100644 --- a/src/MongoFramework/MongoDbSet.cs +++ b/src/MongoFramework/MongoDbSet.cs @@ -129,8 +129,32 @@ public virtual void AddRange(IEnumerable entities) { Context.ChangeTracker.SetEntityState(entity, EntityEntryState.Added); } + } + + /// + /// Marks the entity as unchanged in the change tracker and starts tracking. + /// + /// + public virtual void Attach(TEntity entity) + { + Check.NotNull(entity, nameof(entity)); + + Context.ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges); } + /// + /// Marks the collection of entities as unchanged in the change tracker and starts tracking. + /// + /// + public virtual void AttachRange(IEnumerable entities) + { + Check.NotNull(entities, nameof(entities)); + foreach (var entity in entities) + { + Context.ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges); + } + } + /// /// Marks the entity for updating. /// @@ -209,30 +233,50 @@ public async Task SaveChangesAsync(CancellationToken cancellationToken = default #region IQueryable Implementation - protected virtual IQueryable GetQueryable() + protected virtual IQueryable GetQueryable(bool trackEntities) { var queryable = Context.Query(); - var provider = queryable.Provider as IMongoFrameworkQueryProvider; - provider.EntityProcessors.Add(new EntityTrackingProcessor(Context)); + if (trackEntities) + { + var provider = queryable.Provider as IMongoFrameworkQueryProvider; + provider.EntityProcessors.Add(new EntityTrackingProcessor(Context)); + } return queryable; } - public Expression Expression => GetQueryable().Expression; + public Expression Expression => GetQueryable(true).Expression; - public Type ElementType => GetQueryable().ElementType; + public Type ElementType => GetQueryable(true).ElementType; - public IQueryProvider Provider => GetQueryable().Provider; + public IQueryProvider Provider => GetQueryable(true).Provider; public IEnumerator GetEnumerator() { - return GetQueryable().GetEnumerator(); + return GetQueryable(true).GetEnumerator(); } IEnumerator IEnumerable.GetEnumerator() { return GetEnumerator(); - } - + } + + public virtual IQueryable AsNoTracking() + { + return GetQueryable(false); + } + #endregion + + } + + public static class DbSetExtensions + { + public static IQueryable NoTracking(this IMongoDbSet dbSet) where TEntity : class + { + var queryable = dbSet.Context.Query(); + var provider = queryable.Provider as IMongoFrameworkQueryProvider; + return queryable; + } } + } \ No newline at end of file diff --git a/src/MongoFramework/MongoDbTenantSet.cs b/src/MongoFramework/MongoDbTenantSet.cs index 42872142..c572fb9a 100644 --- a/src/MongoFramework/MongoDbTenantSet.cs +++ b/src/MongoFramework/MongoDbTenantSet.cs @@ -138,7 +138,19 @@ public override void AddRange(IEnumerable entities) entity.TenantId = Context.TenantId; } base.AddRange(entities); - } + } + + public override void Attach(TEntity entity) + { + CheckEntity(entity); + base.Attach(entity); + } + + public override void AttachRange(IEnumerable entities) + { + CheckEntities(entities); + base.AttachRange(entities); + } public override void Update(TEntity entity) { @@ -173,12 +185,15 @@ public override void RemoveRange(Expression> predicate) #region IQueryable Implementation - protected override IQueryable GetQueryable() + protected override IQueryable GetQueryable(bool trackEntities) { var key = Context.TenantId; var queryable = Context.Query().Where(c => c.TenantId == key); - var provider = queryable.Provider as IMongoFrameworkQueryProvider; - provider.EntityProcessors.Add(new EntityTrackingProcessor(Context)); + if (trackEntities) + { + var provider = queryable.Provider as IMongoFrameworkQueryProvider; + provider.EntityProcessors.Add(new EntityTrackingProcessor(Context)); + } return queryable; } @@ -186,8 +201,8 @@ public IQueryable GetSearchTextQueryable(string search) { var key = Context.TenantId; var queryable = Context.Query().WhereFilter(b => b.Text(search)).Where(c => c.TenantId == key); - var provider = queryable.Provider as IMongoFrameworkQueryProvider; - provider.EntityProcessors.Add(new EntityTrackingProcessor(Context)); + var provider = queryable.Provider as IMongoFrameworkQueryProvider; + provider.EntityProcessors.Add(new EntityTrackingProcessor(Context)); return queryable; } diff --git a/tests/MongoFramework.Tests/MongoDbSetTests.cs b/tests/MongoFramework.Tests/MongoDbSetTests.cs index b99a9ab1..801d688d 100644 --- a/tests/MongoFramework.Tests/MongoDbSetTests.cs +++ b/tests/MongoFramework.Tests/MongoDbSetTests.cs @@ -417,6 +417,125 @@ public async Task SuccessfullyLinqFindTrackedAsync() Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.Updated, context.ChangeTracker.GetEntry(result).State); } + [TestMethod] + public void SuccessfullyLinqFindNoTracking() + { + var connection = TestConfiguration.GetConnection(); + var context = new MongoDbContext(connection); + var dbSet = new MongoDbSet(context); + + var model = new TestModel + { + Id = "abcd", + Description = "SuccessfullyLinqFindNoTracking.1" + }; + + dbSet.Add(model); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbContext(connection); + dbSet = new MongoDbSet(context); + + var result = dbSet.AsNoTracking().FirstOrDefault(); + + Assert.IsNull(context.ChangeTracker.GetEntry(result)); + } + + [TestMethod] + public async Task SuccessfullyLinqFindNoTrackingAsync() + { + var connection = TestConfiguration.GetConnection(); + var context = new MongoDbContext(connection); + var dbSet = new MongoDbSet(context); + + var model = new TestModel + { + Id = "abcd", + Description = "SuccessfullyFindTracked.1" + }; + + dbSet.Add(model); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbContext(connection); + dbSet = new MongoDbSet(context); + + var result = await dbSet.AsNoTracking().FirstOrDefaultAsync(); + + Assert.IsNull(context.ChangeTracker.GetEntry(result)); + } + + [TestMethod] + public void SuccessfullyAttachUntrackedEntity() + { + var connection = TestConfiguration.GetConnection(); + var context = new MongoDbContext(connection); + var dbSet = new MongoDbSet(context); + + var model = new TestModel + { + Id = "abcd", + Description = "SuccessfullyAttachUntrackedEntity.1" + }; + + dbSet.Add(model); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbContext(connection); + dbSet = new MongoDbSet(context); + + var result = dbSet.AsNoTracking().FirstOrDefault(); + + dbSet.Attach(result); + + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result).State); + } + + [TestMethod] + public void SuccessfullyAttachUntrackedEntities() + { + var connection = TestConfiguration.GetConnection(); + var context = new MongoDbContext(connection); + var dbSet = new MongoDbSet(context); + + var entities = new[] { + new TestModel + { + Description = "SuccessfullyAttachUntrackedEntities.1" + }, + new TestModel + { + Description = "SuccessfullyAttachUntrackedEntities.2", + BooleanField = true + } + }; + + dbSet.AddRange(entities); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbContext(connection); + dbSet = new MongoDbSet(context); + + var result = dbSet.AsNoTracking().ToList(); + + dbSet.AttachRange(result); + + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result[0]).State); + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result[1]).State); + } + } } \ No newline at end of file diff --git a/tests/MongoFramework.Tests/MongoDbTenantSetTests.cs b/tests/MongoFramework.Tests/MongoDbTenantSetTests.cs index 57cad153..166bc768 100644 --- a/tests/MongoFramework.Tests/MongoDbTenantSetTests.cs +++ b/tests/MongoFramework.Tests/MongoDbTenantSetTests.cs @@ -875,5 +875,192 @@ public async Task SuccessfullyLinqFindTrackedAsync() Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.Updated, context.ChangeTracker.GetEntry(result).State); } + [TestMethod] + public void SuccessfullyLinqFindNoTracking() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var model = new TestModel + { + Id = "abcd", + Description = "SuccessfullyLinqFindNoTracking.1" + }; + + dbSet.Add(model); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbTenantContext(connection, tenantId); + dbSet = new MongoDbTenantSet(context); + + var result = dbSet.AsNoTracking().FirstOrDefault(); + + Assert.IsNull(context.ChangeTracker.GetEntry(result)); + } + + [TestMethod] + public async Task SuccessfullyLinqFindNoTrackingAsync() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var model = new TestModel + { + Id = "abcd", + Description = "SuccessfullyFindTracked.1" + }; + + dbSet.Add(model); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbTenantContext(connection, tenantId); + dbSet = new MongoDbTenantSet(context); + + var result = await dbSet.AsNoTracking().FirstOrDefaultAsync(); + + Assert.IsNull(context.ChangeTracker.GetEntry(result)); + } + + [TestMethod] + public void SuccessfullyAttachUntrackedEntity() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var model = new TestModel + { + Id = "abcd", + Description = "SuccessfullyAttachUntrackedEntity.1" + }; + + dbSet.Add(model); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbTenantContext(connection, tenantId); + dbSet = new MongoDbTenantSet(context); + + var result = dbSet.AsNoTracking().FirstOrDefault(); + + dbSet.Attach(result); + + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result).State); + } + + [TestMethod] + public void SuccessfullyAttachUntrackedEntities() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var entities = new[] { + new TestModel + { + Description = "SuccessfullyAttachUntrackedEntities.1" + }, + new TestModel + { + Description = "SuccessfullyAttachUntrackedEntities.2", + BooleanField = true + } + }; + + dbSet.AddRange(entities); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbTenantContext(connection, tenantId); + dbSet = new MongoDbTenantSet(context); + + var result = dbSet.AsNoTracking().ToList(); + + dbSet.AttachRange(result); + + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result[0]).State); + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result[1]).State); + } + + [TestMethod] + public void AttachRejectsMismatchedEntity() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var model = new TestModel + { + Id = "abcd", + Description = "AttachRejectsMismatchedEntity.1" + }; + + dbSet.Add(model); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbTenantContext(connection, tenantId); + dbSet = new MongoDbTenantSet(context); + + var result = dbSet.AsNoTracking().FirstOrDefault(); + result.TenantId = tenantId + "a"; + + Assert.ThrowsException(() => dbSet.Attach(result)); + } + + [TestMethod] + public void AttachRejectsMismatchedEntities() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var entities = new[] { + new TestModel + { + Description = "AttachRejectsMismatchedEntities.1" + }, + new TestModel + { + Description = "AttachRejectsMismatchedEntities.2", + BooleanField = true + } + }; + + dbSet.AddRange(entities); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbTenantContext(connection, tenantId); + dbSet = new MongoDbTenantSet(context); + + var result = dbSet.AsNoTracking().ToList(); + result[0].TenantId = tenantId + "a"; + + Assert.ThrowsException(() => dbSet.AttachRange(result)); + } + } } \ No newline at end of file From c6ea9a0ad871e869b050c73f7ea3a28d50f20ff4 Mon Sep 17 00:00:00 2001 From: John Campion Date: Thu, 17 Dec 2020 21:28:51 -0500 Subject: [PATCH 2/4] removed bit of code that didn't belong --- src/MongoFramework/MongoDbSet.cs | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/MongoFramework/MongoDbSet.cs b/src/MongoFramework/MongoDbSet.cs index cb9c9609..43f48117 100644 --- a/src/MongoFramework/MongoDbSet.cs +++ b/src/MongoFramework/MongoDbSet.cs @@ -268,15 +268,5 @@ public virtual IQueryable AsNoTracking() #endregion } - - public static class DbSetExtensions - { - public static IQueryable NoTracking(this IMongoDbSet dbSet) where TEntity : class - { - var queryable = dbSet.Context.Query(); - var provider = queryable.Provider as IMongoFrameworkQueryProvider; - return queryable; - } - } } \ No newline at end of file From 9021b74bc5bb2eb3e8cbbd72ca4f138d6b417cbe Mon Sep 17 00:00:00 2001 From: John Campion Date: Thu, 17 Dec 2020 22:38:18 -0500 Subject: [PATCH 3/4] Move Attach methods to context To match EF --- src/MongoFramework/IMongoDbContext.cs | 3 + src/MongoFramework/IMongoDbSet.cs | 2 - src/MongoFramework/IMongoDbTenantContext.cs | 10 +- src/MongoFramework/MongoDbContext.cs | 25 ++++ src/MongoFramework/MongoDbSet.cs | 26 +--- src/MongoFramework/MongoDbTenantContext.cs | 54 +++++++- src/MongoFramework/MongoDbTenantSet.cs | 66 +++------ .../MongoDbContextTenantTests.cs | 128 +++++++++++++++++ .../MongoDbContextTests.cs | 64 +++++++++ tests/MongoFramework.Tests/MongoDbSetTests.cs | 69 +-------- .../MongoDbTenantSetTests.cs | 131 ------------------ 11 files changed, 298 insertions(+), 280 deletions(-) diff --git a/src/MongoFramework/IMongoDbContext.cs b/src/MongoFramework/IMongoDbContext.cs index 60b3ad82..46ffb118 100644 --- a/src/MongoFramework/IMongoDbContext.cs +++ b/src/MongoFramework/IMongoDbContext.cs @@ -19,5 +19,8 @@ public interface IMongoDbContext void SaveChanges(); Task SaveChangesAsync(CancellationToken cancellationToken = default(CancellationToken)); + + void Attach(TEntity entity) where TEntity : class; + void AttachRange(IEnumerable entities) where TEntity : class; } } diff --git a/src/MongoFramework/IMongoDbSet.cs b/src/MongoFramework/IMongoDbSet.cs index acd5637c..bfecfb53 100644 --- a/src/MongoFramework/IMongoDbSet.cs +++ b/src/MongoFramework/IMongoDbSet.cs @@ -23,8 +23,6 @@ public interface IMongoDbSet : IMongoDbSet, IQueryable where T TEntity Create(); void Add(TEntity entity); void AddRange(IEnumerable entities); - void Attach(TEntity entity); - void AttachRange(IEnumerable entities); void Update(TEntity entity); void UpdateRange(IEnumerable entities); void Remove(TEntity entity); diff --git a/src/MongoFramework/IMongoDbTenantContext.cs b/src/MongoFramework/IMongoDbTenantContext.cs index 6d9ef304..7fd81cf3 100644 --- a/src/MongoFramework/IMongoDbTenantContext.cs +++ b/src/MongoFramework/IMongoDbTenantContext.cs @@ -1,7 +1,11 @@ -namespace MongoFramework +using System.Collections.Generic; + +namespace MongoFramework { - public interface IMongoDbTenantContext : IMongoDbContext - { + public interface IMongoDbTenantContext : IMongoDbContext + { string TenantId { get; } + void CheckEntity(IHaveTenantId entity); + void CheckEntities(IEnumerable entity); } } diff --git a/src/MongoFramework/MongoDbContext.cs b/src/MongoFramework/MongoDbContext.cs index f0395477..4d638a99 100644 --- a/src/MongoFramework/MongoDbContext.cs +++ b/src/MongoFramework/MongoDbContext.cs @@ -9,6 +9,7 @@ using System.Reflection; using System.Threading; using System.Threading.Tasks; +using MongoFramework.Utilities; namespace MongoFramework { @@ -107,6 +108,7 @@ public virtual async Task SaveChangesAsync(CancellationToken cancellationToken = ChangeTracker.CommitChanges(); CommandStaging.CommitChanges(); } + private static async Task InternalSaveChangesAsync(IMongoDbConnection connection, IEnumerable commands, WriteModelOptions options, CancellationToken cancellationToken) where TEntity : class { await EntityIndexWriter.ApplyIndexingAsync(connection); @@ -131,7 +133,30 @@ public IQueryable Query() where TEntity : class var provider = new MongoFrameworkQueryProvider(Connection); return new MongoFrameworkQueryable(provider); } + + /// + /// Marks the entity as unchanged in the change tracker and starts tracking. + /// + /// + public virtual void Attach(TEntity entity) where TEntity : class + { + Check.NotNull(entity, nameof(entity)); + ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges); + } + /// + /// Marks the collection of entities as unchanged in the change tracker and starts tracking. + /// + /// + public virtual void AttachRange(IEnumerable entities) where TEntity : class + { + Check.NotNull(entities, nameof(entities)); + foreach (var entity in entities) + { + ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges); + } + } + public void Dispose() { Dispose(true); diff --git a/src/MongoFramework/MongoDbSet.cs b/src/MongoFramework/MongoDbSet.cs index 43f48117..86d3741f 100644 --- a/src/MongoFramework/MongoDbSet.cs +++ b/src/MongoFramework/MongoDbSet.cs @@ -129,32 +129,8 @@ public virtual void AddRange(IEnumerable entities) { Context.ChangeTracker.SetEntityState(entity, EntityEntryState.Added); } - } - - /// - /// Marks the entity as unchanged in the change tracker and starts tracking. - /// - /// - public virtual void Attach(TEntity entity) - { - Check.NotNull(entity, nameof(entity)); - - Context.ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges); } - /// - /// Marks the collection of entities as unchanged in the change tracker and starts tracking. - /// - /// - public virtual void AttachRange(IEnumerable entities) - { - Check.NotNull(entities, nameof(entities)); - - foreach (var entity in entities) - { - Context.ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges); - } - } - + /// /// Marks the entity for updating. /// diff --git a/src/MongoFramework/MongoDbTenantContext.cs b/src/MongoFramework/MongoDbTenantContext.cs index 62a838e5..48e5a921 100644 --- a/src/MongoFramework/MongoDbTenantContext.cs +++ b/src/MongoFramework/MongoDbTenantContext.cs @@ -1,6 +1,7 @@ -using MongoFramework.Infrastructure.Commands; -using MongoFramework.Utilities; -using System; +using System.Collections.Generic; +using MongoFramework.Infrastructure; +using MongoFramework.Infrastructure.Commands; +using MongoFramework.Utilities; namespace MongoFramework { @@ -23,5 +24,52 @@ protected override WriteModelOptions GetWriteModelOptions() { return new WriteModelOptions { TenantId = TenantId }; } + + public virtual void CheckEntity(IHaveTenantId entity) + { + Check.NotNull(entity, nameof(entity)); + + if (entity.TenantId != TenantId) + { + throw new MultiTenantException($"Entity type {entity.GetType().Name}, tenant ID does not match. Expected: {TenantId}, Entity has: {entity.TenantId}"); + } + } + + public virtual void CheckEntities(IEnumerable entities) + { + Check.NotNull(entities, nameof(entities)); + + foreach (var entity in entities) + { + CheckEntity(entity); + } + } + + /// + /// Marks the entity as unchanged in the change tracker and starts tracking. + /// + /// + public override void Attach(TEntity entity) where TEntity : class + { + if (typeof(IHaveTenantId).IsAssignableFrom(typeof(TEntity))) + { + CheckEntity(entity as IHaveTenantId); + } + base.Attach(entity); + } + + /// + /// Marks the collection of entities as unchanged in the change tracker and starts tracking. + /// + /// + public override void AttachRange(IEnumerable entities) where TEntity : class + { + if (typeof(IHaveTenantId).IsAssignableFrom(typeof(TEntity))) + { + CheckEntities(entities as IEnumerable); + } + base.AttachRange(entities); + } + } } diff --git a/src/MongoFramework/MongoDbTenantSet.cs b/src/MongoFramework/MongoDbTenantSet.cs index c572fb9a..8eb8dda1 100644 --- a/src/MongoFramework/MongoDbTenantSet.cs +++ b/src/MongoFramework/MongoDbTenantSet.cs @@ -27,28 +27,8 @@ public MongoDbTenantSet(IMongoDbContext context) : base(context) { Context = context as IMongoDbTenantContext ?? throw new ArgumentException("Context provided to a MongoDbTenantSet must be IMongoDbTenantContext",nameof(context)); } - - protected virtual void CheckEntity(TEntity entity) - { - Check.NotNull(entity, nameof(entity)); - - if (entity.TenantId != Context.TenantId) - { - throw new MultiTenantException($"Entity type {entity.GetType().Name}, tenant ID does not match. Expected: {Context.TenantId}, Entity has: {entity.TenantId}"); - } - } - - protected virtual void CheckEntities(IEnumerable entities) - { - Check.NotNull(entities, nameof(entities)); - - foreach (var entity in entities) - { - CheckEntity(entity); - } - } - - /// + + /// /// Finds an entity with the given primary key value. If an entity with the given primary key value /// is being tracked by the context, then it is returned immediately without making a request to the /// database. Otherwise, a query is made to the database for an entity with the given primary key value @@ -138,41 +118,29 @@ public override void AddRange(IEnumerable entities) entity.TenantId = Context.TenantId; } base.AddRange(entities); - } - - public override void Attach(TEntity entity) - { - CheckEntity(entity); - base.Attach(entity); - } - - public override void AttachRange(IEnumerable entities) - { - CheckEntities(entities); - base.AttachRange(entities); - } - + } + public override void Update(TEntity entity) { - CheckEntity(entity); + Context.CheckEntity(entity); base.Update(entity); } public override void UpdateRange(IEnumerable entities) { - CheckEntities(entities); + Context.CheckEntities(entities); base.UpdateRange(entities); } public override void Remove(TEntity entity) { - CheckEntity(entity); + Context.CheckEntity(entity); base.Remove(entity); } public override void RemoveRange(IEnumerable entities) { - CheckEntities(entities); + Context.CheckEntities(entities); base.RemoveRange(entities); } @@ -189,10 +157,10 @@ protected override IQueryable GetQueryable(bool trackEntities) { var key = Context.TenantId; var queryable = Context.Query().Where(c => c.TenantId == key); - if (trackEntities) - { - var provider = queryable.Provider as IMongoFrameworkQueryProvider; - provider.EntityProcessors.Add(new EntityTrackingProcessor(Context)); + if (trackEntities) + { + var provider = queryable.Provider as IMongoFrameworkQueryProvider; + provider.EntityProcessors.Add(new EntityTrackingProcessor(Context)); } return queryable; } @@ -200,12 +168,12 @@ protected override IQueryable GetQueryable(bool trackEntities) public IQueryable GetSearchTextQueryable(string search) { var key = Context.TenantId; - var queryable = Context.Query().WhereFilter(b => b.Text(search)).Where(c => c.TenantId == key); - var provider = queryable.Provider as IMongoFrameworkQueryProvider; - provider.EntityProcessors.Add(new EntityTrackingProcessor(Context)); + var queryable = Context.Query().WhereFilter(b => b.Text(search)).Where(c => c.TenantId == key); + var provider = queryable.Provider as IMongoFrameworkQueryProvider; + provider.EntityProcessors.Add(new EntityTrackingProcessor(Context)); return queryable; - } - + } + #endregion } } \ No newline at end of file diff --git a/tests/MongoFramework.Tests/MongoDbContextTenantTests.cs b/tests/MongoFramework.Tests/MongoDbContextTenantTests.cs index 2cd8df14..9d726788 100644 --- a/tests/MongoFramework.Tests/MongoDbContextTenantTests.cs +++ b/tests/MongoFramework.Tests/MongoDbContextTenantTests.cs @@ -111,5 +111,133 @@ public async Task ContextSavesDbSetsAsync() Assert.AreEqual(TestConfiguration.GetTenantId(), context.DbSet.First().TenantId); } } + + [TestMethod] + public void SuccessfullyAttachUntrackedEntity() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var model = new DbSetModel + { + Id = "abcd" + }; + + dbSet.Add(model); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbTenantContext(connection, tenantId); + dbSet = new MongoDbTenantSet(context); + + var result = dbSet.AsNoTracking().FirstOrDefault(); + + context.Attach(result); + + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result).State); + } + + [TestMethod] + public void SuccessfullyAttachUntrackedEntities() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var entities = new[] { + new DbSetModel + { + Id = "abcd" + }, + new DbSetModel + { + Id = "efgh" + } + }; + + dbSet.AddRange(entities); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbTenantContext(connection, tenantId); + dbSet = new MongoDbTenantSet(context); + + var result = dbSet.AsNoTracking().ToList(); + + context.AttachRange(result); + + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result[0]).State); + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result[1]).State); + } + + [TestMethod] + public void AttachRejectsMismatchedEntity() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var model = new DbSetModel + { + Id = "abcd" + }; + + dbSet.Add(model); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbTenantContext(connection, tenantId); + dbSet = new MongoDbTenantSet(context); + + var result = dbSet.AsNoTracking().FirstOrDefault(); + result.TenantId = tenantId + "a"; + + Assert.ThrowsException(() => context.Attach(result)); + } + + [TestMethod] + public void AttachRejectsMismatchedEntities() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var entities = new[] { + new DbSetModel + { + Id = "abcd" + }, + new DbSetModel + { + Id = "efgh" + } + }; + + dbSet.AddRange(entities); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbTenantContext(connection, tenantId); + dbSet = new MongoDbTenantSet(context); + + var result = dbSet.AsNoTracking().ToList(); + result[0].TenantId = tenantId + "a"; + + Assert.ThrowsException(() => context.AttachRange(result)); + } + } } \ No newline at end of file diff --git a/tests/MongoFramework.Tests/MongoDbContextTests.cs b/tests/MongoFramework.Tests/MongoDbContextTests.cs index c4d1b28e..ed263efc 100644 --- a/tests/MongoFramework.Tests/MongoDbContextTests.cs +++ b/tests/MongoFramework.Tests/MongoDbContextTests.cs @@ -85,5 +85,69 @@ public async Task ContextSavesDbSetsAsync() Assert.IsTrue(context.DbSet.Any()); } } + + [TestMethod] + public void SuccessfullyAttachUntrackedEntity() + { + var connection = TestConfiguration.GetConnection(); + var context = new MongoDbContext(connection); + var dbSet = new MongoDbSet(context); + + var model = new DbSetModel + { + Id = "abcd" + }; + + dbSet.Add(model); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbContext(connection); + dbSet = new MongoDbSet(context); + + var result = dbSet.AsNoTracking().FirstOrDefault(); + + context.Attach(result); + + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result).State); + } + + [TestMethod] + public void SuccessfullyAttachUntrackedEntities() + { + var connection = TestConfiguration.GetConnection(); + var context = new MongoDbContext(connection); + var dbSet = new MongoDbSet(context); + + var entities = new[] { + new DbSetModel + { + Id = "abcd" + }, + new DbSetModel + { + Id = "efgh" + } + }; + + dbSet.AddRange(entities); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbContext(connection); + dbSet = new MongoDbSet(context); + + var result = dbSet.AsNoTracking().ToList(); + + context.AttachRange(result); + + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result[0]).State); + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result[1]).State); + } + } } \ No newline at end of file diff --git a/tests/MongoFramework.Tests/MongoDbSetTests.cs b/tests/MongoFramework.Tests/MongoDbSetTests.cs index 801d688d..9b26cd6e 100644 --- a/tests/MongoFramework.Tests/MongoDbSetTests.cs +++ b/tests/MongoFramework.Tests/MongoDbSetTests.cs @@ -470,72 +470,7 @@ public async Task SuccessfullyLinqFindNoTrackingAsync() Assert.IsNull(context.ChangeTracker.GetEntry(result)); } - - [TestMethod] - public void SuccessfullyAttachUntrackedEntity() - { - var connection = TestConfiguration.GetConnection(); - var context = new MongoDbContext(connection); - var dbSet = new MongoDbSet(context); - - var model = new TestModel - { - Id = "abcd", - Description = "SuccessfullyAttachUntrackedEntity.1" - }; - - dbSet.Add(model); - - context.SaveChanges(); - - ResetMongoDb(); - - context = new MongoDbContext(connection); - dbSet = new MongoDbSet(context); - - var result = dbSet.AsNoTracking().FirstOrDefault(); - - dbSet.Attach(result); - - Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result).State); - } - - [TestMethod] - public void SuccessfullyAttachUntrackedEntities() - { - var connection = TestConfiguration.GetConnection(); - var context = new MongoDbContext(connection); - var dbSet = new MongoDbSet(context); - - var entities = new[] { - new TestModel - { - Description = "SuccessfullyAttachUntrackedEntities.1" - }, - new TestModel - { - Description = "SuccessfullyAttachUntrackedEntities.2", - BooleanField = true - } - }; - - dbSet.AddRange(entities); - - context.SaveChanges(); - - ResetMongoDb(); - - context = new MongoDbContext(connection); - dbSet = new MongoDbSet(context); - - var result = dbSet.AsNoTracking().ToList(); - - dbSet.AttachRange(result); - - Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result[0]).State); - Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result[1]).State); - } - - } + + } } \ No newline at end of file diff --git a/tests/MongoFramework.Tests/MongoDbTenantSetTests.cs b/tests/MongoFramework.Tests/MongoDbTenantSetTests.cs index 166bc768..5fc6d403 100644 --- a/tests/MongoFramework.Tests/MongoDbTenantSetTests.cs +++ b/tests/MongoFramework.Tests/MongoDbTenantSetTests.cs @@ -931,136 +931,5 @@ public async Task SuccessfullyLinqFindNoTrackingAsync() Assert.IsNull(context.ChangeTracker.GetEntry(result)); } - [TestMethod] - public void SuccessfullyAttachUntrackedEntity() - { - var connection = TestConfiguration.GetConnection(); - var tenantId = TestConfiguration.GetTenantId(); - var context = new MongoDbTenantContext(connection, tenantId); - var dbSet = new MongoDbTenantSet(context); - - var model = new TestModel - { - Id = "abcd", - Description = "SuccessfullyAttachUntrackedEntity.1" - }; - - dbSet.Add(model); - - context.SaveChanges(); - - ResetMongoDb(); - - context = new MongoDbTenantContext(connection, tenantId); - dbSet = new MongoDbTenantSet(context); - - var result = dbSet.AsNoTracking().FirstOrDefault(); - - dbSet.Attach(result); - - Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result).State); - } - - [TestMethod] - public void SuccessfullyAttachUntrackedEntities() - { - var connection = TestConfiguration.GetConnection(); - var tenantId = TestConfiguration.GetTenantId(); - var context = new MongoDbTenantContext(connection, tenantId); - var dbSet = new MongoDbTenantSet(context); - - var entities = new[] { - new TestModel - { - Description = "SuccessfullyAttachUntrackedEntities.1" - }, - new TestModel - { - Description = "SuccessfullyAttachUntrackedEntities.2", - BooleanField = true - } - }; - - dbSet.AddRange(entities); - - context.SaveChanges(); - - ResetMongoDb(); - - context = new MongoDbTenantContext(connection, tenantId); - dbSet = new MongoDbTenantSet(context); - - var result = dbSet.AsNoTracking().ToList(); - - dbSet.AttachRange(result); - - Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result[0]).State); - Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result[1]).State); - } - - [TestMethod] - public void AttachRejectsMismatchedEntity() - { - var connection = TestConfiguration.GetConnection(); - var tenantId = TestConfiguration.GetTenantId(); - var context = new MongoDbTenantContext(connection, tenantId); - var dbSet = new MongoDbTenantSet(context); - - var model = new TestModel - { - Id = "abcd", - Description = "AttachRejectsMismatchedEntity.1" - }; - - dbSet.Add(model); - - context.SaveChanges(); - - ResetMongoDb(); - - context = new MongoDbTenantContext(connection, tenantId); - dbSet = new MongoDbTenantSet(context); - - var result = dbSet.AsNoTracking().FirstOrDefault(); - result.TenantId = tenantId + "a"; - - Assert.ThrowsException(() => dbSet.Attach(result)); - } - - [TestMethod] - public void AttachRejectsMismatchedEntities() - { - var connection = TestConfiguration.GetConnection(); - var tenantId = TestConfiguration.GetTenantId(); - var context = new MongoDbTenantContext(connection, tenantId); - var dbSet = new MongoDbTenantSet(context); - - var entities = new[] { - new TestModel - { - Description = "AttachRejectsMismatchedEntities.1" - }, - new TestModel - { - Description = "AttachRejectsMismatchedEntities.2", - BooleanField = true - } - }; - - dbSet.AddRange(entities); - - context.SaveChanges(); - - ResetMongoDb(); - - context = new MongoDbTenantContext(connection, tenantId); - dbSet = new MongoDbTenantSet(context); - - var result = dbSet.AsNoTracking().ToList(); - result[0].TenantId = tenantId + "a"; - - Assert.ThrowsException(() => dbSet.AttachRange(result)); - } - } } \ No newline at end of file From 5803d1994f3be46c9cf70af959039af9d3844fd4 Mon Sep 17 00:00:00 2001 From: John Campion Date: Thu, 17 Dec 2020 22:41:47 -0500 Subject: [PATCH 4/4] formatting typo --- tests/MongoFramework.Tests/MongoDbContextTenantTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/MongoFramework.Tests/MongoDbContextTenantTests.cs b/tests/MongoFramework.Tests/MongoDbContextTenantTests.cs index 9d726788..ba7e9e29 100644 --- a/tests/MongoFramework.Tests/MongoDbContextTenantTests.cs +++ b/tests/MongoFramework.Tests/MongoDbContextTenantTests.cs @@ -112,7 +112,7 @@ public async Task ContextSavesDbSetsAsync() } } - [TestMethod] + [TestMethod] public void SuccessfullyAttachUntrackedEntity() { var connection = TestConfiguration.GetConnection();