Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Discord.Net.WebSocket/API/Gateway/ReadyEvent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ public class ReadState
public User User { get; set; }
[JsonProperty("session_id")]
public string SessionId { get; set; }
[JsonProperty("resume_gateway_url")]
public string ResumeGatewayUrl { get; set; }
[JsonProperty("read_state")]
public ReadState[] ReadStates { get; set; }
[JsonProperty("guilds")]
Expand Down
12 changes: 11 additions & 1 deletion src/Discord.Net.WebSocket/DiscordShardedClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ private async Task ResetSemaphoresAsync()

internal override async Task OnLoginAsync(TokenType tokenType, string token)
{
var botGateway = await GetBotGatewayAsync().ConfigureAwait(false);
if (_automaticShards)
{
var botGateway = await GetBotGatewayAsync().ConfigureAwait(false);
_shardIds = Enumerable.Range(0, botGateway.Shards).ToArray();
_totalShards = _shardIds.Length;
_shards = new DiscordSocketClient[_shardIds.Length];
Expand All @@ -163,7 +163,12 @@ internal override async Task OnLoginAsync(TokenType tokenType, string token)

//Assume thread safe: already in a connection lock
for (int i = 0; i < _shards.Length; i++)
{
// Set the gateway URL to the one returned by Discord, if a custom one isn't set.
_shards[i].ApiClient.GatewayUrl = botGateway.Url;

await _shards[i].LoginAsync(tokenType, token);
}

if(_defaultStickers.Length == 0 && _baseConfig.AlwaysDownloadDefaultStickers)
await DownloadDefaultStickersAsync().ConfigureAwait(false);
Expand All @@ -175,7 +180,12 @@ internal override async Task OnLogoutAsync()
if (_shards != null)
{
for (int i = 0; i < _shards.Length; i++)
{
// Reset the gateway URL set for the shard.
_shards[i].ApiClient.GatewayUrl = null;

await _shards[i].LogoutAsync();
}
}

if (_automaticShards)
Expand Down
60 changes: 53 additions & 7 deletions src/Discord.Net.WebSocket/DiscordSocketApiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ internal class DiscordSocketApiClient : DiscordRestApiClient
private readonly bool _isExplicitUrl;
private CancellationTokenSource _connectCancelToken;
private string _gatewayUrl;
private string _resumeGatewayUrl;

//Store our decompression streams for zlib shared state
private MemoryStream _compressed;
Expand All @@ -37,6 +38,32 @@ internal class DiscordSocketApiClient : DiscordRestApiClient

public ConnectionState ConnectionState { get; private set; }

/// <summary>
/// Sets the gateway URL used for identifies.
/// </summary>
/// <remarks>
/// If a custom URL is set, setting this property does nothing.
/// </remarks>
public string GatewayUrl
{
set
{
// Makes the sharded client not override the custom value.
if (_isExplicitUrl)
return;

_gatewayUrl = FormatGatewayUrl(value);
}
}

/// <summary>
/// Sets the gateway URL used for resumes.
/// </summary>
public string ResumeGatewayUrl
{
set => _resumeGatewayUrl = FormatGatewayUrl(value);
}

public DiscordSocketApiClient(RestClientProvider restClientProvider, WebSocketProvider webSocketProvider, string userAgent,
string url = null, RetryMode defaultRetryMode = RetryMode.AlwaysRetry, JsonSerializer serializer = null,
bool useSystemClock = true, Func<IRateLimitInfo, Task> defaultRatelimitCallback = null)
Expand Down Expand Up @@ -157,6 +184,17 @@ internal override ValueTask DisposeAsync(bool disposing)
#endif
}

/// <summary>
/// Appends necessary query parameters to the specified gateway URL.
/// </summary>
private static string FormatGatewayUrl(string gatewayUrl)
{
if (gatewayUrl == null)
return null;

return $"{gatewayUrl}?v={DiscordConfig.APIVersion}&encoding={DiscordSocketConfig.GatewayEncoding}&compress=zlib-stream";
}

public async Task ConnectAsync()
{
await _stateLock.WaitAsync().ConfigureAwait(false);
Expand Down Expand Up @@ -191,24 +229,32 @@ internal override async Task ConnectInternalAsync()
if (WebSocketClient != null)
WebSocketClient.SetCancelToken(_connectCancelToken.Token);

if (!_isExplicitUrl)
string gatewayUrl;
if (_resumeGatewayUrl == null)
{
if (!_isExplicitUrl && _gatewayUrl == null)
{
var gatewayResponse = await GetBotGatewayAsync().ConfigureAwait(false);
_gatewayUrl = FormatGatewayUrl(gatewayResponse.Url);
}

gatewayUrl = _gatewayUrl;
}
else
{
var gatewayResponse = await GetGatewayAsync().ConfigureAwait(false);
_gatewayUrl = $"{gatewayResponse.Url}?v={DiscordConfig.APIVersion}&encoding={DiscordSocketConfig.GatewayEncoding}&compress=zlib-stream";
gatewayUrl = _resumeGatewayUrl;
}

#if DEBUG_PACKETS
Console.WriteLine("Connecting to gateway: " + _gatewayUrl);
Console.WriteLine("Connecting to gateway: " + gatewayUrl);
#endif

await WebSocketClient.ConnectAsync(_gatewayUrl).ConfigureAwait(false);
await WebSocketClient.ConnectAsync(gatewayUrl).ConfigureAwait(false);

ConnectionState = ConnectionState.Connected;
}
catch
{
if (!_isExplicitUrl)
_gatewayUrl = null; //Uncache in case the gateway url changed
await DisconnectInternalAsync().ConfigureAwait(false);
throw;
}
Expand Down
7 changes: 6 additions & 1 deletion src/Discord.Net.WebSocket/DiscordSocketClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ private async Task OnConnectingAsync()
}
private async Task OnDisconnectingAsync(Exception ex)
{

await _gatewayLogger.DebugAsync("Disconnecting ApiClient").ConfigureAwait(false);
await ApiClient.DisconnectAsync(ex).ConfigureAwait(false);

Expand Down Expand Up @@ -353,6 +352,10 @@ private async Task OnDisconnectingAsync(Exception ex)
if (guild.IsAvailable)
await GuildUnavailableAsync(guild).ConfigureAwait(false);
}

_sessionId = null;
_lastSeq = 0;
ApiClient.ResumeGatewayUrl = null;
}

/// <inheritdoc />
Expand Down Expand Up @@ -834,6 +837,7 @@ private async Task ProcessMessageAsync(GatewayOpCode opCode, int? seq, string ty

_sessionId = null;
_lastSeq = 0;
ApiClient.ResumeGatewayUrl = null;

if (_shardedClient != null)
{
Expand Down Expand Up @@ -891,6 +895,7 @@ private async Task ProcessMessageAsync(GatewayOpCode opCode, int? seq, string ty
AddPrivateChannel(data.PrivateChannels[i], state);

_sessionId = data.SessionId;
ApiClient.ResumeGatewayUrl = data.ResumeGatewayUrl;
_unavailableGuildCount = unavailableGuilds;
CurrentUser = currentUser;
_previousSessionUser = CurrentUser;
Expand Down