diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs index ee53b9e626c751..3eae142edcd4cb 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeaders.cs @@ -5,11 +5,28 @@ using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Text; namespace System.Net.Http.Headers { + /// + /// Key/value pairs of headers. The value is either a raw or a . + /// We're using a custom type instead of because we need ref access to fields. + /// + internal struct HeaderEntry + { + public HeaderDescriptor Key; + public object Value; + + public HeaderEntry(HeaderDescriptor key, object value) + { + Key = key; + Value = value; + } + } + public abstract class HttpHeaders : IEnumerable>> { // This type is used to store a collection of headers in 'headerStore': @@ -32,8 +49,9 @@ public abstract class HttpHeaders : IEnumerableKey/value pairs of headers. The value is either a raw or a . - private Dictionary? _headerStore; + /// Either a array or a Dictionary<, > + private object? _headerStore; + private int _count; private readonly HttpHeaderType _allowedHeaderTypes; private readonly HttpHeaderType _treatAsCustomHeaderTypes; @@ -52,8 +70,6 @@ internal HttpHeaders(HttpHeaderType allowedHeaderTypes, HttpHeaderType treatAsCu _treatAsCustomHeaderTypes = treatAsCustomHeaderTypes & ~HttpHeaderType.NonTrailing; } - internal Dictionary? HeaderStore => _headerStore; - /// Gets a view of the contents of this headers collection that does not parse nor validate the data upon access. public HttpHeadersNonValidated NonValidated => new HttpHeadersNonValidated(this); @@ -71,7 +87,8 @@ internal void Add(HeaderDescriptor descriptor, string? value) // it to the store if we added at least one value. if (addToStore && (info.ParsedValue != null)) { - AddHeaderToStore(descriptor, info); + Debug.Assert(!ContainsKey(descriptor)); + AddEntryToStore(new HeaderEntry(descriptor, info)); } } @@ -104,7 +121,8 @@ internal void Add(HeaderDescriptor descriptor, IEnumerable values) // However, if all values for a _new_ header were invalid, then don't add the header. if (addToStore && (info.ParsedValue != null)) { - AddHeaderToStore(descriptor, info); + Debug.Assert(!ContainsKey(descriptor)); + AddEntryToStore(new HeaderEntry(descriptor, info)); } } } @@ -120,29 +138,24 @@ internal bool TryAddWithoutValidation(HeaderDescriptor descriptor, string? value // values, e.g. adding two null-strings (or empty, or whitespace-only) results in "My-Header: ,". value ??= string.Empty; - // Ensure the header store dictionary has been created. - _headerStore ??= new Dictionary(); + ref object? storeValueRef = ref GetValueRefOrAddDefault(descriptor); + object? currentValue = storeValueRef; - if (_headerStore.TryGetValue(descriptor, out object? currentValue)) + if (currentValue is null) { - if (currentValue is HeaderStoreItemInfo info) - { - // The header store already contained a HeaderStoreItemInfo, so add to it. - AddRawValue(info, value); - } - else + storeValueRef = value; + } + else + { + if (currentValue is not HeaderStoreItemInfo info) { // The header store contained a single raw string value, so promote it // to being a HeaderStoreItemInfo and add to it. Debug.Assert(currentValue is string); - _headerStore[descriptor] = info = new HeaderStoreItemInfo() { RawValue = currentValue }; - AddRawValue(info, value); + storeValueRef = info = new HeaderStoreItemInfo() { RawValue = currentValue }; } - } - else - { - // The header store did not contain the header. Add the raw string. - _headerStore.Add(descriptor, value); + + AddRawValue(info, value); } return true; @@ -159,28 +172,33 @@ internal bool TryAddWithoutValidation(HeaderDescriptor descriptor, IEnumerable enumerator = values.GetEnumerator()) + using IEnumerator enumerator = values.GetEnumerator(); + if (enumerator.MoveNext()) { + TryAddWithoutValidation(descriptor, enumerator.Current); if (enumerator.MoveNext()) { - TryAddWithoutValidation(descriptor, enumerator.Current); - if (enumerator.MoveNext()) + ref object? storeValueRef = ref GetValueRefOrAddDefault(descriptor); + Debug.Assert(storeValueRef is not null); + + object value = storeValueRef; + if (value is not HeaderStoreItemInfo info) { - HeaderStoreItemInfo info = GetOrCreateHeaderInfo(descriptor, parseRawValues: false); - do - { - AddRawValue(info, enumerator.Current ?? string.Empty); - } - while (enumerator.MoveNext()); + Debug.Assert(value is string); + storeValueRef = info = new HeaderStoreItemInfo { RawValue = value }; } + + do + { + AddRawValue(info, enumerator.Current ?? string.Empty); + } + while (enumerator.MoveNext()); } } return true; } - public void Clear() => _headerStore?.Clear(); - public IEnumerable GetValues(string name) => GetValues(GetHeaderDescriptor(name)); internal IEnumerable GetValues(HeaderDescriptor descriptor) @@ -206,7 +224,7 @@ public bool TryGetValues(string name, [NotNullWhen(true)] out IEnumerable? values) { - if (_headerStore != null && TryGetAndParseHeaderInfo(descriptor, out HeaderStoreItemInfo? info)) + if (TryGetAndParseHeaderInfo(descriptor, out HeaderStoreItemInfo? info)) { values = GetStoreValuesAsStringArray(descriptor, info); return true; @@ -223,7 +241,7 @@ internal bool Contains(HeaderDescriptor descriptor) // We can't just call headerStore.ContainsKey() since after parsing the value the header may not exist // anymore (if the value contains newline chars, we remove the header). So try to parse the // header value. - return _headerStore != null && TryGetAndParseHeaderInfo(descriptor, out _); + return TryGetAndParseHeaderInfo(descriptor, out _); } public override string ToString() @@ -235,35 +253,34 @@ public override string ToString() var vsb = new ValueStringBuilder(stackalloc char[512]); - if (_headerStore is Dictionary headerStore) + foreach (HeaderEntry entry in GetEntries()) { - foreach (KeyValuePair header in headerStore) - { - vsb.Append(header.Key.Name); - vsb.Append(": "); + vsb.Append(entry.Key.Name); + vsb.Append(": "); - GetStoreValuesAsStringOrStringArray(header.Key, header.Value, out string? singleValue, out string[]? multiValue); - Debug.Assert(singleValue is not null ^ multiValue is not null); + GetStoreValuesAsStringOrStringArray(entry.Key, entry.Value, out string? singleValue, out string[]? multiValue); + Debug.Assert(singleValue is not null ^ multiValue is not null); - if (singleValue is not null) - { - vsb.Append(singleValue); - } - else - { - // Note that if we get multiple values for a header that doesn't support multiple values, we'll - // just separate the values using a comma (default separator). - string? separator = header.Key.Parser is HttpHeaderParser parser && parser.SupportsMultipleValues ? parser.Separator : HttpHeaderParser.DefaultSeparator; + if (singleValue is not null) + { + vsb.Append(singleValue); + } + else + { + // Note that if we get multiple values for a header that doesn't support multiple values, we'll + // just separate the values using a comma (default separator). + string? separator = entry.Key.Parser is HttpHeaderParser parser && parser.SupportsMultipleValues ? parser.Separator : HttpHeaderParser.DefaultSeparator; - for (int i = 0; i < multiValue!.Length; i++) - { - if (i != 0) vsb.Append(separator); - vsb.Append(multiValue[i]); - } + Debug.Assert(multiValue is not null && multiValue.Length > 0); + vsb.Append(multiValue[0]); + for (int i = 1; i < multiValue.Length; i++) + { + vsb.Append(separator); + vsb.Append(multiValue[i]); } - - vsb.Append(Environment.NewLine); } + + vsb.Append(Environment.NewLine); } return vsb.ToString(); @@ -292,39 +309,57 @@ internal string GetHeaderString(HeaderDescriptor descriptor) #region IEnumerable>> Members - public IEnumerator>> GetEnumerator() => _headerStore != null && _headerStore.Count > 0 ? - GetEnumeratorCore() : - ((IEnumerable>>)Array.Empty>>()).GetEnumerator(); + public IEnumerator>> GetEnumerator() => _count == 0 ? + ((IEnumerable>>)Array.Empty>>()).GetEnumerator() : + GetEnumeratorCore(); private IEnumerator>> GetEnumeratorCore() { - foreach (KeyValuePair header in _headerStore!) + HeaderEntry[]? entries = GetEntriesArray(); + Debug.Assert(_count != 0 && entries is not null, "Caller should have validated the collection is not empty"); + + int count = _count; + for (int i = 0; i < count; i++) { - HeaderDescriptor descriptor = header.Key; - object value = header.Value; + HeaderEntry entry = entries[i]; - HeaderStoreItemInfo? info = value as HeaderStoreItemInfo; - if (info is null) + if (entry.Value is not HeaderStoreItemInfo info) { // To retain consistent semantics, we need to upgrade a raw string to a HeaderStoreItemInfo // during enumeration so that we can parse the raw value in order to a) return // the correct set of parsed values, and b) update the instance for subsequent enumerations // to reflect that parsing. - _headerStore[descriptor] = info = new HeaderStoreItemInfo() { RawValue = value }; + info = new HeaderStoreItemInfo() { RawValue = entry.Value }; + + if (EntriesAreLiveView) + { + entries[i].Value = info; + } + else + { + Debug.Assert(ContainsKey(entry.Key)); + ((Dictionary)_headerStore!)[entry.Key] = info; + } } // Make sure we parse all raw values before returning the result. Note that this has to be // done before we calculate the array length (next line): A raw value may contain a list of // values. - if (!ParseRawHeaderValues(descriptor, info, removeEmptyHeader: false)) + if (!ParseRawHeaderValues(entry.Key, info)) { - // We have an invalid header value (contains newline chars). Delete it. - _headerStore.Remove(descriptor); + // We saw an invalid header value (contains newline chars) and deleted it. + + // If the HeaderEntry[] we are enumerating is the live header store, the entries have shifted. + if (EntriesAreLiveView) + { + i--; + count--; + } } else { - string[] values = GetStoreValuesAsStringArray(descriptor, info); - yield return new KeyValuePair>(descriptor.Name, values); + string[] values = GetStoreValuesAsStringArray(entry.Key, info); + yield return new KeyValuePair>(entry.Key.Name, values); } } } @@ -342,7 +377,7 @@ internal void AddParsedValue(HeaderDescriptor descriptor, object value) Debug.Assert(value != null); Debug.Assert(descriptor.Parser != null, "Can't add parsed value if there is no parser available."); - HeaderStoreItemInfo info = GetOrCreateHeaderInfo(descriptor, parseRawValues: true); + HeaderStoreItemInfo info = GetOrCreateHeaderInfo(descriptor); // If the current header has only one value, we can't add another value. The strongly typed property // must not call AddParsedValue(), but SetParsedValue(). E.g. for headers like 'Date', 'Host'. @@ -358,7 +393,7 @@ internal void SetParsedValue(HeaderDescriptor descriptor, object value) // This method will first clear all values. This is used e.g. when setting the 'Date' or 'Host' header. // i.e. headers not supporting collections. - HeaderStoreItemInfo info = GetOrCreateHeaderInfo(descriptor, parseRawValues: true); + HeaderStoreItemInfo info = GetOrCreateHeaderInfo(descriptor); info.InvalidValue = null; info.ParsedValue = null; @@ -381,17 +416,10 @@ internal void SetOrRemoveParsedValue(HeaderDescriptor descriptor, object? value) public bool Remove(string name) => Remove(GetHeaderDescriptor(name)); - internal bool Remove(HeaderDescriptor descriptor) => _headerStore != null && _headerStore.Remove(descriptor); - internal bool RemoveParsedValue(HeaderDescriptor descriptor, object value) { Debug.Assert(value != null); - if (_headerStore == null) - { - return false; - } - // If we have a value for this header, then verify if we have a single value. If so, compare that // value with 'item'. If we have a list of values, then remove 'item' from the list. if (TryGetAndParseHeaderInfo(descriptor, out HeaderStoreItemInfo? info)) @@ -462,11 +490,6 @@ internal bool ContainsParsedValue(HeaderDescriptor descriptor, object value) { Debug.Assert(value != null); - if (_headerStore == null) - { - return false; - } - // If we have a value for this header, then verify if we have a single value. If so, compare that // value with 'item'. If we have a list of values, then compare each item in the list with 'item'. if (TryGetAndParseHeaderInfo(descriptor, out HeaderStoreItemInfo? info)) @@ -517,41 +540,58 @@ internal virtual void AddHeaders(HttpHeaders sourceHeaders) Debug.Assert(sourceHeaders != null); Debug.Assert(GetType() == sourceHeaders.GetType(), "Can only copy headers from an instance of the same type."); - Dictionary? sourceHeadersStore = sourceHeaders._headerStore; - if (sourceHeadersStore is null || sourceHeadersStore.Count == 0) + // Only add header values if they're not already set on the message. Note that we don't merge + // collections: If both the default headers and the message have set some values for a certain + // header, then we don't try to merge the values. + if (_count == 0 && sourceHeaders._headerStore is HeaderEntry[] sourceEntries) { - return; - } - - _headerStore ??= new Dictionary(); + // If the target collection is empty, we don't have to search for existing values + _count = sourceHeaders._count; + if (_headerStore is not HeaderEntry[] entries || entries.Length < _count) + { + entries = new HeaderEntry[sourceEntries.Length]; + _headerStore = entries; + } - foreach (KeyValuePair header in sourceHeadersStore) - { - // Only add header values if they're not already set on the message. Note that we don't merge - // collections: If both the default headers and the message have set some values for a certain - // header, then we don't try to merge the values. - if (!_headerStore.ContainsKey(header.Key)) + for (int i = 0; i < _count && i < sourceEntries.Length; i++) { - object sourceValue = header.Value; - if (sourceValue is HeaderStoreItemInfo info) + HeaderEntry entry = sourceEntries[i]; + if (entry.Value is HeaderStoreItemInfo info) { - AddHeaderInfo(header.Key, info); + entry.Value = CloneHeaderInfo(entry.Key, info); } - else + entries[i] = entry; + } + } + else + { + foreach (HeaderEntry entry in sourceHeaders.GetEntries()) + { + ref object? storeValueRef = ref GetValueRefOrAddDefault(entry.Key); + if (storeValueRef is null) { - Debug.Assert(sourceValue is string); - _headerStore.Add(header.Key, sourceValue); + object sourceValue = entry.Value; + if (sourceValue is HeaderStoreItemInfo info) + { + storeValueRef = CloneHeaderInfo(entry.Key, info); + } + else + { + Debug.Assert(sourceValue is string); + storeValueRef = sourceValue; + } } } } } - private void AddHeaderInfo(HeaderDescriptor descriptor, HeaderStoreItemInfo sourceInfo) + private HeaderStoreItemInfo CloneHeaderInfo(HeaderDescriptor descriptor, HeaderStoreItemInfo sourceInfo) { - HeaderStoreItemInfo destinationInfo = CreateAndAddHeaderToStore(descriptor); - - // Always copy raw values - destinationInfo.RawValue = CloneStringHeaderInfoValues(sourceInfo.RawValue); + var destinationInfo = new HeaderStoreItemInfo + { + // Always copy raw values + RawValue = CloneStringHeaderInfoValues(sourceInfo.RawValue) + }; if (descriptor.Parser == null) { @@ -585,6 +625,8 @@ private void AddHeaderInfo(HeaderDescriptor descriptor, HeaderStoreItemInfo sour } } } + + return destinationInfo; } private static void CloneAndAddValue(HeaderStoreItemInfo destinationInfo, object source) @@ -623,74 +665,54 @@ private static void CloneAndAddValue(HeaderStoreItemInfo destinationInfo, object } } - private HeaderStoreItemInfo GetOrCreateHeaderInfo(HeaderDescriptor descriptor, bool parseRawValues) + private HeaderStoreItemInfo GetOrCreateHeaderInfo(HeaderDescriptor descriptor) { - HeaderStoreItemInfo? result = null; - bool found; - if (parseRawValues) + if (TryGetAndParseHeaderInfo(descriptor, out HeaderStoreItemInfo? info)) { - found = TryGetAndParseHeaderInfo(descriptor, out result); + return info; } else { - found = TryGetHeaderValue(descriptor, out object? value); - if (found) - { - if (value is HeaderStoreItemInfo hsti) - { - result = hsti; - } - else - { - Debug.Assert(value is string); - _headerStore![descriptor] = result = new HeaderStoreItemInfo { RawValue = value }; - } - } - } - - if (!found) - { - result = CreateAndAddHeaderToStore(descriptor); + return CreateAndAddHeaderToStore(descriptor); } - - Debug.Assert(result != null); - return result; } private HeaderStoreItemInfo CreateAndAddHeaderToStore(HeaderDescriptor descriptor) { + Debug.Assert(!ContainsKey(descriptor)); + // If we don't have the header in the store yet, add it now. HeaderStoreItemInfo result = new HeaderStoreItemInfo(); // If the descriptor header type is in _treatAsCustomHeaderTypes, it must be converted to a custom header before calling this method Debug.Assert((descriptor.HeaderType & _treatAsCustomHeaderTypes) == 0); - AddHeaderToStore(descriptor, result); + AddEntryToStore(new HeaderEntry(descriptor, result)); return result; } - private void AddHeaderToStore(HeaderDescriptor descriptor, object value) - { - Debug.Assert(value is string || value is HeaderStoreItemInfo); - (_headerStore ??= new Dictionary()).Add(descriptor, value); - } - internal bool TryGetHeaderValue(HeaderDescriptor descriptor, [NotNullWhen(true)] out object? value) { - if (_headerStore == null) + ref object storeValueRef = ref GetValueRefOrNullRef(descriptor); + if (Unsafe.IsNullRef(ref storeValueRef)) { value = null; return false; } - - return _headerStore.TryGetValue(descriptor, out value); + else + { + value = storeValueRef; + return true; + } } private bool TryGetAndParseHeaderInfo(HeaderDescriptor key, [NotNullWhen(true)] out HeaderStoreItemInfo? info) { - if (TryGetHeaderValue(key, out object? value)) + ref object storeValueRef = ref GetValueRefOrNullRef(key); + if (!Unsafe.IsNullRef(ref storeValueRef)) { + object value = storeValueRef; if (value is HeaderStoreItemInfo hsi) { info = hsi; @@ -698,20 +720,21 @@ private bool TryGetAndParseHeaderInfo(HeaderDescriptor key, [NotNullWhen(true)] else { Debug.Assert(value is string); - _headerStore![key] = info = new HeaderStoreItemInfo() { RawValue = value }; + storeValueRef = info = new HeaderStoreItemInfo() { RawValue = value }; } - return ParseRawHeaderValues(key, info, removeEmptyHeader: true); + return ParseRawHeaderValues(key, info); } info = null; return false; } - private bool ParseRawHeaderValues(HeaderDescriptor descriptor, HeaderStoreItemInfo info, bool removeEmptyHeader) + private bool ParseRawHeaderValues(HeaderDescriptor descriptor, HeaderStoreItemInfo info) { // Unlike TryGetHeaderInfo() this method tries to parse all non-validated header values (if any) // before returning to the caller. + Debug.Assert(!info.IsEmpty); if (info.RawValue != null) { List? rawValues = info.RawValue as List; @@ -730,16 +753,12 @@ private bool ParseRawHeaderValues(HeaderDescriptor descriptor, HeaderStoreItemIn info.RawValue = null; // During parsing, we removed the value since it contains newline chars. Return false to indicate that - // this is an empty header. If the caller specified to remove empty headers, we'll remove the header before - // returning. + // this is an empty header. if ((info.InvalidValue == null) && (info.ParsedValue == null)) { - if (removeEmptyHeader) - { - // After parsing the raw value, no value is left because all values contain newline chars. - Debug.Assert(_headerStore != null); - _headerStore.Remove(descriptor); - } + // After parsing the raw value, no value is left because all values contain newline chars. + Debug.Assert(_count > 0); + Remove(descriptor); return false; } } @@ -808,7 +827,8 @@ internal bool TryParseAndAddValue(HeaderDescriptor descriptor, string? value) { // If we get here, then the value could be parsed correctly. If we created a new HeaderStoreItemInfo, add // it to the store if we added at least one value. - AddHeaderToStore(descriptor, info); + Debug.Assert(!ContainsKey(descriptor)); + AddEntryToStore(new HeaderEntry(descriptor, info)); } return result; @@ -1061,12 +1081,14 @@ internal bool TryGetHeaderDescriptor(string name, out HeaderDescriptor descripto if (HeaderDescriptor.TryGet(name, out descriptor)) { - if ((descriptor.HeaderType & _allowedHeaderTypes) != 0) + HttpHeaderType headerType = descriptor.HeaderType; + + if ((headerType & _allowedHeaderTypes) != 0) { return true; } - if ((descriptor.HeaderType & _treatAsCustomHeaderTypes) != 0) + if ((headerType & _treatAsCustomHeaderTypes) != 0) { descriptor = descriptor.AsCustomHeader(); return true; @@ -1128,7 +1150,8 @@ internal static void GetStoreValuesAsStringOrStringArray(HeaderDescriptor descri } else { - values = multiValue = length != 0 ? new string[length] : Array.Empty(); + Debug.Assert(length > 1, "The header should have been removed when it became empty"); + values = multiValue = new string[length]; } int currentIndex = 0; @@ -1252,5 +1275,234 @@ internal bool CanAddParsedValue(HttpHeaderParser parser) internal bool IsEmpty => (RawValue == null) && (InvalidValue == null) && (ParsedValue == null); } + + + #region Low-level implementation details that work with _headerStore directly + + // Used to store the CollectionsMarshal.GetValueRefOrAddDefault out parameter. + // This is a workaround for the Roslyn bug where we can't use a discard instead: + // https://github.com/dotnet/roslyn/issues/56587#issuecomment-934955526 + private static bool s_dictionaryGetValueRefOrAddDefaultExistsDummy; + + private const int InitialCapacity = 4; + internal const int ArrayThreshold = 64; // Above this threshold, header ordering will not be preserved + + internal HeaderEntry[]? GetEntriesArray() + { + object? store = _headerStore; + if (store is null) + { + return null; + } + else if (store is HeaderEntry[] entries) + { + return entries; + } + else + { + return GetEntriesFromDictionary(); + } + + HeaderEntry[] GetEntriesFromDictionary() + { + var dictionary = (Dictionary)_headerStore!; + var entries = new HeaderEntry[dictionary.Count]; + int i = 0; + foreach (KeyValuePair entry in dictionary) + { + entries[i++] = new HeaderEntry + { + Key = entry.Key, + Value = entry.Value + }; + } + return entries; + } + } + + internal ReadOnlySpan GetEntries() + { + return new ReadOnlySpan(GetEntriesArray(), 0, _count); + } + + internal int Count => _count; + + private bool EntriesAreLiveView => _headerStore is HeaderEntry[]; + + private ref object GetValueRefOrNullRef(HeaderDescriptor key) + { + ref object valueRef = ref Unsafe.NullRef(); + + object? store = _headerStore; + if (store is HeaderEntry[] entries) + { + for (int i = 0; i < _count && i < entries.Length; i++) + { + if (key.Equals(entries[i].Key)) + { + valueRef = ref entries[i].Value; + break; + } + } + } + else if (store is not null) + { + valueRef = ref CollectionsMarshal.GetValueRefOrNullRef(Unsafe.As>(store), key); + } + + return ref valueRef; + } + + private ref object? GetValueRefOrAddDefault(HeaderDescriptor key) + { + object? store = _headerStore; + if (store is HeaderEntry[] entries) + { + for (int i = 0; i < _count && i < entries.Length; i++) + { + if (key.Equals(entries[i].Key)) + { + return ref entries[i].Value!; + } + } + + int count = _count; + _count++; + if ((uint)count < (uint)entries.Length) + { + entries[count].Key = key; + return ref entries[count].Value!; + } + + return ref GrowEntriesAndAddDefault(key); + } + else if (store is null) + { + _count++; + entries = new HeaderEntry[InitialCapacity]; + _headerStore = entries; + ref HeaderEntry firstEntry = ref MemoryMarshal.GetArrayDataReference(entries); + firstEntry.Key = key; + return ref firstEntry.Value!; + } + else + { + return ref DictionaryGetValueRefOrAddDefault(key); + } + + ref object? GrowEntriesAndAddDefault(HeaderDescriptor key) + { + var entries = (HeaderEntry[])_headerStore!; + if (entries.Length == ArrayThreshold) + { + return ref ConvertToDictionaryAndAddDefault(key); + } + else + { + Array.Resize(ref entries, entries.Length << 1); + _headerStore = entries; + ref HeaderEntry firstNewEntry = ref entries[entries.Length >> 1]; + firstNewEntry.Key = key; + return ref firstNewEntry.Value!; + } + } + + ref object? ConvertToDictionaryAndAddDefault(HeaderDescriptor key) + { + var entries = (HeaderEntry[])_headerStore!; + var dictionary = new Dictionary(ArrayThreshold); + _headerStore = dictionary; + foreach (HeaderEntry entry in entries) + { + dictionary.Add(entry.Key, entry.Value); + } + Debug.Assert(dictionary.Count == _count - 1); + return ref CollectionsMarshal.GetValueRefOrAddDefault(dictionary, key, out s_dictionaryGetValueRefOrAddDefaultExistsDummy); + } + + ref object? DictionaryGetValueRefOrAddDefault(HeaderDescriptor key) + { + var dictionary = (Dictionary)_headerStore!; + ref object? value = ref CollectionsMarshal.GetValueRefOrAddDefault(dictionary, key, out s_dictionaryGetValueRefOrAddDefaultExistsDummy); + if (value is null) + { + _count++; + } + return ref value; + } + } + + private void AddEntryToStore(HeaderEntry entry) + { + Debug.Assert(!ContainsKey(entry.Key)); + + if (_headerStore is HeaderEntry[] entries) + { + int count = _count; + if ((uint)count < (uint)entries.Length) + { + entries[count] = entry; + _count++; + return; + } + } + + GetValueRefOrAddDefault(entry.Key) = entry.Value; + } + + internal bool ContainsKey(HeaderDescriptor key) + { + return !Unsafe.IsNullRef(ref GetValueRefOrNullRef(key)); + } + + public void Clear() + { + if (_headerStore is HeaderEntry[] entries) + { + Array.Clear(entries, 0, _count); + } + else + { + _headerStore = null; + } + _count = 0; + } + + internal bool Remove(HeaderDescriptor key) + { + bool removed = false; + + object? store = _headerStore; + if (store is HeaderEntry[] entries) + { + for (int i = 0; i < _count && i < entries.Length; i++) + { + if (key.Equals(entries[i].Key)) + { + while (i + 1 < _count && (uint)(i + 1) < (uint)entries.Length) + { + entries[i] = entries[i + 1]; + i++; + } + entries[i] = default; + removed = true; + break; + } + } + } + else if (store is not null) + { + removed = Unsafe.As>(store).Remove(key); + } + + if (removed) + { + _count--; + } + + return removed; + } + + #endregion // _headerStore implementation } } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs index 5e67476d116a99..b2d2f5c3a8727b 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/Headers/HttpHeadersNonValidated.cs @@ -25,7 +25,7 @@ namespace System.Net.Http.Headers /// Gets the number of headers stored in the collection. /// Multiple header values associated with the same header name are considered to be one header as far as this count is concerned. - public int Count => _headers?.HeaderStore?.Count ?? 0; + public int Count => _headers?.Count ?? 0; /// Gets whether the collection contains the specified header. /// The name of the header. @@ -33,7 +33,7 @@ namespace System.Net.Http.Headers public bool Contains(string headerName) => _headers is HttpHeaders headers && headers.TryGetHeaderDescriptor(headerName, out HeaderDescriptor descriptor) && - headers.TryGetHeaderValue(descriptor, out _); + headers.ContainsKey(descriptor); /// Gets the values for the specified header name. /// The name of the header. @@ -83,8 +83,8 @@ public bool TryGetValues(string headerName, out HeaderStringValues values) /// Gets an enumerator that iterates through the . /// An enumerator that iterates through the . public Enumerator GetEnumerator() => - _headers is HttpHeaders headers && headers.HeaderStore is Dictionary store ? - new Enumerator(store.GetEnumerator()) : + _headers is HttpHeaders headers && headers.GetEntriesArray() is HeaderEntry[] entries ? + new Enumerator(entries, headers.Count) : default; /// @@ -120,35 +120,34 @@ IEnumerable IReadOnlyDictionary. /// Enumerates the elements of a . public struct Enumerator : IEnumerator> { - /// The wrapped enumerator for the underlying headers dictionary. - private Dictionary.Enumerator _headerStoreEnumerator; - /// The current value. + private readonly HeaderEntry[] _entries; + private readonly int _numberOfEntries; + private int _index; private KeyValuePair _current; - /// true if the enumerator was constructed via the ctor; otherwise, false./ - private bool _valid; - /// Initializes the enumerator. - /// The underlying dictionary enumerator. - internal Enumerator(Dictionary.Enumerator headerStoreEnumerator) + internal Enumerator(HeaderEntry[] entries, int numberOfEntries) { - _headerStoreEnumerator = headerStoreEnumerator; + _entries = entries; + _numberOfEntries = numberOfEntries; + _index = 0; _current = default; - _valid = true; } /// public bool MoveNext() { - if (_valid && _headerStoreEnumerator.MoveNext()) + int index = _index; + if (_entries is HeaderEntry[] entries && index < _numberOfEntries && (uint)index < (uint)entries.Length) { - KeyValuePair current = _headerStoreEnumerator.Current; + HeaderEntry entry = entries[index]; + _index++; - HttpHeaders.GetStoreValuesAsStringOrStringArray(current.Key, current.Value, out string? singleValue, out string[]? multiValue); + HttpHeaders.GetStoreValuesAsStringOrStringArray(entry.Key, entry.Value, out string? singleValue, out string[]? multiValue); Debug.Assert(singleValue is not null ^ multiValue is not null); _current = new KeyValuePair( - current.Key.Name, - singleValue is not null ? new HeaderStringValues(current.Key, singleValue) : new HeaderStringValues(current.Key, multiValue!)); + entry.Key.Name, + singleValue is not null ? new HeaderStringValues(entry.Key, singleValue) : new HeaderStringValues(entry.Key, multiValue!)); return true; } diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs index 811d33dbfb4ef5..1a1b4b00bbb60f 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http2Connection.cs @@ -1337,15 +1337,10 @@ private void WriteHeaderCollection(HttpRequestMessage request, HttpHeaders heade { if (NetEventSource.Log.IsEnabled()) Trace(""); - if (headers.HeaderStore is null) - { - return; - } - HeaderEncodingSelector? encodingSelector = _pool.Settings._requestHeaderEncodingSelector; ref string[]? tmpHeaderValuesArray = ref t_headerValues; - foreach (KeyValuePair header in headers.HeaderStore) + foreach (HeaderEntry header in headers.GetEntries()) { int headerValuesCount = HttpHeaders.GetStoreValuesIntoStringArray(header.Key, header.Value, ref tmpHeaderValuesArray); Debug.Assert(headerValuesCount > 0, "No values for header??"); @@ -1361,7 +1356,7 @@ private void WriteHeaderCollection(HttpRequestMessage request, HttpHeaders heade // The Connection, Upgrade and ProxyConnection headers are also not supported in HTTP2. if (knownHeader != KnownHeaders.Host && knownHeader != KnownHeaders.Connection && knownHeader != KnownHeaders.Upgrade && knownHeader != KnownHeaders.ProxyConnection) { - if (header.Key.KnownHeader == KnownHeaders.TE) + if (knownHeader == KnownHeaders.TE) { // HTTP/2 allows only 'trailers' TE header. rfc7540 8.1.2.2 foreach (string value in headerValues) diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs index 3a236ccc65ae38..9a7ebccc416a8e 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs @@ -622,14 +622,9 @@ private void BufferHeaders(HttpRequestMessage request) // TODO: special-case Content-Type for static table values values? private void BufferHeaderCollection(HttpHeaders headers) { - if (headers.HeaderStore == null) - { - return; - } - HeaderEncodingSelector? encodingSelector = _connection.Pool.Settings._requestHeaderEncodingSelector; - foreach (KeyValuePair header in headers.HeaderStore) + foreach (HeaderEntry header in headers.GetEntries()) { int headerValuesCount = HttpHeaders.GetStoreValuesIntoStringArray(header.Key, header.Value, ref _headerValues); Debug.Assert(headerValuesCount > 0, "No values for header??"); @@ -645,7 +640,7 @@ private void BufferHeaderCollection(HttpHeaders headers) // The Connection, Upgrade and ProxyConnection headers are also not supported in HTTP/3. if (knownHeader != KnownHeaders.Host && knownHeader != KnownHeaders.Connection && knownHeader != KnownHeaders.Upgrade && knownHeader != KnownHeaders.ProxyConnection) { - if (header.Key.KnownHeader == KnownHeaders.TE) + if (knownHeader == KnownHeaders.TE) { // HTTP/2 allows only 'trailers' TE header. rfc7540 8.1.2.2 // HTTP/3 does not mention this one way or another; assume it has the same rule. diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs index 12b6b4bf3c55a0..deb0f5d9373d19 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs @@ -258,10 +258,12 @@ private async ValueTask WriteHeadersAsync(HttpHeaders headers, string? cookiesFr { Debug.Assert(_currentRequest != null); - if (headers.HeaderStore != null) + if (headers.GetEntriesArray() is HeaderEntry[] entries) { - foreach (KeyValuePair header in headers.HeaderStore) + for (int i = 0; i < headers.Count; i++) { + HeaderEntry header = entries[i]; + if (header.Key.KnownHeader != null) { await WriteBytesAsync(header.Key.KnownHeader.AsciiBytesWithColonSpace, async).ConfigureAwait(false); @@ -298,10 +300,10 @@ private async ValueTask WriteHeadersAsync(HttpHeaders headers, string? cookiesFr separator = parser.Separator!; } - for (int i = 1; i < headerValuesCount; i++) + for (int j = 1; j < headerValuesCount; j++) { await WriteAsciiStringAsync(separator, async).ConfigureAwait(false); - await WriteStringAsync(_headerValues[i], async, valueEncoding).ConfigureAwait(false); + await WriteStringAsync(_headerValues[j], async, valueEncoding).ConfigureAwait(false); } } } diff --git a/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs b/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs index c3b16afb47360d..d2f81304dd912a 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/HPack/HPackRoundtripTests.cs @@ -60,7 +60,7 @@ private static Memory HPackEncode(HttpHeaders headers, Encoding? valueEnco FillAvailableSpaceWithOnes(buffer); string[] headerValues = Array.Empty(); - foreach (KeyValuePair header in headers.HeaderStore) + foreach (HeaderEntry header in headers.GetEntries()) { int headerValuesCount = HttpHeaders.GetStoreValuesIntoStringArray(header.Key, header.Value, ref headerValues); Assert.InRange(headerValuesCount, 0, int.MaxValue); diff --git a/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs b/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs index 139422163a95bc..79196e77c75ea9 100644 --- a/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs +++ b/src/libraries/System.Net.Http/tests/UnitTests/Headers/HttpHeadersTest.cs @@ -1710,6 +1710,24 @@ public void GetEnumerator_UseExplicitInterfaceImplementation_EnumeratorReturnsNo Assert.False(enumerator.MoveNext(), "Only 2 values expected, but enumerator returns a third one."); } + [Fact] + public void GetEnumerator_InvalidValueBetweenValidHeaders_EnumeratorReturnsAllValidValuesAndRemovesInvalidValue() + { + MockHeaders headers = new MockHeaders(); + headers.TryAddWithoutValidation("foo", "fooValue"); + headers.TryAddWithoutValidation("invalid", "invalid\nvalue"); + headers.TryAddWithoutValidation("bar", "barValue"); + + Assert.Equal(3, headers.Count); + + IDictionary> dict = headers.ToDictionary(pair => pair.Key, pair => pair.Value); + Assert.Equal("fooValue", Assert.Single(Assert.Contains("foo", dict))); + Assert.Equal("barValue", Assert.Single(Assert.Contains("bar", dict))); + + Assert.Equal(2, headers.Count); + Assert.False(headers.NonValidated.Contains("invalid")); + } + [Fact] public void AddParsedValue_AddSingleValueToNonExistingHeader_HeaderGetsCreatedAndValueAdded() { @@ -2208,6 +2226,150 @@ public void HeaderStringValues_Constructed_ProducesExpectedResults() } } + [Theory] + [MemberData(nameof(NumberOfHeadersUpToArrayThreshold_AddNonValidated_EnumerateNonValidated))] + public void Add_WithinArrayThresholdHeaders_EnumerationPreservesOrdering(int numberOfHeaders, bool addNonValidated, bool enumerateNonValidated) + { + var headers = new MockHeaders(); + + for (int i = 0; i < numberOfHeaders; i++) + { + if (addNonValidated) + { + headers.TryAddWithoutValidation(i.ToString(), i.ToString()); + } + else + { + headers.Add(i.ToString(), i.ToString()); + } + } + + KeyValuePair[] entries = enumerateNonValidated + ? headers.NonValidated.Select(pair => KeyValuePair.Create(pair.Key, Assert.Single(pair.Value))).ToArray() + : headers.Select(pair => KeyValuePair.Create(pair.Key, Assert.Single(pair.Value))).ToArray(); + + Assert.Equal(numberOfHeaders, entries.Length); + for (int i = 0; i < numberOfHeaders; i++) + { + Assert.Equal(i.ToString(), entries[i].Key); + Assert.Equal(i.ToString(), entries[i].Value); + } + } + + [Fact] + public void Add_Remove_HeaderOrderingIsPreserved() + { + var headers = new MockHeaders(); + headers.Add("a", ""); + headers.Add("b", ""); + headers.Add("c", ""); + + headers.Remove("b"); + + Assert.Equal(new[] { "a", "c" }, headers.Select(pair => pair.Key)); + } + + [Fact] + public void Add_AddToExistingKey_OriginalOrderingIsPreserved() + { + var headers = new MockHeaders(); + headers.Add("a", "a1"); + headers.Add("b", "b1"); + headers.Add("a", "a2"); + + Assert.Equal(new[] { "a", "b" }, headers.Select(pair => pair.Key)); + } + + [Theory] + [InlineData(3)] + [InlineData(4)] + [InlineData(5)] + [InlineData(HttpHeaders.ArrayThreshold / 4)] + [InlineData(HttpHeaders.ArrayThreshold / 2)] + [InlineData(HttpHeaders.ArrayThreshold - 1)] + [InlineData(HttpHeaders.ArrayThreshold)] + [InlineData(HttpHeaders.ArrayThreshold + 1)] + [InlineData(HttpHeaders.ArrayThreshold * 2)] + [InlineData(HttpHeaders.ArrayThreshold * 4)] + public void Add_LargeNumberOfHeaders_OperationsStillSupported(int numberOfHeaders) + { + string[] keys = Enumerable.Range(1, numberOfHeaders).Select(i => i.ToString()).ToArray(); + + var headers = new MockHeaders(); + foreach (string key in keys) + { + Assert.False(headers.NonValidated.Contains(key)); + headers.TryAddWithoutValidation(key, key); + Assert.True(headers.NonValidated.Contains(key)); + } + + string[] nonValidatedKeys = headers.NonValidated.Select(pair => pair.Key).ToArray(); + Assert.Equal(numberOfHeaders, nonValidatedKeys.Length); + + string[] newKeys = headers.Select(pair => pair.Key).ToArray(); + Assert.Equal(numberOfHeaders, newKeys.Length); + + string[] nonValidatedKeysAfterValidation = headers.NonValidated.Select(pair => pair.Key).ToArray(); + Assert.Equal(numberOfHeaders, nonValidatedKeysAfterValidation.Length); + + if (numberOfHeaders > HttpHeaders.ArrayThreshold) + { + // Ordering is lost when adding more than ArrayThreshold headers + Array.Sort(nonValidatedKeys, (a, b) => int.Parse(a).CompareTo(int.Parse(b))); + Array.Sort(newKeys, (a, b) => int.Parse(a).CompareTo(int.Parse(b))); + Array.Sort(nonValidatedKeysAfterValidation, (a, b) => int.Parse(a).CompareTo(int.Parse(b))); + } + Assert.Equal(keys, nonValidatedKeys); + Assert.Equal(keys, newKeys); + Assert.Equal(keys, nonValidatedKeysAfterValidation); + + headers.Add("3", "secondValue"); + Assert.True(headers.TryGetValues("3", out IEnumerable valuesFor3)); + Assert.Equal(new[] { "3", "secondValue" }, valuesFor3); + + Assert.True(headers.TryAddWithoutValidation("invalid", "invalid\nvalue")); + Assert.True(headers.TryAddWithoutValidation("valid", "validValue")); + + Assert.Equal(numberOfHeaders + 2, headers.NonValidated.Count); + + // Remove all headers except for "1", "valid", "invalid" + for (int i = 2; i <= numberOfHeaders; i++) + { + Assert.True(headers.Remove(i.ToString())); + } + + Assert.False(headers.Remove("3")); + + // "1", "invalid", "valid" + Assert.True(headers.NonValidated.Contains("invalid")); + Assert.Equal(3, headers.NonValidated.Count); + + Assert.Equal(new[] { "1", "valid" }, headers.Select(pair => pair.Key).OrderBy(i => i)); + + Assert.Equal(2, headers.NonValidated.Count); + + headers.Clear(); + + Assert.Equal(0, headers.NonValidated.Count); + Assert.Empty(headers); + Assert.False(headers.Contains("3")); + + Assert.True(headers.TryAddWithoutValidation("3", "newValue")); + Assert.True(headers.TryGetValues("3", out valuesFor3)); + Assert.Equal(new[] { "newValue" }, valuesFor3); + } + + public static IEnumerable NumberOfHeadersUpToArrayThreshold_AddNonValidated_EnumerateNonValidated() + { + for (int i = 0; i <= HttpHeaders.ArrayThreshold; i++) + { + yield return new object[] { i, false, false }; + yield return new object[] { i, false, true }; + yield return new object[] { i, true, false }; + yield return new object[] { i, true, true }; + } + } + public static IEnumerable GetInvalidHeaderNames() { yield return new object[] { "invalid header" };