import logging import socket from decimal import Decimal from typing import Callable from unittest.mock import MagicMock from connection.connection_handler import ConnectionHandler, ConnectionHandlerFactory from connection.ip_address import IpAddress from google.protobuf.message import Message from proto.common_pb2 import LoginRequest, LoginResponse, MessageType, Side from proto.execution_pb2 import ( InsertOrderRequest, InsertOrderResponse, CancelOrderRequest, CancelOrderResponse, ) from proto.risk_limits_pb2 import ( GetUserRiskLimitsRequest, GetUserRiskLimitsResponse, SetUserRiskLimitsRequest, SetUserRiskLimitsResponse, GetInstrumentRiskLimitsRequest, GetInstrumentRiskLimitsResponse, SetInstrumentRiskLimitsRequest, SetInstrumentRiskLimitsResponse, UserRiskLimits, InstrumentRiskLimits, RollingWindowLimit, ) from tests.conftest import DEFAULT_PASSWORD from tests.common.mock_expectations import CallExpectationsManager, ResponseExpectation logger = logging.getLogger(__name__) class RiskGatewayTestClient(ConnectionHandler): """Test client that speaks both the execution and risk_limits protocols.""" def __init__(self, socket_fd: socket.socket, ip_address: IpAddress, on_close: Callable[[], None], call_expectations_manager: CallExpectationsManager) -> None: super().__init__(socket_fd, ip_address, on_close) self.next_request_id = 1 self.callbacks: dict[int, Callable] = {} self.call_expectations_manager = call_expectations_manager self._last_order_id: int | None = None self.on_login_response = call_expectations_manager.create_mock("on_login_response") self.on_insert_order_response = call_expectations_manager.create_mock("on_insert_order_response") self.on_cancel_order_response = call_expectations_manager.create_mock("on_cancel_order_response") self.on_get_user_limits_response = call_expectations_manager.create_mock("on_get_user_limits_response") self.on_set_user_limits_response = call_expectations_manager.create_mock("on_set_user_limits_response") self.on_get_instrument_limits_response = call_expectations_manager.create_mock("on_get_instrument_limits_response") self.on_set_instrument_limits_response = call_expectations_manager.create_mock("on_set_instrument_limits_response") def on_disconnect(self) -> None: logger.info(f"Disconnected from {self.ip_address}") def handle_message(self, message_type: int, raw_message: bytes) -> None: logger.info(f"Handling message of type {MessageType.Name(message_type)}") msg: Message if message_type == MessageType.AUTH_LOGIN_RESPONSE: msg = LoginResponse.FromString(raw_message) elif message_type == MessageType.EXEC_INSERT_ORDER_RESPONSE: msg = InsertOrderResponse.FromString(raw_message) elif message_type == MessageType.EXEC_CANCEL_ORDER_RESPONSE: msg = CancelOrderResponse.FromString(raw_message) elif message_type == MessageType.RISK_GET_USER_LIMITS_RESPONSE: msg = GetUserRiskLimitsResponse.FromString(raw_message) elif message_type == MessageType.RISK_SET_USER_LIMITS_RESPONSE: msg = SetUserRiskLimitsResponse.FromString(raw_message) elif message_type == MessageType.RISK_GET_INSTRUMENT_LIMITS_RESPONSE: msg = GetInstrumentRiskLimitsResponse.FromString(raw_message) elif message_type == MessageType.RISK_SET_INSTRUMENT_LIMITS_RESPONSE: msg = SetInstrumentRiskLimitsResponse.FromString(raw_message) else: raise ValueError(f"Unexpected message type: {MessageType.Name(message_type)}") assert hasattr(msg, "request_id") callback = self.callbacks.pop(msg.request_id, None) if callback: callback(msg) else: raise ValueError(f"No callback for request_id: {msg.request_id}") # ------------------------------------------------------------------ # Low-level helpers # ------------------------------------------------------------------ def _send_request(self, message_type: int, message: Message, callback: Callable) -> int: request_id = self.next_request_id self.next_request_id += 1 message.request_id = request_id # type: ignore[union-attr] self.callbacks[request_id] = callback self.send_message(message_type, message) return request_id # ------------------------------------------------------------------ # Auth # ------------------------------------------------------------------ def test_login(self, username: str, password: str = DEFAULT_PASSWORD, expect_success: bool = True) -> ResponseExpectation[LoginResponse]: request = LoginRequest(username=username, password=password) request_id = self._send_request( MessageType.AUTH_LOGIN_REQUEST, request, self.on_login_response) expectation = ResponseExpectation( self.on_login_response, LoginResponse, request_id, expect_success) self.call_expectations_manager.add_expectation(expectation) self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False) assert expectation.fulfilled, "Login response not received" return expectation # ------------------------------------------------------------------ # Execution # ------------------------------------------------------------------ def test_insert_order( self, instrument_symbol: str, side: Side.ValueType, price: Decimal, quantity: int, expect_success: bool = True, ) -> ResponseExpectation[InsertOrderResponse]: request = InsertOrderRequest( instrument_symbol=instrument_symbol, side=side, price=float(price), quantity=quantity) request_id = self._send_request( MessageType.EXEC_INSERT_ORDER_REQUEST, request, self.on_insert_order_response) expectation = ResponseExpectation( self.on_insert_order_response, InsertOrderResponse, request_id, expect_success) self.call_expectations_manager.add_expectation(expectation) self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False) assert expectation.fulfilled, "Insert order response not received" if expect_success: self._last_order_id = expectation.get_response().order_id return expectation def test_cancel_order( self, instrument_symbol: str, order_id: int | None = None, expect_success: bool = True, ) -> ResponseExpectation[CancelOrderResponse]: if order_id is None: assert self._last_order_id is not None, "No order to cancel" order_id = self._last_order_id request = CancelOrderRequest( instrument_symbol=instrument_symbol, order_id=order_id) request_id = self._send_request( MessageType.EXEC_CANCEL_ORDER_REQUEST, request, self.on_cancel_order_response) expectation = ResponseExpectation( self.on_cancel_order_response, CancelOrderResponse, request_id, expect_success) self.call_expectations_manager.add_expectation(expectation) self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False) assert expectation.fulfilled, "Cancel order response not received" return expectation # ------------------------------------------------------------------ # Risk Limits – user level # ------------------------------------------------------------------ def test_set_user_risk_limits( self, user_risk_limits: UserRiskLimits | None = None, expect_success: bool = True, ) -> ResponseExpectation[SetUserRiskLimitsResponse]: request = SetUserRiskLimitsRequest() if user_risk_limits is not None: request.user_risk_limits.CopyFrom(user_risk_limits) request_id = self._send_request( MessageType.RISK_SET_USER_LIMITS_REQUEST, request, self.on_set_user_limits_response) expectation = ResponseExpectation( self.on_set_user_limits_response, SetUserRiskLimitsResponse, request_id, expect_success) self.call_expectations_manager.add_expectation(expectation) self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False) assert expectation.fulfilled, "Set user risk limits response not received" return expectation def test_get_user_risk_limits( self, expect_success: bool = True, ) -> ResponseExpectation[GetUserRiskLimitsResponse]: request = GetUserRiskLimitsRequest() request_id = self._send_request( MessageType.RISK_GET_USER_LIMITS_REQUEST, request, self.on_get_user_limits_response) expectation = ResponseExpectation( self.on_get_user_limits_response, GetUserRiskLimitsResponse, request_id, expect_success) self.call_expectations_manager.add_expectation(expectation) self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False) assert expectation.fulfilled, "Get user risk limits response not received" return expectation # ------------------------------------------------------------------ # Risk Limits – instrument level # ------------------------------------------------------------------ def test_set_instrument_risk_limits( self, instrument_symbol: str, instrument_risk_limits: InstrumentRiskLimits | None = None, expect_success: bool = True, ) -> ResponseExpectation[SetInstrumentRiskLimitsResponse]: request = SetInstrumentRiskLimitsRequest(instrument_symbol=instrument_symbol) if instrument_risk_limits is not None: request.instrument_risk_limits.CopyFrom(instrument_risk_limits) request_id = self._send_request( MessageType.RISK_SET_INSTRUMENT_LIMITS_REQUEST, request, self.on_set_instrument_limits_response) expectation = ResponseExpectation( self.on_set_instrument_limits_response, SetInstrumentRiskLimitsResponse, request_id, expect_success) self.call_expectations_manager.add_expectation(expectation) self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False) assert expectation.fulfilled, "Set instrument risk limits response not received" return expectation def test_get_instrument_risk_limits( self, expect_success: bool = True, ) -> ResponseExpectation[GetInstrumentRiskLimitsResponse]: request = GetInstrumentRiskLimitsRequest() request_id = self._send_request( MessageType.RISK_GET_INSTRUMENT_LIMITS_REQUEST, request, self.on_get_instrument_limits_response) expectation = ResponseExpectation( self.on_get_instrument_limits_response, GetInstrumentRiskLimitsResponse, request_id, expect_success) self.call_expectations_manager.add_expectation(expectation) self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False) assert expectation.fulfilled, "Get instrument risk limits response not received" return expectation class RiskGatewayClientConnectionHandlerFactory(ConnectionHandlerFactory[RiskGatewayTestClient]): def __init__(self, call_expectations_manager: CallExpectationsManager) -> None: self.call_expectations_manager = call_expectations_manager def on_new_connection(self, socket_fd: socket.socket, ip_address: IpAddress, close_callback: Callable[[], None]) -> RiskGatewayTestClient: return RiskGatewayTestClient( socket_fd, ip_address, close_callback, self.call_expectations_manager) def on_connection_closed(self, connection_handler: RiskGatewayTestClient) -> None: pass