Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 59 additions & 7 deletions src/StreamJsonRpc/JsonRpc.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,6 @@ public class JsonRpc : IDisposableObservable, IJsonRpcFormatterCallbacks, IJsonR
/// </summary>
private readonly RpcTargetInfo rpcTargetInfo;

/// <summary>
/// Carries the value from a <see cref="JoinableTaskTokenHeaderName"/> when <see cref="JoinableTaskFactory"/> has not been set.
/// </summary>
private readonly System.Threading.AsyncLocal<string?> joinableTaskTokenWithoutJtf = new();

/// <summary>
/// List of remote RPC targets to call if connection should be relayed.
/// </summary>
Expand Down Expand Up @@ -122,6 +117,12 @@ public class JsonRpc : IDisposableObservable, IJsonRpcFormatterCallbacks, IJsonR
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private JoinableTaskFactory? joinableTaskFactory;

/// <summary>
/// Backing field for the <see cref="JoinableTaskTokenTracker"/> property.
/// </summary>
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private JoinableTaskTokenTracker joinableTaskTracker = JoinableTaskTokenTracker.Default;

/// <summary>
/// Backing field for the <see cref="CancellationStrategy"/> property.
/// </summary>
Expand Down Expand Up @@ -460,6 +461,33 @@ public JoinableTaskFactory? JoinableTaskFactory
}
}

/// <summary>
/// Gets or sets the <see cref="JoinableTaskTokenTracker"/> to use to correlate <see cref="JoinableTask"/> tokens.
/// This property is only applicable when <see cref="JoinableTaskFactory"/> is <see langword="null" />.
/// </summary>
/// <value>Defaults to an instance shared with all other <see cref="JsonRpc"/> instances that do not otherwise set this value explicitly.</value>
/// <remarks>
/// <para>
/// This property is ignored when <see cref="JoinableTaskFactory"/> is set to a non-<see langword="null" /> value.
/// </para>
/// <para>
/// This property should only be set explicitly when in an advanced scenario where one process has many <see cref="JsonRpc"/> instances
/// that interact with multiple remote processes such that avoiding correlating <see cref="JoinableTask"/> tokens across <see cref="JsonRpc"/> instances
/// is undesirable.
/// </para>
/// </remarks>
public JoinableTaskTokenTracker JoinableTaskTracker
{
get => this.joinableTaskTracker;

set
{
Requires.NotNull(value, nameof(value));
this.ThrowIfConfigurationLocked();
this.joinableTaskTracker = value;
}
}

/// <summary>
/// Gets a <see cref="Task"/> that completes when this instance is disposed or when listening has stopped
/// whether by error, disposal or the stream closing.
Expand Down Expand Up @@ -1928,7 +1956,7 @@ private JsonRpcError CreateCancellationResponse(JsonRpcRequest request)
JsonRpcEventSource.Instance.SendingRequest(request.RequestId.NumberIfPossibleForEvent, request.Method, JsonRpcEventSource.GetArgumentsString(request));
}

string? parentToken = this.JoinableTaskFactory is not null ? this.JoinableTaskFactory.Context.Capture() : this.joinableTaskTokenWithoutJtf.Value;
string? parentToken = this.JoinableTaskFactory is not null ? this.JoinableTaskFactory.Context.Capture() : this.JoinableTaskTracker.Token;
if (parentToken is not null)
{
request.TrySetTopLevelProperty(JoinableTaskTokenHeaderName, parentToken);
Expand Down Expand Up @@ -2087,7 +2115,7 @@ private async ValueTask<JsonRpcMessage> DispatchIncomingRequestAsync(JsonRpcRequ
request.TryGetTopLevelProperty<string>(JoinableTaskTokenHeaderName, out string? parentToken);
if (this.JoinableTaskFactory is null)
{
this.joinableTaskTokenWithoutJtf.Value = parentToken;
this.JoinableTaskTracker.Token = parentToken;
}

if (this.JoinableTaskFactory is null || parentToken is null)
Expand Down Expand Up @@ -2693,6 +2721,30 @@ private void ThrowIfConfigurationLocked()
}
}

