gtat-tech-career-kickstarte.../solution/tests/test_client/info_test_client.py

227 lines
12 KiB
Python

from decimal import Decimal
import socket
from typing import Callable
from unittest.mock import ANY, _Call, MagicMock
import logging
from common.info_client import BaseInfoClient
from common.utils import decimal_from_float
from connection.connection_handler import ConnectionHandlerFactory
from connection.ip_address import IpAddress
from proto.info_pb2 import *
from proto.common_pb2 import Instrument, LoginRequest, LoginResponse, Side
from tests.conftest import DEFAULT_PASSWORD
from tests.common.mock_expectations import CallExpectations, CallExpectationsManager, ResponseExpectation
logger = logging.getLogger(__name__)
class InfoTestClient(BaseInfoClient):
def __init__(self, socket_fd: socket.socket, ip_address: IpAddress, on_close: Callable[[], None],
call_expectations_manager: CallExpectationsManager) -> None:
super().__init__(socket_fd, ip_address, on_close)
self._last_order_book_id: int | None = None
self._last_order_id: int | None = None
self.call_expectations_manager = call_expectations_manager
self.on_login_response = self.call_expectations_manager.create_mock("on_login_response")
self.on_create_instrument_response = self.call_expectations_manager.create_mock("on_create_instrument_response")
self.on_order_book_subscribe_response = self.call_expectations_manager.create_mock("on_order_book_subscribe_response")
self.on_instrument = self.call_expectations_manager.create_mock("on_instrument") # type: ignore
self.on_top_of_book = self.call_expectations_manager.create_mock("on_top_of_book") # type: ignore
self.on_price_depth_book = self.call_expectations_manager.create_mock("on_price_depth_book") # type: ignore
self.on_trade = self.call_expectations_manager.create_mock("on_trade") # type: ignore
def test_login(self, username: str, password: str = DEFAULT_PASSWORD, expect_success: bool = True) -> ResponseExpectation[LoginResponse]:
request = LoginRequest(username=username, password=password)
request_id = self.send_login(request, callback=self.on_login_response)
response_expectation = ResponseExpectation(self.on_login_response, LoginResponse, request_id, expect_success)
self.call_expectations_manager.add_expectation(response_expectation)
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert response_expectation.fulfilled, f"Expected response not received: {response_expectation}"
return response_expectation
def test_create_instrument(self, instrument: Instrument, tick_size: float, expect_success: bool = True, expect_public_feed: bool = True) -> ResponseExpectation[CreateInstrumentResponse]:
request = CreateInstrumentRequest(instrument=instrument, tick_size=tick_size)
request_id = self.send_create_instrument(request, callback=self.on_create_instrument_response)
response_expectation = ResponseExpectation(self.on_create_instrument_response, CreateInstrumentResponse, request_id, expect_success)
self.call_expectations_manager.add_expectation(response_expectation)
if expect_success:
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert response_expectation.fulfilled, f"Expected response not received: {response_expectation}"
response = response_expectation.get_response()
assert response.error_message == ""
self._last_order_book_id = response.order_book_id
if expect_success and expect_public_feed:
self.expect_on_instrument_event(instrument, tick_size, response.order_book_id)
return response_expectation
def test_subscribe_to_order_book(self, instrument_symbol: str, type: SubscriptionType.ValueType, expect_success: bool = True) -> ResponseExpectation[OrderBookSubscribeResponse]:
request = OrderBookSubscribeRequest(instrument_symbol=instrument_symbol, subscription_type=type)
request_id = self.send_order_book_subscribe(request, callback=self.on_order_book_subscribe_response)
response_expectation = ResponseExpectation(self.on_order_book_subscribe_response, OrderBookSubscribeResponse, request_id, expect_success)
self.call_expectations_manager.add_expectation(response_expectation)
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert response_expectation.fulfilled, f"Expected response not received: {response_expectation}"
return response_expectation
def expect_on_instrument_event(self, instrument: Instrument, tick_size: float, order_book_id: int) -> None:
class OnInstrumentExpectation(CallExpectations):
def is_expected_call(self, mock_call: _Call) -> bool:
if len(mock_call.args) != 1:
return False
call_arg = mock_call.args[0]
if not isinstance(call_arg, OnInstrument):
return False
if call_arg.instrument.symbol != instrument.symbol:
return False
return True
def validate_call(self, mock_call: _Call) -> None:
call_arg = mock_call.args[0]
assert call_arg.instrument == instrument
assert tick_size == ANY or abs(call_arg.tick_size - tick_size) < 1e-5, f"Expected tick size {tick_size}, but got {call_arg.tick_size}"
assert call_arg.order_book_id == order_book_id
assert isinstance(self.on_instrument, MagicMock)
self.call_expectations_manager.add_expectation(OnInstrumentExpectation(self.on_instrument))
def expect_on_top_of_book_event(self, instrument_symbol: str, best_bid: PriceLevel | None, best_ask: PriceLevel | None) -> None:
class OnTopOfBookExpectation(CallExpectations):
def is_expected_call(self, mock_call: _Call) -> bool:
if len(mock_call.args) != 1:
return False
call_arg = mock_call.args[0]
if not isinstance(call_arg, OnTopOfBook):
return False
if call_arg.instrument_symbol != instrument_symbol:
return False
if (best_bid is not None) != call_arg.HasField("best_bid"):
return False
if (best_ask is not None) != call_arg.HasField("best_ask"):
return False
return True
def validate_call(self, mock_call: _Call) -> None:
call_arg = mock_call.args[0]
assert isinstance(call_arg, OnTopOfBook)
self._compare_price_levels(best_bid, call_arg.best_bid if call_arg.HasField("best_bid") else None)
self._compare_price_levels(best_ask, call_arg.best_ask if call_arg.HasField("best_ask") else None)
return True
def _compare_price_levels(self, expected: PriceLevel | None, actual: PriceLevel | None) -> None:
if expected is None:
assert actual is None
return
if actual is None:
assert expected is None
return
assert actual.quantity == expected.quantity, f"Expected quantity {expected.quantity}, but got {actual.quantity}"
assert abs(actual.price - expected.price) < 1e-5, f"Expected price {expected.price}, but got {actual.price}"
assert isinstance(self.on_top_of_book, MagicMock)
self.call_expectations_manager.add_expectation(OnTopOfBookExpectation(self.on_top_of_book))
def expect_on_price_depth_book_event(self, instrument_symbol: str, bids: list[PriceLevel], asks: list[PriceLevel]) -> None:
class OnPriceDepthBookExpectation(CallExpectations):
def is_expected_call(self, mock_call: _Call) -> bool:
if len(mock_call.args) != 1:
return False
call_arg = mock_call.args[0]
if not isinstance(call_arg, OnPriceDepthBook):
return False
if call_arg.instrument_symbol != instrument_symbol:
return False
if (len(bids) > 0) != (len(call_arg.bids) > 0):
return False
if (len(asks) > 0) != (len(call_arg.asks) > 0):
return False
return True
def validate_call(self, mock_call: _Call) -> None:
call_arg = mock_call.args[0]
self._compare_price_levels(bids, list(call_arg.bids))
self._compare_price_levels(asks, list(call_arg.asks))
def _compare_price_levels(self, expected: list[PriceLevel], actual: list[PriceLevel]) -> None:
differences = []
# Create dicts keyed by price for fast lookup (prices are unique per list)
expected_by_price = {decimal_from_float(pl.price): pl for pl in expected}
actual_by_price = {decimal_from_float(pl.price): pl for pl in actual}
# Check for missing prices and differences in quantity
for price in expected_by_price.keys():
if price not in actual_by_price:
differences.append(f"Missing actual price level at price {price}")
continue
exp = expected_by_price[price]
act = actual_by_price[price]
if exp.quantity != act.quantity:
differences.append(f"For price {price}: expected quantity {exp.quantity}, but got {act.quantity}")
# Check for extra actual price levels
for price in actual_by_price.keys():
if price not in expected_by_price:
differences.append(f"Extra actual price level at price {price}")
assert not differences, "\n".join(differences)
assert isinstance(self.on_price_depth_book, MagicMock)
self.call_expectations_manager.add_expectation(OnPriceDepthBookExpectation(self.on_price_depth_book))
def expect_on_trade_event(self, instrument_symbol: str, price: float, quantity: int,
aggressor_side: Side.ValueType) -> None:
class OnTradeExpectation(CallExpectations):
def is_expected_call(self, mock_call: _Call) -> bool:
if len(mock_call.args) != 1:
return False
call_arg = mock_call.args[0]
if not isinstance(call_arg, OnTrade):
return False
if call_arg.instrument_symbol != instrument_symbol:
return False
return True
def validate_call(self, mock_call: _Call) -> None:
call_arg = mock_call.args[0]
assert abs(call_arg.price - price) < 1e-5, f"Expected price {price}, got {call_arg.price}"
assert call_arg.quantity == quantity, f"Expected quantity {quantity}, got {call_arg.quantity}"
assert call_arg.aggressor_side == aggressor_side, \
f"Expected aggressor side {Side.Name(aggressor_side)}, got {Side.Name(call_arg.aggressor_side)}"
assert isinstance(self.on_trade, MagicMock)
self.call_expectations_manager.add_expectation(OnTradeExpectation(self.on_trade))
def on_instrument(self, message: OnInstrument) -> None:
# This method is mocked with MagicMock
pass
def on_top_of_book(self, message: OnTopOfBook) -> None:
# This method is mocked with MagicMock
pass
def on_price_depth_book(self, message: OnPriceDepthBook) -> None:
# This method is mocked with MagicMock
pass
def on_trade(self, message: OnTrade) -> None:
# This method is mocked with MagicMock
pass
class InfoClientConnectionHandlerFactory(ConnectionHandlerFactory[InfoTestClient]):
def __init__(self, call_expectations_manager: CallExpectationsManager) -> None:
self.call_expectations_manager = call_expectations_manager
def on_new_connection(self, socket_fd: socket.socket, ip_address: IpAddress, close_callback: Callable[[], None]) -> InfoTestClient:
return InfoTestClient(socket_fd, ip_address, close_callback, self.call_expectations_manager)
def on_connection_closed(self, connection_handler: InfoTestClient) -> None:
pass