diff --git a/src/EFCore/Infrastructure/Internal/LazyLoader.cs b/src/EFCore/Infrastructure/Internal/LazyLoader.cs index 6dd385600d3..8215392c8d8 100644 --- a/src/EFCore/Infrastructure/Internal/LazyLoader.cs +++ b/src/EFCore/Infrastructure/Internal/LazyLoader.cs @@ -4,6 +4,8 @@ using System.Collections.Concurrent; using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using Microsoft.EntityFrameworkCore.ChangeTracking; using Microsoft.EntityFrameworkCore.Internal; namespace Microsoft.EntityFrameworkCore.Infrastructure.Internal; @@ -20,7 +22,8 @@ public class LazyLoader : ILazyLoader, IInjectableService private bool _disposed; private bool _detached; private IDictionary? _loadedStates; - private readonly ConcurrentDictionary<(object Entity, string NavigationName), bool> _isLoading = new(NavEntryEqualityComparer.Instance); + private readonly Lock _isLoadingLock = new Lock(); + private readonly Dictionary<(object Entity, string NavigationName), (TaskCompletionSource TaskCompletionSource, AsyncLocal Depth)> _isLoading = new(NavEntryEqualityComparer.Instance); private HashSet? _nonLazyNavigations; /// @@ -107,30 +110,56 @@ public virtual void Load(object entity, [CallerMemberName] string navigationName Check.NotEmpty(navigationName, nameof(navigationName)); var navEntry = (entity, navigationName); - if (_isLoading.TryAdd(navEntry, true)) + + bool exists; + (TaskCompletionSource TaskCompletionSource, AsyncLocal Depth) isLoadingValue; + + lock (_isLoadingLock) + { + ref var refIsLoadingValue = ref CollectionsMarshal.GetValueRefOrAddDefault(_isLoading, navEntry, out exists); + if (!exists) + { + refIsLoadingValue = (new(), new()); + } + isLoadingValue = refIsLoadingValue!; + isLoadingValue.Depth.Value++; + } + + if (exists) + { + // Only waits for the outermost call on the call stack. See #35528. + if (isLoadingValue.Depth.Value == 1) + { + isLoadingValue.TaskCompletionSource.Task.Wait(); + } + return; + } + + try { - try + // ShouldLoad is called after _isLoading.Add because it could attempt to load the property. See #13138. + if (ShouldLoad(entity, navigationName, out var entry)) { - // ShouldLoad is called after _isLoading.Add because it could attempt to load the property. See #13138. - if (ShouldLoad(entity, navigationName, out var entry)) + try { - try - { - entry.Load( - _queryTrackingBehavior == QueryTrackingBehavior.NoTrackingWithIdentityResolution - ? LoadOptions.ForceIdentityResolution - : LoadOptions.None); - } - catch - { - entry.IsLoaded = false; - throw; - } + entry.Load( + _queryTrackingBehavior == QueryTrackingBehavior.NoTrackingWithIdentityResolution + ? LoadOptions.ForceIdentityResolution + : LoadOptions.None); + } + catch + { + entry.IsLoaded = false; + throw; } } - finally + } + finally + { + isLoadingValue.TaskCompletionSource.TrySetResult(); + lock (_isLoadingLock) { - _isLoading.TryRemove(navEntry, out _); + _isLoading.Remove(navEntry); } } } @@ -150,31 +179,57 @@ public virtual async Task LoadAsync( Check.NotEmpty(navigationName, nameof(navigationName)); var navEntry = (entity, navigationName); - if (_isLoading.TryAdd(navEntry, true)) + + bool exists; + (TaskCompletionSource TaskCompletionSource, AsyncLocal Depth) isLoadingValue; + + lock (_isLoadingLock) + { + ref var refIsLoadingValue = ref CollectionsMarshal.GetValueRefOrAddDefault(_isLoading, navEntry, out exists); + if (!exists) + { + refIsLoadingValue = (new(), new()); + } + isLoadingValue = refIsLoadingValue!; + isLoadingValue.Depth.Value++; + } + + if (exists) + { + // Only waits for the outermost call on the call stack. See #35528. + if (isLoadingValue.Depth.Value == 1) + { + await isLoadingValue.TaskCompletionSource.Task.WaitAsync(cancellationToken).ConfigureAwait(false); + } + return; + } + + try { - try + // ShouldLoad is called after _isLoading.Add because it could attempt to load the property. See #13138. + if (ShouldLoad(entity, navigationName, out var entry)) { - // ShouldLoad is called after _isLoading.Add because it could attempt to load the property. See #13138. - if (ShouldLoad(entity, navigationName, out var entry)) + try { - try - { - await entry.LoadAsync( - _queryTrackingBehavior == QueryTrackingBehavior.NoTrackingWithIdentityResolution - ? LoadOptions.ForceIdentityResolution - : LoadOptions.None, - cancellationToken).ConfigureAwait(false); - } - catch - { - entry.IsLoaded = false; - throw; - } + await entry.LoadAsync( + _queryTrackingBehavior == QueryTrackingBehavior.NoTrackingWithIdentityResolution + ? LoadOptions.ForceIdentityResolution + : LoadOptions.None, + cancellationToken).ConfigureAwait(false); + } + catch + { + entry.IsLoaded = false; + throw; } } - finally + } + finally + { + isLoadingValue.TaskCompletionSource.TrySetResult(); + lock (_isLoadingLock) { - _isLoading.TryRemove(navEntry, out _); + _isLoading.Remove(navEntry); } } } diff --git a/test/EFCore.Specification.Tests/LoadTestBase.cs b/test/EFCore.Specification.Tests/LoadTestBase.cs index c9c97a3a2a8..1b216612cc2 100644 --- a/test/EFCore.Specification.Tests/LoadTestBase.cs +++ b/test/EFCore.Specification.Tests/LoadTestBase.cs @@ -5035,6 +5035,60 @@ public virtual void Setting_navigation_to_null_is_detected_by_local_DetectChange Assert.Equal(EntityState.Deleted, childEntry.State); } + [ConditionalTheory] // Issue #35528 + [InlineData(false, false)] + [InlineData(true, false)] + [InlineData(false, true)] + [InlineData(true, true)] + public virtual async Task Lazy_loading_is_thread_safe(bool noTracking, bool async) + { + using var context = CreateContext(lazyLoadingEnabled: true); + + //Creating another context to avoid caches + using var context2 = CreateContext(lazyLoadingEnabled: true); + + IQueryable query = context.Set(); + IQueryable query2 = context2.Set(); + + if (noTracking) + { + query = query.AsNoTracking(); + query2 = query2.AsNoTracking(); + } + + var parent = query.Single(); + + var children = (await parent.LazyLoadChildren(async))?.Select(x => x.Id).OrderBy(x => x).ToList(); + var singlePkToPk = (await parent.LazyLoadSinglePkToPk(async))?.Id; + var single = (await parent.LazyLoadSingle(async))?.Id; + var childrenAk = (await parent.LazyLoadChildrenAk(async))?.Select(x => x.Id).OrderBy(x => x).ToList(); + var singleAk = (await parent.LazyLoadSingleAk(async))?.Id; + var childrenShadowFk = (await parent.LazyLoadChildrenShadowFk(async))?.Select(x => x.Id).OrderBy(x => x).ToList(); + var singleShadowFk = (await parent.LazyLoadSingleShadowFk(async))?.Id; + var childrenCompositeKey = (await parent.LazyLoadChildrenCompositeKey(async))?.Select(x => x.Id).OrderBy(x => x).ToList(); + var singleCompositeKey = (await parent.LazyLoadSingleCompositeKey(async))?.Id; + + var parent2 = query2.Single(); + + var parallelOptions = new ParallelOptions + { + MaxDegreeOfParallelism = Environment.ProcessorCount * 500 + }; + + await Parallel.ForAsync(0, 50000, parallelOptions, async (i, ct) => + { + Assert.Equal(children, (await parent2.LazyLoadChildren(async))?.Select(x => x.Id).OrderBy(x => x).ToList()); + Assert.Equal(singlePkToPk, (await parent2.LazyLoadSinglePkToPk(async))?.Id); + Assert.Equal(single, (await parent2.LazyLoadSingle(async))?.Id); + Assert.Equal(childrenAk, (await parent2.LazyLoadChildrenAk(async))?.Select(x => x.Id).OrderBy(x => x).ToList()); + Assert.Equal(singleAk, (await parent2.LazyLoadSingleAk(async))?.Id); + Assert.Equal(childrenShadowFk, (await parent2.LazyLoadChildrenShadowFk(async))?.Select(x => x.Id).OrderBy(x => x).ToList()); + Assert.Equal(singleShadowFk, (await parent2.LazyLoadSingleShadowFk(async))?.Id); + Assert.Equal(childrenCompositeKey, (await parent2.LazyLoadChildrenCompositeKey(async))?.Select(x => x.Id).OrderBy(x => x).ToList()); + Assert.Equal(singleCompositeKey, (await parent2.LazyLoadSingleCompositeKey(async))?.Id); + }); + } + private static void SetState( DbContext context, object entity, @@ -5092,6 +5146,17 @@ public SinglePkToPk SinglePkToPk set => _singlePkToPk = value; } + public async Task LazyLoadSinglePkToPk(bool async) + { + if (async) + { + await Loader.LoadAsync(this, default, nameof(SinglePkToPk)); + return _singlePkToPk; + } + + return SinglePkToPk; + } + public Single Single { get => Loader.Load(this, ref _single); @@ -5121,35 +5186,101 @@ public IEnumerable ChildrenAk set => _childrenAk = value; } + public async Task> LazyLoadChildrenAk(bool async) + { + if (async) + { + await Loader.LoadAsync(this, default, nameof(ChildrenAk)); + return _childrenAk; + } + + return ChildrenAk; + } + public SingleAk SingleAk { get => Loader.Load(this, ref _singleAk); set => _singleAk = value; } + public async Task LazyLoadSingleAk(bool async) + { + if (async) + { + await Loader.LoadAsync(this, default, nameof(SingleAk)); + return _singleAk; + } + + return SingleAk; + } + public IEnumerable ChildrenShadowFk { get => Loader.Load(this, ref _childrenShadowFk); set => _childrenShadowFk = value; } + public async Task> LazyLoadChildrenShadowFk(bool async) + { + if (async) + { + await Loader.LoadAsync(this, default, nameof(ChildrenShadowFk)); + return _childrenShadowFk; + } + + return ChildrenShadowFk; + } + public SingleShadowFk SingleShadowFk { get => Loader.Load(this, ref _singleShadowFk); set => _singleShadowFk = value; } + public async Task LazyLoadSingleShadowFk(bool async) + { + if (async) + { + await Loader.LoadAsync(this, default, nameof(SingleShadowFk)); + return _singleShadowFk; + } + + return SingleShadowFk; + } + public IEnumerable ChildrenCompositeKey { get => Loader.Load(this, ref _childrenCompositeKey); set => _childrenCompositeKey = value; } + public async Task> LazyLoadChildrenCompositeKey(bool async) + { + if (async) + { + await Loader.LoadAsync(this, default, nameof(ChildrenCompositeKey)); + return _childrenCompositeKey; + } + + return ChildrenCompositeKey; + } + public SingleCompositeKey SingleCompositeKey { get => Loader.Load(this, ref _singleCompositeKey); set => _singleCompositeKey = value; } + + public async Task LazyLoadSingleCompositeKey(bool async) + { + if (async) + { + await Loader.LoadAsync(this, default, nameof(SingleCompositeKey)); + return _singleCompositeKey; + } + + return SingleCompositeKey; + } } protected class Child