diff --git a/CHANGELOG.md b/CHANGELOG.md index b59b032..eeb543c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,10 +5,22 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - ## Unreleased ### Added +- `skip_utf8_validation` parameter (https://github.com/bybit-exchange/pybit/issues/224) - [WebSocket Trading](https://bybit-exchange.github.io/docs/v5/websocket/trade/guideline) support +- added "disconnect_on_exception" argument to `WebSocket` constructor. Pass `False` to keep the connection open on exception. + +### Changed +- Now utf-8 validation is disabled by default. To enable pass `skip_utf8_validation=False` to `WebSocket()` +- Made all exceptions inherit from `PybitException` +- replaced base "Exception" exceptions with "PybitException" inherited ones +- rework on _on_error callback. Won't raise the exception when "disconnect_on_exception" is False +- now exceptions handled in `WebSocket`'s `_on_error` callback will be logged along with the stacktrace. + +### Deprecated +- `restart_on_error` in `WebSocket` constructor. Use `restart_on_ws_disconnect` instead +- *Error exceptions. Replaced with *Exception exceptions(Ex.: `InvalidRequestError` replaced with `InvalidRequestException`) ## [5.7.0] - 2024-04-11 diff --git a/pybit/_utils.py b/pybit/_utils.py new file mode 100644 index 0000000..4342b3c --- /dev/null +++ b/pybit/_utils.py @@ -0,0 +1,243 @@ +""" Utility classes and functions. """ + +import inspect +from abc import abstractmethod, ABC +from typing import Optional, Any +from warnings import warn +from packaging import version +from . import VERSION + +DEPRECATION_CONFIG = 'deprecation_config' + + +class DeprecationConfig(ABC): + """ Base deprecation configuration class model """ + + modification_version: str + details: Optional[str] + + def __init__( + self, + modification_version: str, + details: Optional[str] = None, + ) -> None: + self.modification_version = modification_version + self.details = details + + @property + def should_be_modified(self) -> bool: + """ Check if there is a need to modify the function/class(remove, or replace some arguments).""" + return version.parse(VERSION) >= version.parse(self.modification_version) + + @property + @abstractmethod + def warn_message(self) -> Optional[str]: + """ Return the deprecation message. + This method should be implemented in the subclass. + + Returns: + Optional[str]: + The deprecation message. + If there is no message to be shown, return None. + """ + + def warn(self): + """ Warn the user about the deprecation. """ + msg = self.warn_message + if msg: + warn(msg, DeprecationWarning, 2) + + +class ClassDeprecationConfig(DeprecationConfig): + """ Configuration class model for deprecated classes. + + Args: + remove_version (str): The version in which the class will be removed. + cls (type): The class to be deprecated. + details (Optional[str]): Additional details about the deprecation. + replacement (Optional[str | type]): The class to be used as a replacement(if any). + """ + + cls: type + replacement: Optional[str | type] + + def __init__( + self, + remove_version: str, + cls: type, + details: Optional[str] = None, + replacement: Optional[str] = None, + ) -> None: + self.cls = cls + self.replacement = replacement + super().__init__(remove_version, details) + + @property + def warn_message(self) -> str: + message = f'"{self.cls.__name__}" is deprecated and will be removed in version {self.modification_version}.' + if self.replacement: + replacement = self.replacement if isinstance( + self.replacement, str) else self.replacement.__name__ + message += f' Use "{replacement}" instead.' + if self.details: + message += f' {self.details}' + return message + + +class FunctionArgumentsDeprecationConfig(DeprecationConfig): + """ Configuration class model for deprecated function arguments. + + Args: + modification_version (str): + The version in which the arguments will be removed/replaced. + to_be_removed (Optional[list[str] | str]): + The arguments to be removed. Either a list of arguments + or a single argument(if there is only one arg to be removed). + to_be_replaced (Optional[list[tuple[str, str]] | tuple[str, str]]): + The arguments to be replaced. + Either a list of tuples of arguments to be replaced or a single tuple + (if there is only one arg to be replaced). First element of the tuple is + the argument to be replaced and the second element is the argument to be replaced with. + function_name (str): The name of the function. + kwargs (dict[str, Any]): The keyword arguments of the function. + details (Optional[str]): Additional details about the deprecation. + """ + kwargs: dict[str, Any] + function_name: str + to_be_removed: list[str] + to_be_replaced: list[tuple[str, str]] + + def __init__( + self, + modification_version: str, + to_be_removed: list[str] | str, + to_be_replaced: list[tuple[str, str]] | tuple[str, str], + function_name: str, + kwargs: dict[str, Any], + details: Optional[str] = None, + ) -> None: + self.kwargs = kwargs + self.function_name = function_name + self.to_be_removed = to_be_removed if isinstance( + to_be_removed, list) else [to_be_removed] + self.to_be_replaced = to_be_replaced if isinstance( + to_be_replaced, list) else [to_be_replaced] + super().__init__(modification_version, details) + + @property + def warn_message(self) -> Optional[str]: + replace_args = list( + filter(lambda x: x[0] in self.kwargs, self.to_be_replaced) + ) + if len(self.to_be_removed) + len(replace_args) == 0: + return None + + message = ( + f'The following arguments from function "{self.function_name}" ' + 'are deprecated and will be removed/replaced in version ' + f'{self.modification_version}:' + ) + if len(self.to_be_removed) > 0: + message += '\nArguments to be removed:\n\t' + message += '\n\t'.join( + [f'- "{x}"' for x in self.to_be_removed] + ) + if len(replace_args) > 0: + message += '\nArguments to be replaced:\n\t' + message += '\n\t'.join( + [f'- "{x[0]}"(Replace with "{x[1]}")' for x in replace_args] + ) + if self.details: + message += f' {self.details}' + + return message + + +def deprecate_class( + remove_version: str, + details: Optional[str] = None, + replacement: Optional[str | type] = None +): + """ Decorator to deprecate a class. + + Args: + remove_version (str): The version in which the class will be removed. + details (Optional[str]): Additional details about the deprecation. + replacement (Optional[str | type]): The class to be used as a replacement(if any). + """ + def decorator(cls): + if not inspect.isclass(cls): + raise AssertionError( + "This decorator can only be applied to classes.") + setattr( + cls, + DEPRECATION_CONFIG, + ClassDeprecationConfig( + remove_version=remove_version, + cls=cls, + details=details, + replacement=replacement, + ) + ) + init = cls.__init__ + + def __init__(self, *args, **kwargs): + if cls is self.__class__: + getattr(self, DEPRECATION_CONFIG).warn() + init(self, *args, **kwargs) + cls.__init__ = __init__ + return cls + return decorator + + +def deprecate_function_arguments( + modification_version: str, + to_be_removed: Optional[list[str] | str] = None, + to_be_replaced: Optional[list[tuple[str, str]] | tuple[str, str]] = None, + details: Optional[str] = None, +): + """ Decorator to deprecate function arguments. + + Args: + modification_version (str): + The version in which the arguments will be removed/replaced. + to_be_removed (Optional[list[str] | str]): + The arguments to be removed. Either a list of arguments + or a single argument(if there is only one arg to be removed). + to_be_replaced (Optional[list[tuple[str, str]] | tuple[str, str]]): + The arguments to be replaced. + Either a list of tuples of arguments to be replaced or a single tuple + (if there is only one arg to be replaced). First element of the tuple is + the argument to be replaced and the second element is the argument to be replaced with. + details (Optional[str]): Additional details about the deprecation. + """ + if to_be_removed is None and to_be_replaced is None: + raise ValueError( + 'At least one of "to_be_removed" or "to_be_replaced" must be provided.' + ) + + def decorator(func): + if not inspect.isfunction(func): + raise AssertionError( + "This decorator can only be applied to functions.") + config = FunctionArgumentsDeprecationConfig( + modification_version=modification_version, + to_be_removed=to_be_removed or [], + to_be_replaced=to_be_replaced or [], + function_name=func.__qualname__, + kwargs={}, + details=details, + ) + + def wrapper(*args, **kwargs): + config.kwargs = kwargs + config.warn() + return func(*args, **kwargs) + + setattr( + wrapper, + DEPRECATION_CONFIG, + config, + ) + return wrapper + return decorator diff --git a/pybit/_websocket_stream.py b/pybit/_websocket_stream.py index ff9714d..90ca35d 100644 --- a/pybit/_websocket_stream.py +++ b/pybit/_websocket_stream.py @@ -1,11 +1,21 @@ -import websocket +""" Websocket stream manager for Bybit API. """ + +from typing import Optional, Union, Callable import threading import time import json -from ._http_manager import generate_signature import logging import copy from uuid import uuid4 +import websocket +from ._http_manager import generate_signature +from .exceptions import ( + PybitException, + AlreadySubscribedTopicException, + AuthorizationFailedException, + WSConnectionNotEstablishedException +) +from ._utils import deprecate_function_arguments from . import _helpers @@ -21,6 +31,17 @@ class _WebSocketManager: + ws: Optional[websocket.WebSocketApp] + wst: Optional[threading.Thread] + auth: bool + exited: bool + attempting_connection: bool + data: dict + + @deprecate_function_arguments( + "6.0", + to_be_replaced=("restart_on_error", "restart_on_ws_disconnect"), + ) def __init__( self, callback_function, @@ -34,9 +55,12 @@ def __init__( ping_interval=20, ping_timeout=10, retries=10, - restart_on_error=True, + restart_on_error: Optional[bool] = None, + restart_on_ws_disconnect: bool = True, + disconnect_on_exception: bool = True, trace_logging=False, private_auth_expire=1, + skip_utf8_validation=False, ): self.testnet = testnet self.domain = domain @@ -50,7 +74,7 @@ def __init__( self.ws_name = ws_name if api_key: self.ws_name += " (Auth)" - + # Delta time for private auth expiration in seconds self.private_auth_expire = private_auth_expire @@ -58,11 +82,11 @@ def __init__( # { # "topic_name": function # } - self.callback_directory = {} + self.callback_directory: dict[str, Callable] = {} # Record the subscriptions made so that we can resubscribe if the WSS # connection is broken. - self.subscriptions = [] + self.subscriptions: dict[str, str] = {} # Set ping settings. self.ping_interval = ping_interval @@ -71,23 +95,41 @@ def __init__( self.retries = retries # Other optional data handling settings. - self.handle_error = restart_on_error + self.restart_on_ws_disconnect = restart_on_error or restart_on_ws_disconnect + # If True, disconnects the websocket connection when a non-websocket + # exception is raised(for example, a broken ws message was received, which + # caused wrong handling, and, therefore, an exception was thrown). + # If False, the websocket connection will not be closed, and the exception + # will be ignored(will be only logged). + self.disconnect_on_exception = disconnect_on_exception # Enable websocket-client's trace logging for extra debug information # on the websocket connection, including the raw sent & recv messages websocket.enableTrace(trace_logging) + # Set the skip_utf8_validation parameter to True to skip the utf-8 + # validation of incoming messages. + # Could be useful if incoming messages contain invalid utf-8 characters. + # Also disabling utf-8 validation could improve performance + # (more about performance: https://github.com/websocket-client/websocket-client). + self.skip_utf8_validation = skip_utf8_validation + # Set initial state, initialize dictionary and connect. - self._reset() + self.auth = False + self.exited = False + self.data = {} + self.endpoint = None + self.ws = None + self.wst = None self.attempting_connection = False - def _on_open(self): + def _on_open(self, *_): """ Log WS open. """ - logger.debug(f"WebSocket {self.ws_name} opened.") + logger.debug("WebSocket %s opened.", self.ws_name) - def _on_message(self, message): + def _on_message(self, _, message): """ Parse incoming messages. """ @@ -98,6 +140,7 @@ def _on_message(self, message): self.callback(message) def is_connected(self): + """ Check if the websocket is connected. """ try: if self.ws.sock.connected: return True @@ -106,6 +149,11 @@ def is_connected(self): except AttributeError: return False + def _send(self, message): + if self.ws is None: + raise WSConnectionNotEstablishedException() + self.ws.send(message) + def _connect(self, url): """ Open websocket in a thread. @@ -118,8 +166,8 @@ def resubscribe_to_topics(): # no previous WSS connection. return - for req_id, subscription_message in self.subscriptions.items(): - self.ws.send(subscription_message) + for _, subscription_message in self.subscriptions.items(): + self._send(subscription_message) self.attempting_connection = True @@ -144,14 +192,14 @@ def resubscribe_to_topics(): while ( infinitely_reconnect or retries > 0 ) and not self.is_connected(): - logger.info(f"WebSocket {self.ws_name} attempting connection...") + logger.info("WebSocket %s attempting connection...", self.ws_name) self.ws = websocket.WebSocketApp( url=url, - on_message=lambda ws, msg: self._on_message(msg), - on_close=lambda ws, *args: self._on_close(), - on_open=lambda ws, *args: self._on_open(), - on_error=lambda ws, err: self._on_error(err), - on_pong=lambda ws, *args: self._on_pong(), + on_message=self._on_message, + on_close=self._on_close, + on_open=self._on_open, + on_error=self._on_error, + on_pong=self._on_pong, ) # Setup the thread running WebSocketApp. @@ -159,6 +207,7 @@ def resubscribe_to_topics(): target=lambda: self.ws.run_forever( ping_interval=self.ping_interval, ping_timeout=self.ping_timeout, + skip_utf8_validation=self.skip_utf8_validation, ) ) @@ -176,11 +225,12 @@ def resubscribe_to_topics(): self.exit() raise websocket.WebSocketTimeoutException( f"WebSocket {self.ws_name} ({self.endpoint}) connection " - f"failed. Too many connection attempts. pybit will no " - f"longer try to reconnect." + "failed. Too many connection attempts. pybit will no " + "longer try to reconnect." + ) - logger.info(f"WebSocket {self.ws_name} connected") + logger.info("WebSocket %s connected", self.ws_name) # If given an api_key, authenticate. if self.api_key and self.api_secret: @@ -205,44 +255,61 @@ def _auth(self): ) # Authenticate with API. - self.ws.send( + self._send( json.dumps( {"op": "auth", "args": [self.api_key, expires, signature]} ) ) - def _on_error(self, error): + def _on_error(self, _, error): """ Exit on errors and raise exception, or attempt reconnect. """ - if type(error).__name__ not in [ - "WebSocketConnectionClosedException", - "ConnectionResetError", - "WebSocketTimeoutException", - ]: - # Raises errors not related to websocket disconnection. - self.exit() - raise error - - if not self.exited: - logger.error( - f"WebSocket {self.ws_name} ({self.endpoint}) " - f"encountered error: {error}." + is_ws_disconnect = any( + map( + lambda exception: isinstance(error, exception), + [ + websocket.WebSocketConnectionClosedException, + websocket.WebSocketTimeoutException, + ] ) - self.exit() + ) + should_raise = isinstance(error, PybitException) or \ + (is_ws_disconnect and not self.restart_on_ws_disconnect) or \ + (not is_ws_disconnect and self.disconnect_on_exception) + + log_callback = logger.error if is_ws_disconnect else logger.exception + log_callback( + "WebSocket %(ws_name)s (%(endpoint)s) encountered error: %(error)s.", + {"ws_name": self.ws_name, "endpoint": self.endpoint, "error": error}, + ) - # Reconnect. - if self.handle_error and not self.attempting_connection: + if is_ws_disconnect and self.restart_on_ws_disconnect and not self.attempting_connection: + if not self.exited: + self.exit() + logger.info( + "Attempting to reconnect WebSocket %s...", + self.ws_name + ) self._reset() self._connect(self.endpoint) - def _on_close(self): + if should_raise: + self.exit() + logger.info( + "WebSocket %s closed because an exception was raised. " + "If you want to keep the connection open, set disconnect_on_exception=False", + self.ws_name + ) + raise error + + def _on_close(self, *args): """ Log WS close. """ - logger.debug(f"WebSocket {self.ws_name} closed.") + logger.debug("WebSocket %s closed.", self.ws_name) - def _on_pong(self): + def _on_pong(self, *args): """ Sends a custom ping upon the receipt of the pong frame. @@ -255,7 +322,7 @@ def _on_pong(self): self._send_custom_ping() def _send_custom_ping(self): - self.ws.send(self.custom_ping_message) + self._send(self.custom_ping_message) def _send_initial_ping(self): """https://github.com/bybit-exchange/pybit/issues/164""" @@ -314,8 +381,10 @@ def subscribe( self, topic: str, callback, - symbol: (str, list) = False + symbol: Union[str, list[str], None] = None, ): + """ Subscribe to a topic on the websocket stream. """ + symbol = symbol or [] def prepare_subscription_args(list_of_symbols): """ @@ -332,7 +401,7 @@ def prepare_subscription_args(list_of_symbols): topics.append(topic.format(symbol=single_symbol)) return topics - if type(symbol) == str: + if isinstance(symbol, str): symbol = [symbol] subscription_args = prepare_subscription_args(symbol) @@ -346,7 +415,7 @@ def prepare_subscription_args(list_of_symbols): while not self.is_connected(): # Wait until the connection is open before subscribing. time.sleep(0.1) - self.ws.send(subscription_message) + self._send(subscription_message) self.subscriptions[req_id] = subscription_message for topic in subscription_args: self._set_callback(topic, callback) @@ -368,8 +437,8 @@ def _process_delta_orderbook(self, message, topic): # Make updates according to delta response. book_sides = {"b": message["data"]["b"], "a": message["data"]["a"]} - self.data[topic]["u"]=message["data"]["u"] - self.data[topic]["seq"]=message["data"]["seq"] + self.data[topic]["u"] = message["data"]["u"] + self.data[topic]["seq"] = message["data"]["seq"] for side, entries in book_sides.items(): for entry in entries: @@ -417,13 +486,13 @@ def _process_delta_ticker(self, message, topic): def _process_auth_message(self, message): # If we get successful futures auth, notify user if message.get("success") is True: - logger.debug(f"Authorization for {self.ws_name} successful.") + logger.debug("Authorization for %s successful.", self.ws_name) self.auth = True # If we get unsuccessful auth, notify user. elif message.get("success") is False or message.get("type") == "error": - raise Exception( - f"Authorization for {self.ws_name} failed. Please check your " - f"API keys and resync your system time. Raw error: {message}" + raise AuthorizationFailedException( + ws_name=self.ws_name, + raw_message=message, ) def _process_subscription_message(self, message): @@ -436,11 +505,11 @@ def _process_subscription_message(self, message): # If we get successful futures subscription, notify user if message.get("success") is True: - logger.debug(f"Subscription to {sub} successful.") + logger.debug("Subscription to %s successful.", sub) # Futures subscription fail elif message.get("success") is False: response = message["ret_msg"] - logger.error("Couldn't subscribe to topic." f"Error: {response}.") + logger.error("Couldn't subscribe to topic. Error: %s.", response) self._pop_callback(sub[0]) def _process_normal_message(self, message): @@ -489,9 +558,7 @@ def is_subscription_message(): def _check_callback_directory(self, topics): for topic in topics: if topic in self.callback_directory: - raise Exception( - f"You have already subscribed to this topic: " f"{topic}" - ) + raise AlreadySubscribedTopicException(topic) def _set_callback(self, topic, callback_function): self.callback_directory[topic] = callback_function diff --git a/pybit/_websocket_trading.py b/pybit/_websocket_trading.py index f7523ee..4dcf517 100644 --- a/pybit/_websocket_trading.py +++ b/pybit/_websocket_trading.py @@ -1,9 +1,11 @@ -from dataclasses import dataclass, field +"""Module for the WebSocket Trading API.""" + import json import uuid import logging from ._websocket_stream import _WebSocketManager from . import _helpers +from .exceptions import AuthorizationFailedException logger = logging.getLogger(__name__) @@ -23,38 +25,29 @@ def __init__(self, recv_window, referral_id, **kwargs): def _process_auth_message(self, message): # If we get successful auth, notify user if message.get("retCode") == 0: - logger.debug(f"Authorization for {self.ws_name} successful.") + logger.debug("Authorization for %s successful.", self.ws_name) self.auth = True # If we get unsuccessful auth, notify user. else: - raise Exception( - f"Authorization for {self.ws_name} failed. Please check your " - f"API keys and resync your system time. Raw error: {message}" + raise AuthorizationFailedException( + raw_message=message, + ws_name=self.ws_name, ) def _process_error_message(self, message): logger.error( - f"WebSocket request {message['reqId']} hit an error. Enabling " - f"traceLogging to reproduce the issue. Raw error: {message}" + "WebSocket request %s hit an error. Enabling " + "traceLogging to reproduce the issue. Raw error: %s", + message['reqId'], + message, + ) self._pop_callback(message["reqId"]) def _handle_incoming_message(self, message): - def is_auth_message(): - if message.get("op") == "auth": - return True - else: - return False - - def is_error_message(): - if message.get("retCode") != 0: - return True - else: - return False - - if is_auth_message(): + if message.get("op") == "auth": self._process_auth_message(message) - elif is_error_message(): + elif message.get("retCode") != 0: self._process_error_message(message) else: callback_function = self._pop_callback(message["reqId"]) diff --git a/pybit/exceptions.py b/pybit/exceptions.py index 67aaa02..16e2138 100644 --- a/pybit/exceptions.py +++ b/pybit/exceptions.py @@ -1,16 +1,126 @@ -class UnauthorizedExceptionError(Exception): - pass +""" This module contains the exceptions for the Pybit package. """ -class InvalidChannelTypeError(Exception): - pass +from ._utils import deprecate_class -class TopicMismatchError(Exception): - pass +class PybitException(Exception): + """ + Base exception class for all exceptions. + """ + + +class WSConnectionNotEstablishedException(PybitException): + """ + Exception raised when connection is not established. + """ + + def __init__(self): + super().__init__("WebSocket connection is not established. Please connect first.") + + +# =================== Authorization Exceptions =================== +class AuthorizationException(PybitException): + """ + Base exception class for all authorization exceptions. + """ + + +@deprecate_class( + '6.0', + replacement='NoCredentialsAuthorizationException', +) +class UnauthorizedExceptionError(AuthorizationException): + """ + Exception raised for unauthorized requests. + """ + + +class NoCredentialsAuthorizationException(UnauthorizedExceptionError): + """ + Exception raised when no credentials are provided. + """ + + def __init__(self): + super().__init__('"api_key" and/or "api_secret" are not set. They both are needed in order to access private sources') -class FailedRequestError(Exception): +class AuthorizationFailedException(AuthorizationException): + """ + Exception raised when authorization fails. + """ + + def __init__(self, ws_name: str, raw_message: str): + super().__init__( + f"Authorization for {ws_name} failed. Please check your " + f"API keys and resync your system time. Raw error: {raw_message}" + ) +# =============================================================== + + +@deprecate_class( + '6.0', + replacement='InvalidChannelTypeException', +) +class InvalidChannelTypeError(PybitException): + """ + Exception raised for invalid channel types. + """ + + +class InvalidChannelTypeException(InvalidChannelTypeError): + """ + Exception raised for invalid channel types. + """ + + def __init__(self, provided_channel: str, available_channels: list[str]): + super().__init__( + f'Invalid channel type("{provided_channel}"). Available: {available_channels}') + + +# ================ Topic Exceptions ======================== +class TopicException(PybitException): + """ + Base exception class for all topic exceptions. + """ + + +@deprecate_class( + '6.0', + replacement='TopicMismatchException', +) +class TopicMismatchError(TopicException): + """ + Exception raised for topic mismatch. + """ + + +class TopicMismatchException(TopicMismatchError): + """ + Exception raised for topic mismatch. + """ + + def __init__(self): + super().__init__("Requested topic does not match channel_type") + + +class AlreadySubscribedTopicException(TopicException): + """ + Exception raised for already subscribed topics. + """ + + def __init__(self, topic: str): + super().__init__(f"Already subscribed to topic: {topic}") +# ============================================================ + + +@deprecate_class( + '6.0', + replacement='FailedRequestException', +) +class FailedRequestError(PybitException): + # TODO: Remove this class in the next major release + # and copy-paste the __init__ method from this class to the replacement class """ Exception raised for failed requests. @@ -34,7 +144,26 @@ def __init__(self, request, message, status_code, time, resp_headers): ) -class InvalidRequestError(Exception): +class FailedRequestException(FailedRequestError): + """ + Exception raised for failed requests. + + Attributes: + request -- The original request that caused the error. + message -- Explanation of the error. + status_code -- The code number returned. + time -- The time of the error. + resp_headers -- The response headers from API. None, if the request caused an error locally. + """ + + +@deprecate_class( + '6.0', + replacement='InvalidRequestException', +) +class InvalidRequestError(PybitException): + # TODO: Remove this class in the next major release + # and copy-paste the __init__ method from this class to the replacement class """ Exception raised for returned Bybit errors. @@ -56,3 +185,16 @@ def __init__(self, request, message, status_code, time, resp_headers): f"{message} (ErrCode: {status_code}) (ErrTime: {time})" f".\nRequest → {request}." ) + + +class InvalidRequestException(FailedRequestError): + """ + Exception raised for returned Bybit errors. + + Attributes: + request -- The original request that caused the error. + message -- Explanation of the error. + status_code -- The code number returned. + time -- The time of the error. + resp_headers -- The response headers from API. None, if the request caused an error locally. + """ diff --git a/setup.py b/setup.py index ce055df..e77389c 100644 --- a/setup.py +++ b/setup.py @@ -31,5 +31,6 @@ "requests", "websocket-client", "websockets", + "packaging", ], ) diff --git a/tests/test_pybit.py b/tests/test_pybit.py index 67159d1..070f7fc 100644 --- a/tests/test_pybit.py +++ b/tests/test_pybit.py @@ -1,6 +1,21 @@ -import unittest, time -from pybit.exceptions import InvalidChannelTypeError, TopicMismatchError +"""Tests for the Pybit API wrapper.""" + +import sys +import inspect +import unittest +from unittest import mock +import time +from websocket import ( + WebSocketConnectionClosedException, + WebSocketTimeoutException +) +from pybit.exceptions import ( + InvalidChannelTypeError, + TopicMismatchError, + NoCredentialsAuthorizationException +) from pybit.unified_trading import HTTP, WebSocket +from pybit._utils import DEPRECATION_CONFIG # session uses Bybit's mainnet endpoint session = HTTP() @@ -52,7 +67,7 @@ def test_place_active_order(self): class WebSocketTest(unittest.TestCase): # A very simple test to ensure we're getting something from WS. - def _callback_function(msg): + def _callback_function(self, msg): print(msg) def test_websocket(self): @@ -80,12 +95,13 @@ def test_topic_category_mismatch(self): ) ws.order_stream(callback=self._callback_function) - + + class PrivateWebSocketTest(unittest.TestCase): # Connect to private websocket and see if we can auth. - def _callback_function(msg): + def _callback_function(self, msg): print(msg) - + def test_private_websocket_connect(self): ws_private = WebSocket( testnet=True, @@ -93,8 +109,145 @@ def test_private_websocket_connect(self): api_key="...", api_secret="...", trace_logging=True, - #private_auth_expire=10 + # private_auth_expire=10 ) - + ws_private.position_stream(callback=self._callback_function) - #time.sleep(10) + # time.sleep(10) + + +class WSOnErrorCallbackTest(unittest.TestCase): + """ Test WebSocket on_error callback. """ + + def test_tries_to_reconnect(self): + """ Test if WebSocket tries to reconnect on connection error. """ + ws.restart_on_ws_disconnect = True + ws.attempting_connection = False + ws._reset = mock.MagicMock() + ws._connect = mock.MagicMock() + ws.exit = mock.MagicMock() + # WebSocketConnectionClosedException + ws._on_error(ws, WebSocketConnectionClosedException()) + ws._reset.assert_called_once() + ws._connect.assert_called_once() + ws.exit.assert_called_once() + # WebSocketTimeoutException + ws._reset.reset_mock() + ws._connect.reset_mock() + ws.exit.reset_mock() + ws._on_error(ws, WebSocketTimeoutException()) + ws._reset.assert_called_once() + ws._connect.assert_called_once() + ws.exit.assert_called_once() + + def test_doesnt_try_to_reconnect_when_restart_on_ws_disconnect_is_false(self): + """ Test if WebSocket doesn't try to reconnect when restart_on_ws_disconnect is False. """ + ws.restart_on_ws_disconnect = False + ws.attempting_connection = False + ws._reset = mock.MagicMock() + ws._connect = mock.MagicMock() + ws.exit = mock.MagicMock() + self.assertRaises( + WebSocketConnectionClosedException, + ws._on_error, ws, WebSocketConnectionClosedException() + ) + ws._reset.assert_not_called() + ws._connect.assert_not_called() + ws.exit.assert_called_once() + + def test_disconnects_on_pybit_exception(self): + """ Test if WebSocket disconnects on Pybit exception. """ + ws.restart_on_ws_disconnect = False + ws.attempting_connection = False + ws._connect = mock.MagicMock() + ws.exit = mock.MagicMock() + self.assertRaises( + NoCredentialsAuthorizationException, + ws._on_error, ws, NoCredentialsAuthorizationException() + ) + ws._connect.assert_not_called() + ws.exit.assert_called_once() + + def test_ignores_exceptions_when_disconnect_on_exception_is_false(self): + """ Test if WebSocket ignores exceptions when disconnect_on_exception is False. """ + ws.disconnect_on_exception = False + ws.restart_on_ws_disconnect = False + ws.attempting_connection = False + ws._connect = mock.MagicMock() + ws.exit = mock.MagicMock() + ws._on_error(ws, Exception()) + ws._connect.assert_not_called() + ws.exit.assert_not_called() + + def test_raises_exception_when_disconnect_on_exception_is_true(self): + """ Test if WebSocket raises exception when disconnect_on_exception is True. """ + ws.disconnect_on_exception = True + ws.restart_on_ws_disconnect = False + ws.attempting_connection = False + ws._connect = mock.MagicMock() + ws.exit = mock.MagicMock() + self.assertRaises( + Exception, + ws._on_error, ws, Exception() + ) + ws._connect.assert_not_called() + ws.exit.assert_called_once() + + def test_doesn_nothing_when_attempting_connection_is_true(self): + """ Test if WebSocket does nothing when attempting_connection is True. """ + ws.restart_on_ws_disconnect = True + ws.attempting_connection = True + ws._reset = mock.MagicMock() + ws._connect = mock.MagicMock() + ws.exit = mock.MagicMock() + ws._on_error(ws, WebSocketConnectionClosedException()) + ws._reset.assert_not_called() + ws._connect.assert_not_called() + ws.exit.assert_not_called() + + +class DeprecatedMembersTest(unittest.TestCase): + """ Test deprecated members. """ + + def _check_deprecated_function(self, func): + config = func.__dict__.get(DEPRECATION_CONFIG) + if config and config.should_be_modified: + message = ( + f'There are arguments from function "{config.function_name}" ' + 'that are deprecated and must be removed in version ' + f'{config.modification_version}:\n' + + ', '.join([f'"{x}"' for x in config.to_be_removed]) + + (', ' if len(config.to_be_removed) > 0 else '') + + ', '.join( + [f'"{x[0]}"(Replaced with "{x[1]}")' for x in config.to_be_replaced]) + ) + raise AssertionError(message) + + def _check_deprecated_class(self, cls): + config = cls[1].__dict__.get(DEPRECATION_CONFIG) + if config: + self.assertFalse( + config.should_be_modified, + f'Class "{cls[0]}" should be removed in version {config.modification_version}!' + ) + + def test_should_modify_deprecated_members(self): + """ Test if deprecated members(classes/functions) + should be modified(removed/renamed) in current version. + """ + pybit_modules = [e[1] + for e in sys.modules.items() if e[0].startswith('pybit.')] + for module in pybit_modules: + # Getting all classes from the module + classes = inspect.getmembers(module, inspect.isclass) + for cls in classes: + # Check all classes in the module + self._check_deprecated_class(cls) + # Check all functions in the class + for item in cls[1].__dict__.items(): + if inspect.isfunction(item[1]): + self._check_deprecated_function(item[1]) + # Check all functions in the module + functions = inspect.getmembers(module, inspect.isfunction) + for func in functions: + self._check_deprecated_function(func[1])