This commit is contained in:
Tom Kunc 2026-03-19 21:48:58 +11:00
parent 664ee27835
commit 9bfaa0d57c
102 changed files with 9211 additions and 262 deletions

View File

@ -0,0 +1,44 @@
---
description: Overview of the Tech Career Kickstarter / Optivex project structure, purpose, and conventions
alwaysApply: true
---
# Tech Career Kickstarter — Optivex
This repository supports Optiver's **Tech Career Kickstarter** program: a one-week, on-premise training and assessment program for STEM students. Participants build a simulated financial exchange called **Optivex**.
## Repository Structure
The repo has three top-level areas with different audiences:
| Directory | Audience | Purpose |
|---|---|---|
| `setup_utils/` | Program staff only | Shell scripts for provisioning EC2 dev machines and the UAT machine. |
| `solution/` | Program staff only | Reference system tests and the UAT continuous deployment daemon. |
| `x_template/` | **Participants** | Submodule for the starter project cloned per team. This is the only part participants see. |
## The Exchange System (x_template/)
Participants implement **N interconnected components** that communicate over TCP using Protobuf 3, where N is defined by each team according to how they want to design their system. In the end, all protocols defined in `x_template/proto/` must be implemented by these components in order to pass all system tests. The protocol may be extended with new messages and services, but they must remain compatible with the original protocols for the automated system tests to pass.
## Key Patterns
### BaseApplication
All components inherit from `src/application/application.py::BaseApplication` (ABC). It handles:
- CLI arg `-c/--config` pointing to a JSON config file validated against a JSON Schema
- File + stream logging setup
- SIGINT/SIGTERM → clean `SystemExit(0)`
Subclasses implement `_start()`.
### Deployment & Testing
- `deploy.sh` (in `setup_utils/`) builds a wheel via `uv build` and SCPs it to the UAT machine.
- The UAT daemon (`solution/tests/continuous_deployment/test_finder.py`) polls for new wheels, installs each in an isolated venv, and runs the system tests from `solution/tests/`.
- System tests use `ComponentOrchestrator` and `ProcessManager` to spawn component binaries by console-script name and connect to them over TCP. The component dependencies are defined in `solution/tests/continuous_deployment/testing_dependencies.json`.
- Only protocol messages that are explicitly defined in `x_template/proto/` are allowed to be used in the system tests.
- Test results are written to `/tmp/<timestamp>/test_results.xml` and SCPed back to the dev machine.
### Component Configuration
Runtime config for each component follows `src/application/config_schema.json`. The full deployment is described in `deployment_config.json` (validated against `deployment_config_schema.json`), which lists components, their implemented protocols, ports, and which system tests to run.

View File

@ -0,0 +1,39 @@
---
name: rebuild-proto
description: Rebuild generated Python protobuf files by running build_proto.sh after any modification to .proto files. Use automatically whenever the agent edits, creates, or deletes any .proto file in the solution/proto/ or x_template/proto/ directories.
---
# Rebuild Proto Files
## When to Trigger
**Automatically** run this skill whenever you modify, create, rename, or delete any `.proto` file under:
- `solution/proto/`
- `x_template/proto/`
This includes changes made via `Write`, `StrReplace`, or `Delete` tools on any `*.proto` file.
## How to Rebuild
Each top-level directory (`solution/` and `x_template/`) has its own `build_proto.sh`. Run the script **from the directory that contains the modified proto file**.
After editing a proto file in `solution/proto/`:
```bash
cd solution && bash build_proto.sh
```
After editing a proto file in `x_template/proto/`:
```bash
cd x_template && bash build_proto.sh
```
If you modified proto files in **both** directories, run both commands.
## Important Notes
- Always run the rebuild **after all proto edits in a batch are complete**, not after each individual file edit. This avoids redundant rebuilds.
- The script generates Python files into `src/proto/` (relative to the directory) and fixes imports to use relative style.
- If `protoc` fails, check that all `.proto` imports resolve correctly and that no syntax errors were introduced.
- Do **not** edit the generated `*_pb2.py` or `*_pb2.pyi` files directly — they will be overwritten on the next rebuild.

5
.gitignore vendored
View File

@ -72,4 +72,7 @@ dmypy.json
Thumbs.db Thumbs.db
# Project-specific # Project-specific
*.log performance_report*.json
.claude/
*.pem

4
.gitmodules vendored Normal file
View File

@ -0,0 +1,4 @@
[submodule "template"]
path = x_template
url = https://github.com/optiver-external/gtat-tech-career-kickstarter-challenge.git
branch = main

14
.vscode/settings.json vendored
View File

@ -1,7 +1,11 @@
{ {
"python.testing.pytestArgs": [ "python.defaultInterpreterPath": "${workspaceFolder}/solution/.venv/bin/python",
"tests" "python.testing.cwd": "${workspaceFolder}/solution",
], "python.testing.unittestEnabled": false,
"python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true,
"python.testing.pytestEnabled": true "python.testing.pytestArgs": [
"tests",
"--venv-path=${workspaceFolder}/solution/.venv",
"--deployment-config=${workspaceFolder}/solution/deployment_config.json"
]
} }

146
README.md
View File

@ -1,122 +1,88 @@
# Optivex # Career Kickstarter - Optivex
You are building **Optivex**, Optiver's own financial exchange. Your task is to design and implement an exchange supporting the features defined in the protobuf contracts in the `proto/` directory. This repository contains the code for the Tech Career Kickstarter program. In this program, participants build a simulated financial exchange called **Optivex**.
Clients of your exchange will connect via TCP to the public endpoints of your **components**. Each component can implement one or more protocols, which are defined in the `proto/` directory. Internally, you can organize your solution in whatever way you see fit, as long as you implement all the protocols defined in the `proto/` directory. The one exception is the **order book**: it must be a dedicated component that does not require authentication (`authRequired: false`), as it is an internal service that other components connect to. The project is organized into the following directories:
A pre-populated [user data file](#user-data-file) will be provided, which is recreated automatically each day. Do not rely on any other form of persistent storage; all other data should be kept in memory. * `setup_utils/`: The setup utilities for the program.
* `solution/`: The solution code for the program. Contains the system tests and the UAT continuous deployment daemon.
* `x_template/`: The template code for the program. This is a submodule pointing to the template repository.
## Constraints ## Notes to Program Managers
* A component is a TCP server listening to a single port. Each component may connect to other components via TCP. * Instruct participants on designing iteratively so they can optimize for testing through the week.
* The original protocols may be extended with new messages and services, but they must remain compatible with the original version for the automated system tests to pass. The system tests will only exercise messages defined in the original `proto/` files — custom extensions will not be tested. * Expect struggles with modelling client handler vs publishing messages to all clients. Requires abstracting stuff away to avoid circular dependency.
* The `request_id` field is used to match requests to responses. It must be set to the same value in the request and response messages. * There will always be trade-offs between performance and robustness. Are they aware of them?
* Every response message must contain an `error_message` field.
* Floating point numbers precision is 4 decimal places.
## Project Structure ### Interesting Observation Points for Assessment
``` * How do they handle the race conditions? There will always be trade-offs between performance and robustness. Are they aware of them?
├── proto/ # Protobuf contracts your components must implement * Between creating an order book and creating an instrument in Info?
├── src/ * Between inserting an order and executing a trade?
│ ├── application/ # BaseApplication: config loading, logging, signal handling * Between subscribing to an order book and receiving updates?
│ │ ├── config_schema.json # JSON Schema for component runtime config
│ │ └── data_file_schema.json # JSON Schema for the user data file
│ ├── connection/ # TCP connectivity library
│ └── sample_app/ # Reference implementation — start here
├── deployment_config.json # Describes your components, protocols, and ports
├── deployment_config_schema.json # JSON Schema for the above
└── pyproject.toml # Project metadata, dependencies, and entry points
```
## Getting Started * How do they design their initialization logic? Is it clear, maintainable and robust against race conditions?
1. Install dependencies and activate the virtual environment: * How do they handle precision issues when dealing with floating point numbers, like price?
```
uv sync && source .venv/bin/activate
```
2. Run the sample app to see the framework in action:
```
sample-app -c src/sample_app/sample_app_config.json
```
See [`src/sample_app/README.md`](src/sample_app/README.md) for details.
3. Read [`proto/README.md`](proto/README.md) — it describes every message your components must implement, and is the best place to understand the full scope of the system. ### Guidelines for Participants
## Key Concepts * You have access to all the "framework" code you need to implement your solution, like the Application class, the TCP connection manager and the message codec. Navigate through the code to understand how it works and how you can use it to your advantage. This is a phylosophy we extensively follow at Optiver.
### BaseApplication * You can **control the level of complexity** you have to deal with all at once by **tuning the abstraction layers** in your code, like when you merge together two classes that became too similar or rather split up the ones that got too large.
`BaseApplication` (`src/application/application.py`) is the base class all your components should extend. It handles config loading and validation, logging setup, and graceful shutdown. Implement `_start()` in your subclass. ## Tech Setup
### Entry Points This program is designed to run on AWS EC2 instances. The dev machines are EC2 instances that are used by the participants to develop their solution. The UAT machine is an EC2 instance that is used to run the system tests against the participants' solution.
Each component must be registered as a **console-script** entry point in `pyproject.toml` under `[project.scripts]`. For example: Each machine must have installed:
```toml * *Linux*: The operating system used to run the program. Recommended: Red Hat Enterprise Linux 9. Make sure you have `ssh` and that the every machine can access Github.
[project.scripts] * *git*: For cloning the repository and pushing changes.
sample-app = "sample_app.main:main" * *python3.11*: The version of Python used to run the project.
my-order-book = "order_book.main:main" * *uv*: The package manager for the project, used to build the project and install dependencies.
``` * *protobuf-compiler*: For generating the Python files from the protobuf files. Version 3.25.0 or higher is required.
The `packageName` field in `deployment_config.json` must match the console-script name exactly (e.g. `"my-order-book"`). The test runner uses this name to locate and launch the component binary. ### UAT Instance
### Component Configuration The program manager should use this machine as their main working machine during the program. It must be set up to have easy `ssh` access to the dev machines.
Each component's runtime config is validated against `src/application/config_schema.json`. The key fields are: This machine **must be set up with** the following:
| Field | Required | Description | * A SSH identity key in `$HOME/.ssh/id_optivex`.
|---|---|---| * Environment variables:
| `logLevel` | Yes | `DEBUG`, `INFO`, `WARN`, or `ERROR` | * *CK_DEPLOYMENT_DIR*: The directory where the deployment files will be stored.
| `logDirectory` | Yes | Directory where log files are written | * *CK_TESTS_DIR*: The directory where the tests (under `solution/tests`) are checked out in this server.
| `listenOn` | Yes | `{ host, port }` — address the component listens on | * *CK_TEMPLATE_REPO_URL*: The URL of the template repository to be cloned on the dev machines.
| `connectTo` | No | Map of target component names to `{ host, port }` — keys must match the `name` field in [Deployment Config](#deployment-config) | * A clone of this repository.
| `dataFilePath` | Yes | Path to the JSON user data file |
### User Data File The `setup_utils` directory contains important scripts for setting up the dev environments of the program. Once this machine is set up, **execute `setup_utils/setup_dev_machine.sh` for every dev machine** to set them up. This script will:
The user data file follows the schema in `src/application/data_file_schema.json`. It contains an array of users, each with `username`, `password`, and `full_name`. Components that require authentication should read this file to validate credentials. * Copy the SSH identity key to the dev machines.
* Copy `setup_utils/deploy_to_uat.sh` into the dev machines and make it available in their path.
* Configure the following environment variables on the dev machines:
* *CK_UAT_USER*: The username of the UAT user.
* *CK_UAT_HOST*: The hostname of the UAT machine (the public IP of the UAT machine).
* *CK_UAT_DEPLOYMENT_DIR*: The directory where the deployment files will be stored on the UAT machine.
* Clone the template repository on the dev machines under `~/optiver-career-kickstarter`.
### Deployment Config Finally, you're ready to run `test_finder.py`.
`deployment_config.json` describes your full deployment. It has two top-level fields: #### Continuous Deployment Daemon
* **`components`** — an array of component configurations, each with: Inside the `solution/` directory, the `continuous_deployment/test_finder.py` script continuously monitors for new deployments on the UAT machine and trigger system tests against them. It must be run on the UAT machine.
* `name` — a unique identifier for the component.
* `packageName` — the console-script entry point name (see [Entry Points](#entry-points)).
* `protocols` — which protocols this component implements (`admin`, `order_book`, `info`, `execution`, `risk_limits`).
* `config` — the component's runtime config (see [Component Configuration](#component-configuration)).
* **`systemTests`** — which test suites to run (see [Incremental Testing](#incremental-testing)).
Full specification is available in `deployment_config_schema.json`. The script runs from within the project's solution. First you must create a virtual environment install the required dependencies. You can do this with `uv sync` from within `solution/`, then activate it with `source .venv/bin/activate`.
## Deploying and Testing Now you can run the script with `python test_finder.py`.
Once ready to test, run `deploy.sh` from your project root to build and submit your implementation to the testing environment. This script builds the project using `uv build` and copies the results onto the UAT server, where automated system tests will be run against it. Test results will be copied back to your dev server under the `/tmp` directory. This script will look for new deployments in the `CK_DEPLOYMENT_DIR` directory on the UAT machine. If it finds a new deployment, first it creates a new virtual environment and installs the deployed package (wheel) there. Then it runs the system tests against the deployed package. The test results will be copied back to the dev servers.
### How Tests Run ### Dev Instances
The test runner reads your `deployment_config.json` to determine which components to start and which tests to run. For each test: The dev machines must be set up with an user account matching the machine name. For example, if the machine name is `dev-1`, the user account must be `dev-1`.
1. **Dependency resolution** — the required protocols are determined from the test's dependencies and the `connectTo` fields of your components. After the machine is properly set up by the program manager, participants will have access to:
2. **Start order** — components are started in dependency order (a component's `connectTo` targets are started first).
3. **Config overrides** — the test runner rewrites several fields in your component config at runtime: `listenOn` and `connectTo` ports are replaced with ephemeral ports, and `logDirectory` and `dataFilePath` are pointed to temporary paths. Do not hardcode any of these values in your component logic; always read them from the config.
4. **Startup** — each component must accept TCP connections on its assigned port within **2 seconds**, or the test will fail.
5. **Execution** — tests run with 4 parallel workers, so your components must handle concurrent connections. Each test simulates a single day of trading.
### Incremental Testing * *~/optiver-career-kickstarter*: A clone of this repository's template already at a new branch named after the participant's username.
* *deploy_to_uat.sh*: This file must be made available to the participants in their dev server. It builds their project using `uv build` and copies the results onto the UAT server, where automated system tests will be run against it.
The `systemTests` array controls which test suites are run. You can test incrementally by listing only the protocols you've implemented so far:
```json
"systemTests": ["order_book"]
```
As you implement more protocols, add them to the array. The available test suites and their protocol dependencies are:
| Test Suite | Requires Protocols |
|---|---|
| `order_book` | `admin` |
| `info` | `admin`, `order_book` |
| `execution` | `admin`, `order_book` |
| `risk_limits` | `admin`, `execution` |

2
TODO.md Normal file
View File

@ -0,0 +1,2 @@
* Remove GetAllRequest/Response from order book. (check if needed)
* Tests are extensively calling `verify_no_unexpected_calls`. This will likely break in different designs. We should rather create expectations for "no" calls to certain methods.

View File

@ -1,50 +0,0 @@
{
"components": [
{
"name": "admin",
"packageName": "sample-app",
"protocols": ["admin"],
"authRequired": false,
"config": {
"logLevel": "DEBUG",
"logDirectory": "./logs",
"listenOn": {
"host": "localhost",
"port": 9001
},
"connectTo": {
"core": {
"host": "localhost",
"port": 9100
}
}
}
},
{
"name": "core",
"packageName": "sample-app",
"protocols": [
"execution",
"info",
"order_book",
"risk_limits"
],
"authRequired": true,
"config": {
"logLevel": "DEBUG",
"logDirectory": "./logs",
"dataFilePath": "./data.json",
"listenOn": {
"host": "localhost",
"port": 9100
}
}
}
],
"systemTests": [
"execution",
"info",
"order_book",
"risk_limits"
]
}

59
setup_utils/deploy.sh Normal file
View File

@ -0,0 +1,59 @@
#!/bin/bash
# Exit on any error
set -e
UAT_USER="$CK_UAT_USER"
UAT_HOST="$CK_UAT_HOST"
DEPLOYMENT_DIR="$(date +%s)"
UAT_DEST_DIR="$CK_UAT_DEPLOYMENT_DIR/$(whoami)/$DEPLOYMENT_DIR"
PROJECT_ROOT="${1:-$(pwd)}"
IDENTITY_KEY="$HOME/.ssh/${CK_DEPLOY_KEY_NAME:-id_optivex}"
if [ ! -f "$IDENTITY_KEY" ]; then
echo "Identity key not found. Please check your SSH configuration."
exit 1
fi
echo "Starting deployment..."
echo " -> Deployment id: $DEPLOYMENT_DIR"
if [ -z "$VIRTUAL_ENV" ]; then
if [ -d ".venv" ]; then
echo "Activating virtual environment from .venv..."
source .venv/bin/activate
else
echo "No virtual environment found. Please create one and try again."
exit 1
fi
fi
echo "Building the project with 'uv'..."
uv build
DIST_DIR="$PROJECT_ROOT/dist"
if [ ! -d "$DIST_DIR" ]; then
echo "Build failed or dist directory not found."
exit 1
fi
echo "Copying files to UAT server..."
# Create destination directory all the way into the logs directory
ssh -i "$IDENTITY_KEY" -o StrictHostKeyChecking=no "$UAT_USER@$UAT_HOST" "mkdir -p $UAT_DEST_DIR/logs"
# -O forces legacy SCP protocol (OpenSSH 8.7+ defaults to SFTP, which
# the restricted_deploy.sh forced command does not allow)
scp -O -i "$IDENTITY_KEY" -o StrictHostKeyChecking=no -r "$DIST_DIR/"* "$UAT_USER@$UAT_HOST:$UAT_DEST_DIR/"
if [ $? -ne 0 ]; then
echo "File transfer failed."
exit 1
fi
echo "Deployment to UAT complete."
echo "Tests results will soon be available at /tmp/$DEPLOYMENT_DIR/test_results.xml"
echo "Bumping patch version of the application for the next deployment..."
uv version --bump patch
echo "All done!"

1089
setup_utils/provision_aws.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,115 @@
#!/bin/bash
# restricted_deploy.sh — Forced SSH command for per-team deployment isolation.
#
# Installed on UAT at /usr/local/bin/restricted_deploy.sh.
# Referenced in deployer's authorized_keys as:
# command="/usr/local/bin/restricted_deploy.sh <team>",no-port-forwarding,... ssh-rsa ...
#
# Only allows:
# - scp sink mode (write): scp -t [-r] [-d] /srv/deployments/<team>/...
# - mkdir: mkdir -p /srv/deployments/<team>/...
set -euo pipefail
TEAM_NAME="$1"
BASE_DIR="/srv/deployments"
ALLOWED_DIR="${BASE_DIR}/${TEAM_NAME}"
if [ ! -d "$ALLOWED_DIR" ]; then
echo "ERROR: team directory does not exist" >&2
exit 1
fi
ALLOWED_REAL=$(readlink -f "$ALLOWED_DIR")
ORIG_CMD="${SSH_ORIGINAL_COMMAND:-}"
if [ -z "$ORIG_CMD" ]; then
echo "ERROR: interactive shell access denied" >&2
exit 1
fi
CMD_VERB=$(echo "$ORIG_CMD" | awk '{print $1}')
case "$CMD_VERB" in
scp)
# Extract the target path (always the last argument)
TARGET_PATH=$(echo "$ORIG_CMD" | awk '{print $NF}')
# Validate flags: only -t, -r, -d are allowed
FLAGS=$(echo "$ORIG_CMD" | awk '{$1=""; $NF=""; print}' | tr -s ' ')
for flag in $FLAGS; do
case "$flag" in
-t|-r|-d) ;;
*) echo "ERROR: disallowed scp flag: $flag" >&2; exit 1 ;;
esac
done
# Must contain -t (sink/write mode)
if ! echo "$ORIG_CMD" | grep -q -- '-t'; then
echo "ERROR: only scp sink mode (-t) is allowed" >&2
exit 1
fi
;;
mkdir)
if ! echo "$ORIG_CMD" | grep -q '^mkdir -p '; then
echo "ERROR: only 'mkdir -p <path>' is allowed" >&2
exit 1
fi
TARGET_PATH=$(echo "$ORIG_CMD" | awk '{print $NF}')
;;
*)
echo "ERROR: command not allowed: $CMD_VERB" >&2
exit 1
;;
esac
# ---- Path validation (anti-traversal) ----
# Reject paths containing ".."
if echo "$TARGET_PATH" | grep -q '\.\.'; then
echo "ERROR: path traversal detected" >&2
exit 1
fi
# Textual prefix check
case "$TARGET_PATH" in
"${ALLOWED_DIR}/"*|"${ALLOWED_DIR}")
;;
*)
echo "ERROR: target path is outside allowed directory" >&2
exit 1
;;
esac
# Canonical check for existing paths (defeats symlink escapes)
if [ -e "$TARGET_PATH" ]; then
REAL_TARGET=$(readlink -f "$TARGET_PATH")
case "$REAL_TARGET" in
"${ALLOWED_REAL}/"*|"${ALLOWED_REAL}")
;;
*)
echo "ERROR: resolved path is outside allowed directory" >&2
exit 1
;;
esac
else
# For non-existent paths (mkdir case), resolve the nearest existing parent
CHECK_PATH="$TARGET_PATH"
while [ ! -e "$CHECK_PATH" ] && [ "$CHECK_PATH" != "/" ]; do
CHECK_PATH=$(dirname "$CHECK_PATH")
done
if [ -e "$CHECK_PATH" ]; then
REAL_CHECK=$(readlink -f "$CHECK_PATH")
case "$REAL_CHECK" in
"${ALLOWED_REAL}/"*|"${ALLOWED_REAL}")
;;
*)
echo "ERROR: resolved parent path is outside allowed directory" >&2
exit 1
;;
esac
fi
fi
# All checks passed — execute the original command
exec $ORIG_CMD

View File

@ -0,0 +1,80 @@
#!/bin/bash
#
# setup_dev_machine.sh — Run from the UAT instance to finish setting up dev machines.
#
# Discovers all dev instances via AWS tags and clones the template repo on each.
# For AWS deployments, provision_aws.py handles instance creation, SSH keys,
# env vars, and the deploy script via user-data. This script handles the
# remaining steps that require network connectivity between UAT and dev
# (i.e. cloning the template repo onto the dev machines).
#
# Prerequisites (set on UAT via user-data):
# CK_TEMPLATE_REPO_URL — git URL for the template repo
# CK_PROJECT_TAG — project tag used to discover instances
# CK_AWS_REGION — AWS region where instances are running
#
set -e
# ---------------------------------------------------------------------------
if [ -z "${CK_TEMPLATE_REPO_URL:-}" ]; then
echo "CK_TEMPLATE_REPO_URL is not set, skipping template repo clone."
exit 0
fi
if [ -z "${CK_PROJECT_TAG:-}" ]; then
echo "Error: CK_PROJECT_TAG is not set. Cannot discover dev instances."
exit 1
fi
if [ -z "${CK_AWS_REGION:-}" ]; then
echo "Error: CK_AWS_REGION is not set. Cannot query AWS API."
exit 1
fi
# Discover all running dev instances by project tag (exclude UAT)
echo "Discovering dev instances (Project=$CK_PROJECT_TAG, Region=$CK_AWS_REGION)..."
DEV_INSTANCES=$(aws ec2 describe-instances \
--region "$CK_AWS_REGION" \
--filters \
"Name=tag:Project,Values=$CK_PROJECT_TAG" \
"Name=instance-state-name,Values=running" \
--query 'Reservations[].Instances[?Tags[?Key==`Role` && Value!=`uat`]].[Tags[?Key==`Role`].Value | [0], PrivateIpAddress]' \
--output text)
if [ -z "$DEV_INSTANCES" ]; then
echo "No dev instances found. Nothing to do."
exit 0
fi
echo "Found dev instances:"
echo "$DEV_INSTANCES"
echo "---"
FAILED=0
while IFS=$'\t' read -r DEV_NAME DEV_IP; do
echo ""
echo "=== Setting up $DEV_NAME ($DEV_IP) ==="
if ssh -o StrictHostKeyChecking=no -o ConnectTimeout=10 "$DEV_NAME@$DEV_IP" '
git clone '"'$CK_TEMPLATE_REPO_URL'"' ~/optiver-career-kickstarter
cd ~/optiver-career-kickstarter
uv sync
git checkout -b team-$(whoami)
'; then
echo "Template repo cloned on $DEV_NAME ($DEV_IP)"
else
echo "ERROR: Failed to set up $DEV_NAME ($DEV_IP)"
FAILED=$((FAILED + 1))
fi
done <<< "$DEV_INSTANCES"
echo ""
echo "=== Done ==="
if [ "$FAILED" -gt 0 ]; then
echo "$FAILED dev machine(s) failed setup."
exit 1
fi
echo "All dev machines set up successfully."

92
setup_utils/setup_uat.sh Normal file
View File

