gtat-tech-career-kickstarte.../solution/tests/common/component_orchestrator.py

266 lines
11 KiB
Python

"""Generic component orchestrator for system tests.
Reads deployment_config.json and testing_dependencies.json to determine which
components to start for a given protocol, resolves their startup order via
topological sort on connectTo references, assigns dynamic ports, and manages
process lifecycles.
"""
import json
import logging
import socket
from collections import deque
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from connection.ip_address import IpAddress
from tests.common.process_manager import PerformanceStats, ProcessManager
from tests.common import constants
logger = logging.getLogger(__name__)
COMPONENT_DEPENDENCIES_FILE = Path(__file__).parent.parent / "continuous_deployment" / "testing_dependencies.json"
def _allocate_ports(count: int) -> list[int]:
"""Reserve `count` ephemeral ports from the OS, minimising collision risk."""
sockets: list[socket.socket] = []
ports: list[int] = []
try:
for _ in range(count):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(("localhost", 0))
ports.append(s.getsockname()[1])
sockets.append(s)
finally:
for s in sockets:
s.close()
return ports
@dataclass
class ComponentInstance:
name: str
package_name: str
protocols: list[str]
original_config: dict[str, Any]
auth_required: bool
address: IpAddress | None = None
process_manager: ProcessManager | None = None
class ComponentOrchestrator:
"""Manages the lifecycle of all components needed to test a specific protocol.
Typical usage inside a pytest fixture::
orchestrator = ComponentOrchestrator(venv_path, deployment_config, users_data_file)
orchestrator.start_for_protocol("order_book")
server_addr = orchestrator.get_server_address("order_book")
yield
orchestrator.stop_all()
"""
def __init__(self, venv_path: Path, deployment_config: dict[str, Any],
data_file_path: Path) -> None:
self.venv_path = venv_path
self.deployment_config = deployment_config
self.data_file_path = data_file_path
self.component_dependencies = self._load_component_dependencies()
self._running_components: list[ComponentInstance] = []
self._protocol_to_address: dict[str, IpAddress] = {}
@staticmethod
def _load_component_dependencies() -> dict[str, Any]:
with open(COMPONENT_DEPENDENCIES_FILE) as f:
return json.load(f)
def start_for_protocol(self, protocol: str) -> None:
"""Resolve dependencies and start every component needed to test *protocol*."""
required_protocols = self._resolve_required_protocols(protocol)
logger.info(f"Protocol '{protocol}' requires protocols: {required_protocols}")
components = self._select_components(required_protocols)
ordered = self._topological_sort(components)
logger.info(f"Component startup order: {[c.name for c in ordered]}")
ports = _allocate_ports(len(ordered))
port_assignments: dict[str, int] = {
comp.name: port for comp, port in zip(ordered, ports)
}
for comp in ordered:
config = self._build_runtime_config(comp, port_assignments)
address = IpAddress(host="localhost", port=port_assignments[comp.name])
pm = ProcessManager(self.venv_path)
assert pm.start_process(comp.package_name, config), (
f"Failed to start component '{comp.name}' (binary: {comp.package_name})"
)
startup_ms = pm.wait_until_server_is_ready(address, timeout_in_seconds=2)
logger.info(
f"Component '{comp.name}' ready in {startup_ms:.2f} ms "
f"on port {address.port}"
)
comp.address = address
comp.process_manager = pm
self._running_components.append(comp)
for p in comp.protocols:
self._protocol_to_address[p] = address
def get_server_address(self, protocol: str) -> IpAddress:
"""Return the listen address of the component implementing *protocol*."""
if protocol not in self._protocol_to_address:
raise ValueError(f"No running component implements protocol '{protocol}'")
return self._protocol_to_address[protocol]
def is_auth_required(self, protocol: str) -> bool:
"""Return whether the component implementing *protocol* requires authentication."""
for comp in self._running_components:
if protocol in comp.protocols:
return comp.auth_required
raise ValueError(f"No running component implements protocol '{protocol}'")
def stop_all(self) -> dict[str, PerformanceStats]:
"""Stop all running components in reverse startup order.
Returns a mapping of component name to performance statistics.
"""
perf_stats: dict[str, PerformanceStats] = {}
for comp in reversed(self._running_components):
if comp.process_manager:
try:
comp.process_manager.stop_process()
if comp.process_manager.performance_stats:
perf_stats[comp.name] = comp.process_manager.performance_stats
except Exception:
logger.exception(f"Error stopping component '{comp.name}'")
self._running_components.clear()
self._protocol_to_address.clear()
return perf_stats
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _resolve_required_protocols(self, protocol: str) -> set[str]:
"""Transitively collect all *required* protocols for *protocol*."""
result: set[str] = {protocol}
queue: deque[str] = deque([protocol])
while queue:
current = queue.popleft()
deps = self.component_dependencies.get(current, [])
for dep in deps:
if dep not in result:
result.add(dep)
queue.append(dep)
return result
def _select_components(self, required_protocols: set[str]) -> list[ComponentInstance]:
"""Pick the minimal set of components from the deployment config that
cover *required_protocols*, plus any transitive ``connectTo``
dependencies. A single component may cover several protocols."""
comp_cfg_by_name: dict[str, dict[str, Any]] = {
cfg["name"]: cfg for cfg in self.deployment_config["components"]
}
selected_names: set[str] = set()
for comp_cfg in self.deployment_config["components"]:
if set(comp_cfg["protocols"]) & required_protocols:
selected_names.add(comp_cfg["name"])
queue: deque[str] = deque(selected_names)
while queue:
name = queue.popleft()
cfg = comp_cfg_by_name.get(name)
assert cfg is not None, f"Component '{name}' not defined in deployment config"
for target in cfg["config"].get("connectTo", {}):
if target not in selected_names:
assert target in comp_cfg_by_name, f"Component '{target}' not defined in deployment config"
selected_names.add(target)
queue.append(target)
components: list[ComponentInstance] = []
covered: set[str] = set()
for comp_cfg in self.deployment_config["components"]:
if comp_cfg["name"] in selected_names:
comp_protocols = set(comp_cfg["protocols"])
components.append(ComponentInstance(
name=comp_cfg["name"],
package_name=comp_cfg["packageName"],
protocols=comp_protocols,
original_config=comp_cfg["config"],
auth_required=comp_cfg.get("authRequired", True),
))
covered |= comp_protocols
missing = required_protocols - covered
if missing:
raise ValueError(
f"Protocols {missing} are required but not implemented by "
f"any component in deployment_config.json"
)
return components
def _topological_sort(self, components: list[ComponentInstance]) -> list[ComponentInstance]:
"""Sort components so that each component's connectTo targets are
started before the component itself (Kahn's algorithm)."""
name_set = {c.name for c in components}
by_name = {c.name: c for c in components}
# adj[A] = {B, ...} means A depends on (connects to) B
adj: dict[str, set[str]] = {c.name: set() for c in components}
for comp in components:
for target in comp.original_config.get("connectTo", {}):
if target in name_set:
adj[comp.name].add(target)
in_deg = {name: len(deps) for name, deps in adj.items()}
queue: deque[str] = deque(n for n, d in in_deg.items() if d == 0)
result: list[ComponentInstance] = []
while queue:
name = queue.popleft()
result.append(by_name[name])
for other, deps in adj.items():
if name in deps:
in_deg[other] -= 1
if in_deg[other] == 0:
queue.append(other)
if len(result) != len(components):
raise ValueError("Circular dependency detected among components")
return result
def _build_runtime_config(
self, comp: ComponentInstance, port_assignments: dict[str, int]
) -> dict[str, Any]:
"""Clone the component's config with dynamic ports and the test log dir."""
config: dict[str, Any] = dict(comp.original_config)
config["logDirectory"] = constants.LOG_DIRECTORY
config["listenOn"] = {
"host": "localhost",
"port": port_assignments[comp.name],
}
config["dataFilePath"] = str(self.data_file_path)
if "connectTo" in config:
rewritten: dict[str, Any] = {}
for target_name, target_cfg in config["connectTo"].items():
if target_name in port_assignments:
rewritten[target_name] = {
"host": "localhost",
"port": port_assignments[target_name],
}
else:
rewritten[target_name] = target_cfg
config["connectTo"] = rewritten
return config