/// <summary>
/// An object that correlates <see cref="JoinableTask"/> tokens within and between <see cref="JsonRpc"/> instances
/// within a process that does <em>not</em> use <see cref="JoinableTaskFactory"/>,
/// for purposes of mitigating deadlocks in processes that <em>do</em> use <see cref="JoinableTaskFactory"/>.
/// </summary>
public class JoinableTaskTokenTracker
{
/// <summary>
/// The default instance to use.
/// </summary>
internal static readonly JoinableTaskTokenTracker Default = new JoinableTaskTokenTracker();

/// <summary>
/// Carries the value from a <see cref="JoinableTaskTokenHeaderName"/> when <see cref="JoinableTaskFactory"/> has not been set.
/// </summary>
private readonly System.Threading.AsyncLocal<string?> joinableTaskTokenWithoutJtf = new();

internal string? Token
{
get => this.joinableTaskTokenWithoutJtf.Value;
set => this.joinableTaskTokenWithoutJtf.Value = value;
}
}

private class OutstandingCallData
{
internal OutstandingCallData(object taskCompletionSource, Action<JsonRpcMessage?> completionHandler, Type? expectedResultType)
Expand Down
4 changes: 4 additions & 0 deletions src/StreamJsonRpc/netstandard2.0/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
StreamJsonRpc.JsonRpc.JoinableTaskTokenTracker
StreamJsonRpc.JsonRpc.JoinableTaskTokenTracker.JoinableTaskTokenTracker() -> void
StreamJsonRpc.JsonRpc.JoinableTaskTracker.get -> StreamJsonRpc.JsonRpc.JoinableTaskTokenTracker!
StreamJsonRpc.JsonRpc.JoinableTaskTracker.set -> void
4 changes: 4 additions & 0 deletions src/StreamJsonRpc/netstandard2.1/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
StreamJsonRpc.JsonRpc.JoinableTaskTokenTracker
StreamJsonRpc.JsonRpc.JoinableTaskTokenTracker.JoinableTaskTokenTracker() -> void
StreamJsonRpc.JsonRpc.JoinableTaskTracker.get -> StreamJsonRpc.JsonRpc.JoinableTaskTokenTracker!
StreamJsonRpc.JsonRpc.JoinableTaskTracker.set -> void
21 changes: 14 additions & 7 deletions test/StreamJsonRpc.Tests/JsonRpcJsonHeadersTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,16 @@ public async Task Completion_FaultsOnFatalError()
Assert.Same(completion, this.serverRpc.Completion);
}

