Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace Microsoft.Azure.Cosmos
internal class CrossRegionHedgingAvailabilityStrategy : AvailabilityStrategyInternal
{
private const string HedgeContext = "Hedge Context";
private const string ResponseRegion = "Response Region";
private const string HedgeConfig = "Hedge Config";

/// <summary>
/// Latency threshold which activates the first region hedging
Expand All @@ -44,6 +44,8 @@ internal class CrossRegionHedgingAvailabilityStrategy : AvailabilityStrategyInte
/// </summary>
public bool EnableMultiWriteRegionHedge { get; private set; }

private readonly string HedgeConfigText;

/// <summary>
/// Constructor for hedging availability strategy
/// </summary>
Expand All @@ -68,6 +70,8 @@ public CrossRegionHedgingAvailabilityStrategy(
this.Threshold = threshold;
this.ThresholdStep = thresholdStep ?? TimeSpan.FromMilliseconds(-1);
this.EnableMultiWriteRegionHedge = enableMultiWriteRegionHedge;

this.HedgeConfigText = $"t:{this.Threshold.TotalMilliseconds}ms, s:{this.ThresholdStep.TotalMilliseconds}ms, w:{this.EnableMultiWriteRegionHedge}";
}

/// <inheritdoc/>
Expand Down Expand Up @@ -133,121 +137,125 @@ internal override async Task<ResponseMessage> ExecuteAvailabilityStrategyAsync(
? null
: await StreamExtension.AsClonableStreamAsync(request.Content)))
{
IReadOnlyCollection<string> hedgeRegions = client.DocumentClient.GlobalEndpointManager
.GetApplicableRegions(
request.RequestOptions?.ExcludeRegions,
OperationTypeExtensions.IsReadOperation(request.OperationType));
using (RequestMessage nonModifiedRequestClone = request.Clone(trace, clonedBody))
{
IReadOnlyCollection<string> hedgeRegions = client.DocumentClient.GlobalEndpointManager
.GetApplicableRegions(
request.RequestOptions?.ExcludeRegions,
OperationTypeExtensions.IsReadOperation(request.OperationType));

List<Task> requestTasks = new List<Task>(hedgeRegions.Count + 1);
List<Task> requestTasks = new List<Task>(hedgeRegions.Count + 1);

Task<HedgingResponse> primaryRequest = null;
HedgingResponse hedgeResponse = null;

//Send out hedged requests
for (int requestNumber = 0; requestNumber < hedgeRegions.Count; requestNumber++)
{
TimeSpan awaitTime = requestNumber == 0 ? this.Threshold : this.ThresholdStep;
Task<HedgingResponse> primaryRequest = null;
HedgingResponse hedgeResponse = null;

using (CancellationTokenSource timerTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken))
//Send out hedged requests
for (int requestNumber = 0; requestNumber < hedgeRegions.Count; requestNumber++)
{
CancellationToken timerToken = timerTokenSource.Token;
using (Task hedgeTimer = Task.Delay(awaitTime, timerToken))
{
if (requestNumber == 0)
{
primaryRequest = this.RequestSenderAndResultCheckAsync(
sender,
request,
hedgeRegions.ElementAt(requestNumber),
cancellationToken,
cancellationTokenSource,
trace);

requestTasks.Add(primaryRequest);
}
else
{
Task<HedgingResponse> requestTask = this.CloneAndSendAsync(
sender: sender,
request: request,
clonedBody: clonedBody,
hedgeRegions: hedgeRegions,
requestNumber: requestNumber,
trace: trace,
cancellationToken: cancellationToken,
cancellationTokenSource: cancellationTokenSource);

requestTasks.Add(requestTask);
}

requestTasks.Add(hedgeTimer);

Task completedTask = await Task.WhenAny(requestTasks);
requestTasks.Remove(completedTask);

if (completedTask == hedgeTimer)
{
continue;
}

timerTokenSource.Cancel();
requestTasks.Remove(hedgeTimer);
TimeSpan awaitTime = requestNumber == 0 ? this.Threshold : this.ThresholdStep;

if (completedTask.IsFaulted)
{
AggregateException innerExceptions = completedTask.Exception.Flatten();
}

hedgeResponse = await (Task<HedgingResponse>)completedTask;
if (hedgeResponse.IsNonTransient)
using (CancellationTokenSource timerTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken))
{
CancellationToken timerToken = timerTokenSource.Token;
using (Task hedgeTimer = Task.Delay(awaitTime, timerToken))
{
cancellationTokenSource.Cancel();
//Take is not inclusive, so we need to add 1 to the request number which starts at 0
((CosmosTraceDiagnostics)hedgeResponse.ResponseMessage.Diagnostics).Value.AddOrUpdateDatum(
HedgeContext,
hedgeRegions.Take(requestNumber + 1));
((CosmosTraceDiagnostics)hedgeResponse.ResponseMessage.Diagnostics).Value.AddOrUpdateDatum(
ResponseRegion,
hedgeResponse.ResponseRegion);
return hedgeResponse.ResponseMessage;
if (requestNumber == 0)
{
primaryRequest = this.RequestSenderAndResultCheckAsync(
sender,
request,
hedgeRegions.ElementAt(requestNumber),
cancellationToken,
cancellationTokenSource,
trace);

requestTasks.Add(primaryRequest);
}
else
{
Task<HedgingResponse> requestTask = this.CloneAndSendAsync(
sender: sender,
request: nonModifiedRequestClone,
clonedBody: clonedBody,
hedgeRegions: hedgeRegions,
requestNumber: requestNumber,
trace: trace,
cancellationToken: cancellationToken,
cancellationTokenSource: cancellationTokenSource);

requestTasks.Add(requestTask);
}

requestTasks.Add(hedgeTimer);

Task completedTask = await Task.WhenAny(requestTasks);
requestTasks.Remove(completedTask);

if (completedTask == hedgeTimer)
{
continue;
}

timerTokenSource.Cancel();
requestTasks.Remove(hedgeTimer);

if (completedTask.IsFaulted)
{
AggregateException innerExceptions = completedTask.Exception.Flatten();
}

hedgeResponse = await (Task<HedgingResponse>)completedTask;
if (hedgeResponse.IsNonTransient)
{
cancellationTokenSource.Cancel();

((CosmosTraceDiagnostics)hedgeResponse.ResponseMessage.Diagnostics).Value.AddOrUpdateDatum(
HedgeConfig,
this.HedgeConfigText);
//Take is not inclusive, so we need to add 1 to the request number which starts at 0
((CosmosTraceDiagnostics)hedgeResponse.ResponseMessage.Diagnostics).Value.AddOrUpdateDatum(
HedgeContext,
hedgeRegions.Take(requestNumber + 1));
return hedgeResponse.ResponseMessage;
}
}
}
}
}

//Wait for a good response from the hedged requests/primary request
Exception lastException = null;
while (requestTasks.Any())
{
Task completedTask = await Task.WhenAny(requestTasks);
requestTasks.Remove(completedTask);
if (completedTask.IsFaulted)
//Wait for a good response from the hedged requests/primary request
Exception lastException = null;
while (requestTasks.Any())
{
AggregateException innerExceptions = completedTask.Exception.Flatten();
lastException = innerExceptions.InnerExceptions.FirstOrDefault();
Task completedTask = await Task.WhenAny(requestTasks);
requestTasks.Remove(completedTask);
if (completedTask.IsFaulted)
{
AggregateException innerExceptions = completedTask.Exception.Flatten();
lastException = innerExceptions.InnerExceptions.FirstOrDefault();
}

hedgeResponse = await (Task<HedgingResponse>)completedTask;
if (hedgeResponse.IsNonTransient || requestTasks.Count == 0)
{
cancellationTokenSource.Cancel();
((CosmosTraceDiagnostics)hedgeResponse.ResponseMessage.Diagnostics).Value.AddOrUpdateDatum(
HedgeConfig,
this.HedgeConfigText);
((CosmosTraceDiagnostics)hedgeResponse.ResponseMessage.Diagnostics).Value.AddOrUpdateDatum(
HedgeContext,
hedgeRegions);
return hedgeResponse.ResponseMessage;
}
}

hedgeResponse = await (Task<HedgingResponse>)completedTask;
if (hedgeResponse.IsNonTransient || requestTasks.Count == 0)
if (lastException != null)
{
cancellationTokenSource.Cancel();
((CosmosTraceDiagnostics)hedgeResponse.ResponseMessage.Diagnostics).Value.AddOrUpdateDatum(
HedgeContext,
hedgeRegions);
((CosmosTraceDiagnostics)hedgeResponse.ResponseMessage.Diagnostics).Value.AddOrUpdateDatum(
ResponseRegion,
hedgeResponse.ResponseRegion);
return hedgeResponse.ResponseMessage;
throw lastException;
}
}

if (lastException != null)
{
throw lastException;
Debug.Assert(hedgeResponse != null);
return hedgeResponse.ResponseMessage;
}

Debug.Assert(hedgeResponse != null);
return hedgeResponse.ResponseMessage;
}
}
}
Expand Down Expand Up @@ -303,12 +311,12 @@ private async Task<HedgingResponse> RequestSenderAndResultCheckAsync(
cancellationTokenSource.Cancel();
}

return new HedgingResponse(true, response, hedgedRegion);
return new HedgingResponse(true, response);
}

return new HedgingResponse(false, response, hedgedRegion);
return new HedgingResponse(false, response);
}
catch (OperationCanceledException oce ) when (cancellationTokenSource.IsCancellationRequested)
catch (OperationCanceledException oce) when (cancellationTokenSource.IsCancellationRequested)
{
throw new CosmosOperationCanceledException(oce, trace);
}
Expand Down Expand Up @@ -348,13 +356,11 @@ private sealed class HedgingResponse
{
public readonly bool IsNonTransient;
public readonly ResponseMessage ResponseMessage;
public readonly string ResponseRegion;

public HedgingResponse(bool isNonTransient, ResponseMessage responseMessage, string responseRegion)
public HedgingResponse(bool isNonTransient, ResponseMessage responseMessage)
{
this.IsNonTransient = isNonTransient;
this.ResponseMessage = responseMessage;
this.ResponseRegion = responseRegion;
}
}
}
Expand Down
Loading
Loading