Skip to content

Commit 2eb556c

Browse files
update connect timeout
1 parent dcccdd0 commit 2eb556c

3 files changed

Lines changed: 90 additions & 25 deletions

File tree

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Diagnostics;
67
using System.Net;
78
using System.Net.Security;
89
using System.Security.Cryptography.X509Certificates;
10+
using System.Threading;
11+
using System.Threading.Tasks;
912

1013
namespace Microsoft.Data.SqlClient.SNI
1114
{
@@ -194,7 +197,7 @@ internal static bool ValidateSslServerCertificate(string targetServerName, X509C
194197
return true;
195198
}
196199
}
197-
200+
198201
/// <summary>
199202
/// We validate the provided certificate provided by the client with the one from the server to see if it matches.
200203
/// Certificate validation and chain trust validations are done by SSLStream class [System.Net.Security.SecureChannel.VerifyRemoteCertificate method]
@@ -239,6 +242,22 @@ internal static bool ValidateSslServerCertificate(X509Certificate clientCert, X5
239242
}
240243
}
241244

245+
internal static IPAddress[] GetDnsIpAddresses(string serverName, ref TimeSpan timeout)
246+
{
247+
using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses)))
248+
{
249+
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, "Getting DNS host entries for serverName {0} with {1} timeout.", args0: serverName, args1: timeout);
250+
using CancellationTokenSource cts = new CancellationTokenSource(timeout);
251+
Stopwatch stopwatch = Stopwatch.StartNew();
252+
// using this overload to support netstandard
253+
Task<IPAddress[]> task = Dns.GetHostAddressesAsync(serverName);
254+
task.ConfigureAwait(false);
255+
task.Wait(cts.Token);
256+
timeout -= stopwatch.Elapsed;
257+
return task.Result;
258+
}
259+
}
260+
242261
internal static IPAddress[] GetDnsIpAddresses(string serverName)
243262
{
244263
using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses)))

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs

Lines changed: 64 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ public SNITCPHandle(
164164
ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts;
165165
}
166166

167+
Stopwatch stopwatch = Stopwatch.StartNew();
168+
167169
bool reportError = true;
168170