@ -0,0 +1,92 @@
#!/bin/bash
#
# setup_uat.sh — Run on the UAT instance after provisioning completes.
#
# Discovers dev instances by their Project tag using the EC2 instance metadata
# and AWS CLI, then builds ~/.ssh/config entries so ec2-user can refer to dev
# machines by hostname (e.g. "ssh dev-1" to push test results).
#
# Prerequisites:
# - AWS CLI installed (or use: dnf install -y awscli)
# - Instance must have an IAM role allowing ec2:DescribeInstances
# (or export AWS_ACCESS_KEY_ID / AWS_SECRET_ACCESS_KEY)
#
# Usage:
# bash setup_uat.sh [--project-tag career-kickstarter]
set -euo pipefail
PROJECT_TAG="${1:-career-kickstarter}"
# Resolve region from instance metadata (IMDSv2)
TOKEN=$(curl -s -X PUT "http://169.254.169.254/latest/api/token" \
-H "X-aws-ec2-metadata-token-ttl-seconds: 60")
REGION=$(curl -s -H "X-aws-ec2-metadata-token: $TOKEN" \
http://169.254.169.254/latest/meta-data/placement/region)
INSTANCE_ID=$(curl -s -H "X-aws-ec2-metadata-token: $TOKEN" \
http://169.254.169.254/latest/meta-data/instance-id)
echo "Region: $REGION"
echo "Instance ID: $INSTANCE_ID"
echo "Project tag: $PROJECT_TAG"
# Discover dev instances (Role != "uat", same Project tag, running)
DEV_INSTANCES=$(aws ec2 describe-instances \
--region "$REGION" \
--filters \
"Name=tag:Project,Values=$PROJECT_TAG" \
"Name=instance-state-name,Values=running" \
--query 'Reservations[].Instances[?Tags[?Key==`Role` && Value!=`uat`]].[Tags[?Key==`Role`].Value | [0], PrivateIpAddress]' \
--output text)
if [ -z "$DEV_INSTANCES" ]; then
echo "No dev instances found with Project=$PROJECT_TAG"
exit 0
fi
# Build SSH config entries
SSH_CONFIG="$HOME/.ssh/config"
MARKER="# --- Career Kickstarter dev machines ---"
# Remove any previous CK entries
if grep -q "$MARKER" "$SSH_CONFIG" 2>/dev/null; then
sed -i "/$MARKER/,/^$/d" "$SSH_CONFIG"
fi
{
echo ""
echo "$MARKER"
while IFS=$'\t' read -r dev_name private_ip; do
[ -z "$dev_name" ] && continue
echo "Host $dev_name"
echo " HostName $private_ip"
echo " User $dev_name"
echo ""
done <<< "$DEV_INSTANCES"
} >> "$SSH_CONFIG"
chmod 600 "$SSH_CONFIG"
echo ""
echo "SSH config updated. Dev machines:"
while IFS=$'\t' read -r dev_name private_ip; do
[ -z "$dev_name" ] && continue
echo " $dev_name -> $private_ip"
done <<< "$DEV_INSTANCES"
echo ""
echo "You can now run: ssh dev-1 (etc.)"
# ---------------------------------------------------------------------------
# Clone template repo and set up solution test suite
# ---------------------------------------------------------------------------
if [ -z "${CK_UAT_REPO_URL:-}" ]; then
echo "CK_UAT_REPO_URL is not set, skipping UAT repo clone."
else
echo "Cloning UAT repository..."
git clone "$CK_UAT_REPO_URL" ~/gtat-tech-career-kickstarter || true
if [ -d ~/gtat-tech-career-kickstarter/solution ]; then
cd ~/gtat-tech-career-kickstarter/solution && uv sync && bash build_proto.sh
echo "Solution test suite ready."
fi
fi

68
solution/README.md Normal file
View File

@ -0,0 +1,68 @@
# Optivex — Reference Solution
Reference implementation of the Optivex exchange and the system tests used to evaluate participant submissions.
## Architecture Overview
Four components communicate over TCP using length-prefixed Protobuf 3 messages:
```
┌──────────┐
│ Admin │
└──┬───┬───┘
│ │
┌───────┘ └────────┐
▼ ▼
┌──────────────┐ ┌──────────────┐
│ Order Book │◄───│ Info │
└──────────────┘ └──────────────┘
▲ ▲
│ ┌───────────────┘
│ │
┌──────┴────┴─────┐
│ Risk Gateway │
└─────────────────┘
```
- **Order Book** — Matching engine. No upstream dependencies.
- **Info** — Market data. Mirrors order book state to serve TOP_OF_BOOK and PRICE_DEPTH_BOOK subscriptions.
- **Admin** — Instrument lifecycle. Orchestrates creation across Order Book and Info.
- **Risk Gateway** — Execution + risk. Enforces limits before forwarding orders to Order Book. Implements both the `execution` and `risk_limits` protocols.
## Protocol Extensions vs Template
The solution adds **internal management messages** for Admin ↔ Order Book and Admin ↔ Info communication. These are not part of the participant-facing API.
Added to `common.proto` (numbered at 130+/140+ to avoid collisions with the public API range 067):
| Message Type | Purpose |
|---|---|
| `INFO_CREATE_INSTRUMENT_REQUEST/RESPONSE` (130131) | Admin → Info: register instrument with `order_book_id` |
| `ORDER_BOOK_CREATE_REQUEST/RESPONSE` (140141) | Admin → Order Book: create order book |
| `ORDER_BOOK_DELETE_REQUEST/RESPONSE` (142143) | Admin → Order Book: delete order book |
Note that `info.CreateInstrumentRequest` carries `order_book_id` (book already exists), while `admin.CreateInstrumentRequest` carries `tick_size` (book doesn't exist yet). The Admin component bridges this gap with a two-phase flow: create order book first, then register the instrument in Info with the resulting ID.
## Non-Obvious Design Decisions
### Matching engine uses negated prices for max-heap
Python's `heapq` is min-heap only. Bids need max-heap behavior (highest price first). The sort key negates the price for bids: `(-price, timestamp, order_id)` vs `(+price, timestamp, order_id)` for asks. This gives price-time priority with `order_id` as a final tiebreaker.
Cancel uses `list.remove()` + `heapify()` — O(n) rather than O(log n) with lazy deletion, but sufficient for this scale.
### Info mirrors order book state from the event stream
Info doesn't access the matching engine directly. `OrderBookClientAggregator` reconstructs book state from `OnOrderInserted`, `OnOrderCancelled`, and `OnTrade` events, maintaining quantity-per-price-level aggregates.
A **pending trades mechanism** handles the ordering issue where `OnOrderInserted` references `trade_ids` for trades that haven't arrived yet. While pending trades exist, the book is considered inconsistent and market data updates are suppressed. Reads (`get_top_of_book`, `get_price_depth_book`) assert consistency.
### Rolling window limits use in-place list pruning
`RiskLimitsStore` maintains timestamped event lists for each rolling metric (message rate, order quantity, order amount). On each check, entries older than `now - window_seconds` are pruned with `del list[:first_valid]`. This mutates the list in place — O(n) per check but amortized O(1) per entry since each is pruned exactly once.
Limit checks follow a **check-then-record** pattern: `check_order_limits()` validates without side effects, and `record_order_attempt()` is called only after the check passes. This avoids polluting rolling windows with rejected orders.
### Decimal precision avoids float artifacts
All prices and amounts go through `decimal_from_float()`: `Decimal(str(value)).quantize(Decimal("0.0001"))`. The `str()` intermediate avoids the well-known `Decimal(0.1)``0.100000000000000005...` problem.

View File

@ -0,0 +1,101 @@
{
"components": [
{
"name": "admin",
"packageName": "sample-admin",
"protocols": ["admin"],
"authRequired": false,
"config": {
"logLevel": "DEBUG",
"logDirectory": "./logs",
"listenOn": {
"host": "localhost",
"port": 9001
},
"connectTo": {
"order_book": {
"host": "localhost",
"port": 9002
},
"info": {
"host": "localhost",
"port": 9003
}
}
}
},
{
"name": "order_book",
"packageName": "sample-order-book",
"protocols": [
"order_book"
],
"authRequired": false,
"config": {
"logLevel": "DEBUG",
"logDirectory": "./logs",
"listenOn": {
"host": "localhost",
"port": 9002
}
}
},
{
"name": "info",
"packageName": "sample-info",
"protocols": [
"info"
],
"authRequired": true,
"config": {
"logLevel": "DEBUG",
"logDirectory": "./logs",
"dataFilePath": "./data.json",
"listenOn": {
"host": "localhost",
"port": 9003
},
"connectTo": {
"order_book": {
"host": "localhost",
"port": 9002
}
}
}
},
{
"name": "risk_gateway",
"packageName": "sample-risk-gateway",
"protocols": [
"risk_limits",
"execution"
],
"authRequired": true,
"config": {
"logLevel": "DEBUG",
"logDirectory": "./logs",
"dataFilePath": "./data.json",
"listenOn": {
"host": "localhost",
"port": 9005
},
"connectTo": {
"order_book": {
"host": "localhost",
"port": 9002
},
"info": {
"host": "localhost",
"port": 9003
}
}
}
}
],
"systemTests": [
"execution",
"info",
"order_book",
"risk_limits"
]
}

View File

@ -17,6 +17,8 @@ The foundation. Defines:
- `Side` — the `BUY`/`SELL` enum used across order-related messages. - `Side` — the `BUY`/`SELL` enum used across order-related messages.
- `LoginRequest` / `LoginResponse` — the authentication handshake your components must perform with every connecting client before serving any other message. - `LoginRequest` / `LoginResponse` — the authentication handshake your components must perform with every connecting client before serving any other message.
User information is provided via a static JSON data file (see `users_data_schema.json`) whose path is supplied in each component's config as `dataFilePath`. Components that require authentication must read this file at startup to validate login credentials.
Read this first to orient yourself. Every other file imports it. Read this first to orient yourself. Every other file imports it.
### `admin.proto` ### `admin.proto`
@ -24,8 +26,6 @@ The administration interface your exchange must expose to let the system be boot
- `CreateInstrumentRequest/Response` — your exchange must accept requests to list a new tradeable instrument, create an order book for it, and return the assigned `order_book_id`. - `CreateInstrumentRequest/Response` — your exchange must accept requests to list a new tradeable instrument, create an order book for it, and return the assigned `order_book_id`.
User information is provided via a static JSON data file (see `users_data_schema.json` in the project root) whose path is supplied in each component's config as `dataFilePath`. Components that require authentication must read this file at startup to validate login credentials.
### `info.proto` ### `info.proto`
The market data feed your exchange must maintain and push to subscribed clients. The market data feed your exchange must maintain and push to subscribed clients.

View File

@ -17,6 +17,9 @@ enum MessageType {
INFO_ON_TOP_OF_BOOK = 33; INFO_ON_TOP_OF_BOOK = 33;
INFO_ON_PRICE_DEPTH_BOOK = 34; INFO_ON_PRICE_DEPTH_BOOK = 34;
INFO_ON_TRADE = 35; INFO_ON_TRADE = 35;
// Internal management messages (used by admin service)
INFO_CREATE_INSTRUMENT_REQUEST = 130;
INFO_CREATE_INSTRUMENT_RESPONSE = 131;
// Order Book // Order Book
ORDER_BOOK_ON_ORDER_BOOK_CREATED = 40; ORDER_BOOK_ON_ORDER_BOOK_CREATED = 40;
@ -29,6 +32,11 @@ enum MessageType {
ORDER_BOOK_INSERT_ORDER_RESPONSE = 47; ORDER_BOOK_INSERT_ORDER_RESPONSE = 47;
ORDER_BOOK_CANCEL_ORDER_REQUEST = 48; ORDER_BOOK_CANCEL_ORDER_REQUEST = 48;
ORDER_BOOK_CANCEL_ORDER_RESPONSE = 49; ORDER_BOOK_CANCEL_ORDER_RESPONSE = 49;
// Internal management messages (used by admin service)
ORDER_BOOK_CREATE_REQUEST = 140;
ORDER_BOOK_CREATE_RESPONSE = 141;
ORDER_BOOK_DELETE_REQUEST = 142;
ORDER_BOOK_DELETE_RESPONSE = 143;
// Execution // Execution
EXEC_INSERT_ORDER_REQUEST = 50; EXEC_INSERT_ORDER_REQUEST = 50;

View File

@ -3,7 +3,28 @@ package optiver.exchange.info;
import "common.proto"; import "common.proto";
// All active instruments are sent to client on connection, then any new instrument is sent on creation // ------------------------------------------------------------
// Internal management messages (used by admin service)
// ------------------------------------------------------------
message CreateInstrumentRequest {
int64 request_id = 1;
Instrument instrument = 2;
int64 order_book_id = 3;
}
message CreateInstrumentResponse {
int64 request_id = 1;
string error_message = 2;
int64 created_timestamp = 3;
}
// ------------------------------------------------------------
// Messages sent to all connected clients with the current
// state of the service, then upon any changes
// ------------------------------------------------------------
// Upon connection, all active instruments are sent to client. Any new instrument is sent on creation.
message OnInstrument { message OnInstrument {
Instrument instrument = 1; Instrument instrument = 1;
int64 created_timestamp = 2; int64 created_timestamp = 2;

View File

@ -4,7 +4,32 @@ package optiver.exchange.orderbook;
import "common.proto"; import "common.proto";
// ------------------------------------------------------------ // ------------------------------------------------------------
// Messages sent to all clients upon logging in with the current // Internal management messages (used by admin service)
// ------------------------------------------------------------
message CreateOrderBookRequest {
int64 request_id = 1;
double tick_size = 2;
}
message CreateOrderBookResponse {
int64 request_id = 1;
string error_message = 2;
int64 order_book_id = 3;
}
message DeleteOrderBookRequest {
int64 request_id = 1;
int64 order_book_id = 2;
}
message DeleteOrderBookResponse {
int64 request_id = 1;
string error_message = 2;
}
// ------------------------------------------------------------
// Messages sent to all connected clients with the current
// state of the order book, then upon any changes // state of the order book, then upon any changes
// ------------------------------------------------------------ // ------------------------------------------------------------

View File

@ -1,5 +1,5 @@
[project] [project]
name = "Optivex" name = "OptivexTests"
version = "0.1.0" version = "0.1.0"
description = "Career Kickstarter - Optiver's Exchange" description = "Career Kickstarter - Optiver's Exchange"
readme = "README.md" readme = "README.md"
@ -11,13 +11,18 @@ dependencies = [
"mypy>=1.19.1", "mypy>=1.19.1",
"mypy-protobuf>=5.0.0", "mypy-protobuf>=5.0.0",
"protobuf>=7.34.0", "protobuf>=7.34.0",
"psutil>=7.0.0",
"pytest>=9.0.2", "pytest>=9.0.2",
"pytest-xdist>=3.8.0",
"types-jsonschema>=4.23.0.20241208", "types-jsonschema>=4.23.0.20241208",
"types-protobuf>=6.32.1.20260221", "types-protobuf>=6.32.1.20260221",
] ]
[project.scripts] [project.scripts]
sample-app = "sample_app.main:main" sample-admin = "admin.main:main"
sample-order-book = "order_book.main:main"
sample-info = "info.main:main"
sample-risk-gateway = "risk_gateway.main:main"
[build-system] [build-system]
requires = ["setuptools", "wheel", "uv"] requires = ["setuptools", "wheel", "uv"]
@ -25,12 +30,19 @@ build-backend = "setuptools.build_meta"
[tool.setuptools.package-data] [tool.setuptools.package-data]
"application" = ["*.json"] "application" = ["*.json"]
"common" = ["*.json"]
[tool.uv] [tool.uv]
package = true package = true
[tool.pytest.ini_options]
pythonpath = ["src"]
[tool.mypy] [tool.mypy]
exclude = "src/proto" exclude = "src/proto"
[tool.pytest.ini_options]
pythonpath = ["src"]
addopts = "-n 4"
log_cli = false
log_file = "logs/pytest.log"
log_file_level = "DEBUG"
log_file_format = "%(asctime)s [%(levelname)8s] (%(name)s) %(message)s (%(filename)s:%(lineno)s)"
log_file_date_format = "%Y-%m-%d %H:%M:%S.%f"

View File

@ -0,0 +1,94 @@
from datetime import datetime
from typing import Callable
import socket
import logging
from admin.info_client import AdminInfoClient
from admin.order_book_client import AdminOrderBookClient
from connection.ip_address import IpAddress
from connection.connection_handler import ConnectionHandler
from proto.admin_pb2 import CreateInstrumentRequest, CreateInstrumentResponse
from proto.order_book_pb2 import CreateOrderBookRequest, CreateOrderBookResponse
from proto.info_pb2 import CreateInstrumentRequest as CreateInfoInstrumentRequest, CreateInstrumentResponse as CreateInfoInstrumentResponse
from proto.common_pb2 import MessageType
logger = logging.getLogger(__name__)
class AdminClientHandler(ConnectionHandler):
def __init__(self, socket_fd: socket.socket, ip_address: IpAddress, close_callback: Callable[[], None],
order_book_client: AdminOrderBookClient,
info_client: AdminInfoClient) -> None:
super().__init__(socket_fd, ip_address, close_callback)
self.order_book_client = order_book_client
self.info_client = info_client
def on_disconnect(self) -> None:
logger.info(f"Admin client {self.ip_address} disconnected")
def handle_message(self, message_type: int, message: bytes) -> None:
logger.info(f"Handling message of type {MessageType.Name(message_type)} from {self.ip_address}") # type: ignore
if message_type == MessageType.ADMIN_CREATE_INSTRUMENT_REQUEST:
self._handle_create_instrument_request(message)
else:
raise ValueError(f"Unexpected message type received: {message_type}")
def _handle_create_instrument_request(self, message: bytes) -> None:
try:
request = self._deserialize_message(CreateInstrumentRequest, message)
if request.request_id is None:
raise ValueError("Request ID must be set")
if request.instrument is None:
raise ValueError("Instrument must be set")
if not request.instrument.symbol:
raise ValueError("Symbol must be set")
if not request.instrument.description:
raise ValueError("Description must be set")
if not request.instrument.currency:
raise ValueError("Currency must be set")
if request.instrument.multiplier is None:
raise ValueError("Multiplier must be set")
if request.instrument.multiplier <= 0:
raise ValueError("Multiplier must be positive")
if request.tick_size is None:
raise ValueError("Tick size must be set")
if request.tick_size <= 0:
raise ValueError("Tick size must be positive")
logger.info(f"Creating order book for instrument {request.instrument.symbol}")
self.order_book_client.send_create_order_book(
CreateOrderBookRequest(tick_size=request.tick_size),
lambda order_book_response: self._on_create_order_book_response(order_book_response, request))
except Exception as e:
self._send_error_response_on_create_instrument(request, str(e))
def _on_create_order_book_response(self, order_book_response: CreateOrderBookResponse,
request: CreateInstrumentRequest) -> None:
if order_book_response.error_message:
self._send_error_response_on_create_instrument(request, f"Error from order book service: {order_book_response.error_message}")
return
logger.info(f"Order book created with ID {order_book_response.order_book_id}")
logger.info(f"Creating instrument {request.instrument.symbol} via info service")
self.info_client.send_create_instrument(
CreateInfoInstrumentRequest(instrument=request.instrument, order_book_id=order_book_response.order_book_id),
lambda info_response: self._on_create_instrument_response(info_response, request, order_book_response.order_book_id))
def _on_create_instrument_response(self, info_response: CreateInfoInstrumentResponse,
request: CreateInstrumentRequest, order_book_id: int) -> None:
if info_response.error_message:
self._send_error_response_on_create_instrument(request, f"Error from info service: {info_response.error_message}")
return
logger.info(f"Instrument {request.instrument.symbol} created")
response = CreateInstrumentResponse(request_id=request.request_id)
response.order_book_id = order_book_id
response.created_timestamp = info_response.created_timestamp
logger.debug(f"Sending create instrument response: {response}")
self.send_message(MessageType.ADMIN_CREATE_INSTRUMENT_RESPONSE, response)
def _send_error_response_on_create_instrument(self, request: CreateInstrumentRequest, error_message: str) -> None:
logger.info(f"Error creating instrument {request.instrument.symbol}: {error_message}")
response = CreateInstrumentResponse(request_id=request.request_id, error_message=error_message)
self.send_message(MessageType.ADMIN_CREATE_INSTRUMENT_RESPONSE, response)

View File

@ -0,0 +1,27 @@
from typing import Callable
import socket
import logging
from admin.client_handler import AdminClientHandler
from admin.info_client import AdminInfoClient
from admin.order_book_client import AdminOrderBookClient
from connection.ip_address import IpAddress
from connection.connection_handler import ConnectionHandlerFactory
logger = logging.getLogger(__name__)
class ClientsManager(ConnectionHandlerFactory[AdminClientHandler]):
def __init__(self, order_book_client: AdminOrderBookClient,
info_client: AdminInfoClient) -> None:
self.order_book_client = order_book_client
self.info_client = info_client
def on_new_connection(self, socket_fd: socket.socket, ip_address: IpAddress,
close_callback: Callable[[], None]) -> AdminClientHandler:
return AdminClientHandler(
socket_fd, ip_address, close_callback,
self.order_book_client, self.info_client)
def on_connection_closed(self, connection_handler: AdminClientHandler) -> None:
pass

View File

@ -0,0 +1,32 @@
import socket
from typing import Callable
from common.info_client import BaseInfoClient
from connection.ip_address import IpAddress
from proto.common_pb2 import LoginRequest, LoginResponse
from proto.info_pb2 import OnInstrument, OnPriceDepthBook, OnTopOfBook, OnTrade
class AdminInfoClient(BaseInfoClient):
def __init__(self, socket_fd: socket.socket, ip_address: IpAddress, on_close: Callable[[], None]) -> None:
super().__init__(socket_fd, ip_address, on_close)
def on_disconnect(self) -> None:
super().on_disconnect()
raise ConnectionError("Lost connection to info server")
def login(self, on_login_response: Callable[[LoginResponse], None]) -> None:
# TODO Question candidates how they would approach authentication for the admin client in production
request = LoginRequest(username="admin", password="admin")
self.send_login(request, on_login_response)
def on_instrument(self, message: OnInstrument) -> None:
pass
def on_top_of_book(self, message: OnTopOfBook) -> None:
pass
def on_price_depth_book(self, message: OnPriceDepthBook) -> None:
pass
def on_trade(self, message: OnTrade) -> None:
pass

View File

@ -0,0 +1,61 @@
import logging
from importlib.resources import files as resource_files
from application.application import BaseApplication
from connection.ip_address import IpAddress
from connection.tcp_connection_manager import TcpConnectionManager
from proto.common_pb2 import LoginResponse
from admin.clients_manager import ClientsManager
from admin.info_client import AdminInfoClient
from admin.order_book_client import AdminOrderBookClient
logger = logging.getLogger(__name__)
class AdminApplication(BaseApplication):
def _start(self) -> None:
tcp_connection_manager = TcpConnectionManager()
logger.info("Connecting to order book server...")
order_book_ip = IpAddress(
host=self._config["connectTo"]["order_book"]["host"],
port=self._config["connectTo"]["order_book"]["port"])
order_book_client: AdminOrderBookClient = tcp_connection_manager.connect(
order_book_ip, lambda *args: AdminOrderBookClient(*args))
logger.info("Connecting to info server...")
info_ip = IpAddress(
host=self._config["connectTo"]["info"]["host"],
port=self._config["connectTo"]["info"]["port"])
info_client: AdminInfoClient = tcp_connection_manager.connect(
info_ip, lambda *args: AdminInfoClient(*args))
info_client.login(self._on_info_login_response)
clients_manager = ClientsManager(
order_book_client=order_book_client,
info_client=info_client)
server_ip = IpAddress(
host=self._config["listenOn"]["host"],
port=self._config["listenOn"]["port"])
logger.info(f"Starting admin server on {server_ip}")
with tcp_connection_manager.listen(server_ip, clients_manager):
logger.info("Admin server started.")
logger.info("Running event loop until interrupted.")
while True:
tcp_connection_manager.wait_for_events()
def _on_info_login_response(self, response: LoginResponse) -> None:
if response.error_message:
logger.error(f"Failed to log in to info server: {response}")
raise SystemExit(1)
logger.info(f"Logged in to info server: {response}")
def main() -> None:
config_schema_path = resource_files("application").joinpath("config_schema.json")
app = AdminApplication(config_schema=config_schema_path, app_name="admin")
app.run()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,26 @@
import socket
from typing import Callable
from common.order_book_client import BaseOrderBookClient
from connection.ip_address import IpAddress
from proto.order_book_pb2 import OnOrderBookCreated, OnOrderInserted, OnOrderCancelled, OnTrade
class AdminOrderBookClient(BaseOrderBookClient):
def __init__(self, socket_fd: socket.socket, ip_address: IpAddress, on_close: Callable[[], None]) -> None:
super().__init__(socket_fd, ip_address, on_close)
def on_disconnect(self) -> None:
super().on_disconnect()
raise ConnectionError("Lost connection to order book server")
def on_order_book_created(self, message: OnOrderBookCreated) -> None:
pass
def on_order_inserted(self, message: OnOrderInserted) -> None:
pass
def on_order_cancelled(self, message: OnOrderCancelled) -> None:
pass
def on_trade(self, message: OnTrade) -> None:
pass

View File

@ -0,0 +1,38 @@
import json
import logging
from pathlib import Path
from importlib.resources import files as resource_files
from typing import Any
import jsonschema
logger = logging.getLogger(__name__)
USERS_DATA_SCHEMA_PATH = resource_files("application").joinpath("data_file_schema.json")
class DataFileReader:
def __init__(self, file_path: str) -> None:
self._file_path = Path(file_path)
self._data = self._load_and_validate()
logger.info(f"Loaded data file '{self._file_path}' with {len(self._data['users'])} users")
def _load_and_validate(self) -> dict[str, Any]:
with open(self._file_path) as f:
data = json.load(f)
with USERS_DATA_SCHEMA_PATH.open() as f:
schema = json.load(f)
jsonschema.validate(instance=data, schema=schema)
return data
def authenticate(self, username: str, password: str) -> bool:
user = self._get_user(username)
if user is None:
return False
return user["password"] == password
def _get_user(self, username: str) -> dict[str, Any] | None:
for user in self._data["users"]:
if user["username"] == username:
return user
return None

View File

@ -0,0 +1,45 @@
from typing import Callable
import socket
import logging
from application.data_file_reader import DataFileReader
from connection.ip_address import IpAddress
from connection.connection_handler import ConnectionHandler
from proto.common_pb2 import LoginRequest, LoginResponse
logger = logging.getLogger(__name__)
class BaseClientHandler(ConnectionHandler):
def __init__(self, socket_fd: socket.socket, ip_address: IpAddress, close_callback: Callable[[], None],
data_file_reader: DataFileReader) -> None:
super().__init__(socket_fd, ip_address, close_callback)
self.logged_in = False
self.username: str | None = None
self._data_file_reader = data_file_reader
def _login(self, request: LoginRequest) -> LoginResponse:
response = LoginResponse(request_id=request.request_id)
try:
if request.request_id is None:
raise ValueError("Request ID must be set")
if request.username is None or request.username.strip() == "":
raise ValueError("Username must be set")
if request.password is None or request.password.strip() == "":
raise ValueError("Password must be set")
if self.logged_in:
raise ValueError("Already logged in")
if not self._data_file_reader.authenticate(request.username, request.password):
raise ValueError("Invalid username or password")
self.logged_in = True
self.username = request.username
logger.info(f"Client logged in with username {self.username}")
except Exception as e:
logger.exception("Error handling login request")
response.error_message = str(e)
finally:
return response

View File

@ -0,0 +1,108 @@
from abc import abstractmethod
import logging
import socket
from typing import Callable, TypeVar
from google.protobuf.message import Message
from connection.connection_handler import ConnectionHandler
from connection.ip_address import IpAddress
from proto.info_pb2 import *
from proto.admin_pb2 import CreateInstrumentRequest, CreateInstrumentResponse
from proto.common_pb2 import LoginRequest, LoginResponse, MessageType
logger = logging.getLogger(__name__)
ProtoMessage = TypeVar('ProtoMessage', bound=Message)
class BaseInfoClient(ConnectionHandler):
def __init__(self, socket_fd: socket.socket, ip_address: IpAddress, on_close: Callable[[], None]) -> None:
super().__init__(socket_fd, ip_address, on_close)
self.next_request_id = 1
self.callbacks: dict[int, Callable[[ProtoMessage], None]] = {}
def on_disconnect(self) -> None:
logger.info(f"Disconnected from {self.ip_address}")
@abstractmethod
def on_instrument(self, message: OnInstrument) -> None:
pass
@abstractmethod
def on_top_of_book(self, message: OnTopOfBook) -> None:
pass
@abstractmethod
def on_price_depth_book(self, message: OnPriceDepthBook) -> None:
pass
@abstractmethod
def on_trade(self, message: OnTrade) -> None:
pass
def send_login(self, login_request: LoginRequest, callback: Callable[[LoginResponse], None]) -> int:
request_id = self._get_next_request_id()
self.callbacks[request_id] = callback # type: ignore
login_request.request_id = request_id
self.send_message(MessageType.AUTH_LOGIN_REQUEST, login_request)
return request_id
def send_create_instrument(self, create_info_request: CreateInstrumentRequest, callback: Callable[[CreateInstrumentResponse], None]) -> int:
request_id = self._get_next_request_id()
self.callbacks[request_id] = callback # type: ignore
create_info_request.request_id = request_id
self.send_message(MessageType.INFO_CREATE_INSTRUMENT_REQUEST, create_info_request)
return request_id
def send_order_book_subscribe(self, order_book_subscribe_request: OrderBookSubscribeRequest, callback: Callable[[OrderBookSubscribeResponse], None]) -> int:
request_id = self._get_next_request_id()
self.callbacks[request_id] = callback # type: ignore
order_book_subscribe_request.request_id = request_id
self.send_message(MessageType.INFO_SUBSCRIBE_REQUEST, order_book_subscribe_request)
return request_id
def handle_message(self, message_type: int, raw_message: bytes) -> None:
logger.info(f"Handling message of type {MessageType.Name(message_type)}") # type: ignore
# Decode the message based on message_type and call the appropriate callback
message: Message
if message_type == MessageType.AUTH_LOGIN_RESPONSE:
message = LoginResponse.FromString(raw_message)
elif message_type == MessageType.INFO_CREATE_INSTRUMENT_RESPONSE:
message = CreateInstrumentResponse.FromString(raw_message)
elif message_type == MessageType.INFO_SUBSCRIBE_RESPONSE:
message = OrderBookSubscribeResponse.FromString(raw_message)
elif message_type == MessageType.INFO_ON_INSTRUMENT:
message = OnInstrument.FromString(raw_message)
self.on_instrument(message)
return
elif message_type == MessageType.INFO_ON_TOP_OF_BOOK:
message = OnTopOfBook.FromString(raw_message)
self.on_top_of_book(message)
return
elif message_type == MessageType.INFO_ON_PRICE_DEPTH_BOOK:
message = OnPriceDepthBook.FromString(raw_message)
self.on_price_depth_book(message)
return
elif message_type == MessageType.INFO_ON_TRADE:
message = OnTrade.FromString(raw_message)
self.on_trade(message)
return
else:
raise ValueError(f"Received message with unexpected message_type: {message_type}")
logger.debug(f"Message: {message}")
assert hasattr(message, 'request_id'), f"Response message of type {message_type} does not have a request_id"
callback = self.callbacks.pop(message.request_id, None)
if callback:
callback(message)
else:
raise ValueError(f"Received response with unknown request_id: {message.request_id}")
def send_message(self, message_type: int, message: ProtoMessage) -> None:
logger.info(f"Sending message of type {MessageType.Name(MessageType.ValueType(message_type))}")
super().send_message(message_type, message)
def _get_next_request_id(self) -> int:
request_id = self.next_request_id
self.next_request_id += 1
return request_id

View File

@ -0,0 +1,25 @@
from enum import Enum
from proto.common_pb2 import Side as ProtoSide
class Side(Enum):
BUY = 'buy'
SELL = 'sell'
def to_proto(self) -> ProtoSide.ValueType:
if self == Side.BUY:
return ProtoSide.BUY
elif self == Side.SELL:
return ProtoSide.SELL
else:
raise ValueError(f"Unexpected side: {self}")
@staticmethod
def from_proto(proto_side: ProtoSide) -> 'Side':
if proto_side == ProtoSide.BUY:
return Side.BUY
elif proto_side == ProtoSide.SELL:
return Side.SELL
else:
raise ValueError(f"Unexpected side: {proto_side}")

View File

@ -0,0 +1,115 @@
from abc import abstractmethod
import logging
import socket
from typing import Callable, TypeVar
from connection.connection_handler import ConnectionHandler
from connection.ip_address import IpAddress
from google.protobuf.message import Message
from proto.order_book_pb2 import *
from proto.common_pb2 import MessageType
logger = logging.getLogger(__name__)
ProtoMessage = TypeVar('ProtoMessage', bound=Message)
class BaseOrderBookClient(ConnectionHandler):
def __init__(self, socket_fd: socket.socket, ip_address: IpAddress, on_close: Callable[[], None]) -> None:
super().__init__(socket_fd, ip_address, on_close)
self.next_request_id = 1
self.callbacks: dict[int, Callable[[ProtoMessage], None]] = {}
def on_disconnect(self) -> None:
logger.info(f"Disconnected from {self.ip_address}")
@abstractmethod
def on_order_book_created(self, message: OnOrderBookCreated) -> None:
pass
@abstractmethod
def on_order_inserted(self, message: OnOrderInserted) -> None:
pass
@abstractmethod
def on_order_cancelled(self, message: OnOrderCancelled) -> None:
pass
@abstractmethod
def on_trade(self, message: OnTrade) -> None:
pass
def send_create_order_book(self, request: CreateOrderBookRequest, callback: Callable[[CreateOrderBookResponse], None]) -> int:
request_id = self._get_next_request_id()
self.callbacks[request_id] = callback # type: ignore
request.request_id = request_id
self.send_message(MessageType.ORDER_BOOK_CREATE_REQUEST, request)
return request_id
def send_delete_order_book(self, request: DeleteOrderBookRequest, callback: Callable[[DeleteOrderBookResponse], None]) -> int:
request_id = self._get_next_request_id()
self.callbacks[request_id] = callback # type: ignore
request.request_id = request_id
self.send_message(MessageType.ORDER_BOOK_DELETE_REQUEST, request)
return request_id
def send_insert_order(self, insert_order_request: InsertOrderRequest, callback: Callable[[InsertOrderResponse], None]) -> int:
request_id = self._get_next_request_id()
self.callbacks[request_id] = callback # type: ignore
insert_order_request.request_id = request_id
self.send_message(MessageType.ORDER_BOOK_INSERT_ORDER_REQUEST, insert_order_request)
return request_id
def send_cancel_order(self, cancel_order_request: CancelOrderRequest, callback: Callable[[CancelOrderResponse], None]) -> int:
request_id = self._get_next_request_id()
self.callbacks[request_id] = callback # type: ignore
cancel_order_request.request_id = request_id
self.send_message(MessageType.ORDER_BOOK_CANCEL_ORDER_REQUEST, cancel_order_request)
return request_id
def handle_message(self, message_type: int, raw_message: bytes) -> None:
logger.info(f"Handling message of type {MessageType.Name(message_type)}") # type: ignore
message: Message
if message_type == MessageType.ORDER_BOOK_ON_ORDER_BOOK_CREATED:
message = OnOrderBookCreated.FromString(raw_message)
self.on_order_book_created(message)
return
elif message_type == MessageType.ORDER_BOOK_ON_ORDER_INSERTED:
message = OnOrderInserted.FromString(raw_message)
self.on_order_inserted(message)
return
elif message_type == MessageType.ORDER_BOOK_ON_ORDER_CANCELLED:
message = OnOrderCancelled.FromString(raw_message)
self.on_order_cancelled(message)
return
elif message_type == MessageType.ORDER_BOOK_ON_TRADE:
message = OnTrade.FromString(raw_message)
self.on_trade(message)
return
if message_type == MessageType.ORDER_BOOK_CREATE_RESPONSE:
message = CreateOrderBookResponse.FromString(raw_message)
elif message_type == MessageType.ORDER_BOOK_DELETE_RESPONSE:
message = DeleteOrderBookResponse.FromString(raw_message)
elif message_type == MessageType.ORDER_BOOK_INSERT_ORDER_RESPONSE:
message = InsertOrderResponse.FromString(raw_message)
elif message_type == MessageType.ORDER_BOOK_CANCEL_ORDER_RESPONSE:
message = CancelOrderResponse.FromString(raw_message)
else:
raise ValueError(f"Received message with unexpected message_type: {message_type}")
logger.debug(f"Message: {message}")
assert hasattr(message, 'request_id'), f"Response message of type {message_type} does not have a request_id"
callback = self.callbacks.pop(message.request_id, None)
if callback:
callback(message)
else:
raise ValueError(f"Received response with unknown request_id: {message.request_id}")
def send_message(self, message_type: int, message: ProtoMessage) -> None:
logger.info(f"Sending message of type {MessageType.Name(MessageType.ValueType(message_type))}")
super().send_message(message_type, message)
def _get_next_request_id(self) -> int:
request_id = self.next_request_id
self.next_request_id += 1
return request_id

View File

@ -0,0 +1,8 @@
from decimal import Decimal
DECIMAL_PRECISION = Decimal("0.0001")
def decimal_from_float(value: float) -> Decimal:
return Decimal(str(value)).quantize(DECIMAL_PRECISION)

View File

@ -27,6 +27,7 @@ def read_message(socket_fd: socket.socket) -> tuple[int, bytes]:
raw_msg = socket_fd.recv(msg_len) raw_msg = socket_fd.recv(msg_len)
logger.debug(f"Actual message length: {len(raw_msg)}") logger.debug(f"Actual message length: {len(raw_msg)}")
# TODO introduce bug here to test error handling
message_type = int.from_bytes(raw_msg[:MESSAGE_TYPE_BYTES], byteorder=BYTE_ORDER) message_type = int.from_bytes(raw_msg[:MESSAGE_TYPE_BYTES], byteorder=BYTE_ORDER)
message = raw_msg[MESSAGE_TYPE_BYTES:] message = raw_msg[MESSAGE_TYPE_BYTES:]

View File

@ -169,6 +169,7 @@ class TcpConnectionManager:
client_connection.handle_message(message_type, message) client_connection.handle_message(message_type, message)
except Exception: except Exception:
logger.exception(f"Error while handling message from {ip_address}. Client will be disconnected") logger.exception(f"Error while handling message from {ip_address}. Client will be disconnected")
# TODO do not disconnect and introduce bug here to test error handling?
self._close_socket(socket_fd, ip_address) self._close_socket(socket_fd, ip_address)
raise raise
logger.debug(f"Done handling message") logger.debug(f"Done handling message")

View File

View File

@ -0,0 +1,128 @@
from typing import Callable
from datetime import datetime
import socket
import logging
from application.data_file_reader import DataFileReader
from common.client_handler import BaseClientHandler
from connection.ip_address import IpAddress
from info.instruments_manager import InstrumentsManager
from info.models import Instrument
from info.subscriptions_manager import SubscriptionsManager
from proto.info_pb2 import *
from proto.common_pb2 import LoginRequest, MessageType
logger = logging.getLogger(__name__)
class InfoClientHandler(BaseClientHandler):
def __init__(self, socket_fd: socket.socket, ip_address: IpAddress, close_callback: Callable[[], None],
instruments_manager: InstrumentsManager,
subscriptions_manager: SubscriptionsManager,
data_file_reader: DataFileReader) -> None:
super().__init__(socket_fd, ip_address, close_callback, data_file_reader)
self.instruments_manager = instruments_manager
self.subscriptions_manager = subscriptions_manager
def on_disconnect(self) -> None:
logger.info(f"Client {self.ip_address} disconnected")
if self.logged_in:
self.subscriptions_manager.remove_client_subscriptions(self)
pass
def handle_message(self, message_type: int, message: bytes) -> None:
logger.info(f"Handling message of type {MessageType.Name(message_type)} by {self.username}") # type: ignore
if message_type == MessageType.AUTH_LOGIN_REQUEST:
self._handle_login_request(message)
elif message_type == MessageType.INFO_CREATE_INSTRUMENT_REQUEST:
self._handle_create_instrument_request(message)
elif message_type == MessageType.INFO_SUBSCRIBE_REQUEST:
self._handle_subscribe_request(message)
else:
raise ValueError(f"Unexpected message type received: {message_type}")
def _handle_login_request(self, message: bytes) -> None:
request = self._deserialize_message(LoginRequest, message)
response = self._login(request)
logger.debug(f"Sending login response: {response}")
self.send_message(MessageType.AUTH_LOGIN_RESPONSE, response)
if response.error_message:
raise ValueError("Login failed, client will be disconnected")
self._publish_all_instruments_to_client()
def _publish_all_instruments_to_client(self) -> None:
assert self.logged_in
logger.info(f"Publishing full state to client {self.username}")
for instrument in self.instruments_manager.get_all_complete_instruments():
self.publish_instrument_to_client(instrument)
def publish_instrument_to_client(self, instrument: Instrument) -> None:
assert self.logged_in
logger.info(f"Publishing instrument {instrument.symbol} to client {self.username}")
self.send_message(MessageType.INFO_ON_INSTRUMENT, OnInstrument(
instrument=instrument.to_instrument_proto(),
order_book_id=instrument.order_book_id,
tick_size=instrument.tick_size))
def _handle_create_instrument_request(self, message: bytes) -> None:
request_timestamp = datetime.now()
response = CreateInstrumentResponse()
try:
request = self._deserialize_message(CreateInstrumentRequest, message)
response.request_id = request.request_id
if not self.logged_in:
raise ValueError("Client is not logged in")
if request.request_id is None:
raise ValueError("Request ID must be set")
if request.instrument is None:
raise ValueError("Instrument must be set")
if request.instrument.symbol is None:
raise ValueError("Symbol must be set")
if request.instrument.description is None:
raise ValueError("Description must be set")
if request.instrument.currency is None:
raise ValueError("Currency must be set")
if request.instrument.multiplier is None:
raise ValueError("Multiplier must be set")
if request.instrument.multiplier <= 0:
raise ValueError("Multiplier must be positive")
if request.order_book_id is None:
raise ValueError("Order book ID must be set")
if self.instruments_manager.exists(request.instrument.symbol):
raise ValueError(f"Instrument {request.instrument.symbol} already exists")
instrument = Instrument.from_proto(request.instrument, request_timestamp, request.order_book_id)
logger.debug(f"New instrument {request.instrument.symbol} is valid, storing it in instruments manager")
self.instruments_manager.add_instrument(instrument)
except Exception as e:
logger.debug(f"Error creating instrument", exc_info=True)
response.error_message = str(e)
finally:
logger.debug(f"Sending create instrument response: {response}")
self.send_message(MessageType.INFO_CREATE_INSTRUMENT_RESPONSE, response)
def _handle_subscribe_request(self, message: bytes) -> None:
response = OrderBookSubscribeResponse()
try:
request = self._deserialize_message(OrderBookSubscribeRequest, message)
response.request_id = request.request_id
if not self.logged_in:
raise ValueError("Client is not logged in")
if request.request_id is None:
raise ValueError("Request ID must be set")
if request.instrument_symbol is None:
raise ValueError("Instrument symbol must be set")
if request.subscription_type is None:
raise ValueError("Subscription type must be set")
self.subscriptions_manager.add_client_subscription(
self, request.instrument_symbol, request.subscription_type)
except Exception as e:
logger.debug(f"Error subscribing", exc_info=True)
response.error_message = str(e)
raise # note that the response is sent before this exception is propagated, then the client is disconnected
finally:
logger.debug(f"Sending subscribe response: {response}")
self.send_message(MessageType.INFO_SUBSCRIBE_RESPONSE, response)

View File

@ -0,0 +1,55 @@
from dataclasses import dataclass
from typing import Callable
import socket
import logging
from application.data_file_reader import DataFileReader
from connection.ip_address import IpAddress
from connection.connection_handler import ConnectionHandlerFactory
from info.client_handler import InfoClientHandler
from info.i_instruments_listener import IInstrumentsListener
from info.instruments_manager import InstrumentsManager
from info.models import Instrument
from info.subscriptions_manager import SubscriptionsManager
logger = logging.getLogger(__name__)
@dataclass
class PendingInstrument:
symbol: str | None = None
tick_size: float | None = None
def is_complete(self) -> bool:
return self.symbol is not None and self.tick_size is not None
class ClientsManager(ConnectionHandlerFactory[InfoClientHandler], IInstrumentsListener):
def __init__(self, instruments_manager: InstrumentsManager, subscriptions_manager: SubscriptionsManager,
data_file_reader: DataFileReader) -> None:
self.instruments_manager = instruments_manager
self.subscriptions_manager = subscriptions_manager
self.data_file_reader = data_file_reader
self.connected_clients: dict[socket.socket, InfoClientHandler] = {}
def on_new_connection(self, socket_fd: socket.socket, ip_address: IpAddress,
close_callback: Callable[[], None]) -> InfoClientHandler:
client_handler = InfoClientHandler(
socket_fd, ip_address, close_callback,
self.instruments_manager, self.subscriptions_manager,
self.data_file_reader)
assert socket_fd not in self.connected_clients
self.connected_clients[socket_fd] = client_handler
return client_handler
def on_connection_closed(self, connection_handler: InfoClientHandler) -> None:
assert connection_handler.socket_fd in self.connected_clients
del self.connected_clients[connection_handler.socket_fd]
def on_new_complete_instrument(self, instrument: Instrument) -> None:
assert instrument.is_complete()
logger.info(f"Publishing new complete instrument {instrument.symbol} to {len(self.connected_clients)} clients")
for client in self.connected_clients.values():
if client.logged_in:
client.publish_instrument_to_client(instrument)
else:
logger.debug(f"Skipping client {client.ip_address} that's not logged in")

View File

@ -0,0 +1,10 @@
from abc import ABC, abstractmethod
from info.models import Instrument
from proto.info_pb2 import *
class IInstrumentsListener(ABC):
@abstractmethod
def on_new_complete_instrument(self, instrument: Instrument) -> None:
pass

View File

@ -0,0 +1,11 @@
from abc import ABC, abstractmethod
class IOrderBookListener(ABC):
@abstractmethod
def on_order_book_created(self, order_book_id: int, tick_size: float) -> None:
pass
@abstractmethod
def on_order_book_changed(self, order_book_id: int, change_timestamp: int) -> None:
pass

View File

@ -0,0 +1,14 @@
{
"logLevel": "DEBUG",
"logDirectory": "./logs",
"listenOn": {
"host": "localhost",
"port": 51365
},
"connectTo": {
"order_book": {
"host": "localhost",
"port": 51364
}
}
}

View File

@ -0,0 +1,77 @@
import logging
from typing import Generator
from info.i_instruments_listener import IInstrumentsListener
from info.i_order_book_listener import IOrderBookListener
from .models import Instrument
logger = logging.getLogger(__name__)
class InstrumentsManager(IOrderBookListener):
def __init__(self) -> None:
self.instrument_by_symbol: dict[str, Instrument] = {}
self.instrument_by_order_book_id: dict[int, Instrument] = {}
self._pending_tick_size_by_order_book_id: dict[int, float] = {}
self.instruments_listener: IInstrumentsListener | None = None
def set_instruments_listener(self, instruments_listener: IInstrumentsListener) -> None:
assert self.instruments_listener is None
assert len(self.instrument_by_symbol) == 0, "Instruments manager must be empty when setting instruments listener"
self.instruments_listener = instruments_listener
def add_instrument(self, instrument: Instrument) -> None:
if instrument.symbol in self.instrument_by_symbol:
raise ValueError(f"Instrument {instrument.symbol} already exists")
if instrument.order_book_id in self.instrument_by_order_book_id:
raise ValueError(f"Instrument with order book id {instrument.order_book_id} already exists")
self.instrument_by_symbol[instrument.symbol] = instrument
self.instrument_by_order_book_id[instrument.order_book_id] = instrument
if instrument.order_book_id in self._pending_tick_size_by_order_book_id:
tick_size = self._pending_tick_size_by_order_book_id[instrument.order_book_id]
del self._pending_tick_size_by_order_book_id[instrument.order_book_id]
instrument.tick_size = tick_size
logger.debug(f"Instrument {instrument.symbol} is now complete, publishing to listeners")
self._publish_instrument(instrument)
return
def exists(self, symbol: str) -> bool:
return symbol in self.instrument_by_symbol
def get_instrument_by_symbol(self, symbol: str) -> Instrument:
instrument = self.instrument_by_symbol.get(symbol)
if instrument is None:
raise ValueError(f"Instrument {symbol} not found")
return instrument
def get_instrument_by_order_book_id(self, order_book_id: int) -> Instrument:
instrument = self.instrument_by_order_book_id.get(order_book_id)
if instrument is None:
raise ValueError(f"Instrument with order book id {order_book_id} not found")
return instrument
def get_all_complete_instruments(self) -> Generator[Instrument, None, None]:
for instrument in self.instrument_by_symbol.values():
if instrument.is_complete():
yield instrument
# IOrderBookListener implementation
def on_order_book_created(self, order_book_id: int, tick_size: float) -> None:
logger.debug(f"Order book {order_book_id} created with tick size {tick_size}")
if order_book_id in self.instrument_by_order_book_id:
instrument = self.instrument_by_order_book_id[order_book_id]
instrument.tick_size = tick_size
logger.debug(f"Instrument {instrument.symbol} is now complete, publishing to listeners")
self._publish_instrument(instrument)
return
logger.debug(f"Order book {order_book_id} created but instrument not found, storing tick size for later")
self._pending_tick_size_by_order_book_id[order_book_id] = tick_size
def on_order_book_changed(self, order_book_id: int, change_timestamp: int) -> None:
pass
def _publish_instrument(self, instrument: Instrument) -> None:
assert self.instruments_listener is not None
assert instrument.is_complete()
self.instruments_listener.on_new_complete_instrument(instrument)

59
solution/src/info/main.py Normal file
View File

@ -0,0 +1,59 @@
import logging
from importlib.resources import files as resource_files
from application.application import BaseApplication
from application.data_file_reader import DataFileReader
from connection.ip_address import IpAddress
from connection.tcp_connection_manager import TcpConnectionManager
from info.clients_manager import ClientsManager
from info.instruments_manager import InstrumentsManager
from info.order_book_client_aggregator import OrderBookClientAggregator
from info.subscriptions_manager import SubscriptionsManager
logger = logging.getLogger(__name__)
class InfoApplication(BaseApplication):
def _start(self) -> None:
data_file_reader = DataFileReader(self._config["dataFilePath"])
tcp_connection_manager = TcpConnectionManager()
logger.info("Connecting to order book server...")
order_book_service_ip_address = IpAddress(
host=self._config["connectTo"]["order_book"]["host"],
port=self._config["connectTo"]["order_book"]["port"])
order_book_client_aggregator: OrderBookClientAggregator = tcp_connection_manager.connect(
order_book_service_ip_address, lambda *args: OrderBookClientAggregator(*args))
instruments_manager = InstrumentsManager()
order_book_client_aggregator.add_order_book_listener(instruments_manager)
subscriptions_manager = SubscriptionsManager(
instruments_manager=instruments_manager,
order_book_aggregator=order_book_client_aggregator)
order_book_client_aggregator.add_order_book_listener(subscriptions_manager)
clients_manager = ClientsManager(
instruments_manager=instruments_manager,
subscriptions_manager=subscriptions_manager,
data_file_reader=data_file_reader)
instruments_manager.set_instruments_listener(clients_manager)
server_ip_address = IpAddress(
host=self._config["listenOn"]["host"],
port=self._config["listenOn"]["port"])
logger.info(f"Starting server on {server_ip_address}")
with tcp_connection_manager.listen(server_ip_address, clients_manager):
logger.info("Server started.")
logger.info("Running event loop until interrupted.")
while True:
tcp_connection_manager.wait_for_events()
def main() -> None:
config_schema_path = resource_files("application").joinpath("config_schema.json")
app = InfoApplication(config_schema=config_schema_path, app_name="info")
app.run()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,36 @@
from dataclasses import dataclass
from datetime import datetime
from decimal import Decimal
from proto.common_pb2 import Instrument as ProtoInstrument
@dataclass
class Instrument:
symbol: str
description: str
currency: str
multiplier: Decimal
created_timestamp: datetime
order_book_id: int
tick_size: float | None = None
def is_complete(self) -> bool:
return self.tick_size is not None
def to_instrument_proto(self) -> ProtoInstrument:
return ProtoInstrument(
symbol=self.symbol,
description=self.description,
currency=self.currency,
multiplier=float(self.multiplier))
@staticmethod
def from_proto(proto_instrument: ProtoInstrument, created_timestamp: datetime, order_book_id: int) -> 'Instrument':
return Instrument(
symbol=proto_instrument.symbol,
description=proto_instrument.description,
currency=proto_instrument.currency,
multiplier=Decimal(proto_instrument.multiplier),
created_timestamp=created_timestamp,
order_book_id=order_book_id)

View File

@ -0,0 +1,208 @@
from dataclasses import dataclass, field
from datetime import datetime
from decimal import Decimal
import logging
import socket
from typing import Callable, Iterator
from common.models import Side
from common.order_book_client import BaseOrderBookClient
from connection.ip_address import IpAddress
from info.i_order_book_listener import IOrderBookListener
from proto.order_book_pb2 import *
from proto.info_pb2 import PriceLevel as ProtoPriceLevel
from collections import defaultdict
logger = logging.getLogger(__name__)
@dataclass
class Order:
order_id: int
order_book_id: int
quantity: int
price: Decimal
side: Side
@dataclass
class OrderBookLevel:
buy_quantity: int = 0
sell_quantity: int = 0
@dataclass
class OrderBook:
order_book_id: int
tick_size: Decimal
created_timestamp: datetime
levels: defaultdict[Decimal, OrderBookLevel] = field(default_factory=lambda: defaultdict(OrderBookLevel))
def __getitem__(self, price: Decimal) -> OrderBookLevel:
return self.levels[price]
def __setitem__(self, price: Decimal, level: OrderBookLevel) -> None:
self.levels[price] = level
def __delitem__(self, price: Decimal) -> None:
del self.levels[price]
def __len__(self) -> int:
return len(self.levels)
def __iter__(self) -> Iterator[Decimal]:
return iter(self.levels)
def __contains__(self, price: Decimal) -> bool:
return price in self.levels
@dataclass
class OrderBookPriceLevel:
quantity: int
price: Decimal
def to_proto(self) -> ProtoPriceLevel:
return ProtoPriceLevel(
quantity=self.quantity,
price=float(self.price)
)
class OrderBookClientAggregator(BaseOrderBookClient):
def __init__(self, socket_fd: socket.socket, ip_address: IpAddress, on_close: Callable[[], None]) -> None:
super().__init__(socket_fd, ip_address, on_close)
self._order_book_listeners: list[IOrderBookListener] = []
self._order_books = dict[int, OrderBook]()
self._order_by_id = dict[int, Order]()
self._pending_trade_ids_by_order_book = defaultdict[int, set[int]](set)
def add_order_book_listener(self, order_book_listener: IOrderBookListener) -> None:
self._order_book_listeners.append(order_book_listener)
def get_top_of_book(self, order_book_id: int) -> dict[Side, OrderBookPriceLevel]:
assert order_book_id in self._order_books, f"Order book {order_book_id} not found"
assert self._is_order_book_consistent(order_book_id), "Not all trades have been processed, order book would be inconsistent"
top_of_book: dict[Side, OrderBookPriceLevel] = {}
order_book = self._order_books[order_book_id]
best_bid = max((price for price in order_book if order_book[price].buy_quantity > 0), default=None)
if best_bid is not None:
top_of_book[Side.BUY] = OrderBookPriceLevel(
quantity=order_book[best_bid].buy_quantity,
price=best_bid
)
best_ask = min((price for price in order_book if order_book[price].sell_quantity > 0), default=None)
if best_ask is not None:
top_of_book[Side.SELL] = OrderBookPriceLevel(
quantity=order_book[best_ask].sell_quantity,
price=best_ask
)
assert len(top_of_book) <= 2
return top_of_book
def get_price_depth_book(self, order_book_id: int) -> dict[Decimal, OrderBookLevel]:
assert order_book_id in self._order_books, f"Order book {order_book_id} not found"
assert self._is_order_book_consistent(order_book_id), "Not all trades have been processed, order book would be inconsistent"
order_book = self._order_books[order_book_id]
return {price: level for price, level in order_book.levels.items()
if level.buy_quantity > 0 or level.sell_quantity > 0}
def on_order_book_created(self, message: OnOrderBookCreated) -> None:
order_book_id = message.order_book_id
logger.info(f"Received on order book created message with id {order_book_id}")
assert order_book_id not in self._order_books, f"Order book {order_book_id} already exists"
self._order_books[order_book_id] = OrderBook(
order_book_id=order_book_id,
tick_size=Decimal(str(message.tick_size)),
created_timestamp=datetime.fromtimestamp(message.created_timestamp)
)
logger.info(f"Order book {order_book_id} created")
self._report_order_book_created(self._order_books[order_book_id])
def on_order_inserted(self, order_inserted_message: OnOrderInserted) -> None:
order_book_id = order_inserted_message.order_book_id
logger.info(f"Received on order inserted message on order book {order_book_id}")
assert order_book_id in self._order_books, f"Order book {order_book_id} not found"
new_order = Order(
order_id=order_inserted_message.order_id,
order_book_id=order_book_id,
quantity=order_inserted_message.quantity,
price=Decimal(str(order_inserted_message.price)),
side=Side.from_proto(order_inserted_message.side)
)
self._order_by_id[new_order.order_id] = new_order
if order_inserted_message.trade_ids:
logger.debug(f"Order {new_order.order_id} has pending trades {order_inserted_message.trade_ids}")
pending_trade_ids = self._pending_trade_ids_by_order_book[new_order.order_book_id]
pending_trade_ids.update(order_inserted_message.trade_ids)
order_book = self._order_books[new_order.order_book_id]
if new_order.side == Side.BUY:
order_book[new_order.price].buy_quantity += order_inserted_message.quantity
else:
order_book[new_order.price].sell_quantity += order_inserted_message.quantity
self._report_order_book_changed(new_order.order_book_id, order_inserted_message.timestamp)
def on_order_cancelled(self, cancelled_order_message: OnOrderCancelled) -> None:
order_id = cancelled_order_message.order_id
logger.info(f"Received on order cancelled message for order id {order_id}")
assert order_id in self._order_by_id, f"Order {order_id} not found"
cancelled_order = self._order_by_id[order_id]
self._order_by_id.pop(order_id)
assert cancelled_order.order_book_id in self._order_books, f"Order book {cancelled_order.order_book_id} not found"
order_book = self._order_books[cancelled_order.order_book_id]
if cancelled_order.side == Side.BUY:
order_book[cancelled_order.price].buy_quantity -= cancelled_order.quantity
else:
order_book[cancelled_order.price].sell_quantity -= cancelled_order.quantity
self._report_order_book_changed(cancelled_order.order_book_id, cancelled_order_message.cancellation_timestamp)
def on_trade(self, trade_message: OnTrade) -> None:
# Assuming that the trade always comes after the order insertion
assert trade_message.order_book_id in self._pending_trade_ids_by_order_book, f"Order book {trade_message.order_book_id} not found"
pending_trade_ids = self._pending_trade_ids_by_order_book[trade_message.order_book_id]
pending_trade_ids.remove(trade_message.trade_id)
# Reduce quantities for both buy and sell sides at the trade price
buy_order = self._order_by_id[trade_message.buy_order_id]
buy_order.quantity -= trade_message.quantity
if buy_order.quantity == 0:
self._order_by_id.pop(trade_message.buy_order_id)
sell_order = self._order_by_id[trade_message.sell_order_id]
sell_order.quantity -= trade_message.quantity
if sell_order.quantity == 0:
self._order_by_id.pop(trade_message.sell_order_id)
trade_price = Decimal(str(trade_message.price))
order_book = self._order_books[trade_message.order_book_id]
order_book[trade_price].buy_quantity -= trade_message.quantity
order_book[trade_price].sell_quantity -= trade_message.quantity
self._report_order_book_changed(trade_message.order_book_id, trade_message.timestamp)
def _is_order_book_consistent(self, order_book_id: int) -> bool:
awaiting_trade_ids = self._pending_trade_ids_by_order_book[order_book_id]
return len(awaiting_trade_ids) == 0
def _report_order_book_created(self, order_book: OrderBook) -> None:
if not self._order_book_listeners:
logger.warning(f"No order book listeners set, skipping update")
return
for listener in self._order_book_listeners:
listener.on_order_book_created(order_book.order_book_id, order_book.tick_size)
def _report_order_book_changed(self, order_book_id: int, change_timestamp: int) -> None:
if not self._order_book_listeners:
logger.warning(f"No order book listeners set, skipping update")
return
if not self._is_order_book_consistent(order_book_id):
logger.debug(f"Order book {order_book_id} not consistent yet, skipping update")
return
for listener in self._order_book_listeners:
listener.on_order_book_changed(order_book_id, change_timestamp)

View File

View File

@ -0,0 +1,117 @@
from collections import defaultdict
from dataclasses import dataclass
from decimal import Decimal
import logging
from common.client_handler import BaseClientHandler
from proto.common_pb2 import MessageType
from proto.info_pb2 import OnTopOfBook, OnPriceDepthBook, SubscriptionType, PriceLevel as ProtoPriceLevel
from common.models import Side
from info.i_order_book_listener import IOrderBookListener
from info.instruments_manager import InstrumentsManager
from info.order_book_client_aggregator import OrderBookClientAggregator
logger = logging.getLogger(__name__)
@dataclass
class Subscription:
client: BaseClientHandler
subscription_type: SubscriptionType.ValueType
def __eq__(self, other):
if not isinstance(other, Subscription):
raise NotImplementedError()
return id(self.client) == id(other.client)
def __hash__(self):
return hash(id(self.client))
class SubscriptionsManager(IOrderBookListener):
def __init__(self, instruments_manager: InstrumentsManager, order_book_aggregator: OrderBookClientAggregator) -> None:
self.instruments_manager: InstrumentsManager = instruments_manager
self.order_book_aggregator: OrderBookClientAggregator = order_book_aggregator
self.subscriptions_by_order_book: dict[int, set[Subscription]] = defaultdict(set)
self._last_change_timestamp_by_order_book: dict[int, int] = {}
# TODO cache last top-of-book and price-depth-book updates, compare before sending to clients
def add_client_subscription(self, client_handler: BaseClientHandler, instrument_symbol: str,
subscription_type: SubscriptionType.ValueType) -> None:
assert client_handler.logged_in, "Client must be logged in to subscribe"
logger.info(f"Client {client_handler.username} subscribed to {instrument_symbol} for {SubscriptionType.Name(subscription_type)}")
order_book_id = self.instruments_manager.get_instrument_by_symbol(instrument_symbol).order_book_id
self.subscriptions_by_order_book[order_book_id].add(Subscription(client_handler, subscription_type))
last_change_timestamp = self._last_change_timestamp_by_order_book.get(order_book_id)
if last_change_timestamp is None:
return
logger.debug(f"Sending initial update to {client_handler.username}")
if subscription_type == SubscriptionType.TOP_OF_BOOK:
top_of_book_update = self._make_top_of_book_update(order_book_id, last_change_timestamp)
client_handler.send_message(MessageType.INFO_ON_TOP_OF_BOOK, top_of_book_update)
elif subscription_type == SubscriptionType.PRICE_DEPTH_BOOK:
price_depth_book = self._make_price_depth_book_update(order_book_id, last_change_timestamp)
client_handler.send_message(MessageType.INFO_ON_PRICE_DEPTH_BOOK, price_depth_book)
def remove_client_subscriptions(self, client_handler: BaseClientHandler) -> None:
logger.info(f"Removing all subscriptions for {client_handler.username}")
for subscriptions in self.subscriptions_by_order_book.values():
# subscription type is not used in the comparison, so using anything
subscriptions.discard(Subscription(client_handler, SubscriptionType.TOP_OF_BOOK))
def on_order_book_created(self, order_book_id: int, tick_size: float) -> None:
pass
def on_order_book_changed(self, order_book_id: int, change_timestamp: int) -> None:
logger.info(f"Order book {order_book_id} changed at {change_timestamp}")
self._last_change_timestamp_by_order_book[order_book_id] = change_timestamp
top_of_book_update = self._make_top_of_book_update(order_book_id, change_timestamp)
price_depth_book = self._make_price_depth_book_update(order_book_id, change_timestamp)
subscriptions = self.subscriptions_by_order_book[order_book_id]
logger.debug(f"Sending updates to {len(subscriptions)} clients")
for subscription in subscriptions:
if subscription.subscription_type == SubscriptionType.TOP_OF_BOOK:
subscription.client.send_message(MessageType.INFO_ON_TOP_OF_BOOK, top_of_book_update)
elif subscription.subscription_type == SubscriptionType.PRICE_DEPTH_BOOK:
subscription.client.send_message(MessageType.INFO_ON_PRICE_DEPTH_BOOK, price_depth_book)
logger.debug("All updates sent")
def _make_top_of_book_update(self, order_book_id: int, change_timestamp: int) -> OnTopOfBook:
instrument = self.instruments_manager.get_instrument_by_order_book_id(order_book_id)
top_of_book = self.order_book_aggregator.get_top_of_book(order_book_id)
best_bid = top_of_book[Side.BUY].to_proto() if Side.BUY in top_of_book else None
best_ask = top_of_book[Side.SELL].to_proto() if Side.SELL in top_of_book else None
return OnTopOfBook(
instrument_symbol=instrument.symbol,
timestamp=change_timestamp,
best_bid=best_bid,
best_ask=best_ask
)
def _make_price_depth_book_update(self, order_book_id: int, change_timestamp: int) -> OnPriceDepthBook:
instrument = self.instruments_manager.get_instrument_by_order_book_id(order_book_id)
price_depth_book = self.order_book_aggregator.get_price_depth_book(order_book_id)
bids = []
asks = []
for price, level in price_depth_book.items():
if level.buy_quantity > 0:
bids.append(self._make_proto_price_level(price, level.buy_quantity))
if level.sell_quantity > 0:
asks.append(self._make_proto_price_level(price, level.sell_quantity))
return OnPriceDepthBook(
instrument_symbol=instrument.symbol,
timestamp=change_timestamp,
bids=bids,
asks=asks
)
def _make_proto_price_level(self, price: Decimal, quantity: int) -> ProtoPriceLevel:
return ProtoPriceLevel(
quantity=quantity,
price=float(price)
)

View File

View File

@ -0,0 +1,204 @@
import itertools
from typing import Callable
from datetime import datetime
import socket
import logging
from common.models import Side
from common.utils import decimal_from_float
from connection.ip_address import IpAddress
from connection.connection_handler import ConnectionHandler
from order_book.i_clients_manager import IClientsManager, create_order_inserted_message
from order_book.matching_engine import OrderBook
from order_book.models import ExecutionReport, Order
from order_book.order_book_manager import OrderBookManager
from proto.order_book_pb2 import * # noqa: F403
from proto.order_book_pb2 import CreateOrderBookRequest, CreateOrderBookResponse, DeleteOrderBookRequest, DeleteOrderBookResponse
from proto.common_pb2 import MessageType
logger = logging.getLogger(__name__)
class OrderBookClientHandler(ConnectionHandler):
def __init__(self, socket_fd: socket.socket, ip_address: IpAddress, close_callback: Callable[[], None],
client_manager: IClientsManager, order_book_manager: OrderBookManager) -> None:
super().__init__(socket_fd, ip_address, close_callback)
self.client_manager = client_manager
self.order_book_manager = order_book_manager
def on_disconnect(self) -> None:
logger.info(f"Client {self.ip_address} disconnected")
# TODO cancel on disconnect
pass
def handle_message(self, message_type: int, message: bytes) -> None:
logger.info(f"Handling message of type {MessageType.Name(message_type)} from {self.ip_address}") # type: ignore
if message_type == MessageType.ORDER_BOOK_CREATE_REQUEST:
self._handle_create_order_book_request(message)
elif message_type == MessageType.ORDER_BOOK_DELETE_REQUEST:
self._handle_delete_order_book_request(message)
elif message_type == MessageType.ORDER_BOOK_INSERT_ORDER_REQUEST:
self._handle_insert_order_request(message)
elif message_type == MessageType.ORDER_BOOK_CANCEL_ORDER_REQUEST:
self._handle_cancel_order_request(message)
else:
raise ValueError(f"Unexpected message type received: {message_type}")
def publish_full_state(self) -> None:
logger.info(f"Publishing full state to client {self.ip_address}")
for order_book in self.order_book_manager.order_book_map.values():
on_order_book_created = OnOrderBookCreated(
order_book_id=order_book.order_book_id, tick_size=float(order_book.tick_size))
self.send_message(MessageType.ORDER_BOOK_ON_ORDER_BOOK_CREATED, on_order_book_created)
for order in itertools.chain(order_book.bids, order_book.asks):
on_order_inserted = create_order_inserted_message(order, order_book.order_book_id)
self.send_message(MessageType.ORDER_BOOK_ON_ORDER_INSERTED, on_order_inserted)
def _handle_create_order_book_request(self, message: bytes) -> None:
response = CreateOrderBookResponse()
order_book: OrderBook | None = None
try:
request = self._deserialize_message(CreateOrderBookRequest, message)
response.request_id = request.request_id
if request.request_id is None:
raise ValueError("Request ID must be set")
if request.tick_size is None:
raise ValueError("Tick size must be set")
if request.tick_size <= 0:
raise ValueError(f"Tick size must be positive: {request.tick_size}")
logger.info(f"Creating order book with tick size {request.tick_size}")
order_book = self.order_book_manager.create_order_book(request.tick_size)
assert order_book is not None
response.order_book_id = order_book.order_book_id
logger.info(f"Order book created with id {order_book.order_book_id}")
except Exception as e:
logger.debug(f"Error creating order book", exc_info=True)
response.error_message = str(e)
raise
finally:
logger.debug(f"Sending create order book response: {response}")
self.send_message(MessageType.ORDER_BOOK_CREATE_RESPONSE, response)
if order_book is not None:
self.client_manager.publish_order_book_created(order_book.order_book_id, float(order_book.tick_size))
def _handle_delete_order_book_request(self, message: bytes) -> None:
response = DeleteOrderBookResponse()
try:
request = self._deserialize_message(DeleteOrderBookRequest, message)
response.request_id = request.request_id
if request.request_id is None:
raise ValueError("Request ID must be set")
if request.order_book_id is None:
raise ValueError("Order book ID must be set")
logger.info(f"Deleting order book {request.order_book_id}")
self.order_book_manager.delete_order_book(request.order_book_id)
logger.info(f"Order book {request.order_book_id} deleted")
except Exception as e:
logger.debug(f"Error deleting order book", exc_info=True)
response.error_message = str(e)
raise
finally:
logger.debug(f"Sending delete order book response: {response}")
self.send_message(MessageType.ORDER_BOOK_DELETE_RESPONSE, response)
def _handle_insert_order_request(self, message: bytes) -> None:
request_timestamp = datetime.now()
response = InsertOrderResponse()
order_id: int | None = None
execution_reports: list[ExecutionReport] = []
try:
request = self._deserialize_message(InsertOrderRequest, message)
response.request_id = request.request_id
if request.request_id is None:
raise ValueError("Request ID must be set")
if request.order_book_id is None:
raise ValueError("Order book ID must be set")
if request.side is None:
raise ValueError("Side must be set")
if request.price is None:
raise ValueError("Price must be set")
if request.quantity is None:
raise ValueError("Quantity must be set")
if request.quantity <= 0:
raise ValueError(f"Quantity must be positive: {request.quantity}")
if not request.username:
raise ValueError("Username must be set")
price_decimal = decimal_from_float(request.price)
if price_decimal <= 0:
raise ValueError(f"Price must be positive: {price_decimal}")
order_book = self.order_book_manager.get_order_book(request.order_book_id)
if price_decimal % order_book.tick_size != 0:
raise ValueError(f"Price {price_decimal} must be a multiple of tick size {order_book.tick_size}")
order_side = Side.from_proto(request.side)
order = Order(order_id=None, timestamp=request_timestamp, side=order_side, price=price_decimal,
original_quantity=request.quantity, username=request.username)
logger.info(f"Inserting order into order book {request.order_book_id}: {order}")
order_id = order_book.add_order(order)
logger.info(f"Order added with id {order_id}")
logger.debug(f"Matching orders in order book {request.order_book_id}")
execution_reports = order_book.match_orders(aggressor_side=order_side)
logger.info(f"Executed {len(execution_reports)} trades")
assert set(report.trade_id for report in execution_reports) == set(order.aggressive_trade_ids)
response.order_id = order_id
response.timestamp = int(request_timestamp.timestamp())
response.trade_ids.extend(order.aggressive_trade_ids)
response.traded_quantity = order.executed_quantity
except Exception as e:
logger.debug(f"Error inserting order", exc_info=True)
response.error_message = str(e)
finally:
logger.debug(f"Sending insert order response: {response}")
self.send_message(MessageType.ORDER_BOOK_INSERT_ORDER_RESPONSE, response)
if response.error_message:
assert execution_reports == [] and order_id is None, "Order should not have been inserted if there was an error"
return
logger.debug(f"Publishing updates to connected clients")
self.client_manager.publish_order_inserted(order, request.order_book_id)
self.client_manager.publish_trades(execution_reports, request.order_book_id)
def _handle_cancel_order_request(self, message: bytes) -> None:
request_timestamp = datetime.now()
response = CancelOrderResponse()
try:
request = self._deserialize_message(CancelOrderRequest, message)
response.request_id = request.request_id
if request.request_id is None:
raise ValueError("Request ID must be set")
if request.order_book_id is None:
raise ValueError("Order book ID must be set")
if request.order_id is None:
raise ValueError("Order ID must be set")
order_book = self.order_book_manager.get_order_book(request.order_book_id)
order = order_book.get_order(request.order_id)
if order is None:
raise ValueError(f"Order with ID {request.order_id} not found")
remaining_quantity = order.open_quantity
logger.info(f"Cancelling order in order book {request.order_book_id}: {request.order_id}")
order_book.cancel_order(request.order_id)
logger.info(f"Order cancelled")
response.cancellation_timestamp = int(request_timestamp.timestamp())
response.remaining_quantity = remaining_quantity
except Exception as e:
logger.debug(f"Error cancelling order", exc_info=True)
response.error_message = str(e)
finally:
logger.debug(f"Sending cancel order response: {response}")
self.send_message(MessageType.ORDER_BOOK_CANCEL_ORDER_RESPONSE, response)
if response.error_message:
return
logger.debug(f"Publishing updates to connected clients")
self.client_manager.publish_order_cancelled(request.order_id, request_timestamp)

View File

@ -0,0 +1,62 @@
from datetime import datetime
from typing import Callable
import socket
import logging
from connection.ip_address import IpAddress
from connection.connection_handler import ConnectionHandlerFactory
from order_book.client_handler import OrderBookClientHandler
from order_book.i_clients_manager import IClientsManager, create_order_inserted_message
from order_book.models import ExecutionReport, Order
from order_book.order_book_manager import OrderBookManager
from proto.order_book_pb2 import *
from proto.common_pb2 import MessageType
from google.protobuf.message import Message
logger = logging.getLogger(__name__)
class ClientsManager(ConnectionHandlerFactory[OrderBookClientHandler], IClientsManager):
def __init__(self, order_book_manager: OrderBookManager) -> None:
self.order_book_manager = order_book_manager
self.connected_clients: dict[socket.socket, OrderBookClientHandler] = {}
def on_new_connection(self, socket_fd: socket.socket, ip_address: IpAddress,
close_callback: Callable[[], None]) -> OrderBookClientHandler:
client_handler = OrderBookClientHandler(socket_fd, ip_address, close_callback, self, self.order_book_manager)
assert socket_fd not in self.connected_clients
self.connected_clients[socket_fd] = client_handler
client_handler.publish_full_state()
return client_handler
def on_connection_closed(self, connection_handler: OrderBookClientHandler) -> None:
assert connection_handler.socket_fd in self.connected_clients
del self.connected_clients[connection_handler.socket_fd]
def publish_order_book_created(self, order_book_id: int, tick_size: float) -> None:
message = OnOrderBookCreated(order_book_id=order_book_id, tick_size=tick_size)
self._publish_update_to_all_clients(MessageType.ORDER_BOOK_ON_ORDER_BOOK_CREATED, message)
def publish_order_inserted(self, order: Order, order_book_id: int) -> None:
order_inserted_message = create_order_inserted_message(order, order_book_id)
self._publish_update_to_all_clients(MessageType.ORDER_BOOK_ON_ORDER_INSERTED, order_inserted_message)
def publish_order_cancelled(self, order_id: int, cancellation_timestamp: datetime) -> None:
order_cancelled_message = OnOrderCancelled(order_id=order_id, cancellation_timestamp=int(cancellation_timestamp.timestamp()))
self._publish_update_to_all_clients(MessageType.ORDER_BOOK_ON_ORDER_CANCELLED, order_cancelled_message)
def publish_trades(self, trades: list[ExecutionReport], order_book_id: int) -> None:
for execution_report in trades:
trade_message = OnTrade(
trade_id=execution_report.trade_id, order_book_id=order_book_id,
timestamp=int(execution_report.timestamp.timestamp()),
buy_order_id=execution_report.buy_order_id, sell_order_id=execution_report.sell_order_id,
price=float(execution_report.price), quantity=execution_report.quantity,
aggressor_side=execution_report.aggressor_side.to_proto())
self._publish_update_to_all_clients(MessageType.ORDER_BOOK_ON_TRADE, trade_message)
def _publish_update_to_all_clients(self, message_type: MessageType.ValueType, update_message: Message) -> None:
logger.info(f"Publishing update message of type {MessageType.Name(message_type)} to {len(self.connected_clients)} clients")
for client in self.connected_clients.values():
client.send_message(int(message_type), update_message)

View File

@ -0,0 +1,32 @@
from abc import ABC, abstractmethod
from datetime import datetime
from order_book.models import ExecutionReport, Order
from proto.order_book_pb2 import *
class IClientsManager(ABC):
@abstractmethod
def publish_order_book_created(self, order_book_id: int, tick_size: float) -> None:
pass
@abstractmethod
def publish_order_inserted(self, order: Order, order_book_id: int) -> None:
pass
@abstractmethod
def publish_order_cancelled(self, order_id: int, cancellation_timestamp: datetime) -> None:
pass
@abstractmethod
def publish_trades(self, trades: list[ExecutionReport], order_book_id: int) -> None:
pass
def create_order_inserted_message(order: Order, order_book_id: int) -> OnOrderInserted:
assert order.order_id is not None
proto_side = order.side.to_proto()
return OnOrderInserted(
order_id=order.order_id, order_book_id=order_book_id,
side=proto_side, price=float(order.price), quantity=order.original_quantity,
username=order.username, trade_ids=order.aggressive_trade_ids)

View File

@ -1,24 +1,25 @@
import logging import logging
from pathlib import Path from importlib.resources import files as resource_files
from application.application import BaseApplication from application.application import BaseApplication
from connection.ip_address import IpAddress from connection.ip_address import IpAddress
from connection.tcp_connection_manager import TcpConnectionManager from connection.tcp_connection_manager import TcpConnectionManager
from sample_app.connection_handler import PingPongClientHandlerFactory from order_book.client_handler import OrderBookManager
from order_book.clients_manager import ClientsManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SampleApplication(BaseApplication): class OrderBookApplication(BaseApplication):
def _start(self) -> None: def _start(self) -> None:
logger.info("Starting the sample application...") order_book_manager = OrderBookManager()
connection_handler_factory = PingPongClientHandlerFactory() clients_manager = ClientsManager(order_book_manager=order_book_manager)
tcp_connection_manager = TcpConnectionManager()
server_ip_address = IpAddress( server_ip_address = IpAddress(
host=self._config["listenOn"]["host"], host=self._config["listenOn"]["host"],
port=self._config["listenOn"]["port"]) port=self._config["listenOn"]["port"])
logger.info(f"Starting server on {server_ip_address}") logger.info(f"Starting server on {server_ip_address}")
with tcp_connection_manager.listen(server_ip_address, connection_handler_factory): tcp_connection_manager = TcpConnectionManager()
with tcp_connection_manager.listen(server_ip_address, clients_manager):
logger.info("Server started.") logger.info("Server started.")
logger.info("Running event loop until interrupted.") logger.info("Running event loop until interrupted.")
while True: while True:
@ -26,8 +27,8 @@ class SampleApplication(BaseApplication):
def main() -> None: def main() -> None:
config_schema_path = Path(__file__).parent.parent / "application" / "config_schema.json" config_schema_path = resource_files("application").joinpath("config_schema.json")
app = SampleApplication(config_schema=config_schema_path, app_name="sample_app") app = OrderBookApplication(config_schema=config_schema_path, app_name="order_book")
app.run() app.run()

View File

@ -0,0 +1,99 @@
import heapq
import logging
from datetime import datetime
from decimal import Decimal
from common.models import Side
from order_book.models import ExecutionReport, Order
logger = logging.getLogger(__name__)
class OrderBook:
def __init__(self, order_book_id: int, tick_size: Decimal) -> None:
self.order_book_id = order_book_id
self.tick_size = tick_size
self.bids: list[Order] = []
self.asks: list[Order] = []
self.order_map: dict[int, Order] = {}
self._next_order_id: int = 1
self._next_trade_id: int = 1
def get_order(self, order_id: int) -> Order | None:
return self.order_map.get(order_id)
def add_order(self, order: Order) -> int:
if order.price % self.tick_size != 0:
raise ValueError(f"Price {order.price} must be a multiple of tick size {self.tick_size}")
order.order_id = self._next_order_id
logger.debug(f"Adding order: {order}")
self._next_order_id += 1
self.order_map[order.order_id] = order
orders_heap = self.bids if order.side == Side.BUY else self.asks
heapq.heappush(orders_heap, order)
logger.info(f"Order added: {order}")
return order.order_id
def cancel_order(self, order_id: int) -> None:
logger.debug(f"Cancelling order: {order_id}")
order = self.order_map.pop(order_id, None)
if order is None:
raise ValueError(f"Order {order_id} not found")
orders_heap = self.bids if order.side == Side.BUY else self.asks
orders_heap.remove(order)
heapq.heapify(orders_heap)
logger.info(f"Order cancelled: {order_id}")
def match_orders(self, aggressor_side: Side) -> list[ExecutionReport]:
logger.debug("Matching orders...")
execution_reports: list[ExecutionReport] = []
while self.bids and self.asks:
best_bid: Order = self.bids[0]
best_ask: Order = self.asks[0]
assert best_bid.order_id is not None
assert best_ask.order_id is not None
logger.debug(f"Best bid: {best_bid.open_quantity} @ {best_bid.price}"
f", Best ask: {best_ask.open_quantity} @ {best_ask.price}")
if best_bid.price >= best_ask.price:
trade_id = self._next_trade_id
self._next_trade_id += 1
trade_quantity = min(best_bid.open_quantity, best_ask.open_quantity)
# TODO execution price might be an interesting puzzle for the group.
if aggressor_side == Side.BUY:
trade_price = best_ask.price
best_bid.aggressive_trade_ids.append(trade_id)
else:
trade_price = best_bid.price
best_ask.aggressive_trade_ids.append(trade_id)
execution_report = ExecutionReport(
trade_id=trade_id,
buy_order_id=best_bid.order_id,
sell_order_id=best_ask.order_id,
price=trade_price,
quantity=trade_quantity,
timestamp=datetime.now(),
aggressor_side=aggressor_side
)
execution_reports.append(execution_report)
logger.info(f"Trade executed: {trade_quantity} @ {best_ask.price}"
f" between {best_bid.order_id} and {best_ask.order_id}")
best_bid.executed_quantity += trade_quantity
if best_bid.open_quantity == 0:
logger.info(f"Removing fully executed buy order {best_bid.order_id}")
heapq.heappop(self.bids)
best_ask.executed_quantity += trade_quantity
if best_ask.open_quantity == 0:
logger.info(f"Removing fully executed sell order {best_ask.order_id}")
heapq.heappop(self.asks)
else:
break
logger.debug(f"Done matching orders. {len(execution_reports)} trades executed.")
return execution_reports

View File

@ -0,0 +1,49 @@
from dataclasses import dataclass, field
from datetime import datetime
from decimal import Decimal
from common.models import Side
@dataclass
class Order:
order_id: int | None
username: str
timestamp: datetime
side: Side
price: Decimal
original_quantity: int
aggressive_trade_ids: list[int] = field(default_factory=list)
executed_quantity: int = 0
@property
def open_quantity(self) -> int:
return self.original_quantity - self.executed_quantity
def __eq__(self, other: object) -> bool:
if not isinstance(other, Order):
return False
if self.order_id is None or other.order_id is None:
raise ValueError("Order ID must be set before comparing")
return self.order_id == other.order_id
def __lt__(self, other: 'Order') -> bool:
if not isinstance(other, Order):
return False
return self.__get_sorting_tuple() < other.__get_sorting_tuple()
def __get_sorting_tuple(self) -> tuple:
# We want to sort buy orders in descending order and sell orders in ascending order of price
price_multiplier = -1 if self.side == Side.BUY else 1
return (price_multiplier * self.price, self.timestamp, self.order_id)
@dataclass
class ExecutionReport:
trade_id: int
buy_order_id: int
sell_order_id: int
price: Decimal
quantity: int
timestamp: datetime
aggressor_side: Side

View File

@ -0,0 +1,8 @@
{
"logLevel": "DEBUG",
"logDirectory": "./logs",
"listenOn": {
"host": "localhost",
"port": 51364
}
}

View File

@ -0,0 +1,27 @@
from decimal import Decimal
from order_book.matching_engine import OrderBook
class OrderBookManager:
def __init__(self) -> None:
self.order_book_map: dict[int, OrderBook] = {}
self._next_order_book_id = 1
def create_order_book(self, tick_size: float) -> OrderBook:
order_book_id = self._next_order_book_id
self._next_order_book_id += 1
new_order_book = OrderBook(order_book_id, Decimal(str(tick_size)))
self.order_book_map[order_book_id] = new_order_book
return new_order_book
def delete_order_book(self, order_book_id: int) -> None:
if order_book_id not in self.order_book_map:
raise ValueError(f"Order book {order_book_id} not found")
del self.order_book_map[order_book_id]
def get_order_book(self, order_book_id: int) -> OrderBook:
if order_book_id not in self.order_book_map:
raise ValueError(f"Order book {order_book_id} not found")
return self.order_book_map[order_book_id]

View File

View File

View File

@ -0,0 +1,270 @@
import logging
import socket
from typing import Callable
from common.client_handler import BaseClientHandler
from application.data_file_reader import DataFileReader
from common.models import Side
from common.utils import decimal_from_float
from connection.connection_handler import ConnectionHandlerFactory
from connection.ip_address import IpAddress
from proto.common_pb2 import LoginRequest
from proto.common_pb2 import MessageType
from proto.order_book_pb2 import InsertOrderResponse as OrderBookInsertOrderResponse, CancelOrderResponse as OrderBookCancelOrderResponse
from proto.execution_pb2 import (
CancelOrderRequest, CancelOrderResponse, InsertOrderRequest, InsertOrderResponse
)
from proto.risk_limits_pb2 import (
GetInstrumentRiskLimitsResponse, GetUserRiskLimitsResponse,
GetUserRiskLimitsRequest, SetInstrumentRiskLimitsResponse, SetUserRiskLimitsRequest,
GetInstrumentRiskLimitsRequest, SetInstrumentRiskLimitsRequest, SetUserRiskLimitsResponse
)
from risk_gateway.order_book_client import OrderBookClient
from risk_gateway.risk_limits_store import InstrumentRiskLimits, RiskLimitsStore, UserRiskLimits
logger = logging.getLogger(__name__)
class GatewayClientConnectionHandler(BaseClientHandler):
def __init__(self, socket_fd: socket.socket, ip_address: IpAddress, close_callback: Callable[[], None],
risk_limits_store: RiskLimitsStore, order_book_client: OrderBookClient,
data_file_reader: DataFileReader) -> None:
super().__init__(socket_fd, ip_address, close_callback, data_file_reader)
self.risk_limits_store = risk_limits_store
self.order_book_client = order_book_client
def on_disconnect(self) -> None:
logger.info(f"Client {self.ip_address} disconnected")
def handle_message(self, message_type: int, message: bytes) -> None:
logger.info(f"Handling message of type {MessageType.Name(message_type)} by {self.username}") # type: ignore
if message_type == MessageType.AUTH_LOGIN_REQUEST:
self._handle_login_request(message)
elif message_type == MessageType.EXEC_INSERT_ORDER_REQUEST:
self._handle_insert_order_request(message)
elif message_type == MessageType.EXEC_CANCEL_ORDER_REQUEST:
self._handle_cancel_order_request(message)
elif message_type == MessageType.RISK_GET_USER_LIMITS_REQUEST:
self._handle_get_user_risk_limits_request(message)
elif message_type == MessageType.RISK_SET_USER_LIMITS_REQUEST:
self._handle_set_user_risk_limits_request(message)
elif message_type == MessageType.RISK_GET_INSTRUMENT_LIMITS_REQUEST:
self._handle_get_instrument_risk_limits_request(message)
elif message_type == MessageType.RISK_SET_INSTRUMENT_LIMITS_REQUEST:
self._handle_set_instrument_risk_limits_request(message)
else:
raise ValueError(f"Unexpected message type received: {message_type}")
def _handle_login_request(self, message: bytes) -> None:
request = self._deserialize_message(LoginRequest, message)
response = self._login(request)
logger.debug(f"Sending login response: {response}")
self.send_message(MessageType.AUTH_LOGIN_RESPONSE, response)
if response.error_message:
raise ValueError("Login failed, client will be disconnected")
def _handle_insert_order_request(self, message: bytes) -> None:
try:
request = self._deserialize_message(InsertOrderRequest, message)
logger.debug(f"Received insert order request: {request}")
if not self.logged_in:
raise ValueError("Client is not logged in")
if request.request_id is None:
raise ValueError("Request ID must be set")
if request.instrument_symbol is None:
raise ValueError("Instrument symbol must be set")
if request.side is None:
raise ValueError("Side must be set")
if request.price is None:
raise ValueError("Price must be set")
if request.quantity is None:
raise ValueError("Quantity must be set")
request_side = Side.from_proto(request.side)
assert self.username is not None
price_decimal = decimal_from_float(request.price)
limit_error = self.risk_limits_store.check_order_limits(
self.username, request.instrument_symbol, price_decimal, request.quantity)
if limit_error:
logger.info(f"Order rejected by risk limits: {limit_error}")
error_response = InsertOrderResponse(
request_id=request.request_id, error_message=limit_error)
self.send_message(MessageType.EXEC_INSERT_ORDER_RESPONSE, error_response)
return
self.risk_limits_store.record_order_attempt(
self.username, request.instrument_symbol, price_decimal, request.quantity)
client_request_id = request.request_id
def on_response(ob_response: OrderBookInsertOrderResponse) -> None:
logger.debug(f"Received insert order response from order book: {ob_response}")
exec_response = InsertOrderResponse()
exec_response.request_id = client_request_id
exec_response.error_message = ob_response.error_message
exec_response.order_id = ob_response.order_id
exec_response.timestamp = ob_response.timestamp
exec_response.trade_ids.extend(ob_response.trade_ids)
exec_response.traded_quantity = ob_response.traded_quantity
self.send_message(MessageType.EXEC_INSERT_ORDER_RESPONSE, exec_response)
self.order_book_client.insert_order_for_user(
request.instrument_symbol, self.username, request_side, price_decimal, request.quantity, on_response)
except Exception as e:
logger.debug(f"Error inserting order", exc_info=True)
error_response = InsertOrderResponse(
request_id=request.request_id, error_message=str(e))
logger.debug(f"Sending insert order response: {error_response}")
self.send_message(MessageType.EXEC_INSERT_ORDER_RESPONSE, error_response)
def _handle_cancel_order_request(self, message: bytes) -> None:
try:
request = self._deserialize_message(CancelOrderRequest, message)
logger.debug(f"Received cancel order request: {request}")
if not self.logged_in:
raise ValueError("Client is not logged in")
if request.request_id is None:
raise ValueError("Request ID must be set")
if request.instrument_symbol is None:
raise ValueError("Instrument symbol must be set")
if request.order_id is None:
raise ValueError("Order ID must be set")
assert self.username is not None
client_request_id = request.request_id
def on_response(ob_response: OrderBookCancelOrderResponse) -> None:
logger.debug(f"Received cancel order response from order book: {ob_response}")
exec_response = CancelOrderResponse()
exec_response.request_id = client_request_id
exec_response.error_message = ob_response.error_message
self.send_message(MessageType.EXEC_CANCEL_ORDER_RESPONSE, exec_response)
self.order_book_client.cancel_order_for_user(request.instrument_symbol, self.username, request.order_id, on_response)
except Exception as e:
logger.debug(f"Error canceling order", exc_info=True)
error_response = CancelOrderResponse(
request_id=request.request_id, error_message=str(e))
logger.debug(f"Sending cancel order response: {error_response}")
self.send_message(MessageType.EXEC_CANCEL_ORDER_RESPONSE, error_response)
def _handle_get_user_risk_limits_request(self, message: bytes) -> None:
response = GetUserRiskLimitsResponse()
try:
request = self._deserialize_message(GetUserRiskLimitsRequest, message)
logger.debug(f"Received get user risk limits request: {request}")
response.request_id = request.request_id
if not self.logged_in:
raise ValueError("Client is not logged in")
if request.request_id is None:
raise ValueError("Request ID must be set")
assert self.username is not None
user_risk_limits = self.risk_limits_store.get_user_risk_limits(self.username)
if user_risk_limits is None:
response.error_message = "User risk limits not found"
else:
logger.debug(f"User risk limits: {user_risk_limits}")
response.user_risk_limits.CopyFrom(user_risk_limits.to_proto())
except Exception as e:
logger.debug(f"Error getting user risk limits", exc_info=True)
response.error_message = str(e)
finally:
logger.debug(f"Sending get user risk limits response: {response}")
self.send_message(MessageType.RISK_GET_USER_LIMITS_RESPONSE, response)
def _handle_set_user_risk_limits_request(self, message: bytes) -> None:
response = SetUserRiskLimitsResponse()
try:
request = self._deserialize_message(SetUserRiskLimitsRequest, message)
logger.debug(f"Received set user risk limits request: {request}")
response.request_id = request.request_id
if not self.logged_in:
raise ValueError("Client is not logged in")
if request.request_id is None:
raise ValueError("Request ID must be set")
if request.user_risk_limits is None:
self.risk_limits_store.clear_user_risk_limits(self.username)
else:
user_risk_limits = UserRiskLimits.from_proto(request.user_risk_limits)
assert user_risk_limits is not None and self.username is not None
self.risk_limits_store.set_user_risk_limits(self.username, user_risk_limits)
except Exception as e:
logger.debug(f"Error setting user risk limits", exc_info=True)
response.error_message = str(e)
finally:
logger.debug(f"Sending set user risk limits response: {response}")
self.send_message(MessageType.RISK_SET_USER_LIMITS_RESPONSE, response)
def _handle_get_instrument_risk_limits_request(self, message: bytes) -> None:
response = GetInstrumentRiskLimitsResponse()
try:
request = self._deserialize_message(GetInstrumentRiskLimitsRequest, message)
logger.debug(f"Received get instrument risk limits request: {request}")
response.request_id = request.request_id
if not self.logged_in:
raise ValueError("Client is not logged in")
if request.request_id is None:
raise ValueError("Request ID must be set")
assert self.username is not None
instrument_risk_limits = self.risk_limits_store.get_all_instrument_risk_limits(self.username)
logger.debug(f"Returning {len(instrument_risk_limits)} instrument risk limits")
for instrument_symbol, risk_limits in instrument_risk_limits.items():
response.risk_limits_by_instrument[instrument_symbol].CopyFrom(risk_limits.to_proto())
except Exception as e:
logger.debug(f"Error getting instrument risk limits", exc_info=True)
response.error_message = str(e)
finally:
logger.debug(f"Sending get instrument risk limits response: {response}")
self.send_message(MessageType.RISK_GET_INSTRUMENT_LIMITS_RESPONSE, response)
def _handle_set_instrument_risk_limits_request(self, message: bytes) -> None:
response = SetInstrumentRiskLimitsResponse()
try:
request = self._deserialize_message(SetInstrumentRiskLimitsRequest, message)
logger.info(f"Received set instrument risk limits request: {request}")
response.request_id = request.request_id
if not self.logged_in:
raise ValueError("Client is not logged in")
if request.request_id is None:
raise ValueError("Request ID must be set")
if request.instrument_symbol is None:
raise ValueError("Instrument symbol must be set")
if request.instrument_risk_limits is None:
self.risk_limits_store.clear_instrument_risk_limits(request.instrument_symbol)
else:
risk_limits = InstrumentRiskLimits.from_proto(request.instrument_risk_limits)
assert risk_limits is not None and self.username is not None
self.risk_limits_store.set_instrument_risk_limits(self.username, request.instrument_symbol, risk_limits)
except Exception as e:
logger.debug(f"Error setting instrument risk limits", exc_info=True)
response.error_message = str(e)
finally:
logger.debug(f"Sending set instrument risk limits response: {response}")
self.send_message(MessageType.RISK_SET_INSTRUMENT_LIMITS_RESPONSE, response)
class GatewayClientConnectionHandlerFactory(ConnectionHandlerFactory[GatewayClientConnectionHandler]):
def __init__(self, risk_limits_store: RiskLimitsStore, order_book_client: OrderBookClient,
data_file_reader: DataFileReader) -> None:
self.risk_limits_store = risk_limits_store
self.order_book_client = order_book_client
self.data_file_reader = data_file_reader
def on_new_connection(self, socket_fd: socket.socket, ip_address: IpAddress,
close_callback: Callable[[], None]) -> GatewayClientConnectionHandler:
return GatewayClientConnectionHandler(
socket_fd, ip_address, close_callback,
self.risk_limits_store, self.order_book_client, self.data_file_reader)
def on_connection_closed(self, connection_handler: GatewayClientConnectionHandler) -> None:
pass

View File

@ -0,0 +1,39 @@
import socket
from typing import Callable
from common.info_client import BaseInfoClient
from connection.ip_address import IpAddress
from proto.common_pb2 import LoginRequest, LoginResponse
from proto.info_pb2 import OnInstrument, OnPriceDepthBook, OnTopOfBook, OnTrade
class InfoClientMapping(BaseInfoClient):
def __init__(self, socket_fd: socket.socket, ip_address: IpAddress, on_close: Callable[[], None]) -> None:
super().__init__(socket_fd, ip_address, on_close)
self.symbol_to_order_book_id: dict[str, int] = {}
self.order_book_id_to_symbol: dict[int, str] = {}
def login(self, on_login_response: Callable[[LoginResponse], None]) -> None:
request = LoginRequest(username="risk_gateway", password="risk_gateway")
self.send_login(request, on_login_response)
def get_order_book_id_by_instrument(self, symbol: str) -> int | None:
return self.symbol_to_order_book_id.get(symbol)
def get_instrument_symbol_by_order_book_id(self, order_book_id: int) -> str | None:
return self.order_book_id_to_symbol.get(order_book_id)
def on_instrument(self, message: OnInstrument) -> None:
symbol = message.instrument.symbol
order_book_id = message.order_book_id
assert symbol not in self.symbol_to_order_book_id
self.symbol_to_order_book_id[symbol] = order_book_id
self.order_book_id_to_symbol[order_book_id] = symbol
def on_top_of_book(self, message: OnTopOfBook) -> None:
pass
def on_price_depth_book(self, message: OnPriceDepthBook) -> None:
pass
def on_trade(self, message: OnTrade) -> None:
pass

View File

@ -0,0 +1,67 @@
import logging
from importlib.resources import files as resource_files
from application.application import BaseApplication
from application.data_file_reader import DataFileReader
from connection.ip_address import IpAddress
from connection.tcp_connection_manager import TcpConnectionManager
from proto.common_pb2 import LoginResponse
from risk_gateway.connection_handler import GatewayClientConnectionHandlerFactory
from risk_gateway.info_client_mapping import InfoClientMapping
from risk_gateway.order_book_client import OrderBookClient
from risk_gateway.risk_limits_store import RiskLimitsStore
logger = logging.getLogger(__name__)
class RiskGatewayApplication(BaseApplication):
def _start(self) -> None:
tcp_connection_manager = TcpConnectionManager()
data_file_reader = DataFileReader(self._config["dataFilePath"])
logger.info("Connecting to info server...")
info_service_ip_address = IpAddress(
host=self._config["connectTo"]["info"]["host"],
port=self._config["connectTo"]["info"]["port"])
info_client: InfoClientMapping = tcp_connection_manager.connect(
info_service_ip_address, lambda *args: InfoClientMapping(*args))
info_client.login(self._on_info_login_response)
risk_limits_store = RiskLimitsStore()
logger.info("Connecting to order book server...")
order_book_service_ip_address = IpAddress(
host=self._config["connectTo"]["order_book"]["host"],
port=self._config["connectTo"]["order_book"]["port"])
order_book_client: OrderBookClient = tcp_connection_manager.connect(
order_book_service_ip_address, lambda *args: OrderBookClient(*(*args, info_client, risk_limits_store)))
connection_handler_factory = GatewayClientConnectionHandlerFactory(
risk_limits_store=risk_limits_store,
order_book_client=order_book_client,
data_file_reader=data_file_reader)
server_ip_address = IpAddress(
host=self._config["listenOn"]["host"],
port=self._config["listenOn"]["port"])
logger.info(f"Starting server on {server_ip_address}")
with tcp_connection_manager.listen(server_ip_address, connection_handler_factory):
logger.info("Server started.")
logger.info("Running event loop until interrupted.")
while True:
tcp_connection_manager.wait_for_events()
def _on_info_login_response(self, response: LoginResponse) -> None:
if response.error_message:
logger.error(f"Failed to log in to info server: {response}")
raise SystemExit(1)
logger.info(f"Logged in to info server: {response}")
def main():
config_schema_path = resource_files("application").joinpath("config_schema.json")
app = RiskGatewayApplication(config_schema=config_schema_path, app_name="risk_gateway")
app.run()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,81 @@
from decimal import Decimal
import logging
import socket
from typing import Callable
from common.models import Side
from common.order_book_client import BaseOrderBookClient
from common.utils import decimal_from_float
from connection.ip_address import IpAddress
from proto.order_book_pb2 import (
InsertOrderRequest, InsertOrderResponse,
CancelOrderRequest, CancelOrderResponse,
OnOrderBookCreated, OnOrderCancelled, OnOrderInserted, OnTrade,
)
from risk_gateway.info_client_mapping import InfoClientMapping
from risk_gateway.risk_limits_store import RiskLimitsStore
logger = logging.getLogger(__name__)
class OrderBookClient(BaseOrderBookClient):
def __init__(self, socket_fd: socket.socket, ip_address: IpAddress, on_close: Callable[[], None],
info_client: InfoClientMapping, risk_limits_store: RiskLimitsStore) -> None:
super().__init__(socket_fd, ip_address, on_close)
self.info_client = info_client
self.risk_limits_store = risk_limits_store
self.username_by_order_id: dict[int, str] = {}
def insert_order_for_user(self, instrument_symbol: str, username: str, side: Side, price: Decimal, quantity: int, callback: Callable[[InsertOrderResponse], None]) -> int:
order_book_id = self.info_client.get_order_book_id_by_instrument(instrument_symbol)
if order_book_id is None:
raise ValueError(f"Order book for instrument {instrument_symbol} not found")
request = InsertOrderRequest(order_book_id=order_book_id, username=username, side=side.to_proto(), price=float(price), quantity=quantity)
def on_insert_order_response(response: InsertOrderResponse) -> None:
if response.order_id is not None:
self.username_by_order_id[response.order_id] = username
callback(response)
return self.send_insert_order(request, on_insert_order_response)
def cancel_order_for_user(self, instrument_symbol: str, username: str, order_id: int, callback: Callable[[CancelOrderResponse], None]) -> None:
if order_id not in self.username_by_order_id:
raise ValueError(f"Order {order_id} not found")
if self.username_by_order_id[order_id] != username:
raise ValueError(f"Order {order_id} does not belong to user {username}")
order_book_id = self.info_client.get_order_book_id_by_instrument(instrument_symbol)
if order_book_id is None:
raise ValueError(f"Order book for instrument {instrument_symbol} not found")
request = CancelOrderRequest(order_book_id=order_book_id, order_id=order_id)
def on_cancel_order_response(response: CancelOrderResponse) -> None:
if not response.error_message:
del self.username_by_order_id[order_id]
callback(response)
self.send_cancel_order(request, on_cancel_order_response)
def on_order_book_created(self, message: OnOrderBookCreated) -> None:
logger.debug(f"Order book created: id={message.order_book_id}")
def on_order_inserted(self, message: OnOrderInserted) -> None:
instrument_symbol = self.info_client.get_instrument_symbol_by_order_book_id(message.order_book_id)
if instrument_symbol is None:
logger.warning(f"Unknown order_book_id {message.order_book_id} in OnOrderInserted")
return
logger.debug(f"Order inserted: id={message.order_id}, instrument={instrument_symbol}, "
f"user={message.username}, price={message.price}, qty={message.quantity}")
self.risk_limits_store.track_order(
message.order_id, message.username, instrument_symbol,
decimal_from_float(message.price), message.quantity)
def on_order_cancelled(self, message: OnOrderCancelled) -> None:
logger.debug(f"Order cancelled: id={message.order_id}")
self.risk_limits_store.untrack_order(message.order_id)
def on_trade(self, message: OnTrade) -> None:
logger.debug(f"Trade: id={message.trade_id}, buy_order={message.buy_order_id}, "
f"sell_order={message.sell_order_id}, qty={message.quantity}")
self.risk_limits_store.reduce_order_quantity(message.buy_order_id, message.quantity)
self.risk_limits_store.reduce_order_quantity(message.sell_order_id, message.quantity)

View File

@ -3,16 +3,16 @@
"logDirectory": "./logs", "logDirectory": "./logs",
"listenOn": { "listenOn": {
"host": "localhost", "host": "localhost",
"port": 51301 "port": 51366
}, },
"connectTo": { "connectTo": {
"client1": { "order_book": {
"host": "localhost", "host": "localhost",
"port": 51302 "port": 51364
}, },
"client2": { "info": {
"host": "localhost", "host": "localhost",
"port": 51303 "port": 51365
} }
} }
} }

View File

@ -0,0 +1,290 @@
from decimal import Decimal
import logging
import time
from dataclasses import dataclass
from typing import NamedTuple, Optional
from common.utils import decimal_from_float
from proto.risk_limits_pb2 import (
RollingWindowLimit as ProtoRollingWindowLimit, InstrumentRiskLimits as ProtoInstrumentRiskLimits,
UserRiskLimits as ProtoUserRiskLimits
)
logger = logging.getLogger(__name__)
@dataclass
class RollingWindowLimit:
limit: int
window_in_seconds: int
def to_proto(self) -> ProtoRollingWindowLimit:
return ProtoRollingWindowLimit(limit=self.limit, window_in_seconds=self.window_in_seconds)
@staticmethod
def from_proto(proto: ProtoRollingWindowLimit) -> Optional['RollingWindowLimit']:
if proto is None:
return None
return RollingWindowLimit(limit=proto.limit, window_in_seconds=proto.window_in_seconds)
@dataclass
class InstrumentRiskLimits:
max_outstanding_quantity: int | None
max_outstanding_amount: Decimal | None
order_quantity_rolling_limit: RollingWindowLimit | None
order_amount_rolling_limit: RollingWindowLimit | None
def to_proto(self) -> ProtoInstrumentRiskLimits:
return ProtoInstrumentRiskLimits(
max_outstanding_quantity=self.max_outstanding_quantity, # type: ignore
max_outstanding_amount=float(self.max_outstanding_amount) if self.max_outstanding_amount is not None else None,
order_quantity_rolling_limit=self.order_quantity_rolling_limit.to_proto() if self.order_quantity_rolling_limit is not None else None,
order_amount_rolling_limit=self.order_amount_rolling_limit.to_proto() if self.order_amount_rolling_limit is not None else None
)
@staticmethod
def from_proto(proto: ProtoInstrumentRiskLimits | None) -> Optional['InstrumentRiskLimits']:
if proto is None:
return None
return InstrumentRiskLimits(
max_outstanding_quantity=proto.max_outstanding_quantity,
max_outstanding_amount=decimal_from_float(proto.max_outstanding_amount),
order_quantity_rolling_limit=RollingWindowLimit.from_proto(proto.order_quantity_rolling_limit),
order_amount_rolling_limit=RollingWindowLimit.from_proto(proto.order_amount_rolling_limit)
)
@dataclass
class UserRiskLimits:
max_outstanding_quantity: int | None
message_rate_rolling_limit: RollingWindowLimit | None
def to_proto(self) -> ProtoUserRiskLimits:
return ProtoUserRiskLimits(
max_outstanding_quantity=self.max_outstanding_quantity, # type: ignore
message_rate_rolling_limit=self.message_rate_rolling_limit.to_proto() if self.message_rate_rolling_limit is not None else None
)
@staticmethod
def from_proto(proto: ProtoUserRiskLimits | None) -> Optional['UserRiskLimits']:
if proto is None:
return None
return UserRiskLimits(
max_outstanding_quantity=proto.max_outstanding_quantity,
message_rate_rolling_limit=RollingWindowLimit.from_proto(proto.message_rate_rolling_limit)
)
@dataclass
class RiskLimits:
user_risk_limits: UserRiskLimits | None
instrument_risk_limits_by_symbol: dict[str, InstrumentRiskLimits]
@dataclass
class TrackedOrder:
order_id: int
username: str
instrument_symbol: str
price: Decimal
remaining_quantity: int
@dataclass
class TimestampedValue:
timestamp: float
value: Decimal
class UserInstrumentKey(NamedTuple):
username: str
instrument_symbol: str
class RiskLimitsStore:
def __init__(self) -> None:
self.limits_by_username: dict[str, RiskLimits] = {}
self.tracked_orders: dict[int, TrackedOrder] = {}
self._user_outstanding_qty: dict[str, int] = {}
self._instrument_outstanding_qty: dict[UserInstrumentKey, int] = {}
self._instrument_outstanding_amt: dict[UserInstrumentKey, Decimal] = {}
self.message_timestamps: dict[str, list[float]] = {}
self.order_quantity_events: dict[UserInstrumentKey, list[TimestampedValue]] = {}
self.order_amount_events: dict[UserInstrumentKey, list[TimestampedValue]] = {}
# ---- Limit configuration ----
def get_user_risk_limits(self, username: str) -> UserRiskLimits | None:
risk_limits = self.limits_by_username.get(username)
if risk_limits is None:
return None
return risk_limits.user_risk_limits
def set_user_risk_limits(self, username: str, risk_limits: UserRiskLimits) -> None:
if username not in self.limits_by_username:
self.limits_by_username[username] = RiskLimits(user_risk_limits=risk_limits, instrument_risk_limits_by_symbol={})
else:
self.limits_by_username[username].user_risk_limits = risk_limits
def clear_user_risk_limits(self, username: str) -> None:
risk_limits = self.limits_by_username.get(username)
if risk_limits is None:
return
risk_limits.user_risk_limits = None
def get_all_instrument_risk_limits(self, username: str) -> dict[str, InstrumentRiskLimits]:
risk_limits = self.limits_by_username.get(username)
if risk_limits is None:
return {}
return risk_limits.instrument_risk_limits_by_symbol
def get_instrument_risk_limits(self, username: str, instrument_symbol: str) -> InstrumentRiskLimits | None:
risk_limits = self.limits_by_username.get(username)
if risk_limits is None:
return None
return risk_limits.instrument_risk_limits_by_symbol.get(instrument_symbol)
def set_instrument_risk_limits(self, username: str, instrument_symbol: str, risk_limits: InstrumentRiskLimits) -> None:
if username not in self.limits_by_username:
self.limits_by_username[username] = RiskLimits(user_risk_limits=None, instrument_risk_limits_by_symbol={})
self.limits_by_username[username].instrument_risk_limits_by_symbol[instrument_symbol] = risk_limits
def clear_instrument_risk_limits(self, username: str, instrument_symbol: str) -> None:
risk_limits = self.limits_by_username.get(username)
if risk_limits is None:
return
risk_limits.instrument_risk_limits_by_symbol.pop(instrument_symbol, None)
# ---- Live order tracking ----
def track_order(self, order_id: int, username: str, instrument_symbol: str, price: Decimal, quantity: int) -> None:
self.tracked_orders[order_id] = TrackedOrder(
order_id=order_id, username=username,
instrument_symbol=instrument_symbol, price=price,
remaining_quantity=quantity)
key = UserInstrumentKey(username, instrument_symbol)
self._user_outstanding_qty[username] = self._user_outstanding_qty.get(username, 0) + quantity
self._instrument_outstanding_qty[key] = self._instrument_outstanding_qty.get(key, 0) + quantity
self._instrument_outstanding_amt[key] = self._instrument_outstanding_amt.get(key, Decimal(0.0)) + price * quantity
def untrack_order(self, order_id: int) -> None:
order = self.tracked_orders.pop(order_id, None)
if order is None:
return
self._adjust_totals(order, -order.remaining_quantity)
def reduce_order_quantity(self, order_id: int, traded_quantity: int) -> None:
order = self.tracked_orders.get(order_id)
if order is None:
return
reduction = min(traded_quantity, order.remaining_quantity)
self._adjust_totals(order, -reduction)
order.remaining_quantity -= reduction
if order.remaining_quantity <= 0:
del self.tracked_orders[order_id]
def _adjust_totals(self, order: TrackedOrder, qty_delta: int) -> None:
self._user_outstanding_qty[order.username] = self._user_outstanding_qty.get(order.username, 0) + qty_delta
key = UserInstrumentKey(order.username, order.instrument_symbol)
self._instrument_outstanding_qty[key] = self._instrument_outstanding_qty.get(key, 0) + qty_delta
self._instrument_outstanding_amt[key] = self._instrument_outstanding_amt.get(key, Decimal(0.0)) + order.price * qty_delta
# ---- Limit checking ----
def check_order_limits(self, username: str, instrument_symbol: str, price: Decimal, quantity: int) -> str | None:
"""Returns an error message if any limit would be breached, None otherwise."""
now = time.time()
key = UserInstrumentKey(username, instrument_symbol)
user_limits = self.get_user_risk_limits(username)
if user_limits is not None:
if user_limits.max_outstanding_quantity and user_limits.max_outstanding_quantity > 0:
current = self._user_outstanding_qty.get(username, 0)
if current + quantity > user_limits.max_outstanding_quantity:
return (f"User max outstanding quantity breached: "
f"{current} + {quantity} > {user_limits.max_outstanding_quantity}")
rl = user_limits.message_rate_rolling_limit
if rl and rl.limit > 0 and rl.window_in_seconds > 0:
count = self._count_in_window(
self.message_timestamps.get(username, []), rl.window_in_seconds, now)
if count + 1 > rl.limit:
return (f"User message rate limit breached: "
f"{count} + 1 > {rl.limit} in {rl.window_in_seconds}s window")
instrument_limits = self.get_instrument_risk_limits(username, instrument_symbol)
if instrument_limits is not None:
if instrument_limits.max_outstanding_quantity and instrument_limits.max_outstanding_quantity > 0:
current = self._instrument_outstanding_qty.get(key, 0)
if current + quantity > instrument_limits.max_outstanding_quantity:
return (f"Instrument max outstanding quantity breached: "
f"{current} + {quantity} > {instrument_limits.max_outstanding_quantity}")
if instrument_limits.max_outstanding_amount and instrument_limits.max_outstanding_amount > 0:
current = self._instrument_outstanding_amt.get(key, Decimal(0.0))
new_amount = price * quantity
if current + new_amount > instrument_limits.max_outstanding_amount:
return (f"Instrument max outstanding amount breached: "
f"{current} + {new_amount} > {instrument_limits.max_outstanding_amount}")
rl = instrument_limits.order_quantity_rolling_limit
if rl and rl.limit > 0 and rl.window_in_seconds > 0:
current = self._sum_in_window(
self.order_quantity_events.get(key, []), rl.window_in_seconds, now)
if current + quantity > rl.limit:
return (f"Instrument order quantity rolling limit breached: "
f"{current} + {quantity} > {rl.limit} in {rl.window_in_seconds}s window")
rl = instrument_limits.order_amount_rolling_limit
if rl and rl.limit > 0 and rl.window_in_seconds > 0:
current = self._sum_in_window(
self.order_amount_events.get(key, []), rl.window_in_seconds, now)
new_amount = price * quantity
if current + new_amount > rl.limit:
return (f"Instrument order amount rolling limit breached: "
f"{current} + {new_amount} > {rl.limit} in {rl.window_in_seconds}s window")
return None
def record_order_attempt(self, username: str, instrument_symbol: str, price: Decimal, quantity: int) -> None:
"""Record rolling-window events after an order passes limit checks."""
now = time.time()
self.message_timestamps.setdefault(username, []).append(now)
key = UserInstrumentKey(username, instrument_symbol)
self.order_quantity_events.setdefault(key, []).append(TimestampedValue(now, Decimal(quantity)))
self.order_amount_events.setdefault(key, []).append(TimestampedValue(now, price * quantity))
# ---- Internal helpers ----
@staticmethod
def _count_in_window(timestamps: list[float], window_seconds: int, now: float) -> int:
cutoff = now - window_seconds
first_valid = 0
for i, t in enumerate(timestamps):
if t >= cutoff:
first_valid = i
break
else:
timestamps.clear()
return 0
if first_valid > 0:
del timestamps[:first_valid]
return len(timestamps)
@staticmethod
def _sum_in_window(events: list[TimestampedValue], window_seconds: int, now: float) -> Decimal:
cutoff = now - window_seconds
first_valid = 0
for i, e in enumerate(events):
if e.timestamp >= cutoff:
first_valid = i
break
else:
events.clear()
return Decimal(0)
if first_valid > 0:
del events[:first_valid]
return sum(e.value for e in events)

View File

View File

View File

@ -0,0 +1,52 @@
"""Reusable authentication tests for components with ``authRequired=true``.
Any test class that includes :class:`AuthenticationTests` as a mixin must:
1. Implement :meth:`_connect_unauthenticated` returning a client with a
``test_login(username, password, expect_success)`` method.
2. Have ``self.auth_required``, ``self.test_name``, and
``self.call_expectations_manager`` available (provided by the standard
``conftest.py`` fixtures and per-class ``setup``).
Tests are automatically skipped when ``self.auth_required`` is ``False``.
"""
import pytest
from tests.conftest import DEFAULT_PASSWORD
class AuthenticationTests:
def _connect_unauthenticated(self):
"""Connect to the component under test *without* logging in."""
raise NotImplementedError
def test_login_with_valid_credentials(self) -> None:
if not self.auth_required:
pytest.skip("Component does not require authentication")
client = self._connect_unauthenticated()
client.test_login(username=self.test_name, password=DEFAULT_PASSWORD)
self.call_expectations_manager.verify_no_unexpected_calls()
def test_login_with_invalid_username(self) -> None:
if not self.auth_required:
pytest.skip("Component does not require authentication")
client = self._connect_unauthenticated()
client.test_login(username="invalid_username", expect_success=False)
self.call_expectations_manager.verify_no_unexpected_calls()
def test_login_with_wrong_password(self) -> None:
if not self.auth_required:
pytest.skip("Component does not require authentication")
client = self._connect_unauthenticated()
client.test_login(username=self.test_name, password="wrong_password", expect_success=False)
self.call_expectations_manager.verify_no_unexpected_calls()
def test_login_twice_on_same_connection_is_rejected(self) -> None:
if not self.auth_required:
pytest.skip("Component does not require authentication")
client = self._connect_unauthenticated()
client.test_login(username=self.test_name, password=DEFAULT_PASSWORD)
client.test_login(username=self.test_name, expect_success=False)
self.call_expectations_manager.verify_no_unexpected_calls()

View File

@ -0,0 +1,265 @@
"""Generic component orchestrator for system tests.
Reads deployment_config.json and testing_dependencies.json to determine which
components to start for a given protocol, resolves their startup order via
topological sort on connectTo references, assigns dynamic ports, and manages
process lifecycles.
"""
import json
import logging
import socket
from collections import deque
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from connection.ip_address import IpAddress
from tests.common.process_manager import PerformanceStats, ProcessManager
from tests.common import constants
logger = logging.getLogger(__name__)
COMPONENT_DEPENDENCIES_FILE = Path(__file__).parent.parent / "continuous_deployment" / "testing_dependencies.json"
def _allocate_ports(count: int) -> list[int]:
"""Reserve `count` ephemeral ports from the OS, minimising collision risk."""
sockets: list[socket.socket] = []
ports: list[int] = []
try:
for _ in range(count):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(("localhost", 0))
ports.append(s.getsockname()[1])
sockets.append(s)
finally:
for s in sockets:
s.close()
return ports
@dataclass
class ComponentInstance:
name: str
package_name: str
protocols: list[str]
original_config: dict[str, Any]
auth_required: bool
address: IpAddress | None = None
process_manager: ProcessManager | None = None
class ComponentOrchestrator:
"""Manages the lifecycle of all components needed to test a specific protocol.
Typical usage inside a pytest fixture::
orchestrator = ComponentOrchestrator(venv_path, deployment_config, users_data_file)
orchestrator.start_for_protocol("order_book")
server_addr = orchestrator.get_server_address("order_book")
yield
orchestrator.stop_all()
"""
def __init__(self, venv_path: Path, deployment_config: dict[str, Any],
data_file_path: Path) -> None:
self.venv_path = venv_path
self.deployment_config = deployment_config
self.data_file_path = data_file_path
self.component_dependencies = self._load_component_dependencies()
self._running_components: list[ComponentInstance] = []
self._protocol_to_address: dict[str, IpAddress] = {}
@staticmethod
def _load_component_dependencies() -> dict[str, Any]:
with open(COMPONENT_DEPENDENCIES_FILE) as f:
return json.load(f)
def start_for_protocol(self, protocol: str) -> None:
"""Resolve dependencies and start every component needed to test *protocol*."""
required_protocols = self._resolve_required_protocols(protocol)
logger.info(f"Protocol '{protocol}' requires protocols: {required_protocols}")
components = self._select_components(required_protocols)
ordered = self._topological_sort(components)
logger.info(f"Component startup order: {[c.name for c in ordered]}")
ports = _allocate_ports(len(ordered))
port_assignments: dict[str, int] = {
comp.name: port for comp, port in zip(ordered, ports)
}
for comp in ordered:
config = self._build_runtime_config(comp, port_assignments)
address = IpAddress(host="localhost", port=port_assignments[comp.name])
pm = ProcessManager(self.venv_path)
assert pm.start_process(comp.package_name, config), (
f"Failed to start component '{comp.name}' (binary: {comp.package_name})"
)
startup_ms = pm.wait_until_server_is_ready(address, timeout_in_seconds=2)
logger.info(
f"Component '{comp.name}' ready in {startup_ms:.2f} ms "
f"on port {address.port}"
)
comp.address = address
comp.process_manager = pm
self._running_components.append(comp)
for p in comp.protocols:
self._protocol_to_address[p] = address
def get_server_address(self, protocol: str) -> IpAddress:
"""Return the listen address of the component implementing *protocol*."""
if protocol not in self._protocol_to_address:
raise ValueError(f"No running component implements protocol '{protocol}'")
return self._protocol_to_address[protocol]
def is_auth_required(self, protocol: str) -> bool:
"""Return whether the component implementing *protocol* requires authentication."""
for comp in self._running_components:
if protocol in comp.protocols:
return comp.auth_required
raise ValueError(f"No running component implements protocol '{protocol}'")
def stop_all(self) -> dict[str, PerformanceStats]:
"""Stop all running components in reverse startup order.
Returns a mapping of component name to performance statistics.
"""
perf_stats: dict[str, PerformanceStats] = {}
for comp in reversed(self._running_components):
if comp.process_manager:
try:
comp.process_manager.stop_process()
if comp.process_manager.performance_stats:
perf_stats[comp.name] = comp.process_manager.performance_stats
except Exception:
logger.exception(f"Error stopping component '{comp.name}'")
self._running_components.clear()
self._protocol_to_address.clear()
return perf_stats
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _resolve_required_protocols(self, protocol: str) -> set[str]:
"""Transitively collect all *required* protocols for *protocol*."""
result: set[str] = {protocol}
queue: deque[str] = deque([protocol])
while queue:
current = queue.popleft()
deps = self.component_dependencies.get(current, [])
for dep in deps:
if dep not in result:
result.add(dep)
queue.append(dep)
return result
def _select_components(self, required_protocols: set[str]) -> list[ComponentInstance]:
"""Pick the minimal set of components from the deployment config that
cover *required_protocols*, plus any transitive ``connectTo``
dependencies. A single component may cover several protocols."""
comp_cfg_by_name: dict[str, dict[str, Any]] = {
cfg["name"]: cfg for cfg in self.deployment_config["components"]
}
selected_names: set[str] = set()
for comp_cfg in self.deployment_config["components"]:
if set(comp_cfg["protocols"]) & required_protocols:
selected_names.add(comp_cfg["name"])
queue: deque[str] = deque(selected_names)
while queue:
name = queue.popleft()
cfg = comp_cfg_by_name.get(name)
assert cfg is not None, f"Component '{name}' not defined in deployment config"
for target in cfg["config"].get("connectTo", {}):
if target not in selected_names:
assert target in comp_cfg_by_name, f"Component '{target}' not defined in deployment config"
selected_names.add(target)
queue.append(target)
components: list[ComponentInstance] = []
covered: set[str] = set()
for comp_cfg in self.deployment_config["components"]:
if comp_cfg["name"] in selected_names:
comp_protocols = set(comp_cfg["protocols"])
components.append(ComponentInstance(
name=comp_cfg["name"],
package_name=comp_cfg["packageName"],
protocols=comp_protocols,
original_config=comp_cfg["config"],
auth_required=comp_cfg.get("authRequired", True),
))
covered |= comp_protocols
missing = required_protocols - covered
if missing:
raise ValueError(
f"Protocols {missing} are required but not implemented by "
f"any component in deployment_config.json"
)
return components
def _topological_sort(self, components: list[ComponentInstance]) -> list[ComponentInstance]:
"""Sort components so that each component's connectTo targets are
started before the component itself (Kahn's algorithm)."""
name_set = {c.name for c in components}
by_name = {c.name: c for c in components}
# adj[A] = {B, ...} means A depends on (connects to) B
adj: dict[str, set[str]] = {c.name: set() for c in components}
for comp in components:
for target in comp.original_config.get("connectTo", {}):
if target in name_set:
adj[comp.name].add(target)
in_deg = {name: len(deps) for name, deps in adj.items()}
queue: deque[str] = deque(n for n, d in in_deg.items() if d == 0)
result: list[ComponentInstance] = []
while queue:
name = queue.popleft()
result.append(by_name[name])
for other, deps in adj.items():
if name in deps:
in_deg[other] -= 1
if in_deg[other] == 0:
queue.append(other)
if len(result) != len(components):
raise ValueError("Circular dependency detected among components")
return result
def _build_runtime_config(
self, comp: ComponentInstance, port_assignments: dict[str, int]
) -> dict[str, Any]:
"""Clone the component's config with dynamic ports and the test log dir."""
config: dict[str, Any] = dict(comp.original_config)
config["logDirectory"] = constants.LOG_DIRECTORY
config["listenOn"] = {
"host": "localhost",
"port": port_assignments[comp.name],
}
config["dataFilePath"] = str(self.data_file_path)
if "connectTo" in config:
rewritten: dict[str, Any] = {}
for target_name, target_cfg in config["connectTo"].items():
if target_name in port_assignments:
rewritten[target_name] = {
"host": "localhost",
"port": port_assignments[target_name],
}
else:
rewritten[target_name] = target_cfg
config["connectTo"] = rewritten
return config

View File

@ -0,0 +1 @@
LOG_DIRECTORY = "./logs"

View File

@ -0,0 +1,197 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Generic, TypeVar
from unittest.mock import MagicMock, _Call
import logging
from connection.tcp_connection_manager import TcpConnectionManager
logger = logging.getLogger(__name__)
T = TypeVar('T')
DEFAULT_RESPONSE_TIMEOUT = 0.5 # in seconds
class CallExpectations(ABC):
def __init__(self, mocked_method: MagicMock) -> None:
self.mocked_method = mocked_method
self.fulfilled = False
self.fulfilled_call: _Call | None = None
@abstractmethod
def is_expected_call(self, mock_call: _Call) -> bool:
"""
Check if the given mock call is the expected event.
Do NOT raise AssertionError in this method. Return False if the call is not expected.
"""
pass
@abstractmethod
def validate_call(self, mock_call: _Call) -> None:
"""
Validate the given mock call.
Raise AssertionError if the call is not expected.
"""
pass
def verify(self) -> bool:
for index in range(self.mocked_method.call_count):
mock_call = self.mocked_method.call_args_list[index]
if not self.is_expected_call(mock_call):
logger.debug(f"Skipping unwanted call: {mock_call}")
continue
logger.debug(f"Found expected call: {mock_call}")
logger.debug("Validating call...")
self.validate_call(mock_call)
logger.debug(f"Call validated successfully")
self.fulfilled = True
self.fulfilled_call = mock_call
self.mocked_method.call_count -= 1
self.mocked_method.call_args_list.pop(index)
self.mocked_method.mock_calls.pop(index)
return True
logger.debug(f"Expected call of {self.mocked_method._mock_name} not found")
return False
def __repr__(self) -> str:
cls_name = self.__class__.__name__
base = f"<{cls_name}(mocked_method={getattr(self.mocked_method, '_mock_name', self.mocked_method)})"
extra_attrs = []
for k, v in sorted(self.__dict__.items()):
if k in {'mocked_method', 'fulfilled', 'fulfilled_call'}:
continue
extra_attrs.append(f"{k}={repr(v)}")
if extra_attrs:
base += ", " + ", ".join(extra_attrs)
return base + ">"
class ResponseExpectation(Generic[T], CallExpectations):
def __init__(self, mocked_method: MagicMock, response_type: type[T], request_id: int, expect_success: bool = True) -> None:
super().__init__(mocked_method)
self.response_type = response_type
self.request_id = request_id
self.expect_success = expect_success
def is_expected_call(self, mock_call: _Call) -> bool:
if len(mock_call.args) != 1:
return False
call_arg = mock_call.args[0]
if not isinstance(call_arg, self.response_type):
return False
if not hasattr(call_arg, "request_id"):
return False
if call_arg.request_id != self.request_id:
return False
return True
def validate_call(self, mock_call: _Call) -> None:
call_arg = mock_call.args[0]
assert hasattr(call_arg, "error_message")
if self.expect_success:
assert call_arg.error_message == "", f"Unexpected error: {call_arg.error_message}"
else:
assert call_arg.error_message != "", "Unexpected success response"
def get_response(self) -> T:
assert self.fulfilled_call is not None
return self.fulfilled_call.args[0]
class CallExpectationsManager:
def __init__(self) -> None:
self.pending_expectations_by_mock: dict[MagicMock, list[CallExpectations]] = defaultdict(list)
self.fulfilled_expectations_by_mock: dict[MagicMock, list[CallExpectations]] = defaultdict(list)
self.tcp_connection_manager: TcpConnectionManager | None = None
self.response_timeout: float = DEFAULT_RESPONSE_TIMEOUT
def create_mock(self, mock_name: str) -> MagicMock:
mocked_method = MagicMock(name=mock_name)
self.register_mock(mocked_method)
return mocked_method
def register_mock(self, mocked_method: MagicMock) -> None:
self.pending_expectations_by_mock[mocked_method]
def add_expectation(self, call_expectation: CallExpectations) -> None:
self.pending_expectations_by_mock[call_expectation.mocked_method].append(call_expectation)
def get_number_of_pending_calls(self) -> int:
return sum([len(pending_expectations) for pending_expectations in self.pending_expectations_by_mock.values()])
def setup_network(self, tcp_connection_manager: TcpConnectionManager, response_timeout: float = DEFAULT_RESPONSE_TIMEOUT) -> None:
"""
Set up the network to be able to await network events when verifying expectations.
Ideally this would be set in the constructor, but it's not possible due to circular dependencies.
For this reason, we allow the network to not be set up and expectations still to be checked.
"""
self.tcp_connection_manager = tcp_connection_manager
self.response_timeout = response_timeout
def verify_expectations(self, assert_no_pending_calls: bool = True) -> None:
# Check pending expectations before awaiting network events because it may fulfill some expectations and avoid
# unnecessary blocking network calls. Then check it again after awaiting network events to ensure all expectations
# are fulfilled.
self._check_pending_expectations()
if self.tcp_connection_manager:
self._await_network_events()
self._check_pending_expectations()
if assert_no_pending_calls:
if self.get_number_of_pending_calls() > 0:
logger.error(self._make_issues_report())
assert self.get_number_of_pending_calls() == 0, self._make_issues_report()
def verify_no_unexpected_calls(self, assert_no_unexpected_calls: bool = True) -> None:
has_unexpected_calls = False
issues = []
for mocked_method, _ in self.pending_expectations_by_mock.items():
for index in range(mocked_method.call_count):
has_unexpected_calls = True
mock_call = mocked_method.call_args_list[index]
issues.append(f"Detected unexpected call to {mocked_method._mock_name}: {repr(mock_call)}")
if assert_no_unexpected_calls:
assert not has_unexpected_calls, "\n".join(issues)
def _make_issues_report(self) -> str:
return "Expectations not fulfilled:\n" + "\n".join([f"{mocked_method._mock_name}: {call_expectation}"
for mocked_method, call_expectations in self.pending_expectations_by_mock.items()
for call_expectation in call_expectations if not call_expectation.fulfilled])
def _check_pending_expectations(self) -> None:
for mocked_method, call_expectations in self.pending_expectations_by_mock.items():
if not call_expectations:
continue
logger.info(f"Verifying {len(call_expectations)} expectations for {mocked_method._mock_name}")
for call_expectation in call_expectations.copy():
if call_expectation.verify():
call_expectations.remove(call_expectation)
self.fulfilled_expectations_by_mock[mocked_method].append(call_expectation)
else:
logger.debug(f"Expectation still pending for {mocked_method._mock_name}: {call_expectation}")
def _await_network_events(self) -> None:
assert self.tcp_connection_manager is not None, "Network not set up"
pending_calls_count = self.get_number_of_pending_calls()
next_timeout = self.response_timeout if pending_calls_count > 0 else 0
logger.info(f"Awaiting (at least) {pending_calls_count} network events for up to {next_timeout} seconds...")
total_received_events = 0
while True:
received_events = self.tcp_connection_manager.wait_for_events(timeout_in_seconds=next_timeout)
logger.debug(f"Received {received_events} network events")
if received_events == 0:
logger.debug("Stopping network event waiting loop")
break
total_received_events += received_events
pending_calls_count = self.get_number_of_pending_calls()
if pending_calls_count == 0:
logger.debug("Received all pending network events. Next network read attempts will be non-blocking.")
next_timeout = 0
else:
logger.debug(f"Still waiting for {pending_calls_count} pending network events")
logger.debug(f"Received {total_received_events} network events in total")

View File

@ -0,0 +1,180 @@
import logging
import json
import os
from dataclasses import dataclass
from pathlib import Path
import signal
import subprocess
import threading
import socket
import time
import tempfile
from typing import IO
import psutil
from connection.ip_address import IpAddress
logger = logging.getLogger(__name__)
GRACEFUL_STOP_TIMEOUT: int = 30 # seconds
FORCE_KILL_TIMEOUT: int = 5 # seconds
PERF_MONITOR_INTERVAL_S: float = 0.2
@dataclass(frozen=True)
class PerformanceStats:
peak_rss_mb: float
avg_rss_mb: float
peak_cpu_percent: float
avg_cpu_percent: float
samples: int
class PerformanceMonitor:
"""Lightweight background sampler for a subprocess's memory and CPU usage."""
def __init__(self, pid: int, interval_s: float = PERF_MONITOR_INTERVAL_S) -> None:
self._ps_process = psutil.Process(pid)
self._interval_s = interval_s
self._stop_event = threading.Event()
self._thread = threading.Thread(target=self._poll, daemon=True)
self._rss_samples: list[float] = []
self._cpu_samples: list[float] = []
def start(self) -> None:
try:
self._ps_process.cpu_percent()
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass
self._thread.start()
def stop(self) -> PerformanceStats:
self._stop_event.set()
self._thread.join(timeout=2)
return self._build_stats()
def _poll(self) -> None:
while not self._stop_event.is_set():
try:
rss_mb = self._ps_process.memory_info().rss / (1024 * 1024)
self._rss_samples.append(rss_mb)
self._cpu_samples.append(self._ps_process.cpu_percent())
except (psutil.NoSuchProcess, psutil.AccessDenied):
break
self._stop_event.wait(self._interval_s)
def _build_stats(self) -> PerformanceStats:
avg_rss = sum(self._rss_samples) / len(self._rss_samples) if self._rss_samples else 0
peak_rss = max(self._rss_samples) if self._rss_samples else 0
avg_cpu = sum(self._cpu_samples) / len(self._cpu_samples) if self._cpu_samples else 0
peak_cpu = max(self._cpu_samples) if self._cpu_samples else 0
return PerformanceStats(
peak_rss_mb=round(peak_rss, 2),
avg_rss_mb=round(avg_rss, 2),
peak_cpu_percent=round(peak_cpu, 2),
avg_cpu_percent=round(avg_cpu, 2),
samples=len(self._rss_samples),
)
class ProcessManager:
def __init__(self, venv_path: Path) -> None:
self.venv_path = venv_path
self.process: subprocess.Popen | None = None
self.stdout_thread: threading.Thread | None = None
self._perf_monitor: PerformanceMonitor | None = None
self.performance_stats: PerformanceStats | None = None
def _create_temp_app_config(self, config_data: dict) -> str:
"""Creates a temporary configuration file and returns its path."""
self.temp_config_file = tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json")
json.dump(config_data, self.temp_config_file)
self.temp_config_file.close()
return self.temp_config_file.name
def start_process(self, package_name: str, config_data: dict | None = None) -> bool:
try:
"""Starts the process, optionally using a temporary configuration file."""
command = [str(self.venv_path / "bin" / package_name)]
if config_data:
temp_config_path = self._create_temp_app_config(config_data)
logger.info(f"Temporary config file created at: {temp_config_path}")
command.extend(["-c", str(temp_config_path)])
logger.info(f"Starting process with command: {' '.join(command)}")
self.process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True
)
time.sleep(0.15) # Give the process some time to start (or crash)
if self.process.poll() is not None:
raise RuntimeError(f"Failed to start process {package_name}. Return code: {self.process.returncode}")
self.stdout_thread = threading.Thread(target=_log_output, args=(package_name, self.process.stdout,), daemon=True)
self.stdout_thread.start()
self._perf_monitor = PerformanceMonitor(self.process.pid)
self._perf_monitor.start()
logger.info(f"Process started with PID: {self.process.pid}")
return True
except Exception as e:
logger.exception("Error starting process")
return False
def wait_until_server_is_ready(self, ip_address: IpAddress, timeout_in_seconds: float) -> float:
start_time = time.time()
while time.time() - start_time < timeout_in_seconds:
if self.process.poll() is not None:
raise RuntimeError(f"Process crashed. Return code: {self.process.returncode}")
try:
conn_socket = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM)
conn_socket.connect((ip_address.host, ip_address.port))
time_to_startup = (time.time() - start_time) * 1000
logger.info(f"Server started in approximately {time_to_startup:.2f} ms")
conn_socket.close()
return time_to_startup
except ConnectionRefusedError:
logger.debug("Server is not ready yet")
time.sleep(0.1)
raise TimeoutError(f"Server did not start in {timeout_in_seconds} seconds")
def stop_process(self) -> None:
assert self.process is not None, "Process is not running"
if self._perf_monitor is not None:
self.performance_stats = self._perf_monitor.stop()
logger.info(f"Sending SIGINT to process {self.process.pid}")
self.process.send_signal(signal.SIGINT)
try:
self.process.wait(timeout=GRACEFUL_STOP_TIMEOUT)
except subprocess.TimeoutExpired:
logger.warning(f"Process {self.process.pid} did not stop after SIGINT grace period ({GRACEFUL_STOP_TIMEOUT} seconds). Sending SIGTERM...")
self.process.send_signal(signal.SIGTERM)
try:
self.process.wait(timeout=FORCE_KILL_TIMEOUT)
except subprocess.TimeoutExpired:
logger.warning(f"Process {self.process.pid} still running after SIGTERM. Sending SIGKILL...")
self.process.send_signal(signal.SIGKILL)
self.process.wait(timeout=FORCE_KILL_TIMEOUT)
if self.process.returncode != 0:
logger.warning(f"Process {self.process.pid} exited with return code: {self.process.returncode}")
logger.info("Process stopped")
if self.temp_config_file:
os.unlink(self.temp_config_file.name)
logger.info(f"Temporary config file deleted: {self.temp_config_file.name}")
assert self.stdout_thread is not None
self.stdout_thread.join()
def _log_output(package_name: str, pipe: IO[str]) -> None:
sub_logger = logging.getLogger(f"SUBPROCESS_LOG_{package_name}")
for line in iter(pipe.readline, ''):
sub_logger.info(line.strip())

View File

134
solution/tests/conftest.py Normal file
View File

@ -0,0 +1,134 @@
import dataclasses
from datetime import datetime
import json
import logging
import tempfile
from pathlib import Path
from typing import Any, Generator
import pytest
from tests.common.component_orchestrator import ComponentOrchestrator
from tests.common import constants
logger = logging.getLogger(__name__)
DEFAULT_PASSWORD = "password"
EXTRA_TEST_USERS = 5
SERVICE_USERS = [
{"username": "admin", "password": "admin", "full_name": "Admin Service Account"},
{"username": "risk_gateway", "password": "risk_gateway", "full_name": "Risk Gateway Service Account"},
]
_performance_data: list[dict[str, Any]] = []
def pytest_addoption(parser: pytest.Parser) -> None:
parser.addoption("--venv-path", action="store", default="", help="The path to the virtual environment")
parser.addoption(
"--deployment-config",
action="store",
default="deployment_config.json",
help="Path to the deployment_config.json for the deployment under test",
)
@pytest.fixture
def venv_path(request: pytest.FixtureRequest) -> Path:
venv_path = request.config.getoption("--venv-path")
assert venv_path is not None, "Virtual environment path is not set"
assert Path(venv_path).exists(), f"Virtual environment path {venv_path} does not exist"
return Path(venv_path)
@pytest.fixture
def deployment_config(request: pytest.FixtureRequest) -> dict:
config_path = Path(request.config.getoption("--deployment-config"))
assert config_path.exists(), f"Deployment config not found: {config_path}"
with config_path.open() as f:
return json.load(f)
def _create_data_file_with_users(test_name: str) -> Path:
"""Create a temporary data file with service accounts and test-specific users.
Creates the primary test user plus numbered extras (e.g. ``<test>_2``
``<test>_6``) so that multi-client tests can log in with distinct identities.
"""
users = list(SERVICE_USERS) + [
{"username": test_name, "password": DEFAULT_PASSWORD, "full_name": f"Test User ({test_name})"},
]
for i in range(2, EXTRA_TEST_USERS + 2):
username = f"{test_name}_{i}"
users.append({"username": username, "password": DEFAULT_PASSWORD, "full_name": f"Test User ({username})"})
data = {"users": users}
tmp = tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json", prefix="data_file_")
json.dump(data, tmp, indent=2)
tmp.close()
logger.info(f"Created temporary data file with {len(users)} users: {tmp.name}")
return Path(tmp.name)
def pytest_sessionfinish(session: pytest.Session, exitstatus: int) -> None:
"""Write collected performance data to ``performance_report.json``.
With pytest-xdist each worker writes a partial file; the controller
(or a single-process run) merges them into the final report.
"""
log_dir = Path(constants.LOG_DIRECTORY)
log_dir.mkdir(parents=True, exist_ok=True)
worker_input = getattr(session.config, "workerinput", None)
if worker_input is not None:
worker_id = worker_input["workerid"]
partial = log_dir / f"_perf_partial_{worker_id}.json"
partial.write_text(json.dumps(_performance_data, indent=2))
else:
all_data: list[dict[str, Any]] = list(_performance_data)
for partial in sorted(log_dir.glob("_perf_partial_*.json")):
all_data.extend(json.loads(partial.read_text()))
partial.unlink()
if all_data:
now = datetime.now().strftime('%Y%m%d_%H%M%S')
report_path = log_dir / f"performance_report_{now}.json"
report_path.write_text(json.dumps(all_data, indent=2))
logger.info(f"Performance report written to {report_path}")
@pytest.fixture(autouse=True)
def start_components(
request: pytest.FixtureRequest, venv_path: Path, deployment_config: dict
) -> Generator[None, None, None]:
"""Automatically start all components needed for the test class's declared PROTOCOL.
Any test class with a ``PROTOCOL`` class attribute (matching a protocol name
from the deployment-config schema) will have the required components started
before each test and stopped afterwards. The component implementing the
declared protocol is exposed as ``self.server_address`` on the test instance.
"""
protocol: str | None = getattr(request.cls, "PROTOCOL", None)
if protocol is None:
yield
return
test_name = request.node.name
data_file_path = _create_data_file_with_users(test_name)
try:
orchestrator = ComponentOrchestrator(venv_path, deployment_config, data_file_path)
orchestrator.start_for_protocol(protocol)
if request.instance is not None:
request.instance.server_address = orchestrator.get_server_address(protocol)
request.instance.orchestrator = orchestrator
request.instance.auth_required = orchestrator.is_auth_required(protocol)
yield
perf_stats = orchestrator.stop_all()
if perf_stats:
_performance_data.append({
"test": test_name,
"components": {name: dataclasses.asdict(stats) for name, stats in perf_stats.items()},
})
finally:
data_file_path.unlink(missing_ok=True)

View File

@ -0,0 +1,249 @@
import json
import logging
import os
import signal
import time
import subprocess
import venv
from datetime import datetime
from pathlib import Path
from typing import Optional
from dataclasses import dataclass
import jsonschema
logger = logging.getLogger(__name__)
DEPLOYMENT_CONFIG_SCHEMA_FILE: Path = Path(__file__).parent.parent.parent / "deployment_config_schema.json"
DEPLOYMENT_DIR: str = os.path.expanduser(os.environ.get("CK_DEPLOYMENT_DIR", "~/deployment"))
TESTS_DIR: str = os.path.expanduser(os.environ.get("CK_TESTS_DIR", "~/gtat-tech-career-kickstarter/solution/tests"))
TEST_STARTED_FLAG: str = "started.flag"
CHECK_INTERVAL: int = 30 # seconds
WORKERS_COUNT: int = 4
TEST_TIMEOUT: int = 20 * 60 # 20 minutes
GRACEFUL_STOP_TIMEOUT: int = 30 # seconds
@dataclass
class DeployedInstance:
user_dir: Path
deployment_dir: Path
wheel_file: Path
venv_dir: Optional[Path] = None
@property
def py(self) -> str:
assert self.venv_dir is not None
return str(self.venv_dir / "bin" / "python")
def __repr__(self) -> str:
return f"{self.user_dir}//{self.deployment_dir}"
def find_oldest_untested_deployment() -> Optional[DeployedInstance]:
oldest_deployment: Optional[DeployedInstance] = None
for user_dir in Path(DEPLOYMENT_DIR).iterdir():
if not user_dir.is_dir():
continue
for deployment_dir in user_dir.iterdir():
if not deployment_dir.is_dir():
continue
started_flag = deployment_dir / TEST_STARTED_FLAG
if started_flag.exists():
logger.debug(f"Skipping {deployment_dir} as it has already been marked as started.")
continue
wheels_files = list(deployment_dir.glob("*.whl"))
if not wheels_files:
logger.debug(f"Skipping {deployment_dir} as it does not contain any wheel files.")
continue
if len(wheels_files) > 1:
raise ValueError("Multiple wheel files found in the same directory.")
deployment = DeployedInstance(user_dir, deployment_dir, wheels_files[0])
if oldest_deployment is None or deployment_dir.stat().st_mtime < oldest_deployment.deployment_dir.stat().st_mtime:
oldest_deployment = deployment
return oldest_deployment
def create_venv(deployed_instance: DeployedInstance) -> None:
logger.info(f"Creating virtual environment for {deployed_instance}...")
deployed_instance.venv_dir = deployed_instance.deployment_dir / "venv"
venv.create(str(deployed_instance.venv_dir), with_pip=True)
logger.info(f"Virtual environment created for {deployed_instance}.")
def install_wheel(deployed_instance: DeployedInstance) -> None:
logger.info(f"Installing wheel {deployed_instance.wheel_file} into its virtual environment...")
subprocess.run([deployed_instance.py, "-m", "pip", "install", str(deployed_instance.wheel_file)], check=True)
logger.info(f"Wheel {deployed_instance.wheel_file} installed into its virtual environment.")
def _resolve_test_files(deployed_instance: DeployedInstance) -> list[Path]:
config_file = deployed_instance.deployment_dir / "deployment_config.json"
with config_file.open() as f:
config = json.load(f)
with DEPLOYMENT_CONFIG_SCHEMA_FILE.open() as f:
schema = json.load(f)
jsonschema.validate(instance=config, schema=schema)
system_tests = config["systemTests"]
test_files = [Path(TESTS_DIR) / f"test_{system_test}_system.py" for system_test in system_tests]
missing = [str(p) for p in test_files if not p.exists()]
if missing:
raise FileNotFoundError(f"Test files not found: {missing}")
return test_files
def run_pytest(deployed_instance: DeployedInstance) -> list[Path]:
logger.info(f"Running pytest for {deployed_instance}...")
test_files = _resolve_test_files(deployed_instance)
logger.info(f"Running system tests: {[str(f) for f in test_files]}")
junit_report_file = deployed_instance.deployment_dir / "test_results.xml"
final_report_file = deployed_instance.deployment_dir / "final_report.json"
deployment_config_file = deployed_instance.deployment_dir / "deployment_config.json"
command = [
deployed_instance.py, "-m", "pytest", *[str(f) for f in test_files],
f"-n {WORKERS_COUNT}",
"-W error::pytest.PytestUnhandledThreadExceptionWarning",
f"--venv-path={str(deployed_instance.venv_dir)}",
f"--deployment-config={str(deployment_config_file)}",
f"--junit-xml={junit_report_file}",
]
logger.debug(f"Running command: {command}")
process = subprocess.Popen(command, cwd=deployed_instance.deployment_dir, start_new_session=True)
logger.debug(f"Process started with PID: {process.pid}")
started_at = time.monotonic()
timed_out = False
try:
process.wait(timeout=TEST_TIMEOUT)
except subprocess.TimeoutExpired:
timed_out = True
logger.warning(f"Pytest process {process.pid} exceeded timeout ({TEST_TIMEOUT} seconds) for {deployed_instance}. Shutting down...")
_gracefully_stop_process(process, timeout=GRACEFUL_STOP_TIMEOUT)
testing_duration_ms = int((time.monotonic() - started_at) * 1000)
logger.debug(f"Testing done in {testing_duration_ms}ms")
_write_final_report(
final_report_file=final_report_file,
testing_duration_ms=testing_duration_ms,
testing_timed_out=timed_out,
)
logger.info(f"Testing completed for {deployed_instance}.")
return [junit_report_file, final_report_file]
def _write_final_report(final_report_file: Path, testing_duration_ms: int, testing_timed_out: bool) -> None:
report_content = {
"testing_duration_ms": testing_duration_ms,
"testing_timed_out": testing_timed_out,
}
with final_report_file.open("w", encoding="utf-8") as f:
json.dump(report_content, f)
logger.info(f"Final test run report written to {final_report_file}")
def _gracefully_stop_process(process: subprocess.Popen[bytes], timeout: int = GRACEFUL_STOP_TIMEOUT) -> int:
FORCE_KILL_TIMEOUT: int = 5 # seconds
if process.poll() is not None:
return process.returncode
logger.info(f"Sending SIGINT to process {process.pid}")
os.killpg(process.pid, signal.SIGINT)
try:
process.wait(timeout=timeout)
return process.returncode
except subprocess.TimeoutExpired:
logger.warning(
f"Process {process.pid} did not stop after SIGINT grace period ({timeout} seconds). "
"Sending SIGTERM..."
)
if process.poll() is not None:
return process.returncode
os.killpg(process.pid, signal.SIGTERM)
try:
process.wait(timeout=FORCE_KILL_TIMEOUT)
return process.returncode
except subprocess.TimeoutExpired:
logger.warning(f"Process {process.pid} still running after SIGTERM. Sending SIGKILL...")
if process.poll() is not None:
return process.returncode
os.killpg(process.pid, signal.SIGKILL)
process.wait(timeout=FORCE_KILL_TIMEOUT)
return process.returncode
def copy_test_results_to_dev_server(report_files: list[Path], dev_server_host: str) -> None:
logger.info(f"Copying test result files to dev server {dev_server_host}...")
for report_file in report_files:
subprocess.run(["scp", str(report_file), f"{dev_server_host}:/tmp/{report_file.name}"], check=True)
def copy_logs_to_dev_server(deployed_instance: DeployedInstance, dev_server_host: str) -> None:
logger.info(f"Copying logs to dev server {dev_server_host}...")
subprocess.run(["scp", str(deployed_instance.deployment_dir / "logs/*"), f"{dev_server_host}:/tmp/pytest.log"], check=True)
def main() -> None:
logger.info("Starting test runner...")
while True:
logger.info("Checking for untested deployments...")
untested_deployment = find_oldest_untested_deployment()
if untested_deployment:
logger.info(f"Found untested deployment in {untested_deployment.deployment_dir}")
started_flag = untested_deployment.deployment_dir / TEST_STARTED_FLAG
started_flag.touch()
try:
create_venv(untested_deployment)
install_wheel(untested_deployment)
except subprocess.CalledProcessError as e:
logger.exception(f"Error during installation: {e}")
continue
report_files = None
try:
report_files = run_pytest(untested_deployment)
except subprocess.SubprocessError as e:
logger.exception(f"Error while running tests: {e}")
continue
try:
username = untested_deployment.user_dir.name
if report_files is None:
logger.error("Report files are None, skipping copy to dev server.")
continue
copy_test_results_to_dev_server(report_files, username)
copy_logs_to_dev_server(untested_deployment, username)
logger.info(f"Test results copied to dev server for user: {username}")
except subprocess.CalledProcessError as e:
logger.exception(f"Error copying test results to dev server: {e}")
continue
else:
logger.info(f"No untested deployments found. Retrying in {CHECK_INTERVAL} seconds...")
time.sleep(CHECK_INTERVAL)
def setup_logging() -> None:
log_file_name = Path(__file__).parent / "logs" / f"test_finder_{datetime.now():%Y%m%d_%H%M%S}.log"
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(log_file_name),
logging.StreamHandler()
]
)
logger.info(f"Logging setup complete. Log file: {log_file_name}")
if __name__ == "__main__":
setup_logging()
main()

