from decimal import Decimal import socket from typing import Callable import logging from connection.connection_handler import ConnectionHandler, ConnectionHandlerFactory from connection.ip_address import IpAddress from proto.admin_pb2 import CreateInstrumentRequest, CreateInstrumentResponse from proto.common_pb2 import Instrument, MessageType from tests.common.mock_expectations import CallExpectationsManager, ResponseExpectation logger = logging.getLogger(__name__) class AdminTestClient(ConnectionHandler): def __init__(self, socket_fd: socket.socket, ip_address: IpAddress, on_close: Callable[[], None], call_expectations_manager: CallExpectationsManager) -> None: super().__init__(socket_fd, ip_address, on_close) self.next_request_id = 1 self.callbacks: dict[int, Callable] = {} self.call_expectations_manager = call_expectations_manager self.on_create_instrument_response = self.call_expectations_manager.create_mock("on_create_instrument_response") def on_disconnect(self) -> None: logger.info(f"Disconnected from {self.ip_address}") def handle_message(self, message_type: int, raw_message: bytes) -> None: logger.info(f"Handling message of type {MessageType.Name(message_type)}") if message_type == MessageType.ADMIN_CREATE_INSTRUMENT_RESPONSE: message = CreateInstrumentResponse.FromString(raw_message) else: raise ValueError(f"Unexpected message type: {MessageType.Name(message_type)}") callback = self.callbacks.pop(message.request_id, None) if callback: callback(message) else: raise ValueError(f"Received response with unknown request_id: {message.request_id}") def test_create_instrument(self, instrument: Instrument, tick_size: Decimal, expect_success: bool = True) -> ResponseExpectation[CreateInstrumentResponse]: request = CreateInstrumentRequest(instrument=instrument, tick_size=float(tick_size)) request_id = self._get_next_request_id() request.request_id = request_id self.callbacks[request_id] = self.on_create_instrument_response self.send_message(MessageType.ADMIN_CREATE_INSTRUMENT_REQUEST, request) response_expectation = ResponseExpectation( self.on_create_instrument_response, CreateInstrumentResponse, request_id, expect_success) self.call_expectations_manager.add_expectation(response_expectation) if expect_success: self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False) assert response_expectation.fulfilled, "Expected CreateInstrumentResponse not received" return response_expectation def send_message(self, message_type: int, message) -> None: logger.info(f"Sending message of type {MessageType.Name(MessageType.ValueType(message_type))}") super().send_message(message_type, message) def _get_next_request_id(self) -> int: request_id = self.next_request_id self.next_request_id += 1 return request_id class AdminClientConnectionHandlerFactory(ConnectionHandlerFactory[AdminTestClient]): def __init__(self, call_expectations_manager: CallExpectationsManager) -> None: self.call_expectations_manager = call_expectations_manager def on_new_connection(self, socket_fd: socket.socket, ip_address: IpAddress, close_callback: Callable[[], None]) -> AdminTestClient: return AdminTestClient(socket_fd, ip_address, close_callback, self.call_expectations_manager) def on_connection_closed(self, connection_handler: AdminTestClient) -> None: pass