169171
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Connecting to serverName {1} and port {2}", args0: _connectionId, args1: serverName, args2: port);
@@ -183,6 +185,11 @@ public SNITCPHandle(
183185
}
184186
catch (Exception ex)
185187
{
188+
TimeSpan timeLeft = ts - stopwatch.Elapsed;
189+
if (!isInfiniteTimeOut && timeLeft <= TimeSpan.Zero)
190+
{
191+
throw;
192+
}
186193
// Retry with cached IP address
187194
if (ex is SocketException || ex is ArgumentException || ex is AggregateException)
188195
{
@@ -214,26 +221,31 @@ public SNITCPHandle(
214221
{
215222
if (parallel)
216223
{
217-
_socket = TryConnectParallel(firstCachedIP, portRetry, ts, isInfiniteTimeOut, ref reportError, cachedFQDN, ref pendingDNSInfo);
224+
_socket = TryConnectParallel(firstCachedIP, portRetry, timeLeft, isInfiniteTimeOut, ref reportError, cachedFQDN, ref pendingDNSInfo);
218225
}
219226
else
220227
{
221-
_socket = Connect(firstCachedIP, portRetry, ts, isInfiniteTimeOut, ipPreference, cachedFQDN, ref pendingDNSInfo);
228+
_socket = Connect(firstCachedIP, portRetry, timeLeft, isInfiniteTimeOut, ipPreference, cachedFQDN, ref pendingDNSInfo);
222229
}
223230
}
224231
catch (Exception exRetry)
225232
{
233+
timeLeft = ts - stopwatch.Elapsed;
234+
if (!isInfiniteTimeOut && timeLeft <= TimeSpan.Zero)
235+
{
236+
throw;
237+
}
226238
if (exRetry is SocketException || exRetry is ArgumentNullException
227239
|| exRetry is ArgumentException || exRetry is ArgumentOutOfRangeException || exRetry is AggregateException)
228240
{
229241
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Retrying exception {1}", args0: _connectionId, args1: exRetry?.Message);
230242
if (parallel)
231243
{
232-
_socket = TryConnectParallel(secondCachedIP, portRetry, ts, isInfiniteTimeOut, ref reportError, cachedFQDN, ref pendingDNSInfo);
244+
_socket = TryConnectParallel(secondCachedIP, portRetry, timeLeft, isInfiniteTimeOut, ref reportError, cachedFQDN, ref pendingDNSInfo);
233245
}
234246
else
235247
{
236-
_socket = Connect(secondCachedIP, portRetry, ts, isInfiniteTimeOut, ipPreference, cachedFQDN, ref pendingDNSInfo);
248+
_socket = Connect(secondCachedIP, portRetry, timeLeft, isInfiniteTimeOut, ipPreference, cachedFQDN, ref pendingDNSInfo);
237249
}
238250
}
239251
else
@@ -249,6 +261,10 @@ public SNITCPHandle(
249261
throw;
250262
}
251263
}
264+
finally
265+
{
266+
stopwatch.Stop();
267+
}
252268

253269
if (_socket == null || !_socket.Connected)
254270
{
@@ -304,8 +320,11 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i
304320
{
305321
Socket availableSocket = null;
306322
Task<Socket> connectTask;
323+
TimeSpan timeout = ts;
307324

308-
IPAddress[] serverAddresses = SNICommon.GetDnsIpAddresses(hostName);
325+
IPAddress[] serverAddresses = isInfiniteTimeOut
326+
? SNICommon.GetDnsIpAddresses(hostName)
327+
: SNICommon.GetDnsIpAddresses(hostName, ref timeout);
309328

310329
if (serverAddresses.Length > MaxParallelIpAddresses)
311330
{
@@ -338,7 +357,7 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i
338357

339358
connectTask = ParallelConnectAsync(serverAddresses, port);
340359

341-
if (!(isInfiniteTimeOut ? connectTask.Wait(-1) : connectTask.Wait(ts)))
360+
if (!(isInfiniteTimeOut ? connectTask.Wait(-1) : connectTask.Wait(timeout)))
342361
{
343362
callerReportError = false;
344363
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "Connection Id {0} Connection timed out, Exception: {1}", args0: _connectionId, args1: Strings.SNI_ERROR_40);
@@ -349,7 +368,7 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i
349368
availableSocket = connectTask.Result;
350369
return availableSocket;
351370
}
352-
371+
353372
/// <summary>
354373
/// Returns array of IP addresses for the given server name, sorted according to the given preference.
355374
/// </summary>
@@ -389,7 +408,7 @@ private static IEnumerable<IPAddress> GetHostAddressesSortedByPreference(string
389408
}
390409
}
391410
}
392-
411+
393412
// Connect to server with hostName and port.
394413
// The IP information will be collected temporarily as the pendingDNSInfo but is not stored in the DNS cache at this point.
395414
// Only write to the DNS cache when we receive IsSupported flag as true in the Feature Ext Ack from server.
@@ -422,26 +441,44 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo
422441
port,
423442
ipAddress.AddressFamily,
424443
isInfiniteTimeout);
425-
444+
426445
bool isConnected;
427446
try // catching SocketException with SocketErrorCode == WouldBlock to run Socket.Select
428447
{
429-
socket.Connect(ipAddress, port);
430-
if (!isInfiniteTimeout)
448+
if (isInfiniteTimeout)
449+
{
450+
socket.Connect(ipAddress, port);
451+
}
452+
else
431453
{
454+
TimeSpan timeLeft = timeout - timeTaken.Elapsed;
455+
if (timeLeft <= TimeSpan.Zero)
456+
{
457+
return null;
458+
}
459+
// Socket.Connect does not support infinite timeouts, so we use Task to simulate it
460+
Task socketConnectTask = new Task(() => socket.Connect(ipAddress, port));
461+
socketConnectTask.ConfigureAwait(false);
462+
socketConnectTask.Start();
463+
if (!socketConnectTask.Wait(timeLeft))
464+
{
465+
throw ADP.TimeoutException($"The socket couldn't connect during the expected {timeLeft} remaining time to connect.");
466+
}
432467
throw SQL.SocketDidNotThrow();
433468
}
434-
469+
435470
isConnected = true;
436471
}
437-
catch (SocketException socketException) when (!isInfiniteTimeout &&
438-
socketException.SocketErrorCode ==
439-
SocketError.WouldBlock)
472+
catch (AggregateException aggregateException) when (!isInfiniteTimeout
473+
&& aggregateException.InnerException is SocketException socketException
474+
&& socketException.SocketErrorCode == SocketError.WouldBlock)
440475
{
441476
// https://github.com/dotnet/SqlClient/issues/826#issuecomment-736224118
442477
// Socket.Select is used because it supports timeouts, while Socket.Connect does not
443478

444-
List<Socket> checkReadLst; List<Socket> checkWriteLst; List<Socket> checkErrorLst;
479+
List<Socket> checkReadLst;
480+
List<Socket> checkWriteLst;
481+
List<Socket> checkErrorLst;
445482

446483
// Repeating Socket.Select several times if our timeout is greater
447484
// than int.MaxValue microseconds because of
@@ -450,9 +487,10 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo
450487
do
451488
{
452489
TimeSpan timeLeft = timeout - timeTaken.Elapsed;
453-
454-
if (timeLeft <= TimeSpan.Zero)
490+
if (!isInfiniteTimeout && timeLeft <= TimeSpan.Zero)
491+
{
455492
return null;
493+
}
456494

457495
int socketSelectTimeout =
458496
checked((int)(Math.Min(timeLeft.TotalMilliseconds, int.MaxValue / 1000) * 1000));
@@ -487,11 +525,15 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo
487525
return socket;
488526
}
489527
}
490-
catch (SocketException e)
528+
catch (AggregateException aggregateException) when (aggregateException.InnerException is SocketException socketException)
491529
{
492-
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "THIS EXCEPTION IS BEING SWALLOWED: {0}", args0: e?.Message);
530+
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "THIS EXCEPTION IS BEING SWALLOWED: {0}", args0: socketException?.Message);
493531
SqlClientEventSource.Log.TryAdvancedTraceEvent(
494-
$"{nameof(SNITCPHandle)}.{nameof(Connect)}{EventType.ERR}THIS EXCEPTION IS BEING SWALLOWED: {e}");
532+
$"{nameof(SNITCPHandle)}.{nameof(Connect)}{EventType.ERR}THIS EXCEPTION IS BEING SWALLOWED: {socketException}");
533+
}
534+
catch (AggregateException aggregateException) when (aggregateException.InnerException is TimeoutException timeoutException)
535+
{
536+
Console.WriteLine(timeoutException); // temporary for testing
495537
}
496538
finally
497539
{
@@ -675,7 +717,7 @@ private bool ValidateServerCertificate(object sender, X509Certificate serverCert
675717
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Certificate will not be validated.", args0: _connectionId);
676718
return true;
677719
}
678-
720+
679721
string serverNameToValidate;
680722
if (!string.IsNullOrEmpty(_hostNameInCertificate))
681723
{

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,13 +189,17 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re
189189
TimeSpan ts = default;
190190
// In case the Timeout is Infinite, we will receive the max value of Int64 as the tick count
191191
// The infinite Timeout is a function of ConnectionString Timeout=0
192-
if (long.MaxValue != timerExpire)
192+
bool isInfiniteTimeout = long.MaxValue == timerExpire;
193+
if (!isInfiniteTimeout)
193194
{
194195
ts = DateTime.FromFileTime(timerExpire) - DateTime.Now;
195196
ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts;
196197
}
197198

198-
IPAddress[] ipAddresses = SNICommon.GetDnsIpAddresses(browserHostname);
199+
IPAddress[] ipAddresses = isInfiniteTimeout
200+
? SNICommon.GetDnsIpAddresses(browserHostname)
201+
: SNICommon.GetDnsIpAddresses(browserHostname, ref ts);
202+
199203
Debug.Assert(ipAddresses.Length > 0, "DNS should throw if zero addresses resolve");
200204
IPAddress[] ipv4Addresses = null;
201205
IPAddress[] ipv6Addresses = null;

0 commit comments

Comments
 (0)