View File

@ -0,0 +1,7 @@
{
"admin": [],
"order_book": ["admin"],
"info": ["admin", "order_book"],
"execution": ["admin", "order_book"],
"risk_limits": ["admin", "execution"]
}

View File

@ -0,0 +1,78 @@
from decimal import Decimal
import socket
from typing import Callable
import logging
from connection.connection_handler import ConnectionHandler, ConnectionHandlerFactory
from connection.ip_address import IpAddress
from proto.admin_pb2 import CreateInstrumentRequest, CreateInstrumentResponse
from proto.common_pb2 import Instrument, MessageType
from tests.common.mock_expectations import CallExpectationsManager, ResponseExpectation
logger = logging.getLogger(__name__)
class AdminTestClient(ConnectionHandler):
def __init__(self, socket_fd: socket.socket, ip_address: IpAddress, on_close: Callable[[], None],
call_expectations_manager: CallExpectationsManager) -> None:
super().__init__(socket_fd, ip_address, on_close)
self.next_request_id = 1
self.callbacks: dict[int, Callable] = {}
self.call_expectations_manager = call_expectations_manager
self.on_create_instrument_response = self.call_expectations_manager.create_mock("on_create_instrument_response")
def on_disconnect(self) -> None:
logger.info(f"Disconnected from {self.ip_address}")
def handle_message(self, message_type: int, raw_message: bytes) -> None:
logger.info(f"Handling message of type {MessageType.Name(message_type)}")
if message_type == MessageType.ADMIN_CREATE_INSTRUMENT_RESPONSE:
message = CreateInstrumentResponse.FromString(raw_message)
else:
raise ValueError(f"Unexpected message type: {MessageType.Name(message_type)}")
callback = self.callbacks.pop(message.request_id, None)
if callback:
callback(message)
else:
raise ValueError(f"Received response with unknown request_id: {message.request_id}")
def test_create_instrument(self, instrument: Instrument, tick_size: Decimal,
expect_success: bool = True) -> ResponseExpectation[CreateInstrumentResponse]:
request = CreateInstrumentRequest(instrument=instrument, tick_size=float(tick_size))
request_id = self._get_next_request_id()
request.request_id = request_id
self.callbacks[request_id] = self.on_create_instrument_response
self.send_message(MessageType.ADMIN_CREATE_INSTRUMENT_REQUEST, request)
response_expectation = ResponseExpectation(
self.on_create_instrument_response, CreateInstrumentResponse, request_id, expect_success)
self.call_expectations_manager.add_expectation(response_expectation)
if expect_success:
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert response_expectation.fulfilled, "Expected CreateInstrumentResponse not received"
return response_expectation
def send_message(self, message_type: int, message) -> None:
logger.info(f"Sending message of type {MessageType.Name(MessageType.ValueType(message_type))}")
super().send_message(message_type, message)
def _get_next_request_id(self) -> int:
request_id = self.next_request_id
self.next_request_id += 1
return request_id
class AdminClientConnectionHandlerFactory(ConnectionHandlerFactory[AdminTestClient]):
def __init__(self, call_expectations_manager: CallExpectationsManager) -> None:
self.call_expectations_manager = call_expectations_manager
def on_new_connection(self, socket_fd: socket.socket, ip_address: IpAddress,
close_callback: Callable[[], None]) -> AdminTestClient:
return AdminTestClient(socket_fd, ip_address, close_callback, self.call_expectations_manager)
def on_connection_closed(self, connection_handler: AdminTestClient) -> None:
pass

