/*
* QUANTCONNECT.COM - Democratizing Finance, Empowering Individuals.
* Lean Algorithmic Trading Engine v2.0. Copyright 2014 QuantConnect Corporation.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
using System;
using System.Linq;
using System.Threading;
using QuantConnect.Data;
using QuantConnect.Util;
using QuantConnect.Logging;
using System.Threading.Tasks;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
namespace QuantConnect.Brokerages
{
///
/// Handles brokerage data subscriptions with multiple websocket connections, with optional symbol weighting
///
public class BrokerageMultiWebSocketSubscriptionManager : EventBasedDataQueueHandlerSubscriptionManager, IDisposable
{
private readonly string _webSocketUrl;
private readonly int _maximumSymbolsPerWebSocket;
private readonly int _maximumWebSocketConnections;
private readonly Func _webSocketFactory;
private readonly Func _subscribeFunc;
private readonly Func _unsubscribeFunc;
private readonly Action _messageHandler;
private readonly RateGate _connectionRateLimiter;
private readonly System.Timers.Timer _reconnectTimer;
private const int ConnectionTimeout = 30000;
private readonly object _locker = new();
private readonly List _webSocketEntries = new();
///
/// Initializes a new instance of the class
///
/// The URL for websocket connections
/// The maximum number of symbols per websocket connection
/// The maximum number of websocket connections allowed (if zero, symbol weighting is disabled)
/// A dictionary for the symbol weights
/// A function which returns a new websocket instance
/// A function which subscribes a symbol
/// A function which unsubscribes a symbol
/// The websocket message handler
/// The maximum duration of the websocket connection, TimeSpan.Zero for no duration limit
/// The rate limiter for creating new websocket connections
public BrokerageMultiWebSocketSubscriptionManager(
string webSocketUrl,
int maximumSymbolsPerWebSocket,
int maximumWebSocketConnections,
Dictionary symbolWeights,
Func webSocketFactory,
Func subscribeFunc,
Func unsubscribeFunc,
Action messageHandler,
TimeSpan webSocketConnectionDuration,
RateGate connectionRateLimiter = null)
{
_webSocketUrl = webSocketUrl;
_maximumSymbolsPerWebSocket = maximumSymbolsPerWebSocket;
_maximumWebSocketConnections = maximumWebSocketConnections;
_webSocketFactory = webSocketFactory;
_subscribeFunc = subscribeFunc;
_unsubscribeFunc = unsubscribeFunc;
_messageHandler = messageHandler;
// let's use a reasonable default, no API will like to get DOS on reconnections. 50 WS will take 120s
_connectionRateLimiter = connectionRateLimiter ?? new RateGate(5, TimeSpan.FromSeconds(12));
if (_maximumWebSocketConnections > 0)
{
// symbol weighting enabled, create all websocket instances
for (var i = 0; i < _maximumWebSocketConnections; i++)
{
var webSocket = CreateWebSocket();
_webSocketEntries.Add(new BrokerageMultiWebSocketEntry(symbolWeights, webSocket));
}
}
// Some exchanges (e.g. Binance) require a daily restart for websocket connections
if (webSocketConnectionDuration != TimeSpan.Zero)
{
_reconnectTimer = new System.Timers.Timer
{
Interval = webSocketConnectionDuration.TotalMilliseconds
};
_reconnectTimer.Elapsed += (_, _) =>
{
List webSocketEntries;
lock (_locker)
{
// let's make a copy so we don't hold the lock
webSocketEntries = _webSocketEntries.ToList();
}
Log.Trace($"BrokerageMultiWebSocketSubscriptionManager(): Restarting {webSocketEntries.Count} websocket connections");
Parallel.ForEach(webSocketEntries, new ParallelOptions { MaxDegreeOfParallelism = 4 }, entry =>
{
if (entry.WebSocket.IsOpen)
{
Log.Trace($"BrokerageMultiWebSocketSubscriptionManager(): Websocket restart - disconnect: ({entry.WebSocket.GetHashCode()})");
Disconnect(entry.WebSocket);
Log.Trace($"BrokerageMultiWebSocketSubscriptionManager(): Websocket restart - connect: ({entry.WebSocket.GetHashCode()})");
Connect(entry.WebSocket);
}
});
};
_reconnectTimer.Start();
Log.Trace($"BrokerageMultiWebSocketSubscriptionManager(): WebSocket connections will be restarted every: {webSocketConnectionDuration}");
}
}
///
/// Subscribes to the symbols
///
/// Symbols to subscribe
/// Type of tick data
protected override bool Subscribe(IEnumerable symbols, TickType tickType)
{
Log.Trace($"BrokerageMultiWebSocketSubscriptionManager.Subscribe(): {string.Join(",", symbols.Select(x => x.Value))}");
var success = true;
foreach (var symbol in symbols)
{
var webSocket = GetWebSocketForSymbol(symbol);
success &= _subscribeFunc(webSocket, symbol);
}
return success;
}
///
/// Unsubscribes from the symbols
///
/// Symbols to subscribe
/// Type of tick data
protected override bool Unsubscribe(IEnumerable symbols, TickType tickType)
{
Log.Trace($"BrokerageMultiWebSocketSubscriptionManager.Unsubscribe(): {string.Join(",", symbols.Select(x => x.Value))}");
var success = true;
foreach (var symbol in symbols)
{
var entry = GetWebSocketEntryBySymbol(symbol);
if (entry != null)
{
entry.RemoveSymbol(symbol);
success &= _unsubscribeFunc(entry.WebSocket, symbol);
}
}
return success;
}
///
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
///
public override void Dispose()
{
_reconnectTimer?.Stop();
_reconnectTimer.DisposeSafely();
lock (_locker)
{
foreach (var entry in _webSocketEntries)
{
try
{
entry.WebSocket.Open -= OnOpen;
entry.WebSocket.Message -= EventHandler;
entry.WebSocket.Close();
}
catch (Exception ex)
{
Log.Error(ex);
}
}
_webSocketEntries.Clear();
}
}
private BrokerageMultiWebSocketEntry GetWebSocketEntryBySymbol(Symbol symbol)
{
lock (_locker)
{
foreach (var entry in _webSocketEntries.Where(entry => entry.Contains(symbol)))
{
return entry;
}
}
return null;
}
///
/// Adds a symbol to an existing or new websocket connection
///
private IWebSocket GetWebSocketForSymbol(Symbol symbol)
{
BrokerageMultiWebSocketEntry entry;
lock (_locker)
{
if (_webSocketEntries.All(x => x.SymbolCount >= _maximumSymbolsPerWebSocket))
{
if (_maximumWebSocketConnections > 0)
{
throw new NotSupportedException($"Maximum symbol count reached for the current configuration [MaxSymbolsPerWebSocket={_maximumSymbolsPerWebSocket}, MaxWebSocketConnections:{_maximumWebSocketConnections}]");
}
// symbol limit reached on all, create new websocket instance
var webSocket = CreateWebSocket();
_webSocketEntries.Add(new BrokerageMultiWebSocketEntry(webSocket));
}
// sort by weight ascending, taking into account the symbol limit per websocket
_webSocketEntries.Sort((x, y) =>
x.SymbolCount >= _maximumSymbolsPerWebSocket
? 1
: y.SymbolCount >= _maximumSymbolsPerWebSocket
? -1
: Math.Sign(x.TotalWeight - y.TotalWeight));
entry = _webSocketEntries.First();
}
if (!entry.WebSocket.IsOpen)
{
Connect(entry.WebSocket);
}
entry.AddSymbol(symbol);
Log.Trace($"BrokerageMultiWebSocketSubscriptionManager.GetWebSocketForSymbol(): added symbol: {symbol} to websocket: {entry.WebSocket.GetHashCode()} - Count: {entry.SymbolCount}");
return entry.WebSocket;
}
///
/// When we create a websocket we will subscribe to it's events once and initialize it
///
/// Note that the websocket is no connected yet
private IWebSocket CreateWebSocket()
{
var webSocket = _webSocketFactory();
webSocket.Open += OnOpen;
webSocket.Message += EventHandler;
webSocket.Initialize(_webSocketUrl);
return webSocket;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private void EventHandler(object _, WebSocketMessage message)
{
_messageHandler(message);
}
private void Connect(IWebSocket webSocket)
{
var connectedEvent = new ManualResetEvent(false);
EventHandler onOpenAction = (_, _) =>
{
connectedEvent.Set();
};
webSocket.Open += onOpenAction;
_connectionRateLimiter.WaitToProceed();
try
{
webSocket.Connect();
if (!connectedEvent.WaitOne(ConnectionTimeout))
{
throw new TimeoutException($"BrokerageMultiWebSocketSubscriptionManager.Connect(): WebSocket connection timeout: {webSocket.GetHashCode()}");
}
}
finally
{
webSocket.Open -= onOpenAction;
connectedEvent.DisposeSafely();
}
}
private void Disconnect(IWebSocket webSocket)
{
webSocket.Close();
}
private void OnOpen(object sender, EventArgs e)
{
var webSocket = (IWebSocket)sender;
lock (_locker)
{
foreach (var entry in _webSocketEntries)
{
if (entry.WebSocket == webSocket && entry.Symbols.Count > 0)
{
Log.Trace($"BrokerageMultiWebSocketSubscriptionManager.Connect(): WebSocket opened: {webSocket.GetHashCode()} - Resubscribing existing symbols: {entry.Symbols.Count}");
Task.Factory.StartNew(() =>
{
foreach (var symbol in entry.Symbols)
{
_subscribeFunc(webSocket, symbol);
}
});
break;
}
}
}
}
}
}