From eb8fe9f48e4e056213bbd1a1e4cccf933160db86 Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Sat, 18 Dec 2021 04:43:43 +0100 Subject: [PATCH 01/22] Change HttpHeaders backing store to an array --- .../System/Net/Http/Headers/HttpHeaders.cs | 305 +++++++++++++----- .../Http/Headers/HttpHeadersNonValidated.cs | 46 +-- .../SocketsHttpHandler/Http2Connection.cs | 9 +- .../SocketsHttpHandler/Http3RequestStream.cs | 9 +- .../Http/SocketsHttpHandler/HttpConnection.cs | 9 +- 5 files changed, 278 insertions(+), 100 deletions(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index ee53b9e626c751..53ee582f807c44 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -5,11 +5,19 @@ using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Text; namespace System.Net.Http.Headers { + /// Key/value pairs of headers. The value is either a raw or a . + internal struct HeaderEntry + { + public HeaderDescriptor Key; + public object Value; + } + public abstract class HttpHeaders : IEnumerable>> { // This type is used to store a collection of headers in 'headerStore': @@ -33,7 +41,7 @@ public abstract class HttpHeaders : IEnumerableKey/value pairs of headers. The value is either a raw or a . - private Dictionary? _headerStore; + private HeaderStore _headerStore; private readonly HttpHeaderType _allowedHeaderTypes; private readonly HttpHeaderType _treatAsCustomHeaderTypes; @@ -52,7 +60,9 @@ internal HttpHeaders(HttpHeaderType allowedHeaderTypes, HttpHeaderType treatAsCu _treatAsCustomHeaderTypes = treatAsCustomHeaderTypes & ~HttpHeaderType.NonTrailing; } - internal Dictionary? HeaderStore => _headerStore; + internal HeaderEntry[]? Entries => _headerStore.Entries; + + internal int Count => _headerStore.Count; /// Gets a view of the contents of this headers collection that does not parse nor validate the data upon access. public HttpHeadersNonValidated NonValidated => new HttpHeadersNonValidated(this); @@ -120,29 +130,25 @@ internal bool TryAddWithoutValidation(HeaderDescriptor descriptor, string? value // values, e.g. adding two null-strings (or empty, or whitespace-only) results in "My-Header: ,". value ??= string.Empty; - // Ensure the header store dictionary has been created. - _headerStore ??= new Dictionary(); + ref object? storeValueRef = ref _headerStore.GetValueRefOrAddDefault(descriptor); + object? currentValue = storeValueRef; - if (_headerStore.TryGetValue(descriptor, out object? currentValue)) + if (currentValue is null) { - if (currentValue is HeaderStoreItemInfo info) - { - // The header store already contained a HeaderStoreItemInfo, so add to it. - AddRawValue(info, value); - } - else - { - // The header store contained a single raw string value, so promote it - // to being a HeaderStoreItemInfo and add to it. - Debug.Assert(currentValue is string); - _headerStore[descriptor] = info = new HeaderStoreItemInfo() { RawValue = currentValue }; - AddRawValue(info, value); - } + storeValueRef = value; + } + else if (currentValue is HeaderStoreItemInfo info) + { + // The header store already contained a HeaderStoreItemInfo, so add to it. + AddRawValue(info, value); } else { - // The header store did not contain the header. Add the raw string. - _headerStore.Add(descriptor, value); + // The header store contained a single raw string value, so promote it + // to being a HeaderStoreItemInfo and add to it. + Debug.Assert(currentValue is string); + storeValueRef = info = new HeaderStoreItemInfo() { RawValue = currentValue }; + AddRawValue(info, value); } return true; @@ -179,7 +185,7 @@ internal bool TryAddWithoutValidation(HeaderDescriptor descriptor, IEnumerable _headerStore?.Clear(); + public void Clear() => _headerStore.Clear(); public IEnumerable GetValues(string name) => GetValues(GetHeaderDescriptor(name)); @@ -206,7 +212,7 @@ public bool TryGetValues(string name, [NotNullWhen(true)] out IEnumerable? values) { - if (_headerStore != null && TryGetAndParseHeaderInfo(descriptor, out HeaderStoreItemInfo? info)) + if (TryGetAndParseHeaderInfo(descriptor, out HeaderStoreItemInfo? info)) { values = GetStoreValuesAsStringArray(descriptor, info); return true; @@ -223,7 +229,7 @@ internal bool Contains(HeaderDescriptor descriptor) // We can't just call headerStore.ContainsKey() since after parsing the value the header may not exist // anymore (if the value contains newline chars, we remove the header). So try to parse the // header value. - return _headerStore != null && TryGetAndParseHeaderInfo(descriptor, out _); + return TryGetAndParseHeaderInfo(descriptor, out _); } public override string ToString() @@ -235,14 +241,19 @@ public override string ToString() var vsb = new ValueStringBuilder(stackalloc char[512]); - if (_headerStore is Dictionary headerStore) + if (Entries is HeaderEntry[] entries) { - foreach (KeyValuePair header in headerStore) + foreach (HeaderEntry entry in entries) { - vsb.Append(header.Key.Name); + if (entry.Value is null) + { + break; + } + + vsb.Append(entry.Key.Name); vsb.Append(": "); - GetStoreValuesAsStringOrStringArray(header.Key, header.Value, out string? singleValue, out string[]? multiValue); + GetStoreValuesAsStringOrStringArray(entry.Key, entry.Value, out string? singleValue, out string[]? multiValue); Debug.Assert(singleValue is not null ^ multiValue is not null); if (singleValue is not null) @@ -253,7 +264,7 @@ public override string ToString() { // Note that if we get multiple values for a header that doesn't support multiple values, we'll // just separate the values using a comma (default separator). - string? separator = header.Key.Parser is HttpHeaderParser parser && parser.SupportsMultipleValues ? parser.Separator : HttpHeaderParser.DefaultSeparator; + string? separator = entry.Key.Parser is HttpHeaderParser parser && parser.SupportsMultipleValues ? parser.Separator : HttpHeaderParser.DefaultSeparator; for (int i = 0; i < multiValue!.Length; i++) { @@ -292,25 +303,30 @@ internal string GetHeaderString(HeaderDescriptor descriptor) #region IEnumerable>> Members - public IEnumerator>> GetEnumerator() => _headerStore != null && _headerStore.Count > 0 ? - GetEnumeratorCore() : - ((IEnumerable>>)Array.Empty>>()).GetEnumerator(); + public IEnumerator>> GetEnumerator() => _headerStore.IsEmpty ? + ((IEnumerable>>)Array.Empty>>()).GetEnumerator() : + GetEnumeratorCore(); private IEnumerator>> GetEnumeratorCore() { - foreach (KeyValuePair header in _headerStore!) + for (int i = 0; i < Entries!.Length; i++) { - HeaderDescriptor descriptor = header.Key; - object value = header.Value; + HeaderEntry entry = Entries[i]; + object value = entry.Value; + if (value is null) + { + break; + } + + HeaderDescriptor descriptor = entry.Key; - HeaderStoreItemInfo? info = value as HeaderStoreItemInfo; - if (info is null) + if (value is not HeaderStoreItemInfo info) { // To retain consistent semantics, we need to upgrade a raw string to a HeaderStoreItemInfo // during enumeration so that we can parse the raw value in order to a) return // the correct set of parsed values, and b) update the instance for subsequent enumerations // to reflect that parsing. - _headerStore[descriptor] = info = new HeaderStoreItemInfo() { RawValue = value }; + Entries[i].Value = info = new HeaderStoreItemInfo() { RawValue = value }; } // Make sure we parse all raw values before returning the result. Note that this has to be @@ -381,13 +397,13 @@ internal void SetOrRemoveParsedValue(HeaderDescriptor descriptor, object? value) public bool Remove(string name) => Remove(GetHeaderDescriptor(name)); - internal bool Remove(HeaderDescriptor descriptor) => _headerStore != null && _headerStore.Remove(descriptor); + internal bool Remove(HeaderDescriptor descriptor) => _headerStore.Remove(descriptor); internal bool RemoveParsedValue(HeaderDescriptor descriptor, object value) { Debug.Assert(value != null); - if (_headerStore == null) + if (Entries is null) { return false; } @@ -462,7 +478,7 @@ internal bool ContainsParsedValue(HeaderDescriptor descriptor, object value) { Debug.Assert(value != null); - if (_headerStore == null) + if (Entries is null) { return false; } @@ -517,41 +533,43 @@ internal virtual void AddHeaders(HttpHeaders sourceHeaders) Debug.Assert(sourceHeaders != null); Debug.Assert(GetType() == sourceHeaders.GetType(), "Can only copy headers from an instance of the same type."); - Dictionary? sourceHeadersStore = sourceHeaders._headerStore; - if (sourceHeadersStore is null || sourceHeadersStore.Count == 0) + if (sourceHeaders.Entries is HeaderEntry[] sourceEntries) { - return; - } - - _headerStore ??= new Dictionary(); - - foreach (KeyValuePair header in sourceHeadersStore) - { - // Only add header values if they're not already set on the message. Note that we don't merge - // collections: If both the default headers and the message have set some values for a certain - // header, then we don't try to merge the values. - if (!_headerStore.ContainsKey(header.Key)) + foreach (HeaderEntry entry in sourceEntries) { - object sourceValue = header.Value; - if (sourceValue is HeaderStoreItemInfo info) + if (entry.Value is null) { - AddHeaderInfo(header.Key, info); + break; } - else + + // Only add header values if they're not already set on the message. Note that we don't merge + // collections: If both the default headers and the message have set some values for a certain + // header, then we don't try to merge the values. + ref object? storeValueRef = ref _headerStore.GetValueRefOrAddDefault(entry.Key); + if (storeValueRef is null) { - Debug.Assert(sourceValue is string); - _headerStore.Add(header.Key, sourceValue); + object sourceValue = entry.Value; + if (sourceValue is HeaderStoreItemInfo info) + { + storeValueRef = CloneHeaderInfo(entry.Key, info); + } + else + { + Debug.Assert(sourceValue is string); + storeValueRef = sourceValue; + } } } } } - private void AddHeaderInfo(HeaderDescriptor descriptor, HeaderStoreItemInfo sourceInfo) + private HeaderStoreItemInfo CloneHeaderInfo(HeaderDescriptor descriptor, HeaderStoreItemInfo sourceInfo) { - HeaderStoreItemInfo destinationInfo = CreateAndAddHeaderToStore(descriptor); - - // Always copy raw values - destinationInfo.RawValue = CloneStringHeaderInfoValues(sourceInfo.RawValue); + var destinationInfo = new HeaderStoreItemInfo + { + // Always copy raw values + RawValue = CloneStringHeaderInfoValues(sourceInfo.RawValue) + }; if (descriptor.Parser == null) { @@ -585,6 +603,8 @@ private void AddHeaderInfo(HeaderDescriptor descriptor, HeaderStoreItemInfo sour } } } + + return destinationInfo; } private static void CloneAndAddValue(HeaderStoreItemInfo destinationInfo, object source) @@ -643,7 +663,8 @@ private HeaderStoreItemInfo GetOrCreateHeaderInfo(HeaderDescriptor descriptor, b else { Debug.Assert(value is string); - _headerStore![descriptor] = result = new HeaderStoreItemInfo { RawValue = value }; + result = new HeaderStoreItemInfo { RawValue = value }; + AddHeaderToStore(descriptor, result); } } } @@ -673,24 +694,32 @@ private HeaderStoreItemInfo CreateAndAddHeaderToStore(HeaderDescriptor descripto private void AddHeaderToStore(HeaderDescriptor descriptor, object value) { Debug.Assert(value is string || value is HeaderStoreItemInfo); - (_headerStore ??= new Dictionary()).Add(descriptor, value); + Debug.Assert(Unsafe.IsNullRef(ref _headerStore.GetValueRefOrNullRef(descriptor))); + _headerStore.GetValueRefOrAddDefault(descriptor) = value; } internal bool TryGetHeaderValue(HeaderDescriptor descriptor, [NotNullWhen(true)] out object? value) { - if (_headerStore == null) + ref object valueRef = ref _headerStore.GetValueRefOrNullRef(descriptor); + + if (Unsafe.IsNullRef(ref valueRef)) { value = null; return false; } - - return _headerStore.TryGetValue(descriptor, out value); + else + { + value = valueRef; + return true; + } } private bool TryGetAndParseHeaderInfo(HeaderDescriptor key, [NotNullWhen(true)] out HeaderStoreItemInfo? info) { - if (TryGetHeaderValue(key, out object? value)) + ref object valueRef = ref _headerStore.GetValueRefOrNullRef(key); + if (!Unsafe.IsNullRef(ref valueRef)) { + object value = valueRef; if (value is HeaderStoreItemInfo hsi) { info = hsi; @@ -698,7 +727,7 @@ private bool TryGetAndParseHeaderInfo(HeaderDescriptor key, [NotNullWhen(true)] else { Debug.Assert(value is string); - _headerStore![key] = info = new HeaderStoreItemInfo() { RawValue = value }; + valueRef = info = new HeaderStoreItemInfo() { RawValue = value }; } return ParseRawHeaderValues(key, info, removeEmptyHeader: true); @@ -737,7 +766,7 @@ private bool ParseRawHeaderValues(HeaderDescriptor descriptor, HeaderStoreItemIn if (removeEmptyHeader) { // After parsing the raw value, no value is left because all values contain newline chars. - Debug.Assert(_headerStore != null); + Debug.Assert(_headerStore.Entries is not null); _headerStore.Remove(descriptor); } return false; @@ -1252,5 +1281,137 @@ internal bool CanAddParsedValue(HttpHeaderParser parser) internal bool IsEmpty => (RawValue == null) && (InvalidValue == null) && (ParsedValue == null); } + + internal struct HeaderStore + { + private const int InitialCapacity = 4; + private const int MaxHeaderCount = 128; + + public HeaderEntry[]? Entries; + + public bool IsEmpty => Entries is not HeaderEntry[] entries || entries[0].Value is null; + + public int Count + { + get + { + if (Entries is HeaderEntry[] store) + { + for (int i = 0; i < store.Length; i++) + { + if (store[i].Value is null) + { + return i; + } + } + return store.Length; + } + + return 0; + } + } + + public ref object GetValueRefOrNullRef(HeaderDescriptor key) + { + if (Entries is HeaderEntry[] entries) + { + for (int i = 0; i < entries.Length; i++) + { + if (entries[i].Key.Equals(key)) + { + return ref entries[i].Value; + } + } + } + + return ref Unsafe.NullRef(); + } + + public ref object? GetValueRefOrAddDefault(HeaderDescriptor key) + { + if (Entries is HeaderEntry[] entries) + { + for (int i = 0; i < entries.Length; i++) + { + ref HeaderEntry entry = ref entries[i]; + if (entry.Value is null) + { + entry.Key = key; + return ref entry.Value!; + } + else if (entry.Key.Equals(key)) + { + return ref entry.Value!; + } + } + + return ref GrowEntriesAndAddDefault(key); + } + else + { + entries = new HeaderEntry[InitialCapacity]; + Entries = entries; + ref HeaderEntry firstElement = ref MemoryMarshal.GetArrayDataReference(entries); + firstElement.Key = key; + return ref firstElement.Value!; + } + } + + private ref object? GrowEntriesAndAddDefault(HeaderDescriptor key) + { + HeaderEntry[] entries = Entries!; + if (entries.Length == MaxHeaderCount) + { + ThrowOnTooManyHeaders(); + } + Array.Resize(ref entries, entries.Length << 1); + Entries = entries; + ref HeaderEntry firstNewEntry = ref entries[entries.Length >> 1]; + firstNewEntry.Key = key; + return ref firstNewEntry.Value!; + } + + public void Clear() + { + if (Entries is HeaderEntry[] entries) + { + Array.Clear(entries); + } + } + + public bool Remove(HeaderDescriptor key) + { + if (Entries is HeaderEntry[] entries) + { + for (int i = 0; i < entries.Length; i++) + { + if (entries[i].Value is null) + { + break; + } + + if (entries[i].Key.Equals(key)) + { + while ((uint)(i + 1) < (uint)entries.Length && entries[i + 1].Value is not null) + { + entries[i] = entries[i + 1]; + i++; + } + entries[i] = default; + + return true; + } + } + } + + return false; + } + + [DoesNotReturn] + private static void ThrowOnTooManyHeaders() + { + throw new HttpRequestException("Too many headers!"); + } + } } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs index 5e67476d116a99..8df78ba716dcb5 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs @@ -25,7 +25,7 @@ namespace System.Net.Http.Headers /// Gets the number of headers stored in the collection. /// Multiple header values associated with the same header name are considered to be one header as far as this count is concerned. - public int Count => _headers?.HeaderStore?.Count ?? 0; + public int Count => _headers?.Count ?? 0; /// Gets whether the collection contains the specified header. /// The name of the header. @@ -83,8 +83,8 @@ public bool TryGetValues(string headerName, out HeaderStringValues values) /// Gets an enumerator that iterates through the . /// An enumerator that iterates through the . public Enumerator GetEnumerator() => - _headers is HttpHeaders headers && headers.HeaderStore is Dictionary store ? - new Enumerator(store.GetEnumerator()) : + _headers is HttpHeaders headers && headers.Entries is HeaderEntry[] entries ? + new Enumerator(entries) : default; /// @@ -120,36 +120,38 @@ IEnumerable IReadOnlyDictionary. /// Enumerates the elements of a . public struct Enumerator : IEnumerator> { - /// The wrapped enumerator for the underlying headers dictionary. - private Dictionary.Enumerator _headerStoreEnumerator; - /// The current value. + private readonly HeaderEntry[] _entries; + private int _index; private KeyValuePair _current; - /// true if the enumerator was constructed via the ctor; otherwise, false./ - private bool _valid; /// Initializes the enumerator. - /// The underlying dictionary enumerator. - internal Enumerator(Dictionary.Enumerator headerStoreEnumerator) + /// The underlying header entries. + internal Enumerator(HeaderEntry[] entries) { - _headerStoreEnumerator = headerStoreEnumerator; + _entries = entries; + _index = 0; _current = default; - _valid = true; } /// public bool MoveNext() { - if (_valid && _headerStoreEnumerator.MoveNext()) + int index = _index; + if (_entries is HeaderEntry[] entries && (uint)index < (uint)entries.Length) { - KeyValuePair current = _headerStoreEnumerator.Current; - - HttpHeaders.GetStoreValuesAsStringOrStringArray(current.Key, current.Value, out string? singleValue, out string[]? multiValue); - Debug.Assert(singleValue is not null ^ multiValue is not null); - - _current = new KeyValuePair( - current.Key.Name, - singleValue is not null ? new HeaderStringValues(current.Key, singleValue) : new HeaderStringValues(current.Key, multiValue!)); - return true; + HeaderEntry entry = entries[index]; + _index++; + + if (entry.Value is not null) + { + HttpHeaders.GetStoreValuesAsStringOrStringArray(entry.Key, entry.Value, out string? singleValue, out string[]? multiValue); + Debug.Assert(singleValue is not null ^ multiValue is not null); + + _current = new KeyValuePair( + entry.Key.Name, + singleValue is not null ? new HeaderStringValues(entry.Key, singleValue) : new HeaderStringValues(entry.Key, multiValue!)); + return true; + } } _current = default; diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs index 811d33dbfb4ef5..1a8157c2a3073f 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs @@ -1337,7 +1337,7 @@ private void WriteHeaderCollection(HttpRequestMessage request, HttpHeaders heade { if (NetEventSource.Log.IsEnabled()) Trace(""); - if (headers.HeaderStore is null) + if (headers.Entries is null) { return; } @@ -1345,8 +1345,13 @@ private void WriteHeaderCollection(HttpRequestMessage request, HttpHeaders heade HeaderEncodingSelector? encodingSelector = _pool.Settings._requestHeaderEncodingSelector; ref string[]? tmpHeaderValuesArray = ref t_headerValues; - foreach (KeyValuePair header in headers.HeaderStore) + foreach (HeaderEntry header in headers.Entries) { + if (header.Value is null) + { + break; + } + int headerValuesCount = HttpHeaders.GetStoreValuesIntoStringArray(header.Key, header.Value, ref tmpHeaderValuesArray); Debug.Assert(headerValuesCount > 0, "No values for header??"); ReadOnlySpan headerValues = tmpHeaderValuesArray.AsSpan(0, headerValuesCount); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs index 3a236ccc65ae38..8af7c82c4e7bff 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs @@ -622,15 +622,20 @@ private void BufferHeaders(HttpRequestMessage request) // TODO: special-case Content-Type for static table values values? private void BufferHeaderCollection(HttpHeaders headers) { - if (headers.HeaderStore == null) + if (headers.Entries is null) { return; } HeaderEncodingSelector? encodingSelector = _connection.Pool.Settings._requestHeaderEncodingSelector; - foreach (KeyValuePair header in headers.HeaderStore) + foreach (HeaderEntry header in headers.Entries) { + if (header.Value is null) + { + break; + } + int headerValuesCount = HttpHeaders.GetStoreValuesIntoStringArray(header.Key, header.Value, ref _headerValues); Debug.Assert(headerValuesCount > 0, "No values for header??"); ReadOnlySpan headerValues = _headerValues.AsSpan(0, headerValuesCount); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs index 12b6b4bf3c55a0..b0042b50dfdbda 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs @@ -258,10 +258,15 @@ private async ValueTask WriteHeadersAsync(HttpHeaders headers, string? cookiesFr { Debug.Assert(_currentRequest != null); - if (headers.HeaderStore != null) + if (headers.Entries != null) { - foreach (KeyValuePair header in headers.HeaderStore) + foreach (HeaderEntry header in headers.Entries) { + if (header.Value is null) + { + break; + } + if (header.Key.KnownHeader != null) { await WriteBytesAsync(header.Key.KnownHeader.AsciiBytesWithColonSpace, async).ConfigureAwait(false); From d23ad1835d8e67ba55f4623c4d46b8774750646a Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Sat, 18 Dec 2021 09:13:06 +0100 Subject: [PATCH 02/22] Reduce the size of HeaderDescriptor to 1 object --- .../Net/Http/Headers/HeaderDescriptor.cs | 69 +++++++++++------ .../Net/Http/Headers/HeaderStringValues.cs | 2 +- .../System/Net/Http/Headers/HttpHeaders.cs | 75 +++++++++++++------ .../Http/Headers/HttpHeadersNonValidated.cs | 2 +- .../System/Net/Http/Headers/KnownHeader.cs | 2 + .../SocketsHttpHandler/Http2Connection.cs | 24 ++---- .../SocketsHttpHandler/Http3RequestStream.cs | 26 ++----- .../Http/SocketsHttpHandler/HttpConnection.cs | 19 ++--- .../SocketsHttpHandler/HttpConnectionBase.cs | 4 +- 9 files changed, 125 insertions(+), 98 deletions(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs index 229490b5432dd9..44a68cd28b2585 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs @@ -4,6 +4,7 @@ using System.Buffers; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Text; using System.Text.Unicode; @@ -15,32 +16,55 @@ namespace System.Net.Http.Headers // Use HeaderDescriptor.TryGet to resolve an arbitrary header name to a HeaderDescriptor. internal readonly struct HeaderDescriptor : IEquatable { - private readonly string _headerName; - private readonly KnownHeader? _knownHeader; + /// + /// Either a or + /// + private readonly object _descriptor; public HeaderDescriptor(KnownHeader knownHeader) { - _knownHeader = knownHeader; - _headerName = knownHeader.Name; + _descriptor = knownHeader; } // This should not be used directly; use static TryGet below internal HeaderDescriptor(string headerName) { - _headerName = headerName; - _knownHeader = null; + _descriptor = headerName; } - public string Name => _headerName; - public HttpHeaderParser? Parser => _knownHeader?.Parser; - public HttpHeaderType HeaderType => _knownHeader == null ? HttpHeaderType.Custom : _knownHeader.HeaderType; - public KnownHeader? KnownHeader => _knownHeader; + public string Name + { + get + { + object? descriptor = _descriptor; + return descriptor is KnownHeader knownHeader ? + knownHeader.Name : + Unsafe.As(descriptor); + } + } + + public object Descriptor => _descriptor; + + public HttpHeaderParser? Parser => _descriptor is KnownHeader knownHeader ? knownHeader.Parser : null; + public HttpHeaderType HeaderType => _descriptor is KnownHeader knownHeader ? knownHeader.HeaderType : HttpHeaderType.Custom; + public string Separator => _descriptor is KnownHeader knownHeader ? knownHeader.Separator : HttpHeaderParser.DefaultSeparator; + + public bool Equals(HeaderDescriptor other) + { + object? descriptor = _descriptor; + object? otherDescriptor = other._descriptor; + + if (descriptor is string headerName) + { + return string.Equals(headerName, otherDescriptor as string, StringComparison.OrdinalIgnoreCase); + } + else + { + return ReferenceEquals(descriptor, otherDescriptor); + } + } - public bool Equals(HeaderDescriptor other) => - _knownHeader == null ? - string.Equals(_headerName, other._headerName, StringComparison.OrdinalIgnoreCase) : - _knownHeader == other._knownHeader; - public override int GetHashCode() => _knownHeader?.GetHashCode() ?? StringComparer.OrdinalIgnoreCase.GetHashCode(_headerName); + public override int GetHashCode() => throw new InvalidOperationException(); // We don't expect this to be called public override bool Equals(object? obj) => throw new InvalidOperationException(); // Ensure this is never called, to avoid boxing // Returns false for invalid header name. @@ -112,9 +136,9 @@ internal static bool TryGetStaticQPackHeader(int index, out HeaderDescriptor des public HeaderDescriptor AsCustomHeader() { - Debug.Assert(_knownHeader != null); - Debug.Assert(_knownHeader.HeaderType != HttpHeaderType.Custom); - return new HeaderDescriptor(_knownHeader.Name); + Debug.Assert(_descriptor is KnownHeader); + Debug.Assert(HeaderType != HttpHeaderType.Custom); + return new HeaderDescriptor(Name); } public string GetHeaderValue(ReadOnlySpan headerValue, Encoding? valueEncoding) @@ -125,10 +149,9 @@ public string GetHeaderValue(ReadOnlySpan headerValue, Encoding? valueEnco } // If it's a known header value, use the known value instead of allocating a new string. - if (_knownHeader != null) + if (_descriptor is KnownHeader knownHeader) { - string[]? knownValues = _knownHeader.KnownValues; - if (knownValues != null) + if (knownHeader.KnownValues is string[] knownValues) { for (int i = 0; i < knownValues.Length; i++) { @@ -139,7 +162,7 @@ public string GetHeaderValue(ReadOnlySpan headerValue, Encoding? valueEnco } } - if (_knownHeader == KnownHeaders.ContentType) + if (knownHeader == KnownHeaders.ContentType) { string? contentType = GetKnownContentType(headerValue); if (contentType != null) @@ -147,7 +170,7 @@ public string GetHeaderValue(ReadOnlySpan headerValue, Encoding? valueEnco return contentType; } } - else if (_knownHeader == KnownHeaders.Location) + else if (knownHeader == KnownHeaders.Location) { // Normally Location should be in ISO-8859-1 but occasionally some servers respond with UTF-8. if (TryDecodeUtf8(headerValue, out string? decoded)) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderStringValues.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderStringValues.cs index a313a2306e78f0..6b5f4d2a666a2a 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderStringValues.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderStringValues.cs @@ -45,7 +45,7 @@ internal HeaderStringValues(HeaderDescriptor descriptor, string[] values) public override string ToString() => _value switch { string value => value, - string[] values => string.Join(_header.Parser is HttpHeaderParser parser && parser.SupportsMultipleValues ? parser.Separator : HttpHeaderParser.DefaultSeparator, values), + string[] values => string.Join(_header.Separator, values), _ => string.Empty, }; diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index 53ee582f807c44..467c89d893328a 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -264,7 +264,7 @@ public override string ToString() { // Note that if we get multiple values for a header that doesn't support multiple values, we'll // just separate the values using a comma (default separator). - string? separator = entry.Key.Parser is HttpHeaderParser parser && parser.SupportsMultipleValues ? parser.Separator : HttpHeaderParser.DefaultSeparator; + string separator = entry.Key.Separator; for (int i = 0; i < multiValue!.Length; i++) { @@ -294,8 +294,7 @@ internal string GetHeaderString(HeaderDescriptor descriptor) // Note that if we get multiple values for a header that doesn't support multiple values, we'll // just separate the values using a comma (default separator). - string? separator = descriptor.Parser != null && descriptor.Parser.SupportsMultipleValues ? descriptor.Parser.Separator : HttpHeaderParser.DefaultSeparator; - return string.Join(separator, multiValue!); + return string.Join(descriptor.Separator, multiValue!); } return string.Empty; @@ -701,7 +700,6 @@ private void AddHeaderToStore(HeaderDescriptor descriptor, object value) internal bool TryGetHeaderValue(HeaderDescriptor descriptor, [NotNullWhen(true)] out object? value) { ref object valueRef = ref _headerStore.GetValueRefOrNullRef(descriptor); - if (Unsafe.IsNullRef(ref valueRef)) { value = null; @@ -1289,22 +1287,22 @@ internal struct HeaderStore public HeaderEntry[]? Entries; - public bool IsEmpty => Entries is not HeaderEntry[] entries || entries[0].Value is null; + public bool IsEmpty => Entries is not HeaderEntry[] entries || entries[0].Key.Descriptor is null; public int Count { get { - if (Entries is HeaderEntry[] store) + if (Entries is HeaderEntry[] entries) { - for (int i = 0; i < store.Length; i++) + for (int i = 0; i < entries.Length; i++) { - if (store[i].Value is null) + if (entries[i].Key.Descriptor is null) { return i; } } - return store.Length; + return entries.Length; } return 0; @@ -1315,11 +1313,24 @@ public ref object GetValueRefOrNullRef(HeaderDescriptor key) { if (Entries is HeaderEntry[] entries) { - for (int i = 0; i < entries.Length; i++) + if (key.Descriptor is string) + { + for (int i = 0; i < entries.Length; i++) + { + if (string.Equals(entries[i].Key.Descriptor as string, Unsafe.As(key.Descriptor), StringComparison.OrdinalIgnoreCase)) + { + return ref entries[i].Value; + } + } + } + else { - if (entries[i].Key.Equals(key)) + for (int i = 0; i < entries.Length; i++) { - return ref entries[i].Value; + if (ReferenceEquals(entries[i].Key.Descriptor, key.Descriptor)) + { + return ref entries[i].Value; + } } } } @@ -1331,17 +1342,36 @@ public ref object GetValueRefOrNullRef(HeaderDescriptor key) { if (Entries is HeaderEntry[] entries) { - for (int i = 0; i < entries.Length; i++) + if (key.Descriptor is string) { - ref HeaderEntry entry = ref entries[i]; - if (entry.Value is null) + for (int i = 0; i < entries.Length; i++) { - entry.Key = key; - return ref entry.Value!; + ref HeaderEntry entry = ref entries[i]; + if (entry.Key.Descriptor is null) + { + entry.Key = key; + return ref entry.Value!; + } + else if (string.Equals(entry.Key.Descriptor as string, Unsafe.As(key.Descriptor), StringComparison.OrdinalIgnoreCase)) + { + return ref entry.Value!; + } } - else if (entry.Key.Equals(key)) + } + else + { + for (int i = 0; i < entries.Length; i++) { - return ref entry.Value!; + ref HeaderEntry entry = ref entries[i]; + if (entry.Key.Descriptor is null) + { + entry.Key = key; + return ref entry.Value!; + } + else if (ReferenceEquals(entry.Key.Descriptor, key.Descriptor)) + { + return ref entry.Value!; + } } } @@ -1385,14 +1415,15 @@ public bool Remove(HeaderDescriptor key) { for (int i = 0; i < entries.Length; i++) { - if (entries[i].Value is null) + HeaderDescriptor entryKey = entries[i].Key; + if (entryKey.Descriptor is null) { break; } - if (entries[i].Key.Equals(key)) + if (entryKey.Equals(key)) { - while ((uint)(i + 1) < (uint)entries.Length && entries[i + 1].Value is not null) + while ((uint)(i + 1) < (uint)entries.Length && entries[i + 1].Key.Descriptor is not null) { entries[i] = entries[i + 1]; i++; diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs index 8df78ba716dcb5..d3bde0099b1571 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs @@ -142,7 +142,7 @@ public bool MoveNext() HeaderEntry entry = entries[index]; _index++; - if (entry.Value is not null) + if (entry.Key.Descriptor is not null) { HttpHeaders.GetStoreValuesAsStringOrStringArray(entry.Key, entry.Value, out string? singleValue, out string[]? multiValue); Debug.Assert(singleValue is not null ^ multiValue is not null); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/KnownHeader.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/KnownHeader.cs index 0163db073a852e..72705dd19ea8df 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/KnownHeader.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/KnownHeader.cs @@ -47,5 +47,7 @@ public KnownHeader(string name, HttpHeaderType headerType, HttpHeaderParser? par public string[]? KnownValues { get; } public byte[] AsciiBytesWithColonSpace { get; } public HeaderDescriptor Descriptor => new HeaderDescriptor(this); + + public string Separator => Parser is HttpHeaderParser parser && parser.SupportsMultipleValues ? parser.Separator! : HttpHeaderParser.DefaultSeparator; } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs index 1a8157c2a3073f..88ad58d7083fc9 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs @@ -1347,7 +1347,7 @@ private void WriteHeaderCollection(HttpRequestMessage request, HttpHeaders heade ref string[]? tmpHeaderValuesArray = ref t_headerValues; foreach (HeaderEntry header in headers.Entries) { - if (header.Value is null) + if (header.Key.Descriptor is null) { break; } @@ -1358,15 +1358,14 @@ private void WriteHeaderCollection(HttpRequestMessage request, HttpHeaders heade Encoding? valueEncoding = encodingSelector?.Invoke(header.Key.Name, request); - KnownHeader? knownHeader = header.Key.KnownHeader; - if (knownHeader != null) + if (header.Key.Descriptor is KnownHeader knownHeader) { // The Host header is not sent for HTTP2 because we send the ":authority" pseudo-header instead // (see pseudo-header handling below in WriteHeaders). // The Connection, Upgrade and ProxyConnection headers are also not supported in HTTP2. if (knownHeader != KnownHeaders.Host && knownHeader != KnownHeaders.Connection && knownHeader != KnownHeaders.Upgrade && knownHeader != KnownHeaders.ProxyConnection) { - if (header.Key.KnownHeader == KnownHeaders.TE) + if (knownHeader == KnownHeaders.TE) { // HTTP/2 allows only 'trailers' TE header. rfc7540 8.1.2.2 foreach (string value in headerValues) @@ -1383,19 +1382,8 @@ private void WriteHeaderCollection(HttpRequestMessage request, HttpHeaders heade // For all other known headers, send them via their pre-encoded name and the associated value. WriteBytes(knownHeader.Http2EncodedName, ref headerBuffer); - string? separator = null; - if (headerValues.Length > 1) - { - HttpHeaderParser? parser = header.Key.Parser; - if (parser != null && parser.SupportsMultipleValues) - { - separator = parser.Separator; - } - else - { - separator = HttpHeaderParser.DefaultSeparator; - } - } + + string? separator = headerValues.Length > 1 ? knownHeader.Separator : null; WriteLiteralHeaderValues(headerValues, separator, valueEncoding, ref headerBuffer); } @@ -1403,7 +1391,7 @@ private void WriteHeaderCollection(HttpRequestMessage request, HttpHeaders heade else { // The header is not known: fall back to just encoding the header name and value(s). - WriteLiteralHeader(header.Key.Name, headerValues, valueEncoding, ref headerBuffer); + WriteLiteralHeader(Unsafe.As(header.Key.Descriptor), headerValues, valueEncoding, ref headerBuffer); } } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs index 8af7c82c4e7bff..3d740cbe90abbc 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs @@ -631,7 +631,7 @@ private void BufferHeaderCollection(HttpHeaders headers) foreach (HeaderEntry header in headers.Entries) { - if (header.Value is null) + if (header.Key.Descriptor is null) { break; } @@ -642,15 +642,14 @@ private void BufferHeaderCollection(HttpHeaders headers) Encoding? valueEncoding = encodingSelector?.Invoke(header.Key.Name, _request); - KnownHeader? knownHeader = header.Key.KnownHeader; - if (knownHeader != null) + if (header.Key.Descriptor is KnownHeader knownHeader) { // The Host header is not sent for HTTP/3 because we send the ":authority" pseudo-header instead // (see pseudo-header handling below in WriteHeaders). // The Connection, Upgrade and ProxyConnection headers are also not supported in HTTP/3. if (knownHeader != KnownHeaders.Host && knownHeader != KnownHeaders.Connection && knownHeader != KnownHeaders.Upgrade && knownHeader != KnownHeaders.ProxyConnection) { - if (header.Key.KnownHeader == KnownHeaders.TE) + if (knownHeader == KnownHeaders.TE) { // HTTP/2 allows only 'trailers' TE header. rfc7540 8.1.2.2 // HTTP/3 does not mention this one way or another; assume it has the same rule. @@ -667,19 +666,8 @@ private void BufferHeaderCollection(HttpHeaders headers) // For all other known headers, send them via their pre-encoded name and the associated value. BufferBytes(knownHeader.Http3EncodedName); - string? separator = null; - if (headerValues.Length > 1) - { - HttpHeaderParser? parser = header.Key.Parser; - if (parser != null && parser.SupportsMultipleValues) - { - separator = parser.Separator; - } - else - { - separator = HttpHeaderParser.DefaultSeparator; - } - } + + string? separator = headerValues.Length > 1 ? knownHeader.Separator : null; BufferLiteralHeaderValues(headerValues, separator, valueEncoding); } @@ -687,7 +675,7 @@ private void BufferHeaderCollection(HttpHeaders headers) else { // The header is not known: fall back to just encoding the header name and value(s). - BufferLiteralHeaderWithoutNameReference(header.Key.Name, headerValues, HttpHeaderParser.DefaultSeparator, valueEncoding); + BufferLiteralHeaderWithoutNameReference(Unsafe.As(header.Key.Descriptor), headerValues, HttpHeaderParser.DefaultSeparator, valueEncoding); } } } @@ -910,7 +898,7 @@ private void OnHeader(int? staticIndex, HeaderDescriptor descriptor, string? sta { if (descriptor.Name[0] == ':') { - if (descriptor.KnownHeader != KnownHeaders.PseudoStatus) + if (descriptor.Descriptor != KnownHeaders.PseudoStatus) { if (NetEventSource.Log.IsEnabled()) Trace($"Received unknown pseudo-header '{descriptor.Name}'."); throw new Http3ConnectionException(Http3ErrorCode.ProtocolError); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs index b0042b50dfdbda..9c386937c34e72 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs @@ -262,18 +262,18 @@ private async ValueTask WriteHeadersAsync(HttpHeaders headers, string? cookiesFr { foreach (HeaderEntry header in headers.Entries) { - if (header.Value is null) + if (header.Key.Descriptor is null) { break; } - if (header.Key.KnownHeader != null) + if (header.Key.Descriptor is KnownHeader) { - await WriteBytesAsync(header.Key.KnownHeader.AsciiBytesWithColonSpace, async).ConfigureAwait(false); + await WriteBytesAsync(Unsafe.As(header.Key.Descriptor).AsciiBytesWithColonSpace, async).ConfigureAwait(false); } else { - await WriteAsciiStringAsync(header.Key.Name, async).ConfigureAwait(false); + await WriteAsciiStringAsync(Unsafe.As(header.Key.Descriptor), async).ConfigureAwait(false); await WriteTwoBytesAsync((byte)':', (byte)' ', async).ConfigureAwait(false); } @@ -285,7 +285,7 @@ private async ValueTask WriteHeadersAsync(HttpHeaders headers, string? cookiesFr await WriteStringAsync(_headerValues[0], async, valueEncoding).ConfigureAwait(false); - if (cookiesFromContainer != null && header.Key.KnownHeader == KnownHeaders.Cookie) + if (cookiesFromContainer != null && header.Key.Descriptor == KnownHeaders.Cookie) { await WriteTwoBytesAsync((byte)';', (byte)' ', async).ConfigureAwait(false); await WriteStringAsync(cookiesFromContainer, async, valueEncoding).ConfigureAwait(false); @@ -296,12 +296,7 @@ private async ValueTask WriteHeadersAsync(HttpHeaders headers, string? cookiesFr // Some headers such as User-Agent and Server use space as a separator (see: ProductInfoHeaderParser) if (headerValuesCount > 1) { - HttpHeaderParser? parser = header.Key.Parser; - string separator = HttpHeaderParser.DefaultSeparator; - if (parser != null && parser.SupportsMultipleValues) - { - separator = parser.Separator!; - } + string separator = header.Key.Separator; for (int i = 1; i < headerValuesCount; i++) { @@ -1059,7 +1054,7 @@ private static void ParseHeaderNameValue(HttpConnection connection, ReadOnlySpan throw new HttpRequestException(SR.Format(SR.net_http_invalid_response_header_name, Encoding.ASCII.GetString(line.Slice(0, pos)))); } - if (isFromTrailer && descriptor.KnownHeader != null && (descriptor.KnownHeader.HeaderType & HttpHeaderType.NonTrailing) == HttpHeaderType.NonTrailing) + if (isFromTrailer && (descriptor.HeaderType & HttpHeaderType.NonTrailing) == HttpHeaderType.NonTrailing) { // Disallowed trailer fields. // A recipient MUST ignore fields that are forbidden to be sent in a trailer. diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs index 3c266f20da59f5..89a9b61903c2f7 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs @@ -24,8 +24,8 @@ internal abstract class HttpConnectionBase : IDisposable, IHttpTrace public string GetResponseHeaderValueWithCaching(HeaderDescriptor descriptor, ReadOnlySpan value, Encoding? valueEncoding) { return - ReferenceEquals(descriptor.KnownHeader, KnownHeaders.Date) ? GetOrAddCachedValue(ref _lastDateHeaderValue, descriptor, value, valueEncoding) : - ReferenceEquals(descriptor.KnownHeader, KnownHeaders.Server) ? GetOrAddCachedValue(ref _lastServerHeaderValue, descriptor, value, valueEncoding) : + ReferenceEquals(descriptor.Descriptor, KnownHeaders.Date) ? GetOrAddCachedValue(ref _lastDateHeaderValue, descriptor, value, valueEncoding) : + ReferenceEquals(descriptor.Descriptor, KnownHeaders.Server) ? GetOrAddCachedValue(ref _lastServerHeaderValue, descriptor, value, valueEncoding) : descriptor.GetHeaderValue(value, valueEncoding); static string GetOrAddCachedValue([NotNull] ref string? cache, HeaderDescriptor descriptor, ReadOnlySpan value, Encoding? encoding) From 18c2eaf5c8d103bd2d7496aaa403b8f4f7a41b2b Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Sat, 18 Dec 2021 10:42:17 +0100 Subject: [PATCH 03/22] Update UnitTests, fix GetOrCreateHeaderInfo --- .../System/Net/Http/Headers/HttpHeaders.cs | 32 +++++++------------ .../UnitTests/HPack/HPackRoundtripTests.cs | 25 ++++++--------- .../UnitTests/Headers/HeaderEncodingTest.cs | 4 +-- .../UnitTests/Headers/KnownHeadersTest.cs | 2 +- 4 files changed, 24 insertions(+), 39 deletions(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index 467c89d893328a..95e7857c91cf83 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -644,37 +644,29 @@ private static void CloneAndAddValue(HeaderStoreItemInfo destinationInfo, object private HeaderStoreItemInfo GetOrCreateHeaderInfo(HeaderDescriptor descriptor, bool parseRawValues) { - HeaderStoreItemInfo? result = null; - bool found; if (parseRawValues) { - found = TryGetAndParseHeaderInfo(descriptor, out result); + if (TryGetAndParseHeaderInfo(descriptor, out HeaderStoreItemInfo? info)) + { + return info; + } } else { - found = TryGetHeaderValue(descriptor, out object? value); - if (found) + ref object valueRef = ref _headerStore.GetValueRefOrNullRef(descriptor); + if (!Unsafe.IsNullRef(ref valueRef)) { - if (value is HeaderStoreItemInfo hsti) - { - result = hsti; - } - else + object value = valueRef; + if (value is not HeaderStoreItemInfo info) { Debug.Assert(value is string); - result = new HeaderStoreItemInfo { RawValue = value }; - AddHeaderToStore(descriptor, result); + valueRef = info = new HeaderStoreItemInfo { RawValue = value }; } + return info; } } - if (!found) - { - result = CreateAndAddHeaderToStore(descriptor); - } - - Debug.Assert(result != null); - return result; + return CreateAndAddHeaderToStore(descriptor); } private HeaderStoreItemInfo CreateAndAddHeaderToStore(HeaderDescriptor descriptor) @@ -692,7 +684,7 @@ private HeaderStoreItemInfo CreateAndAddHeaderToStore(HeaderDescriptor descripto private void AddHeaderToStore(HeaderDescriptor descriptor, object value) { - Debug.Assert(value is string || value is HeaderStoreItemInfo); + Debug.Assert(value is string or HeaderStoreItemInfo); Debug.Assert(Unsafe.IsNullRef(ref _headerStore.GetValueRefOrNullRef(descriptor))); _headerStore.GetValueRefOrAddDefault(descriptor) = value; } diff --git a/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs b/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs index c3b16afb47360d..2fb1827fa2f12a 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs @@ -60,30 +60,23 @@ private static Memory HPackEncode(HttpHeaders headers, Encoding? valueEnco FillAvailableSpaceWithOnes(buffer); string[] headerValues = Array.Empty(); - foreach (KeyValuePair header in headers.HeaderStore) + foreach (HeaderEntry header in headers.Entries) { + if (header.Key.Descriptor is null) + { + break; + } + int headerValuesCount = HttpHeaders.GetStoreValuesIntoStringArray(header.Key, header.Value, ref headerValues); Assert.InRange(headerValuesCount, 0, int.MaxValue); ReadOnlySpan headerValuesSpan = headerValues.AsSpan(0, headerValuesCount); - KnownHeader knownHeader = header.Key.KnownHeader; - if (knownHeader != null) + if (header.Key.Descriptor is KnownHeader knownHeader) { // For all other known headers, send them via their pre-encoded name and the associated value. WriteBytes(knownHeader.Http2EncodedName); - string separator = null; - if (headerValuesSpan.Length > 1) - { - HttpHeaderParser parser = header.Key.Parser; - if (parser != null && parser.SupportsMultipleValues) - { - separator = parser.Separator; - } - else - { - separator = HttpHeaderParser.DefaultSeparator; - } - } + + string? separator = headerValuesSpan.Length > 1 ? knownHeader.Separator : null; WriteLiteralHeaderValues(headerValuesSpan, separator); } diff --git a/src/libraries/System.Net.Http/tests/UnitTests/Headers/HeaderEncodingTest.cs b/src/libraries/System.Net.Http/tests/UnitTests/Headers/HeaderEncodingTest.cs index d7f41e0deeee61..9b3ecd7aa9a5f8 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/Headers/HeaderEncodingTest.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/Headers/HeaderEncodingTest.cs @@ -24,12 +24,12 @@ public void RoundTripsUtf8(string input) byte[] encoded = Encoding.UTF8.GetBytes(input); Assert.True(HeaderDescriptor.TryGet("custom-header", out HeaderDescriptor descriptor)); - Assert.Null(descriptor.KnownHeader); + Assert.IsType(descriptor.Descriptor); string roundtrip = descriptor.GetHeaderValue(encoded, Encoding.UTF8); Assert.Equal(input, roundtrip); Assert.True(HeaderDescriptor.TryGet("Cache-Control", out descriptor)); - Assert.NotNull(descriptor.KnownHeader); + Assert.IsType(descriptor.Descriptor); roundtrip = descriptor.GetHeaderValue(encoded, Encoding.UTF8); Assert.Equal(input, roundtrip); } diff --git a/src/libraries/System.Net.Http/tests/UnitTests/Headers/KnownHeadersTest.cs b/src/libraries/System.Net.Http/tests/UnitTests/Headers/KnownHeadersTest.cs index 78fada7218b4ac..b94563f4dce3de 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/Headers/KnownHeadersTest.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/Headers/KnownHeadersTest.cs @@ -118,7 +118,7 @@ static void Validate(string name, KnownHeader h) Assert.NotNull(h); Assert.Same(h, KnownHeaders.TryGetKnownHeader(name)); - Assert.Same(h, h.Descriptor.KnownHeader); + Assert.Same(h, h.Descriptor.Descriptor); Assert.Equal(name, h.Name, StringComparer.OrdinalIgnoreCase); Assert.Equal(name, h.Descriptor.Name, StringComparer.OrdinalIgnoreCase); } From 59be8e5881c8f297eb0598381a213b39f05fd89c Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Mon, 20 Dec 2021 02:11:38 +0100 Subject: [PATCH 04/22] Switch to a dictionary after ArrayThreshold headers --- .../Net/Http/Headers/HeaderDescriptor.cs | 14 +- .../System/Net/Http/Headers/HttpHeaders.cs | 258 ++++++++++++------ .../Http/Headers/HttpHeadersNonValidated.cs | 2 +- .../SocketsHttpHandler/Http2Connection.cs | 4 +- .../SocketsHttpHandler/Http3RequestStream.cs | 4 +- .../Http/SocketsHttpHandler/HttpConnection.cs | 4 +- .../UnitTests/HPack/HPackRoundtripTests.cs | 2 +- 7 files changed, 195 insertions(+), 93 deletions(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs index 44a68cd28b2585..d3fc840edf91f0 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs @@ -64,7 +64,19 @@ public bool Equals(HeaderDescriptor other) } } - public override int GetHashCode() => throw new InvalidOperationException(); // We don't expect this to be called + public override int GetHashCode() + { + object? descriptor = _descriptor; + if (descriptor is string headerName) + { + return StringComparer.OrdinalIgnoreCase.GetHashCode(headerName); + } + else + { + return descriptor.GetHashCode(); + } + } + public override bool Equals(object? obj) => throw new InvalidOperationException(); // Ensure this is never called, to avoid boxing // Returns false for invalid header name. diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index 95e7857c91cf83..dcc2dceca48645 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -60,7 +60,7 @@ internal HttpHeaders(HttpHeaderType allowedHeaderTypes, HttpHeaderType treatAsCu _treatAsCustomHeaderTypes = treatAsCustomHeaderTypes & ~HttpHeaderType.NonTrailing; } - internal HeaderEntry[]? Entries => _headerStore.Entries; + internal HeaderEntry[]? GetEntries() => _headerStore.GetEntries(); internal int Count => _headerStore.Count; @@ -241,11 +241,11 @@ public override string ToString() var vsb = new ValueStringBuilder(stackalloc char[512]); - if (Entries is HeaderEntry[] entries) + if (GetEntries() is HeaderEntry[] entries) { foreach (HeaderEntry entry in entries) { - if (entry.Value is null) + if (entry.Key.Descriptor is null) { break; } @@ -308,16 +308,18 @@ public IEnumerator>> GetEnumerator() => private IEnumerator>> GetEnumeratorCore() { - for (int i = 0; i < Entries!.Length; i++) + HeaderEntry[]? entries = GetEntries()!; + + for (int i = 0; i < entries.Length; i++) { - HeaderEntry entry = Entries[i]; - object value = entry.Value; - if (value is null) + HeaderEntry entry = entries[i]; + HeaderDescriptor descriptor = entry.Key; + if (descriptor.Descriptor is null) { break; } - HeaderDescriptor descriptor = entry.Key; + object value = entry.Value; if (value is not HeaderStoreItemInfo info) { @@ -325,16 +327,21 @@ private IEnumerator>> GetEnumeratorCore // during enumeration so that we can parse the raw value in order to a) return // the correct set of parsed values, and b) update the instance for subsequent enumerations // to reflect that parsing. - Entries[i].Value = info = new HeaderStoreItemInfo() { RawValue = value }; + entries[i].Value = info = new HeaderStoreItemInfo() { RawValue = value }; } // Make sure we parse all raw values before returning the result. Note that this has to be // done before we calculate the array length (next line): A raw value may contain a list of // values. - if (!ParseRawHeaderValues(descriptor, info, removeEmptyHeader: false)) + if (!ParseRawHeaderValues(descriptor, info)) { - // We have an invalid header value (contains newline chars). Delete it. - _headerStore.Remove(descriptor); + // We saw an invalid header value (contains newline chars) and deleted it. + + // If the HeaderEntry[] we are enumerating is the live header store, the entries have shifted. + if (_headerStore.RemoveShiftsEntries) + { + i--; + } } else { @@ -402,11 +409,6 @@ internal bool RemoveParsedValue(HeaderDescriptor descriptor, object value) { Debug.Assert(value != null); - if (Entries is null) - { - return false; - } - // If we have a value for this header, then verify if we have a single value. If so, compare that // value with 'item'. If we have a list of values, then remove 'item' from the list. if (TryGetAndParseHeaderInfo(descriptor, out HeaderStoreItemInfo? info)) @@ -477,11 +479,6 @@ internal bool ContainsParsedValue(HeaderDescriptor descriptor, object value) { Debug.Assert(value != null); - if (Entries is null) - { - return false; - } - // If we have a value for this header, then verify if we have a single value. If so, compare that // value with 'item'. If we have a list of values, then compare each item in the list with 'item'. if (TryGetAndParseHeaderInfo(descriptor, out HeaderStoreItemInfo? info)) @@ -532,7 +529,7 @@ internal virtual void AddHeaders(HttpHeaders sourceHeaders) Debug.Assert(sourceHeaders != null); Debug.Assert(GetType() == sourceHeaders.GetType(), "Can only copy headers from an instance of the same type."); - if (sourceHeaders.Entries is HeaderEntry[] sourceEntries) + if (sourceHeaders.GetEntries() is HeaderEntry[] sourceEntries) { foreach (HeaderEntry entry in sourceEntries) { @@ -720,14 +717,14 @@ private bool TryGetAndParseHeaderInfo(HeaderDescriptor key, [NotNullWhen(true)] valueRef = info = new HeaderStoreItemInfo() { RawValue = value }; } - return ParseRawHeaderValues(key, info, removeEmptyHeader: true); + return ParseRawHeaderValues(key, info); } info = null; return false; } - private bool ParseRawHeaderValues(HeaderDescriptor descriptor, HeaderStoreItemInfo info, bool removeEmptyHeader) + private bool ParseRawHeaderValues(HeaderDescriptor descriptor, HeaderStoreItemInfo info) { // Unlike TryGetHeaderInfo() this method tries to parse all non-validated header values (if any) // before returning to the caller. @@ -753,12 +750,9 @@ private bool ParseRawHeaderValues(HeaderDescriptor descriptor, HeaderStoreItemIn // returning. if ((info.InvalidValue == null) && (info.ParsedValue == null)) { - if (removeEmptyHeader) - { - // After parsing the raw value, no value is left because all values contain newline chars. - Debug.Assert(_headerStore.Entries is not null); - _headerStore.Remove(descriptor); - } + // After parsing the raw value, no value is left because all values contain newline chars. + Debug.Assert(!_headerStore.IsEmpty); + _headerStore.Remove(descriptor); return false; } } @@ -1274,18 +1268,80 @@ internal bool CanAddParsedValue(HttpHeaderParser parser) internal struct HeaderStore { +#pragma warning disable IDE0052 // Remove unread private members + // Used to store the CollectionsMarshal.GetValueRefOrAddDefault out parameter + private static bool s_dictionaryGetValueRefOrAddDefaultExistsDummy; +#pragma warning restore IDE0052 // Remove unread private members + private const int InitialCapacity = 4; - private const int MaxHeaderCount = 128; + private const int ArrayThreshold = 64; // Above this threshold, header ordering will not be preserved + + private object? _store; + + public HeaderEntry[]? GetEntries() + { + object? store = _store; + if (store is null) + { + return null; + } + else if (store is HeaderEntry[] entries) + { + return entries; + } + else + { + return GetEntriesFromDictionary(Unsafe.As>(store)); + } - public HeaderEntry[]? Entries; + static HeaderEntry[] GetEntriesFromDictionary(Dictionary dictionary) + { + var entries = new HeaderEntry[dictionary.Count]; + int i = 0; + foreach (KeyValuePair entry in dictionary) + { + entries[i++] = new HeaderEntry + { + Key = entry.Key, + Value = entry.Value + }; + } + return entries; + } + } - public bool IsEmpty => Entries is not HeaderEntry[] entries || entries[0].Key.Descriptor is null; + public bool IsEmpty + { + get + { + object? store = _store; + if (store is null) + { + return true; + } + else if (store is HeaderEntry[] entries) + { + return 0u >= (uint)entries.Length || entries[0].Key.Descriptor is null; + } + else + { + return Unsafe.As>(store).Count != 0; + } + } + } + + public bool RemoveShiftsEntries => _store is HeaderEntry[]; public int Count { get { - if (Entries is HeaderEntry[] entries) + object? store = _store; + if (store is null) + { + return 0; + } + else if (store is HeaderEntry[] entries) { for (int i = 0; i < entries.Length; i++) { @@ -1296,35 +1352,45 @@ public int Count } return entries.Length; } - - return 0; + else + { + return Unsafe.As>(store).Count; + } } } public ref object GetValueRefOrNullRef(HeaderDescriptor key) { - if (Entries is HeaderEntry[] entries) + object? store = _store; + if (store is not null) { - if (key.Descriptor is string) + if (store is HeaderEntry[] entries) { - for (int i = 0; i < entries.Length; i++) + if (key.Descriptor is string) { - if (string.Equals(entries[i].Key.Descriptor as string, Unsafe.As(key.Descriptor), StringComparison.OrdinalIgnoreCase)) + for (int i = 0; i < entries.Length; i++) { - return ref entries[i].Value; + if (string.Equals(entries[i].Key.Descriptor as string, Unsafe.As(key.Descriptor), StringComparison.OrdinalIgnoreCase)) + { + return ref entries[i].Value; + } } } - } - else - { - for (int i = 0; i < entries.Length; i++) + else { - if (ReferenceEquals(entries[i].Key.Descriptor, key.Descriptor)) + for (int i = 0; i < entries.Length; i++) { - return ref entries[i].Value; + if (ReferenceEquals(entries[i].Key.Descriptor, key.Descriptor)) + { + return ref entries[i].Value; + } } } } + else + { + return ref CollectionsMarshal.GetValueRefOrNullRef(Unsafe.As>(store), key); + } } return ref Unsafe.NullRef(); @@ -1332,7 +1398,16 @@ public ref object GetValueRefOrNullRef(HeaderDescriptor key) public ref object? GetValueRefOrAddDefault(HeaderDescriptor key) { - if (Entries is HeaderEntry[] entries) + object? store = _store; + if (store is null) + { + var entries = new HeaderEntry[InitialCapacity]; + _store = entries; + ref HeaderEntry firstElement = ref MemoryMarshal.GetArrayDataReference(entries); + firstElement.Key = key; + return ref firstElement.Value!; + } + else if (store is HeaderEntry[] entries) { if (key.Descriptor is string) { @@ -1371,70 +1446,85 @@ public ref object GetValueRefOrNullRef(HeaderDescriptor key) } else { - entries = new HeaderEntry[InitialCapacity]; - Entries = entries; - ref HeaderEntry firstElement = ref MemoryMarshal.GetArrayDataReference(entries); - firstElement.Key = key; - return ref firstElement.Value!; + return ref CollectionsMarshal.GetValueRefOrAddDefault(Unsafe.As>(store), key, out s_dictionaryGetValueRefOrAddDefaultExistsDummy); } } private ref object? GrowEntriesAndAddDefault(HeaderDescriptor key) { - HeaderEntry[] entries = Entries!; - if (entries.Length == MaxHeaderCount) + var entries = (HeaderEntry[])_store!; + if (entries.Length == ArrayThreshold) { - ThrowOnTooManyHeaders(); + var dictionary = new Dictionary(ArrayThreshold); + _store = dictionary; + foreach (HeaderEntry entry in entries) + { + dictionary.Add(entry.Key, entry.Value); + } + return ref CollectionsMarshal.GetValueRefOrAddDefault(dictionary, key, out s_dictionaryGetValueRefOrAddDefaultExistsDummy); + } + else + { + Array.Resize(ref entries, entries.Length << 1); + _store = entries; + ref HeaderEntry firstNewEntry = ref entries[entries.Length >> 1]; + firstNewEntry.Key = key; + return ref firstNewEntry.Value!; } - Array.Resize(ref entries, entries.Length << 1); - Entries = entries; - ref HeaderEntry firstNewEntry = ref entries[entries.Length >> 1]; - firstNewEntry.Key = key; - return ref firstNewEntry.Value!; } public void Clear() { - if (Entries is HeaderEntry[] entries) + object? store = _store; + if (store is not null) { - Array.Clear(entries); + if (store is HeaderEntry[] entries) + { + Array.Clear(entries); + } + else + { + Unsafe.As>(store).Clear(); + } } } public bool Remove(HeaderDescriptor key) { - if (Entries is HeaderEntry[] entries) + object? store = _store; + if (store is not null) { - for (int i = 0; i < entries.Length; i++) + if (store is HeaderEntry[] entries) { - HeaderDescriptor entryKey = entries[i].Key; - if (entryKey.Descriptor is null) - { - break; - } - - if (entryKey.Equals(key)) + for (int i = 0; i < entries.Length; i++) { - while ((uint)(i + 1) < (uint)entries.Length && entries[i + 1].Key.Descriptor is not null) + HeaderDescriptor entryKey = entries[i].Key; + if (entryKey.Descriptor is null) { - entries[i] = entries[i + 1]; - i++; + break; } - entries[i] = default; - return true; + if (entryKey.Equals(key)) + { + while ((uint)(i + 1) < (uint)entries.Length && entries[i + 1].Key.Descriptor is not null) + { + entries[i] = entries[i + 1]; + i++; + } + entries[i] = default; + + return true; + } } } + else + { + return Unsafe.As>(store).Remove(key); + } } return false; } - - [DoesNotReturn] - private static void ThrowOnTooManyHeaders() - { - throw new HttpRequestException("Too many headers!"); - } } } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs index d3bde0099b1571..f584452e01fa11 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs @@ -83,7 +83,7 @@ public bool TryGetValues(string headerName, out HeaderStringValues values) /// Gets an enumerator that iterates through the . /// An enumerator that iterates through the . public Enumerator GetEnumerator() => - _headers is HttpHeaders headers && headers.Entries is HeaderEntry[] entries ? + _headers is HttpHeaders headers && headers.GetEntries() is HeaderEntry[] entries ? new Enumerator(entries) : default; diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs index 88ad58d7083fc9..bb67a0a1db5f4c 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs @@ -1337,7 +1337,7 @@ private void WriteHeaderCollection(HttpRequestMessage request, HttpHeaders heade { if (NetEventSource.Log.IsEnabled()) Trace(""); - if (headers.Entries is null) + if (headers.GetEntries() is not HeaderEntry[] entries) { return; } @@ -1345,7 +1345,7 @@ private void WriteHeaderCollection(HttpRequestMessage request, HttpHeaders heade HeaderEncodingSelector? encodingSelector = _pool.Settings._requestHeaderEncodingSelector; ref string[]? tmpHeaderValuesArray = ref t_headerValues; - foreach (HeaderEntry header in headers.Entries) + foreach (HeaderEntry header in entries) { if (header.Key.Descriptor is null) { diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs index 3d740cbe90abbc..84abb9d6037945 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs @@ -622,14 +622,14 @@ private void BufferHeaders(HttpRequestMessage request) // TODO: special-case Content-Type for static table values values? private void BufferHeaderCollection(HttpHeaders headers) { - if (headers.Entries is null) + if (headers.GetEntries() is not HeaderEntry[] entries) { return; } HeaderEncodingSelector? encodingSelector = _connection.Pool.Settings._requestHeaderEncodingSelector; - foreach (HeaderEntry header in headers.Entries) + foreach (HeaderEntry header in entries) { if (header.Key.Descriptor is null) { diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs index 9c386937c34e72..35d2a2c8225fa5 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs @@ -258,9 +258,9 @@ private async ValueTask WriteHeadersAsync(HttpHeaders headers, string? cookiesFr { Debug.Assert(_currentRequest != null); - if (headers.Entries != null) + if (headers.GetEntries() is HeaderEntry[] entries) { - foreach (HeaderEntry header in headers.Entries) + foreach (HeaderEntry header in entries) { if (header.Key.Descriptor is null) { diff --git a/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs b/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs index 2fb1827fa2f12a..ca607f9e33f1e8 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs @@ -60,7 +60,7 @@ private static Memory HPackEncode(HttpHeaders headers, Encoding? valueEnco FillAvailableSpaceWithOnes(buffer); string[] headerValues = Array.Empty(); - foreach (HeaderEntry header in headers.Entries) + foreach (HeaderEntry header in headers.GetEntries()) { if (header.Key.Descriptor is null) { From 7c0628400089bca250bc4892196e41606ecc0628 Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Mon, 20 Dec 2021 04:30:00 +0100 Subject: [PATCH 05/22] Add unit tests --- .../System/Net/Http/Headers/HttpHeaders.cs | 20 ++- .../UnitTests/Headers/HttpHeadersTest.cs | 162 ++++++++++++++++++ 2 files changed, 177 insertions(+), 5 deletions(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index dcc2dceca48645..bb800d753ffa89 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -327,7 +327,17 @@ private IEnumerator>> GetEnumeratorCore // during enumeration so that we can parse the raw value in order to a) return // the correct set of parsed values, and b) update the instance for subsequent enumerations // to reflect that parsing. - entries[i].Value = info = new HeaderStoreItemInfo() { RawValue = value }; + info = new HeaderStoreItemInfo() { RawValue = value }; + + if (_headerStore.EntriesAreLiveView) + { + entries[i].Value = info; + } + else + { + Debug.Assert(_headerStore.GetValueRefOrAddDefault(descriptor) is not null); + _headerStore.GetValueRefOrAddDefault(descriptor) = info; + } } // Make sure we parse all raw values before returning the result. Note that this has to be @@ -338,7 +348,7 @@ private IEnumerator>> GetEnumeratorCore // We saw an invalid header value (contains newline chars) and deleted it. // If the HeaderEntry[] we are enumerating is the live header store, the entries have shifted. - if (_headerStore.RemoveShiftsEntries) + if (_headerStore.EntriesAreLiveView) { i--; } @@ -1274,7 +1284,7 @@ internal struct HeaderStore #pragma warning restore IDE0052 // Remove unread private members private const int InitialCapacity = 4; - private const int ArrayThreshold = 64; // Above this threshold, header ordering will not be preserved + internal const int ArrayThreshold = 64; // Above this threshold, header ordering will not be preserved private object? _store; @@ -1325,12 +1335,12 @@ public bool IsEmpty } else { - return Unsafe.As>(store).Count != 0; + return Unsafe.As>(store).Count == 0; } } } - public bool RemoveShiftsEntries => _store is HeaderEntry[]; + public bool EntriesAreLiveView => _store is HeaderEntry[]; public int Count { diff --git a/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs b/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs index 139422163a95bc..7e8cba6e82872c 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs @@ -1710,6 +1710,24 @@ public void GetEnumerator_UseExplicitInterfaceImplementation_EnumeratorReturnsNo Assert.False(enumerator.MoveNext(), "Only 2 values expected, but enumerator returns a third one."); } + [Fact] + public void GetEnumerator_InvalidValueBetweenValidHeaders_EnumeratorReturnsAllValidValuesAndRemovedInvalidValue() + { + MockHeaders headers = new MockHeaders(); + headers.TryAddWithoutValidation("foo", "fooValue"); + headers.TryAddWithoutValidation("invalid", "invalid\nvalue"); + headers.TryAddWithoutValidation("bar", "barValue"); + + Assert.Equal(3, headers.Count); + + IDictionary> dict = headers.ToDictionary(pair => pair.Key, pair => pair.Value); + Assert.Equal("fooValue", Assert.Single(Assert.Contains("foo", dict))); + Assert.Equal("barValue", Assert.Single(Assert.Contains("bar", dict))); + + Assert.Equal(2, headers.Count); + Assert.DoesNotContain("invalid", dict); + } + [Fact] public void AddParsedValue_AddSingleValueToNonExistingHeader_HeaderGetsCreatedAndValueAdded() { @@ -2208,6 +2226,150 @@ public void HeaderStringValues_Constructed_ProducesExpectedResults() } } + [Theory] + [MemberData(nameof(NumberOfHeadersUpToArrayThreshold_AddNonValidated_EnumerateNonValidated))] + public void Add_WithinArrayThresholdHeaders_EnumerationPreservesOrdering(int numberOfHeaders, bool addNonValidated, bool enumerateNonValidated) + { + var headers = new MockHeaders(); + + for (int i = 0; i < numberOfHeaders; i++) + { + if (addNonValidated) + { + headers.TryAddWithoutValidation(i.ToString(), i.ToString()); + } + else + { + headers.Add(i.ToString(), i.ToString()); + } + } + + KeyValuePair[] entries = enumerateNonValidated + ? headers.NonValidated.Select(pair => KeyValuePair.Create(pair.Key, Assert.Single(pair.Value))).ToArray() + : headers.Select(pair => KeyValuePair.Create(pair.Key, Assert.Single(pair.Value))).ToArray(); + + Assert.Equal(numberOfHeaders, entries.Length); + for (int i = 0; i < numberOfHeaders; i++) + { + Assert.Equal(i.ToString(), entries[i].Key); + Assert.Equal(i.ToString(), entries[i].Value); + } + } + + [Fact] + public void Add_Remove_HeaderOrderingIsPreserved() + { + var headers = new MockHeaders(); + headers.Add("a", ""); + headers.Add("b", ""); + headers.Add("c", ""); + + headers.Remove("b"); + + Assert.Equal(new[] { "a", "c" }, headers.Select(pair => pair.Key)); + } + + [Fact] + public void Add_AddToExistingKey_OriginalOrderingIsPreserved() + { + var headers = new MockHeaders(); + headers.Add("a", "a1"); + headers.Add("b", "b1"); + headers.Add("a", "a2"); + + Assert.Equal(new[] { "a", "b" }, headers.Select(pair => pair.Key)); + } + + [Theory] + [InlineData(3)] + [InlineData(4)] + [InlineData(5)] + [InlineData(HttpHeaders.HeaderStore.ArrayThreshold / 4)] + [InlineData(HttpHeaders.HeaderStore.ArrayThreshold / 2)] + [InlineData(HttpHeaders.HeaderStore.ArrayThreshold - 1)] + [InlineData(HttpHeaders.HeaderStore.ArrayThreshold)] + [InlineData(HttpHeaders.HeaderStore.ArrayThreshold + 1)] + [InlineData(HttpHeaders.HeaderStore.ArrayThreshold * 2)] + [InlineData(HttpHeaders.HeaderStore.ArrayThreshold * 4)] + public void Add_LargeNumberOfHeaders_OperationsStillSupported(int numberOfHeaders) + { + string[] keys = Enumerable.Range(1, numberOfHeaders).Select(i => i.ToString()).ToArray(); + + var headers = new MockHeaders(); + foreach (string key in keys) + { + Assert.False(headers.NonValidated.Contains(key)); + headers.TryAddWithoutValidation(key, key); + Assert.True(headers.NonValidated.Contains(key)); + } + + string[] nonValidatedKeys = headers.NonValidated.Select(pair => pair.Key).ToArray(); + Assert.Equal(numberOfHeaders, nonValidatedKeys.Length); + + string[] newKeys = headers.Select(pair => pair.Key).ToArray(); + Assert.Equal(numberOfHeaders, newKeys.Length); + + string[] nonValidatedKeysAfterValidation = headers.NonValidated.Select(pair => pair.Key).ToArray(); + Assert.Equal(numberOfHeaders, nonValidatedKeysAfterValidation.Length); + + if (numberOfHeaders > HttpHeaders.HeaderStore.ArrayThreshold) + { + // Ordering is lost when adding more than ArrayThreshold headers + Array.Sort(nonValidatedKeys, (a, b) => int.Parse(a).CompareTo(int.Parse(b))); + Array.Sort(newKeys, (a, b) => int.Parse(a).CompareTo(int.Parse(b))); + Array.Sort(nonValidatedKeysAfterValidation, (a, b) => int.Parse(a).CompareTo(int.Parse(b))); + } + Assert.Equal(keys, nonValidatedKeys); + Assert.Equal(keys, newKeys); + Assert.Equal(keys, nonValidatedKeysAfterValidation); + + headers.Add("3", "secondValue"); + Assert.True(headers.TryGetValues("3", out IEnumerable valuesFor3)); + Assert.Equal(new[] { "3", "secondValue" }, valuesFor3); + + Assert.True(headers.TryAddWithoutValidation("invalid", "invalid\nvalue")); + Assert.True(headers.TryAddWithoutValidation("valid", "validValue")); + + Assert.Equal(numberOfHeaders + 2, headers.NonValidated.Count); + + // Remove all headers except for "1", "valid", "invalid" + for (int i = 2; i <= numberOfHeaders; i++) + { + Assert.True(headers.Remove(i.ToString())); + } + + Assert.False(headers.Remove("3")); + + // "1", "invalid", "valid" + Assert.True(headers.NonValidated.Contains("invalid")); + Assert.Equal(3, headers.NonValidated.Count); + + Assert.Equal(new[] { "1", "valid" }, headers.Select(pair => pair.Key).OrderBy(i => i)); + + Assert.Equal(2, headers.NonValidated.Count); + + headers.Clear(); + + Assert.Equal(0, headers.NonValidated.Count); + Assert.Empty(headers); + Assert.False(headers.Contains("3")); + + Assert.True(headers.TryAddWithoutValidation("3", "newValue")); + Assert.True(headers.TryGetValues("3", out valuesFor3)); + Assert.Equal(new[] { "newValue" }, valuesFor3); + } + + public static IEnumerable NumberOfHeadersUpToArrayThreshold_AddNonValidated_EnumerateNonValidated() + { + for (int i = 0; i <= HttpHeaders.HeaderStore.ArrayThreshold; i++) + { + yield return new object[] { i, false, false }; + yield return new object[] { i, false, true }; + yield return new object[] { i, true, false }; + yield return new object[] { i, true, true }; + } + } + public static IEnumerable GetInvalidHeaderNames() { yield return new object[] { "invalid header" }; From 344a48d6f707a072875554f70fc1824b78904439 Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Mon, 20 Dec 2021 04:43:26 +0100 Subject: [PATCH 06/22] Use storeValueRef naming consistently --- .../System/Net/Http/Headers/HttpHeaders.cs | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index bb800d753ffa89..2b5ca990ebae4e 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -660,14 +660,14 @@ private HeaderStoreItemInfo GetOrCreateHeaderInfo(HeaderDescriptor descriptor, b } else { - ref object valueRef = ref _headerStore.GetValueRefOrNullRef(descriptor); - if (!Unsafe.IsNullRef(ref valueRef)) + ref object storeValueRef = ref _headerStore.GetValueRefOrNullRef(descriptor); + if (!Unsafe.IsNullRef(ref storeValueRef)) { - object value = valueRef; + object value = storeValueRef; if (value is not HeaderStoreItemInfo info) { Debug.Assert(value is string); - valueRef = info = new HeaderStoreItemInfo { RawValue = value }; + storeValueRef = info = new HeaderStoreItemInfo { RawValue = value }; } return info; } @@ -698,25 +698,25 @@ private void AddHeaderToStore(HeaderDescriptor descriptor, object value) internal bool TryGetHeaderValue(HeaderDescriptor descriptor, [NotNullWhen(true)] out object? value) { - ref object valueRef = ref _headerStore.GetValueRefOrNullRef(descriptor); - if (Unsafe.IsNullRef(ref valueRef)) + ref object storeValueRef = ref _headerStore.GetValueRefOrNullRef(descriptor); + if (Unsafe.IsNullRef(ref storeValueRef)) { value = null; return false; } else { - value = valueRef; + value = storeValueRef; return true; } } private bool TryGetAndParseHeaderInfo(HeaderDescriptor key, [NotNullWhen(true)] out HeaderStoreItemInfo? info) { - ref object valueRef = ref _headerStore.GetValueRefOrNullRef(key); - if (!Unsafe.IsNullRef(ref valueRef)) + ref object storeValueRef = ref _headerStore.GetValueRefOrNullRef(key); + if (!Unsafe.IsNullRef(ref storeValueRef)) { - object value = valueRef; + object value = storeValueRef; if (value is HeaderStoreItemInfo hsi) { info = hsi; @@ -724,7 +724,7 @@ private bool TryGetAndParseHeaderInfo(HeaderDescriptor key, [NotNullWhen(true)] else { Debug.Assert(value is string); - valueRef = info = new HeaderStoreItemInfo() { RawValue = value }; + storeValueRef = info = new HeaderStoreItemInfo() { RawValue = value }; } return ParseRawHeaderValues(key, info); From e6920893a2fbb6887aecce52c8a05a43fda14601 Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Mon, 20 Dec 2021 11:48:49 +0100 Subject: [PATCH 07/22] Workaround field layout regression (#63005) --- .../System/Net/Http/Headers/HttpHeaders.cs | 372 +++++++++--------- .../UnitTests/Headers/HttpHeadersTest.cs | 18 +- 2 files changed, 190 insertions(+), 200 deletions(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index 2b5ca990ebae4e..fe8a776232c743 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -40,8 +40,8 @@ public abstract class HttpHeaders : IEnumerableKey/value pairs of headers. The value is either a raw or a . - private HeaderStore _headerStore; + /// Either a array or a Dictionary<, > + private object? _headerStore; private readonly HttpHeaderType _allowedHeaderTypes; private readonly HttpHeaderType _treatAsCustomHeaderTypes; @@ -60,10 +60,6 @@ internal HttpHeaders(HttpHeaderType allowedHeaderTypes, HttpHeaderType treatAsCu _treatAsCustomHeaderTypes = treatAsCustomHeaderTypes & ~HttpHeaderType.NonTrailing; } - internal HeaderEntry[]? GetEntries() => _headerStore.GetEntries(); - - internal int Count => _headerStore.Count; - /// Gets a view of the contents of this headers collection that does not parse nor validate the data upon access. public HttpHeadersNonValidated NonValidated => new HttpHeadersNonValidated(this); @@ -130,7 +126,7 @@ internal bool TryAddWithoutValidation(HeaderDescriptor descriptor, string? value // values, e.g. adding two null-strings (or empty, or whitespace-only) results in "My-Header: ,". value ??= string.Empty; - ref object? storeValueRef = ref _headerStore.GetValueRefOrAddDefault(descriptor); + ref object? storeValueRef = ref GetValueRefOrAddDefault(descriptor); object? currentValue = storeValueRef; if (currentValue is null) @@ -185,8 +181,6 @@ internal bool TryAddWithoutValidation(HeaderDescriptor descriptor, IEnumerable _headerStore.Clear(); - public IEnumerable GetValues(string name) => GetValues(GetHeaderDescriptor(name)); internal IEnumerable GetValues(HeaderDescriptor descriptor) @@ -302,7 +296,7 @@ internal string GetHeaderString(HeaderDescriptor descriptor) #region IEnumerable>> Members - public IEnumerator>> GetEnumerator() => _headerStore.IsEmpty ? + public IEnumerator>> GetEnumerator() => IsEmpty ? ((IEnumerable>>)Array.Empty>>()).GetEnumerator() : GetEnumeratorCore(); @@ -329,14 +323,14 @@ private IEnumerator>> GetEnumeratorCore // to reflect that parsing. info = new HeaderStoreItemInfo() { RawValue = value }; - if (_headerStore.EntriesAreLiveView) + if (EntriesAreLiveView) { entries[i].Value = info; } else { - Debug.Assert(_headerStore.GetValueRefOrAddDefault(descriptor) is not null); - _headerStore.GetValueRefOrAddDefault(descriptor) = info; + Debug.Assert(GetValueRefOrAddDefault(descriptor) is not null); + GetValueRefOrAddDefault(descriptor) = info; } } @@ -348,7 +342,7 @@ private IEnumerator>> GetEnumeratorCore // We saw an invalid header value (contains newline chars) and deleted it. // If the HeaderEntry[] we are enumerating is the live header store, the entries have shifted. - if (_headerStore.EntriesAreLiveView) + if (EntriesAreLiveView) { i--; } @@ -413,8 +407,6 @@ internal void SetOrRemoveParsedValue(HeaderDescriptor descriptor, object? value) public bool Remove(string name) => Remove(GetHeaderDescriptor(name)); - internal bool Remove(HeaderDescriptor descriptor) => _headerStore.Remove(descriptor); - internal bool RemoveParsedValue(HeaderDescriptor descriptor, object value) { Debug.Assert(value != null); @@ -551,7 +543,7 @@ internal virtual void AddHeaders(HttpHeaders sourceHeaders) // Only add header values if they're not already set on the message. Note that we don't merge // collections: If both the default headers and the message have set some values for a certain // header, then we don't try to merge the values. - ref object? storeValueRef = ref _headerStore.GetValueRefOrAddDefault(entry.Key); + ref object? storeValueRef = ref GetValueRefOrAddDefault(entry.Key); if (storeValueRef is null) { object sourceValue = entry.Value; @@ -660,7 +652,7 @@ private HeaderStoreItemInfo GetOrCreateHeaderInfo(HeaderDescriptor descriptor, b } else { - ref object storeValueRef = ref _headerStore.GetValueRefOrNullRef(descriptor); + ref object storeValueRef = ref GetValueRefOrNullRef(descriptor); if (!Unsafe.IsNullRef(ref storeValueRef)) { object value = storeValueRef; @@ -692,13 +684,13 @@ private HeaderStoreItemInfo CreateAndAddHeaderToStore(HeaderDescriptor descripto private void AddHeaderToStore(HeaderDescriptor descriptor, object value) { Debug.Assert(value is string or HeaderStoreItemInfo); - Debug.Assert(Unsafe.IsNullRef(ref _headerStore.GetValueRefOrNullRef(descriptor))); - _headerStore.GetValueRefOrAddDefault(descriptor) = value; + Debug.Assert(Unsafe.IsNullRef(ref GetValueRefOrNullRef(descriptor))); + GetValueRefOrAddDefault(descriptor) = value; } internal bool TryGetHeaderValue(HeaderDescriptor descriptor, [NotNullWhen(true)] out object? value) { - ref object storeValueRef = ref _headerStore.GetValueRefOrNullRef(descriptor); + ref object storeValueRef = ref GetValueRefOrNullRef(descriptor); if (Unsafe.IsNullRef(ref storeValueRef)) { value = null; @@ -713,7 +705,7 @@ internal bool TryGetHeaderValue(HeaderDescriptor descriptor, [NotNullWhen(true)] private bool TryGetAndParseHeaderInfo(HeaderDescriptor key, [NotNullWhen(true)] out HeaderStoreItemInfo? info) { - ref object storeValueRef = ref _headerStore.GetValueRefOrNullRef(key); + ref object storeValueRef = ref GetValueRefOrNullRef(key); if (!Unsafe.IsNullRef(ref storeValueRef)) { object value = storeValueRef; @@ -761,8 +753,8 @@ private bool ParseRawHeaderValues(HeaderDescriptor descriptor, HeaderStoreItemIn if ((info.InvalidValue == null) && (info.ParsedValue == null)) { // After parsing the raw value, no value is left because all values contain newline chars. - Debug.Assert(!_headerStore.IsEmpty); - _headerStore.Remove(descriptor); + Debug.Assert(!IsEmpty); + Remove(descriptor); return false; } } @@ -1276,197 +1268,193 @@ internal bool CanAddParsedValue(HttpHeaderParser parser) internal bool IsEmpty => (RawValue == null) && (InvalidValue == null) && (ParsedValue == null); } - internal struct HeaderStore - { -#pragma warning disable IDE0052 // Remove unread private members - // Used to store the CollectionsMarshal.GetValueRefOrAddDefault out parameter - private static bool s_dictionaryGetValueRefOrAddDefaultExistsDummy; -#pragma warning restore IDE0052 // Remove unread private members - private const int InitialCapacity = 4; - internal const int ArrayThreshold = 64; // Above this threshold, header ordering will not be preserved + #region Low-level implementation details that work with _headerStore directly + + // Used to store the CollectionsMarshal.GetValueRefOrAddDefault out parameter + private static bool s_dictionaryGetValueRefOrAddDefaultExistsDummy; + + private const int InitialCapacity = 4; + internal const int ArrayThreshold = 64; // Above this threshold, header ordering will not be preserved + + internal HeaderEntry[]? GetEntries() + { + object? store = _headerStore; + if (store is null) + { + return null; + } + else if (store is HeaderEntry[] entries) + { + return entries; + } + else + { + return GetEntriesFromDictionary(Unsafe.As>(store)); + } - private object? _store; + static HeaderEntry[] GetEntriesFromDictionary(Dictionary dictionary) + { + var entries = new HeaderEntry[dictionary.Count]; + int i = 0; + foreach (KeyValuePair entry in dictionary) + { + entries[i++] = new HeaderEntry + { + Key = entry.Key, + Value = entry.Value + }; + } + return entries; + } + } - public HeaderEntry[]? GetEntries() + private bool IsEmpty + { + get { - object? store = _store; + object? store = _headerStore; if (store is null) { - return null; + return true; } else if (store is HeaderEntry[] entries) { - return entries; + return 0u >= (uint)entries.Length || entries[0].Key.Descriptor is null; } else { - return GetEntriesFromDictionary(Unsafe.As>(store)); + return Unsafe.As>(store).Count == 0; } + } + } + + private bool EntriesAreLiveView => _headerStore is HeaderEntry[]; - static HeaderEntry[] GetEntriesFromDictionary(Dictionary dictionary) + internal int Count + { + get + { + object? store = _headerStore; + if (store is null) + { + return 0; + } + else if (store is HeaderEntry[] entries) { - var entries = new HeaderEntry[dictionary.Count]; - int i = 0; - foreach (KeyValuePair entry in dictionary) + for (int i = 0; i < entries.Length; i++) { - entries[i++] = new HeaderEntry + if (entries[i].Key.Descriptor is null) { - Key = entry.Key, - Value = entry.Value - }; + return i; + } } - return entries; + return entries.Length; } - } - - public bool IsEmpty - { - get + else { - object? store = _store; - if (store is null) - { - return true; - } - else if (store is HeaderEntry[] entries) - { - return 0u >= (uint)entries.Length || entries[0].Key.Descriptor is null; - } - else - { - return Unsafe.As>(store).Count == 0; - } + return Unsafe.As>(store).Count; } } + } - public bool EntriesAreLiveView => _store is HeaderEntry[]; - - public int Count + private ref object GetValueRefOrNullRef(HeaderDescriptor key) + { + object? store = _headerStore; + if (store is not null) { - get + if (store is HeaderEntry[] entries) { - object? store = _store; - if (store is null) - { - return 0; - } - else if (store is HeaderEntry[] entries) + if (key.Descriptor is string) { for (int i = 0; i < entries.Length; i++) { - if (entries[i].Key.Descriptor is null) + if (string.Equals(entries[i].Key.Descriptor as string, Unsafe.As(key.Descriptor), StringComparison.OrdinalIgnoreCase)) { - return i; + return ref entries[i].Value; } } - return entries.Length; } else { - return Unsafe.As>(store).Count; + for (int i = 0; i < entries.Length; i++) + { + if (ReferenceEquals(entries[i].Key.Descriptor, key.Descriptor)) + { + return ref entries[i].Value; + } + } } } + else + { + return ref CollectionsMarshal.GetValueRefOrNullRef(Unsafe.As>(store), key); + } } - public ref object GetValueRefOrNullRef(HeaderDescriptor key) + return ref Unsafe.NullRef(); + } + + private ref object? GetValueRefOrAddDefault(HeaderDescriptor key) + { + object? store = _headerStore; + if (store is null) + { + var entries = new HeaderEntry[InitialCapacity]; + _headerStore = entries; + ref HeaderEntry firstElement = ref MemoryMarshal.GetArrayDataReference(entries); + firstElement.Key = key; + return ref firstElement.Value!; + } + else if (store is HeaderEntry[] entries) { - object? store = _store; - if (store is not null) + if (key.Descriptor is string) { - if (store is HeaderEntry[] entries) + for (int i = 0; i < entries.Length; i++) { - if (key.Descriptor is string) + ref HeaderEntry entry = ref entries[i]; + if (entry.Key.Descriptor is null) { - for (int i = 0; i < entries.Length; i++) - { - if (string.Equals(entries[i].Key.Descriptor as string, Unsafe.As(key.Descriptor), StringComparison.OrdinalIgnoreCase)) - { - return ref entries[i].Value; - } - } + entry.Key = key; + return ref entry.Value!; } - else + else if (string.Equals(entry.Key.Descriptor as string, Unsafe.As(key.Descriptor), StringComparison.OrdinalIgnoreCase)) { - for (int i = 0; i < entries.Length; i++) - { - if (ReferenceEquals(entries[i].Key.Descriptor, key.Descriptor)) - { - return ref entries[i].Value; - } - } + return ref entry.Value!; } } - else - { - return ref CollectionsMarshal.GetValueRefOrNullRef(Unsafe.As>(store), key); - } } - - return ref Unsafe.NullRef(); - } - - public ref object? GetValueRefOrAddDefault(HeaderDescriptor key) - { - object? store = _store; - if (store is null) - { - var entries = new HeaderEntry[InitialCapacity]; - _store = entries; - ref HeaderEntry firstElement = ref MemoryMarshal.GetArrayDataReference(entries); - firstElement.Key = key; - return ref firstElement.Value!; - } - else if (store is HeaderEntry[] entries) + else { - if (key.Descriptor is string) + for (int i = 0; i < entries.Length; i++) { - for (int i = 0; i < entries.Length; i++) + ref HeaderEntry entry = ref entries[i]; + if (entry.Key.Descriptor is null) { - ref HeaderEntry entry = ref entries[i]; - if (entry.Key.Descriptor is null) - { - entry.Key = key; - return ref entry.Value!; - } - else if (string.Equals(entry.Key.Descriptor as string, Unsafe.As(key.Descriptor), StringComparison.OrdinalIgnoreCase)) - { - return ref entry.Value!; - } + entry.Key = key; + return ref entry.Value!; } - } - else - { - for (int i = 0; i < entries.Length; i++) + else if (ReferenceEquals(entry.Key.Descriptor, key.Descriptor)) { - ref HeaderEntry entry = ref entries[i]; - if (entry.Key.Descriptor is null) - { - entry.Key = key; - return ref entry.Value!; - } - else if (ReferenceEquals(entry.Key.Descriptor, key.Descriptor)) - { - return ref entry.Value!; - } + return ref entry.Value!; } } - - return ref GrowEntriesAndAddDefault(key); - } - else - { - return ref CollectionsMarshal.GetValueRefOrAddDefault(Unsafe.As>(store), key, out s_dictionaryGetValueRefOrAddDefaultExistsDummy); } + + return ref GrowEntriesAndAddDefault(key); + } + else + { + return ref CollectionsMarshal.GetValueRefOrAddDefault(Unsafe.As>(store), key, out s_dictionaryGetValueRefOrAddDefaultExistsDummy); } - private ref object? GrowEntriesAndAddDefault(HeaderDescriptor key) + ref object? GrowEntriesAndAddDefault(HeaderDescriptor key) { - var entries = (HeaderEntry[])_store!; + var entries = (HeaderEntry[])_headerStore!; if (entries.Length == ArrayThreshold) { var dictionary = new Dictionary(ArrayThreshold); - _store = dictionary; + _headerStore = dictionary; foreach (HeaderEntry entry in entries) { dictionary.Add(entry.Key, entry.Value); @@ -1476,65 +1464,67 @@ public ref object GetValueRefOrNullRef(HeaderDescriptor key) else { Array.Resize(ref entries, entries.Length << 1); - _store = entries; + _headerStore = entries; ref HeaderEntry firstNewEntry = ref entries[entries.Length >> 1]; firstNewEntry.Key = key; return ref firstNewEntry.Value!; } } + } - public void Clear() + public void Clear() + { + object? store = _headerStore; + if (store is not null) { - object? store = _store; - if (store is not null) + if (store is HeaderEntry[] entries) { - if (store is HeaderEntry[] entries) - { - Array.Clear(entries); - } - else - { - Unsafe.As>(store).Clear(); - } + Array.Clear(entries); + } + else + { + Unsafe.As>(store).Clear(); } } + } - public bool Remove(HeaderDescriptor key) + internal bool Remove(HeaderDescriptor key) + { + object? store = _headerStore; + if (store is not null) { - object? store = _store; - if (store is not null) + if (store is HeaderEntry[] entries) { - if (store is HeaderEntry[] entries) + for (int i = 0; i < entries.Length; i++) { - for (int i = 0; i < entries.Length; i++) + HeaderDescriptor entryKey = entries[i].Key; + if (entryKey.Descriptor is null) { - HeaderDescriptor entryKey = entries[i].Key; - if (entryKey.Descriptor is null) - { - break; - } + break; + } - if (entryKey.Equals(key)) + if (entryKey.Equals(key)) + { + while ((uint)(i + 1) < (uint)entries.Length && entries[i + 1].Key.Descriptor is not null) { - while ((uint)(i + 1) < (uint)entries.Length && entries[i + 1].Key.Descriptor is not null) - { - entries[i] = entries[i + 1]; - i++; - } - entries[i] = default; - - return true; + entries[i] = entries[i + 1]; + i++; } + entries[i] = default; + + return true; } } - else - { - return Unsafe.As>(store).Remove(key); - } } - - return false; + else + { + return Unsafe.As>(store).Remove(key); + } } + + return false; } + + #endregion // _headerStore implementation } } diff --git a/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs b/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs index 7e8cba6e82872c..b9d5abb5d450ca 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs @@ -2284,13 +2284,13 @@ public void Add_AddToExistingKey_OriginalOrderingIsPreserved() [InlineData(3)] [InlineData(4)] [InlineData(5)] - [InlineData(HttpHeaders.HeaderStore.ArrayThreshold / 4)] - [InlineData(HttpHeaders.HeaderStore.ArrayThreshold / 2)] - [InlineData(HttpHeaders.HeaderStore.ArrayThreshold - 1)] - [InlineData(HttpHeaders.HeaderStore.ArrayThreshold)] - [InlineData(HttpHeaders.HeaderStore.ArrayThreshold + 1)] - [InlineData(HttpHeaders.HeaderStore.ArrayThreshold * 2)] - [InlineData(HttpHeaders.HeaderStore.ArrayThreshold * 4)] + [InlineData(HttpHeaders.ArrayThreshold / 4)] + [InlineData(HttpHeaders.ArrayThreshold / 2)] + [InlineData(HttpHeaders.ArrayThreshold - 1)] + [InlineData(HttpHeaders.ArrayThreshold)] + [InlineData(HttpHeaders.ArrayThreshold + 1)] + [InlineData(HttpHeaders.ArrayThreshold * 2)] + [InlineData(HttpHeaders.ArrayThreshold * 4)] public void Add_LargeNumberOfHeaders_OperationsStillSupported(int numberOfHeaders) { string[] keys = Enumerable.Range(1, numberOfHeaders).Select(i => i.ToString()).ToArray(); @@ -2312,7 +2312,7 @@ public void Add_LargeNumberOfHeaders_OperationsStillSupported(int numberOfHeader string[] nonValidatedKeysAfterValidation = headers.NonValidated.Select(pair => pair.Key).ToArray(); Assert.Equal(numberOfHeaders, nonValidatedKeysAfterValidation.Length); - if (numberOfHeaders > HttpHeaders.HeaderStore.ArrayThreshold) + if (numberOfHeaders > HttpHeaders.ArrayThreshold) { // Ordering is lost when adding more than ArrayThreshold headers Array.Sort(nonValidatedKeys, (a, b) => int.Parse(a).CompareTo(int.Parse(b))); @@ -2361,7 +2361,7 @@ public void Add_LargeNumberOfHeaders_OperationsStillSupported(int numberOfHeader public static IEnumerable NumberOfHeadersUpToArrayThreshold_AddNonValidated_EnumerateNonValidated() { - for (int i = 0; i <= HttpHeaders.HeaderStore.ArrayThreshold; i++) + for (int i = 0; i <= HttpHeaders.ArrayThreshold; i++) { yield return new object[] { i, false, false }; yield return new object[] { i, false, true }; From 2e62a231b4971783313e09b004259cc379f28b24 Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Mon, 20 Dec 2021 11:49:37 +0100 Subject: [PATCH 08/22] Mark _descriptor on HeaderDescriptor as nullable --- .../src/System/Net/Http/Headers/HeaderDescriptor.cs | 10 ++++++---- .../src/System/Net/Http/Headers/HttpHeaders.cs | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs index d3fc840edf91f0..43e22000c3dc05 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs @@ -19,7 +19,7 @@ namespace System.Net.Http.Headers /// /// Either a or /// - private readonly object _descriptor; + private readonly object? _descriptor; public HeaderDescriptor(KnownHeader knownHeader) { @@ -36,14 +36,15 @@ public string Name { get { - object? descriptor = _descriptor; + Debug.Assert(_descriptor is not null); + object descriptor = _descriptor; return descriptor is KnownHeader knownHeader ? knownHeader.Name : Unsafe.As(descriptor); } } - public object Descriptor => _descriptor; + public object? Descriptor => _descriptor; public HttpHeaderParser? Parser => _descriptor is KnownHeader knownHeader ? knownHeader.Parser : null; public HttpHeaderType HeaderType => _descriptor is KnownHeader knownHeader ? knownHeader.HeaderType : HttpHeaderType.Custom; @@ -66,7 +67,8 @@ public bool Equals(HeaderDescriptor other) public override int GetHashCode() { - object? descriptor = _descriptor; + Debug.Assert(_descriptor is not null); + object descriptor = _descriptor; if (descriptor is string headerName) { return StringComparer.OrdinalIgnoreCase.GetHashCode(headerName); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index fe8a776232c743..1acd6b243b4472 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -302,7 +302,7 @@ public IEnumerator>> GetEnumerator() => private IEnumerator>> GetEnumeratorCore() { - HeaderEntry[]? entries = GetEntries()!; + HeaderEntry[] entries = GetEntries()!; for (int i = 0; i < entries.Length; i++) { From 36626b7f02f663d37fa08f357a1e24d8670c1e84 Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Tue, 21 Dec 2021 15:03:32 +0100 Subject: [PATCH 09/22] Remove HeaderDescriptor.Descriptor and add HasValue, IsKnownHeader, Equals --- .../Net/Http/Headers/HeaderDescriptor.cs | 24 +++++++++- .../System/Net/Http/Headers/HttpHeaders.cs | 44 ++++++++++++------- .../Http/Headers/HttpHeadersNonValidated.cs | 2 +- .../SocketsHttpHandler/Http2Connection.cs | 6 +-- .../SocketsHttpHandler/Http3RequestStream.cs | 8 ++-- .../Http/SocketsHttpHandler/HttpConnection.cs | 10 ++--- .../SocketsHttpHandler/HttpConnectionBase.cs | 4 +- .../UnitTests/HPack/HPackRoundtripTests.cs | 6 +-- .../UnitTests/Headers/HeaderEncodingTest.cs | 7 ++- .../UnitTests/Headers/KnownHeadersTest.cs | 3 +- 10 files changed, 75 insertions(+), 39 deletions(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs index 43e22000c3dc05..2df2800081e7a8 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs @@ -32,6 +32,8 @@ internal HeaderDescriptor(string headerName) _descriptor = headerName; } + public bool HasValue => _descriptor is not null; + public string Name { get @@ -44,12 +46,32 @@ public string Name } } - public object? Descriptor => _descriptor; + public bool IsKnownHeader([NotNullWhen(true)] out KnownHeader? knownHeader, [NotNullWhen(false)] out string? headerName) + { + Debug.Assert(_descriptor is not null); + object descriptor = _descriptor; + if (descriptor.GetType() == typeof(KnownHeader)) + { + knownHeader = Unsafe.As(descriptor); + Unsafe.SkipInit(out headerName); + return true; + } + else + { + Unsafe.SkipInit(out knownHeader); + headerName = Unsafe.As(descriptor); + return false; + } + } public HttpHeaderParser? Parser => _descriptor is KnownHeader knownHeader ? knownHeader.Parser : null; public HttpHeaderType HeaderType => _descriptor is KnownHeader knownHeader ? knownHeader.HeaderType : HttpHeaderType.Custom; public string Separator => _descriptor is KnownHeader knownHeader ? knownHeader.Separator : HttpHeaderParser.DefaultSeparator; + public bool Equals(KnownHeader other) => ReferenceEquals(_descriptor, other); + + public bool Equals(string other) => string.Equals(_descriptor as string, other, StringComparison.OrdinalIgnoreCase); + public bool Equals(HeaderDescriptor other) { object? descriptor = _descriptor; diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index 1acd6b243b4472..2107cca8fd2015 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -239,7 +239,7 @@ public override string ToString() { foreach (HeaderEntry entry in entries) { - if (entry.Key.Descriptor is null) + if (!entry.Key.HasValue) { break; } @@ -308,7 +308,7 @@ private IEnumerator>> GetEnumeratorCore { HeaderEntry entry = entries[i]; HeaderDescriptor descriptor = entry.Key; - if (descriptor.Descriptor is null) + if (!descriptor.HasValue) { break; } @@ -535,7 +535,7 @@ internal virtual void AddHeaders(HttpHeaders sourceHeaders) { foreach (HeaderEntry entry in sourceEntries) { - if (entry.Value is null) + if (!entry.Key.HasValue) { break; } @@ -1320,7 +1320,7 @@ private bool IsEmpty } else if (store is HeaderEntry[] entries) { - return 0u >= (uint)entries.Length || entries[0].Key.Descriptor is null; + return 0u >= (uint)entries.Length || !entries[0].Key.HasValue; } else { @@ -1344,7 +1344,7 @@ internal int Count { for (int i = 0; i < entries.Length; i++) { - if (entries[i].Key.Descriptor is null) + if (!entries[i].Key.HasValue) { return i; } @@ -1365,13 +1365,18 @@ private ref object GetValueRefOrNullRef(HeaderDescriptor key) { if (store is HeaderEntry[] entries) { - if (key.Descriptor is string) + if (key.IsKnownHeader(out KnownHeader? knownHeader, out string? headerName)) { for (int i = 0; i < entries.Length; i++) { - if (string.Equals(entries[i].Key.Descriptor as string, Unsafe.As(key.Descriptor), StringComparison.OrdinalIgnoreCase)) + ref HeaderEntry entry = ref entries[i]; + if (!entry.Key.HasValue) { - return ref entries[i].Value; + break; + } + else if (entry.Key.Equals(knownHeader)) + { + return ref entry.Value; } } } @@ -1379,9 +1384,14 @@ private ref object GetValueRefOrNullRef(HeaderDescriptor key) { for (int i = 0; i < entries.Length; i++) { - if (ReferenceEquals(entries[i].Key.Descriptor, key.Descriptor)) + ref HeaderEntry entry = ref entries[i]; + if (!entry.Key.HasValue) + { + break; + } + else if (entry.Key.Equals(headerName)) { - return ref entries[i].Value; + return ref entry.Value; } } } @@ -1408,17 +1418,17 @@ private ref object GetValueRefOrNullRef(HeaderDescriptor key) } else if (store is HeaderEntry[] entries) { - if (key.Descriptor is string) + if (key.IsKnownHeader(out KnownHeader? knownHeader, out string? headerName)) { for (int i = 0; i < entries.Length; i++) { ref HeaderEntry entry = ref entries[i]; - if (entry.Key.Descriptor is null) + if (!entry.Key.HasValue) { entry.Key = key; return ref entry.Value!; } - else if (string.Equals(entry.Key.Descriptor as string, Unsafe.As(key.Descriptor), StringComparison.OrdinalIgnoreCase)) + else if (entry.Key.Equals(knownHeader)) { return ref entry.Value!; } @@ -1429,12 +1439,12 @@ private ref object GetValueRefOrNullRef(HeaderDescriptor key) for (int i = 0; i < entries.Length; i++) { ref HeaderEntry entry = ref entries[i]; - if (entry.Key.Descriptor is null) + if (!entry.Key.HasValue) { entry.Key = key; return ref entry.Value!; } - else if (ReferenceEquals(entry.Key.Descriptor, key.Descriptor)) + else if (entry.Key.Equals(headerName)) { return ref entry.Value!; } @@ -1498,14 +1508,14 @@ internal bool Remove(HeaderDescriptor key) for (int i = 0; i < entries.Length; i++) { HeaderDescriptor entryKey = entries[i].Key; - if (entryKey.Descriptor is null) + if (!entryKey.HasValue) { break; } if (entryKey.Equals(key)) { - while ((uint)(i + 1) < (uint)entries.Length && entries[i + 1].Key.Descriptor is not null) + while ((uint)(i + 1) < (uint)entries.Length && entries[i + 1].Key.HasValue) { entries[i] = entries[i + 1]; i++; diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs index f584452e01fa11..2bef3d6ac51fb2 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs @@ -142,7 +142,7 @@ public bool MoveNext() HeaderEntry entry = entries[index]; _index++; - if (entry.Key.Descriptor is not null) + if (entry.Key.HasValue) { HttpHeaders.GetStoreValuesAsStringOrStringArray(entry.Key, entry.Value, out string? singleValue, out string[]? multiValue); Debug.Assert(singleValue is not null ^ multiValue is not null); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs index bb67a0a1db5f4c..4a52beea6da6ae 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs @@ -1347,7 +1347,7 @@ private void WriteHeaderCollection(HttpRequestMessage request, HttpHeaders heade ref string[]? tmpHeaderValuesArray = ref t_headerValues; foreach (HeaderEntry header in entries) { - if (header.Key.Descriptor is null) + if (!header.Key.HasValue) { break; } @@ -1358,7 +1358,7 @@ private void WriteHeaderCollection(HttpRequestMessage request, HttpHeaders heade Encoding? valueEncoding = encodingSelector?.Invoke(header.Key.Name, request); - if (header.Key.Descriptor is KnownHeader knownHeader) + if (header.Key.IsKnownHeader(out KnownHeader? knownHeader, out string? headerName)) { // The Host header is not sent for HTTP2 because we send the ":authority" pseudo-header instead // (see pseudo-header handling below in WriteHeaders). @@ -1391,7 +1391,7 @@ private void WriteHeaderCollection(HttpRequestMessage request, HttpHeaders heade else { // The header is not known: fall back to just encoding the header name and value(s). - WriteLiteralHeader(Unsafe.As(header.Key.Descriptor), headerValues, valueEncoding, ref headerBuffer); + WriteLiteralHeader(headerName, headerValues, valueEncoding, ref headerBuffer); } } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs index 84abb9d6037945..e179c7c1ef9783 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs @@ -631,7 +631,7 @@ private void BufferHeaderCollection(HttpHeaders headers) foreach (HeaderEntry header in entries) { - if (header.Key.Descriptor is null) + if (!header.Key.HasValue) { break; } @@ -642,7 +642,7 @@ private void BufferHeaderCollection(HttpHeaders headers) Encoding? valueEncoding = encodingSelector?.Invoke(header.Key.Name, _request); - if (header.Key.Descriptor is KnownHeader knownHeader) + if (header.Key.IsKnownHeader(out KnownHeader? knownHeader, out string? headerName)) { // The Host header is not sent for HTTP/3 because we send the ":authority" pseudo-header instead // (see pseudo-header handling below in WriteHeaders). @@ -675,7 +675,7 @@ private void BufferHeaderCollection(HttpHeaders headers) else { // The header is not known: fall back to just encoding the header name and value(s). - BufferLiteralHeaderWithoutNameReference(Unsafe.As(header.Key.Descriptor), headerValues, HttpHeaderParser.DefaultSeparator, valueEncoding); + BufferLiteralHeaderWithoutNameReference(headerName, headerValues, HttpHeaderParser.DefaultSeparator, valueEncoding); } } } @@ -898,7 +898,7 @@ private void OnHeader(int? staticIndex, HeaderDescriptor descriptor, string? sta { if (descriptor.Name[0] == ':') { - if (descriptor.Descriptor != KnownHeaders.PseudoStatus) + if (!descriptor.Equals(KnownHeaders.PseudoStatus)) { if (NetEventSource.Log.IsEnabled()) Trace($"Received unknown pseudo-header '{descriptor.Name}'."); throw new Http3ConnectionException(Http3ErrorCode.ProtocolError); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs index 35d2a2c8225fa5..6077a25138e5cc 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs @@ -262,18 +262,18 @@ private async ValueTask WriteHeadersAsync(HttpHeaders headers, string? cookiesFr { foreach (HeaderEntry header in entries) { - if (header.Key.Descriptor is null) + if (!header.Key.HasValue) { break; } - if (header.Key.Descriptor is KnownHeader) + if (header.Key.IsKnownHeader(out KnownHeader? knownHeader, out string? headerName)) { - await WriteBytesAsync(Unsafe.As(header.Key.Descriptor).AsciiBytesWithColonSpace, async).ConfigureAwait(false); + await WriteBytesAsync(knownHeader.AsciiBytesWithColonSpace, async).ConfigureAwait(false); } else { - await WriteAsciiStringAsync(Unsafe.As(header.Key.Descriptor), async).ConfigureAwait(false); + await WriteAsciiStringAsync(headerName, async).ConfigureAwait(false); await WriteTwoBytesAsync((byte)':', (byte)' ', async).ConfigureAwait(false); } @@ -285,7 +285,7 @@ private async ValueTask WriteHeadersAsync(HttpHeaders headers, string? cookiesFr await WriteStringAsync(_headerValues[0], async, valueEncoding).ConfigureAwait(false); - if (cookiesFromContainer != null && header.Key.Descriptor == KnownHeaders.Cookie) + if (cookiesFromContainer != null && header.Key.Equals(KnownHeaders.Cookie)) { await WriteTwoBytesAsync((byte)';', (byte)' ', async).ConfigureAwait(false); await WriteStringAsync(cookiesFromContainer, async, valueEncoding).ConfigureAwait(false); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs index 89a9b61903c2f7..45eae0fa57649f 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs @@ -24,8 +24,8 @@ internal abstract class HttpConnectionBase : IDisposable, IHttpTrace public string GetResponseHeaderValueWithCaching(HeaderDescriptor descriptor, ReadOnlySpan value, Encoding? valueEncoding) { return - ReferenceEquals(descriptor.Descriptor, KnownHeaders.Date) ? GetOrAddCachedValue(ref _lastDateHeaderValue, descriptor, value, valueEncoding) : - ReferenceEquals(descriptor.Descriptor, KnownHeaders.Server) ? GetOrAddCachedValue(ref _lastServerHeaderValue, descriptor, value, valueEncoding) : + descriptor.Equals(KnownHeaders.Date) ? GetOrAddCachedValue(ref _lastDateHeaderValue, descriptor, value, valueEncoding) : + descriptor.Equals(KnownHeaders.Server) ? GetOrAddCachedValue(ref _lastServerHeaderValue, descriptor, value, valueEncoding) : descriptor.GetHeaderValue(value, valueEncoding); static string GetOrAddCachedValue([NotNull] ref string? cache, HeaderDescriptor descriptor, ReadOnlySpan value, Encoding? encoding) diff --git a/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs b/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs index ca607f9e33f1e8..23b7a8e8cd9fc3 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs @@ -62,7 +62,7 @@ private static Memory HPackEncode(HttpHeaders headers, Encoding? valueEnco foreach (HeaderEntry header in headers.GetEntries()) { - if (header.Key.Descriptor is null) + if (!header.Key.HasValue) { break; } @@ -71,7 +71,7 @@ private static Memory HPackEncode(HttpHeaders headers, Encoding? valueEnco Assert.InRange(headerValuesCount, 0, int.MaxValue); ReadOnlySpan headerValuesSpan = headerValues.AsSpan(0, headerValuesCount); - if (header.Key.Descriptor is KnownHeader knownHeader) + if (header.Key.IsKnownHeader(out KnownHeader? knownHeader, out string? headerName)) { // For all other known headers, send them via their pre-encoded name and the associated value. WriteBytes(knownHeader.Http2EncodedName); @@ -83,7 +83,7 @@ private static Memory HPackEncode(HttpHeaders headers, Encoding? valueEnco else { // The header is not known: fall back to just encoding the header name and value(s). - WriteLiteralHeader(header.Key.Name, headerValuesSpan); + WriteLiteralHeader(headerName, headerValuesSpan); } } diff --git a/src/libraries/System.Net.Http/tests/UnitTests/Headers/HeaderEncodingTest.cs b/src/libraries/System.Net.Http/tests/UnitTests/Headers/HeaderEncodingTest.cs index 9b3ecd7aa9a5f8..57a5b37d91a5cf 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/Headers/HeaderEncodingTest.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/Headers/HeaderEncodingTest.cs @@ -24,12 +24,15 @@ public void RoundTripsUtf8(string input) byte[] encoded = Encoding.UTF8.GetBytes(input); Assert.True(HeaderDescriptor.TryGet("custom-header", out HeaderDescriptor descriptor)); - Assert.IsType(descriptor.Descriptor); + Assert.False(descriptor.IsKnownHeader(out _, out string? headerName)); + Assert.Equal("custom-header", headerName); string roundtrip = descriptor.GetHeaderValue(encoded, Encoding.UTF8); Assert.Equal(input, roundtrip); Assert.True(HeaderDescriptor.TryGet("Cache-Control", out descriptor)); - Assert.IsType(descriptor.Descriptor); + Assert.True(descriptor.IsKnownHeader(out KnownHeader? knownHeader, out _)); + Assert.NotNull(knownHeader); + Assert.Equal("Cache-Control", knownHeader.Name); roundtrip = descriptor.GetHeaderValue(encoded, Encoding.UTF8); Assert.Equal(input, roundtrip); } diff --git a/src/libraries/System.Net.Http/tests/UnitTests/Headers/KnownHeadersTest.cs b/src/libraries/System.Net.Http/tests/UnitTests/Headers/KnownHeadersTest.cs index b94563f4dce3de..419495e1fc3e0b 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/Headers/KnownHeadersTest.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/Headers/KnownHeadersTest.cs @@ -118,7 +118,8 @@ static void Validate(string name, KnownHeader h) Assert.NotNull(h); Assert.Same(h, KnownHeaders.TryGetKnownHeader(name)); - Assert.Same(h, h.Descriptor.Descriptor); + Assert.True(h.Descriptor.IsKnownHeader(out KnownHeader? knownHeader, out _)); + Assert.Same(h, knownHeader); Assert.Equal(name, h.Name, StringComparer.OrdinalIgnoreCase); Assert.Equal(name, h.Descriptor.Name, StringComparer.OrdinalIgnoreCase); } From 0a700363a2156a53986e0a0acfebe04b62ee2275 Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Tue, 21 Dec 2021 15:13:31 +0100 Subject: [PATCH 10/22] Simplify HttpHeaderParser.Separator logic --- .../Net/Http/Headers/HttpHeaderParser.cs | 21 ++++--------------- .../System/Net/Http/Headers/KnownHeader.cs | 2 +- 2 files changed, 5 insertions(+), 18 deletions(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaderParser.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaderParser.cs index 2fa79f1f41812e..86d689256e6a59 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaderParser.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaderParser.cs @@ -12,21 +12,14 @@ internal abstract class HttpHeaderParser internal const string DefaultSeparator = ", "; private readonly bool _supportsMultipleValues; - private readonly string? _separator; + private readonly string _separator; public bool SupportsMultipleValues { get { return _supportsMultipleValues; } } - public string? Separator - { - get - { - Debug.Assert(_supportsMultipleValues); - return _separator; - } - } + public string Separator => _separator; // If ValueType implements Equals() as required, there is no need to provide a comparer. A comparer is needed // e.g. if we want to compare strings using case-insensitive comparison. @@ -36,14 +29,8 @@ public virtual IEqualityComparer? Comparer } protected HttpHeaderParser(bool supportsMultipleValues) - { - _supportsMultipleValues = supportsMultipleValues; - - if (supportsMultipleValues) - { - _separator = DefaultSeparator; - } - } + : this(supportsMultipleValues, DefaultSeparator) + { } protected HttpHeaderParser(bool supportsMultipleValues, string separator) { diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/KnownHeader.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/KnownHeader.cs index 72705dd19ea8df..d835451a15e290 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/KnownHeader.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/KnownHeader.cs @@ -48,6 +48,6 @@ public KnownHeader(string name, HttpHeaderType headerType, HttpHeaderParser? par public byte[] AsciiBytesWithColonSpace { get; } public HeaderDescriptor Descriptor => new HeaderDescriptor(this); - public string Separator => Parser is HttpHeaderParser parser && parser.SupportsMultipleValues ? parser.Separator! : HttpHeaderParser.DefaultSeparator; + public string Separator => Parser is HttpHeaderParser parser ? parser.Separator : HttpHeaderParser.DefaultSeparator; } } From 3f1275bd558453a1177d7a4e9bd3e460c006fc56 Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Tue, 21 Dec 2021 15:23:34 +0100 Subject: [PATCH 11/22] Add comments on HasValue checks --- .../src/System/Net/Http/Headers/HeaderDescriptor.cs | 4 +++- .../src/System/Net/Http/Headers/HttpHeaders.cs | 3 +++ .../src/System/Net/Http/Headers/HttpHeadersNonValidated.cs | 2 +- .../src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs | 1 + .../System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs | 1 + .../src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs | 1 + .../tests/UnitTests/HPack/HPackRoundtripTests.cs | 1 + 7 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs index 2df2800081e7a8..1085da25e75bd8 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs @@ -17,7 +17,9 @@ namespace System.Net.Http.Headers internal readonly struct HeaderDescriptor : IEquatable { /// - /// Either a or + /// Either a or . + /// Marked as nullable since a default (uninitialized) instance of this struct is also used in practice + /// to indicate the end of the header collection. /// private readonly object? _descriptor; diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index 2107cca8fd2015..3d55996f400947 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -241,6 +241,7 @@ public override string ToString() { if (!entry.Key.HasValue) { + // An entry without a value indicates the end of the header collection break; } @@ -310,6 +311,7 @@ private IEnumerator>> GetEnumeratorCore HeaderDescriptor descriptor = entry.Key; if (!descriptor.HasValue) { + // An entry without a value indicates the end of the header collection break; } @@ -537,6 +539,7 @@ internal virtual void AddHeaders(HttpHeaders sourceHeaders) { if (!entry.Key.HasValue) { + // An entry without a value indicates the end of the header collection break; } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs index 2bef3d6ac51fb2..53939253c240db 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs @@ -142,7 +142,7 @@ public bool MoveNext() HeaderEntry entry = entries[index]; _index++; - if (entry.Key.HasValue) + if (entry.Key.HasValue) // An entry without a value indicates the end of the header collection { HttpHeaders.GetStoreValuesAsStringOrStringArray(entry.Key, entry.Value, out string? singleValue, out string[]? multiValue); Debug.Assert(singleValue is not null ^ multiValue is not null); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs index 4a52beea6da6ae..b159573f012f02 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs @@ -1349,6 +1349,7 @@ private void WriteHeaderCollection(HttpRequestMessage request, HttpHeaders heade { if (!header.Key.HasValue) { + // An entry without a value indicates the end of the header collection break; } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs index e179c7c1ef9783..c2998c654c7581 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs @@ -633,6 +633,7 @@ private void BufferHeaderCollection(HttpHeaders headers) { if (!header.Key.HasValue) { + // An entry without a value indicates the end of the header collection break; } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs index 6077a25138e5cc..139ebc46a0668c 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs @@ -264,6 +264,7 @@ private async ValueTask WriteHeadersAsync(HttpHeaders headers, string? cookiesFr { if (!header.Key.HasValue) { + // An entry without a value indicates the end of the header collection break; } diff --git a/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs b/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs index 23b7a8e8cd9fc3..2c41c4ce2f1fdb 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs @@ -64,6 +64,7 @@ private static Memory HPackEncode(HttpHeaders headers, Encoding? valueEnco { if (!header.Key.HasValue) { + // An entry without a value indicates the end of the header collection break; } From af6811351bde634078125182764463d7090ef5a4 Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Tue, 4 Jan 2022 13:59:30 +0100 Subject: [PATCH 12/22] Lazily group headers by name --- .../Net/Http/Headers/HeaderDescriptor.cs | 2 - .../System/Net/Http/Headers/HttpHeaders.cs | 442 ++++++++++-------- .../Http/Headers/HttpHeadersNonValidated.cs | 4 +- 3 files changed, 250 insertions(+), 198 deletions(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs index 1085da25e75bd8..d3d15b0ad28ec0 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs @@ -72,8 +72,6 @@ public bool IsKnownHeader([NotNullWhen(true)] out KnownHeader? knownHeader, [Not public bool Equals(KnownHeader other) => ReferenceEquals(_descriptor, other); - public bool Equals(string other) => string.Equals(_descriptor as string, other, StringComparison.OrdinalIgnoreCase); - public bool Equals(HeaderDescriptor other) { object? descriptor = _descriptor; diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index 3d55996f400947..0a540f335f3985 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -16,6 +16,12 @@ internal struct HeaderEntry { public HeaderDescriptor Key; public object Value; + + public HeaderEntry(HeaderDescriptor key, object value) + { + Key = key; + Value = value; + } } public abstract class HttpHeaders : IEnumerable>> @@ -42,6 +48,8 @@ public abstract class HttpHeaders : IEnumerableEither a array or a Dictionary<, > private object? _headerStore; + private int _count; + private bool _entriesMayBeUngrouped; // Indicates that a non-validating Add was performed and some entries may not be grouped by name private readonly HttpHeaderType _allowedHeaderTypes; private readonly HttpHeaderType _treatAsCustomHeaderTypes; @@ -77,7 +85,8 @@ internal void Add(HeaderDescriptor descriptor, string? value) // it to the store if we added at least one value. if (addToStore && (info.ParsedValue != null)) { - AddHeaderToStore(descriptor, info); + Debug.Assert(!ContainsEntry(descriptor)); + AddEntryToStore(new HeaderEntry(descriptor, info)); } } @@ -110,7 +119,8 @@ internal void Add(HeaderDescriptor descriptor, IEnumerable values) // However, if all values for a _new_ header were invalid, then don't add the header. if (addToStore && (info.ParsedValue != null)) { - AddHeaderToStore(descriptor, info); + Debug.Assert(!ContainsEntry(descriptor)); + AddEntryToStore(new HeaderEntry(descriptor, info)); } } } @@ -126,26 +136,8 @@ internal bool TryAddWithoutValidation(HeaderDescriptor descriptor, string? value // values, e.g. adding two null-strings (or empty, or whitespace-only) results in "My-Header: ,". value ??= string.Empty; - ref object? storeValueRef = ref GetValueRefOrAddDefault(descriptor); - object? currentValue = storeValueRef; - - if (currentValue is null) - { - storeValueRef = value; - } - else if (currentValue is HeaderStoreItemInfo info) - { - // The header store already contained a HeaderStoreItemInfo, so add to it. - AddRawValue(info, value); - } - else - { - // The header store contained a single raw string value, so promote it - // to being a HeaderStoreItemInfo and add to it. - Debug.Assert(currentValue is string); - storeValueRef = info = new HeaderStoreItemInfo() { RawValue = currentValue }; - AddRawValue(info, value); - } + AddEntryToStore(new HeaderEntry(descriptor, value)); + _entriesMayBeUngrouped = true; return true; } @@ -161,21 +153,33 @@ internal bool TryAddWithoutValidation(HeaderDescriptor descriptor, IEnumerable enumerator = values.GetEnumerator()) + using IEnumerator enumerator = values.GetEnumerator(); + if (enumerator.MoveNext()) { + string firstValue = enumerator.Current ?? string.Empty; if (enumerator.MoveNext()) { - TryAddWithoutValidation(descriptor, enumerator.Current); - if (enumerator.MoveNext()) + var valuesList = new List { - HeaderStoreItemInfo info = GetOrCreateHeaderInfo(descriptor, parseRawValues: false); - do - { - AddRawValue(info, enumerator.Current ?? string.Empty); - } - while (enumerator.MoveNext()); + firstValue + }; + + do + { + valuesList.Add(enumerator.Current ?? string.Empty); } + while (enumerator.MoveNext()); + + AddEntryToStore(new HeaderEntry(descriptor, new HeaderStoreItemInfo + { + RawValue = valuesList + })); } + else + { + AddEntryToStore(new HeaderEntry(descriptor, firstValue)); + } + _entriesMayBeUngrouped = true; } return true; @@ -297,13 +301,13 @@ internal string GetHeaderString(HeaderDescriptor descriptor) #region IEnumerable>> Members - public IEnumerator>> GetEnumerator() => IsEmpty ? + public IEnumerator>> GetEnumerator() => _count == 0 ? ((IEnumerable>>)Array.Empty>>()).GetEnumerator() : GetEnumeratorCore(); private IEnumerator>> GetEnumeratorCore() { - HeaderEntry[] entries = GetEntries()!; + HeaderEntry[] entries = GetEntriesGroupedByName()!; for (int i = 0; i < entries.Length; i++) { @@ -331,8 +335,8 @@ private IEnumerator>> GetEnumeratorCore } else { - Debug.Assert(GetValueRefOrAddDefault(descriptor) is not null); - GetValueRefOrAddDefault(descriptor) = info; + Debug.Assert(ContainsEntry(descriptor)); + ((Dictionary)_headerStore!)[descriptor] = info; } } @@ -370,7 +374,7 @@ internal void AddParsedValue(HeaderDescriptor descriptor, object value) Debug.Assert(value != null); Debug.Assert(descriptor.Parser != null, "Can't add parsed value if there is no parser available."); - HeaderStoreItemInfo info = GetOrCreateHeaderInfo(descriptor, parseRawValues: true); + HeaderStoreItemInfo info = GetOrCreateHeaderInfo(descriptor); // If the current header has only one value, we can't add another value. The strongly typed property // must not call AddParsedValue(), but SetParsedValue(). E.g. for headers like 'Date', 'Host'. @@ -386,7 +390,7 @@ internal void SetParsedValue(HeaderDescriptor descriptor, object value) // This method will first clear all values. This is used e.g. when setting the 'Date' or 'Host' header. // i.e. headers not supporting collections. - HeaderStoreItemInfo info = GetOrCreateHeaderInfo(descriptor, parseRawValues: true); + HeaderStoreItemInfo info = GetOrCreateHeaderInfo(descriptor); info.InvalidValue = null; info.ParsedValue = null; @@ -533,7 +537,7 @@ internal virtual void AddHeaders(HttpHeaders sourceHeaders) Debug.Assert(sourceHeaders != null); Debug.Assert(GetType() == sourceHeaders.GetType(), "Can only copy headers from an instance of the same type."); - if (sourceHeaders.GetEntries() is HeaderEntry[] sourceEntries) + if (sourceHeaders.GetEntriesGroupedByName() is HeaderEntry[] sourceEntries) { foreach (HeaderEntry entry in sourceEntries) { @@ -546,19 +550,15 @@ internal virtual void AddHeaders(HttpHeaders sourceHeaders) // Only add header values if they're not already set on the message. Note that we don't merge // collections: If both the default headers and the message have set some values for a certain // header, then we don't try to merge the values. - ref object? storeValueRef = ref GetValueRefOrAddDefault(entry.Key); - if (storeValueRef is null) + if (!ContainsEntry(entry.Key)) { - object sourceValue = entry.Value; - if (sourceValue is HeaderStoreItemInfo info) - { - storeValueRef = CloneHeaderInfo(entry.Key, info); - } - else + object value = entry.Value; + if (value is HeaderStoreItemInfo info) { - Debug.Assert(sourceValue is string); - storeValueRef = sourceValue; + value = CloneHeaderInfo(entry.Key, info); } + + AddEntryToStore(new HeaderEntry(entry.Key, value)); } } } @@ -644,56 +644,36 @@ private static void CloneAndAddValue(HeaderStoreItemInfo destinationInfo, object } } - private HeaderStoreItemInfo GetOrCreateHeaderInfo(HeaderDescriptor descriptor, bool parseRawValues) + private HeaderStoreItemInfo GetOrCreateHeaderInfo(HeaderDescriptor descriptor) { - if (parseRawValues) + if (TryGetAndParseHeaderInfo(descriptor, out HeaderStoreItemInfo? info)) { - if (TryGetAndParseHeaderInfo(descriptor, out HeaderStoreItemInfo? info)) - { - return info; - } + return info; } else { - ref object storeValueRef = ref GetValueRefOrNullRef(descriptor); - if (!Unsafe.IsNullRef(ref storeValueRef)) - { - object value = storeValueRef; - if (value is not HeaderStoreItemInfo info) - { - Debug.Assert(value is string); - storeValueRef = info = new HeaderStoreItemInfo { RawValue = value }; - } - return info; - } + return CreateAndAddHeaderToStore(descriptor); } - - return CreateAndAddHeaderToStore(descriptor); } private HeaderStoreItemInfo CreateAndAddHeaderToStore(HeaderDescriptor descriptor) { + Debug.Assert(!ContainsEntry(descriptor)); + // If we don't have the header in the store yet, add it now. HeaderStoreItemInfo result = new HeaderStoreItemInfo(); // If the descriptor header type is in _treatAsCustomHeaderTypes, it must be converted to a custom header before calling this method Debug.Assert((descriptor.HeaderType & _treatAsCustomHeaderTypes) == 0); - AddHeaderToStore(descriptor, result); + AddEntryToStore(new HeaderEntry(descriptor, result)); return result; } - private void AddHeaderToStore(HeaderDescriptor descriptor, object value) - { - Debug.Assert(value is string or HeaderStoreItemInfo); - Debug.Assert(Unsafe.IsNullRef(ref GetValueRefOrNullRef(descriptor))); - GetValueRefOrAddDefault(descriptor) = value; - } - internal bool TryGetHeaderValue(HeaderDescriptor descriptor, [NotNullWhen(true)] out object? value) { - ref object storeValueRef = ref GetValueRefOrNullRef(descriptor); + ref object storeValueRef = ref GetValueRefGroupedByNameOrNullRef(descriptor); if (Unsafe.IsNullRef(ref storeValueRef)) { value = null; @@ -708,7 +688,7 @@ internal bool TryGetHeaderValue(HeaderDescriptor descriptor, [NotNullWhen(true)] private bool TryGetAndParseHeaderInfo(HeaderDescriptor key, [NotNullWhen(true)] out HeaderStoreItemInfo? info) { - ref object storeValueRef = ref GetValueRefOrNullRef(key); + ref object storeValueRef = ref GetValueRefGroupedByNameOrNullRef(key); if (!Unsafe.IsNullRef(ref storeValueRef)) { object value = storeValueRef; @@ -756,7 +736,7 @@ private bool ParseRawHeaderValues(HeaderDescriptor descriptor, HeaderStoreItemIn if ((info.InvalidValue == null) && (info.ParsedValue == null)) { // After parsing the raw value, no value is left because all values contain newline chars. - Debug.Assert(!IsEmpty); + Debug.Assert(_count > 0); Remove(descriptor); return false; } @@ -826,7 +806,8 @@ internal bool TryParseAndAddValue(HeaderDescriptor descriptor, string? value) { // If we get here, then the value could be parsed correctly. If we created a new HeaderStoreItemInfo, add // it to the store if we added at least one value. - AddHeaderToStore(descriptor, info); + Debug.Assert(!ContainsEntry(descriptor)); + AddEntryToStore(new HeaderEntry(descriptor, info)); } return result; @@ -1079,12 +1060,14 @@ internal bool TryGetHeaderDescriptor(string name, out HeaderDescriptor descripto if (HeaderDescriptor.TryGet(name, out descriptor)) { - if ((descriptor.HeaderType & _allowedHeaderTypes) != 0) + HttpHeaderType headerType = descriptor.HeaderType; + + if ((headerType & _allowedHeaderTypes) != 0) { return true; } - if ((descriptor.HeaderType & _treatAsCustomHeaderTypes) != 0) + if ((headerType & _treatAsCustomHeaderTypes) != 0) { descriptor = descriptor.AsCustomHeader(); return true; @@ -1312,24 +1295,10 @@ static HeaderEntry[] GetEntriesFromDictionary(Dictionary= (uint)entries.Length || !entries[0].Key.HasValue; - } - else - { - return Unsafe.As>(store).Count == 0; - } - } + GroupAllEntriesByName(); + return GetEntries(); } private bool EntriesAreLiveView => _headerStore is HeaderEntry[]; @@ -1338,64 +1307,112 @@ internal int Count { get { - object? store = _headerStore; - if (store is null) - { - return 0; - } - else if (store is HeaderEntry[] entries) + GroupAllEntriesByName(); + return _count; + } + } + + private void GroupAllEntriesByName() + { + if (_entriesMayBeUngrouped && _headerStore is HeaderEntry[] entries) + { + _entriesMayBeUngrouped = false; + + for (int i = 0; i < entries.Length; i++) { - for (int i = 0; i < entries.Length; i++) + HeaderDescriptor key = entries[i].Key; + if (!key.HasValue) + { + break; + } + + for (int j = i + 1; j < entries.Length; j++) { - if (!entries[i].Key.HasValue) + HeaderDescriptor secondKey = entries[j].Key; + if (!secondKey.HasValue) { - return i; + break; + } + + if (secondKey.Equals(key)) + { + MergeEntryValues(ref entries[i].Value, entries[j].Value); + RemoveAt(entries, j); + j--; } } - return entries.Length; } - else + } + } + + private static void MergeEntryValues(ref object existingValueRef, object newValue) + { + Debug.Assert(existingValueRef is not null); + Debug.Assert(newValue is not null); + + object existingValue = existingValueRef; + if (existingValue is not HeaderStoreItemInfo info) + { + Debug.Assert(existingValue is string); + existingValueRef = info = new HeaderStoreItemInfo { RawValue = existingValue }; + } + + if (newValue is HeaderStoreItemInfo newInfo) + { + // This should only happen if the TryAddWithoutValidation(string, IEnumerable) was called when a header was already present + Debug.Assert(newInfo.ParsedValue is null); + Debug.Assert(newInfo.InvalidValue is null); + Debug.Assert(newInfo.RawValue is not null); + Debug.Assert(newInfo.RawValue is List); + foreach (string value in (List)newInfo.RawValue) { - return Unsafe.As>(store).Count; + AddRawValue(info, value); } } + else + { + Debug.Assert(newValue is string); + AddRawValue(info, Unsafe.As(newValue)); + } } - private ref object GetValueRefOrNullRef(HeaderDescriptor key) + private ref object GetValueRefGroupedByNameOrNullRef(HeaderDescriptor key) { object? store = _headerStore; if (store is not null) { if (store is HeaderEntry[] entries) { - if (key.IsKnownHeader(out KnownHeader? knownHeader, out string? headerName)) + for (int i = 0; i < entries.Length; i++) { - for (int i = 0; i < entries.Length; i++) + ref HeaderEntry entry = ref entries[i]; + if (!entry.Key.HasValue) { - ref HeaderEntry entry = ref entries[i]; - if (!entry.Key.HasValue) - { - break; - } - else if (entry.Key.Equals(knownHeader)) - { - return ref entry.Value; - } + break; } - } - else - { - for (int i = 0; i < entries.Length; i++) + + if (entry.Key.Equals(key)) { - ref HeaderEntry entry = ref entries[i]; - if (!entry.Key.HasValue) + if (_entriesMayBeUngrouped) { - break; - } - else if (entry.Key.Equals(headerName)) - { - return ref entry.Value; + for (int j = i + 1; j < entries.Length; j++) + { + HeaderDescriptor secondKey = entries[j].Key; + if (!secondKey.HasValue) + { + break; + } + + if (secondKey.Equals(key)) + { + MergeEntryValues(ref entry.Value, entries[j].Value); + RemoveAt(entries, j); + j--; + } + } } + + return ref entry.Value; } } } @@ -1408,83 +1425,108 @@ private ref object GetValueRefOrNullRef(HeaderDescriptor key) return ref Unsafe.NullRef(); } - private ref object? GetValueRefOrAddDefault(HeaderDescriptor key) + private void AddEntryToStore(HeaderEntry entry) { - object? store = _headerStore; - if (store is null) - { - var entries = new HeaderEntry[InitialCapacity]; - _headerStore = entries; - ref HeaderEntry firstElement = ref MemoryMarshal.GetArrayDataReference(entries); - firstElement.Key = key; - return ref firstElement.Value!; - } - else if (store is HeaderEntry[] entries) + if (_headerStore is HeaderEntry[] entries) { - if (key.IsKnownHeader(out KnownHeader? knownHeader, out string? headerName)) + int count = _count; + if ((uint)count < (uint)entries.Length) { - for (int i = 0; i < entries.Length; i++) - { - ref HeaderEntry entry = ref entries[i]; - if (!entry.Key.HasValue) - { - entry.Key = key; - return ref entry.Value!; - } - else if (entry.Key.Equals(knownHeader)) - { - return ref entry.Value!; - } - } + entries[count] = entry; + _count++; } else { - for (int i = 0; i < entries.Length; i++) - { - ref HeaderEntry entry = ref entries[i]; - if (!entry.Key.HasValue) - { - entry.Key = key; - return ref entry.Value!; - } - else if (entry.Key.Equals(headerName)) - { - return ref entry.Value!; - } - } + GrowAndAddEntry(ref entry); } - - return ref GrowEntriesAndAddDefault(key); } else { - return ref CollectionsMarshal.GetValueRefOrAddDefault(Unsafe.As>(store), key, out s_dictionaryGetValueRefOrAddDefaultExistsDummy); + InitializeOrAddToDictionary(ref entry); + } + + void InitializeOrAddToDictionary(ref HeaderEntry entry) + { + if (_headerStore is Dictionary) + { + AddEntryToDictionary(ref entry); + } + else + { + var entries = new HeaderEntry[InitialCapacity]; + _headerStore = entries; + MemoryMarshal.GetArrayDataReference(entries) = entry; + _count = 1; + } } - ref object? GrowEntriesAndAddDefault(HeaderDescriptor key) + void GrowAndAddEntry(ref HeaderEntry entry) { var entries = (HeaderEntry[])_headerStore!; if (entries.Length == ArrayThreshold) { - var dictionary = new Dictionary(ArrayThreshold); - _headerStore = dictionary; - foreach (HeaderEntry entry in entries) + _headerStore = new Dictionary(ArrayThreshold); + _count = 0; + for (int i = 0; i < entries.Length; i++) { - dictionary.Add(entry.Key, entry.Value); + AddEntryToDictionary(ref entries[i]); } - return ref CollectionsMarshal.GetValueRefOrAddDefault(dictionary, key, out s_dictionaryGetValueRefOrAddDefaultExistsDummy); + AddEntryToDictionary(ref entry); } else { Array.Resize(ref entries, entries.Length << 1); + entries[entries.Length >> 1] = entry; _headerStore = entries; - ref HeaderEntry firstNewEntry = ref entries[entries.Length >> 1]; - firstNewEntry.Key = key; - return ref firstNewEntry.Value!; + _count++; + } + } + + void AddEntryToDictionary(ref HeaderEntry entry) + { + var dictionary = (Dictionary)_headerStore!; + ref object? valueRef = ref CollectionsMarshal.GetValueRefOrAddDefault(dictionary, entry.Key, out s_dictionaryGetValueRefOrAddDefaultExistsDummy); + if (valueRef is null) + { + valueRef = entry.Value; + _count++; + } + else + { + MergeEntryValues(ref valueRef, entry.Value); } } } + internal bool ContainsEntry(HeaderDescriptor key) + { + object? store = _headerStore; + if (store is not null) + { + if (store is HeaderEntry[] entries) + { + for (int i = 0; i < entries.Length; i++) + { + HeaderDescriptor entryKey = entries[i].Key; + if (!entryKey.HasValue) + { + break; + } + else if (entryKey.Equals(key)) + { + return true; + } + } + } + else + { + return Unsafe.As>(store).ContainsKey(key); + } + } + + return false; + } + public void Clear() { object? store = _headerStore; @@ -1492,12 +1534,13 @@ public void Clear() { if (store is HeaderEntry[] entries) { - Array.Clear(entries); + Array.Clear(entries, 0, _count); } else { Unsafe.As>(store).Clear(); } + _count = 0; } } @@ -1508,6 +1551,8 @@ internal bool Remove(HeaderDescriptor key) { if (store is HeaderEntry[] entries) { + bool removedAny = false; + for (int i = 0; i < entries.Length; i++) { HeaderDescriptor entryKey = entries[i].Key; @@ -1518,26 +1563,35 @@ internal bool Remove(HeaderDescriptor key) if (entryKey.Equals(key)) { - while ((uint)(i + 1) < (uint)entries.Length && entries[i + 1].Key.HasValue) - { - entries[i] = entries[i + 1]; - i++; - } - entries[i] = default; - - return true; + removedAny = true; + RemoveAt(entries, i); + i--; } } + + return removedAny; } - else + else if (Unsafe.As>(store).Remove(key)) { - return Unsafe.As>(store).Remove(key); + _count--; + return true; } } return false; } + private void RemoveAt(HeaderEntry[] entries, int index) + { + int count = _count--; + while ((uint)(index + 1) < (uint)entries.Length && index + 1 < count) + { + entries[index] = entries[index + 1]; + index++; + } + entries[index] = default; + } + #endregion // _headerStore implementation } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs index 53939253c240db..4660f08438d758 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs @@ -33,7 +33,7 @@ namespace System.Net.Http.Headers public bool Contains(string headerName) => _headers is HttpHeaders headers && headers.TryGetHeaderDescriptor(headerName, out HeaderDescriptor descriptor) && - headers.TryGetHeaderValue(descriptor, out _); + headers.ContainsEntry(descriptor); /// Gets the values for the specified header name. /// The name of the header. @@ -83,7 +83,7 @@ public bool TryGetValues(string headerName, out HeaderStringValues values) /// Gets an enumerator that iterates through the . /// An enumerator that iterates through the . public Enumerator GetEnumerator() => - _headers is HttpHeaders headers && headers.GetEntries() is HeaderEntry[] entries ? + _headers is HttpHeaders headers && headers.GetEntriesGroupedByName() is HeaderEntry[] entries ? new Enumerator(entries) : default; From 2787afcf3cedfef3f246fd092da52bc41c13925f Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Tue, 4 Jan 2022 16:01:14 +0100 Subject: [PATCH 13/22] Add a header ordering+grouping test --- .../Net/Http/Http2LoopbackConnection.cs | 18 +++- .../System/Net/Http/HttpClientHandlerTest.cs | 91 +++++++++++++++++++ 2 files changed, 104 insertions(+), 5 deletions(-) diff --git a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs index b3bb701e540f36..01ac5cf7417254 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs @@ -28,6 +28,7 @@ public class Http2LoopbackConnection : GenericLoopbackConnection private readonly TimeSpan _timeout; private int _lastStreamId; private bool _expectClientDisconnect; + private List _dynamicHeaderTable = new(); private readonly byte[] _prefix = new byte[24]; public string PrefixString => Encoding.UTF8.GetString(_prefix, 0, _prefix.Length); @@ -490,12 +491,15 @@ private static (int bytesConsumed, string value) DecodeString(ReadOnlySpan new HttpHeaderData("www-authenticate", "") }; - private static HttpHeaderData GetHeaderForIndex(int index) + private HttpHeaderData GetHeaderForIndex(int index) { - return s_staticTable[index - 1]; + index--; + return index < s_staticTable.Length + ? s_staticTable[index] + : _dynamicHeaderTable[index - s_staticTable.Length]; } - private static (int bytesConsumed, HttpHeaderData headerData) DecodeLiteralHeader(ReadOnlySpan headerBlock, byte prefixMask) + private (int bytesConsumed, HttpHeaderData headerData) DecodeLiteralHeader(ReadOnlySpan headerBlock, byte prefixMask) { int i = 0; @@ -520,7 +524,7 @@ private static (int bytesConsumed, HttpHeaderData headerData) DecodeLiteralHeade return (i, new HttpHeaderData(name, value)); } - private static (int bytesConsumed, HttpHeaderData headerData) DecodeHeader(ReadOnlySpan headerBlock) + private (int bytesConsumed, HttpHeaderData headerData) DecodeHeader(ReadOnlySpan headerBlock) { int i = 0; @@ -536,7 +540,11 @@ private static (int bytesConsumed, HttpHeaderData headerData) DecodeHeader(ReadO else if ((b & 0b11000000) == 0b01000000) { // Literal with indexing - return DecodeLiteralHeader(headerBlock, 0b00111111); + (int bytesConsumed, HttpHeaderData headerData) = DecodeLiteralHeader(headerBlock, 0b00111111); + + _dynamicHeaderTable.Insert(0, headerData); + + return (bytesConsumed, headerData); } else if ((b & 0b11100000) == 0b00100000) { diff --git a/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.cs b/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.cs index 48379e2b505509..9e67637af35570 100644 --- a/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.cs +++ b/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.cs @@ -357,6 +357,97 @@ await LoopbackServer.CreateClientAndServerAsync(async uri => }, server => server.AcceptConnectionSendCustomResponseAndCloseAsync($"HTTP/1.1 200 OK\r\n{invalidHeader}\r\n{LoopbackServer.CorsHeaders}Content-Length: 11\r\n\r\nhello world")); } + [Theory] + [InlineData(false, false, false)] + [InlineData(false, true, true)] + [InlineData(true, false, true)] + public async Task SendAsync_MultipleEntriesPerHeaderName_ValuesMayNotBeGrouped(bool manyHeaders, bool enumerateHeadersBeforeSend, bool valuesShouldBeGrouped) + { + Assert.True(valuesShouldBeGrouped || !manyHeaders, "Values will be grouped if we go over the HttpHeaders.ArrayThreshold"); + Assert.True(valuesShouldBeGrouped || !enumerateHeadersBeforeSend, "Enumerating the values forces them to be grouped by name"); + + if (PlatformDetection.IsBrowser) + { + valuesShouldBeGrouped = true; + } + +#if !NETCOREAPP + valuesShouldBeGrouped = true; +#endif + + await LoopbackServerFactory.CreateClientAndServerAsync(async uri => + { + var request = new HttpRequestMessage(HttpMethod.Get, uri) + { + Version = UseVersion + }; + + request.Headers.TryAddWithoutValidation("foo", "foo-single-1"); + request.Headers.TryAddWithoutValidation("bar", "bar-single-1"); + request.Headers.TryAddWithoutValidation("foo", new[] { "foo-multi-1", "foo-multi-2" }); + request.Headers.TryAddWithoutValidation("bar", "bar-single-2"); + request.Headers.TryAddWithoutValidation("foo", "foo-single-2"); + + if (manyHeaders) + { + for (int i = 0; i < 100; i++) + { + request.Headers.TryAddWithoutValidation($"dummy-{i}", "dummy"); + } + } + + if (enumerateHeadersBeforeSend) + { + _ = request.Headers.ToArray(); + } + + using HttpClient client = CreateHttpClient(); + + (await client.SendAsync(TestAsync, request)).Dispose(); + }, + async server => + { + HttpRequestData requestData = await server.HandleRequestAsync(HttpStatusCode.OK); + HttpHeaderData[] headers = requestData.Headers.Where(h => h.Name == "foo" || h.Name == "bar").ToArray(); + + if (valuesShouldBeGrouped) + { + Assert.Equal(2, headers.Length); + + if (manyHeaders) + { + // Ordering is not preserved after HttpHeaders.ArrayThreshold + headers = headers.OrderByDescending(h => h.Name).ToArray(); + } + + Assert.Equal("foo", headers[0].Name); + Assert.Equal("foo-single-1, foo-multi-1, foo-multi-2, foo-single-2", headers[0].Value); + + Assert.Equal("bar", headers[1].Name); + Assert.Equal("bar-single-1, bar-single-2", headers[1].Value); + } + else + { + Assert.Equal(5, headers.Length); + + Assert.Equal("foo", headers[0].Name); + Assert.Equal("foo-single-1", headers[0].Value); + + Assert.Equal("bar", headers[1].Name); + Assert.Equal("bar-single-1", headers[1].Value); + + Assert.Equal("foo", headers[2].Name); + Assert.Equal("foo-multi-1, foo-multi-2", headers[2].Value); + + Assert.Equal("bar", headers[3].Name); + Assert.Equal("bar-single-2", headers[3].Value); + + Assert.Equal("foo", headers[4].Name); + Assert.Equal("foo-single-2", headers[4].Value); + } + }); + } + [Theory] [InlineData(false, false)] [InlineData(true, false)] From 17164113af245c371355530c0ec345c5376333eb Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Thu, 6 Jan 2022 20:26:59 +0100 Subject: [PATCH 14/22] Make use of the _count field --- .../Net/Http/Headers/HeaderDescriptor.cs | 29 +-- .../System/Net/Http/Headers/HttpHeaders.cs | 227 ++++++++---------- .../Http/Headers/HttpHeadersNonValidated.cs | 27 +-- .../SocketsHttpHandler/Http2Connection.cs | 13 +- .../SocketsHttpHandler/Http3RequestStream.cs | 13 +- .../Http/SocketsHttpHandler/HttpConnection.cs | 14 +- .../UnitTests/HPack/HPackRoundtripTests.cs | 6 - .../UnitTests/Headers/HttpHeadersTest.cs | 4 +- 8 files changed, 127 insertions(+), 206 deletions(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs index d3d15b0ad28ec0..e2f73aa50a28c1 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs @@ -18,10 +18,8 @@ namespace System.Net.Http.Headers { /// /// Either a or . - /// Marked as nullable since a default (uninitialized) instance of this struct is also used in practice - /// to indicate the end of the header collection. /// - private readonly object? _descriptor; + private readonly object _descriptor; public HeaderDescriptor(KnownHeader knownHeader) { @@ -34,23 +32,19 @@ internal HeaderDescriptor(string headerName) _descriptor = headerName; } - public bool HasValue => _descriptor is not null; - public string Name { get { - Debug.Assert(_descriptor is not null); object descriptor = _descriptor; - return descriptor is KnownHeader knownHeader ? - knownHeader.Name : + return descriptor.GetType() == typeof(KnownHeader) ? + Unsafe.As(descriptor).Name : Unsafe.As(descriptor); } } public bool IsKnownHeader([NotNullWhen(true)] out KnownHeader? knownHeader, [NotNullWhen(false)] out string? headerName) { - Debug.Assert(_descriptor is not null); object descriptor = _descriptor; if (descriptor.GetType() == typeof(KnownHeader)) { @@ -66,7 +60,7 @@ public bool IsKnownHeader([NotNullWhen(true)] out KnownHeader? knownHeader, [Not } } - public HttpHeaderParser? Parser => _descriptor is KnownHeader knownHeader ? knownHeader.Parser : null; + public HttpHeaderParser? Parser => (_descriptor as KnownHeader)?.Parser; public HttpHeaderType HeaderType => _descriptor is KnownHeader knownHeader ? knownHeader.HeaderType : HttpHeaderType.Custom; public string Separator => _descriptor is KnownHeader knownHeader ? knownHeader.Separator : HttpHeaderParser.DefaultSeparator; @@ -74,26 +68,23 @@ public bool IsKnownHeader([NotNullWhen(true)] out KnownHeader? knownHeader, [Not public bool Equals(HeaderDescriptor other) { - object? descriptor = _descriptor; - object? otherDescriptor = other._descriptor; - - if (descriptor is string headerName) + object descriptor = _descriptor; + if (descriptor.GetType() == typeof(string)) { - return string.Equals(headerName, otherDescriptor as string, StringComparison.OrdinalIgnoreCase); + return string.Equals(Unsafe.As(descriptor), other._descriptor as string, StringComparison.OrdinalIgnoreCase); } else { - return ReferenceEquals(descriptor, otherDescriptor); + return ReferenceEquals(descriptor, other._descriptor); } } public override int GetHashCode() { - Debug.Assert(_descriptor is not null); object descriptor = _descriptor; - if (descriptor is string headerName) + if (descriptor.GetType() == typeof(string)) { - return StringComparer.OrdinalIgnoreCase.GetHashCode(headerName); + return StringComparer.OrdinalIgnoreCase.GetHashCode(Unsafe.As(descriptor)); } else { diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index 0a540f335f3985..b892dc69728e55 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -239,41 +239,32 @@ public override string ToString() var vsb = new ValueStringBuilder(stackalloc char[512]); - if (GetEntries() is HeaderEntry[] entries) + foreach (HeaderEntry entry in GetEntries()) { - foreach (HeaderEntry entry in entries) - { - if (!entry.Key.HasValue) - { - // An entry without a value indicates the end of the header collection - break; - } + vsb.Append(entry.Key.Name); + vsb.Append(": "); - vsb.Append(entry.Key.Name); - vsb.Append(": "); + GetStoreValuesAsStringOrStringArray(entry.Key, entry.Value, out string? singleValue, out string[]? multiValue); + Debug.Assert(singleValue is not null ^ multiValue is not null); - GetStoreValuesAsStringOrStringArray(entry.Key, entry.Value, out string? singleValue, out string[]? multiValue); - Debug.Assert(singleValue is not null ^ multiValue is not null); + if (singleValue is not null) + { + vsb.Append(singleValue); + } + else + { + // Note that if we get multiple values for a header that doesn't support multiple values, we'll + // just separate the values using a comma (default separator). + string separator = entry.Key.Separator; - if (singleValue is not null) + for (int i = 0; i < multiValue!.Length; i++) { - vsb.Append(singleValue); + if (i != 0) vsb.Append(separator); + vsb.Append(multiValue[i]); } - else - { - // Note that if we get multiple values for a header that doesn't support multiple values, we'll - // just separate the values using a comma (default separator). - string separator = entry.Key.Separator; - - for (int i = 0; i < multiValue!.Length; i++) - { - if (i != 0) vsb.Append(separator); - vsb.Append(multiValue[i]); - } - } - - vsb.Append(Environment.NewLine); } + + vsb.Append(Environment.NewLine); } return vsb.ToString(); @@ -307,27 +298,19 @@ public IEnumerator>> GetEnumerator() => private IEnumerator>> GetEnumeratorCore() { - HeaderEntry[] entries = GetEntriesGroupedByName()!; + HeaderEntry[] entries = GetEntriesGroupedByName(out int numberOfEntries)!; - for (int i = 0; i < entries.Length; i++) + for (int i = 0; i < numberOfEntries && i < entries.Length; i++) { HeaderEntry entry = entries[i]; - HeaderDescriptor descriptor = entry.Key; - if (!descriptor.HasValue) - { - // An entry without a value indicates the end of the header collection - break; - } - - object value = entry.Value; - if (value is not HeaderStoreItemInfo info) + if (entry.Value is not HeaderStoreItemInfo info) { // To retain consistent semantics, we need to upgrade a raw string to a HeaderStoreItemInfo // during enumeration so that we can parse the raw value in order to a) return // the correct set of parsed values, and b) update the instance for subsequent enumerations // to reflect that parsing. - info = new HeaderStoreItemInfo() { RawValue = value }; + info = new HeaderStoreItemInfo() { RawValue = entry.Value }; if (EntriesAreLiveView) { @@ -335,15 +318,15 @@ private IEnumerator>> GetEnumeratorCore } else { - Debug.Assert(ContainsEntry(descriptor)); - ((Dictionary)_headerStore!)[descriptor] = info; + Debug.Assert(ContainsEntry(entry.Key)); + ((Dictionary)_headerStore!)[entry.Key] = info; } } // Make sure we parse all raw values before returning the result. Note that this has to be // done before we calculate the array length (next line): A raw value may contain a list of // values. - if (!ParseRawHeaderValues(descriptor, info)) + if (!ParseRawHeaderValues(entry.Key, info)) { // We saw an invalid header value (contains newline chars) and deleted it. @@ -351,12 +334,13 @@ private IEnumerator>> GetEnumeratorCore if (EntriesAreLiveView) { i--; + numberOfEntries--; } } else { - string[] values = GetStoreValuesAsStringArray(descriptor, info); - yield return new KeyValuePair>(descriptor.Name, values); + string[] values = GetStoreValuesAsStringArray(entry.Key, info); + yield return new KeyValuePair>(entry.Key.Name, values); } } } @@ -537,29 +521,20 @@ internal virtual void AddHeaders(HttpHeaders sourceHeaders) Debug.Assert(sourceHeaders != null); Debug.Assert(GetType() == sourceHeaders.GetType(), "Can only copy headers from an instance of the same type."); - if (sourceHeaders.GetEntriesGroupedByName() is HeaderEntry[] sourceEntries) + foreach (HeaderEntry entry in sourceHeaders.GetEntriesGroupedByName()) { - foreach (HeaderEntry entry in sourceEntries) + // Only add header values if they're not already set on the message. Note that we don't merge + // collections: If both the default headers and the message have set some values for a certain + // header, then we don't try to merge the values. + if (!ContainsEntry(entry.Key)) { - if (!entry.Key.HasValue) + object value = entry.Value; + if (value is HeaderStoreItemInfo info) { - // An entry without a value indicates the end of the header collection - break; + value = CloneHeaderInfo(entry.Key, info); } - // Only add header values if they're not already set on the message. Note that we don't merge - // collections: If both the default headers and the message have set some values for a certain - // header, then we don't try to merge the values. - if (!ContainsEntry(entry.Key)) - { - object value = entry.Value; - if (value is HeaderStoreItemInfo info) - { - value = CloneHeaderInfo(entry.Key, info); - } - - AddEntryToStore(new HeaderEntry(entry.Key, value)); - } + AddEntryToStore(new HeaderEntry(entry.Key, value)); } } } @@ -1263,8 +1238,10 @@ internal bool CanAddParsedValue(HttpHeaderParser parser) private const int InitialCapacity = 4; internal const int ArrayThreshold = 64; // Above this threshold, header ordering will not be preserved - internal HeaderEntry[]? GetEntries() + internal HeaderEntry[]? GetEntries(out int numberOfEntries) { + numberOfEntries = _count; + object? store = _headerStore; if (store is null) { @@ -1276,11 +1253,12 @@ internal bool CanAddParsedValue(HttpHeaderParser parser) } else { - return GetEntriesFromDictionary(Unsafe.As>(store)); + return GetEntriesFromDictionary(); } - static HeaderEntry[] GetEntriesFromDictionary(Dictionary dictionary) + HeaderEntry[] GetEntriesFromDictionary() { + var dictionary = (Dictionary)_headerStore!; var entries = new HeaderEntry[dictionary.Count]; int i = 0; foreach (KeyValuePair entry in dictionary) @@ -1295,7 +1273,18 @@ static HeaderEntry[] GetEntriesFromDictionary(Dictionary GetEntries() + { + return new ReadOnlySpan(GetEntries(out int numberOfEntries), 0, numberOfEntries); + } + + internal HeaderEntry[]? GetEntriesGroupedByName(out int numberOfEntries) + { + GroupAllEntriesByName(); + return GetEntries(out numberOfEntries); + } + + internal ReadOnlySpan GetEntriesGroupedByName() { GroupAllEntriesByName(); return GetEntries(); @@ -1318,23 +1307,13 @@ private void GroupAllEntriesByName() { _entriesMayBeUngrouped = false; - for (int i = 0; i < entries.Length; i++) + for (int i = 0; i < _count && i < entries.Length; i++) { HeaderDescriptor key = entries[i].Key; - if (!key.HasValue) - { - break; - } - for (int j = i + 1; j < entries.Length; j++) + for (int j = i + 1; j < _count && (uint)j < (uint)entries.Length; j++) { - HeaderDescriptor secondKey = entries[j].Key; - if (!secondKey.HasValue) - { - break; - } - - if (secondKey.Equals(key)) + if (key.Equals(entries[j].Key)) { MergeEntryValues(ref entries[i].Value, entries[j].Value); RemoveAt(entries, j); @@ -1383,36 +1362,16 @@ private ref object GetValueRefGroupedByNameOrNullRef(HeaderDescriptor key) { if (store is HeaderEntry[] entries) { - for (int i = 0; i < entries.Length; i++) + for (int i = 0; i < _count && i < entries.Length; i++) { - ref HeaderEntry entry = ref entries[i]; - if (!entry.Key.HasValue) - { - break; - } - - if (entry.Key.Equals(key)) + if (key.Equals(entries[i].Key)) { if (_entriesMayBeUngrouped) { - for (int j = i + 1; j < entries.Length; j++) - { - HeaderDescriptor secondKey = entries[j].Key; - if (!secondKey.HasValue) - { - break; - } - - if (secondKey.Equals(key)) - { - MergeEntryValues(ref entry.Value, entries[j].Value); - RemoveAt(entries, j); - j--; - } - } + MergeEntriesAfter(entries, i); } - return ref entry.Value; + return ref entries[i].Value; } } } @@ -1423,6 +1382,25 @@ private ref object GetValueRefGroupedByNameOrNullRef(HeaderDescriptor key) } return ref Unsafe.NullRef(); + + void MergeEntriesAfter(HeaderEntry[] entries, int i) + { + if ((uint)i < (uint)entries.Length) + { + ref HeaderEntry entry = ref entries[i]; + HeaderDescriptor key = entry.Key; + + for (i++; i < _count && (uint)i < (uint)entries.Length; i++) + { + if (key.Equals(entries[i].Key)) + { + MergeEntryValues(ref entry.Value, entries[i].Value); + RemoveAt(entries, i); + i--; + } + } + } + } } private void AddEntryToStore(HeaderEntry entry) @@ -1463,6 +1441,7 @@ void InitializeOrAddToDictionary(ref HeaderEntry entry) void GrowAndAddEntry(ref HeaderEntry entry) { var entries = (HeaderEntry[])_headerStore!; + Debug.Assert(entries.Length == _count); if (entries.Length == ArrayThreshold) { _headerStore = new Dictionary(ArrayThreshold); @@ -1505,14 +1484,9 @@ internal bool ContainsEntry(HeaderDescriptor key) { if (store is HeaderEntry[] entries) { - for (int i = 0; i < entries.Length; i++) + for (int i = 0; i < _count && i < entries.Length; i++) { - HeaderDescriptor entryKey = entries[i].Key; - if (!entryKey.HasValue) - { - break; - } - else if (entryKey.Equals(key)) + if (key.Equals(entries[i].Key)) { return true; } @@ -1546,50 +1520,41 @@ public void Clear() internal bool Remove(HeaderDescriptor key) { + bool removedAny = false; + object? store = _headerStore; if (store is not null) { if (store is HeaderEntry[] entries) { - bool removedAny = false; - - for (int i = 0; i < entries.Length; i++) + for (int i = _count - 1; i >= 0 && (uint)i < (uint)entries.Length; i--) { - HeaderDescriptor entryKey = entries[i].Key; - if (!entryKey.HasValue) - { - break; - } - - if (entryKey.Equals(key)) + if (key.Equals(entries[i].Key)) { removedAny = true; RemoveAt(entries, i); - i--; } } - - return removedAny; } else if (Unsafe.As>(store).Remove(key)) { _count--; - return true; + removedAny = true; } } - return false; + return removedAny; } - private void RemoveAt(HeaderEntry[] entries, int index) + private void RemoveAt(HeaderEntry[] entries, int i) { int count = _count--; - while ((uint)(index + 1) < (uint)entries.Length && index + 1 < count) + while ((uint)(i + 1) < (uint)entries.Length && i + 1 < count) { - entries[index] = entries[index + 1]; - index++; + entries[i] = entries[i + 1]; + i++; } - entries[index] = default; + entries[i] = default; } #endregion // _headerStore implementation diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs index 4660f08438d758..418f4219643f1f 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs @@ -83,8 +83,8 @@ public bool TryGetValues(string headerName, out HeaderStringValues values) /// Gets an enumerator that iterates through the . /// An enumerator that iterates through the . public Enumerator GetEnumerator() => - _headers is HttpHeaders headers && headers.GetEntriesGroupedByName() is HeaderEntry[] entries ? - new Enumerator(entries) : + _headers is HttpHeaders headers && headers.GetEntriesGroupedByName(out int numberOfEntries) is HeaderEntry[] entries ? + new Enumerator(entries, numberOfEntries) : default; /// @@ -121,14 +121,14 @@ IEnumerable IReadOnlyDictionary. public struct Enumerator : IEnumerator> { private readonly HeaderEntry[] _entries; + private readonly int _numberOfEntries; private int _index; private KeyValuePair _current; - /// Initializes the enumerator. - /// The underlying header entries. - internal Enumerator(HeaderEntry[] entries) + internal Enumerator(HeaderEntry[] entries, int numberOfEntries) { _entries = entries; + _numberOfEntries = numberOfEntries; _index = 0; _current = default; } @@ -137,21 +137,18 @@ internal Enumerator(HeaderEntry[] entries) public bool MoveNext() { int index = _index; - if (_entries is HeaderEntry[] entries && (uint)index < (uint)entries.Length) + if (_entries is HeaderEntry[] entries && index < _numberOfEntries && (uint)index < (uint)entries.Length) { HeaderEntry entry = entries[index]; _index++; - if (entry.Key.HasValue) // An entry without a value indicates the end of the header collection - { - HttpHeaders.GetStoreValuesAsStringOrStringArray(entry.Key, entry.Value, out string? singleValue, out string[]? multiValue); - Debug.Assert(singleValue is not null ^ multiValue is not null); + HttpHeaders.GetStoreValuesAsStringOrStringArray(entry.Key, entry.Value, out string? singleValue, out string[]? multiValue); + Debug.Assert(singleValue is not null ^ multiValue is not null); - _current = new KeyValuePair( - entry.Key.Name, - singleValue is not null ? new HeaderStringValues(entry.Key, singleValue) : new HeaderStringValues(entry.Key, multiValue!)); - return true; - } + _current = new KeyValuePair( + entry.Key.Name, + singleValue is not null ? new HeaderStringValues(entry.Key, singleValue) : new HeaderStringValues(entry.Key, multiValue!)); + return true; } _current = default; diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs index b159573f012f02..7ac224492519df 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs @@ -1337,22 +1337,11 @@ private void WriteHeaderCollection(HttpRequestMessage request, HttpHeaders heade { if (NetEventSource.Log.IsEnabled()) Trace(""); - if (headers.GetEntries() is not HeaderEntry[] entries) - { - return; - } - HeaderEncodingSelector? encodingSelector = _pool.Settings._requestHeaderEncodingSelector; ref string[]? tmpHeaderValuesArray = ref t_headerValues; - foreach (HeaderEntry header in entries) + foreach (HeaderEntry header in headers.GetEntries()) { - if (!header.Key.HasValue) - { - // An entry without a value indicates the end of the header collection - break; - } - int headerValuesCount = HttpHeaders.GetStoreValuesIntoStringArray(header.Key, header.Value, ref tmpHeaderValuesArray); Debug.Assert(headerValuesCount > 0, "No values for header??"); ReadOnlySpan headerValues = tmpHeaderValuesArray.AsSpan(0, headerValuesCount); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs index c2998c654c7581..e87a155d5f044e 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs @@ -622,21 +622,10 @@ private void BufferHeaders(HttpRequestMessage request) // TODO: special-case Content-Type for static table values values? private void BufferHeaderCollection(HttpHeaders headers) { - if (headers.GetEntries() is not HeaderEntry[] entries) - { - return; - } - HeaderEncodingSelector? encodingSelector = _connection.Pool.Settings._requestHeaderEncodingSelector; - foreach (HeaderEntry header in entries) + foreach (HeaderEntry header in headers.GetEntries()) { - if (!header.Key.HasValue) - { - // An entry without a value indicates the end of the header collection - break; - } - int headerValuesCount = HttpHeaders.GetStoreValuesIntoStringArray(header.Key, header.Value, ref _headerValues); Debug.Assert(headerValuesCount > 0, "No values for header??"); ReadOnlySpan headerValues = _headerValues.AsSpan(0, headerValuesCount); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs index 139ebc46a0668c..18d3f0cd331838 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs @@ -258,15 +258,11 @@ private async ValueTask WriteHeadersAsync(HttpHeaders headers, string? cookiesFr { Debug.Assert(_currentRequest != null); - if (headers.GetEntries() is HeaderEntry[] entries) + if (headers.GetEntries(out int numberOfEntries) is HeaderEntry[] entries) { - foreach (HeaderEntry header in entries) + for (int i = 0; i < numberOfEntries; i++) { - if (!header.Key.HasValue) - { - // An entry without a value indicates the end of the header collection - break; - } + HeaderEntry header = entries[i]; if (header.Key.IsKnownHeader(out KnownHeader? knownHeader, out string? headerName)) { @@ -299,10 +295,10 @@ private async ValueTask WriteHeadersAsync(HttpHeaders headers, string? cookiesFr { string separator = header.Key.Separator; - for (int i = 1; i < headerValuesCount; i++) + for (int j = 1; j < headerValuesCount; j++) { await WriteAsciiStringAsync(separator, async).ConfigureAwait(false); - await WriteStringAsync(_headerValues[i], async, valueEncoding).ConfigureAwait(false); + await WriteStringAsync(_headerValues[j], async, valueEncoding).ConfigureAwait(false); } } } diff --git a/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs b/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs index 2c41c4ce2f1fdb..47d4507ca1081a 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs @@ -62,12 +62,6 @@ private static Memory HPackEncode(HttpHeaders headers, Encoding? valueEnco foreach (HeaderEntry header in headers.GetEntries()) { - if (!header.Key.HasValue) - { - // An entry without a value indicates the end of the header collection - break; - } - int headerValuesCount = HttpHeaders.GetStoreValuesIntoStringArray(header.Key, header.Value, ref headerValues); Assert.InRange(headerValuesCount, 0, int.MaxValue); ReadOnlySpan headerValuesSpan = headerValues.AsSpan(0, headerValuesCount); diff --git a/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs b/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs index b9d5abb5d450ca..79196e77c75ea9 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs @@ -1711,7 +1711,7 @@ public void GetEnumerator_UseExplicitInterfaceImplementation_EnumeratorReturnsNo } [Fact] - public void GetEnumerator_InvalidValueBetweenValidHeaders_EnumeratorReturnsAllValidValuesAndRemovedInvalidValue() + public void GetEnumerator_InvalidValueBetweenValidHeaders_EnumeratorReturnsAllValidValuesAndRemovesInvalidValue() { MockHeaders headers = new MockHeaders(); headers.TryAddWithoutValidation("foo", "fooValue"); @@ -1725,7 +1725,7 @@ public void GetEnumerator_InvalidValueBetweenValidHeaders_EnumeratorReturnsAllVa Assert.Equal("barValue", Assert.Single(Assert.Contains("bar", dict))); Assert.Equal(2, headers.Count); - Assert.DoesNotContain("invalid", dict); + Assert.False(headers.NonValidated.Contains("invalid")); } [Fact] From 58fb231c98ab9e23c85e659fec2c99b8cf50ae81 Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Thu, 6 Jan 2022 21:03:18 +0100 Subject: [PATCH 15/22] Revert all HeaderDescriptor changes from PR --- .../Net/Http/Headers/HeaderDescriptor.cs | 96 +++++-------------- .../Net/Http/Headers/HeaderStringValues.cs | 2 +- .../Net/Http/Headers/HttpHeaderParser.cs | 21 +++- .../System/Net/Http/Headers/HttpHeaders.cs | 5 +- .../System/Net/Http/Headers/KnownHeader.cs | 2 - .../SocketsHttpHandler/Http2Connection.cs | 20 +++- .../SocketsHttpHandler/Http3RequestStream.cs | 22 ++++- .../Http/SocketsHttpHandler/HttpConnection.cs | 17 ++-- .../SocketsHttpHandler/HttpConnectionBase.cs | 4 +- .../UnitTests/HPack/HPackRoundtripTests.cs | 20 +++- .../UnitTests/Headers/HeaderEncodingTest.cs | 7 +- .../UnitTests/Headers/KnownHeadersTest.cs | 3 +- 12 files changed, 109 insertions(+), 110 deletions(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs index e2f73aa50a28c1..229490b5432dd9 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderDescriptor.cs @@ -4,7 +4,6 @@ using System.Buffers; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; -using System.Runtime.CompilerServices; using System.Text; using System.Text.Unicode; @@ -16,82 +15,32 @@ namespace System.Net.Http.Headers // Use HeaderDescriptor.TryGet to resolve an arbitrary header name to a HeaderDescriptor. internal readonly struct HeaderDescriptor : IEquatable { - /// - /// Either a or . - /// - private readonly object _descriptor; + private readonly string _headerName; + private readonly KnownHeader? _knownHeader; public HeaderDescriptor(KnownHeader knownHeader) { - _descriptor = knownHeader; + _knownHeader = knownHeader; + _headerName = knownHeader.Name; } // This should not be used directly; use static TryGet below internal HeaderDescriptor(string headerName) { - _descriptor = headerName; + _headerName = headerName; + _knownHeader = null; } - public string Name - { - get - { - object descriptor = _descriptor; - return descriptor.GetType() == typeof(KnownHeader) ? - Unsafe.As(descriptor).Name : - Unsafe.As(descriptor); - } - } - - public bool IsKnownHeader([NotNullWhen(true)] out KnownHeader? knownHeader, [NotNullWhen(false)] out string? headerName) - { - object descriptor = _descriptor; - if (descriptor.GetType() == typeof(KnownHeader)) - { - knownHeader = Unsafe.As(descriptor); - Unsafe.SkipInit(out headerName); - return true; - } - else - { - Unsafe.SkipInit(out knownHeader); - headerName = Unsafe.As(descriptor); - return false; - } - } - - public HttpHeaderParser? Parser => (_descriptor as KnownHeader)?.Parser; - public HttpHeaderType HeaderType => _descriptor is KnownHeader knownHeader ? knownHeader.HeaderType : HttpHeaderType.Custom; - public string Separator => _descriptor is KnownHeader knownHeader ? knownHeader.Separator : HttpHeaderParser.DefaultSeparator; - - public bool Equals(KnownHeader other) => ReferenceEquals(_descriptor, other); - - public bool Equals(HeaderDescriptor other) - { - object descriptor = _descriptor; - if (descriptor.GetType() == typeof(string)) - { - return string.Equals(Unsafe.As(descriptor), other._descriptor as string, StringComparison.OrdinalIgnoreCase); - } - else - { - return ReferenceEquals(descriptor, other._descriptor); - } - } - - public override int GetHashCode() - { - object descriptor = _descriptor; - if (descriptor.GetType() == typeof(string)) - { - return StringComparer.OrdinalIgnoreCase.GetHashCode(Unsafe.As(descriptor)); - } - else - { - return descriptor.GetHashCode(); - } - } + public string Name => _headerName; + public HttpHeaderParser? Parser => _knownHeader?.Parser; + public HttpHeaderType HeaderType => _knownHeader == null ? HttpHeaderType.Custom : _knownHeader.HeaderType; + public KnownHeader? KnownHeader => _knownHeader; + public bool Equals(HeaderDescriptor other) => + _knownHeader == null ? + string.Equals(_headerName, other._headerName, StringComparison.OrdinalIgnoreCase) : + _knownHeader == other._knownHeader; + public override int GetHashCode() => _knownHeader?.GetHashCode() ?? StringComparer.OrdinalIgnoreCase.GetHashCode(_headerName); public override bool Equals(object? obj) => throw new InvalidOperationException(); // Ensure this is never called, to avoid boxing // Returns false for invalid header name. @@ -163,9 +112,9 @@ internal static bool TryGetStaticQPackHeader(int index, out HeaderDescriptor des public HeaderDescriptor AsCustomHeader() { - Debug.Assert(_descriptor is KnownHeader); - Debug.Assert(HeaderType != HttpHeaderType.Custom); - return new HeaderDescriptor(Name); + Debug.Assert(_knownHeader != null); + Debug.Assert(_knownHeader.HeaderType != HttpHeaderType.Custom); + return new HeaderDescriptor(_knownHeader.Name); } public string GetHeaderValue(ReadOnlySpan headerValue, Encoding? valueEncoding) @@ -176,9 +125,10 @@ public string GetHeaderValue(ReadOnlySpan headerValue, Encoding? valueEnco } // If it's a known header value, use the known value instead of allocating a new string. - if (_descriptor is KnownHeader knownHeader) + if (_knownHeader != null) { - if (knownHeader.KnownValues is string[] knownValues) + string[]? knownValues = _knownHeader.KnownValues; + if (knownValues != null) { for (int i = 0; i < knownValues.Length; i++) { @@ -189,7 +139,7 @@ public string GetHeaderValue(ReadOnlySpan headerValue, Encoding? valueEnco } } - if (knownHeader == KnownHeaders.ContentType) + if (_knownHeader == KnownHeaders.ContentType) { string? contentType = GetKnownContentType(headerValue); if (contentType != null) @@ -197,7 +147,7 @@ public string GetHeaderValue(ReadOnlySpan headerValue, Encoding? valueEnco return contentType; } } - else if (knownHeader == KnownHeaders.Location) + else if (_knownHeader == KnownHeaders.Location) { // Normally Location should be in ISO-8859-1 but occasionally some servers respond with UTF-8. if (TryDecodeUtf8(headerValue, out string? decoded)) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderStringValues.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderStringValues.cs index 6b5f4d2a666a2a..a313a2306e78f0 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderStringValues.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HeaderStringValues.cs @@ -45,7 +45,7 @@ internal HeaderStringValues(HeaderDescriptor descriptor, string[] values) public override string ToString() => _value switch { string value => value, - string[] values => string.Join(_header.Separator, values), + string[] values => string.Join(_header.Parser is HttpHeaderParser parser && parser.SupportsMultipleValues ? parser.Separator : HttpHeaderParser.DefaultSeparator, values), _ => string.Empty, }; diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaderParser.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaderParser.cs index 86d689256e6a59..2fa79f1f41812e 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaderParser.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaderParser.cs @@ -12,14 +12,21 @@ internal abstract class HttpHeaderParser internal const string DefaultSeparator = ", "; private readonly bool _supportsMultipleValues; - private readonly string _separator; + private readonly string? _separator; public bool SupportsMultipleValues { get { return _supportsMultipleValues; } } - public string Separator => _separator; + public string? Separator + { + get + { + Debug.Assert(_supportsMultipleValues); + return _separator; + } + } // If ValueType implements Equals() as required, there is no need to provide a comparer. A comparer is needed // e.g. if we want to compare strings using case-insensitive comparison. @@ -29,8 +36,14 @@ public virtual IEqualityComparer? Comparer } protected HttpHeaderParser(bool supportsMultipleValues) - : this(supportsMultipleValues, DefaultSeparator) - { } + { + _supportsMultipleValues = supportsMultipleValues; + + if (supportsMultipleValues) + { + _separator = DefaultSeparator; + } + } protected HttpHeaderParser(bool supportsMultipleValues, string separator) { diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index b892dc69728e55..dc2c879726f421 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -255,7 +255,7 @@ public override string ToString() { // Note that if we get multiple values for a header that doesn't support multiple values, we'll // just separate the values using a comma (default separator). - string separator = entry.Key.Separator; + string? separator = entry.Key.Parser is HttpHeaderParser parser && parser.SupportsMultipleValues ? parser.Separator : HttpHeaderParser.DefaultSeparator; for (int i = 0; i < multiValue!.Length; i++) { @@ -284,7 +284,8 @@ internal string GetHeaderString(HeaderDescriptor descriptor) // Note that if we get multiple values for a header that doesn't support multiple values, we'll // just separate the values using a comma (default separator). - return string.Join(descriptor.Separator, multiValue!); + string? separator = descriptor.Parser != null && descriptor.Parser.SupportsMultipleValues ? descriptor.Parser.Separator : HttpHeaderParser.DefaultSeparator; + return string.Join(separator, multiValue!); } return string.Empty; diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/KnownHeader.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/KnownHeader.cs index d835451a15e290..0163db073a852e 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/KnownHeader.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/KnownHeader.cs @@ -47,7 +47,5 @@ public KnownHeader(string name, HttpHeaderType headerType, HttpHeaderParser? par public string[]? KnownValues { get; } public byte[] AsciiBytesWithColonSpace { get; } public HeaderDescriptor Descriptor => new HeaderDescriptor(this); - - public string Separator => Parser is HttpHeaderParser parser ? parser.Separator : HttpHeaderParser.DefaultSeparator; } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs index 7ac224492519df..1a1b4b00bbb60f 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs @@ -1348,7 +1348,8 @@ private void WriteHeaderCollection(HttpRequestMessage request, HttpHeaders heade Encoding? valueEncoding = encodingSelector?.Invoke(header.Key.Name, request); - if (header.Key.IsKnownHeader(out KnownHeader? knownHeader, out string? headerName)) + KnownHeader? knownHeader = header.Key.KnownHeader; + if (knownHeader != null) { // The Host header is not sent for HTTP2 because we send the ":authority" pseudo-header instead // (see pseudo-header handling below in WriteHeaders). @@ -1372,8 +1373,19 @@ private void WriteHeaderCollection(HttpRequestMessage request, HttpHeaders heade // For all other known headers, send them via their pre-encoded name and the associated value. WriteBytes(knownHeader.Http2EncodedName, ref headerBuffer); - - string? separator = headerValues.Length > 1 ? knownHeader.Separator : null; + string? separator = null; + if (headerValues.Length > 1) + { + HttpHeaderParser? parser = header.Key.Parser; + if (parser != null && parser.SupportsMultipleValues) + { + separator = parser.Separator; + } + else + { + separator = HttpHeaderParser.DefaultSeparator; + } + } WriteLiteralHeaderValues(headerValues, separator, valueEncoding, ref headerBuffer); } @@ -1381,7 +1393,7 @@ private void WriteHeaderCollection(HttpRequestMessage request, HttpHeaders heade else { // The header is not known: fall back to just encoding the header name and value(s). - WriteLiteralHeader(headerName, headerValues, valueEncoding, ref headerBuffer); + WriteLiteralHeader(header.Key.Name, headerValues, valueEncoding, ref headerBuffer); } } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs index e87a155d5f044e..9a7ebccc416a8e 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs @@ -632,7 +632,8 @@ private void BufferHeaderCollection(HttpHeaders headers) Encoding? valueEncoding = encodingSelector?.Invoke(header.Key.Name, _request); - if (header.Key.IsKnownHeader(out KnownHeader? knownHeader, out string? headerName)) + KnownHeader? knownHeader = header.Key.KnownHeader; + if (knownHeader != null) { // The Host header is not sent for HTTP/3 because we send the ":authority" pseudo-header instead // (see pseudo-header handling below in WriteHeaders). @@ -656,8 +657,19 @@ private void BufferHeaderCollection(HttpHeaders headers) // For all other known headers, send them via their pre-encoded name and the associated value. BufferBytes(knownHeader.Http3EncodedName); - - string? separator = headerValues.Length > 1 ? knownHeader.Separator : null; + string? separator = null; + if (headerValues.Length > 1) + { + HttpHeaderParser? parser = header.Key.Parser; + if (parser != null && parser.SupportsMultipleValues) + { + separator = parser.Separator; + } + else + { + separator = HttpHeaderParser.DefaultSeparator; + } + } BufferLiteralHeaderValues(headerValues, separator, valueEncoding); } @@ -665,7 +677,7 @@ private void BufferHeaderCollection(HttpHeaders headers) else { // The header is not known: fall back to just encoding the header name and value(s). - BufferLiteralHeaderWithoutNameReference(headerName, headerValues, HttpHeaderParser.DefaultSeparator, valueEncoding); + BufferLiteralHeaderWithoutNameReference(header.Key.Name, headerValues, HttpHeaderParser.DefaultSeparator, valueEncoding); } } } @@ -888,7 +900,7 @@ private void OnHeader(int? staticIndex, HeaderDescriptor descriptor, string? sta { if (descriptor.Name[0] == ':') { - if (!descriptor.Equals(KnownHeaders.PseudoStatus)) + if (descriptor.KnownHeader != KnownHeaders.PseudoStatus) { if (NetEventSource.Log.IsEnabled()) Trace($"Received unknown pseudo-header '{descriptor.Name}'."); throw new Http3ConnectionException(Http3ErrorCode.ProtocolError); diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs index 18d3f0cd331838..d1f058a0c13109 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs @@ -264,13 +264,13 @@ private async ValueTask WriteHeadersAsync(HttpHeaders headers, string? cookiesFr { HeaderEntry header = entries[i]; - if (header.Key.IsKnownHeader(out KnownHeader? knownHeader, out string? headerName)) + if (header.Key.KnownHeader != null) { - await WriteBytesAsync(knownHeader.AsciiBytesWithColonSpace, async).ConfigureAwait(false); + await WriteBytesAsync(header.Key.KnownHeader.AsciiBytesWithColonSpace, async).ConfigureAwait(false); } else { - await WriteAsciiStringAsync(headerName, async).ConfigureAwait(false); + await WriteAsciiStringAsync(header.Key.Name, async).ConfigureAwait(false); await WriteTwoBytesAsync((byte)':', (byte)' ', async).ConfigureAwait(false); } @@ -282,7 +282,7 @@ private async ValueTask WriteHeadersAsync(HttpHeaders headers, string? cookiesFr await WriteStringAsync(_headerValues[0], async, valueEncoding).ConfigureAwait(false); - if (cookiesFromContainer != null && header.Key.Equals(KnownHeaders.Cookie)) + if (cookiesFromContainer != null && header.Key.KnownHeader == KnownHeaders.Cookie) { await WriteTwoBytesAsync((byte)';', (byte)' ', async).ConfigureAwait(false); await WriteStringAsync(cookiesFromContainer, async, valueEncoding).ConfigureAwait(false); @@ -293,7 +293,12 @@ private async ValueTask WriteHeadersAsync(HttpHeaders headers, string? cookiesFr // Some headers such as User-Agent and Server use space as a separator (see: ProductInfoHeaderParser) if (headerValuesCount > 1) { - string separator = header.Key.Separator; + HttpHeaderParser? parser = header.Key.Parser; + string separator = HttpHeaderParser.DefaultSeparator; + if (parser != null && parser.SupportsMultipleValues) + { + separator = parser.Separator!; + } for (int j = 1; j < headerValuesCount; j++) { @@ -1051,7 +1056,7 @@ private static void ParseHeaderNameValue(HttpConnection connection, ReadOnlySpan throw new HttpRequestException(SR.Format(SR.net_http_invalid_response_header_name, Encoding.ASCII.GetString(line.Slice(0, pos)))); } - if (isFromTrailer && (descriptor.HeaderType & HttpHeaderType.NonTrailing) == HttpHeaderType.NonTrailing) + if (isFromTrailer && descriptor.KnownHeader != null && (descriptor.KnownHeader.HeaderType & HttpHeaderType.NonTrailing) == HttpHeaderType.NonTrailing) { // Disallowed trailer fields. // A recipient MUST ignore fields that are forbidden to be sent in a trailer. diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs index 45eae0fa57649f..3c266f20da59f5 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionBase.cs @@ -24,8 +24,8 @@ internal abstract class HttpConnectionBase : IDisposable, IHttpTrace public string GetResponseHeaderValueWithCaching(HeaderDescriptor descriptor, ReadOnlySpan value, Encoding? valueEncoding) { return - descriptor.Equals(KnownHeaders.Date) ? GetOrAddCachedValue(ref _lastDateHeaderValue, descriptor, value, valueEncoding) : - descriptor.Equals(KnownHeaders.Server) ? GetOrAddCachedValue(ref _lastServerHeaderValue, descriptor, value, valueEncoding) : + ReferenceEquals(descriptor.KnownHeader, KnownHeaders.Date) ? GetOrAddCachedValue(ref _lastDateHeaderValue, descriptor, value, valueEncoding) : + ReferenceEquals(descriptor.KnownHeader, KnownHeaders.Server) ? GetOrAddCachedValue(ref _lastServerHeaderValue, descriptor, value, valueEncoding) : descriptor.GetHeaderValue(value, valueEncoding); static string GetOrAddCachedValue([NotNull] ref string? cache, HeaderDescriptor descriptor, ReadOnlySpan value, Encoding? encoding) diff --git a/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs b/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs index 47d4507ca1081a..d2f81304dd912a 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs @@ -66,19 +66,31 @@ private static Memory HPackEncode(HttpHeaders headers, Encoding? valueEnco Assert.InRange(headerValuesCount, 0, int.MaxValue); ReadOnlySpan headerValuesSpan = headerValues.AsSpan(0, headerValuesCount); - if (header.Key.IsKnownHeader(out KnownHeader? knownHeader, out string? headerName)) + KnownHeader knownHeader = header.Key.KnownHeader; + if (knownHeader != null) { // For all other known headers, send them via their pre-encoded name and the associated value. WriteBytes(knownHeader.Http2EncodedName); - - string? separator = headerValuesSpan.Length > 1 ? knownHeader.Separator : null; + string separator = null; + if (headerValuesSpan.Length > 1) + { + HttpHeaderParser parser = header.Key.Parser; + if (parser != null && parser.SupportsMultipleValues) + { + separator = parser.Separator; + } + else + { + separator = HttpHeaderParser.DefaultSeparator; + } + } WriteLiteralHeaderValues(headerValuesSpan, separator); } else { // The header is not known: fall back to just encoding the header name and value(s). - WriteLiteralHeader(headerName, headerValuesSpan); + WriteLiteralHeader(header.Key.Name, headerValuesSpan); } } diff --git a/src/libraries/System.Net.Http/tests/UnitTests/Headers/HeaderEncodingTest.cs b/src/libraries/System.Net.Http/tests/UnitTests/Headers/HeaderEncodingTest.cs index 57a5b37d91a5cf..d7f41e0deeee61 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/Headers/HeaderEncodingTest.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/Headers/HeaderEncodingTest.cs @@ -24,15 +24,12 @@ public void RoundTripsUtf8(string input) byte[] encoded = Encoding.UTF8.GetBytes(input); Assert.True(HeaderDescriptor.TryGet("custom-header", out HeaderDescriptor descriptor)); - Assert.False(descriptor.IsKnownHeader(out _, out string? headerName)); - Assert.Equal("custom-header", headerName); + Assert.Null(descriptor.KnownHeader); string roundtrip = descriptor.GetHeaderValue(encoded, Encoding.UTF8); Assert.Equal(input, roundtrip); Assert.True(HeaderDescriptor.TryGet("Cache-Control", out descriptor)); - Assert.True(descriptor.IsKnownHeader(out KnownHeader? knownHeader, out _)); - Assert.NotNull(knownHeader); - Assert.Equal("Cache-Control", knownHeader.Name); + Assert.NotNull(descriptor.KnownHeader); roundtrip = descriptor.GetHeaderValue(encoded, Encoding.UTF8); Assert.Equal(input, roundtrip); } diff --git a/src/libraries/System.Net.Http/tests/UnitTests/Headers/KnownHeadersTest.cs b/src/libraries/System.Net.Http/tests/UnitTests/Headers/KnownHeadersTest.cs index 419495e1fc3e0b..78fada7218b4ac 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/Headers/KnownHeadersTest.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/Headers/KnownHeadersTest.cs @@ -118,8 +118,7 @@ static void Validate(string name, KnownHeader h) Assert.NotNull(h); Assert.Same(h, KnownHeaders.TryGetKnownHeader(name)); - Assert.True(h.Descriptor.IsKnownHeader(out KnownHeader? knownHeader, out _)); - Assert.Same(h, knownHeader); + Assert.Same(h, h.Descriptor.KnownHeader); Assert.Equal(name, h.Name, StringComparer.OrdinalIgnoreCase); Assert.Equal(name, h.Descriptor.Name, StringComparer.OrdinalIgnoreCase); } From 2e4b79a69502f79751d9a6678de9f9245747d4ef Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Wed, 12 Jan 2022 17:57:39 +0100 Subject: [PATCH 16/22] Switch back to always grouping by name --- .../Net/Http/Http2LoopbackConnection.cs | 18 +- .../System/Net/Http/HttpClientHandlerTest.cs | 91 ----- .../System/Net/Http/Headers/HttpHeaders.cs | 359 +++++++----------- .../Http/Headers/HttpHeadersNonValidated.cs | 6 +- .../Http/SocketsHttpHandler/HttpConnection.cs | 4 +- 5 files changed, 148 insertions(+), 330 deletions(-) diff --git a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs index 01ac5cf7417254..b3bb701e540f36 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http2LoopbackConnection.cs @@ -28,7 +28,6 @@ public class Http2LoopbackConnection : GenericLoopbackConnection private readonly TimeSpan _timeout; private int _lastStreamId; private bool _expectClientDisconnect; - private List _dynamicHeaderTable = new(); private readonly byte[] _prefix = new byte[24]; public string PrefixString => Encoding.UTF8.GetString(_prefix, 0, _prefix.Length); @@ -491,15 +490,12 @@ private static (int bytesConsumed, string value) DecodeString(ReadOnlySpan new HttpHeaderData("www-authenticate", "") }; - private HttpHeaderData GetHeaderForIndex(int index) + private static HttpHeaderData GetHeaderForIndex(int index) { - index--; - return index < s_staticTable.Length - ? s_staticTable[index] - : _dynamicHeaderTable[index - s_staticTable.Length]; + return s_staticTable[index - 1]; } - private (int bytesConsumed, HttpHeaderData headerData) DecodeLiteralHeader(ReadOnlySpan headerBlock, byte prefixMask) + private static (int bytesConsumed, HttpHeaderData headerData) DecodeLiteralHeader(ReadOnlySpan headerBlock, byte prefixMask) { int i = 0; @@ -524,7 +520,7 @@ private HttpHeaderData GetHeaderForIndex(int index) return (i, new HttpHeaderData(name, value)); } - private (int bytesConsumed, HttpHeaderData headerData) DecodeHeader(ReadOnlySpan headerBlock) + private static (int bytesConsumed, HttpHeaderData headerData) DecodeHeader(ReadOnlySpan headerBlock) { int i = 0; @@ -540,11 +536,7 @@ private HttpHeaderData GetHeaderForIndex(int index) else if ((b & 0b11000000) == 0b01000000) { // Literal with indexing - (int bytesConsumed, HttpHeaderData headerData) = DecodeLiteralHeader(headerBlock, 0b00111111); - - _dynamicHeaderTable.Insert(0, headerData); - - return (bytesConsumed, headerData); + return DecodeLiteralHeader(headerBlock, 0b00111111); } else if ((b & 0b11100000) == 0b00100000) { diff --git a/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.cs b/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.cs index 9e67637af35570..48379e2b505509 100644 --- a/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.cs +++ b/src/libraries/Common/tests/System/Net/Http/HttpClientHandlerTest.cs @@ -357,97 +357,6 @@ await LoopbackServer.CreateClientAndServerAsync(async uri => }, server => server.AcceptConnectionSendCustomResponseAndCloseAsync($"HTTP/1.1 200 OK\r\n{invalidHeader}\r\n{LoopbackServer.CorsHeaders}Content-Length: 11\r\n\r\nhello world")); } - [Theory] - [InlineData(false, false, false)] - [InlineData(false, true, true)] - [InlineData(true, false, true)] - public async Task SendAsync_MultipleEntriesPerHeaderName_ValuesMayNotBeGrouped(bool manyHeaders, bool enumerateHeadersBeforeSend, bool valuesShouldBeGrouped) - { - Assert.True(valuesShouldBeGrouped || !manyHeaders, "Values will be grouped if we go over the HttpHeaders.ArrayThreshold"); - Assert.True(valuesShouldBeGrouped || !enumerateHeadersBeforeSend, "Enumerating the values forces them to be grouped by name"); - - if (PlatformDetection.IsBrowser) - { - valuesShouldBeGrouped = true; - } - -#if !NETCOREAPP - valuesShouldBeGrouped = true; -#endif - - await LoopbackServerFactory.CreateClientAndServerAsync(async uri => - { - var request = new HttpRequestMessage(HttpMethod.Get, uri) - { - Version = UseVersion - }; - - request.Headers.TryAddWithoutValidation("foo", "foo-single-1"); - request.Headers.TryAddWithoutValidation("bar", "bar-single-1"); - request.Headers.TryAddWithoutValidation("foo", new[] { "foo-multi-1", "foo-multi-2" }); - request.Headers.TryAddWithoutValidation("bar", "bar-single-2"); - request.Headers.TryAddWithoutValidation("foo", "foo-single-2"); - - if (manyHeaders) - { - for (int i = 0; i < 100; i++) - { - request.Headers.TryAddWithoutValidation($"dummy-{i}", "dummy"); - } - } - - if (enumerateHeadersBeforeSend) - { - _ = request.Headers.ToArray(); - } - - using HttpClient client = CreateHttpClient(); - - (await client.SendAsync(TestAsync, request)).Dispose(); - }, - async server => - { - HttpRequestData requestData = await server.HandleRequestAsync(HttpStatusCode.OK); - HttpHeaderData[] headers = requestData.Headers.Where(h => h.Name == "foo" || h.Name == "bar").ToArray(); - - if (valuesShouldBeGrouped) - { - Assert.Equal(2, headers.Length); - - if (manyHeaders) - { - // Ordering is not preserved after HttpHeaders.ArrayThreshold - headers = headers.OrderByDescending(h => h.Name).ToArray(); - } - - Assert.Equal("foo", headers[0].Name); - Assert.Equal("foo-single-1, foo-multi-1, foo-multi-2, foo-single-2", headers[0].Value); - - Assert.Equal("bar", headers[1].Name); - Assert.Equal("bar-single-1, bar-single-2", headers[1].Value); - } - else - { - Assert.Equal(5, headers.Length); - - Assert.Equal("foo", headers[0].Name); - Assert.Equal("foo-single-1", headers[0].Value); - - Assert.Equal("bar", headers[1].Name); - Assert.Equal("bar-single-1", headers[1].Value); - - Assert.Equal("foo", headers[2].Name); - Assert.Equal("foo-multi-1, foo-multi-2", headers[2].Value); - - Assert.Equal("bar", headers[3].Name); - Assert.Equal("bar-single-2", headers[3].Value); - - Assert.Equal("foo", headers[4].Name); - Assert.Equal("foo-single-2", headers[4].Value); - } - }); - } - [Theory] [InlineData(false, false)] [InlineData(true, false)] diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index dc2c879726f421..6cb0dca525ed90 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -49,7 +49,6 @@ public abstract class HttpHeaders : IEnumerableEither a array or a Dictionary<, > private object? _headerStore; private int _count; - private bool _entriesMayBeUngrouped; // Indicates that a non-validating Add was performed and some entries may not be grouped by name private readonly HttpHeaderType _allowedHeaderTypes; private readonly HttpHeaderType _treatAsCustomHeaderTypes; @@ -85,7 +84,7 @@ internal void Add(HeaderDescriptor descriptor, string? value) // it to the store if we added at least one value. if (addToStore && (info.ParsedValue != null)) { - Debug.Assert(!ContainsEntry(descriptor)); + Debug.Assert(!ContainsKey(descriptor)); AddEntryToStore(new HeaderEntry(descriptor, info)); } } @@ -119,7 +118,7 @@ internal void Add(HeaderDescriptor descriptor, IEnumerable values) // However, if all values for a _new_ header were invalid, then don't add the header. if (addToStore && (info.ParsedValue != null)) { - Debug.Assert(!ContainsEntry(descriptor)); + Debug.Assert(!ContainsKey(descriptor)); AddEntryToStore(new HeaderEntry(descriptor, info)); } } @@ -136,8 +135,25 @@ internal bool TryAddWithoutValidation(HeaderDescriptor descriptor, string? value // values, e.g. adding two null-strings (or empty, or whitespace-only) results in "My-Header: ,". value ??= string.Empty; - AddEntryToStore(new HeaderEntry(descriptor, value)); - _entriesMayBeUngrouped = true; + ref object? storeValueRef = ref GetValueRefOrAddDefault(descriptor); + object? currentValue = storeValueRef; + + if (currentValue is null) + { + storeValueRef = value; + } + else + { + if (currentValue is not HeaderStoreItemInfo info) + { + // The header store contained a single raw string value, so promote it + // to being a HeaderStoreItemInfo and add to it. + Debug.Assert(currentValue is string); + storeValueRef = info = new HeaderStoreItemInfo() { RawValue = currentValue }; + } + + AddRawValue(info, value); + } return true; } @@ -156,30 +172,25 @@ internal bool TryAddWithoutValidation(HeaderDescriptor descriptor, IEnumerable enumerator = values.GetEnumerator(); if (enumerator.MoveNext()) { - string firstValue = enumerator.Current ?? string.Empty; + TryAddWithoutValidation(descriptor, enumerator.Current); if (enumerator.MoveNext()) { - var valuesList = new List + ref object? storeValueRef = ref GetValueRefOrAddDefault(descriptor); + Debug.Assert(storeValueRef is not null); + + object value = storeValueRef; + if (value is not HeaderStoreItemInfo info) { - firstValue - }; + Debug.Assert(value is string); + storeValueRef = info = new HeaderStoreItemInfo { RawValue = value }; + } do { - valuesList.Add(enumerator.Current ?? string.Empty); + AddRawValue(info, enumerator.Current ?? string.Empty); } while (enumerator.MoveNext()); - - AddEntryToStore(new HeaderEntry(descriptor, new HeaderStoreItemInfo - { - RawValue = valuesList - })); - } - else - { - AddEntryToStore(new HeaderEntry(descriptor, firstValue)); } - _entriesMayBeUngrouped = true; } return true; @@ -299,9 +310,10 @@ public IEnumerator>> GetEnumerator() => private IEnumerator>> GetEnumeratorCore() { - HeaderEntry[] entries = GetEntriesGroupedByName(out int numberOfEntries)!; + HeaderEntry[] entries = GetEntriesArray()!; - for (int i = 0; i < numberOfEntries && i < entries.Length; i++) + int count = _count; + for (int i = 0; i < count; i++) { HeaderEntry entry = entries[i]; @@ -319,7 +331,7 @@ private IEnumerator>> GetEnumeratorCore } else { - Debug.Assert(ContainsEntry(entry.Key)); + Debug.Assert(ContainsKey(entry.Key)); ((Dictionary)_headerStore!)[entry.Key] = info; } } @@ -335,7 +347,7 @@ private IEnumerator>> GetEnumeratorCore if (EntriesAreLiveView) { i--; - numberOfEntries--; + count--; } } else @@ -522,20 +534,24 @@ internal virtual void AddHeaders(HttpHeaders sourceHeaders) Debug.Assert(sourceHeaders != null); Debug.Assert(GetType() == sourceHeaders.GetType(), "Can only copy headers from an instance of the same type."); - foreach (HeaderEntry entry in sourceHeaders.GetEntriesGroupedByName()) + foreach (HeaderEntry entry in sourceHeaders.GetEntries()) { // Only add header values if they're not already set on the message. Note that we don't merge // collections: If both the default headers and the message have set some values for a certain // header, then we don't try to merge the values. - if (!ContainsEntry(entry.Key)) + ref object? storeValueRef = ref GetValueRefOrAddDefault(entry.Key); + if (storeValueRef is null) { - object value = entry.Value; - if (value is HeaderStoreItemInfo info) + object sourceValue = entry.Value; + if (sourceValue is HeaderStoreItemInfo info) { - value = CloneHeaderInfo(entry.Key, info); + storeValueRef = CloneHeaderInfo(entry.Key, info); + } + else + { + Debug.Assert(sourceValue is string); + storeValueRef = sourceValue; } - - AddEntryToStore(new HeaderEntry(entry.Key, value)); } } } @@ -634,7 +650,7 @@ private HeaderStoreItemInfo GetOrCreateHeaderInfo(HeaderDescriptor descriptor) private HeaderStoreItemInfo CreateAndAddHeaderToStore(HeaderDescriptor descriptor) { - Debug.Assert(!ContainsEntry(descriptor)); + Debug.Assert(!ContainsKey(descriptor)); // If we don't have the header in the store yet, add it now. HeaderStoreItemInfo result = new HeaderStoreItemInfo(); @@ -649,7 +665,7 @@ private HeaderStoreItemInfo CreateAndAddHeaderToStore(HeaderDescriptor descripto internal bool TryGetHeaderValue(HeaderDescriptor descriptor, [NotNullWhen(true)] out object? value) { - ref object storeValueRef = ref GetValueRefGroupedByNameOrNullRef(descriptor); + ref object storeValueRef = ref GetValueRefOrNullRef(descriptor); if (Unsafe.IsNullRef(ref storeValueRef)) { value = null; @@ -664,7 +680,7 @@ internal bool TryGetHeaderValue(HeaderDescriptor descriptor, [NotNullWhen(true)] private bool TryGetAndParseHeaderInfo(HeaderDescriptor key, [NotNullWhen(true)] out HeaderStoreItemInfo? info) { - ref object storeValueRef = ref GetValueRefGroupedByNameOrNullRef(key); + ref object storeValueRef = ref GetValueRefOrNullRef(key); if (!Unsafe.IsNullRef(ref storeValueRef)) { object value = storeValueRef; @@ -689,6 +705,7 @@ private bool ParseRawHeaderValues(HeaderDescriptor descriptor, HeaderStoreItemIn { // Unlike TryGetHeaderInfo() this method tries to parse all non-validated header values (if any) // before returning to the caller. + Debug.Assert(!info.IsEmpty); if (info.RawValue != null) { List? rawValues = info.RawValue as List; @@ -707,8 +724,7 @@ private bool ParseRawHeaderValues(HeaderDescriptor descriptor, HeaderStoreItemIn info.RawValue = null; // During parsing, we removed the value since it contains newline chars. Return false to indicate that - // this is an empty header. If the caller specified to remove empty headers, we'll remove the header before - // returning. + // this is an empty header. if ((info.InvalidValue == null) && (info.ParsedValue == null)) { // After parsing the raw value, no value is left because all values contain newline chars. @@ -782,7 +798,7 @@ internal bool TryParseAndAddValue(HeaderDescriptor descriptor, string? value) { // If we get here, then the value could be parsed correctly. If we created a new HeaderStoreItemInfo, add // it to the store if we added at least one value. - Debug.Assert(!ContainsEntry(descriptor)); + Debug.Assert(!ContainsKey(descriptor)); AddEntryToStore(new HeaderEntry(descriptor, info)); } @@ -1239,10 +1255,8 @@ internal bool CanAddParsedValue(HttpHeaderParser parser) private const int InitialCapacity = 4; internal const int ArrayThreshold = 64; // Above this threshold, header ordering will not be preserved - internal HeaderEntry[]? GetEntries(out int numberOfEntries) + internal HeaderEntry[]? GetEntriesArray() { - numberOfEntries = _count; - object? store = _headerStore; if (store is null) { @@ -1276,230 +1290,137 @@ HeaderEntry[] GetEntriesFromDictionary() internal ReadOnlySpan GetEntries() { - return new ReadOnlySpan(GetEntries(out int numberOfEntries), 0, numberOfEntries); - } - - internal HeaderEntry[]? GetEntriesGroupedByName(out int numberOfEntries) - { - GroupAllEntriesByName(); - return GetEntries(out numberOfEntries); + return new ReadOnlySpan(GetEntriesArray(), 0, _count); } - internal ReadOnlySpan GetEntriesGroupedByName() - { - GroupAllEntriesByName(); - return GetEntries(); - } + internal int Count => _count; private bool EntriesAreLiveView => _headerStore is HeaderEntry[]; - internal int Count + private ref object GetValueRefOrNullRef(HeaderDescriptor key) { - get - { - GroupAllEntriesByName(); - return _count; - } - } + ref object valueRef = ref Unsafe.NullRef(); - private void GroupAllEntriesByName() - { - if (_entriesMayBeUngrouped && _headerStore is HeaderEntry[] entries) + object? store = _headerStore; + if (store is HeaderEntry[] entries) { - _entriesMayBeUngrouped = false; - for (int i = 0; i < _count && i < entries.Length; i++) { - HeaderDescriptor key = entries[i].Key; - - for (int j = i + 1; j < _count && (uint)j < (uint)entries.Length; j++) + if (key.Equals(entries[i].Key)) { - if (key.Equals(entries[j].Key)) - { - MergeEntryValues(ref entries[i].Value, entries[j].Value); - RemoveAt(entries, j); - j--; - } + valueRef = ref entries[i].Value; + break; } } } - } - - private static void MergeEntryValues(ref object existingValueRef, object newValue) - { - Debug.Assert(existingValueRef is not null); - Debug.Assert(newValue is not null); - - object existingValue = existingValueRef; - if (existingValue is not HeaderStoreItemInfo info) + else if (store is not null) { - Debug.Assert(existingValue is string); - existingValueRef = info = new HeaderStoreItemInfo { RawValue = existingValue }; + valueRef = ref CollectionsMarshal.GetValueRefOrNullRef(Unsafe.As>(store), key); } - if (newValue is HeaderStoreItemInfo newInfo) - { - // This should only happen if the TryAddWithoutValidation(string, IEnumerable) was called when a header was already present - Debug.Assert(newInfo.ParsedValue is null); - Debug.Assert(newInfo.InvalidValue is null); - Debug.Assert(newInfo.RawValue is not null); - Debug.Assert(newInfo.RawValue is List); - foreach (string value in (List)newInfo.RawValue) - { - AddRawValue(info, value); - } - } - else - { - Debug.Assert(newValue is string); - AddRawValue(info, Unsafe.As(newValue)); - } + return ref valueRef; } - private ref object GetValueRefGroupedByNameOrNullRef(HeaderDescriptor key) + private ref object? GetValueRefOrAddDefault(HeaderDescriptor key) { object? store = _headerStore; - if (store is not null) + if (store is HeaderEntry[] entries) { - if (store is HeaderEntry[] entries) + for (int i = 0; i < _count && i < entries.Length; i++) { - for (int i = 0; i < _count && i < entries.Length; i++) + if (key.Equals(entries[i].Key)) { - if (key.Equals(entries[i].Key)) - { - if (_entriesMayBeUngrouped) - { - MergeEntriesAfter(entries, i); - } - - return ref entries[i].Value; - } + return ref entries[i].Value!; } } - else - { - return ref CollectionsMarshal.GetValueRefOrNullRef(Unsafe.As>(store), key); - } - } - return ref Unsafe.NullRef(); - - void MergeEntriesAfter(HeaderEntry[] entries, int i) - { - if ((uint)i < (uint)entries.Length) - { - ref HeaderEntry entry = ref entries[i]; - HeaderDescriptor key = entry.Key; - - for (i++; i < _count && (uint)i < (uint)entries.Length; i++) - { - if (key.Equals(entries[i].Key)) - { - MergeEntryValues(ref entry.Value, entries[i].Value); - RemoveAt(entries, i); - i--; - } - } - } - } - } - - private void AddEntryToStore(HeaderEntry entry) - { - if (_headerStore is HeaderEntry[] entries) - { int count = _count; + _count++; if ((uint)count < (uint)entries.Length) { - entries[count] = entry; - _count++; - } - else - { - GrowAndAddEntry(ref entry); + entries[count].Key = key; + return ref entries[count].Value!; } + + return ref GrowEntriesAndAddDefault(key); + } + else if (store is null) + { + _count++; + entries = new HeaderEntry[InitialCapacity]; + _headerStore = entries; + ref HeaderEntry firstEntry = ref MemoryMarshal.GetArrayDataReference(entries); + firstEntry.Key = key; + return ref firstEntry.Value!; } else { - InitializeOrAddToDictionary(ref entry); + return ref DictionaryGetValueRefOrAddDefault(key); } - void InitializeOrAddToDictionary(ref HeaderEntry entry) + ref object? GrowEntriesAndAddDefault(HeaderDescriptor key) { - if (_headerStore is Dictionary) + var entries = (HeaderEntry[])_headerStore!; + if (entries.Length == ArrayThreshold) { - AddEntryToDictionary(ref entry); + return ref ConvertToDictionaryAndAddDefault(key); } else { - var entries = new HeaderEntry[InitialCapacity]; + Array.Resize(ref entries, entries.Length << 1); _headerStore = entries; - MemoryMarshal.GetArrayDataReference(entries) = entry; - _count = 1; + ref HeaderEntry firstNewEntry = ref entries[entries.Length >> 1]; + firstNewEntry.Key = key; + return ref firstNewEntry.Value!; } } - void GrowAndAddEntry(ref HeaderEntry entry) + ref object? ConvertToDictionaryAndAddDefault(HeaderDescriptor key) { var entries = (HeaderEntry[])_headerStore!; - Debug.Assert(entries.Length == _count); - if (entries.Length == ArrayThreshold) - { - _headerStore = new Dictionary(ArrayThreshold); - _count = 0; - for (int i = 0; i < entries.Length; i++) - { - AddEntryToDictionary(ref entries[i]); - } - AddEntryToDictionary(ref entry); - } - else + var dictionary = new Dictionary(ArrayThreshold); + _headerStore = dictionary; + foreach (HeaderEntry entry in entries) { - Array.Resize(ref entries, entries.Length << 1); - entries[entries.Length >> 1] = entry; - _headerStore = entries; - _count++; + dictionary.Add(entry.Key, entry.Value); } + Debug.Assert(dictionary.Count == _count - 1); + return ref CollectionsMarshal.GetValueRefOrAddDefault(dictionary, key, out s_dictionaryGetValueRefOrAddDefaultExistsDummy); } - void AddEntryToDictionary(ref HeaderEntry entry) + ref object? DictionaryGetValueRefOrAddDefault(HeaderDescriptor key) { var dictionary = (Dictionary)_headerStore!; - ref object? valueRef = ref CollectionsMarshal.GetValueRefOrAddDefault(dictionary, entry.Key, out s_dictionaryGetValueRefOrAddDefaultExistsDummy); - if (valueRef is null) + ref object? value = ref CollectionsMarshal.GetValueRefOrAddDefault(dictionary, key, out s_dictionaryGetValueRefOrAddDefaultExistsDummy); + if (value is null) { - valueRef = entry.Value; _count++; } - else - { - MergeEntryValues(ref valueRef, entry.Value); - } + return ref value; } } - internal bool ContainsEntry(HeaderDescriptor key) + private void AddEntryToStore(HeaderEntry entry) { - object? store = _headerStore; - if (store is not null) + Debug.Assert(!ContainsKey(entry.Key)); + + if (_headerStore is HeaderEntry[] entries) { - if (store is HeaderEntry[] entries) - { - for (int i = 0; i < _count && i < entries.Length; i++) - { - if (key.Equals(entries[i].Key)) - { - return true; - } - } - } - else + int count = _count; + if ((uint)count < (uint)entries.Length) { - return Unsafe.As>(store).ContainsKey(key); + entries[count] = entry; + _count++; + return; } } - return false; + GetValueRefOrAddDefault(entry.Key) = entry.Value; + } + + internal bool ContainsKey(HeaderDescriptor key) + { + return !Unsafe.IsNullRef(ref GetValueRefOrNullRef(key)); } public void Clear() @@ -1521,41 +1442,37 @@ public void Clear() internal bool Remove(HeaderDescriptor key) { - bool removedAny = false; + bool removed = false; object? store = _headerStore; - if (store is not null) + if (store is HeaderEntry[] entries) { - if (store is HeaderEntry[] entries) + for (int i = 0; i < _count && i < entries.Length; i++) { - for (int i = _count - 1; i >= 0 && (uint)i < (uint)entries.Length; i--) + if (key.Equals(entries[i].Key)) { - if (key.Equals(entries[i].Key)) + while (i + 1 < _count && (uint)(i + 1) < (uint)entries.Length) { - removedAny = true; - RemoveAt(entries, i); + entries[i] = entries[i + 1]; + i++; } + entries[i] = default; + removed = true; + break; } } - else if (Unsafe.As>(store).Remove(key)) - { - _count--; - removedAny = true; - } + } + else if (store is not null) + { + removed = Unsafe.As>(store).Remove(key); } - return removedAny; - } - - private void RemoveAt(HeaderEntry[] entries, int i) - { - int count = _count--; - while ((uint)(i + 1) < (uint)entries.Length && i + 1 < count) + if (removed) { - entries[i] = entries[i + 1]; - i++; + _count--; } - entries[i] = default; + + return removed; } #endregion // _headerStore implementation diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs index 418f4219643f1f..b2d2f5c3a8727b 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs @@ -33,7 +33,7 @@ namespace System.Net.Http.Headers public bool Contains(string headerName) => _headers is HttpHeaders headers && headers.TryGetHeaderDescriptor(headerName, out HeaderDescriptor descriptor) && - headers.ContainsEntry(descriptor); + headers.ContainsKey(descriptor); /// Gets the values for the specified header name. /// The name of the header. @@ -83,8 +83,8 @@ public bool TryGetValues(string headerName, out HeaderStringValues values) /// Gets an enumerator that iterates through the . /// An enumerator that iterates through the . public Enumerator GetEnumerator() => - _headers is HttpHeaders headers && headers.GetEntriesGroupedByName(out int numberOfEntries) is HeaderEntry[] entries ? - new Enumerator(entries, numberOfEntries) : + _headers is HttpHeaders headers && headers.GetEntriesArray() is HeaderEntry[] entries ? + new Enumerator(entries, headers.Count) : default; /// diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs index d1f058a0c13109..deb0f5d9373d19 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs @@ -258,9 +258,9 @@ private async ValueTask WriteHeadersAsync(HttpHeaders headers, string? cookiesFr { Debug.Assert(_currentRequest != null); - if (headers.GetEntries(out int numberOfEntries) is HeaderEntry[] entries) + if (headers.GetEntriesArray() is HeaderEntry[] entries) { - for (int i = 0; i < numberOfEntries; i++) + for (int i = 0; i < headers.Count; i++) { HeaderEntry header = entries[i]; From 1864797c0f7bdd0cf93d64b272be79f6646c3f37 Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Wed, 19 Jan 2022 19:50:15 +0100 Subject: [PATCH 17/22] Assert that the collection is not empty in GetEnumeratorCore --- .../System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index 6cb0dca525ed90..a377d218d48440 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -310,7 +310,8 @@ public IEnumerator>> GetEnumerator() => private IEnumerator>> GetEnumeratorCore() { - HeaderEntry[] entries = GetEntriesArray()!; + HeaderEntry[]? entries = GetEntriesArray(); + Debug.Assert(_count != 0 && entries is not null, "Caller should have validated the collection is not empty"); int count = _count; for (int i = 0; i < count; i++) From 076b219ee1bb3798d3b5d7b64d1c6b328a1d843a Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Wed, 19 Jan 2022 22:06:30 +0100 Subject: [PATCH 18/22] Optimize AddHeaders for empty collections --- .../System/Net/Http/Headers/HttpHeaders.cs | 49 ++++++++++++++----- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index a377d218d48440..675edd7e230094 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -535,23 +535,46 @@ internal virtual void AddHeaders(HttpHeaders sourceHeaders) Debug.Assert(sourceHeaders != null); Debug.Assert(GetType() == sourceHeaders.GetType(), "Can only copy headers from an instance of the same type."); - foreach (HeaderEntry entry in sourceHeaders.GetEntries()) - { - // Only add header values if they're not already set on the message. Note that we don't merge - // collections: If both the default headers and the message have set some values for a certain - // header, then we don't try to merge the values. - ref object? storeValueRef = ref GetValueRefOrAddDefault(entry.Key); - if (storeValueRef is null) + // Only add header values if they're not already set on the message. Note that we don't merge + // collections: If both the default headers and the message have set some values for a certain + // header, then we don't try to merge the values. + if (_count == 0 && sourceHeaders._headerStore is HeaderEntry[] sourceEntries) + { + // If the target collection is empty, we don't have to search for existing values + _count = sourceHeaders._count; + if (_headerStore is not HeaderEntry[] entries || entries.Length < _count) + { + entries = new HeaderEntry[sourceEntries.Length]; + _headerStore = entries; + } + + for (int i = 0; i < _count && i < sourceEntries.Length; i++) { - object sourceValue = entry.Value; - if (sourceValue is HeaderStoreItemInfo info) + HeaderEntry entry = sourceEntries[i]; + if (entry.Value is HeaderStoreItemInfo info) { - storeValueRef = CloneHeaderInfo(entry.Key, info); + entry.Value = CloneHeaderInfo(entry.Key, info); } - else + entries[i] = entry; + } + } + else + { + foreach (HeaderEntry entry in sourceHeaders.GetEntries()) + { + ref object? storeValueRef = ref GetValueRefOrAddDefault(entry.Key); + if (storeValueRef is null) { - Debug.Assert(sourceValue is string); - storeValueRef = sourceValue; + object sourceValue = entry.Value; + if (sourceValue is HeaderStoreItemInfo info) + { + storeValueRef = CloneHeaderInfo(entry.Key, info); + } + else + { + Debug.Assert(sourceValue is string); + storeValueRef = sourceValue; + } } } } From 5d681d688f148f93d5377a8ae9ec52f32f3a9ef6 Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Wed, 19 Jan 2022 23:05:06 +0100 Subject: [PATCH 19/22] Reference the Roslyn bug issue --- .../src/System/Net/Http/Headers/HttpHeaders.cs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index 675edd7e230094..db76b9659bcb85 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -1273,7 +1273,9 @@ internal bool CanAddParsedValue(HttpHeaderParser parser) #region Low-level implementation details that work with _headerStore directly - // Used to store the CollectionsMarshal.GetValueRefOrAddDefault out parameter + // Used to store the CollectionsMarshal.GetValueRefOrAddDefault out parameter. + // This is a workaround for the Roslyn bug where we can't use a discard instead: + // https://github.com/dotnet/roslyn/issues/56587#issuecomment-934955526 private static bool s_dictionaryGetValueRefOrAddDefaultExistsDummy; private const int InitialCapacity = 4; From 5babbda901ab5ad4a463e1d8de484173c0b693f6 Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Thu, 20 Jan 2022 00:38:21 +0100 Subject: [PATCH 20/22] Assert that multiValues are never empty --- .../src/System/Net/Http/Headers/HttpHeaders.cs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index db76b9659bcb85..627c64ebc790b8 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -268,9 +268,11 @@ public override string ToString() // just separate the values using a comma (default separator). string? separator = entry.Key.Parser is HttpHeaderParser parser && parser.SupportsMultipleValues ? parser.Separator : HttpHeaderParser.DefaultSeparator; - for (int i = 0; i < multiValue!.Length; i++) + Debug.Assert(multiValue is not null && multiValue.Length > 0); + vsb.Append(multiValue[0]); + for (int i = 1; i < multiValue.Length; i++) { - if (i != 0) vsb.Append(separator); + vsb.Append(separator); vsb.Append(multiValue[i]); } } @@ -1145,7 +1147,8 @@ internal static void GetStoreValuesAsStringOrStringArray(HeaderDescriptor descri } else { - values = multiValue = length != 0 ? new string[length] : Array.Empty(); + Debug.Assert(length > 1, "The header should have been removed when it became empty"); + values = multiValue = new string[length]; } int currentIndex = 0; From 468cf3dbeadde2239ff1253a9ab5408bdbba80b8 Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Thu, 20 Jan 2022 00:57:15 +0100 Subject: [PATCH 21/22] Don't preserve a Dictionary across Clear --- .../src/System/Net/Http/Headers/HttpHeaders.cs | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index 627c64ebc790b8..2bb8fc94d80b03 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -1454,19 +1454,15 @@ internal bool ContainsKey(HeaderDescriptor key) public void Clear() { - object? store = _headerStore; - if (store is not null) + if (_headerStore is HeaderEntry[] entries) { - if (store is HeaderEntry[] entries) - { - Array.Clear(entries, 0, _count); - } - else - { - Unsafe.As>(store).Clear(); - } - _count = 0; + Array.Clear(entries, 0, _count); + } + else + { + _headerStore = null; } + _count = 0; } internal bool Remove(HeaderDescriptor key) From b3aee48a920f2d65a9acdb12f10666574ef8a7aa Mon Sep 17 00:00:00 2001 From: Miha Zupan Date: Thu, 20 Jan 2022 00:59:34 +0100 Subject: [PATCH 22/22] Add comment about why a custom HeaderEntry type is used --- .../src/System/Net/Http/Headers/HttpHeaders.cs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index 2bb8fc94d80b03..3eae142edcd4cb 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -11,7 +11,10 @@ namespace System.Net.Http.Headers { - /// Key/value pairs of headers. The value is either a raw or a . + /// + /// Key/value pairs of headers. The value is either a raw or a . + /// We're using a custom type instead of because we need ref access to fields. + /// internal struct HeaderEntry { public HeaderDescriptor Key;