View File

@ -0,0 +1,226 @@
from decimal import Decimal
import socket
from typing import Callable
from unittest.mock import ANY, _Call, MagicMock
import logging
from common.info_client import BaseInfoClient
from common.utils import decimal_from_float
from connection.connection_handler import ConnectionHandlerFactory
from connection.ip_address import IpAddress
from proto.info_pb2 import *
from proto.common_pb2 import Instrument, LoginRequest, LoginResponse, Side
from tests.conftest import DEFAULT_PASSWORD
from tests.common.mock_expectations import CallExpectations, CallExpectationsManager, ResponseExpectation
logger = logging.getLogger(__name__)
class InfoTestClient(BaseInfoClient):
def __init__(self, socket_fd: socket.socket, ip_address: IpAddress, on_close: Callable[[], None],
call_expectations_manager: CallExpectationsManager) -> None:
super().__init__(socket_fd, ip_address, on_close)
self._last_order_book_id: int | None = None
self._last_order_id: int | None = None
self.call_expectations_manager = call_expectations_manager
self.on_login_response = self.call_expectations_manager.create_mock("on_login_response")
self.on_create_instrument_response = self.call_expectations_manager.create_mock("on_create_instrument_response")
self.on_order_book_subscribe_response = self.call_expectations_manager.create_mock("on_order_book_subscribe_response")
self.on_instrument = self.call_expectations_manager.create_mock("on_instrument") # type: ignore
self.on_top_of_book = self.call_expectations_manager.create_mock("on_top_of_book") # type: ignore
self.on_price_depth_book = self.call_expectations_manager.create_mock("on_price_depth_book") # type: ignore
self.on_trade = self.call_expectations_manager.create_mock("on_trade") # type: ignore
def test_login(self, username: str, password: str = DEFAULT_PASSWORD, expect_success: bool = True) -> ResponseExpectation[LoginResponse]:
request = LoginRequest(username=username, password=password)
request_id = self.send_login(request, callback=self.on_login_response)
response_expectation = ResponseExpectation(self.on_login_response, LoginResponse, request_id, expect_success)
self.call_expectations_manager.add_expectation(response_expectation)
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert response_expectation.fulfilled, f"Expected response not received: {response_expectation}"
return response_expectation
def test_create_instrument(self, instrument: Instrument, tick_size: float, expect_success: bool = True, expect_public_feed: bool = True) -> ResponseExpectation[CreateInstrumentResponse]:
request = CreateInstrumentRequest(instrument=instrument, tick_size=tick_size)
request_id = self.send_create_instrument(request, callback=self.on_create_instrument_response)
response_expectation = ResponseExpectation(self.on_create_instrument_response, CreateInstrumentResponse, request_id, expect_success)
self.call_expectations_manager.add_expectation(response_expectation)
if expect_success:
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert response_expectation.fulfilled, f"Expected response not received: {response_expectation}"
response = response_expectation.get_response()
assert response.error_message == ""
self._last_order_book_id = response.order_book_id
if expect_success and expect_public_feed:
self.expect_on_instrument_event(instrument, tick_size, response.order_book_id)
return response_expectation
def test_subscribe_to_order_book(self, instrument_symbol: str, type: SubscriptionType.ValueType, expect_success: bool = True) -> ResponseExpectation[OrderBookSubscribeResponse]:
request = OrderBookSubscribeRequest(instrument_symbol=instrument_symbol, subscription_type=type)
request_id = self.send_order_book_subscribe(request, callback=self.on_order_book_subscribe_response)
response_expectation = ResponseExpectation(self.on_order_book_subscribe_response, OrderBookSubscribeResponse, request_id, expect_success)
self.call_expectations_manager.add_expectation(response_expectation)
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert response_expectation.fulfilled, f"Expected response not received: {response_expectation}"
return response_expectation
def expect_on_instrument_event(self, instrument: Instrument, tick_size: float, order_book_id: int) -> None:
class OnInstrumentExpectation(CallExpectations):
def is_expected_call(self, mock_call: _Call) -> bool:
if len(mock_call.args) != 1:
return False
call_arg = mock_call.args[0]
if not isinstance(call_arg, OnInstrument):
return False
if call_arg.instrument.symbol != instrument.symbol:
return False
return True
def validate_call(self, mock_call: _Call) -> None:
call_arg = mock_call.args[0]
assert call_arg.instrument == instrument
assert tick_size == ANY or abs(call_arg.tick_size - tick_size) < 1e-5, f"Expected tick size {tick_size}, but got {call_arg.tick_size}"
assert call_arg.order_book_id == order_book_id
assert isinstance(self.on_instrument, MagicMock)
self.call_expectations_manager.add_expectation(OnInstrumentExpectation(self.on_instrument))
def expect_on_top_of_book_event(self, instrument_symbol: str, best_bid: PriceLevel | None, best_ask: PriceLevel | None) -> None:
class OnTopOfBookExpectation(CallExpectations):
def is_expected_call(self, mock_call: _Call) -> bool:
if len(mock_call.args) != 1:
return False
call_arg = mock_call.args[0]
if not isinstance(call_arg, OnTopOfBook):
return False
if call_arg.instrument_symbol != instrument_symbol:
return False
if (best_bid is not None) != call_arg.HasField("best_bid"):
return False
if (best_ask is not None) != call_arg.HasField("best_ask"):
return False
return True
def validate_call(self, mock_call: _Call) -> None:
call_arg = mock_call.args[0]
assert isinstance(call_arg, OnTopOfBook)
self._compare_price_levels(best_bid, call_arg.best_bid if call_arg.HasField("best_bid") else None)
self._compare_price_levels(best_ask, call_arg.best_ask if call_arg.HasField("best_ask") else None)
return True
def _compare_price_levels(self, expected: PriceLevel | None, actual: PriceLevel | None) -> None:
if expected is None:
assert actual is None
return
if actual is None:
assert expected is None
return
assert actual.quantity == expected.quantity, f"Expected quantity {expected.quantity}, but got {actual.quantity}"
assert abs(actual.price - expected.price) < 1e-5, f"Expected price {expected.price}, but got {actual.price}"
assert isinstance(self.on_top_of_book, MagicMock)
self.call_expectations_manager.add_expectation(OnTopOfBookExpectation(self.on_top_of_book))
def expect_on_price_depth_book_event(self, instrument_symbol: str, bids: list[PriceLevel], asks: list[PriceLevel]) -> None:
class OnPriceDepthBookExpectation(CallExpectations):
def is_expected_call(self, mock_call: _Call) -> bool:
if len(mock_call.args) != 1:
return False
call_arg = mock_call.args[0]
if not isinstance(call_arg, OnPriceDepthBook):
return False
if call_arg.instrument_symbol != instrument_symbol:
return False
if (len(bids) > 0) != (len(call_arg.bids) > 0):
return False
if (len(asks) > 0) != (len(call_arg.asks) > 0):
return False
return True
def validate_call(self, mock_call: _Call) -> None:
call_arg = mock_call.args[0]
self._compare_price_levels(bids, list(call_arg.bids))
self._compare_price_levels(asks, list(call_arg.asks))
def _compare_price_levels(self, expected: list[PriceLevel], actual: list[PriceLevel]) -> None:
differences = []
# Create dicts keyed by price for fast lookup (prices are unique per list)
expected_by_price = {decimal_from_float(pl.price): pl for pl in expected}
actual_by_price = {decimal_from_float(pl.price): pl for pl in actual}
# Check for missing prices and differences in quantity
for price in expected_by_price.keys():
if price not in actual_by_price:
differences.append(f"Missing actual price level at price {price}")
continue
exp = expected_by_price[price]
act = actual_by_price[price]
if exp.quantity != act.quantity:
differences.append(f"For price {price}: expected quantity {exp.quantity}, but got {act.quantity}")
# Check for extra actual price levels
for price in actual_by_price.keys():
if price not in expected_by_price:
differences.append(f"Extra actual price level at price {price}")
assert not differences, "\n".join(differences)
assert isinstance(self.on_price_depth_book, MagicMock)
self.call_expectations_manager.add_expectation(OnPriceDepthBookExpectation(self.on_price_depth_book))
def expect_on_trade_event(self, instrument_symbol: str, price: float, quantity: int,
aggressor_side: Side.ValueType) -> None:
class OnTradeExpectation(CallExpectations):
def is_expected_call(self, mock_call: _Call) -> bool:
if len(mock_call.args) != 1:
return False
call_arg = mock_call.args[0]
if not isinstance(call_arg, OnTrade):
return False
if call_arg.instrument_symbol != instrument_symbol:
return False
return True
def validate_call(self, mock_call: _Call) -> None:
call_arg = mock_call.args[0]
assert abs(call_arg.price - price) < 1e-5, f"Expected price {price}, got {call_arg.price}"
assert call_arg.quantity == quantity, f"Expected quantity {quantity}, got {call_arg.quantity}"
assert call_arg.aggressor_side == aggressor_side, \
f"Expected aggressor side {Side.Name(aggressor_side)}, got {Side.Name(call_arg.aggressor_side)}"
assert isinstance(self.on_trade, MagicMock)
self.call_expectations_manager.add_expectation(OnTradeExpectation(self.on_trade))
def on_instrument(self, message: OnInstrument) -> None:
# This method is mocked with MagicMock
pass
def on_top_of_book(self, message: OnTopOfBook) -> None:
# This method is mocked with MagicMock
pass
def on_price_depth_book(self, message: OnPriceDepthBook) -> None:
# This method is mocked with MagicMock
pass
def on_trade(self, message: OnTrade) -> None:
# This method is mocked with MagicMock
pass
class InfoClientConnectionHandlerFactory(ConnectionHandlerFactory[InfoTestClient]):
def __init__(self, call_expectations_manager: CallExpectationsManager) -> None:
self.call_expectations_manager = call_expectations_manager
def on_new_connection(self, socket_fd: socket.socket, ip_address: IpAddress, close_callback: Callable[[], None]) -> InfoTestClient:
return InfoTestClient(socket_fd, ip_address, close_callback, self.call_expectations_manager)
def on_connection_closed(self, connection_handler: InfoTestClient) -> None:
pass

