diff --git a/src/MongoFramework/Infrastructure/Commands/AddToBucketCommand.cs b/src/MongoFramework/Infrastructure/Commands/AddToBucketCommand.cs index 67168c73..1498aada 100644 --- a/src/MongoFramework/Infrastructure/Commands/AddToBucketCommand.cs +++ b/src/MongoFramework/Infrastructure/Commands/AddToBucketCommand.cs @@ -42,7 +42,7 @@ public IEnumerable>> GetModel(WriteM .Min(b => b.Min, itemTimeValue) .Max(b => b.Max, itemTimeValue) .SetOnInsert(b => b.BucketSize, BucketSize) - .SetOnInsert(b => b.Id, entityDefinition.Key.KeyGenerator.Generate()); + .SetOnInsert(b => b.Id, entityDefinition.FindNearestKey().KeyGenerator.Generate()); yield return new UpdateOneModel>(filter, updateDefinition) { diff --git a/src/MongoFramework/Infrastructure/Mapping/EntityDefinitionExtensions.cs b/src/MongoFramework/Infrastructure/Mapping/EntityDefinitionExtensions.cs index 650c508b..f39eb9c3 100644 --- a/src/MongoFramework/Infrastructure/Mapping/EntityDefinitionExtensions.cs +++ b/src/MongoFramework/Infrastructure/Mapping/EntityDefinitionExtensions.cs @@ -6,14 +6,24 @@ namespace MongoFramework.Infrastructure.Mapping { public static class EntityDefinitionExtensions { - public static PropertyDefinition GetIdProperty(this EntityDefinition definition) + /// + /// Finds the nearest from , recursively searching the base if one exists. + /// + /// The to start the search from. + /// The key definition; otherwise if one can not be found. + public static KeyDefinition FindNearestKey(this EntityDefinition definition) { if (definition.Key is null) { - return EntityMapping.GetOrCreateDefinition(definition.EntityType.BaseType).GetIdProperty(); + return EntityMapping.GetOrCreateDefinition(definition.EntityType.BaseType).FindNearestKey(); } - return definition.Key?.Property; + return definition.Key; + } + + public static PropertyDefinition GetIdProperty(this EntityDefinition definition) + { + return definition.FindNearestKey()?.Property; } public static string GetIdName(this EntityDefinition definition) diff --git a/src/MongoFramework/Linq/LinqExtensions.cs b/src/MongoFramework/Linq/LinqExtensions.cs index 52f94b76..549c0b0d 100644 --- a/src/MongoFramework/Linq/LinqExtensions.cs +++ b/src/MongoFramework/Linq/LinqExtensions.cs @@ -27,7 +27,8 @@ public static string ToQuery(this IQueryable queryable) public static IQueryable WhereIdMatches(this IQueryable queryable, IEnumerable entityIds) where TEntity : class { - var idProperty = EntityMapping.GetOrCreateDefinition(typeof(TEntity)).Key.Property; + var idProperty = EntityMapping.GetOrCreateDefinition(typeof(TEntity)).GetIdProperty() + ?? throw new ArgumentException($"No Id property was found on entity type {typeof(TEntity)} or any base types"); return queryable.WherePropertyMatches(idProperty, entityIds); } diff --git a/tests/MongoFramework.Tests/Linq/LinqExtensionsTests.cs b/tests/MongoFramework.Tests/Linq/LinqExtensionsTests.cs index 012a6cc3..b3d716b0 100644 --- a/tests/MongoFramework.Tests/Linq/LinqExtensionsTests.cs +++ b/tests/MongoFramework.Tests/Linq/LinqExtensionsTests.cs @@ -17,6 +17,14 @@ public class LinqExtensionsModel { public string Id { get; set; } } + public class WhereIdMatchesInheritanceBaseModel + { + public string Id { get; set; } + } + public class WhereIdMatchesInheritanceDerivedModel : WhereIdMatchesInheritanceBaseModel + { + public string Description { get; set; } + } public class WhereIdMatchesGuidModel { public Guid Id { get; set; } @@ -70,6 +78,34 @@ public void InvalidToQuery() LinqExtensions.ToQuery(null); } + [TestMethod] + public void WhereIdMatches_BaseTypeWithId() + { + var connection = TestConfiguration.GetConnection(); + var context = new MongoDbContext(connection); + var dbSet = new MongoDbSet(context); + + var entityCollection = new[] + { + new WhereIdMatchesInheritanceDerivedModel { Description = "1" }, + new WhereIdMatchesInheritanceDerivedModel { Description = "2" }, + new WhereIdMatchesInheritanceDerivedModel { Description = "3" }, + new WhereIdMatchesInheritanceDerivedModel { Description = "4" } + }; + dbSet.AddRange(entityCollection); + context.SaveChanges(); + + var provider = new MongoFrameworkQueryProvider(connection); + var queryable = new MongoFrameworkQueryable(provider); + + var entityIds = entityCollection.Select(e => e.Id).Take(2); + + var idMatchQueryable = LinqExtensions.WhereIdMatches(queryable, entityIds); + + Assert.AreEqual(2, idMatchQueryable.Count()); + Assert.IsTrue(idMatchQueryable.ToList().All(e => entityIds.Contains(e.Id))); + } + [TestMethod] public void WhereIdMatchesGuids() {