249 lines
12 KiB
Python
249 lines
12 KiB
Python
import logging
|
||
import socket
|
||
from decimal import Decimal
|
||
from typing import Callable
|
||
from unittest.mock import MagicMock
|
||
|
||
from connection.connection_handler import ConnectionHandler, ConnectionHandlerFactory
|
||
from connection.ip_address import IpAddress
|
||
from google.protobuf.message import Message
|
||
from proto.common_pb2 import LoginRequest, LoginResponse, MessageType, Side
|
||
from proto.execution_pb2 import (
|
||
InsertOrderRequest, InsertOrderResponse,
|
||
CancelOrderRequest, CancelOrderResponse,
|
||
)
|
||
from proto.risk_limits_pb2 import (
|
||
GetUserRiskLimitsRequest, GetUserRiskLimitsResponse,
|
||
SetUserRiskLimitsRequest, SetUserRiskLimitsResponse,
|
||
GetInstrumentRiskLimitsRequest, GetInstrumentRiskLimitsResponse,
|
||
SetInstrumentRiskLimitsRequest, SetInstrumentRiskLimitsResponse,
|
||
UserRiskLimits, InstrumentRiskLimits, RollingWindowLimit,
|
||
)
|
||
|
||
from tests.conftest import DEFAULT_PASSWORD
|
||
from tests.common.mock_expectations import CallExpectationsManager, ResponseExpectation
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class RiskGatewayTestClient(ConnectionHandler):
|
||
"""Test client that speaks both the execution and risk_limits protocols."""
|
||
|
||
def __init__(self, socket_fd: socket.socket, ip_address: IpAddress,
|
||
on_close: Callable[[], None],
|
||
call_expectations_manager: CallExpectationsManager) -> None:
|
||
super().__init__(socket_fd, ip_address, on_close)
|
||
self.next_request_id = 1
|
||
self.callbacks: dict[int, Callable] = {}
|
||
self.call_expectations_manager = call_expectations_manager
|
||
self._last_order_id: int | None = None
|
||
|
||
self.on_login_response = call_expectations_manager.create_mock("on_login_response")
|
||
self.on_insert_order_response = call_expectations_manager.create_mock("on_insert_order_response")
|
||
self.on_cancel_order_response = call_expectations_manager.create_mock("on_cancel_order_response")
|
||
self.on_get_user_limits_response = call_expectations_manager.create_mock("on_get_user_limits_response")
|
||
self.on_set_user_limits_response = call_expectations_manager.create_mock("on_set_user_limits_response")
|
||
self.on_get_instrument_limits_response = call_expectations_manager.create_mock("on_get_instrument_limits_response")
|
||
self.on_set_instrument_limits_response = call_expectations_manager.create_mock("on_set_instrument_limits_response")
|
||
|
||
def on_disconnect(self) -> None:
|
||
logger.info(f"Disconnected from {self.ip_address}")
|
||
|
||
def handle_message(self, message_type: int, raw_message: bytes) -> None:
|
||
logger.info(f"Handling message of type {MessageType.Name(message_type)}")
|
||
msg: Message
|
||
if message_type == MessageType.AUTH_LOGIN_RESPONSE:
|
||
msg = LoginResponse.FromString(raw_message)
|
||
elif message_type == MessageType.EXEC_INSERT_ORDER_RESPONSE:
|
||
msg = InsertOrderResponse.FromString(raw_message)
|
||
elif message_type == MessageType.EXEC_CANCEL_ORDER_RESPONSE:
|
||
msg = CancelOrderResponse.FromString(raw_message)
|
||
elif message_type == MessageType.RISK_GET_USER_LIMITS_RESPONSE:
|
||
msg = GetUserRiskLimitsResponse.FromString(raw_message)
|
||
elif message_type == MessageType.RISK_SET_USER_LIMITS_RESPONSE:
|
||
msg = SetUserRiskLimitsResponse.FromString(raw_message)
|
||
elif message_type == MessageType.RISK_GET_INSTRUMENT_LIMITS_RESPONSE:
|
||
msg = GetInstrumentRiskLimitsResponse.FromString(raw_message)
|
||
elif message_type == MessageType.RISK_SET_INSTRUMENT_LIMITS_RESPONSE:
|
||
msg = SetInstrumentRiskLimitsResponse.FromString(raw_message)
|
||
else:
|
||
raise ValueError(f"Unexpected message type: {MessageType.Name(message_type)}")
|
||
|
||
assert hasattr(msg, "request_id")
|
||
callback = self.callbacks.pop(msg.request_id, None)
|
||
if callback:
|
||
callback(msg)
|
||
else:
|
||
raise ValueError(f"No callback for request_id: {msg.request_id}")
|
||
|
||
# ------------------------------------------------------------------
|
||
# Low-level helpers
|
||
# ------------------------------------------------------------------
|
||
|
||
def _send_request(self, message_type: int, message: Message,
|
||
callback: Callable) -> int:
|
||
request_id = self.next_request_id
|
||
self.next_request_id += 1
|
||
message.request_id = request_id # type: ignore[union-attr]
|
||
self.callbacks[request_id] = callback
|
||
self.send_message(message_type, message)
|
||
return request_id
|
||
|
||
# ------------------------------------------------------------------
|
||
# Auth
|
||
# ------------------------------------------------------------------
|
||
|
||
def test_login(self, username: str, password: str = DEFAULT_PASSWORD,
|
||
expect_success: bool = True) -> ResponseExpectation[LoginResponse]:
|
||
request = LoginRequest(username=username, password=password)
|
||
request_id = self._send_request(
|
||
MessageType.AUTH_LOGIN_REQUEST, request, self.on_login_response)
|
||
|
||
expectation = ResponseExpectation(
|
||
self.on_login_response, LoginResponse, request_id, expect_success)
|
||
self.call_expectations_manager.add_expectation(expectation)
|
||
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
|
||
assert expectation.fulfilled, "Login response not received"
|
||
return expectation
|
||
|
||
# ------------------------------------------------------------------
|
||
# Execution
|
||
# ------------------------------------------------------------------
|
||
|
||
def test_insert_order(
|
||
self, instrument_symbol: str, side: Side.ValueType,
|
||
price: Decimal, quantity: int,
|
||
expect_success: bool = True,
|
||
) -> ResponseExpectation[InsertOrderResponse]:
|
||
request = InsertOrderRequest(
|
||
instrument_symbol=instrument_symbol, side=side,
|
||
price=float(price), quantity=quantity)
|
||
request_id = self._send_request(
|
||
MessageType.EXEC_INSERT_ORDER_REQUEST, request,
|
||
self.on_insert_order_response)
|
||
|
||
expectation = ResponseExpectation(
|
||
self.on_insert_order_response, InsertOrderResponse,
|
||
request_id, expect_success)
|
||
self.call_expectations_manager.add_expectation(expectation)
|
||
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
|
||
assert expectation.fulfilled, "Insert order response not received"
|
||
|
||
if expect_success:
|
||
self._last_order_id = expectation.get_response().order_id
|
||
return expectation
|
||
|
||
def test_cancel_order(
|
||
self, instrument_symbol: str, order_id: int | None = None,
|
||
expect_success: bool = True,
|
||
) -> ResponseExpectation[CancelOrderResponse]:
|
||
if order_id is None:
|
||
assert self._last_order_id is not None, "No order to cancel"
|
||
order_id = self._last_order_id
|
||
|
||
request = CancelOrderRequest(
|
||
instrument_symbol=instrument_symbol, order_id=order_id)
|
||
request_id = self._send_request(
|
||
MessageType.EXEC_CANCEL_ORDER_REQUEST, request,
|
||
self.on_cancel_order_response)
|
||
|
||
expectation = ResponseExpectation(
|
||
self.on_cancel_order_response, CancelOrderResponse,
|
||
request_id, expect_success)
|
||
self.call_expectations_manager.add_expectation(expectation)
|
||
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
|
||
assert expectation.fulfilled, "Cancel order response not received"
|
||
return expectation
|
||
|
||
# ------------------------------------------------------------------
|
||
# Risk Limits – user level
|
||
# ------------------------------------------------------------------
|
||
|
||
def test_set_user_risk_limits(
|
||
self, user_risk_limits: UserRiskLimits | None = None,
|
||
expect_success: bool = True,
|
||
) -> ResponseExpectation[SetUserRiskLimitsResponse]:
|
||
request = SetUserRiskLimitsRequest()
|
||
if user_risk_limits is not None:
|
||
request.user_risk_limits.CopyFrom(user_risk_limits)
|
||
request_id = self._send_request(
|
||
MessageType.RISK_SET_USER_LIMITS_REQUEST, request,
|
||
self.on_set_user_limits_response)
|
||
|
||
expectation = ResponseExpectation(
|
||
self.on_set_user_limits_response, SetUserRiskLimitsResponse,
|
||
request_id, expect_success)
|
||
self.call_expectations_manager.add_expectation(expectation)
|
||
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
|
||
assert expectation.fulfilled, "Set user risk limits response not received"
|
||
return expectation
|
||
|
||
def test_get_user_risk_limits(
|
||
self, expect_success: bool = True,
|
||
) -> ResponseExpectation[GetUserRiskLimitsResponse]:
|
||
request = GetUserRiskLimitsRequest()
|
||
request_id = self._send_request(
|
||
MessageType.RISK_GET_USER_LIMITS_REQUEST, request,
|
||
self.on_get_user_limits_response)
|
||
|
||
expectation = ResponseExpectation(
|
||
self.on_get_user_limits_response, GetUserRiskLimitsResponse,
|
||
request_id, expect_success)
|
||
self.call_expectations_manager.add_expectation(expectation)
|
||
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
|
||
assert expectation.fulfilled, "Get user risk limits response not received"
|
||
return expectation
|
||
|
||
# ------------------------------------------------------------------
|
||
# Risk Limits – instrument level
|
||
# ------------------------------------------------------------------
|
||
|
||
def test_set_instrument_risk_limits(
|
||
self, instrument_symbol: str,
|
||
instrument_risk_limits: InstrumentRiskLimits | None = None,
|
||
expect_success: bool = True,
|
||
) -> ResponseExpectation[SetInstrumentRiskLimitsResponse]:
|
||
request = SetInstrumentRiskLimitsRequest(instrument_symbol=instrument_symbol)
|
||
if instrument_risk_limits is not None:
|
||
request.instrument_risk_limits.CopyFrom(instrument_risk_limits)
|
||
request_id = self._send_request(
|
||
MessageType.RISK_SET_INSTRUMENT_LIMITS_REQUEST, request,
|
||
self.on_set_instrument_limits_response)
|
||
|
||
expectation = ResponseExpectation(
|
||
self.on_set_instrument_limits_response, SetInstrumentRiskLimitsResponse,
|
||
request_id, expect_success)
|
||
self.call_expectations_manager.add_expectation(expectation)
|
||
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
|
||
assert expectation.fulfilled, "Set instrument risk limits response not received"
|
||
return expectation
|
||
|
||
def test_get_instrument_risk_limits(
|
||
self, expect_success: bool = True,
|
||
) -> ResponseExpectation[GetInstrumentRiskLimitsResponse]:
|
||
request = GetInstrumentRiskLimitsRequest()
|
||
request_id = self._send_request(
|
||
MessageType.RISK_GET_INSTRUMENT_LIMITS_REQUEST, request,
|
||
self.on_get_instrument_limits_response)
|
||
|
||
expectation = ResponseExpectation(
|
||
self.on_get_instrument_limits_response, GetInstrumentRiskLimitsResponse,
|
||
request_id, expect_success)
|
||
self.call_expectations_manager.add_expectation(expectation)
|
||
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
|
||
assert expectation.fulfilled, "Get instrument risk limits response not received"
|
||
return expectation
|
||
|
||
|
||
class RiskGatewayClientConnectionHandlerFactory(ConnectionHandlerFactory[RiskGatewayTestClient]):
|
||
def __init__(self, call_expectations_manager: CallExpectationsManager) -> None:
|
||
self.call_expectations_manager = call_expectations_manager
|
||
|
||
def on_new_connection(self, socket_fd: socket.socket, ip_address: IpAddress,
|
||
close_callback: Callable[[], None]) -> RiskGatewayTestClient:
|
||
return RiskGatewayTestClient(
|
||
socket_fd, ip_address, close_callback, self.call_expectations_manager)
|
||
|
||
def on_connection_closed(self, connection_handler: RiskGatewayTestClient) -> None:
|
||
pass
|