Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/MongoFramework/IMongoDbContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,8 @@ public interface IMongoDbContext

void SaveChanges();
Task SaveChangesAsync(CancellationToken cancellationToken = default(CancellationToken));

void Attach<TEntity>(TEntity entity) where TEntity : class;
void AttachRange<TEntity>(IEnumerable<TEntity> entities) where TEntity : class;
}
}
1 change: 1 addition & 0 deletions src/MongoFramework/IMongoDbSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ public interface IMongoDbSet<TEntity> : IMongoDbSet, IQueryable<TEntity> where T
void RemoveRange(IEnumerable<TEntity> entities);
void RemoveRange(Expression<Func<TEntity, bool>> predicate);
void RemoveById(object entityId);
IQueryable<TEntity> AsNoTracking();
}
}
10 changes: 7 additions & 3 deletions src/MongoFramework/IMongoDbTenantContext.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
namespace MongoFramework
using System.Collections.Generic;

namespace MongoFramework
{
public interface IMongoDbTenantContext : IMongoDbContext
{
public interface IMongoDbTenantContext : IMongoDbContext
{
string TenantId { get; }
void CheckEntity(IHaveTenantId entity);
void CheckEntities(IEnumerable<IHaveTenantId> entity);
}
}
25 changes: 25 additions & 0 deletions src/MongoFramework/MongoDbContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using MongoFramework.Utilities;

namespace MongoFramework
{
Expand Down Expand Up @@ -107,6 +108,7 @@ public virtual async Task SaveChangesAsync(CancellationToken cancellationToken =
ChangeTracker.CommitChanges();
CommandStaging.CommitChanges();
}

private static async Task InternalSaveChangesAsync<TEntity>(IMongoDbConnection connection, IEnumerable<IWriteCommand> commands, WriteModelOptions options, CancellationToken cancellationToken) where TEntity : class
{
await EntityIndexWriter.ApplyIndexingAsync<TEntity>(connection);
Expand All @@ -131,7 +133,30 @@ public IQueryable<TEntity> Query<TEntity>() where TEntity : class
var provider = new MongoFrameworkQueryProvider<TEntity>(Connection);
return new MongoFrameworkQueryable<TEntity>(provider);
}

/// <summary>
/// Marks the entity as unchanged in the change tracker and starts tracking.
/// </summary>
/// <param name="entity"></param>
public virtual void Attach<TEntity>(TEntity entity) where TEntity : class
{
Check.NotNull(entity, nameof(entity));
ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges);
}

/// <summary>
/// Marks the collection of entities as unchanged in the change tracker and starts tracking.
/// </summary>
/// <param name="entities"></param>
public virtual void AttachRange<TEntity>(IEnumerable<TEntity> entities) where TEntity : class
{
Check.NotNull(entities, nameof(entities));
foreach (var entity in entities)
{
ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges);
}
}

public void Dispose()
{
Dispose(true);
Expand Down
32 changes: 21 additions & 11 deletions src/MongoFramework/MongoDbSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public virtual void AddRange(IEnumerable<TEntity> entities)
Context.ChangeTracker.SetEntityState(entity, EntityEntryState.Added);
}
}

/// <summary>
/// Marks the entity for updating.
/// </summary>
Expand Down Expand Up @@ -209,30 +209,40 @@ public async Task SaveChangesAsync(CancellationToken cancellationToken = default

#region IQueryable Implementation

protected virtual IQueryable<TEntity> GetQueryable()
protected virtual IQueryable<TEntity> GetQueryable(bool trackEntities)
{
var queryable = Context.Query<TEntity>();
var provider = queryable.Provider as IMongoFrameworkQueryProvider<TEntity>;
provider.EntityProcessors.Add(new EntityTrackingProcessor<TEntity>(Context));
if (trackEntities)
{
var provider = queryable.Provider as IMongoFrameworkQueryProvider<TEntity>;
provider.EntityProcessors.Add(new EntityTrackingProcessor<TEntity>(Context));
}
return queryable;
}

public Expression Expression => GetQueryable().Expression;
public Expression Expression => GetQueryable(true).Expression;

public Type ElementType => GetQueryable().ElementType;
public Type ElementType => GetQueryable(true).ElementType;

public IQueryProvider Provider => GetQueryable().Provider;
public IQueryProvider Provider => GetQueryable(true).Provider;

public IEnumerator<TEntity> GetEnumerator()
{
return GetQueryable().GetEnumerator();
return GetQueryable(true).GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}

}

public virtual IQueryable<TEntity> AsNoTracking()
{
return GetQueryable(false);
}

#endregion
}

}

}
54 changes: 51 additions & 3 deletions src/MongoFramework/MongoDbTenantContext.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using MongoFramework.Infrastructure.Commands;
using MongoFramework.Utilities;
using System;
using System.Collections.Generic;
using MongoFramework.Infrastructure;
using MongoFramework.Infrastructure.Commands;
using MongoFramework.Utilities;

