gtat-tech-career-kickstarte.../solution/tests/common/process_manager.py

181 lines
7.2 KiB
Python

import logging
import json
import os
from dataclasses import dataclass
from pathlib import Path
import signal
import subprocess
import threading
import socket
import time
import tempfile
from typing import IO
import psutil
from connection.ip_address import IpAddress
logger = logging.getLogger(__name__)
GRACEFUL_STOP_TIMEOUT: int = 30 # seconds
FORCE_KILL_TIMEOUT: int = 5 # seconds
PERF_MONITOR_INTERVAL_S: float = 0.2
@dataclass(frozen=True)
class PerformanceStats:
peak_rss_mb: float
avg_rss_mb: float
peak_cpu_percent: float
avg_cpu_percent: float
samples: int
class PerformanceMonitor:
"""Lightweight background sampler for a subprocess's memory and CPU usage."""
def __init__(self, pid: int, interval_s: float = PERF_MONITOR_INTERVAL_S) -> None:
self._ps_process = psutil.Process(pid)
self._interval_s = interval_s
self._stop_event = threading.Event()
self._thread = threading.Thread(target=self._poll, daemon=True)
self._rss_samples: list[float] = []
self._cpu_samples: list[float] = []
def start(self) -> None:
try:
self._ps_process.cpu_percent()
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass
self._thread.start()
def stop(self) -> PerformanceStats:
self._stop_event.set()
self._thread.join(timeout=2)
return self._build_stats()
def _poll(self) -> None:
while not self._stop_event.is_set():
try:
rss_mb = self._ps_process.memory_info().rss / (1024 * 1024)
self._rss_samples.append(rss_mb)
self._cpu_samples.append(self._ps_process.cpu_percent())
except (psutil.NoSuchProcess, psutil.AccessDenied):
break
self._stop_event.wait(self._interval_s)
def _build_stats(self) -> PerformanceStats:
avg_rss = sum(self._rss_samples) / len(self._rss_samples) if self._rss_samples else 0
peak_rss = max(self._rss_samples) if self._rss_samples else 0
avg_cpu = sum(self._cpu_samples) / len(self._cpu_samples) if self._cpu_samples else 0
peak_cpu = max(self._cpu_samples) if self._cpu_samples else 0
return PerformanceStats(
peak_rss_mb=round(peak_rss, 2),
avg_rss_mb=round(avg_rss, 2),
peak_cpu_percent=round(peak_cpu, 2),
avg_cpu_percent=round(avg_cpu, 2),
samples=len(self._rss_samples),
)
class ProcessManager:
def __init__(self, venv_path: Path) -> None:
self.venv_path = venv_path
self.process: subprocess.Popen | None = None
self.stdout_thread: threading.Thread | None = None
self._perf_monitor: PerformanceMonitor | None = None
self.performance_stats: PerformanceStats | None = None
def _create_temp_app_config(self, config_data: dict) -> str:
"""Creates a temporary configuration file and returns its path."""
self.temp_config_file = tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json")
json.dump(config_data, self.temp_config_file)
self.temp_config_file.close()
return self.temp_config_file.name
def start_process(self, package_name: str, config_data: dict | None = None) -> bool:
try:
"""Starts the process, optionally using a temporary configuration file."""
command = [str(self.venv_path / "bin" / package_name)]
if config_data:
temp_config_path = self._create_temp_app_config(config_data)
logger.info(f"Temporary config file created at: {temp_config_path}")
command.extend(["-c", str(temp_config_path)])
logger.info(f"Starting process with command: {' '.join(command)}")
self.process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True
)
time.sleep(0.15) # Give the process some time to start (or crash)
if self.process.poll() is not None:
raise RuntimeError(f"Failed to start process {package_name}. Return code: {self.process.returncode}")
self.stdout_thread = threading.Thread(target=_log_output, args=(package_name, self.process.stdout,), daemon=True)
self.stdout_thread.start()
self._perf_monitor = PerformanceMonitor(self.process.pid)
self._perf_monitor.start()
logger.info(f"Process started with PID: {self.process.pid}")
return True
except Exception as e:
logger.exception("Error starting process")
return False
def wait_until_server_is_ready(self, ip_address: IpAddress, timeout_in_seconds: float) -> float:
start_time = time.time()
while time.time() - start_time < timeout_in_seconds:
if self.process.poll() is not None:
raise RuntimeError(f"Process crashed. Return code: {self.process.returncode}")
try:
conn_socket = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM)
conn_socket.connect((ip_address.host, ip_address.port))
time_to_startup = (time.time() - start_time) * 1000
logger.info(f"Server started in approximately {time_to_startup:.2f} ms")
conn_socket.close()
return time_to_startup
except ConnectionRefusedError:
logger.debug("Server is not ready yet")
time.sleep(0.1)
raise TimeoutError(f"Server did not start in {timeout_in_seconds} seconds")
def stop_process(self) -> None:
assert self.process is not None, "Process is not running"
if self._perf_monitor is not None:
self.performance_stats = self._perf_monitor.stop()
logger.info(f"Sending SIGINT to process {self.process.pid}")
self.process.send_signal(signal.SIGINT)
try:
self.process.wait(timeout=GRACEFUL_STOP_TIMEOUT)
except subprocess.TimeoutExpired:
logger.warning(f"Process {self.process.pid} did not stop after SIGINT grace period ({GRACEFUL_STOP_TIMEOUT} seconds). Sending SIGTERM...")
self.process.send_signal(signal.SIGTERM)
try:
self.process.wait(timeout=FORCE_KILL_TIMEOUT)
except subprocess.TimeoutExpired:
logger.warning(f"Process {self.process.pid} still running after SIGTERM. Sending SIGKILL...")
self.process.send_signal(signal.SIGKILL)
self.process.wait(timeout=FORCE_KILL_TIMEOUT)
if self.process.returncode != 0:
logger.warning(f"Process {self.process.pid} exited with return code: {self.process.returncode}")
logger.info("Process stopped")
if self.temp_config_file:
os.unlink(self.temp_config_file.name)
logger.info(f"Temporary config file deleted: {self.temp_config_file.name}")
assert self.stdout_thread is not None
self.stdout_thread.join()
def _log_output(package_name: str, pipe: IO[str]) -> None:
sub_logger = logging.getLogger(f"SUBPROCESS_LOG_{package_name}")
for line in iter(pipe.readline, ''):
sub_logger.info(line.strip())