View File

@ -0,0 +1,263 @@
from decimal import Decimal
import socket
from typing import Callable
from unittest.mock import ANY, _Call, MagicMock
import logging
from common.order_book_client import BaseOrderBookClient
from common.utils import decimal_from_float
from connection.connection_handler import ConnectionHandlerFactory
from connection.ip_address import IpAddress
from proto.order_book_pb2 import * # noqa: F403
from proto.order_book_pb2 import OnOrderBookCreated
from proto.common_pb2 import Side
from tests.common.mock_expectations import CallExpectations, CallExpectationsManager, ResponseExpectation
logger = logging.getLogger(__name__)
class TradeExpectation(CallExpectations):
def __init__(self, on_trade_mock: MagicMock, trade_id: int, order_book_id: int, buy_order_id: int, sell_order_id: int,
aggressor_side: Side.ValueType, price: Decimal, quantity: int):
super().__init__(on_trade_mock)
self.trade_id = trade_id
self.order_book_id = order_book_id
self.buy_order_id = buy_order_id
self.sell_order_id = sell_order_id
self.aggressor_side = aggressor_side
self.price = price
self.quantity = quantity
self.on_trade_message: OnTrade | None = None
def is_expected_call(self, mock_call: _Call) -> bool:
if len(mock_call.args) != 1:
return False
call_arg = mock_call.args[0]
if not isinstance(call_arg, OnTrade):
return False
if call_arg.trade_id != self.trade_id:
return False
return True
def validate_call(self, mock_call: _Call) -> None:
call_arg = mock_call.args[0]
assert call_arg.order_book_id == self.order_book_id
assert call_arg.buy_order_id == self.buy_order_id
assert call_arg.sell_order_id == self.sell_order_id
assert call_arg.aggressor_side == self.aggressor_side
assert self.price == ANY or abs(decimal_from_float(call_arg.price) - self.price) < 1e-4, f"Expected price {self.price}, but got {call_arg.price}"
assert call_arg.quantity == self.quantity
self.on_trade_message = call_arg
def get_message(self) -> OnTrade:
assert self.on_trade_message is not None, "Trade message not received"
return self.on_trade_message
class OrderBookCreatedExpectation(CallExpectations):
def __init__(self, on_order_book_created_mock: MagicMock, tick_size: Decimal, order_book_id: int | None = None):
super().__init__(on_order_book_created_mock)
self.tick_size = tick_size
self.expected_order_book_id = order_book_id
self.on_order_book_created_message: OnOrderBookCreated | None = None
def is_expected_call(self, mock_call: _Call) -> bool:
if len(mock_call.args) != 1:
return False
call_arg = mock_call.args[0]
if not isinstance(call_arg, OnOrderBookCreated):
return False
if self.expected_order_book_id is not None:
if call_arg.order_book_id != self.expected_order_book_id:
return False
return True
def validate_call(self, mock_call: _Call) -> None:
call_arg = mock_call.args[0]
assert abs(decimal_from_float(call_arg.tick_size) - self.tick_size) < 1e-5, f"Expected tick size {self.tick_size}, but got {call_arg.tick_size}"
self.on_order_book_created_message = call_arg
def get_message(self) -> OnOrderBookCreated:
assert self.on_order_book_created_message is not None, "OnOrderBookCreated message not received"
return self.on_order_book_created_message
class OrderInsertedExpectation(CallExpectations):
def __init__(self, on_order_inserted_mock: MagicMock, order_id: int, order_book_id: int,
side: Side.ValueType, price: Decimal, quantity: int, username: str | None = None):
super().__init__(on_order_inserted_mock)
self.order_id = order_id
self.order_book_id = order_book_id
self.side = side
self.price = price
self.quantity = quantity
self.expected_username = username
self.on_order_inserted_message: OnOrderInserted | None = None
def is_expected_call(self, mock_call: _Call) -> bool:
if len(mock_call.args) != 1:
return False
call_arg = mock_call.args[0]
if not isinstance(call_arg, OnOrderInserted):
return False
if call_arg.order_id != self.order_id:
return False
return True
def validate_call(self, mock_call: _Call) -> None:
call_arg = mock_call.args[0]
assert call_arg.order_book_id == self.order_book_id
assert call_arg.side == self.side
assert self.price == ANY or abs(decimal_from_float(call_arg.price) - self.price) < 1e-5, \
f"Expected price {self.price}, but got {call_arg.price}"
assert call_arg.quantity == self.quantity
if self.expected_username is not None:
assert call_arg.username == self.expected_username, \
f"Expected username '{self.expected_username}', but got '{call_arg.username}'"
self.on_order_inserted_message = call_arg
def get_message(self) -> OnOrderInserted:
assert self.on_order_inserted_message is not None, "OnOrderInserted message not received"
return self.on_order_inserted_message
class OrderCancelledExpectation(CallExpectations):
def __init__(self, on_order_cancelled_mock: MagicMock, order_id: int):
super().__init__(on_order_cancelled_mock)
self.order_id = order_id
self.on_order_cancelled_message: OnOrderCancelled | None = None
def is_expected_call(self, mock_call: _Call) -> bool:
if len(mock_call.args) != 1:
return False
call_arg = mock_call.args[0]
if not isinstance(call_arg, OnOrderCancelled):
return False
if call_arg.order_id != self.order_id:
return False
return True
def validate_call(self, mock_call: _Call) -> None:
self.on_order_cancelled_message = mock_call.args[0]
def get_message(self) -> OnOrderCancelled:
assert self.on_order_cancelled_message is not None, "OnOrderCancelled message not received"
return self.on_order_cancelled_message
class OrderBookTestClient(BaseOrderBookClient):
def __init__(self, socket_fd: socket.socket, ip_address: IpAddress, on_close: Callable[[], None],
call_expectations_manager: CallExpectationsManager) -> None:
super().__init__(socket_fd, ip_address, on_close)
self._last_order_book_id: int | None = None
self._last_order_id: int | None = None
self.call_expectations_manager = call_expectations_manager
self.on_order_book_created = self.call_expectations_manager.create_mock("on_order_book_created") # type: ignore
self.on_insert_order_response = self.call_expectations_manager.create_mock("on_insert_order_response")
self.on_cancel_order_response = self.call_expectations_manager.create_mock("on_cancel_order_response")
self.on_order_inserted = self.call_expectations_manager.create_mock("on_order_inserted") # type: ignore
self.on_order_cancelled = self.call_expectations_manager.create_mock("on_order_cancelled") # type: ignore
self.on_trade = self.call_expectations_manager.create_mock("on_trade") # type: ignore
def test_insert_order(self, side: Side.ValueType, price: Decimal, quantity: int, order_book_id: int | None = None,
username: str = "test_user",
expect_success: bool = True, expect_public_feed: bool = True) -> ResponseExpectation[InsertOrderResponse]:
if order_book_id is None:
assert self._last_order_book_id is not None, "No order book ID provided"
order_book_id = self._last_order_book_id
insert_order_request = InsertOrderRequest(order_book_id=order_book_id, side=side, price=float(price), quantity=quantity,
username=username)
request_id = self.send_insert_order(insert_order_request, callback=self.on_insert_order_response)
response_expectation = ResponseExpectation(self.on_insert_order_response, InsertOrderResponse, request_id, expect_success)
self.call_expectations_manager.add_expectation(response_expectation)
if expect_success:
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert response_expectation.fulfilled, f"Expected response not received: {response_expectation}"
response = response_expectation.get_response()
self._last_order_id = response.order_id
if expect_success and expect_public_feed:
self.expect_on_order_inserted_event(order_id=response.order_id, order_book_id=order_book_id, side=side,
price=price, quantity=quantity)
return response_expectation
def test_cancel_order(self, order_id: int | None = None, order_book_id: int | None = None,
expect_success: bool = True, expect_public_feed: bool = True) -> ResponseExpectation[CancelOrderResponse]:
if order_book_id is None:
assert self._last_order_book_id is not None, "No order book ID provided"
order_book_id = self._last_order_book_id
if order_id is None:
assert self._last_order_id is not None, "No order ID provided"
order_id = self._last_order_id
cancel_order_request = CancelOrderRequest(order_book_id=order_book_id, order_id=order_id)
request_id = self.send_cancel_order(cancel_order_request, callback=self.on_cancel_order_response)
response_expectation = ResponseExpectation(self.on_cancel_order_response, CancelOrderResponse, request_id, expect_success)
self.call_expectations_manager.add_expectation(response_expectation)
if expect_success:
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert response_expectation.fulfilled, f"Expected response not received: {response_expectation}"
if expect_success and expect_public_feed:
self.expect_on_order_cancelled_event(order_id=order_id)
return response_expectation
def expect_on_order_book_created_event(self, tick_size: Decimal, order_book_id: int | None = None) -> OrderBookCreatedExpectation:
assert isinstance(self.on_order_book_created, MagicMock)
expectation = OrderBookCreatedExpectation(self.on_order_book_created, tick_size, order_book_id)
self.call_expectations_manager.add_expectation(expectation)
return expectation
def expect_on_order_inserted_event(self, order_id: int, order_book_id: int, side: Side.ValueType,
price: Decimal, quantity: int,
username: str | None = None) -> OrderInsertedExpectation:
assert isinstance(self.on_order_inserted, MagicMock)
expectation = OrderInsertedExpectation(
self.on_order_inserted, order_id, order_book_id, side, price, quantity, username)
self.call_expectations_manager.add_expectation(expectation)
return expectation
def expect_on_order_cancelled_event(self, order_id: int) -> OrderCancelledExpectation:
assert isinstance(self.on_order_cancelled, MagicMock)
expectation = OrderCancelledExpectation(self.on_order_cancelled, order_id)
self.call_expectations_manager.add_expectation(expectation)
return expectation
def expect_on_trade_event(self, trade_id: int, order_book_id: int, buy_order_id: int, sell_order_id: int,
aggressive_side: Side.ValueType, price: Decimal, quantity: int) -> TradeExpectation:
assert isinstance(self.on_trade, MagicMock)
trade_expectation = TradeExpectation(self.on_trade, trade_id, order_book_id, buy_order_id, sell_order_id,
aggressive_side, price, quantity)
self.call_expectations_manager.add_expectation(trade_expectation)
return trade_expectation
def on_order_book_created(self, message: OnOrderBookCreated) -> None:
pass
def on_order_inserted(self, message: OnOrderInserted) -> None:
pass
def on_order_cancelled(self, message: OnOrderCancelled) -> None:
pass
def on_trade(self, message: OnTrade) -> None:
pass
class OrderBookClientConnectionHandlerFactory(ConnectionHandlerFactory[OrderBookTestClient]):
def __init__(self, call_expectations_manager: CallExpectationsManager) -> None:
self.call_expectations_manager = call_expectations_manager
def on_new_connection(self, socket_fd: socket.socket, ip_address: IpAddress, close_callback: Callable[[], None]) -> OrderBookTestClient:
return OrderBookTestClient(socket_fd, ip_address, close_callback, self.call_expectations_manager)
def on_connection_closed(self, connection_handler: OrderBookTestClient) -> None:
pass