namespace MongoFramework
{
Expand All @@ -23,5 +24,52 @@ protected override WriteModelOptions GetWriteModelOptions()
{
return new WriteModelOptions { TenantId = TenantId };
}

public virtual void CheckEntity(IHaveTenantId entity)
{
Check.NotNull(entity, nameof(entity));

if (entity.TenantId != TenantId)
{
throw new MultiTenantException($"Entity type {entity.GetType().Name}, tenant ID does not match. Expected: {TenantId}, Entity has: {entity.TenantId}");
}
}

public virtual void CheckEntities(IEnumerable<IHaveTenantId> entities)
{
Check.NotNull(entities, nameof(entities));

foreach (var entity in entities)
{
CheckEntity(entity);
}
}

/// <summary>
/// Marks the entity as unchanged in the change tracker and starts tracking.
/// </summary>
/// <param name="entity"></param>
public override void Attach<TEntity>(TEntity entity) where TEntity : class
{
if (typeof(IHaveTenantId).IsAssignableFrom(typeof(TEntity)))
{
CheckEntity(entity as IHaveTenantId);
}
base.Attach(entity);
}

/// <summary>
/// Marks the collection of entities as unchanged in the change tracker and starts tracking.
/// </summary>
/// <param name="entities"></param>
public override void AttachRange<TEntity>(IEnumerable<TEntity> entities) where TEntity : class
{
if (typeof(IHaveTenantId).IsAssignableFrom(typeof(TEntity)))
{
CheckEntities(entities as IEnumerable<IHaveTenantId>);
}
base.AttachRange(entities);
}

}
}
49 changes: 16 additions & 33 deletions src/MongoFramework/MongoDbTenantSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,8 @@ public MongoDbTenantSet(IMongoDbContext context) : base(context)
{
Context = context as IMongoDbTenantContext ?? throw new ArgumentException("Context provided to a MongoDbTenantSet must be IMongoDbTenantContext",nameof(context));
}

protected virtual void CheckEntity(TEntity entity)
{
Check.NotNull(entity, nameof(entity));

if (entity.TenantId != Context.TenantId)
{
throw new MultiTenantException($"Entity type {entity.GetType().Name}, tenant ID does not match. Expected: {Context.TenantId}, Entity has: {entity.TenantId}");
}
}

protected virtual void CheckEntities(IEnumerable<TEntity> entities)
{
Check.NotNull(entities, nameof(entities));

foreach (var entity in entities)
{
CheckEntity(entity);
}
}

/// <summary>

/// <summary>
/// Finds an entity with the given primary key value. If an entity with the given primary key value
/// is being tracked by the context, then it is returned immediately without making a request to the
/// database. Otherwise, a query is made to the database for an entity with the given primary key value
Expand Down Expand Up @@ -139,28 +119,28 @@ public override void AddRange(IEnumerable<TEntity> entities)
}
base.AddRange(entities);
}

public override void Update(TEntity entity)
{
CheckEntity(entity);
Context.CheckEntity(entity);
base.Update(entity);
}

public override void UpdateRange(IEnumerable<TEntity> entities)
{
CheckEntities(entities);
Context.CheckEntities(entities);
base.UpdateRange(entities);
}

public override void Remove(TEntity entity)
{
CheckEntity(entity);
Context.CheckEntity(entity);
base.Remove(entity);
}

public override void RemoveRange(IEnumerable<TEntity> entities)
{
CheckEntities(entities);
Context.CheckEntities(entities);
base.RemoveRange(entities);
}

Expand All @@ -173,24 +153,27 @@ public override void RemoveRange(Expression<Func<TEntity, bool>> predicate)

#region IQueryable Implementation

protected override IQueryable<TEntity> GetQueryable()
protected override IQueryable<TEntity> GetQueryable(bool trackEntities)
{
var key = Context.TenantId;
var queryable = Context.Query<TEntity>().Where(c => c.TenantId == key);
var provider = queryable.Provider as IMongoFrameworkQueryProvider<TEntity>;
provider.EntityProcessors.Add(new EntityTrackingProcessor<TEntity>(Context));
if (trackEntities)
{
var provider = queryable.Provider as IMongoFrameworkQueryProvider<TEntity>;
provider.EntityProcessors.Add(new EntityTrackingProcessor<TEntity>(Context));
}
return queryable;
}

public IQueryable<TEntity> GetSearchTextQueryable(string search)
{
var key = Context.TenantId;
var queryable = Context.Query<TEntity>().WhereFilter(b => b.Text(search)).Where(c => c.TenantId == key);
var queryable = Context.Query<TEntity>().WhereFilter(b => b.Text(search)).Where(c => c.TenantId == key);
var provider = queryable.Provider as IMongoFrameworkQueryProvider<TEntity>;
provider.EntityProcessors.Add(new EntityTrackingProcessor<TEntity>(Context));
return queryable;
}
}

#endregion
}
}
Loading