protected override void InitializeFormattersAndHandlers(bool controlledFlushingClient)
protected override void InitializeFormattersAndHandlers(
Stream serverStream,
Stream clientStream,
out IJsonRpcMessageFormatter serverMessageFormatter,
out IJsonRpcMessageFormatter clientMessageFormatter,
out IJsonRpcMessageHandler serverMessageHandler,
out IJsonRpcMessageHandler clientMessageHandler,
bool controlledFlushingClient)
{
this.clientMessageFormatter = new JsonMessageFormatter
clientMessageFormatter = new JsonMessageFormatter
{
JsonSerializer =
{
Expand All @@ -138,7 +145,7 @@ protected override void InitializeFormattersAndHandlers(bool controlledFlushingC
},
},
};
this.serverMessageFormatter = new JsonMessageFormatter
serverMessageFormatter = new JsonMessageFormatter
{
JsonSerializer =
{
Expand All @@ -150,10 +157,10 @@ protected override void InitializeFormattersAndHandlers(bool controlledFlushingC
},
};

this.serverMessageHandler = new HeaderDelimitedMessageHandler(this.serverStream, this.serverStream, this.serverMessageFormatter);
this.clientMessageHandler = controlledFlushingClient
? new DelayedFlushingHandler(this.clientStream, this.clientMessageFormatter)
: new HeaderDelimitedMessageHandler(this.clientStream, this.clientStream, this.clientMessageFormatter);
serverMessageHandler = new HeaderDelimitedMessageHandler(serverStream, serverStream, serverMessageFormatter);
clientMessageHandler = controlledFlushingClient
? new DelayedFlushingHandler(clientStream, clientMessageFormatter)
: new HeaderDelimitedMessageHandler(clientStream, clientStream, clientMessageFormatter);
}

protected class UnserializableTypeConverter : JsonConverter
Expand Down
27 changes: 16 additions & 11 deletions test/StreamJsonRpc.Tests/JsonRpcJsonHeadersTypeHandlingTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

using System.Text;
using Newtonsoft.Json;
using StreamJsonRpc;
using Xunit.Abstractions;

public class JsonRpcJsonHeadersTypeHandlingTests : JsonRpcJsonHeadersTests
{
Expand All @@ -13,13 +11,20 @@ public JsonRpcJsonHeadersTypeHandlingTests(ITestOutputHelper logger)
{
}

protected override void InitializeFormattersAndHandlers(bool controlledFlushingClient)
protected override void InitializeFormattersAndHandlers(
Stream serverStream,
Stream clientStream,
out IJsonRpcMessageFormatter serverMessageFormatter,
out IJsonRpcMessageFormatter clientMessageFormatter,
out IJsonRpcMessageHandler serverMessageHandler,
out IJsonRpcMessageHandler clientMessageHandler,
bool controlledFlushingClient)
{
this.serverMessageFormatter = new JsonMessageFormatter(new UTF8Encoding(encoderShouldEmitUTF8Identifier: false))
serverMessageFormatter = new JsonMessageFormatter(new UTF8Encoding(encoderShouldEmitUTF8Identifier: false))
{
JsonSerializer =
{
TypeNameHandling = Newtonsoft.Json.TypeNameHandling.Objects,
TypeNameHandling = TypeNameHandling.Objects,
TypeNameAssemblyFormatHandling = TypeNameAssemblyFormatHandling.Simple,
Converters =
{
Expand All @@ -29,11 +34,11 @@ protected override void InitializeFormattersAndHandlers(bool controlledFlushingC
},
};

this.clientMessageFormatter = new JsonMessageFormatter(new UTF8Encoding(encoderShouldEmitUTF8Identifier: false))
clientMessageFormatter = new JsonMessageFormatter(new UTF8Encoding(encoderShouldEmitUTF8Identifier: false))
{
JsonSerializer =
{
TypeNameHandling = Newtonsoft.Json.TypeNameHandling.Objects,
TypeNameHandling = TypeNameHandling.Objects,
TypeNameAssemblyFormatHandling = TypeNameAssemblyFormatHandling.Simple,
Converters =
{
Expand All @@ -43,9 +48,9 @@ protected override void InitializeFormattersAndHandlers(bool controlledFlushingC
},
};

this.serverMessageHandler = new HeaderDelimitedMessageHandler(this.serverStream, this.serverStream, this.serverMessageFormatter);
this.clientMessageHandler = controlledFlushingClient
? new DelayedFlushingHandler(this.clientStream, this.clientMessageFormatter)
: new HeaderDelimitedMessageHandler(this.clientStream, this.clientStream, this.clientMessageFormatter);
serverMessageHandler = new HeaderDelimitedMessageHandler(serverStream, serverStream, serverMessageFormatter);
clientMessageHandler = controlledFlushingClient
? new DelayedFlushingHandler(clientStream, clientMessageFormatter)
: new HeaderDelimitedMessageHandler(clientStream, clientStream, clientMessageFormatter);
}
}
25 changes: 16 additions & 9 deletions test/StreamJsonRpc.Tests/JsonRpcMessagePackLengthTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -383,22 +383,29 @@ public async Task VerboseLoggingDoesNotFailWhenArgsDoNotDeserializePrimitively(b
Assert.True(await clientProxy.IsExtensionArgNonNull(new CustomExtensionType()));
}

protected override void InitializeFormattersAndHandlers(bool controlledFlushingClient)
protected override void InitializeFormattersAndHandlers(
Stream serverStream,
Stream clientStream,
out IJsonRpcMessageFormatter serverMessageFormatter,
out IJsonRpcMessageFormatter clientMessageFormatter,
out IJsonRpcMessageHandler serverMessageHandler,
out IJsonRpcMessageHandler clientMessageHandler,
bool controlledFlushingClient)
{
this.serverMessageFormatter = new MessagePackFormatter();
this.clientMessageFormatter = new MessagePackFormatter();
serverMessageFormatter = new MessagePackFormatter();
clientMessageFormatter = new MessagePackFormatter();

var options = MessagePackFormatter.DefaultUserDataSerializationOptions
.WithResolver(CompositeResolver.Create(
new IMessagePackFormatter[] { new UnserializableTypeFormatter(), new TypeThrowsWhenDeserializedFormatter(), new CustomExtensionFormatter() },
new IFormatterResolver[] { StandardResolverAllowPrivate.Instance }));
((MessagePackFormatter)this.serverMessageFormatter).SetMessagePackSerializerOptions(options);
((MessagePackFormatter)this.clientMessageFormatter).SetMessagePackSerializerOptions(options);
((MessagePackFormatter)serverMessageFormatter).SetMessagePackSerializerOptions(options);
((MessagePackFormatter)clientMessageFormatter).SetMessagePackSerializerOptions(options);

this.serverMessageHandler = new LengthHeaderMessageHandler(this.serverStream, this.serverStream, this.serverMessageFormatter);
this.clientMessageHandler = controlledFlushingClient
? new DelayedFlushingHandler(this.clientStream, this.clientMessageFormatter)
: new LengthHeaderMessageHandler(this.clientStream, this.clientStream, this.clientMessageFormatter);
serverMessageHandler = new LengthHeaderMessageHandler(serverStream, serverStream, serverMessageFormatter);
clientMessageHandler = controlledFlushingClient
? new DelayedFlushingHandler(clientStream, clientMessageFormatter)
: new LengthHeaderMessageHandler(clientStream, clientStream, clientMessageFormatter);
}

[MessagePackObject]
Expand Down
21 changes: 14 additions & 7 deletions test/StreamJsonRpc.Tests/JsonRpcSystemTextJsonHeadersTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,16 @@ public override async Task CanPassExceptionFromServer_ErrorData()
Assert.StrictEqual(COR_E_UNAUTHORIZEDACCESS, errorData.HResult);
}

protected override void InitializeFormattersAndHandlers(bool controlledFlushingClient)
protected override void InitializeFormattersAndHandlers(
Stream serverStream,
Stream clientStream,
out IJsonRpcMessageFormatter serverMessageFormatter,
out IJsonRpcMessageFormatter clientMessageFormatter,
out IJsonRpcMessageHandler serverMessageHandler,
out IJsonRpcMessageHandler clientMessageHandler,
bool controlledFlushingClient)
{
this.clientMessageFormatter = new SystemTextJsonFormatter
clientMessageFormatter = new SystemTextJsonFormatter
{
JsonSerializerOptions =
{
Expand All @@ -37,7 +44,7 @@ protected override void InitializeFormattersAndHandlers(bool controlledFlushingC
},
},
};
this.serverMessageFormatter = new SystemTextJsonFormatter
serverMessageFormatter = new SystemTextJsonFormatter
{
JsonSerializerOptions =
{
Expand All @@ -48,10 +55,10 @@ protected override void InitializeFormattersAndHandlers(bool controlledFlushingC
},
};

this.serverMessageHandler = new HeaderDelimitedMessageHandler(this.serverStream, this.serverStream, this.serverMessageFormatter);
this.clientMessageHandler = controlledFlushingClient
? new DelayedFlushingHandler(this.clientStream, this.clientMessageFormatter)
: new HeaderDelimitedMessageHandler(this.clientStream, this.clientStream, this.clientMessageFormatter);
serverMessageHandler = new HeaderDelimitedMessageHandler(serverStream, serverStream, serverMessageFormatter);
clientMessageHandler = controlledFlushingClient
? new DelayedFlushingHandler(clientStream, clientMessageFormatter)
: new HeaderDelimitedMessageHandler(clientStream, clientStream, clientMessageFormatter);
}

protected class DelayedFlushingHandler : HeaderDelimitedMessageHandler, IControlledFlushHandler
Expand Down
Loading