View File

@ -0,0 +1,248 @@
import logging
import socket
from decimal import Decimal
from typing import Callable
from unittest.mock import MagicMock
from connection.connection_handler import ConnectionHandler, ConnectionHandlerFactory
from connection.ip_address import IpAddress
from google.protobuf.message import Message
from proto.common_pb2 import LoginRequest, LoginResponse, MessageType, Side
from proto.execution_pb2 import (
InsertOrderRequest, InsertOrderResponse,
CancelOrderRequest, CancelOrderResponse,
)
from proto.risk_limits_pb2 import (
GetUserRiskLimitsRequest, GetUserRiskLimitsResponse,
SetUserRiskLimitsRequest, SetUserRiskLimitsResponse,
GetInstrumentRiskLimitsRequest, GetInstrumentRiskLimitsResponse,
SetInstrumentRiskLimitsRequest, SetInstrumentRiskLimitsResponse,
UserRiskLimits, InstrumentRiskLimits, RollingWindowLimit,
)
from tests.conftest import DEFAULT_PASSWORD
from tests.common.mock_expectations import CallExpectationsManager, ResponseExpectation
logger = logging.getLogger(__name__)
class RiskGatewayTestClient(ConnectionHandler):
"""Test client that speaks both the execution and risk_limits protocols."""
def __init__(self, socket_fd: socket.socket, ip_address: IpAddress,
on_close: Callable[[], None],
call_expectations_manager: CallExpectationsManager) -> None:
super().__init__(socket_fd, ip_address, on_close)
self.next_request_id = 1
self.callbacks: dict[int, Callable] = {}
self.call_expectations_manager = call_expectations_manager
self._last_order_id: int | None = None
self.on_login_response = call_expectations_manager.create_mock("on_login_response")
self.on_insert_order_response = call_expectations_manager.create_mock("on_insert_order_response")
self.on_cancel_order_response = call_expectations_manager.create_mock("on_cancel_order_response")
self.on_get_user_limits_response = call_expectations_manager.create_mock("on_get_user_limits_response")
self.on_set_user_limits_response = call_expectations_manager.create_mock("on_set_user_limits_response")
self.on_get_instrument_limits_response = call_expectations_manager.create_mock("on_get_instrument_limits_response")
self.on_set_instrument_limits_response = call_expectations_manager.create_mock("on_set_instrument_limits_response")
def on_disconnect(self) -> None:
logger.info(f"Disconnected from {self.ip_address}")
def handle_message(self, message_type: int, raw_message: bytes) -> None:
logger.info(f"Handling message of type {MessageType.Name(message_type)}")
msg: Message
if message_type == MessageType.AUTH_LOGIN_RESPONSE:
msg = LoginResponse.FromString(raw_message)
elif message_type == MessageType.EXEC_INSERT_ORDER_RESPONSE:
msg = InsertOrderResponse.FromString(raw_message)
elif message_type == MessageType.EXEC_CANCEL_ORDER_RESPONSE:
msg = CancelOrderResponse.FromString(raw_message)
elif message_type == MessageType.RISK_GET_USER_LIMITS_RESPONSE:
msg = GetUserRiskLimitsResponse.FromString(raw_message)
elif message_type == MessageType.RISK_SET_USER_LIMITS_RESPONSE:
msg = SetUserRiskLimitsResponse.FromString(raw_message)
elif message_type == MessageType.RISK_GET_INSTRUMENT_LIMITS_RESPONSE:
msg = GetInstrumentRiskLimitsResponse.FromString(raw_message)
elif message_type == MessageType.RISK_SET_INSTRUMENT_LIMITS_RESPONSE:
msg = SetInstrumentRiskLimitsResponse.FromString(raw_message)
else:
raise ValueError(f"Unexpected message type: {MessageType.Name(message_type)}")
assert hasattr(msg, "request_id")
callback = self.callbacks.pop(msg.request_id, None)
if callback:
callback(msg)
else:
raise ValueError(f"No callback for request_id: {msg.request_id}")
# ------------------------------------------------------------------
# Low-level helpers
# ------------------------------------------------------------------
def _send_request(self, message_type: int, message: Message,
callback: Callable) -> int:
request_id = self.next_request_id
self.next_request_id += 1
message.request_id = request_id # type: ignore[union-attr]
self.callbacks[request_id] = callback
self.send_message(message_type, message)
return request_id
# ------------------------------------------------------------------
# Auth
# ------------------------------------------------------------------
def test_login(self, username: str, password: str = DEFAULT_PASSWORD,
expect_success: bool = True) -> ResponseExpectation[LoginResponse]:
request = LoginRequest(username=username, password=password)
request_id = self._send_request(
MessageType.AUTH_LOGIN_REQUEST, request, self.on_login_response)
expectation = ResponseExpectation(
self.on_login_response, LoginResponse, request_id, expect_success)
self.call_expectations_manager.add_expectation(expectation)
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert expectation.fulfilled, "Login response not received"
return expectation
# ------------------------------------------------------------------
# Execution
# ------------------------------------------------------------------
def test_insert_order(
self, instrument_symbol: str, side: Side.ValueType,
price: Decimal, quantity: int,
expect_success: bool = True,
) -> ResponseExpectation[InsertOrderResponse]:
request = InsertOrderRequest(
instrument_symbol=instrument_symbol, side=side,
price=float(price), quantity=quantity)
request_id = self._send_request(
MessageType.EXEC_INSERT_ORDER_REQUEST, request,
self.on_insert_order_response)
expectation = ResponseExpectation(
self.on_insert_order_response, InsertOrderResponse,
request_id, expect_success)
self.call_expectations_manager.add_expectation(expectation)
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert expectation.fulfilled, "Insert order response not received"
if expect_success:
self._last_order_id = expectation.get_response().order_id
return expectation
def test_cancel_order(
self, instrument_symbol: str, order_id: int | None = None,
expect_success: bool = True,
) -> ResponseExpectation[CancelOrderResponse]:
if order_id is None:
assert self._last_order_id is not None, "No order to cancel"
order_id = self._last_order_id
request = CancelOrderRequest(
instrument_symbol=instrument_symbol, order_id=order_id)
request_id = self._send_request(
MessageType.EXEC_CANCEL_ORDER_REQUEST, request,
self.on_cancel_order_response)
expectation = ResponseExpectation(
self.on_cancel_order_response, CancelOrderResponse,
request_id, expect_success)
self.call_expectations_manager.add_expectation(expectation)
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert expectation.fulfilled, "Cancel order response not received"
return expectation
# ------------------------------------------------------------------
# Risk Limits user level
# ------------------------------------------------------------------
def test_set_user_risk_limits(
self, user_risk_limits: UserRiskLimits | None = None,
expect_success: bool = True,
) -> ResponseExpectation[SetUserRiskLimitsResponse]:
request = SetUserRiskLimitsRequest()
if user_risk_limits is not None:
request.user_risk_limits.CopyFrom(user_risk_limits)
request_id = self._send_request(
MessageType.RISK_SET_USER_LIMITS_REQUEST, request,
self.on_set_user_limits_response)
expectation = ResponseExpectation(
self.on_set_user_limits_response, SetUserRiskLimitsResponse,
request_id, expect_success)
self.call_expectations_manager.add_expectation(expectation)
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert expectation.fulfilled, "Set user risk limits response not received"
return expectation
def test_get_user_risk_limits(
self, expect_success: bool = True,
) -> ResponseExpectation[GetUserRiskLimitsResponse]:
request = GetUserRiskLimitsRequest()
request_id = self._send_request(
MessageType.RISK_GET_USER_LIMITS_REQUEST, request,
self.on_get_user_limits_response)
expectation = ResponseExpectation(
self.on_get_user_limits_response, GetUserRiskLimitsResponse,
request_id, expect_success)
self.call_expectations_manager.add_expectation(expectation)
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert expectation.fulfilled, "Get user risk limits response not received"
return expectation
# ------------------------------------------------------------------
# Risk Limits instrument level
# ------------------------------------------------------------------
def test_set_instrument_risk_limits(
self, instrument_symbol: str,
instrument_risk_limits: InstrumentRiskLimits | None = None,
expect_success: bool = True,
) -> ResponseExpectation[SetInstrumentRiskLimitsResponse]:
request = SetInstrumentRiskLimitsRequest(instrument_symbol=instrument_symbol)
if instrument_risk_limits is not None:
request.instrument_risk_limits.CopyFrom(instrument_risk_limits)
request_id = self._send_request(
MessageType.RISK_SET_INSTRUMENT_LIMITS_REQUEST, request,
self.on_set_instrument_limits_response)
expectation = ResponseExpectation(
self.on_set_instrument_limits_response, SetInstrumentRiskLimitsResponse,
request_id, expect_success)
self.call_expectations_manager.add_expectation(expectation)
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert expectation.fulfilled, "Set instrument risk limits response not received"
return expectation
def test_get_instrument_risk_limits(
self, expect_success: bool = True,
) -> ResponseExpectation[GetInstrumentRiskLimitsResponse]:
request = GetInstrumentRiskLimitsRequest()
request_id = self._send_request(
MessageType.RISK_GET_INSTRUMENT_LIMITS_REQUEST, request,
self.on_get_instrument_limits_response)
expectation = ResponseExpectation(
self.on_get_instrument_limits_response, GetInstrumentRiskLimitsResponse,
request_id, expect_success)
self.call_expectations_manager.add_expectation(expectation)
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
assert expectation.fulfilled, "Get instrument risk limits response not received"
return expectation
class RiskGatewayClientConnectionHandlerFactory(ConnectionHandlerFactory[RiskGatewayTestClient]):
def __init__(self, call_expectations_manager: CallExpectationsManager) -> None:
self.call_expectations_manager = call_expectations_manager
def on_new_connection(self, socket_fd: socket.socket, ip_address: IpAddress,
close_callback: Callable[[], None]) -> RiskGatewayTestClient:
return RiskGatewayTestClient(
socket_fd, ip_address, close_callback, self.call_expectations_manager)
def on_connection_closed(self, connection_handler: RiskGatewayTestClient) -> None:
pass

View File

@ -0,0 +1,139 @@
from decimal import Decimal
import logging
import pytest
from connection.tcp_connection_manager import TcpConnectionManager
from proto.common_pb2 import Instrument, Side
from tests.conftest import DEFAULT_PASSWORD
from tests.common.auth_tests import AuthenticationTests
from tests.common.mock_expectations import CallExpectationsManager
from tests.test_client.admin_test_client import (
AdminClientConnectionHandlerFactory, AdminTestClient,
)
from tests.test_client.risk_gateway_test_client import (
RiskGatewayClientConnectionHandlerFactory, RiskGatewayTestClient,
)
logger = logging.getLogger(__name__)
class TestExecutionSystem(AuthenticationTests):
PROTOCOL = "execution"
@pytest.fixture(autouse=True)
def setup(self, request: pytest.FixtureRequest) -> None:
self.test_name = request.node.name
logger.info(f"Setting up test: {self.test_name}")
self.tcp_connection_manager = TcpConnectionManager()
self.call_expectations_manager = CallExpectationsManager()
self.call_expectations_manager.setup_network(self.tcp_connection_manager)
self._client_factory = RiskGatewayClientConnectionHandlerFactory(self.call_expectations_manager)
self._admin_factory = AdminClientConnectionHandlerFactory(self.call_expectations_manager)
self._admin_client: AdminTestClient | None = None
self._next_instrument_id = 1
def _connect_unauthenticated(self) -> RiskGatewayTestClient:
return self.tcp_connection_manager.connect(self.server_address, self._client_factory)
def _connect_and_login(self, username: str | None = None) -> RiskGatewayTestClient:
"""Connect to the execution server, log in (if auth required), and return the client."""
client: RiskGatewayTestClient = self._connect_unauthenticated()
if self.auth_required:
client.test_login(username=username or self.test_name, password=DEFAULT_PASSWORD)
return client
def _get_admin_client(self) -> AdminTestClient:
if self._admin_client is None:
admin_address = self.orchestrator.get_server_address("admin")
self._admin_client = self.tcp_connection_manager.connect(
admin_address, self._admin_factory)
return self._admin_client
def _create_instrument_via_admin(self, tick_size: Decimal = Decimal("0.01")) -> str:
"""Create an instrument via admin and return its symbol."""
symbol = f"TEST.{self._next_instrument_id}"
self._next_instrument_id += 1
instrument = Instrument(
symbol=symbol, description="Test instrument",
currency="USD", multiplier=1.0)
admin = self._get_admin_client()
admin.test_create_instrument(instrument, tick_size)
return symbol
# =========================================================================
# Authentication
# =========================================================================
def test_insert_order_before_login_is_rejected(self) -> None:
"""Operations that require authentication must fail before login."""
if not self.auth_required:
pytest.skip("Component does not require authentication")
symbol = self._create_instrument_via_admin()
client = self._connect_unauthenticated()
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 10,
expect_success=False)
# =========================================================================
# Insert order
# =========================================================================
def test_insert_order(self) -> None:
symbol = self._create_instrument_via_admin()
client = self._connect_and_login()
resp = client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 10)
response = resp.get_response()
assert response.order_id != 0, "InsertOrderResponse must include an order_id"
assert response.timestamp != 0, "InsertOrderResponse must include a timestamp"
def test_insert_order_on_unknown_instrument(self) -> None:
self._create_instrument_via_admin()
client = self._connect_and_login()
client.test_insert_order("UNKNOWN.SYMBOL", Side.BUY, Decimal("100.0"), 10,
expect_success=False)
def test_insert_order_matching_produces_trade(self) -> None:
symbol = self._create_instrument_via_admin()
buyer = self._connect_and_login(f"{self.test_name}_2")
seller = self._connect_and_login(f"{self.test_name}_3")
buyer.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 10)
sell_resp = seller.test_insert_order(symbol, Side.SELL, Decimal("100.0"), 10)
sell_response = sell_resp.get_response()
assert len(sell_response.trade_ids) == 1, "Expected exactly one trade"
assert sell_response.traded_quantity == 10
def test_orders_pass_without_any_limits_configured(self) -> None:
"""When no risk limits are set, orders must flow through without restriction."""
symbol = self._create_instrument_via_admin()
client = self._connect_and_login()
for _ in range(10):
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 1000)
# =========================================================================
# Cancel order
# =========================================================================
def test_cancel_resting_order(self) -> None:
symbol = self._create_instrument_via_admin()
client = self._connect_and_login()
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 10)
client.test_cancel_order(symbol)
def test_cancel_nonexistent_order(self) -> None:
symbol = self._create_instrument_via_admin()
client = self._connect_and_login()
client.test_cancel_order(symbol, order_id=999999, expect_success=False)
def test_cancel_other_users_order_is_rejected(self) -> None:
"""A user must not be able to cancel another user's order."""
symbol = self._create_instrument_via_admin()
client_a = self._connect_and_login(f"{self.test_name}_2")
client_b = self._connect_and_login(f"{self.test_name}_3")
resp_a = client_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 10)
order_id_a = resp_a.get_response().order_id
client_b.test_cancel_order(symbol, order_id=order_id_a, expect_success=False)

View File

