from decimal import Decimal import socket from typing import Callable from unittest.mock import ANY, _Call, MagicMock import logging from common.info_client import BaseInfoClient from common.utils import decimal_from_float from connection.connection_handler import ConnectionHandlerFactory from connection.ip_address import IpAddress from proto.info_pb2 import * from proto.common_pb2 import Instrument, LoginRequest, LoginResponse, Side from tests.conftest import DEFAULT_PASSWORD from tests.common.mock_expectations import CallExpectations, CallExpectationsManager, ResponseExpectation logger = logging.getLogger(__name__) class InfoTestClient(BaseInfoClient): def __init__(self, socket_fd: socket.socket, ip_address: IpAddress, on_close: Callable[[], None], call_expectations_manager: CallExpectationsManager) -> None: super().__init__(socket_fd, ip_address, on_close) self._last_order_book_id: int | None = None self._last_order_id: int | None = None self.call_expectations_manager = call_expectations_manager self.on_login_response = self.call_expectations_manager.create_mock("on_login_response") self.on_create_instrument_response = self.call_expectations_manager.create_mock("on_create_instrument_response") self.on_order_book_subscribe_response = self.call_expectations_manager.create_mock("on_order_book_subscribe_response") self.on_instrument = self.call_expectations_manager.create_mock("on_instrument") # type: ignore self.on_top_of_book = self.call_expectations_manager.create_mock("on_top_of_book") # type: ignore self.on_price_depth_book = self.call_expectations_manager.create_mock("on_price_depth_book") # type: ignore self.on_trade = self.call_expectations_manager.create_mock("on_trade") # type: ignore def test_login(self, username: str, password: str = DEFAULT_PASSWORD, expect_success: bool = True) -> ResponseExpectation[LoginResponse]: request = LoginRequest(username=username, password=password) request_id = self.send_login(request, callback=self.on_login_response) response_expectation = ResponseExpectation(self.on_login_response, LoginResponse, request_id, expect_success) self.call_expectations_manager.add_expectation(response_expectation) self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False) assert response_expectation.fulfilled, f"Expected response not received: {response_expectation}" return response_expectation def test_create_instrument(self, instrument: Instrument, tick_size: float, expect_success: bool = True, expect_public_feed: bool = True) -> ResponseExpectation[CreateInstrumentResponse]: request = CreateInstrumentRequest(instrument=instrument, tick_size=tick_size) request_id = self.send_create_instrument(request, callback=self.on_create_instrument_response) response_expectation = ResponseExpectation(self.on_create_instrument_response, CreateInstrumentResponse, request_id, expect_success) self.call_expectations_manager.add_expectation(response_expectation) if expect_success: self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False) assert response_expectation.fulfilled, f"Expected response not received: {response_expectation}" response = response_expectation.get_response() assert response.error_message == "" self._last_order_book_id = response.order_book_id if expect_success and expect_public_feed: self.expect_on_instrument_event(instrument, tick_size, response.order_book_id) return response_expectation def test_subscribe_to_order_book(self, instrument_symbol: str, type: SubscriptionType.ValueType, expect_success: bool = True) -> ResponseExpectation[OrderBookSubscribeResponse]: request = OrderBookSubscribeRequest(instrument_symbol=instrument_symbol, subscription_type=type) request_id = self.send_order_book_subscribe(request, callback=self.on_order_book_subscribe_response) response_expectation = ResponseExpectation(self.on_order_book_subscribe_response, OrderBookSubscribeResponse, request_id, expect_success) self.call_expectations_manager.add_expectation(response_expectation) self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False) assert response_expectation.fulfilled, f"Expected response not received: {response_expectation}" return response_expectation def expect_on_instrument_event(self, instrument: Instrument, tick_size: float, order_book_id: int) -> None: class OnInstrumentExpectation(CallExpectations): def is_expected_call(self, mock_call: _Call) -> bool: if len(mock_call.args) != 1: return False call_arg = mock_call.args[0] if not isinstance(call_arg, OnInstrument): return False if call_arg.instrument.symbol != instrument.symbol: return False return True def validate_call(self, mock_call: _Call) -> None: call_arg = mock_call.args[0] assert call_arg.instrument == instrument assert tick_size == ANY or abs(call_arg.tick_size - tick_size) < 1e-5, f"Expected tick size {tick_size}, but got {call_arg.tick_size}" assert call_arg.order_book_id == order_book_id assert isinstance(self.on_instrument, MagicMock) self.call_expectations_manager.add_expectation(OnInstrumentExpectation(self.on_instrument)) def expect_on_top_of_book_event(self, instrument_symbol: str, best_bid: PriceLevel | None, best_ask: PriceLevel | None) -> None: class OnTopOfBookExpectation(CallExpectations): def is_expected_call(self, mock_call: _Call) -> bool: if len(mock_call.args) != 1: return False call_arg = mock_call.args[0] if not isinstance(call_arg, OnTopOfBook): return False if call_arg.instrument_symbol != instrument_symbol: return False if (best_bid is not None) != call_arg.HasField("best_bid"): return False if (best_ask is not None) != call_arg.HasField("best_ask"): return False return True def validate_call(self, mock_call: _Call) -> None: call_arg = mock_call.args[0] assert isinstance(call_arg, OnTopOfBook) self._compare_price_levels(best_bid, call_arg.best_bid if call_arg.HasField("best_bid") else None) self._compare_price_levels(best_ask, call_arg.best_ask if call_arg.HasField("best_ask") else None) return True def _compare_price_levels(self, expected: PriceLevel | None, actual: PriceLevel | None) -> None: if expected is None: assert actual is None return if actual is None: assert expected is None return assert actual.quantity == expected.quantity, f"Expected quantity {expected.quantity}, but got {actual.quantity}" assert abs(actual.price - expected.price) < 1e-5, f"Expected price {expected.price}, but got {actual.price}" assert isinstance(self.on_top_of_book, MagicMock) self.call_expectations_manager.add_expectation(OnTopOfBookExpectation(self.on_top_of_book)) def expect_on_price_depth_book_event(self, instrument_symbol: str, bids: list[PriceLevel], asks: list[PriceLevel]) -> None: class OnPriceDepthBookExpectation(CallExpectations): def is_expected_call(self, mock_call: _Call) -> bool: if len(mock_call.args) != 1: return False call_arg = mock_call.args[0] if not isinstance(call_arg, OnPriceDepthBook): return False if call_arg.instrument_symbol != instrument_symbol: return False if (len(bids) > 0) != (len(call_arg.bids) > 0): return False if (len(asks) > 0) != (len(call_arg.asks) > 0): return False return True def validate_call(self, mock_call: _Call) -> None: call_arg = mock_call.args[0] self._compare_price_levels(bids, list(call_arg.bids)) self._compare_price_levels(asks, list(call_arg.asks)) def _compare_price_levels(self, expected: list[PriceLevel], actual: list[PriceLevel]) -> None: differences = [] # Create dicts keyed by price for fast lookup (prices are unique per list) expected_by_price = {decimal_from_float(pl.price): pl for pl in expected} actual_by_price = {decimal_from_float(pl.price): pl for pl in actual} # Check for missing prices and differences in quantity for price in expected_by_price.keys(): if price not in actual_by_price: differences.append(f"Missing actual price level at price {price}") continue exp = expected_by_price[price] act = actual_by_price[price] if exp.quantity != act.quantity: differences.append(f"For price {price}: expected quantity {exp.quantity}, but got {act.quantity}") # Check for extra actual price levels for price in actual_by_price.keys(): if price not in expected_by_price: differences.append(f"Extra actual price level at price {price}") assert not differences, "\n".join(differences) assert isinstance(self.on_price_depth_book, MagicMock) self.call_expectations_manager.add_expectation(OnPriceDepthBookExpectation(self.on_price_depth_book)) def expect_on_trade_event(self, instrument_symbol: str, price: float, quantity: int, aggressor_side: Side.ValueType) -> None: class OnTradeExpectation(CallExpectations): def is_expected_call(self, mock_call: _Call) -> bool: if len(mock_call.args) != 1: return False call_arg = mock_call.args[0] if not isinstance(call_arg, OnTrade): return False if call_arg.instrument_symbol != instrument_symbol: return False return True def validate_call(self, mock_call: _Call) -> None: call_arg = mock_call.args[0] assert abs(call_arg.price - price) < 1e-5, f"Expected price {price}, got {call_arg.price}" assert call_arg.quantity == quantity, f"Expected quantity {quantity}, got {call_arg.quantity}" assert call_arg.aggressor_side == aggressor_side, \ f"Expected aggressor side {Side.Name(aggressor_side)}, got {Side.Name(call_arg.aggressor_side)}" assert isinstance(self.on_trade, MagicMock) self.call_expectations_manager.add_expectation(OnTradeExpectation(self.on_trade)) def on_instrument(self, message: OnInstrument) -> None: # This method is mocked with MagicMock pass def on_top_of_book(self, message: OnTopOfBook) -> None: # This method is mocked with MagicMock pass def on_price_depth_book(self, message: OnPriceDepthBook) -> None: # This method is mocked with MagicMock pass def on_trade(self, message: OnTrade) -> None: # This method is mocked with MagicMock pass class InfoClientConnectionHandlerFactory(ConnectionHandlerFactory[InfoTestClient]): def __init__(self, call_expectations_manager: CallExpectationsManager) -> None: self.call_expectations_manager = call_expectations_manager def on_new_connection(self, socket_fd: socket.socket, ip_address: IpAddress, close_callback: Callable[[], None]) -> InfoTestClient: return InfoTestClient(socket_fd, ip_address, close_callback, self.call_expectations_manager) def on_connection_closed(self, connection_handler: InfoTestClient) -> None: pass