diff --git a/src/MongoFramework/IMongoDbContext.cs b/src/MongoFramework/IMongoDbContext.cs index 60b3ad82..46ffb118 100644 --- a/src/MongoFramework/IMongoDbContext.cs +++ b/src/MongoFramework/IMongoDbContext.cs @@ -19,5 +19,8 @@ public interface IMongoDbContext void SaveChanges(); Task SaveChangesAsync(CancellationToken cancellationToken = default(CancellationToken)); + + void Attach(TEntity entity) where TEntity : class; + void AttachRange(IEnumerable entities) where TEntity : class; } } diff --git a/src/MongoFramework/IMongoDbSet.cs b/src/MongoFramework/IMongoDbSet.cs index a3f04f5b..bfecfb53 100644 --- a/src/MongoFramework/IMongoDbSet.cs +++ b/src/MongoFramework/IMongoDbSet.cs @@ -29,5 +29,6 @@ public interface IMongoDbSet : IMongoDbSet, IQueryable where T void RemoveRange(IEnumerable entities); void RemoveRange(Expression> predicate); void RemoveById(object entityId); + IQueryable AsNoTracking(); } } diff --git a/src/MongoFramework/IMongoDbTenantContext.cs b/src/MongoFramework/IMongoDbTenantContext.cs index 6d9ef304..7fd81cf3 100644 --- a/src/MongoFramework/IMongoDbTenantContext.cs +++ b/src/MongoFramework/IMongoDbTenantContext.cs @@ -1,7 +1,11 @@ -namespace MongoFramework +using System.Collections.Generic; + +namespace MongoFramework { - public interface IMongoDbTenantContext : IMongoDbContext - { + public interface IMongoDbTenantContext : IMongoDbContext + { string TenantId { get; } + void CheckEntity(IHaveTenantId entity); + void CheckEntities(IEnumerable entity); } } diff --git a/src/MongoFramework/MongoDbContext.cs b/src/MongoFramework/MongoDbContext.cs index f0395477..4d638a99 100644 --- a/src/MongoFramework/MongoDbContext.cs +++ b/src/MongoFramework/MongoDbContext.cs @@ -9,6 +9,7 @@ using System.Reflection; using System.Threading; using System.Threading.Tasks; +using MongoFramework.Utilities; namespace MongoFramework { @@ -107,6 +108,7 @@ public virtual async Task SaveChangesAsync(CancellationToken cancellationToken = ChangeTracker.CommitChanges(); CommandStaging.CommitChanges(); } + private static async Task InternalSaveChangesAsync(IMongoDbConnection connection, IEnumerable commands, WriteModelOptions options, CancellationToken cancellationToken) where TEntity : class { await EntityIndexWriter.ApplyIndexingAsync(connection); @@ -131,7 +133,30 @@ public IQueryable Query() where TEntity : class var provider = new MongoFrameworkQueryProvider(Connection); return new MongoFrameworkQueryable(provider); } + + /// + /// Marks the entity as unchanged in the change tracker and starts tracking. + /// + /// + public virtual void Attach(TEntity entity) where TEntity : class + { + Check.NotNull(entity, nameof(entity)); + ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges); + } + /// + /// Marks the collection of entities as unchanged in the change tracker and starts tracking. + /// + /// + public virtual void AttachRange(IEnumerable entities) where TEntity : class + { + Check.NotNull(entities, nameof(entities)); + foreach (var entity in entities) + { + ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges); + } + } + public void Dispose() { Dispose(true); diff --git a/src/MongoFramework/MongoDbSet.cs b/src/MongoFramework/MongoDbSet.cs index b2019853..86d3741f 100644 --- a/src/MongoFramework/MongoDbSet.cs +++ b/src/MongoFramework/MongoDbSet.cs @@ -130,7 +130,7 @@ public virtual void AddRange(IEnumerable entities) Context.ChangeTracker.SetEntityState(entity, EntityEntryState.Added); } } - + /// /// Marks the entity for updating. /// @@ -209,30 +209,40 @@ public async Task SaveChangesAsync(CancellationToken cancellationToken = default #region IQueryable Implementation - protected virtual IQueryable GetQueryable() + protected virtual IQueryable GetQueryable(bool trackEntities) { var queryable = Context.Query(); - var provider = queryable.Provider as IMongoFrameworkQueryProvider; - provider.EntityProcessors.Add(new EntityTrackingProcessor(Context)); + if (trackEntities) + { + var provider = queryable.Provider as IMongoFrameworkQueryProvider; + provider.EntityProcessors.Add(new EntityTrackingProcessor(Context)); + } return queryable; } - public Expression Expression => GetQueryable().Expression; + public Expression Expression => GetQueryable(true).Expression; - public Type ElementType => GetQueryable().ElementType; + public Type ElementType => GetQueryable(true).ElementType; - public IQueryProvider Provider => GetQueryable().Provider; + public IQueryProvider Provider => GetQueryable(true).Provider; public IEnumerator GetEnumerator() { - return GetQueryable().GetEnumerator(); + return GetQueryable(true).GetEnumerator(); } IEnumerator IEnumerable.GetEnumerator() { return GetEnumerator(); - } - + } + + public virtual IQueryable AsNoTracking() + { + return GetQueryable(false); + } + #endregion - } + + } + } \ No newline at end of file diff --git a/src/MongoFramework/MongoDbTenantContext.cs b/src/MongoFramework/MongoDbTenantContext.cs index 62a838e5..48e5a921 100644 --- a/src/MongoFramework/MongoDbTenantContext.cs +++ b/src/MongoFramework/MongoDbTenantContext.cs @@ -1,6 +1,7 @@ -using MongoFramework.Infrastructure.Commands; -using MongoFramework.Utilities; -using System; +using System.Collections.Generic; +using MongoFramework.Infrastructure; +using MongoFramework.Infrastructure.Commands; +using MongoFramework.Utilities; namespace MongoFramework { @@ -23,5 +24,52 @@ protected override WriteModelOptions GetWriteModelOptions() { return new WriteModelOptions { TenantId = TenantId }; } + + public virtual void CheckEntity(IHaveTenantId entity) + { + Check.NotNull(entity, nameof(entity)); + + if (entity.TenantId != TenantId) + { + throw new MultiTenantException($"Entity type {entity.GetType().Name}, tenant ID does not match. Expected: {TenantId}, Entity has: {entity.TenantId}"); + } + } + + public virtual void CheckEntities(IEnumerable entities) + { + Check.NotNull(entities, nameof(entities)); + + foreach (var entity in entities) + { + CheckEntity(entity); + } + } + + /// + /// Marks the entity as unchanged in the change tracker and starts tracking. + /// + /// + public override void Attach(TEntity entity) where TEntity : class + { + if (typeof(IHaveTenantId).IsAssignableFrom(typeof(TEntity))) + { + CheckEntity(entity as IHaveTenantId); + } + base.Attach(entity); + } + + /// + /// Marks the collection of entities as unchanged in the change tracker and starts tracking. + /// + /// + public override void AttachRange(IEnumerable entities) where TEntity : class + { + if (typeof(IHaveTenantId).IsAssignableFrom(typeof(TEntity))) + { + CheckEntities(entities as IEnumerable); + } + base.AttachRange(entities); + } + } } diff --git a/src/MongoFramework/MongoDbTenantSet.cs b/src/MongoFramework/MongoDbTenantSet.cs index 42872142..8eb8dda1 100644 --- a/src/MongoFramework/MongoDbTenantSet.cs +++ b/src/MongoFramework/MongoDbTenantSet.cs @@ -27,28 +27,8 @@ public MongoDbTenantSet(IMongoDbContext context) : base(context) { Context = context as IMongoDbTenantContext ?? throw new ArgumentException("Context provided to a MongoDbTenantSet must be IMongoDbTenantContext",nameof(context)); } - - protected virtual void CheckEntity(TEntity entity) - { - Check.NotNull(entity, nameof(entity)); - - if (entity.TenantId != Context.TenantId) - { - throw new MultiTenantException($"Entity type {entity.GetType().Name}, tenant ID does not match. Expected: {Context.TenantId}, Entity has: {entity.TenantId}"); - } - } - - protected virtual void CheckEntities(IEnumerable entities) - { - Check.NotNull(entities, nameof(entities)); - - foreach (var entity in entities) - { - CheckEntity(entity); - } - } - - /// + + /// /// Finds an entity with the given primary key value. If an entity with the given primary key value /// is being tracked by the context, then it is returned immediately without making a request to the /// database. Otherwise, a query is made to the database for an entity with the given primary key value @@ -139,28 +119,28 @@ public override void AddRange(IEnumerable entities) } base.AddRange(entities); } - + public override void Update(TEntity entity) { - CheckEntity(entity); + Context.CheckEntity(entity); base.Update(entity); } public override void UpdateRange(IEnumerable entities) { - CheckEntities(entities); + Context.CheckEntities(entities); base.UpdateRange(entities); } public override void Remove(TEntity entity) { - CheckEntity(entity); + Context.CheckEntity(entity); base.Remove(entity); } public override void RemoveRange(IEnumerable entities) { - CheckEntities(entities); + Context.CheckEntities(entities); base.RemoveRange(entities); } @@ -173,24 +153,27 @@ public override void RemoveRange(Expression> predicate) #region IQueryable Implementation - protected override IQueryable GetQueryable() + protected override IQueryable GetQueryable(bool trackEntities) { var key = Context.TenantId; var queryable = Context.Query().Where(c => c.TenantId == key); - var provider = queryable.Provider as IMongoFrameworkQueryProvider; - provider.EntityProcessors.Add(new EntityTrackingProcessor(Context)); + if (trackEntities) + { + var provider = queryable.Provider as IMongoFrameworkQueryProvider; + provider.EntityProcessors.Add(new EntityTrackingProcessor(Context)); + } return queryable; } public IQueryable GetSearchTextQueryable(string search) { var key = Context.TenantId; - var queryable = Context.Query().WhereFilter(b => b.Text(search)).Where(c => c.TenantId == key); + var queryable = Context.Query().WhereFilter(b => b.Text(search)).Where(c => c.TenantId == key); var provider = queryable.Provider as IMongoFrameworkQueryProvider; provider.EntityProcessors.Add(new EntityTrackingProcessor(Context)); return queryable; - } - + } + #endregion } } \ No newline at end of file diff --git a/tests/MongoFramework.Tests/MongoDbContextTenantTests.cs b/tests/MongoFramework.Tests/MongoDbContextTenantTests.cs index 2cd8df14..ba7e9e29 100644 --- a/tests/MongoFramework.Tests/MongoDbContextTenantTests.cs +++ b/tests/MongoFramework.Tests/MongoDbContextTenantTests.cs @@ -111,5 +111,133 @@ public async Task ContextSavesDbSetsAsync() Assert.AreEqual(TestConfiguration.GetTenantId(), context.DbSet.First().TenantId); } } + + [TestMethod] + public void SuccessfullyAttachUntrackedEntity() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var model = new DbSetModel + { + Id = "abcd" + }; + + dbSet.Add(model); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbTenantContext(connection, tenantId); + dbSet = new MongoDbTenantSet(context); + + var result = dbSet.AsNoTracking().FirstOrDefault(); + + context.Attach(result); + + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result).State); + } + + [TestMethod] + public void SuccessfullyAttachUntrackedEntities() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var entities = new[] { + new DbSetModel + { + Id = "abcd" + }, + new DbSetModel + { + Id = "efgh" + } + }; + + dbSet.AddRange(entities); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbTenantContext(connection, tenantId); + dbSet = new MongoDbTenantSet(context); + + var result = dbSet.AsNoTracking().ToList(); + + context.AttachRange(result); + + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result[0]).State); + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result[1]).State); + } + + [TestMethod] + public void AttachRejectsMismatchedEntity() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var model = new DbSetModel + { + Id = "abcd" + }; + + dbSet.Add(model); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbTenantContext(connection, tenantId); + dbSet = new MongoDbTenantSet(context); + + var result = dbSet.AsNoTracking().FirstOrDefault(); + result.TenantId = tenantId + "a"; + + Assert.ThrowsException(() => context.Attach(result)); + } + + [TestMethod] + public void AttachRejectsMismatchedEntities() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var entities = new[] { + new DbSetModel + { + Id = "abcd" + }, + new DbSetModel + { + Id = "efgh" + } + }; + + dbSet.AddRange(entities); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbTenantContext(connection, tenantId); + dbSet = new MongoDbTenantSet(context); + + var result = dbSet.AsNoTracking().ToList(); + result[0].TenantId = tenantId + "a"; + + Assert.ThrowsException(() => context.AttachRange(result)); + } + } } \ No newline at end of file diff --git a/tests/MongoFramework.Tests/MongoDbContextTests.cs b/tests/MongoFramework.Tests/MongoDbContextTests.cs index c4d1b28e..ed263efc 100644 --- a/tests/MongoFramework.Tests/MongoDbContextTests.cs +++ b/tests/MongoFramework.Tests/MongoDbContextTests.cs @@ -85,5 +85,69 @@ public async Task ContextSavesDbSetsAsync() Assert.IsTrue(context.DbSet.Any()); } } + + [TestMethod] + public void SuccessfullyAttachUntrackedEntity() + { + var connection = TestConfiguration.GetConnection(); + var context = new MongoDbContext(connection); + var dbSet = new MongoDbSet(context); + + var model = new DbSetModel + { + Id = "abcd" + }; + + dbSet.Add(model); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbContext(connection); + dbSet = new MongoDbSet(context); + + var result = dbSet.AsNoTracking().FirstOrDefault(); + + context.Attach(result); + + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result).State); + } + + [TestMethod] + public void SuccessfullyAttachUntrackedEntities() + { + var connection = TestConfiguration.GetConnection(); + var context = new MongoDbContext(connection); + var dbSet = new MongoDbSet(context); + + var entities = new[] { + new DbSetModel + { + Id = "abcd" + }, + new DbSetModel + { + Id = "efgh" + } + }; + + dbSet.AddRange(entities); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbContext(connection); + dbSet = new MongoDbSet(context); + + var result = dbSet.AsNoTracking().ToList(); + + context.AttachRange(result); + + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result[0]).State); + Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.NoChanges, context.ChangeTracker.GetEntry(result[1]).State); + } + } } \ No newline at end of file diff --git a/tests/MongoFramework.Tests/MongoDbSetTests.cs b/tests/MongoFramework.Tests/MongoDbSetTests.cs index b99a9ab1..9b26cd6e 100644 --- a/tests/MongoFramework.Tests/MongoDbSetTests.cs +++ b/tests/MongoFramework.Tests/MongoDbSetTests.cs @@ -417,6 +417,60 @@ public async Task SuccessfullyLinqFindTrackedAsync() Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.Updated, context.ChangeTracker.GetEntry(result).State); } - } + [TestMethod] + public void SuccessfullyLinqFindNoTracking() + { + var connection = TestConfiguration.GetConnection(); + var context = new MongoDbContext(connection); + var dbSet = new MongoDbSet(context); + + var model = new TestModel + { + Id = "abcd", + Description = "SuccessfullyLinqFindNoTracking.1" + }; + + dbSet.Add(model); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbContext(connection); + dbSet = new MongoDbSet(context); + + var result = dbSet.AsNoTracking().FirstOrDefault(); + + Assert.IsNull(context.ChangeTracker.GetEntry(result)); + } + + [TestMethod] + public async Task SuccessfullyLinqFindNoTrackingAsync() + { + var connection = TestConfiguration.GetConnection(); + var context = new MongoDbContext(connection); + var dbSet = new MongoDbSet(context); + + var model = new TestModel + { + Id = "abcd", + Description = "SuccessfullyFindTracked.1" + }; + + dbSet.Add(model); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbContext(connection); + dbSet = new MongoDbSet(context); + + var result = await dbSet.AsNoTracking().FirstOrDefaultAsync(); + + Assert.IsNull(context.ChangeTracker.GetEntry(result)); + } + + } } \ No newline at end of file diff --git a/tests/MongoFramework.Tests/MongoDbTenantSetTests.cs b/tests/MongoFramework.Tests/MongoDbTenantSetTests.cs index 57cad153..5fc6d403 100644 --- a/tests/MongoFramework.Tests/MongoDbTenantSetTests.cs +++ b/tests/MongoFramework.Tests/MongoDbTenantSetTests.cs @@ -875,5 +875,61 @@ public async Task SuccessfullyLinqFindTrackedAsync() Assert.AreEqual(MongoFramework.Infrastructure.EntityEntryState.Updated, context.ChangeTracker.GetEntry(result).State); } + [TestMethod] + public void SuccessfullyLinqFindNoTracking() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var model = new TestModel + { + Id = "abcd", + Description = "SuccessfullyLinqFindNoTracking.1" + }; + + dbSet.Add(model); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbTenantContext(connection, tenantId); + dbSet = new MongoDbTenantSet(context); + + var result = dbSet.AsNoTracking().FirstOrDefault(); + + Assert.IsNull(context.ChangeTracker.GetEntry(result)); + } + + [TestMethod] + public async Task SuccessfullyLinqFindNoTrackingAsync() + { + var connection = TestConfiguration.GetConnection(); + var tenantId = TestConfiguration.GetTenantId(); + var context = new MongoDbTenantContext(connection, tenantId); + var dbSet = new MongoDbTenantSet(context); + + var model = new TestModel + { + Id = "abcd", + Description = "SuccessfullyFindTracked.1" + }; + + dbSet.Add(model); + + context.SaveChanges(); + + ResetMongoDb(); + + context = new MongoDbTenantContext(connection, tenantId); + dbSet = new MongoDbTenantSet(context); + + var result = await dbSet.AsNoTracking().FirstOrDefaultAsync(); + + Assert.IsNull(context.ChangeTracker.GetEntry(result)); + } + } } \ No newline at end of file