@ -0,0 +1,941 @@
from decimal import Decimal
import logging
import math
import random
from unittest.mock import ANY
import pytest
from connection.tcp_connection_manager import TcpConnectionManager
from proto.common_pb2 import Instrument, Side
from proto.info_pb2 import PriceLevel, SubscriptionType
from tests.test_client.admin_test_client import AdminClientConnectionHandlerFactory, AdminTestClient
from tests.test_client.info_test_client import InfoClientConnectionHandlerFactory, InfoTestClient
from tests.test_client.order_book_test_client import OrderBookClientConnectionHandlerFactory, OrderBookTestClient
from tests.common.auth_tests import AuthenticationTests
from tests.common.mock_expectations import CallExpectationsManager
logger = logging.getLogger(__name__)
class TestInfoSystem(AuthenticationTests):
PROTOCOL = "info"
@pytest.fixture(autouse=True)
def setup(self, request: pytest.FixtureRequest) -> None:
self.test_name = request.node.name
logger.info(f"Setting up test: {self.test_name}")
self.tcp_connection_manager = TcpConnectionManager()
self.call_expectations_manager = CallExpectationsManager()
self.call_expectations_manager.setup_network(self.tcp_connection_manager)
self.info_client_factory = InfoClientConnectionHandlerFactory(self.call_expectations_manager)
self._admin_factory = AdminClientConnectionHandlerFactory(self.call_expectations_manager)
self._order_book_factory = OrderBookClientConnectionHandlerFactory(self.call_expectations_manager)
self._admin_client: AdminTestClient | None = None
self._order_book_client: OrderBookTestClient | None = None
self._next_instrument_id = 1
def _connect_unauthenticated(self) -> InfoTestClient:
return self.tcp_connection_manager.connect(self.server_address, self.info_client_factory)
def _connect_and_login(self, username: str | None = None) -> InfoTestClient:
"""Connects to info, logs in (if auth required), and expects OnInstrument for existing instruments."""
client = self._connect_unauthenticated()
if self.auth_required:
client.test_login(username=username or self.test_name)
if self.instrument is not None:
client.expect_on_instrument_event(self.instrument, tick_size=self.tick_size, order_book_id=self.order_book_id)
return client
def _get_admin_client(self) -> AdminTestClient:
if self._admin_client is None:
admin_address = self.orchestrator.get_server_address("admin")
self._admin_client = self.tcp_connection_manager.connect(admin_address, self._admin_factory)
return self._admin_client
def _get_order_book_client(self) -> OrderBookTestClient:
if self._order_book_client is None:
ob_address = self.orchestrator.get_server_address("order_book")
self._order_book_client = self.tcp_connection_manager.connect(ob_address, self._order_book_factory)
if self.order_book_id is not None:
self._order_book_client.expect_on_order_book_created_event(
tick_size=self.tick_size, order_book_id=self.order_book_id)
self.call_expectations_manager.verify_expectations(assert_no_pending_calls=False)
self._order_book_client._last_order_book_id = self.order_book_id
return self._order_book_client
def _create_instrument_via_admin(self, instrument: Instrument | None = None, tick_size: Decimal = Decimal("0.01")) -> tuple[Instrument, Decimal, int]:
"""Creates a test instrument via admin. Sets self.instrument, self.tick_size, self.order_book_id.
Returns (instrument, tick_size, order_book_id) for convenience in multi-instrument tests."""
self.tick_size = tick_size
if instrument is None:
self.instrument = Instrument(
symbol=f"TEST.{self._next_instrument_id}",
description="Test instrument",
currency="USD",
multiplier=1.0)
self._next_instrument_id += 1
else:
self.instrument = instrument
admin_client = self._get_admin_client()
response_expectation = admin_client.test_create_instrument(self.instrument, self.tick_size)
response = response_expectation.get_response()
self.order_book_id = response.order_book_id
return self.instrument, self.tick_size, self.order_book_id
def _expect_book_update(self, client: InfoTestClient, subscription_type,
instrument_symbol: str,
expected_bids: dict[Decimal, int],
expected_asks: dict[Decimal, int]) -> None:
"""Set up the correct TOB or PDB expectation based on subscription type."""
if subscription_type == SubscriptionType.TOP_OF_BOOK:
best_bid = None
best_ask = None
if expected_bids:
best_price = max(expected_bids)
best_bid = PriceLevel(price=float(best_price), quantity=expected_bids[best_price])
if expected_asks:
best_price = min(expected_asks)
best_ask = PriceLevel(price=float(best_price), quantity=expected_asks[best_price])
client.expect_on_top_of_book_event(instrument_symbol, best_bid=best_bid, best_ask=best_ask)
elif subscription_type == SubscriptionType.PRICE_DEPTH_BOOK:
client.expect_on_price_depth_book_event(
instrument_symbol,
bids=[PriceLevel(price=float(p), quantity=q) for p, q in expected_bids.items()],
asks=[PriceLevel(price=float(p), quantity=q) for p, q in expected_asks.items()])
def _insert_orders(self, side: Side.ValueType | None, orders_per_level: int, number_of_levels: int,
best_bid_price: Decimal, best_ask_price: Decimal, quantity_per_order: int,
add_cancellation: bool = False) -> tuple[dict[Decimal, int], dict[Decimal, int]]:
"""Insert orders into the real order book and return the expected price level state."""
order_book_client = self._get_order_book_client()
def next_side(i: int) -> Side.ValueType:
if side is not None:
return side
return Side.BUY if i % 2 == 0 else Side.SELL
def next_price_adjustment(i: int) -> Decimal:
# if side is None we always alternate between buy and sell, so the adjustment needs to be divided by 2
tick_adjustment = math.floor(i / 2) if side is None else i
return self.tick_size * tick_adjustment
price_levels_by_side: dict[Side.ValueType, dict[Decimal, int]] = {Side.BUY: {}, Side.SELL: {}}
for level_index in range(number_of_levels):
level_side = next_side(level_index)
price_adjustment = next_price_adjustment(level_index)
if level_side == Side.BUY:
level_price = best_bid_price - price_adjustment
else:
level_price = best_ask_price + price_adjustment
price_levels_by_side[level_side][level_price] = quantity_per_order * orders_per_level
for i in range(orders_per_level):
order_book_client.test_insert_order(
side=level_side, price=level_price, quantity=quantity_per_order,
order_book_id=self.order_book_id, username=f"{self.test_name}_orders")
is_top_of_book_level = price_adjustment == 0
if add_cancellation and not is_top_of_book_level and (level_index + i) % 2 == 0:
order_book_client.test_cancel_order(order_book_id=self.order_book_id)
price_levels_by_side[level_side][level_price] -= quantity_per_order
if price_levels_by_side[level_side][level_price] == 0:
del price_levels_by_side[level_side][level_price]
bids = price_levels_by_side[Side.BUY]
asks = price_levels_by_side[Side.SELL]
return bids, asks
def test_create_instrument(self) -> None:
self._create_instrument_via_admin(
instrument=Instrument(symbol="AAPL", description="Apple", currency="USD", multiplier=1))
client = self._connect_and_login(username="test_create_instrument")
self.call_expectations_manager.verify_expectations()
self.call_expectations_manager.verify_no_unexpected_calls()
@pytest.mark.parametrize(
"subscription_type, side, orders_per_level, number_of_levels",
[
pytest.param(SubscriptionType.TOP_OF_BOOK, Side.BUY, 1, 1, id="TOB 1 buy order"),
pytest.param(SubscriptionType.TOP_OF_BOOK, Side.SELL, 1, 1, id="TOB 1 sell order"),
pytest.param(SubscriptionType.TOP_OF_BOOK, Side.BUY, 2, 1, id="TOB 2 buy orders, 1 level"),
pytest.param(SubscriptionType.TOP_OF_BOOK, Side.SELL, 2, 1, id="TOB 2 sell orders, 1 level"),
pytest.param(SubscriptionType.TOP_OF_BOOK, Side.BUY, 1, 2, id="TOB 1 buy order by 2 levels"),
pytest.param(SubscriptionType.TOP_OF_BOOK, Side.SELL, 1, 2, id="TOB 1 sell order by 2 levels"),
pytest.param(SubscriptionType.TOP_OF_BOOK, None, 1, 2, id="TOB 1 order by level, both sides"),
pytest.param(SubscriptionType.TOP_OF_BOOK, None, 2, 2, id="TOB 2 orders by 2 levels, both sides"),
pytest.param(SubscriptionType.TOP_OF_BOOK, None, 5, 2, id="TOB 5 orders by 2 levels, both sides (unbalanced)"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, Side.BUY, 1, 1, id="PDB 1 buy order"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, Side.SELL, 1, 1, id="PDB 1 sell order"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, Side.BUY, 2, 1, id="PDB 2 buy orders, 1 level"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, Side.SELL, 2, 1, id="PDB 2 sell orders, 1 level"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, Side.BUY, 1, 2, id="PDB 1 buy order by 2 levels"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, Side.SELL, 1, 2, id="PDB 1 sell order by 2 levels"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, None, 1, 2, id="PDB 1 order by level, both sides"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, None, 2, 2, id="PDB 2 orders by 2 levels, both sides"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, None, 5, 2, id="PDB 5 orders by 2 levels, both sides (unbalanced)"),
]
)
def test_order_book_on_subscribe_with_insert(self, subscription_type: SubscriptionType.ValueType, side: Side.ValueType | None, orders_per_level: int, number_of_levels: int) -> None:
self._create_instrument_via_admin()
client = self._connect_and_login()
best_bid_price = random.randint(100, 10000) * self.tick_size
best_ask_price = best_bid_price + self.tick_size
quantity_per_order = random.randint(1, 1000)
expected_bids, expected_asks = self._insert_orders(
side, orders_per_level, number_of_levels,
best_bid_price, best_ask_price, quantity_per_order)
logger.info("Checking that no updates were sent before subscribing")
self.call_expectations_manager.verify_expectations()
self.call_expectations_manager.verify_no_unexpected_calls()
if subscription_type == SubscriptionType.TOP_OF_BOOK:
best_bid = PriceLevel(price=float(best_bid_price), quantity=expected_bids[best_bid_price]) if side is None or side == Side.BUY else None
best_ask = PriceLevel(price=float(best_ask_price), quantity=expected_asks[best_ask_price]) if side is None or side == Side.SELL else None
client.test_subscribe_to_order_book(instrument_symbol=self.instrument.symbol, type=SubscriptionType.TOP_OF_BOOK)
client.expect_on_top_of_book_event(self.instrument.symbol, best_bid=best_bid, best_ask=best_ask)
elif subscription_type == SubscriptionType.PRICE_DEPTH_BOOK:
client.test_subscribe_to_order_book(instrument_symbol=self.instrument.symbol, type=SubscriptionType.PRICE_DEPTH_BOOK)
client.expect_on_price_depth_book_event(
self.instrument.symbol,
bids=[PriceLevel(price=float(price), quantity=quantity) for price, quantity in expected_bids.items()],
asks=[PriceLevel(price=float(price), quantity=quantity) for price, quantity in expected_asks.items()])
else:
raise ValueError(f"Unexpected subscription type: {subscription_type}")
self.call_expectations_manager.verify_expectations()
@pytest.mark.parametrize(
"subscription_type, side, orders_per_level, number_of_levels",
[
pytest.param(SubscriptionType.TOP_OF_BOOK, Side.BUY, 1, 1, id="TOB 1 buy order"),
pytest.param(SubscriptionType.TOP_OF_BOOK, Side.SELL, 1, 1, id="TOB 1 sell order"),
pytest.param(SubscriptionType.TOP_OF_BOOK, Side.BUY, 2, 1, id="TOB 2 buy orders, 1 level"),
pytest.param(SubscriptionType.TOP_OF_BOOK, Side.SELL, 2, 1, id="TOB 2 sell orders, 1 level"),
pytest.param(SubscriptionType.TOP_OF_BOOK, Side.BUY, 1, 2, id="TOB 1 buy order by 2 levels"),
pytest.param(SubscriptionType.TOP_OF_BOOK, Side.SELL, 1, 2, id="TOB 1 sell order by 2 levels"),
pytest.param(SubscriptionType.TOP_OF_BOOK, None, 1, 2, id="TOB 1 order by level, both sides"),
pytest.param(SubscriptionType.TOP_OF_BOOK, None, 2, 2, id="TOB 2 orders by 2 levels, both sides"),
pytest.param(SubscriptionType.TOP_OF_BOOK, None, 5, 2, id="TOB 5 orders by 2 levels, both sides (unbalanced)"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, Side.BUY, 1, 1, id="PDB 1 buy order"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, Side.SELL, 1, 1, id="PDB 1 sell order"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, Side.BUY, 2, 1, id="PDB 2 buy orders, 1 level"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, Side.SELL, 2, 1, id="PDB 2 sell orders, 1 level"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, Side.BUY, 1, 2, id="PDB 1 buy order by 2 levels"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, Side.SELL, 1, 2, id="PDB 1 sell order by 2 levels"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, None, 1, 2, id="PDB 1 order by level, both sides"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, None, 2, 2, id="PDB 2 orders by 2 levels, both sides"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, None, 5, 2, id="PDB 5 orders by 2 levels, both sides (unbalanced)"),
]
)
def test_order_book_on_subscribe_with_insert_and_cancel(self, subscription_type: SubscriptionType.ValueType, side: Side.ValueType | None, orders_per_level: int, number_of_levels: int) -> None:
self._create_instrument_via_admin()
client = self._connect_and_login()
best_bid_price = random.randint(100, 10000) * self.tick_size
best_ask_price = best_bid_price + self.tick_size
quantity_per_order = random.randint(1, 1000)
expected_bids, expected_asks = self._insert_orders(
side, orders_per_level, number_of_levels,
best_bid_price, best_ask_price, quantity_per_order,
add_cancellation=True)
logger.info("Checking that no updates were sent before subscribing")
self.call_expectations_manager.verify_expectations()
self.call_expectations_manager.verify_no_unexpected_calls()
if subscription_type == SubscriptionType.TOP_OF_BOOK:
best_bid = PriceLevel(price=float(best_bid_price), quantity=expected_bids[best_bid_price]) if side is None or side == Side.BUY else None
best_ask = PriceLevel(price=float(best_ask_price), quantity=expected_asks[best_ask_price]) if side is None or side == Side.SELL else None
client.test_subscribe_to_order_book(instrument_symbol=self.instrument.symbol, type=SubscriptionType.TOP_OF_BOOK)
client.expect_on_top_of_book_event(self.instrument.symbol, best_bid=best_bid, best_ask=best_ask)
elif subscription_type == SubscriptionType.PRICE_DEPTH_BOOK:
client.test_subscribe_to_order_book(instrument_symbol=self.instrument.symbol, type=SubscriptionType.PRICE_DEPTH_BOOK)
client.expect_on_price_depth_book_event(
self.instrument.symbol,
bids=[PriceLevel(price=float(price), quantity=quantity) for price, quantity in expected_bids.items()],
asks=[PriceLevel(price=float(price), quantity=quantity) for price, quantity in expected_asks.items()])
else:
raise ValueError(f"Unexpected subscription type: {subscription_type}")
self.call_expectations_manager.verify_expectations()
@pytest.mark.parametrize(
"subscription_type",
[
pytest.param(SubscriptionType.TOP_OF_BOOK, id="TOB"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, id="PDB"),
]
)
def test_order_book_on_subscribe_in_cross_then_trade(self, subscription_type: SubscriptionType.ValueType) -> None:
self._create_instrument_via_admin()
client = self._connect_and_login()
order_book_client = self._get_order_book_client()
logger.info("Inserting passive buy order")
crossing_price = Decimal("100.0")
order_book_client.test_insert_order(
side=Side.BUY, price=crossing_price, quantity=10,
order_book_id=self.order_book_id, username=f"{self.test_name}_buyer")
logger.info("Checking that no updates were sent before subscribing")
self.call_expectations_manager.verify_expectations()
self.call_expectations_manager.verify_no_unexpected_calls()
client.test_subscribe_to_order_book(instrument_symbol=self.instrument.symbol, type=subscription_type)
if subscription_type == SubscriptionType.TOP_OF_BOOK:
client.expect_on_top_of_book_event(self.instrument.symbol, best_bid=PriceLevel(price=float(crossing_price), quantity=10), best_ask=None)
elif subscription_type == SubscriptionType.PRICE_DEPTH_BOOK:
client.expect_on_price_depth_book_event(self.instrument.symbol, bids=[PriceLevel(price=float(crossing_price), quantity=10)], asks=[])
else:
raise ValueError(f"Unexpected subscription type: {subscription_type}")
self.call_expectations_manager.verify_expectations()
logger.info("Inserting crossing sell order (triggers trade)")
order_book_client.test_insert_order(
side=Side.SELL, price=crossing_price, quantity=10,
order_book_id=self.order_book_id, username=f"{self.test_name}_seller")
order_book_client.expect_on_trade_event(
trade_id=ANY, order_book_id=self.order_book_id,
buy_order_id=ANY, sell_order_id=ANY,
aggressive_side=Side.SELL, price=float(crossing_price), quantity=10)
if subscription_type == SubscriptionType.TOP_OF_BOOK:
client.expect_on_top_of_book_event(self.instrument.symbol, best_bid=None, best_ask=None)
elif subscription_type == SubscriptionType.PRICE_DEPTH_BOOK:
client.expect_on_price_depth_book_event(self.instrument.symbol, bids=[], asks=[])
else:
raise ValueError(f"Unexpected subscription type: {subscription_type}")
self.call_expectations_manager.verify_expectations()
# ----------------------------------------------------------------
# Instrument lifecycle push to already-connected clients
# ----------------------------------------------------------------
def test_instrument_pushed_to_already_connected_client(self) -> None:
"""Client connects (and logs in if auth required) before any instrument
exists, then instrument is created via admin. The info component should
push OnInstrument."""
client = self._connect_unauthenticated()
if self.auth_required:
client.test_login(username=self.test_name)
self.call_expectations_manager.verify_no_unexpected_calls()
self._create_instrument_via_admin()
client.expect_on_instrument_event(
self.instrument, tick_size=self.tick_size, order_book_id=self.order_book_id)
self.call_expectations_manager.verify_expectations()
self.call_expectations_manager.verify_no_unexpected_calls()
def test_multiple_instruments_received_on_login(self) -> None:
"""Create several instruments, then connect (and log in if auth
required) a new client. All instruments should be delivered as
OnInstrument events."""
instruments: list[tuple[Instrument, Decimal, int]] = []
for _ in range(3):
instruments.append(self._create_instrument_via_admin())
client = self._connect_unauthenticated()
if self.auth_required:
client.test_login(username=self.test_name)
for inst, tick, ob_id in instruments:
client.expect_on_instrument_event(inst, tick_size=tick, order_book_id=ob_id)
self.call_expectations_manager.verify_expectations()
self.call_expectations_manager.verify_no_unexpected_calls()
def test_instrument_pushed_to_multiple_connected_clients(self) -> None:
"""Two clients are connected before the instrument is created.
Both should receive the OnInstrument event."""
client1 = self._connect_unauthenticated()
if self.auth_required:
client1.test_login(username=self.test_name)
client2 = self._connect_unauthenticated()
if self.auth_required:
client2.test_login(username=f"{self.test_name}_2")
self.call_expectations_manager.verify_no_unexpected_calls()
self._create_instrument_via_admin()
client1.expect_on_instrument_event(
self.instrument, tick_size=self.tick_size, order_book_id=self.order_book_id)
client2.expect_on_instrument_event(
self.instrument, tick_size=self.tick_size, order_book_id=self.order_book_id)
self.call_expectations_manager.verify_expectations()
# ----------------------------------------------------------------
# Subscription edge cases
# ----------------------------------------------------------------
def test_subscribe_to_nonexistent_instrument(self) -> None:
"""Subscribing to an unknown symbol should return an error."""
self._create_instrument_via_admin()
client = self._connect_and_login()
self.call_expectations_manager.verify_expectations()
client.test_subscribe_to_order_book(
instrument_symbol="NONEXISTENT", type=SubscriptionType.TOP_OF_BOOK,
expect_success=False)
# ----------------------------------------------------------------
# Live updates order insert
# ----------------------------------------------------------------
@pytest.mark.parametrize(
"subscription_type",
[
pytest.param(SubscriptionType.TOP_OF_BOOK, id="TOB"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, id="PDB"),
]
)
def test_live_update_on_order_insert(self, subscription_type: SubscriptionType.ValueType) -> None:
"""Subscribe to an empty order book, then insert an order.
The subscriber should receive a book update with the new order."""
self._create_instrument_via_admin()
client = self._connect_and_login()
self.call_expectations_manager.verify_expectations()
client.test_subscribe_to_order_book(
instrument_symbol=self.instrument.symbol, type=subscription_type)
price = Decimal("100.0")
quantity = 10
self._expect_book_update(client, subscription_type, self.instrument.symbol,
expected_bids={price: quantity}, expected_asks={})
order_book_client = self._get_order_book_client()
order_book_client.test_insert_order(
side=Side.BUY, price=price, quantity=quantity,
order_book_id=self.order_book_id, username=f"{self.test_name}_orders")
self.call_expectations_manager.verify_expectations()
# ----------------------------------------------------------------
# Live updates order cancel
# ----------------------------------------------------------------
@pytest.mark.parametrize(
"subscription_type",
[
pytest.param(SubscriptionType.TOP_OF_BOOK, id="TOB"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, id="PDB"),
]
)
def test_live_update_on_order_cancel(self, subscription_type: SubscriptionType.ValueType) -> None:
"""Subscribe, insert an order (verify update), then cancel it.
The book should go back to empty."""
self._create_instrument_via_admin()
client = self._connect_and_login()
self.call_expectations_manager.verify_expectations()
client.test_subscribe_to_order_book(
instrument_symbol=self.instrument.symbol, type=subscription_type)
price = Decimal("100.0")
quantity = 10
self._expect_book_update(client, subscription_type, self.instrument.symbol,
expected_bids={price: quantity}, expected_asks={})
order_book_client = self._get_order_book_client()
order_book_client.test_insert_order(
side=Side.BUY, price=price, quantity=quantity,
order_book_id=self.order_book_id, username=f"{self.test_name}_orders")
self.call_expectations_manager.verify_expectations()
self._expect_book_update(client, subscription_type, self.instrument.symbol,
expected_bids={}, expected_asks={})
order_book_client.test_cancel_order(order_book_id=self.order_book_id)
self.call_expectations_manager.verify_expectations()
# ----------------------------------------------------------------
# Live updates full trade clears book
# ----------------------------------------------------------------
@pytest.mark.parametrize(
"subscription_type",
[
pytest.param(SubscriptionType.TOP_OF_BOOK, id="TOB"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, id="PDB"),
]
)
def test_live_update_after_full_trade(self, subscription_type: SubscriptionType.ValueType) -> None:
"""Subscribe, insert a passive buy, then a crossing sell that fully
trades. The book should end up empty."""
self._create_instrument_via_admin()
client = self._connect_and_login()
self.call_expectations_manager.verify_expectations()
client.test_subscribe_to_order_book(
instrument_symbol=self.instrument.symbol, type=subscription_type)
price = Decimal("100.0")
quantity = 10
self._expect_book_update(client, subscription_type, self.instrument.symbol,
expected_bids={price: quantity}, expected_asks={})
order_book_client = self._get_order_book_client()
order_book_client.test_insert_order(
side=Side.BUY, price=price, quantity=quantity,
order_book_id=self.order_book_id, username=f"{self.test_name}_buyer")
self.call_expectations_manager.verify_expectations()
self._expect_book_update(client, subscription_type, self.instrument.symbol,
expected_bids={}, expected_asks={})
order_book_client.test_insert_order(
side=Side.SELL, price=price, quantity=quantity,
order_book_id=self.order_book_id, username=f"{self.test_name}_seller")
order_book_client.expect_on_trade_event(
trade_id=ANY, order_book_id=self.order_book_id,
buy_order_id=ANY, sell_order_id=ANY,
aggressive_side=Side.SELL, price=float(price), quantity=quantity)
self.call_expectations_manager.verify_expectations()
# ----------------------------------------------------------------
# Live updates partial trade
# ----------------------------------------------------------------
@pytest.mark.parametrize(
"subscription_type",
[
pytest.param(SubscriptionType.TOP_OF_BOOK, id="TOB"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, id="PDB"),
]
)
def test_live_update_after_partial_trade(self, subscription_type: SubscriptionType.ValueType) -> None:
"""Subscribe, insert a large passive buy (qty 10), then a smaller
crossing sell (qty 4). After the partial fill the remaining buy
quantity (6) should be reflected in the book update."""
self._create_instrument_via_admin()
client = self._connect_and_login()
self.call_expectations_manager.verify_expectations()
client.test_subscribe_to_order_book(
instrument_symbol=self.instrument.symbol, type=subscription_type)
price = Decimal("100.0")
buy_qty = 10
sell_qty = 4
remaining = buy_qty - sell_qty
self._expect_book_update(client, subscription_type, self.instrument.symbol,
expected_bids={price: buy_qty}, expected_asks={})
order_book_client = self._get_order_book_client()
order_book_client.test_insert_order(
side=Side.BUY, price=price, quantity=buy_qty,
order_book_id=self.order_book_id, username=f"{self.test_name}_buyer")
self.call_expectations_manager.verify_expectations()
self._expect_book_update(client, subscription_type, self.instrument.symbol,
expected_bids={price: remaining}, expected_asks={})
order_book_client.test_insert_order(
side=Side.SELL, price=price, quantity=sell_qty,
order_book_id=self.order_book_id, username=f"{self.test_name}_seller")
order_book_client.expect_on_trade_event(
trade_id=ANY, order_book_id=self.order_book_id,
buy_order_id=ANY, sell_order_id=ANY,
aggressive_side=Side.SELL, price=float(price), quantity=sell_qty)
self.call_expectations_manager.verify_expectations()
# ----------------------------------------------------------------
# Live updates multiple sequential inserts
# ----------------------------------------------------------------
@pytest.mark.parametrize(
"subscription_type",
[
pytest.param(SubscriptionType.TOP_OF_BOOK, id="TOB"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, id="PDB"),
]
)
def test_live_update_multiple_sequential_inserts(self, subscription_type: SubscriptionType.ValueType) -> None:
"""Subscribe then insert several orders one by one, verifying
the book update after each insert."""
self._create_instrument_via_admin()
client = self._connect_and_login()
self.call_expectations_manager.verify_expectations()
client.test_subscribe_to_order_book(
instrument_symbol=self.instrument.symbol, type=subscription_type)
order_book_client = self._get_order_book_client()
cumulative_bids: dict[Decimal, int] = {}
prices = [Decimal("100.0"), Decimal("99.99"), Decimal("99.98")]
quantity = 5
for price in prices:
cumulative_bids[price] = quantity
self._expect_book_update(client, subscription_type, self.instrument.symbol,
expected_bids=dict(cumulative_bids), expected_asks={})
order_book_client.test_insert_order(
side=Side.BUY, price=price, quantity=quantity,
order_book_id=self.order_book_id, username=f"{self.test_name}_orders")
self.call_expectations_manager.verify_expectations()
self.call_expectations_manager.verify_no_unexpected_calls()
# ----------------------------------------------------------------
# Multi-client both subscribers receive updates
# ----------------------------------------------------------------
@pytest.mark.parametrize(
"subscription_type",
[
pytest.param(SubscriptionType.TOP_OF_BOOK, id="TOB"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, id="PDB"),
]
)
def test_multiple_clients_receive_updates(self, subscription_type: SubscriptionType.ValueType) -> None:
"""Two clients subscribe to the same instrument. Both should
receive book updates when an order is inserted."""
self._create_instrument_via_admin()
client1 = self._connect_and_login(username=self.test_name)
client2 = self._connect_and_login(username=f"{self.test_name}_2")
self.call_expectations_manager.verify_expectations()
client1.test_subscribe_to_order_book(
instrument_symbol=self.instrument.symbol, type=subscription_type)
client2.test_subscribe_to_order_book(
instrument_symbol=self.instrument.symbol, type=subscription_type)
price = Decimal("100.0")
quantity = 10
self._expect_book_update(client1, subscription_type, self.instrument.symbol,
expected_bids={price: quantity}, expected_asks={})
self._expect_book_update(client2, subscription_type, self.instrument.symbol,
expected_bids={price: quantity}, expected_asks={})
order_book_client = self._get_order_book_client()
order_book_client.test_insert_order(
side=Side.BUY, price=price, quantity=quantity,
order_book_id=self.order_book_id, username=f"{self.test_name}_orders")
self.call_expectations_manager.verify_expectations()
# ----------------------------------------------------------------
# Multi-client late subscriber gets current snapshot
# ----------------------------------------------------------------
@pytest.mark.parametrize(
"subscription_type",
[
pytest.param(SubscriptionType.TOP_OF_BOOK, id="TOB"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, id="PDB"),
]
)
def test_late_subscriber_gets_current_snapshot(self, subscription_type: SubscriptionType.ValueType) -> None:
"""Client A subscribes and receives updates as orders are placed.
Client B subscribes later and should receive a snapshot reflecting
the current book state, including all previous inserts."""
self._create_instrument_via_admin()
client_a = self._connect_and_login(username=self.test_name)
self.call_expectations_manager.verify_expectations()
client_a.test_subscribe_to_order_book(
instrument_symbol=self.instrument.symbol, type=subscription_type)
bid_price = Decimal("100.0")
ask_price = Decimal("100.01")
quantity = 10
self._expect_book_update(client_a, subscription_type, self.instrument.symbol,
expected_bids={bid_price: quantity}, expected_asks={})
order_book_client = self._get_order_book_client()
order_book_client.test_insert_order(
side=Side.BUY, price=bid_price, quantity=quantity,
order_book_id=self.order_book_id, username=f"{self.test_name}_orders")
self.call_expectations_manager.verify_expectations()
self._expect_book_update(client_a, subscription_type, self.instrument.symbol,
expected_bids={bid_price: quantity},
expected_asks={ask_price: quantity})
order_book_client.test_insert_order(
side=Side.SELL, price=ask_price, quantity=quantity,
order_book_id=self.order_book_id, username=f"{self.test_name}_orders")
self.call_expectations_manager.verify_expectations()
logger.info("Late subscriber (client B) connects and subscribes")
client_b = self._connect_and_login(username=f"{self.test_name}_2")
self.call_expectations_manager.verify_expectations()
self._expect_book_update(client_b, subscription_type, self.instrument.symbol,
expected_bids={bid_price: quantity},
expected_asks={ask_price: quantity})
client_b.test_subscribe_to_order_book(
instrument_symbol=self.instrument.symbol, type=subscription_type)
self.call_expectations_manager.verify_expectations()
# ----------------------------------------------------------------
# Multi-client mixed subscription types on the same instrument
# ----------------------------------------------------------------
@pytest.mark.parametrize(
"side",
[
pytest.param(Side.BUY, id="BUY"),
pytest.param(Side.SELL, id="SELL"),
]
)
def test_mixed_subscription_types_same_instrument(self, side: Side.ValueType) -> None:
"""One client subscribes with TOP_OF_BOOK and another with
PRICE_DEPTH_BOOK to the same instrument. Each client should
receive the appropriate update format when an order is inserted."""
self._create_instrument_via_admin()
self._tob_client = self._connect_and_login(username=self.test_name)
self._pdb_client = self._connect_and_login(username=f"{self.test_name}_2")
self.call_expectations_manager.verify_expectations()
self._tob_client.test_subscribe_to_order_book(
instrument_symbol=self.instrument.symbol, type=SubscriptionType.TOP_OF_BOOK)
self._pdb_client.test_subscribe_to_order_book(
instrument_symbol=self.instrument.symbol, type=SubscriptionType.PRICE_DEPTH_BOOK)
self._mixed_best_price = Decimal("100.0")
self._mixed_quantity = 10
best_bid = PriceLevel(price=float(self._mixed_best_price), quantity=self._mixed_quantity) if side == Side.BUY else None
best_ask = PriceLevel(price=float(self._mixed_best_price), quantity=self._mixed_quantity) if side == Side.SELL else None
expected_bids = [PriceLevel(price=float(self._mixed_best_price), quantity=self._mixed_quantity)] if side == Side.BUY else []
expected_asks = [PriceLevel(price=float(self._mixed_best_price), quantity=self._mixed_quantity)] if side == Side.SELL else []
self._tob_client.expect_on_top_of_book_event(
self.instrument.symbol, best_bid=best_bid, best_ask=best_ask)
self._pdb_client.expect_on_price_depth_book_event(
self.instrument.symbol, bids=expected_bids, asks=expected_asks)
order_book_client = self._get_order_book_client()
order_book_client.test_insert_order(
side=side, price=self._mixed_best_price, quantity=self._mixed_quantity,
order_book_id=self.order_book_id, username=f"{self.test_name}_orders")
self.call_expectations_manager.verify_expectations()
@pytest.mark.parametrize(
"side",
[
pytest.param(Side.BUY, id="BUY"),
pytest.param(Side.SELL, id="SELL"),
]
)
def test_mixed_subscriptions_tob_unchanged_after_non_best_insert(self, side: Side.ValueType) -> None:
"""After the first order establishes the best price, insert a second
order at a worse price. The PDB subscriber should see a new level
added while the TOB subscriber sees the same best price/quantity."""
self.test_mixed_subscription_types_same_instrument(side)
worse_price = (self._mixed_best_price - self.tick_size) if side == Side.BUY else (self._mixed_best_price + self.tick_size)
worse_qty = 5
best_bid = PriceLevel(price=float(self._mixed_best_price), quantity=self._mixed_quantity) if side == Side.BUY else None
best_ask = PriceLevel(price=float(self._mixed_best_price), quantity=self._mixed_quantity) if side == Side.SELL else None
self._tob_client.expect_on_top_of_book_event(
self.instrument.symbol, best_bid=best_bid, best_ask=best_ask)
if side == Side.BUY:
self._pdb_client.expect_on_price_depth_book_event(
self.instrument.symbol,
bids=[
PriceLevel(price=float(self._mixed_best_price), quantity=self._mixed_quantity),
PriceLevel(price=float(worse_price), quantity=worse_qty),
],
asks=[])
else:
self._pdb_client.expect_on_price_depth_book_event(
self.instrument.symbol,
bids=[],
asks=[
PriceLevel(price=float(self._mixed_best_price), quantity=self._mixed_quantity),
PriceLevel(price=float(worse_price), quantity=worse_qty),
])
order_book_client = self._get_order_book_client()
order_book_client.test_insert_order(
side=side, price=worse_price, quantity=worse_qty,
order_book_id=self.order_book_id, username=f"{self.test_name}_orders")
self.call_expectations_manager.verify_expectations()
# ----------------------------------------------------------------
# Subscription isolation different instruments
# ----------------------------------------------------------------
def test_subscriptions_to_different_instruments_are_isolated(self) -> None:
"""Subscribe to instrument A only. Insert orders for both A and B.
The subscriber should only receive updates for instrument A."""
inst_a, tick_a, ob_id_a = self._create_instrument_via_admin()
client = self._connect_and_login()
self.call_expectations_manager.verify_expectations()
client.test_subscribe_to_order_book(
instrument_symbol=inst_a.symbol, type=SubscriptionType.TOP_OF_BOOK)
order_book_client = self._get_order_book_client()
price_a = Decimal("100.0")
client.expect_on_top_of_book_event(
inst_a.symbol,
best_bid=PriceLevel(price=float(price_a), quantity=10),
best_ask=None)
order_book_client.test_insert_order(
side=Side.BUY, price=price_a, quantity=10,
order_book_id=ob_id_a, username=f"{self.test_name}_orders")
self.call_expectations_manager.verify_expectations()
order_book_client.expect_on_order_book_created_event(
tick_size=Decimal("0.02"), order_book_id=None)
_, _, ob_id_b = self._create_instrument_via_admin(tick_size=Decimal("0.02"))
inst_b = self.instrument
client.expect_on_instrument_event(
inst_b, tick_size=self.tick_size, order_book_id=ob_id_b)
self.call_expectations_manager.verify_expectations()
order_book_client.test_insert_order(
side=Side.BUY, price=Decimal("200.0"), quantity=20,
order_book_id=ob_id_b, username=f"{self.test_name}_orders")
self.call_expectations_manager.verify_expectations()
self.call_expectations_manager.verify_no_unexpected_calls()
# ----------------------------------------------------------------
# Unsubscribed client should not get order book updates
# ----------------------------------------------------------------
@pytest.mark.parametrize(
"subscription_type",
[
pytest.param(SubscriptionType.TOP_OF_BOOK, id="TOB"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, id="PDB"),
]
)
def test_unsubscribed_client_receives_no_updates(self, subscription_type: SubscriptionType.ValueType) -> None:
"""A client that is logged in but has not subscribed should not
receive any TOB or PDB updates when the order book changes."""
self._create_instrument_via_admin()
client = self._connect_and_login()
self.call_expectations_manager.verify_expectations()
order_book_client = self._get_order_book_client()
order_book_client.test_insert_order(
side=Side.BUY, price=Decimal("100.0"), quantity=10,
order_book_id=self.order_book_id, username=f"{self.test_name}_orders")
self.call_expectations_manager.verify_expectations()
self.call_expectations_manager.verify_no_unexpected_calls()
# ----------------------------------------------------------------
# Live update new instrument created while subscribed to another
# ----------------------------------------------------------------
def test_new_instrument_while_subscribed_to_another(self) -> None:
"""While subscribed to instrument A, a new instrument B is created.
The client should receive OnInstrument for B but no spurious
book updates for A."""
self._create_instrument_via_admin()
client = self._connect_and_login()
self.call_expectations_manager.verify_expectations()
client.test_subscribe_to_order_book(
instrument_symbol=self.instrument.symbol, type=SubscriptionType.TOP_OF_BOOK)
order_book_client = self._get_order_book_client()
order_book_client.expect_on_order_book_created_event(
tick_size=Decimal("0.05"), order_book_id=None)
self._create_instrument_via_admin(tick_size=Decimal("0.05"))
client.expect_on_instrument_event(
self.instrument, tick_size=self.tick_size, order_book_id=self.order_book_id)
self.call_expectations_manager.verify_expectations()
self.call_expectations_manager.verify_no_unexpected_calls()
# ----------------------------------------------------------------
# Live updates both sides of the book
# ----------------------------------------------------------------
@pytest.mark.parametrize(
"subscription_type",
[
pytest.param(SubscriptionType.TOP_OF_BOOK, id="TOB"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, id="PDB"),
]
)
def test_live_update_with_both_sides(self, subscription_type: SubscriptionType.ValueType) -> None:
"""Subscribe, insert a buy then a sell at different prices.
The book should reflect orders on both sides."""
self._create_instrument_via_admin()
client = self._connect_and_login()
self.call_expectations_manager.verify_expectations()
client.test_subscribe_to_order_book(
instrument_symbol=self.instrument.symbol, type=subscription_type)
bid_price = Decimal("100.0")
ask_price = Decimal("100.01")
quantity = 10
self._expect_book_update(client, subscription_type, self.instrument.symbol,
expected_bids={bid_price: quantity}, expected_asks={})
order_book_client = self._get_order_book_client()
order_book_client.test_insert_order(
side=Side.BUY, price=bid_price, quantity=quantity,
order_book_id=self.order_book_id, username=f"{self.test_name}_orders")
self.call_expectations_manager.verify_expectations()
self._expect_book_update(client, subscription_type, self.instrument.symbol,
expected_bids={bid_price: quantity},
expected_asks={ask_price: quantity})
order_book_client.test_insert_order(
side=Side.SELL, price=ask_price, quantity=quantity,
order_book_id=self.order_book_id, username=f"{self.test_name}_orders")
self.call_expectations_manager.verify_expectations()
# ----------------------------------------------------------------
# Live updates aggregation at same price level
# ----------------------------------------------------------------
@pytest.mark.parametrize(
"subscription_type",
[
pytest.param(SubscriptionType.TOP_OF_BOOK, id="TOB"),
pytest.param(SubscriptionType.PRICE_DEPTH_BOOK, id="PDB"),
]
)
def test_live_update_quantity_aggregation_at_same_price(self, subscription_type: SubscriptionType.ValueType) -> None:
"""Insert two orders at the same price after subscribing. The
reported quantity should be the sum of both orders."""
self._create_instrument_via_admin()
client = self._connect_and_login()
self.call_expectations_manager.verify_expectations()
client.test_subscribe_to_order_book(
instrument_symbol=self.instrument.symbol, type=subscription_type)
price = Decimal("100.0")
qty1 = 10
qty2 = 15
self._expect_book_update(client, subscription_type, self.instrument.symbol,
expected_bids={price: qty1}, expected_asks={})
order_book_client = self._get_order_book_client()
order_book_client.test_insert_order(
side=Side.BUY, price=price, quantity=qty1,
order_book_id=self.order_book_id, username=f"{self.test_name}_orders")
self.call_expectations_manager.verify_expectations()
self._expect_book_update(client, subscription_type, self.instrument.symbol,
expected_bids={price: qty1 + qty2}, expected_asks={})
order_book_client.test_insert_order(
side=Side.BUY, price=price, quantity=qty2,
order_book_id=self.order_book_id, username=f"{self.test_name}_orders")
self.call_expectations_manager.verify_expectations()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,624 @@
from decimal import Decimal
import logging
import pytest
from connection.tcp_connection_manager import TcpConnectionManager
from proto.common_pb2 import Instrument, Side
from proto.risk_limits_pb2 import (
InstrumentRiskLimits, RollingWindowLimit, UserRiskLimits,
)
from tests.conftest import DEFAULT_PASSWORD
from tests.common.auth_tests import AuthenticationTests
from tests.common.mock_expectations import CallExpectationsManager
from tests.test_client.admin_test_client import (
AdminClientConnectionHandlerFactory, AdminTestClient,
)
from tests.test_client.risk_gateway_test_client import (
RiskGatewayClientConnectionHandlerFactory, RiskGatewayTestClient,
)
logger = logging.getLogger(__name__)
class TestRiskGatewaySystem(AuthenticationTests):
PROTOCOL = "risk_limits"
@pytest.fixture(autouse=True)
def setup(self, request: pytest.FixtureRequest) -> None:
self.test_name = request.node.name
logger.info(f"Setting up test: {self.test_name}")
self.tcp_connection_manager = TcpConnectionManager()
self.call_expectations_manager = CallExpectationsManager()
self.call_expectations_manager.setup_network(self.tcp_connection_manager)
self._client_factory = RiskGatewayClientConnectionHandlerFactory(self.call_expectations_manager)
self._admin_factory = AdminClientConnectionHandlerFactory(self.call_expectations_manager)
self._admin_client: AdminTestClient | None = None
self._next_instrument_id = 1
def _connect_unauthenticated(self) -> RiskGatewayTestClient:
return self.tcp_connection_manager.connect(self.server_address, self._client_factory)
def _connect_and_login(self, username: str | None = None) -> RiskGatewayTestClient:
"""Connect to the risk gateway, log in (if auth required), and return the client."""
client: RiskGatewayTestClient = self.tcp_connection_manager.connect(self.server_address, self._client_factory)
if self.auth_required:
client.test_login(username=username or self.test_name, password=DEFAULT_PASSWORD)
return client
def _get_admin_client(self) -> AdminTestClient:
if self._admin_client is None:
admin_address = self.orchestrator.get_server_address("admin")
self._admin_client = self.tcp_connection_manager.connect(admin_address, self._admin_factory)
return self._admin_client
def _create_instrument_via_admin(self, tick_size: Decimal = Decimal("0.01")) -> str:
"""Create an instrument via admin and return its symbol."""
symbol = f"TEST.{self._next_instrument_id}"
self._next_instrument_id += 1
instrument = Instrument(
symbol=symbol, description="Test instrument",
currency="USD", multiplier=1.0)
admin = self._get_admin_client()
admin.test_create_instrument(instrument, tick_size)
return symbol
# =========================================================================
# Risk limits CRUD user level
# =========================================================================
def test_set_and_get_user_risk_limits(self) -> None:
client = self._connect_and_login()
limits = UserRiskLimits(max_outstanding_quantity=500)
client.test_set_user_risk_limits(limits)
resp = client.test_get_user_risk_limits()
response = resp.get_response()
assert response.user_risk_limits.max_outstanding_quantity == 500
def test_get_user_risk_limits_when_not_set(self) -> None:
client = self._connect_and_login()
client.test_get_user_risk_limits(expect_success=False)
# =========================================================================
# Risk limits CRUD instrument level
# =========================================================================
def test_set_and_get_instrument_risk_limits(self) -> None:
symbol = self._create_instrument_via_admin()
client = self._connect_and_login()
limits = InstrumentRiskLimits(max_outstanding_quantity=100)
client.test_set_instrument_risk_limits(symbol, limits)
resp = client.test_get_instrument_risk_limits()
response = resp.get_response()
assert symbol in response.risk_limits_by_instrument
assert response.risk_limits_by_instrument[symbol].max_outstanding_quantity == 100
def test_get_instrument_risk_limits_when_not_set(self) -> None:
client = self._connect_and_login()
resp = client.test_get_instrument_risk_limits()
response = resp.get_response()
assert len(response.risk_limits_by_instrument) == 0
def test_set_instrument_risk_limits_for_multiple_instruments(self) -> None:
sym_a = self._create_instrument_via_admin()
sym_b = self._create_instrument_via_admin()
client = self._connect_and_login()
client.test_set_instrument_risk_limits(
sym_a, InstrumentRiskLimits(max_outstanding_quantity=100))
client.test_set_instrument_risk_limits(
sym_b, InstrumentRiskLimits(max_outstanding_quantity=200))
resp = client.test_get_instrument_risk_limits()
by_instrument = resp.get_response().risk_limits_by_instrument
assert by_instrument[sym_a].max_outstanding_quantity == 100
assert by_instrument[sym_b].max_outstanding_quantity == 200
# =========================================================================
# User max outstanding quantity
# =========================================================================
def test_user_max_outstanding_quantity_accepts_within_limit(self) -> None:
symbol = self._create_instrument_via_admin()
client = self._connect_and_login()
client.test_set_user_risk_limits(
UserRiskLimits(max_outstanding_quantity=20))
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 20)
def test_user_max_outstanding_quantity_rejects_over_limit(self) -> None:
symbol = self._create_instrument_via_admin()
client = self._connect_and_login()
client.test_set_user_risk_limits(
UserRiskLimits(max_outstanding_quantity=20))
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 20)
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 1,
expect_success=False)
def test_user_max_outstanding_quantity_freed_by_cancellation(self) -> None:
symbol = self._create_instrument_via_admin()
client = self._connect_and_login()
client.test_set_user_risk_limits(
UserRiskLimits(max_outstanding_quantity=20))
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 20)
client.test_cancel_order(symbol)
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 20)
def test_user_max_outstanding_quantity_freed_by_full_trade(self) -> None:
symbol = self._create_instrument_via_admin()
user_a = self._connect_and_login(f"{self.test_name}_2")
user_b = self._connect_and_login(f"{self.test_name}_3")
user_a.test_set_user_risk_limits(
UserRiskLimits(max_outstanding_quantity=20))
user_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 20)
user_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 1,
expect_success=False)
user_b.test_insert_order(symbol, Side.SELL, Decimal("100.0"), 20)
user_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 20)
def test_user_max_outstanding_quantity_freed_by_partial_trade(self) -> None:
symbol = self._create_instrument_via_admin()
user_a = self._connect_and_login(f"{self.test_name}_2")
user_b = self._connect_and_login(f"{self.test_name}_3")
user_a.test_set_user_risk_limits(
UserRiskLimits(max_outstanding_quantity=20))
user_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 20)
user_b.test_insert_order(symbol, Side.SELL, Decimal("100.0"), 10)
user_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 10)
def test_user_max_outstanding_quantity_spans_all_instruments(self) -> None:
"""The user-level limit applies across all instruments combined."""
sym_a = self._create_instrument_via_admin()
sym_b = self._create_instrument_via_admin()
client = self._connect_and_login()
client.test_set_user_risk_limits(
UserRiskLimits(max_outstanding_quantity=30))
client.test_insert_order(sym_a, Side.BUY, Decimal("100.0"), 20)
client.test_insert_order(sym_b, Side.BUY, Decimal("100.0"), 10)
client.test_insert_order(sym_b, Side.BUY, Decimal("100.0"), 1,
expect_success=False)
# =========================================================================
# Instrument max outstanding quantity
# =========================================================================
def test_instrument_max_outstanding_quantity_rejects_over_limit(self) -> None:
symbol = self._create_instrument_via_admin()
client = self._connect_and_login()
client.test_set_instrument_risk_limits(
symbol, InstrumentRiskLimits(max_outstanding_quantity=15))
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 15)
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 1,
expect_success=False)
def test_instrument_max_outstanding_quantity_independent_per_instrument(self) -> None:
"""Limits on instrument A must not affect orders on instrument B."""
sym_a = self._create_instrument_via_admin()
sym_b = self._create_instrument_via_admin()
client = self._connect_and_login()
client.test_set_instrument_risk_limits(
sym_a, InstrumentRiskLimits(max_outstanding_quantity=10))
client.test_insert_order(sym_a, Side.BUY, Decimal("100.0"), 10)
client.test_insert_order(sym_a, Side.BUY, Decimal("100.0"), 1,
expect_success=False)
client.test_insert_order(sym_b, Side.BUY, Decimal("100.0"), 100)
def test_instrument_max_outstanding_quantity_freed_by_cancellation(self) -> None:
symbol = self._create_instrument_via_admin()
client = self._connect_and_login()
client.test_set_instrument_risk_limits(
symbol, InstrumentRiskLimits(max_outstanding_quantity=10))
resp = client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 10)
order_id = resp.get_response().order_id
client.test_cancel_order(symbol, order_id=order_id)
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 10)
def test_instrument_max_outstanding_quantity_freed_by_trade(self) -> None:
symbol = self._create_instrument_via_admin()
user_a = self._connect_and_login(f"{self.test_name}_2")
user_b = self._connect_and_login(f"{self.test_name}_3")
user_a.test_set_instrument_risk_limits(
symbol, InstrumentRiskLimits(max_outstanding_quantity=10))
user_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 10)
user_b.test_insert_order(symbol, Side.SELL, Decimal("100.0"), 10)
user_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 10)
# =========================================================================
# Instrument max outstanding amount
# =========================================================================
def test_instrument_max_outstanding_amount_rejects_over_limit(self) -> None:
symbol = self._create_instrument_via_admin()
client = self._connect_and_login()
client.test_set_instrument_risk_limits(
symbol, InstrumentRiskLimits(max_outstanding_amount=1000.0))
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 10)
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 1,
expect_success=False)
def test_instrument_max_outstanding_amount_considers_price_and_quantity(self) -> None:
"""Amount = price x quantity. A higher price hits the limit sooner."""
symbol = self._create_instrument_via_admin()
client = self._connect_and_login()
client.test_set_instrument_risk_limits(
symbol, InstrumentRiskLimits(max_outstanding_amount=599.0))
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 5)
client.test_insert_order(symbol, Side.BUY, Decimal("10.0"), 10,
expect_success=False)
client.test_insert_order(symbol, Side.SELL, Decimal("100.0"), 1,
expect_success=False)
def test_instrument_max_outstanding_amount_freed_by_trade(self) -> None:
symbol = self._create_instrument_via_admin()
user_a = self._connect_and_login(f"{self.test_name}_2")
user_b = self._connect_and_login(f"{self.test_name}_3")
user_a.test_set_instrument_risk_limits(
symbol, InstrumentRiskLimits(max_outstanding_amount=1000.0))
user_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 10)
user_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 1,
expect_success=False)
user_b.test_insert_order(symbol, Side.SELL, Decimal("100.0"), 10)
user_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 10)
# =========================================================================
# Message rate rolling limit
# =========================================================================
def test_message_rate_limit_rejects_excess(self) -> None:
"""After hitting the message rate limit within the window, further orders are rejected."""
symbol = self._create_instrument_via_admin()
client = self._connect_and_login()
client.test_set_user_risk_limits(UserRiskLimits(
message_rate_rolling_limit=RollingWindowLimit(limit=3, window_in_seconds=60)))
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 1)
client.test_insert_order(symbol, Side.BUY, Decimal("101.0"), 1)
client.test_insert_order(symbol, Side.BUY, Decimal("102.0"), 1)
client.test_insert_order(symbol, Side.BUY, Decimal("103.0"), 1,
expect_success=False)
# =========================================================================
# Order quantity rolling limit
# =========================================================================
def test_order_quantity_rolling_limit_rejects_excess(self) -> None:
symbol = self._create_instrument_via_admin()
client = self._connect_and_login()
client.test_set_instrument_risk_limits(symbol, InstrumentRiskLimits(
order_quantity_rolling_limit=RollingWindowLimit(limit=50, window_in_seconds=60)))
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 30)
client.test_insert_order(symbol, Side.BUY, Decimal("101.0"), 20)
client.test_insert_order(symbol, Side.BUY, Decimal("102.0"), 1,
expect_success=False)
# =========================================================================
# Order amount rolling limit
# =========================================================================
def test_order_amount_rolling_limit_rejects_excess(self) -> None:
symbol = self._create_instrument_via_admin()
client = self._connect_and_login()
client.test_set_instrument_risk_limits(symbol, InstrumentRiskLimits(
order_amount_rolling_limit=RollingWindowLimit(limit=5000, window_in_seconds=60)))
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 30)
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 20)
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 1,
expect_success=False)
# =========================================================================
# Multi-user isolation
# =========================================================================
def test_limits_are_isolated_between_users(self) -> None:
"""User A's limits must not affect user B's ability to trade."""
symbol = self._create_instrument_via_admin()
user_a = self._connect_and_login(f"{self.test_name}_2")
user_b = self._connect_and_login(f"{self.test_name}_3")
user_a.test_set_user_risk_limits(
UserRiskLimits(max_outstanding_quantity=10))
user_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 10)
user_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 1,
expect_success=False)
user_b.test_insert_order(symbol, Side.SELL, Decimal("101.0"), 100)
def test_order_from_user_a_does_not_count_against_user_b(self) -> None:
"""Outstanding orders from user A must not consume user B's limits."""
symbol = self._create_instrument_via_admin()
user_a = self._connect_and_login(f"{self.test_name}_2")
user_b = self._connect_and_login(f"{self.test_name}_3")
user_b.test_set_instrument_risk_limits(
symbol, InstrumentRiskLimits(max_outstanding_quantity=10))
user_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 100)
user_b.test_insert_order(symbol, Side.SELL, Decimal("101.0"), 10)
# =========================================================================
# State tracking across trades and partial fills
# =========================================================================
def test_partial_trade_correctly_updates_outstanding_quantity(self) -> None:
"""After a partial fill, the remaining quantity still counts towards limits."""
symbol = self._create_instrument_via_admin()
user_a = self._connect_and_login(f"{self.test_name}_2")
user_b = self._connect_and_login(f"{self.test_name}_3")
user_a.test_set_instrument_risk_limits(
symbol, InstrumentRiskLimits(max_outstanding_quantity=20))
user_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 20)
user_b.test_insert_order(symbol, Side.SELL, Decimal("100.0"), 5)
user_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 5)
user_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 1,
expect_success=False)
def test_trade_between_two_limited_users_updates_both(self) -> None:
"""When two users trade, both users' outstanding quantities must decrease."""
symbol = self._create_instrument_via_admin()
user_a = self._connect_and_login(f"{self.test_name}_2")
user_b = self._connect_and_login(f"{self.test_name}_3")
user_a.test_set_user_risk_limits(
UserRiskLimits(max_outstanding_quantity=10))
user_b.test_set_user_risk_limits(
UserRiskLimits(max_outstanding_quantity=10))
user_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 10)
user_b.test_insert_order(symbol, Side.SELL, Decimal("100.0"), 10)
user_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 10)
user_b.test_insert_order(symbol, Side.SELL, Decimal("101.0"), 10)
def test_multiple_partial_fills_track_remaining_correctly(self) -> None:
"""Several small trades against a large order must reduce the outstanding by the cumulative traded amount."""
symbol = self._create_instrument_via_admin()
user_a = self._connect_and_login(f"{self.test_name}_2")
user_b = self._connect_and_login(f"{self.test_name}_3")
user_a.test_set_instrument_risk_limits(
symbol, InstrumentRiskLimits(max_outstanding_quantity=30))
user_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 30)
for _ in range(3):
user_b.test_insert_order(symbol, Side.SELL, Decimal("100.0"), 5)
user_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 15)
user_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 1,
expect_success=False)
# =========================================================================
# Combined limits
# =========================================================================
def test_both_user_and_instrument_limits_checked(self) -> None:
"""An order must satisfy both user-level and instrument-level limits."""
symbol = self._create_instrument_via_admin()
client = self._connect_and_login()
client.test_set_user_risk_limits(
UserRiskLimits(max_outstanding_quantity=100))
client.test_set_instrument_risk_limits(
symbol, InstrumentRiskLimits(max_outstanding_quantity=10))
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 10)
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 1,
expect_success=False)
def test_instrument_limit_hit_while_user_limit_ok(self) -> None:
"""Even if the user-level limit has room, the instrument-level limit must block."""
sym_a = self._create_instrument_via_admin()
sym_b = self._create_instrument_via_admin()
client = self._connect_and_login()
client.test_set_user_risk_limits(
UserRiskLimits(max_outstanding_quantity=100))
client.test_set_instrument_risk_limits(
sym_a, InstrumentRiskLimits(max_outstanding_quantity=5))
client.test_insert_order(sym_a, Side.BUY, Decimal("100.0"), 5)
client.test_insert_order(sym_a, Side.BUY, Decimal("100.0"), 1,
expect_success=False)
client.test_insert_order(sym_b, Side.BUY, Decimal("100.0"), 50)
def test_user_limit_hit_while_instrument_limit_ok(self) -> None:
"""Even if the instrument-level limit has room, the user-level limit must block."""
sym_a = self._create_instrument_via_admin()
sym_b = self._create_instrument_via_admin()
client = self._connect_and_login()
client.test_set_user_risk_limits(
UserRiskLimits(max_outstanding_quantity=15))
client.test_set_instrument_risk_limits(
sym_a, InstrumentRiskLimits(max_outstanding_quantity=100))
client.test_set_instrument_risk_limits(
sym_b, InstrumentRiskLimits(max_outstanding_quantity=100))
client.test_insert_order(sym_a, Side.BUY, Decimal("100.0"), 10)
client.test_insert_order(sym_b, Side.BUY, Decimal("100.0"), 5)
client.test_insert_order(sym_b, Side.BUY, Decimal("100.0"), 1,
expect_success=False)
# =========================================================================
# Floating-point precision (4-decimal-digit prices/amounts)
#
# The protocol uses double for prices and amounts. Naive use of IEEE 754
# floats leads to well-known rounding errors (e.g. 0.1 * 3 != 0.3).
# These tests verify that the solution uses proper precision handling so
# that orders exactly at the limit are not falsely rejected, and that
# add/subtract cycles leave no residual error that blocks future orders.
# =========================================================================
def test_amount_boundary_not_falsely_rejected_by_float_multiplication(self) -> None:
"""price * qty can overshoot in float (e.g. 0.1 * 3 = 0.30000000000000004).
An order whose exact amount (4-decimal precision) equals the limit must
be accepted despite IEEE 754 rounding in the multiplication.
"""
symbol = self._create_instrument_via_admin()
client = self._connect_and_login()
client.test_set_instrument_risk_limits(
symbol, InstrumentRiskLimits(max_outstanding_amount=300.03))
# Exact: 100.01 x 3 = 300.03 — equals the limit, must be accepted.
# In float: 100.01 * 3 may produce 300.03000000000003 due to the
# inexact representation of 100.01 in IEEE 754.
client.test_insert_order(symbol, Side.BUY, Decimal("100.01"), 3)
def test_amount_boundary_not_falsely_rejected_by_float_accumulation(self) -> None:
"""Repeated float additions drift away from the true sum.
Three additions of 33.3333 should total 99.9999, leaving room for
exactly 0.0001 more. A naive float accumulation may overshoot 99.9999,
falsely blocking the last order.
"""
symbol = self._create_instrument_via_admin(tick_size=Decimal("0.0001"))
client = self._connect_and_login()
client.test_set_instrument_risk_limits(
symbol, InstrumentRiskLimits(max_outstanding_amount=100.0))
client.test_insert_order(symbol, Side.BUY, Decimal("33.3333"), 1)
client.test_insert_order(symbol, Side.BUY, Decimal("33.3333"), 1)
client.test_insert_order(symbol, Side.BUY, Decimal("33.3333"), 1)
# Exact total so far: 99.9999. One more at 0.0001 → 100.0 = limit.
client.test_insert_order(symbol, Side.BUY, Decimal("0.0001"), 1)
def test_amount_boundary_with_classic_point_one_plus_point_two(self) -> None:
"""The textbook float error: 0.1 + 0.2 = 0.30000000000000004 in IEEE 754.
With limit = 3000, two orders (0.1 x 10000 = 1000, 0.2 x 10000 = 2000)
total exactly 3000. A third fills the boundary via small price.
"""
symbol = self._create_instrument_via_admin()
client = self._connect_and_login()
client.test_set_instrument_risk_limits(
symbol, InstrumentRiskLimits(max_outstanding_amount=3000.0))
client.test_insert_order(symbol, Side.BUY, Decimal("0.1"), 10000)
client.test_insert_order(symbol, Side.BUY, Decimal("0.2"), 10000)
# Exact total: 1000 + 2000 = 3000 = limit.
# Any additional order must be rejected.
client.test_insert_order(symbol, Side.BUY, Decimal("0.01"), 1,
expect_success=False)
def test_amount_freed_correctly_after_cancelling_fractional_price_orders(self) -> None:
"""After inserting and cancelling orders at fractional prices, the full
capacity must be available again no residual float error may block.
"""
symbol = self._create_instrument_via_admin(tick_size=Decimal("0.0001"))
client = self._connect_and_login()
client.test_set_instrument_risk_limits(
symbol, InstrumentRiskLimits(max_outstanding_amount=1000.0))
for price in ["33.3333", "33.3334", "33.3333"]:
resp = client.test_insert_order(symbol, Side.BUY, Decimal(price), 1)
order_id = resp.get_response().order_id
client.test_cancel_order(symbol, order_id=order_id)
# Full capacity must be restored despite float add/subtract drift.
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 10)
def test_amount_freed_correctly_after_trades_at_fractional_prices(self) -> None:
"""After trades fully fill orders at fractional prices, the outstanding
amount must return to zero so the full limit is available again.
"""
symbol = self._create_instrument_via_admin(tick_size=Decimal("0.0001"))
user_a = self._connect_and_login(f"{self.test_name}_2")
user_b = self._connect_and_login(f"{self.test_name}_3")
user_a.test_set_instrument_risk_limits(
symbol, InstrumentRiskLimits(max_outstanding_amount=1000.0))
# Insert and fully fill 3 orders at fractional prices.
for price in ["33.3333", "33.3334", "33.3333"]:
user_a.test_insert_order(symbol, Side.BUY, Decimal(price), 1)
user_b.test_insert_order(symbol, Side.SELL, Decimal(price), 1)
# Full capacity must be available again.
user_a.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 10)
def test_amount_freed_correctly_after_partial_trade_at_fractional_price(self) -> None:
"""A partial fill at a fractional price must reduce the outstanding
amount by exactly traded_qty x price, not by a drifted float value.
"""
symbol = self._create_instrument_via_admin(tick_size=Decimal("0.0001"))
user_a = self._connect_and_login(f"{self.test_name}_2")
user_b = self._connect_and_login(f"{self.test_name}_3")
user_a.test_set_instrument_risk_limits(
symbol, InstrumentRiskLimits(max_outstanding_amount=100.0))
# Insert qty 10 at 10.01 → amount = 100.1 > limit? No: 10.01 x 10 = 100.10
# Use a value that fits: 10.0 x 10 = 100.0 = limit
user_a.test_insert_order(symbol, Side.BUY, Decimal("10.0001"), 5)
# Outstanding: 50.0005. Partial fill 3 units → frees 30.0003.
user_b.test_insert_order(symbol, Side.SELL, Decimal("10.0001"), 3)
# Remaining outstanding: 20.0002. Room left: 79.9998.
user_a.test_insert_order(symbol, Side.BUY, Decimal("79.9998"), 1)
def test_order_amount_rolling_limit_boundary_with_fractional_multiplication(self) -> None:
"""The rolling-window amount sum must handle precision the same way as
the outstanding amount no false rejection at the exact boundary.
"""
symbol = self._create_instrument_via_admin()
client = self._connect_and_login()
client.test_set_instrument_risk_limits(symbol, InstrumentRiskLimits(
order_amount_rolling_limit=RollingWindowLimit(
limit=300, window_in_seconds=60)))
# Exact: 100.01 x 3 = 300.03, but limit is 300 → rejected.
# First fill up to just under the limit.
client.test_insert_order(symbol, Side.BUY, Decimal("99.99"), 1)
client.test_insert_order(symbol, Side.BUY, Decimal("99.99"), 1)
client.test_insert_order(symbol, Side.BUY, Decimal("99.99"), 1)
# Exact rolling total: 299.97. Room for 0.03 more.
client.test_insert_order(symbol, Side.BUY, Decimal("0.01"), 3)
# Now exact total: 299.97 + 0.03 = 300.0 = limit. Next must fail.
client.test_insert_order(symbol, Side.BUY, Decimal("0.01"), 1,
expect_success=False)
def test_many_fractional_insert_cancel_cycles_leave_no_residual(self) -> None:
"""Repeated insert-cancel cycles at the same fractional price must not
accumulate residual float error that eventually blocks orders.
"""
symbol = self._create_instrument_via_admin(tick_size=Decimal("0.0001"))
client = self._connect_and_login()
client.test_set_instrument_risk_limits(
symbol, InstrumentRiskLimits(max_outstanding_amount=100.0))
for _ in range(20):
resp = client.test_insert_order(
symbol, Side.BUY, Decimal("33.3333"), 1)
order_id = resp.get_response().order_id
client.test_cancel_order(symbol, order_id=order_id)
# After 20 round-trips, full capacity must still be available.
client.test_insert_order(symbol, Side.BUY, Decimal("100.0"), 1)

