227 lines
12 KiB
Python
227 lines
12 KiB
Python
from decimal import Decimal
|
|
import socket
|
|
from typing import Callable
|
|
from unittest.mock import ANY, _Call, MagicMock
|
|
import logging
|
|
|
|
from common.info_client import BaseInfoClient
|
|
from common.utils import decimal_from_float
|
|
from connection.connection_handler import ConnectionHandlerFactory
|
|
from connection.ip_address import IpAddress
|
|
from proto.info_pb2 import *
|
|
from proto.common_pb2 import Instrument, LoginRequest, LoginResponse, Side
|
|
|
|
from tests.conftest import DEFAULT_PASSWORD
|
|
from tests.common.mock_expectations import CallExpectations, CallExpectationsManager, ResponseExpectation
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class InfoTestClient(BaseInfoClient):
|
|
def __init__(self, socket_fd: socket.socket, ip_address: IpAddress, on_close: Callable[[], None],
|
|
call_expectations_manager: CallExpectationsManager) -> None:
|
|
super().__init__(socket_fd, ip_address, on_close)
|
|
self._last_order_book_id: int | None = None
|
|
self._last_order_id: int | None = None
|
|
self.call_expectations_manager = call_expectations_manager
|
|
self.on_login_response = self.call_expectations_manager.create_mock("on_login_response")
|
|
self.on_create_instrument_response = self.call_expectations_manager.create_mock("on_create_instrument_response")
|
|
self.on_order_book_subscribe_response = self.call_expectations_manager.create_mock("on_order_book_subscribe_response")
|
|
self.on_instrument = self.call_expectations_manager.create_mock("on_instrument") # type: ignore
|
|
self.on_top_of_book = self.call_expectations_manager.create_mock("on_top_of_book") # type: ignore
|
|
self.on_price_depth_book = self.call_expectations_manager.create_mock("on_price_depth_book") # type: ignore
|
|
self.on_trade = self.call_expectations_manager.create_mock("on_trade") # type: ignore
|
|
|
|
def test_login(self, username: str, password: str = DEFAULT_PASSWORD, expect_success: bool = True) -> ResponseExpectation[LoginResponse]:
|
|
request = LoginRequest(username=username, password=password)
|
|
request_id = self.send_login(request, callback=self.on_login_response)
|
|
|
|
response_expectation = ResponseExpectation(self.on_login_response, LoginResponse, request_id, expect_success)
|
|
self.call_expectations_manager.add_expectation(response_expectation)
|
|
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
|
|
assert response_expectation.fulfilled, f"Expected response not received: {response_expectation}"
|
|
return response_expectation
|
|
|
|
def test_create_instrument(self, instrument: Instrument, tick_size: float, expect_success: bool = True, expect_public_feed: bool = True) -> ResponseExpectation[CreateInstrumentResponse]:
|
|
request = CreateInstrumentRequest(instrument=instrument, tick_size=tick_size)
|
|
request_id = self.send_create_instrument(request, callback=self.on_create_instrument_response)
|
|
|
|
response_expectation = ResponseExpectation(self.on_create_instrument_response, CreateInstrumentResponse, request_id, expect_success)
|
|
self.call_expectations_manager.add_expectation(response_expectation)
|
|
|
|
if expect_success:
|
|
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
|
|
assert response_expectation.fulfilled, f"Expected response not received: {response_expectation}"
|
|
response = response_expectation.get_response()
|
|
assert response.error_message == ""
|
|
self._last_order_book_id = response.order_book_id
|
|
|
|
if expect_success and expect_public_feed:
|
|
self.expect_on_instrument_event(instrument, tick_size, response.order_book_id)
|
|
|
|
return response_expectation
|
|
|
|
def test_subscribe_to_order_book(self, instrument_symbol: str, type: SubscriptionType.ValueType, expect_success: bool = True) -> ResponseExpectation[OrderBookSubscribeResponse]:
|
|
request = OrderBookSubscribeRequest(instrument_symbol=instrument_symbol, subscription_type=type)
|
|
request_id = self.send_order_book_subscribe(request, callback=self.on_order_book_subscribe_response)
|
|
|
|
response_expectation = ResponseExpectation(self.on_order_book_subscribe_response, OrderBookSubscribeResponse, request_id, expect_success)
|
|
self.call_expectations_manager.add_expectation(response_expectation)
|
|
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
|
|
assert response_expectation.fulfilled, f"Expected response not received: {response_expectation}"
|
|
|
|
return response_expectation
|
|
|
|
def expect_on_instrument_event(self, instrument: Instrument, tick_size: float, order_book_id: int) -> None:
|
|
class OnInstrumentExpectation(CallExpectations):
|
|
def is_expected_call(self, mock_call: _Call) -> bool:
|
|
if len(mock_call.args) != 1:
|
|
return False
|
|
call_arg = mock_call.args[0]
|
|
if not isinstance(call_arg, OnInstrument):
|
|
return False
|
|
if call_arg.instrument.symbol != instrument.symbol:
|
|
return False
|
|
return True
|
|
|
|
def validate_call(self, mock_call: _Call) -> None:
|
|
call_arg = mock_call.args[0]
|
|
assert call_arg.instrument == instrument
|
|
assert tick_size == ANY or abs(call_arg.tick_size - tick_size) < 1e-5, f"Expected tick size {tick_size}, but got {call_arg.tick_size}"
|
|
assert call_arg.order_book_id == order_book_id
|
|
|
|
assert isinstance(self.on_instrument, MagicMock)
|
|
self.call_expectations_manager.add_expectation(OnInstrumentExpectation(self.on_instrument))
|
|
|
|
def expect_on_top_of_book_event(self, instrument_symbol: str, best_bid: PriceLevel | None, best_ask: PriceLevel | None) -> None:
|
|
class OnTopOfBookExpectation(CallExpectations):
|
|
def is_expected_call(self, mock_call: _Call) -> bool:
|
|
if len(mock_call.args) != 1:
|
|
return False
|
|
call_arg = mock_call.args[0]
|
|
if not isinstance(call_arg, OnTopOfBook):
|
|
return False
|
|
if call_arg.instrument_symbol != instrument_symbol:
|
|
return False
|
|
if (best_bid is not None) != call_arg.HasField("best_bid"):
|
|
return False
|
|
if (best_ask is not None) != call_arg.HasField("best_ask"):
|
|
return False
|
|
return True
|
|
|
|
def validate_call(self, mock_call: _Call) -> None:
|
|
call_arg = mock_call.args[0]
|
|
assert isinstance(call_arg, OnTopOfBook)
|
|
self._compare_price_levels(best_bid, call_arg.best_bid if call_arg.HasField("best_bid") else None)
|
|
self._compare_price_levels(best_ask, call_arg.best_ask if call_arg.HasField("best_ask") else None)
|
|
return True
|
|
|
|
def _compare_price_levels(self, expected: PriceLevel | None, actual: PriceLevel | None) -> None:
|
|
if expected is None:
|
|
assert actual is None
|
|
return
|
|
if actual is None:
|
|
assert expected is None
|
|
return
|
|
assert actual.quantity == expected.quantity, f"Expected quantity {expected.quantity}, but got {actual.quantity}"
|
|
assert abs(actual.price - expected.price) < 1e-5, f"Expected price {expected.price}, but got {actual.price}"
|
|
|
|
assert isinstance(self.on_top_of_book, MagicMock)
|
|
self.call_expectations_manager.add_expectation(OnTopOfBookExpectation(self.on_top_of_book))
|
|
|
|
def expect_on_price_depth_book_event(self, instrument_symbol: str, bids: list[PriceLevel], asks: list[PriceLevel]) -> None:
|
|
class OnPriceDepthBookExpectation(CallExpectations):
|
|
def is_expected_call(self, mock_call: _Call) -> bool:
|
|
if len(mock_call.args) != 1:
|
|
return False
|
|
call_arg = mock_call.args[0]
|
|
if not isinstance(call_arg, OnPriceDepthBook):
|
|
return False
|
|
if call_arg.instrument_symbol != instrument_symbol:
|
|
return False
|
|
if (len(bids) > 0) != (len(call_arg.bids) > 0):
|
|
return False
|
|
if (len(asks) > 0) != (len(call_arg.asks) > 0):
|
|
return False
|
|
return True
|
|
|
|
def validate_call(self, mock_call: _Call) -> None:
|
|
call_arg = mock_call.args[0]
|
|
self._compare_price_levels(bids, list(call_arg.bids))
|
|
self._compare_price_levels(asks, list(call_arg.asks))
|
|
|
|
def _compare_price_levels(self, expected: list[PriceLevel], actual: list[PriceLevel]) -> None:
|
|
differences = []
|
|
|
|
# Create dicts keyed by price for fast lookup (prices are unique per list)
|
|
expected_by_price = {decimal_from_float(pl.price): pl for pl in expected}
|
|
actual_by_price = {decimal_from_float(pl.price): pl for pl in actual}
|
|
|
|
# Check for missing prices and differences in quantity
|
|
for price in expected_by_price.keys():
|
|
if price not in actual_by_price:
|
|
differences.append(f"Missing actual price level at price {price}")
|
|
continue
|
|
exp = expected_by_price[price]
|
|
act = actual_by_price[price]
|
|
if exp.quantity != act.quantity:
|
|
differences.append(f"For price {price}: expected quantity {exp.quantity}, but got {act.quantity}")
|
|
# Check for extra actual price levels
|
|
for price in actual_by_price.keys():
|
|
if price not in expected_by_price:
|
|
differences.append(f"Extra actual price level at price {price}")
|
|
assert not differences, "\n".join(differences)
|
|
|
|
assert isinstance(self.on_price_depth_book, MagicMock)
|
|
self.call_expectations_manager.add_expectation(OnPriceDepthBookExpectation(self.on_price_depth_book))
|
|
|
|
def expect_on_trade_event(self, instrument_symbol: str, price: float, quantity: int,
|
|
aggressor_side: Side.ValueType) -> None:
|
|
class OnTradeExpectation(CallExpectations):
|
|
def is_expected_call(self, mock_call: _Call) -> bool:
|
|
if len(mock_call.args) != 1:
|
|
return False
|
|
call_arg = mock_call.args[0]
|
|
if not isinstance(call_arg, OnTrade):
|
|
return False
|
|
if call_arg.instrument_symbol != instrument_symbol:
|
|
return False
|
|
return True
|
|
|
|
def validate_call(self, mock_call: _Call) -> None:
|
|
call_arg = mock_call.args[0]
|
|
assert abs(call_arg.price - price) < 1e-5, f"Expected price {price}, got {call_arg.price}"
|
|
assert call_arg.quantity == quantity, f"Expected quantity {quantity}, got {call_arg.quantity}"
|
|
assert call_arg.aggressor_side == aggressor_side, \
|
|
f"Expected aggressor side {Side.Name(aggressor_side)}, got {Side.Name(call_arg.aggressor_side)}"
|
|
|
|
assert isinstance(self.on_trade, MagicMock)
|
|
self.call_expectations_manager.add_expectation(OnTradeExpectation(self.on_trade))
|
|
|
|
def on_instrument(self, message: OnInstrument) -> None:
|
|
# This method is mocked with MagicMock
|
|
pass
|
|
|
|
def on_top_of_book(self, message: OnTopOfBook) -> None:
|
|
# This method is mocked with MagicMock
|
|
pass
|
|
|
|
def on_price_depth_book(self, message: OnPriceDepthBook) -> None:
|
|
# This method is mocked with MagicMock
|
|
pass
|
|
|
|
def on_trade(self, message: OnTrade) -> None:
|
|
# This method is mocked with MagicMock
|
|
pass
|
|
|
|
|
|
class InfoClientConnectionHandlerFactory(ConnectionHandlerFactory[InfoTestClient]):
|
|
def __init__(self, call_expectations_manager: CallExpectationsManager) -> None:
|
|
self.call_expectations_manager = call_expectations_manager
|
|
|
|
def on_new_connection(self, socket_fd: socket.socket, ip_address: IpAddress, close_callback: Callable[[], None]) -> InfoTestClient:
|
|
return InfoTestClient(socket_fd, ip_address, close_callback, self.call_expectations_manager)
|
|
|
|
def on_connection_closed(self, connection_handler: InfoTestClient) -> None:
|
|
pass
|