Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public IEnumerable<WriteModel<EntityBucket<TGroup, TSubEntity>>> GetModel(WriteM
.Min(b => b.Min, itemTimeValue)
.Max(b => b.Max, itemTimeValue)
.SetOnInsert(b => b.BucketSize, BucketSize)
.SetOnInsert(b => b.Id, entityDefinition.Key.KeyGenerator.Generate());
.SetOnInsert(b => b.Id, entityDefinition.FindNearestKey().KeyGenerator.Generate());

yield return new UpdateOneModel<EntityBucket<TGroup, TSubEntity>>(filter, updateDefinition)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,24 @@ namespace MongoFramework.Infrastructure.Mapping
{
public static class EntityDefinitionExtensions
{
public static PropertyDefinition GetIdProperty(this EntityDefinition definition)
/// <summary>
/// Finds the nearest <see cref="KeyDefinition"/> from <paramref name="definition"/>, recursively searching the base <see cref="EntityDefinition"/> if one exists.
/// </summary>
/// <param name="definition">The <see cref="EntityDefinition"/> to start the search from.</param>
/// <returns>The key definition; otherwise <see langword="null"/> if one can not be found.</returns>
public static KeyDefinition FindNearestKey(this EntityDefinition definition)
{
if (definition.Key is null)
{
return EntityMapping.GetOrCreateDefinition(definition.EntityType.BaseType).GetIdProperty();
return EntityMapping.GetOrCreateDefinition(definition.EntityType.BaseType).FindNearestKey();
}

return definition.Key?.Property;
return definition.Key;
}

public static PropertyDefinition GetIdProperty(this EntityDefinition definition)
{
return definition.FindNearestKey()?.Property;
}

public static string GetIdName(this EntityDefinition definition)
Expand Down
3 changes: 2 additions & 1 deletion src/MongoFramework/Linq/LinqExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ public static string ToQuery(this IQueryable queryable)

public static IQueryable<TEntity> WhereIdMatches<TEntity>(this IQueryable<TEntity> queryable, IEnumerable entityIds) where TEntity : class
{
var idProperty = EntityMapping.GetOrCreateDefinition(typeof(TEntity)).Key.Property;
var idProperty = EntityMapping.GetOrCreateDefinition(typeof(TEntity)).GetIdProperty()
?? throw new ArgumentException($"No Id property was found on entity type {typeof(TEntity)} or any base types");
return queryable.WherePropertyMatches(idProperty, entityIds);
}

Expand Down
36 changes: 36 additions & 0 deletions tests/MongoFramework.Tests/Linq/LinqExtensionsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ public class LinqExtensionsModel
{
public string Id { get; set; }
}
public class WhereIdMatchesInheritanceBaseModel
{
public string Id { get; set; }
}
public class WhereIdMatchesInheritanceDerivedModel : WhereIdMatchesInheritanceBaseModel
{
public string Description { get; set; }
}
public class WhereIdMatchesGuidModel
{
public Guid Id { get; set; }
Expand Down Expand Up @@ -70,6 +78,34 @@ public void InvalidToQuery()
LinqExtensions.ToQuery(null);
}

[TestMethod]
public void WhereIdMatches_BaseTypeWithId()
{
var connection = TestConfiguration.GetConnection();
var context = new MongoDbContext(connection);
var dbSet = new MongoDbSet<WhereIdMatchesInheritanceDerivedModel>(context);

var entityCollection = new[]
{
new WhereIdMatchesInheritanceDerivedModel { Description = "1" },
new WhereIdMatchesInheritanceDerivedModel { Description = "2" },
new WhereIdMatchesInheritanceDerivedModel { Description = "3" },
new WhereIdMatchesInheritanceDerivedModel { Description = "4" }
};
dbSet.AddRange(entityCollection);
context.SaveChanges();

var provider = new MongoFrameworkQueryProvider<WhereIdMatchesInheritanceDerivedModel>(connection);
var queryable = new MongoFrameworkQueryable<WhereIdMatchesInheritanceDerivedModel>(provider);

var entityIds = entityCollection.Select(e => e.Id).Take(2);

var idMatchQueryable = LinqExtensions.WhereIdMatches(queryable, entityIds);

Assert.AreEqual(2, idMatchQueryable.Count());
Assert.IsTrue(idMatchQueryable.ToList().All(e => entityIds.Contains(e.Id)));
}

[TestMethod]
public void WhereIdMatchesGuids()
{
Expand Down