View File

@ -4,11 +4,11 @@ requires-python = ">=3.11"
[[package]] [[package]]
name = "attrs" name = "attrs"
version = "25.3.0" version = "25.1.0"
source = { registry = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/simple" } source = { registry = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/simple" }
sdist = { url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/5a/b0/1367933a8532ee6ff8d63537de4f1177af4bff9f3e829baf7331f595bb24/attrs-25.3.0.tar.gz", hash = "sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b" } sdist = { url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/49/7c/fdf464bcc51d23881d110abd74b512a42b3d5d376a55a831b44c603ae17f/attrs-25.1.0.tar.gz", hash = "sha256:1c97078a80c814273a76b2a298a932eb681c87415c11dee0a6921de7f1b02c3e" }
wheels = [ wheels = [
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/77/06/bb80f5f86020c4551da315d78b3ab75e8228f89f0162f2c3a819e407941a/attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3" }, { url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/fc/30/d4986a882011f9df997a55e6becd864812ccfcd821d64aac8570ee39f719/attrs-25.1.0-py3-none-any.whl", hash = "sha256:c75a69e28a550a7e93789579c22aa26b0f5b83b75dc4e08fe092980051e1090a" },
] ]
[[package]] [[package]]
@ -20,6 +20,15 @@ wheels = [
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6" }, { url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6" },
] ]
[[package]]
name = "execnet"
version = "2.1.2"
source = { registry = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/simple" }
sdist = { url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/bf/89/780e11f9588d9e7128a3f87788354c7946a9cbb1401ad38a48c4db9a4f07/execnet-2.1.2.tar.gz", hash = "sha256:63d83bfdd9a23e35b9c6a3261412324f964c2ec8dcd8d3c6916ee9373e0befcd" }
wheels = [
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/ab/84/02fc1827e8cdded4aa65baef11296a9bbe595c474f0d6d758af082d849fd/execnet-2.1.2-py3-none-any.whl", hash = "sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec" },
]
[[package]] [[package]]
name = "flake8" name = "flake8"
version = "7.3.0" version = "7.3.0"
@ -50,11 +59,11 @@ wheels = [
[[package]] [[package]]
name = "iniconfig" name = "iniconfig"
version = "2.1.0" version = "2.0.0"
source = { registry = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/simple" } source = { registry = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/simple" }
sdist = { url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7" } sdist = { url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/d7/4b/cbd8e699e64a6f16ca3a8220661b5f83792b3017d0f79807cb8708d33913/iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3" }
wheels = [ wheels = [
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760" }, { url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374" },
] ]
[[package]] [[package]]
@ -228,7 +237,7 @@ wheels = [
] ]
[[package]] [[package]]
name = "optivex" name = "optivextests"
version = "0.1.0" version = "0.1.0"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
@ -238,7 +247,9 @@ dependencies = [
{ name = "mypy" }, { name = "mypy" },
{ name = "mypy-protobuf" }, { name = "mypy-protobuf" },
{ name = "protobuf" }, { name = "protobuf" },
{ name = "psutil" },
{ name = "pytest" }, { name = "pytest" },
{ name = "pytest-xdist" },
{ name = "types-jsonschema" }, { name = "types-jsonschema" },
{ name = "types-protobuf" }, { name = "types-protobuf" },
] ]
@ -251,7 +262,9 @@ requires-dist = [
{ name = "mypy", specifier = ">=1.19.1" }, { name = "mypy", specifier = ">=1.19.1" },
{ name = "mypy-protobuf", specifier = ">=5.0.0" }, { name = "mypy-protobuf", specifier = ">=5.0.0" },
{ name = "protobuf", specifier = ">=7.34.0" }, { name = "protobuf", specifier = ">=7.34.0" },
{ name = "psutil", specifier = ">=7.0.0" },
{ name = "pytest", specifier = ">=9.0.2" }, { name = "pytest", specifier = ">=9.0.2" },
{ name = "pytest-xdist", specifier = ">=3.8.0" },
{ name = "types-jsonschema", specifier = ">=4.23.0.20241208" }, { name = "types-jsonschema", specifier = ">=4.23.0.20241208" },
{ name = "types-protobuf", specifier = ">=6.32.1.20260221" }, { name = "types-protobuf", specifier = ">=6.32.1.20260221" },
] ]
@ -298,6 +311,34 @@ wheels = [
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/a4/e7/14dc9366696dcb53a413449881743426ed289d687bcf3d5aee4726c32ebb/protobuf-7.34.0-py3-none-any.whl", hash = "sha256:e3b914dd77fa33fa06ab2baa97937746ab25695f389869afdf03e81f34e45dc7" }, { url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/a4/e7/14dc9366696dcb53a413449881743426ed289d687bcf3d5aee4726c32ebb/protobuf-7.34.0-py3-none-any.whl", hash = "sha256:e3b914dd77fa33fa06ab2baa97937746ab25695f389869afdf03e81f34e45dc7" },
] ]
[[package]]
name = "psutil"
version = "7.2.2"
source = { registry = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/simple" }
sdist = { url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/aa/c6/d1ddf4abb55e93cebc4f2ed8b5d6dbad109ecb8d63748dd2b20ab5e57ebe/psutil-7.2.2.tar.gz", hash = "sha256:0746f5f8d406af344fd547f1c8daa5f5c33dbc293bb8d6a16d80b4bb88f59372" }
wheels = [
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/51/08/510cbdb69c25a96f4ae523f733cdc963ae654904e8db864c07585ef99875/psutil-7.2.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:2edccc433cbfa046b980b0df0171cd25bcaeb3a68fe9022db0979e7aa74a826b" },
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/d6/f5/97baea3fe7a5a9af7436301f85490905379b1c6f2dd51fe3ecf24b4c5fbf/psutil-7.2.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e78c8603dcd9a04c7364f1a3e670cea95d51ee865e4efb3556a3a63adef958ea" },
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/37/d6/246513fbf9fa174af531f28412297dd05241d97a75911ac8febefa1a53c6/psutil-7.2.2-cp313-cp313t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1a571f2330c966c62aeda00dd24620425d4b0cc86881c89861fbc04549e5dc63" },
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/b8/b5/9182c9af3836cca61696dabe4fd1304e17bc56cb62f17439e1154f225dd3/psutil-7.2.2-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:917e891983ca3c1887b4ef36447b1e0873e70c933afc831c6b6da078ba474312" },
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/16/ba/0756dca669f5a9300d0cbcbfae9a4c30e446dfc7440ffe43ded5724bfd93/psutil-7.2.2-cp313-cp313t-win_amd64.whl", hash = "sha256:ab486563df44c17f5173621c7b198955bd6b613fb87c71c161f827d3fb149a9b" },
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/1c/61/8fa0e26f33623b49949346de05ec1ddaad02ed8ba64af45f40a147dbfa97/psutil-7.2.2-cp313-cp313t-win_arm64.whl", hash = "sha256:ae0aefdd8796a7737eccea863f80f81e468a1e4cf14d926bd9b6f5f2d5f90ca9" },
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/81/69/ef179ab5ca24f32acc1dac0c247fd6a13b501fd5534dbae0e05a1c48b66d/psutil-7.2.2-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:eed63d3b4d62449571547b60578c5b2c4bcccc5387148db46e0c2313dad0ee00" },
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/7b/64/665248b557a236d3fa9efc378d60d95ef56dd0a490c2cd37dafc7660d4a9/psutil-7.2.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:7b6d09433a10592ce39b13d7be5a54fbac1d1228ed29abc880fb23df7cb694c9" },
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/d5/2e/e6782744700d6759ebce3043dcfa661fb61e2fb752b91cdeae9af12c2178/psutil-7.2.2-cp314-cp314t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1fa4ecf83bcdf6e6c8f4449aff98eefb5d0604bf88cb883d7da3d8d2d909546a" },
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/57/49/0a41cefd10cb7505cdc04dab3eacf24c0c2cb158a998b8c7b1d27ee2c1f5/psutil-7.2.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e452c464a02e7dc7822a05d25db4cde564444a67e58539a00f929c51eddda0cf" },
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/dd/2c/ff9bfb544f283ba5f83ba725a3c5fec6d6b10b8f27ac1dc641c473dc390d/psutil-7.2.2-cp314-cp314t-win_amd64.whl", hash = "sha256:c7663d4e37f13e884d13994247449e9f8f574bc4655d509c3b95e9ec9e2b9dc1" },
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/f2/fc/f8d9c31db14fcec13748d373e668bc3bed94d9077dbc17fb0eebc073233c/psutil-7.2.2-cp314-cp314t-win_arm64.whl", hash = "sha256:11fe5a4f613759764e79c65cf11ebdf26e33d6dd34336f8a337aa2996d71c841" },
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/e7/36/5ee6e05c9bd427237b11b3937ad82bb8ad2752d72c6969314590dd0c2f6e/psutil-7.2.2-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ed0cace939114f62738d808fdcecd4c869222507e266e574799e9c0faa17d486" },
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/80/c4/f5af4c1ca8c1eeb2e92ccca14ce8effdeec651d5ab6053c589b074eda6e1/psutil-7.2.2-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:1a7b04c10f32cc88ab39cbf606e117fd74721c831c98a27dc04578deb0c16979" },
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/b5/70/5d8df3b09e25bce090399cf48e452d25c935ab72dad19406c77f4e828045/psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:076a2d2f923fd4821644f5ba89f059523da90dc9014e85f8e45a5774ca5bc6f9" },
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/63/65/37648c0c158dc222aba51c089eb3bdfa238e621674dc42d48706e639204f/psutil-7.2.2-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b0726cecd84f9474419d67252add4ac0cd9811b04d61123054b9fb6f57df6e9e" },
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/8e/13/125093eadae863ce03c6ffdbae9929430d116a246ef69866dad94da3bfbc/psutil-7.2.2-cp36-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:fd04ef36b4a6d599bbdb225dd1d3f51e00105f6d48a28f006da7f9822f2606d8" },
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/04/78/0acd37ca84ce3ddffaa92ef0f571e073faa6d8ff1f0559ab1272188ea2be/psutil-7.2.2-cp36-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b58fabe35e80b264a4e3bb23e6b96f9e45a3df7fb7eed419ac0e5947c61e47cc" },
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/b4/90/e2159492b5426be0c1fef7acba807a03511f97c5f86b3caeda6ad92351a7/psutil-7.2.2-cp37-abi3-win_amd64.whl", hash = "sha256:eb7e81434c8d223ec4a219b5fc1c47d0417b12be7ea866e24fb5ad6e84b3d988" },
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/8c/c7/7bb2e321574b10df20cbde462a94e2b71d05f9bbda251ef27d104668306a/psutil-7.2.2-cp37-abi3-win_arm64.whl", hash = "sha256:8c233660f575a5a89e6d4cb65d9f938126312bca76d8fe087b947b3a1aaac9ee" },
]
[[package]] [[package]]
name = "pycodestyle" name = "pycodestyle"
version = "2.14.0" version = "2.14.0"
@ -341,6 +382,19 @@ wheels = [
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b" }, { url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b" },
] ]
[[package]]
name = "pytest-xdist"
version = "3.8.0"
source = { registry = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/simple" }
dependencies = [
{ name = "execnet" },
{ name = "pytest" },
]
sdist = { url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/78/b4/439b179d1ff526791eb921115fca8e44e596a13efeda518b9d845a619450/pytest_xdist-3.8.0.tar.gz", hash = "sha256:7e578125ec9bc6050861aa93f2d59f1d8d085595d6551c2c90b6f4fad8d3a9f1" }
wheels = [
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/ca/31/d4e37e9e550c2b92a9cbc2e4d0b7420a27224968580b5a447f420847c975/pytest_xdist-3.8.0-py3-none-any.whl", hash = "sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88" },
]
[[package]] [[package]]
name = "referencing" name = "referencing"
version = "0.36.2" version = "0.36.2"
@ -486,9 +540,9 @@ wheels = [
[[package]] [[package]]
name = "typing-extensions" name = "typing-extensions"
version = "4.13.0" version = "4.12.2"
source = { registry = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/simple" } source = { registry = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/simple" }
sdist = { url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/0e/3e/b00a62db91a83fff600de219b6ea9908e6918664899a2d85db222f4fbf19/typing_extensions-4.13.0.tar.gz", hash = "sha256:0a4ac55a5820789d87e297727d229866c9650f6521b64206413c4fbada24d95b" } sdist = { url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/df/db/f35a00659bc03fec321ba8bce9420de607a1d37f8342eee1863174c69557/typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8" }
wheels = [ wheels = [
{ url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/e0/86/39b65d676ec5732de17b7e3c476e45bb80ec64eb50737a8dce1a4178aba1/typing_extensions-4.13.0-py3-none-any.whl", hash = "sha256:c8dd92cc0d6425a97c18fbb9d1954e5ff92c1ca881a309c45f06ebc0b79058e5" }, { url = "https://optiver.jfrog.io/artifactory/api/pypi/pypi/packages/packages/26/9f/ad63fc0248c5379346306f8668cda6e2e2e9c95e01216d2b8ffd9ff037d0/typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d" },
] ]

View File

@ -1,26 +0,0 @@
# Sample App
This is a sample application for the Stock Exchange training program.
## Quick Tip: Connecting to the Server via Command Line (Linux)
1. **Connect to the Server**
Use `telnet` to connect to the server:
```bash
telnet <server_address> <port>
```
Replace `<server_address>` and `<port>` with the appropriate values.
2. **Send Messages**
Once connected, type your message and press `Enter` to send it to the server.
3. **View Responses**
The server will echo back the messages you send.
### Example
```bash
telnet localhost 8080
Hello, Server!
# Server response
```

Some files were not shown because too many files have changed in this diff Show More