gtat-tech-career-kickstarte.../solution/tests/test_client/risk_gateway_test_client.py

249 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import socket
from decimal import Decimal
from typing import Callable
from unittest.mock import MagicMock
from connection.connection_handler import ConnectionHandler, ConnectionHandlerFactory
from connection.ip_address import IpAddress
from google.protobuf.message import Message
from proto.common_pb2 import LoginRequest, LoginResponse, MessageType, Side
from proto.execution_pb2 import (
InsertOrderRequest, InsertOrderResponse,
CancelOrderRequest, CancelOrderResponse,
)
from proto.risk_limits_pb2 import (
GetUserRiskLimitsRequest, GetUserRiskLimitsResponse,
SetUserRiskLimitsRequest, SetUserRiskLimitsResponse,
GetInstrumentRiskLimitsRequest, GetInstrumentRiskLimitsResponse,
SetInstrumentRiskLimitsRequest, SetInstrumentRiskLimitsResponse,
UserRiskLimits, InstrumentRiskLimits, RollingWindowLimit,
)
from tests.conftest import DEFAULT_PASSWORD
from tests.common.mock_expectations import CallExpectationsManager, ResponseExpectation
logger = logging.getLogger(__name__)
class RiskGatewayTestClient(ConnectionHandler):
"""Test client that speaks both the execution and risk_limits protocols."""
def __init__(self, socket_fd: socket.socket, ip_address: IpAddress,
on_close: Callable[[], None],
call_expectations_manager: CallExpectationsManager) -> None:
super().__init__(socket_fd, ip_address, on_close)
self.next_request_id = 1
self.callbacks: dict[int, Callable] = {}
self.call_expectations_manager = call_expectations_manager
self._last_order_id: int | None = None
self.on_login_response = call_expectations_manager.create_mock("on_login_response")
self.on_insert_order_response = call_expectations_manager.create_mock("on_insert_order_response")
self.on_cancel_order_response = call_expectations_manager.create_mock("on_cancel_order_response")
self.on_get_user_limits_response = call_expectations_manager.create_mock("on_get_user_limits_response")
self.on_set_user_limits_response = call_expectations_manager.create_mock("on_set_user_limits_response")
self.on_get_instrument_limits_response = call_expectations_manager.create_mock("on_get_instrument_limits_response")
self.on_set_instrument_limits_response = call_expectations_manager.create_mock("on_set_instrument_limits_response")
def on_disconnect(self) -> None:
logger.info(f"Disconnected from {self.ip_address}")
def handle_message(self, message_type: int, raw_message: bytes) -> None:
logger.info(f"Handling message of type {MessageType.Name(message_type)}")
msg: Message
if message_type == MessageType.AUTH_LOGIN_RESPONSE:
msg = LoginResponse.FromString(raw_message)
elif message_type == MessageType.EXEC_INSERT_ORDER_RESPONSE:
msg = InsertOrderResponse.FromString(raw_message)
elif message_type == MessageType.EXEC_CANCEL_ORDER_RESPONSE:
msg = CancelOrderResponse.FromString(raw_message)
elif message_type == MessageType.RISK_GET_USER_LIMITS_RESPONSE:
msg = GetUserRiskLimitsResponse.FromString(raw_message)
elif message_type == MessageType.RISK_SET_USER_LIMITS_RESPONSE:
msg = SetUserRiskLimitsResponse.FromString(raw_message)
elif message_type == MessageType.RISK_GET_INSTRUMENT_LIMITS_RESPONSE:
msg = GetInstrumentRiskLimitsResponse.FromString(raw_message)
elif message_type == MessageType.RISK_SET_INSTRUMENT_LIMITS_RESPONSE:
msg = SetInstrumentRiskLimitsResponse.FromString(raw_message)
else:
raise ValueError(f"Unexpected message type: {MessageType.Name(message_type)}")
assert hasattr(msg, "request_id")
callback = self.callbacks.pop(msg.request_id, None)
if callback:
callback(msg)
else:
raise ValueError(f"No callback for request_id: {msg.request_id}")
# ------------------------------------------------------------------
# Low-level helpers
# ------------------------------------------------------------------
def _send_request(self, message_type: int, message: Message,
callback: Callable) -> int:
request_id = self.next_request_id
self.next_request_id += 1
message.request_id = request_id # type: ignore[union-attr]
self.callbacks[request_id] = callback
self.send_message(message_type, message)
return request_id
# ------------------------------------------------------------------
# Auth
# ------------------------------------------------------------------
def test_login(self, username: str, password: str = DEFAULT_PASSWORD,
expect_success: bool = True) -> ResponseExpectation[LoginResponse]:
request = LoginRequest(username=username, password=password)
request_id = self._send_request(
MessageType.AUTH_LOGIN_REQUEST, request, self.on_login_response)
expectation = ResponseExpectation(
self.on_login_response, LoginResponse, request_id, expect_success)
self.call_expectations_manager.add_expectation(expectation)
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert expectation.fulfilled, "Login response not received"
return expectation
# ------------------------------------------------------------------
# Execution
# ------------------------------------------------------------------
def test_insert_order(
self, instrument_symbol: str, side: Side.ValueType,
price: Decimal, quantity: int,
expect_success: bool = True,
) -> ResponseExpectation[InsertOrderResponse]:
request = InsertOrderRequest(
instrument_symbol=instrument_symbol, side=side,
price=float(price), quantity=quantity)
request_id = self._send_request(
MessageType.EXEC_INSERT_ORDER_REQUEST, request,
self.on_insert_order_response)
expectation = ResponseExpectation(
self.on_insert_order_response, InsertOrderResponse,
request_id, expect_success)
self.call_expectations_manager.add_expectation(expectation)
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert expectation.fulfilled, "Insert order response not received"
if expect_success:
self._last_order_id = expectation.get_response().order_id
return expectation
def test_cancel_order(
self, instrument_symbol: str, order_id: int | None = None,
expect_success: bool = True,
) -> ResponseExpectation[CancelOrderResponse]:
if order_id is None:
assert self._last_order_id is not None, "No order to cancel"
order_id = self._last_order_id
request = CancelOrderRequest(
instrument_symbol=instrument_symbol, order_id=order_id)
request_id = self._send_request(
MessageType.EXEC_CANCEL_ORDER_REQUEST, request,
self.on_cancel_order_response)
expectation = ResponseExpectation(
self.on_cancel_order_response, CancelOrderResponse,
request_id, expect_success)
self.call_expectations_manager.add_expectation(expectation)
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert expectation.fulfilled, "Cancel order response not received"
return expectation
# ------------------------------------------------------------------
# Risk Limits user level
# ------------------------------------------------------------------
def test_set_user_risk_limits(
self, user_risk_limits: UserRiskLimits | None = None,
expect_success: bool = True,
) -> ResponseExpectation[SetUserRiskLimitsResponse]:
request = SetUserRiskLimitsRequest()
if user_risk_limits is not None:
request.user_risk_limits.CopyFrom(user_risk_limits)
request_id = self._send_request(
MessageType.RISK_SET_USER_LIMITS_REQUEST, request,
self.on_set_user_limits_response)
expectation = ResponseExpectation(
self.on_set_user_limits_response, SetUserRiskLimitsResponse,
request_id, expect_success)
self.call_expectations_manager.add_expectation(expectation)
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert expectation.fulfilled, "Set user risk limits response not received"
return expectation
def test_get_user_risk_limits(
self, expect_success: bool = True,
) -> ResponseExpectation[GetUserRiskLimitsResponse]:
request = GetUserRiskLimitsRequest()
request_id = self._send_request(
MessageType.RISK_GET_USER_LIMITS_REQUEST, request,
self.on_get_user_limits_response)
expectation = ResponseExpectation(
self.on_get_user_limits_response, GetUserRiskLimitsResponse,
request_id, expect_success)
self.call_expectations_manager.add_expectation(expectation)
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert expectation.fulfilled, "Get user risk limits response not received"
return expectation
# ------------------------------------------------------------------
# Risk Limits instrument level
# ------------------------------------------------------------------
def test_set_instrument_risk_limits(
self, instrument_symbol: str,
instrument_risk_limits: InstrumentRiskLimits | None = None,
expect_success: bool = True,
) -> ResponseExpectation[SetInstrumentRiskLimitsResponse]:
request = SetInstrumentRiskLimitsRequest(instrument_symbol=instrument_symbol)
if instrument_risk_limits is not None:
request.instrument_risk_limits.CopyFrom(instrument_risk_limits)
request_id = self._send_request(
MessageType.RISK_SET_INSTRUMENT_LIMITS_REQUEST, request,
self.on_set_instrument_limits_response)
expectation = ResponseExpectation(
self.on_set_instrument_limits_response, SetInstrumentRiskLimitsResponse,
request_id, expect_success)
self.call_expectations_manager.add_expectation(expectation)
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert expectation.fulfilled, "Set instrument risk limits response not received"
return expectation
def test_get_instrument_risk_limits(
self, expect_success: bool = True,
) -> ResponseExpectation[GetInstrumentRiskLimitsResponse]:
request = GetInstrumentRiskLimitsRequest()
request_id = self._send_request(
MessageType.RISK_GET_INSTRUMENT_LIMITS_REQUEST, request,
self.on_get_instrument_limits_response)
expectation = ResponseExpectation(
self.on_get_instrument_limits_response, GetInstrumentRiskLimitsResponse,
request_id, expect_success)
self.call_expectations_manager.add_expectation(expectation)
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert expectation.fulfilled, "Get instrument risk limits response not received"
return expectation
class RiskGatewayClientConnectionHandlerFactory(ConnectionHandlerFactory[RiskGatewayTestClient]):
def __init__(self, call_expectations_manager: CallExpectationsManager) -> None:
self.call_expectations_manager = call_expectations_manager
def on_new_connection(self, socket_fd: socket.socket, ip_address: IpAddress,
close_callback: Callable[[], None]) -> RiskGatewayTestClient:
return RiskGatewayTestClient(
socket_fd, ip_address, close_callback, self.call_expectations_manager)
def on_connection_closed(self, connection_handler: RiskGatewayTestClient) -> None:
pass