From 10ee00afb5ac63bde3d9cc2013fa1e5db393f404 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Mon, 15 Dec 2025 21:52:49 +0000 Subject: [PATCH 01/22] Add model manager that automatically manage model across processes --- .../apache_beam/ml/inference/model_manager.py | 669 ++++++++++++++++++ .../ml/inference/model_manager_test.py | 548 ++++++++++++++ 2 files changed, 1217 insertions(+) create mode 100644 sdks/python/apache_beam/ml/inference/model_manager.py create mode 100644 sdks/python/apache_beam/ml/inference/model_manager_test.py diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py new file mode 100644 index 000000000000..e56b8ee8d03f --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -0,0 +1,669 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Module for managing ML models in Apache Beam pipelines. + +This module provides classes and functions to efficiently manage multiple +machine learning models within Apache Beam pipelines. It includes functionality +for loading, caching, and updating models using multi-process shared memory, +ensuring that models are reused across different workers to optimize resource +usage and performance. +""" + +import uuid +import time +import threading +import subprocess +import logging +import gc +import numpy as np +from scipy.optimize import nnls +import torch +import heapq +import itertools +from collections import defaultdict, deque, Counter, OrderedDict +from typing import Dict, Any, Tuple, Optional, Callable + +logger = logging.getLogger(__name__) + + +class GPUMonitor: + def __init__( + self, + fallback_memory_mb: float = 16000.0, + poll_interval: float = 0.5, + peak_window_seconds: float = 30.0): + self._current_usage = 0.0 + self._peak_usage = 0.0 + self._total_memory = fallback_memory_mb + self._poll_interval = poll_interval + self._peak_window_seconds = peak_window_seconds + self._memory_history = deque() + self._running = False + self._thread = None + self._lock = threading.Lock() + self._gpu_available = self._detect_hardware() + + def _detect_hardware(self): + try: + cmd = [ + "nvidia-smi", + "--query-gpu=memory.total", + "--format=csv,noheader,nounits" + ] + output = subprocess.check_output(cmd, text=True).strip() + self._total_memory = float(output) + return True + except (FileNotFoundError, subprocess.CalledProcessError): + logger.warning( + "nvidia-smi not found or failed. Defaulting total memory to %s MB", + self._total_memory) + return False + except Exception as e: + logger.warning( + "Error parsing nvidia-smi output: %s. " + "Defaulting total memory to %s MB", + e, + self._total_memory) + return False + + def start(self): + if self._running or not self._gpu_available: + return + self._running = True + self._thread = threading.Thread(target=self._poll_loop, daemon=True) + self._thread.start() + + def stop(self): + self._running = False + if self._thread: + self._thread.join() + + def reset_peak(self): + with self._lock: + now = time.time() + self._memory_history.clear() + self._memory_history.append((now, self._current_usage)) + self._peak_usage = self._current_usage + + def get_stats(self) -> Tuple[float, float, float]: + with self._lock: + return self._current_usage, self._peak_usage, self._total_memory + + def refresh(self): + """Forces an immediate poll of the GPU.""" + usage = self._get_nvidia_smi_used() + now = time.time() + with self._lock: + self._current_usage = usage + self._memory_history.append((now, usage)) + # Recalculate peak immediately + while self._memory_history and (now - self._memory_history[0][0] + > self._peak_window_seconds): + self._memory_history.popleft() + self._peak_usage = ( + max(m for _, m in self._memory_history) + if self._memory_history else usage) + + def _get_nvidia_smi_used(self) -> float: + try: + cmd = "nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits" + output = subprocess.check_output(cmd, shell=True).decode("utf-8").strip() + free_memory = float(output) + return self._total_memory - free_memory + except Exception: + return 0.0 + + def _poll_loop(self): + while self._running: + usage = self._get_nvidia_smi_used() + now = time.time() + with self._lock: + self._current_usage = usage + self._memory_history.append((now, usage)) + while self._memory_history and (now - self._memory_history[0][0] + > self._peak_window_seconds): + self._memory_history.popleft() + self._peak_usage = ( + max(m for _, m in self._memory_history) + if self._memory_history else usage) + time.sleep(self._poll_interval) + + +class ResourceEstimator: + def __init__(self, smoothing_factor: float = 0.2, min_data_points: int = 5): + self.smoothing_factor = smoothing_factor + self.min_data_points = min_data_points + self.estimates: Dict[str, float] = {} + self.history = defaultdict(lambda: deque(maxlen=20)) + self.known_models = set() + self._lock = threading.Lock() + + def is_unknown(self, model_tag: str) -> bool: + with self._lock: + return model_tag not in self.estimates + + def get_estimate(self, model_tag: str, default_mb: float = 4000.0) -> float: + with self._lock: + return self.estimates.get(model_tag, default_mb) + + def set_initial_estimate(self, model_tag: str, cost: float): + with self._lock: + self.estimates[model_tag] = cost + self.known_models.add(model_tag) + logger.info("Initial Profile for %s: %s MB", model_tag, cost) + + def add_observation( + self, active_snapshot: Dict[str, int], peak_memory: float): + logger.info( + "Adding Observation: Snapshot=%s, PeakMemory=%.1f MB", + active_snapshot, + peak_memory) + if not active_snapshot: + return + with self._lock: + config_key = tuple(sorted(active_snapshot.items())) + self.history[config_key].append(peak_memory) + for tag in active_snapshot: + self.known_models.add(tag) + self._solve() + + def _solve(self): + """ + Solves Ax=b using raw readings (no pre-averaging) and NNLS. + This creates a 'tall' matrix A where every memory reading is + a separate equation. + """ + unique = sorted(list(self.known_models)) + + # We need to build the matrix first to know if we have enough data points + A, b = [], [] + + for config_key, mem_values in self.history.items(): + if not mem_values: + continue + + # 1. Create the feature row for this configuration ONCE + # (It represents the model counts + bias) + counts = dict(config_key) + feature_row = [counts.get(model, 0) for model in unique] + feature_row.append(1) # Bias column + + # 2. Add a separate row to the matrix for EVERY individual reading + # Instead of averaging, we flatten the history into the matrix + for reading in mem_values: + A.append(feature_row) # The inputs (models) stay the same + b.append(reading) # The output (memory) varies due to noise + + # Convert to numpy for SciPy + A = np.array(A) + b = np.array(b) + + if len( + self.history.keys()) < len(unique) + 1 or len(A) < self.min_data_points: + # Not enough data to solve yet + return + + logger.info( + "Solving with %s total observations for %s models.", + len(A), + len(unique)) + + try: + # Solve using Non-Negative Least Squares + # x will be >= 0 + x, _ = nnls(A, b) + + weights = x[:-1] + bias = x[-1] + + for i, model in enumerate(unique): + calculated_cost = weights[i] + + if model in self.estimates: + old = self.estimates[model] + new = (old * (1 - self.smoothing_factor)) + ( + calculated_cost * self.smoothing_factor) + self.estimates[model] = new + else: + self.estimates[model] = calculated_cost + + logger.info( + "Updated Estimate for %s: %.1f MB", model, self.estimates[model]) + logger.info("System Bias: %s MB", bias) + + except Exception as e: + logger.error("Solver failed: %s", e) + + +class TrackedModelProxy: + def __init__(self, obj): + object.__setattr__(self, "_wrapped_obj", obj) + object.__setattr__(self, "_beam_tracking_id", str(uuid.uuid4())) + + def __getattr__(self, name): + return getattr(self._wrapped_obj, name) + + def __setattr__(self, name, value): + setattr(self._wrapped_obj, name, value) + + def __call__(self, *args, **kwargs): + return self._wrapped_obj(*args, **kwargs) + + def __setstate__(self, state): + self.__dict__.update(state) + + def __getstate__(self): + return self.__dict__ + + def __str__(self): + return str(self._wrapped_obj) + + def __repr__(self): + return repr(self._wrapped_obj) + + def __dir__(self): + return dir(self._wrapped_obj) + + def unsafe_hard_delete(self): + if hasattr(self._wrapped_obj, "unsafe_hard_delete"): + self._wrapped_obj.unsafe_hard_delete() + + +class ModelManager: + _lock = threading.Lock() + + def __init__( + self, + monitor: Optional['GPUMonitor'] = None, + slack_percentage: float = 0.10, + poll_interval: float = 0.5, + peak_window_seconds: float = 30.0, + min_data_points: int = 5, + smoothing_factor: float = 0.2, + eviction_cooldown_seconds: float = 10.0, + min_model_copies: int = 1): + + self._estimator = ResourceEstimator( + min_data_points=min_data_points, smoothing_factor=smoothing_factor) + self._monitor = monitor if monitor else GPUMonitor( + poll_interval=poll_interval, peak_window_seconds=peak_window_seconds) + self._slack_percentage = slack_percentage + + self._eviction_cooldown = eviction_cooldown_seconds + self._min_model_copies = min_model_copies + + # Resource State + self._models = defaultdict(list) + self._idle_lru = OrderedDict() + self._active_counts = Counter() + self._total_active_jobs = 0 + self._pending_reservations = 0.0 + + self._isolation_mode = False + self._pending_isolation_count = 0 + self._isolation_baseline = 0.0 + + self._wait_queue = [] + self._ticket_counter = itertools.count() + self._cv = threading.Condition() + self._load_lock = threading.Lock() + + self._monitor.start() + + def all_models(self, tag) -> list[Any]: + return self._models[tag] + + def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: + current_priority = 0 if self._estimator.is_unknown(tag) else 1 + ticket_num = next(self._ticket_counter) + my_id = object() + + with self._cv: + # FAST PATH + if self._pending_isolation_count == 0 and not self._isolation_mode: + cached_instance = self._try_grab_from_lru(tag) + if cached_instance: + return cached_instance + + # SLOW PATH + logger.info( + "Acquire Queued: tag=%s, priority=%d " + "total models count=%s ticket num=%s", + tag, + current_priority, + len(self._models[tag]), + ticket_num) + heapq.heappush( + self._wait_queue, (current_priority, ticket_num, my_id, tag)) + + should_spawn = False + est_cost = 0.0 + is_unknown = False + + try: + while True: + if not self._wait_queue or self._wait_queue[0][2] is not my_id: + logger.info( + "Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num) + self._cv.wait() + continue + + real_is_unknown = self._estimator.is_unknown(tag) + real_priority = 0 if real_is_unknown else 1 + + if current_priority != real_priority: + heapq.heappop(self._wait_queue) + current_priority = real_priority + heapq.heappush( + self._wait_queue, (current_priority, ticket_num, my_id, tag)) + self._cv.notify_all() + continue + + cached_instance = self._try_grab_from_lru(tag) + if cached_instance: + return cached_instance + + is_unknown = real_is_unknown + + # Path A: Isolation + if is_unknown: + if self._total_active_jobs > 0: + logger.info( + "Waiting to enter isolation: tag=%s ticket num=%s", + tag, + ticket_num) + self._cv.wait() + continue + + logger.info("Unknown model %s detected. Flushing GPU.", tag) + self._delete_all_models() + + self._isolation_mode = True + self._total_active_jobs += 1 + self._isolation_baseline, _, _ = self._monitor.get_stats() + self._monitor.reset_peak() + should_spawn = True + break + + # Path B: Concurrent + else: + if self._pending_isolation_count > 0 or self._isolation_mode: + logger.info( + "Waiting due to isolation in progress: tag=%s ticket num%s", + tag, + ticket_num) + self._cv.wait() + continue + + curr, _, total = self._monitor.get_stats() + est_cost = self._estimator.get_estimate(tag) + limit = total * (1 - self._slack_percentage) + + # Use current usage for capacity check (ignore old spikes) + if (curr + self._pending_reservations + est_cost) <= limit: + self._pending_reservations += est_cost + self._total_active_jobs += 1 + self._active_counts[tag] += 1 + should_spawn = True + break + + # Evict to make space (passing tag to check demand/existence) + if self._evict_to_make_space(limit, est_cost, requesting_tag=tag): + continue + + idle_count = 0 + other_idle_count = 0 + for item in self._idle_lru.items(): + if item[1][0] == tag: + idle_count += 1 + else: + other_idle_count += 1 + total_model_count = 0 + for _, instances in self._models.items(): + total_model_count += len(instances) + curr, _, _ = self._monitor.get_stats() + logger.info( + "Waiting for resources to free up: " + "tag=%s ticket num%s model count=%s " + "idle count=%s resource usage=%.1f MB " + "total models count=%s other idle=%s", + tag, + ticket_num, + len(self._models[tag]), + idle_count, + curr, + total_model_count, + other_idle_count) + self._cv.wait(timeout=10.0) + + finally: + if self._wait_queue and self._wait_queue[0][2] is my_id: + heapq.heappop(self._wait_queue) + else: + for i, item in enumerate(self._wait_queue): + if item[2] is my_id: + self._wait_queue.pop(i) + heapq.heapify(self._wait_queue) + self._cv.notify_all() + + if should_spawn: + return self._spawn_new_model(tag, loader_func, is_unknown, est_cost) + + def release_model(self, tag: str, instance: Any): + with self._cv: + try: + self._total_active_jobs -= 1 + if self._active_counts[tag] > 0: + self._active_counts[tag] -= 1 + + self._idle_lru[id(instance)] = (tag, instance, time.time()) + + _, peak_during_job, _ = self._monitor.get_stats() + + if self._isolation_mode and self._active_counts[tag] == 0: + cost = max(0, peak_during_job - self._isolation_baseline) + self._estimator.set_initial_estimate(tag, cost) + self._isolation_mode = False + self._isolation_baseline = 0.0 + else: + snapshot = { + t: len(instances) + for t, instances in self._models.items() if len(instances) > 0 + } + if snapshot: + self._estimator.add_observation(snapshot, peak_during_job) + + finally: + self._cv.notify_all() + + def _try_grab_from_lru(self, tag: str) -> Any: + target_key = None + target_instance = None + + for key, (t, instance, _) in reversed(self._idle_lru.items()): + if t == tag: + target_key = key + target_instance = instance + break + + if target_instance: + del self._idle_lru[target_key] + self._active_counts[tag] += 1 + self._total_active_jobs += 1 + return target_instance + + logger.info("No idle model found for tag: %s", tag) + return None + + def _evict_to_make_space( + self, limit: float, est_cost: float, requesting_tag: str) -> bool: + """ + Evicts models based on Demand Magnitude + Tiers. + Crucially: If we have 0 active copies of 'requesting_tag', we FORCE eviction + of the lowest-demand candidate to avoid starvation. + """ + evicted_something = False + curr, _, _ = self._monitor.get_stats() + projected_usage = curr + self._pending_reservations + est_cost + + if projected_usage <= limit: + return False + + now = time.time() + + demand_map = Counter() + for item in self._wait_queue: + if len(item) >= 4: + demand_map[item[3]] += 1 + + my_demand = demand_map[requesting_tag] + am_i_starving = len(self._models[requesting_tag]) == 0 + + candidates = [] + for key, (tag, instance, release_time) in self._idle_lru.items(): + candidate_demand = demand_map[tag] + + if not am_i_starving: + if candidate_demand >= my_demand: + continue + + age = now - release_time + is_cold = age >= self._eviction_cooldown + + total_copies = len(self._models[tag]) + is_surplus = total_copies > self._min_model_copies + + if is_cold and is_surplus: tier = 0 + elif not is_cold and is_surplus: tier = 1 + elif is_cold and not is_surplus: tier = 2 + else: tier = 3 + + score = (candidate_demand * 10) + tier + + candidates.append((score, release_time, key, tag, instance)) + + candidates.sort(key=lambda x: (x[0], x[1])) + + for score, _, key, tag, instance in candidates: + if projected_usage <= limit: + break + + if key not in self._idle_lru: continue + + self._perform_eviction(key, tag, instance, score) + evicted_something = True + + curr, _, _ = self._monitor.get_stats() + projected_usage = curr + self._pending_reservations + est_cost + + return evicted_something + + def _perform_eviction(self, key, tag, instance, score): + logger.info("Evicting Model: %s (Score %d)", tag, score) + curr, _, _ = self._monitor.get_stats() + logger.info("Resource Usage Before Eviction: %.1f MB", curr) + + if key in self._idle_lru: + del self._idle_lru[key] + + target_id = instance._beam_tracking_id + for i, inst in enumerate(self._models[tag]): + if inst._beam_tracking_id == target_id: + del self._models[tag][i] + break + + instance.unsafe_hard_delete() + del instance + gc.collect() + torch.cuda.empty_cache() + self._monitor.refresh() + self._monitor.reset_peak() + curr, _, _ = self._monitor.get_stats() + logger.info("Resource Usage After Eviction: %.1f MB", curr) + + def _spawn_new_model(self, tag, loader_func, is_unknown, est_cost): + try: + with self._load_lock: + logger.info("Loading Model: %s (Unknown: %s)", tag, is_unknown) + isolation_baseline_snap, _, _ = self._monitor.get_stats() + instance = TrackedModelProxy(loader_func()) + _, peak_during_load, _ = self._monitor.get_stats() + + with self._cv: + snapshot = {tag: 1} + self._estimator.add_observation( + snapshot, peak_during_load - isolation_baseline_snap) + + if not is_unknown: + self._pending_reservations = max( + 0.0, self._pending_reservations - est_cost) + self._models[tag].append(instance) + return instance + + except Exception as e: + logger.error("Load Failed: %s. Error: %s", tag, e) + with self._cv: + self._total_active_jobs -= 1 + if is_unknown: + self._isolation_mode = False + self._isolation_baseline = 0.0 + else: + self._pending_reservations = max( + 0.0, self._pending_reservations - est_cost) + self._active_counts[tag] -= 1 + self._cv.notify_all() + raise e + + def _delete_all_models(self): + self._idle_lru.clear() + for _, instances in self._models.items(): + for instance in instances: + if hasattr(instance, "unsafe_hard_delete"): + instance.unsafe_hard_delete() + del instance + self._models.clear() + self._active_counts.clear() + gc.collect() + torch.cuda.empty_cache() + self._monitor.refresh() + self._monitor.reset_peak() + + def _force_reset(self): + logger.warning("Force Reset Triggered") + self._delete_all_models() + self._models = defaultdict(list) + self._idle_lru = OrderedDict() + self._active_counts = Counter() + self._wait_queue = [] + self._total_active_jobs = 0 + self._pending_reservations = 0.0 + self._isolation_mode = False + self._pending_isolation_count = 0 + self._isolation_baseline = 0.0 + + def shutdown(self): + self._delete_all_models() + gc.collect() + torch.cuda.empty_cache() + self._monitor.stop() + + def __del__(self): + self.shutdown() + + def __exit__(self, exc_type, exc_value, traceback): + self.shutdown() diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py new file mode 100644 index 000000000000..7412ea6a6c64 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -0,0 +1,548 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import time +import threading +import random +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import patch + +try: + from apache_beam.ml.inference.model_manager import ModelManager, GPUMonitor, ResourceEstimator +except ImportError as e: + raise unittest.SkipTest("Model Manager dependencies are not installed") + + +class MockGPUMonitor: + """ + Simulates GPU hardware with cumulative memory tracking. + Allows simulating specific allocation spikes and baseline usage. + """ + def __init__(self, total_memory=12000.0, peak_window: int = 5): + self._current = 0.0 + self._peak = 0.0 + self._total = total_memory + self._lock = threading.Lock() + self.running = False + self.history = [] + self.peak_window = peak_window + + def start(self): + self.running = True + + def stop(self): + self.running = False + + def get_stats(self): + with self._lock: + return self._current, self._peak, self._total + + def reset_peak(self): + with self._lock: + self._peak = self._current + self.history = [self._current] + + def set_usage(self, current_mb): + """Sets absolute usage (legacy helper).""" + with self._lock: + self._current = current_mb + self._peak = max(self._peak, current_mb) + + def allocate(self, amount_mb): + """Simulates memory allocation (e.g., tensors loaded to VRAM).""" + with self._lock: + self._current += amount_mb + self.history.append(self._current) + if len(self.history) > self.peak_window: + self.history.pop(0) + self._peak = max(self.history) + + def free(self, amount_mb): + """Simulates memory freeing (not used often if pooling is active).""" + with self._lock: + self._current = max(0.0, self._current - amount_mb) + self.history.append(self._current) + if len(self.history) > self.peak_window: + self.history.pop(0) + self._peak = max(self.history) + + def refresh(self): + """Simulates a refresh of the monitor stats (no-op for mock).""" + pass + + +class MockModel: + def __init__(self, name, size, monitor): + self.name = name + self.size = size + self.monitor = monitor + self.deleted = False + self.monitor.allocate(size) + + def unsafe_hard_delete(self): + if not self.deleted: + self.monitor.free(self.size) + self.deleted = True + + +class TestModelManager(unittest.TestCase): + def setUp(self): + """Force reset the Singleton ModelManager before every test.""" + ModelManager._instance = None + self.mock_monitor = MockGPUMonitor() + self.manager = ModelManager(monitor=self.mock_monitor) + + def tearDown(self): + self.manager.shutdown() + + def test_model_manager_capacity_check(self): + """ + Test that the manager blocks when spawning models exceeds the limit, + and unblocks when resources become available (via reuse). + """ + model_name = "known_model" + model_cost = 3000.0 + self.manager._estimator.set_initial_estimate(model_name, model_cost) + acquired_refs = [] + + def loader(): + self.mock_monitor.allocate(model_cost) + return model_name + + # 1. Saturate GPU with 3 models (9000 MB usage) + for _ in range(3): + inst = self.manager.acquire_model(model_name, loader) + acquired_refs.append(inst) + + # 2. Spawn one more (Should Block because 9000 + 3000 > Limit) + def run_inference(): + return self.manager.acquire_model(model_name, loader) + + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(run_inference) + try: + future.result(timeout=0.5) + self.fail("Should have blocked due to capacity") + except TimeoutError: + pass + + # 3. Release resources to unblock + item_to_release = acquired_refs.pop() + self.manager.release_model(model_name, item_to_release) + + result = future.result(timeout=2.0) + self.assertIsNotNone(result) + self.assertEqual(result, item_to_release) + + def test_model_manager_unknown_model_runs_isolated(self): + """Test that a model with no history runs in isolation.""" + model_name = "unknown_model_v1" + self.assertTrue(self.manager._estimator.is_unknown(model_name)) + + def dummy_loader(): + time.sleep(0.05) + return "model_instance" + + instance = self.manager.acquire_model(model_name, dummy_loader) + + self.assertTrue(self.manager._isolation_mode) + self.assertEqual(self.manager._total_active_jobs, 1) + + self.manager.release_model(model_name, instance) + self.assertFalse(self.manager._isolation_mode) + self.assertFalse(self.manager._estimator.is_unknown(model_name)) + + def test_model_manager_concurrent_execution(self): + """Test that multiple small known models can run together.""" + model_a = "small_model_a" + model_b = "small_model_b" + + self.manager._estimator.set_initial_estimate(model_a, 1000.0) + self.manager._estimator.set_initial_estimate(model_b, 1000.0) + self.mock_monitor.set_usage(1000.0) + + inst_a = self.manager.acquire_model(model_a, lambda: "A") + inst_b = self.manager.acquire_model(model_b, lambda: "B") + + self.assertEqual(self.manager._total_active_jobs, 2) + + self.manager.release_model(model_a, inst_a) + self.manager.release_model(model_b, inst_b) + self.assertEqual(self.manager._total_active_jobs, 0) + + def test_model_manager_concurrent_mixed_workload_convergence(self): + """ + Simulates a production environment with multiple model types running + concurrently. Verifies that the estimator converges. + """ + TRUE_COSTS = {"model_small": 1500.0, "model_medium": 3000.0} + + def run_job(model_name): + cost = TRUE_COSTS[model_name] + + def loader(): + model = MockModel(model_name, cost, self.mock_monitor) + return model + + instance = self.manager.acquire_model(model_name, loader) + time.sleep(random.uniform(0.01, 0.05)) + self.manager.release_model(model_name, instance) + + # Create a workload stream + workload = ["model_small"] * 15 + ["model_medium"] * 15 + random.shuffle(workload) + + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [executor.submit(run_job, name) for name in workload] + for f in futures: + f.result() + + est_small = self.manager._estimator.get_estimate("model_small") + est_med = self.manager._estimator.get_estimate("model_medium") + + self.assertAlmostEqual(est_small, TRUE_COSTS["model_small"], delta=100.0) + self.assertAlmostEqual(est_med, TRUE_COSTS["model_medium"], delta=100.0) + + def test_model_manager_oom_recovery(self): + """Test that the manager recovers state if a loader crashes.""" + model_name = "crasher_model" + self.manager._estimator.set_initial_estimate(model_name, 1000.0) + + def crashing_loader(): + raise RuntimeError("CUDA OOM or similar") + + with self.assertRaises(RuntimeError): + self.manager.acquire_model(model_name, crashing_loader) + + self.assertEqual(self.manager._total_active_jobs, 0) + self.assertEqual(self.manager._pending_reservations, 0.0) + self.assertFalse(self.manager._cv._is_owned()) + + def test_model_managaer_force_reset_on_exception(self): + """Test that force_reset clears all models from the manager.""" + model_name = "test_model" + + def dummy_loader(): + self.mock_monitor.allocate(1000.0) + raise RuntimeError("Simulated loader exception") + + try: + instance = self.manager.acquire_model( + model_name, lambda: "model_instance") + self.manager.release_model(model_name, instance) + instance = self.manager.acquire_model(model_name, dummy_loader) + except RuntimeError: + self.manager._force_reset() + self.assertTrue(len(self.manager._models[model_name]) == 0) + self.assertEqual(self.manager._total_active_jobs, 0) + self.assertEqual(self.manager._pending_reservations, 0.0) + self.assertFalse(self.manager._isolation_mode) + pass + + instance = self.manager.acquire_model(model_name, lambda: "model_instance") + self.manager.release_model(model_name, instance) + + def test_single_model_convergence_with_fluctuations(self): + """ + Tests that the estimator converges to the true usage with fluctuations. + """ + model_name = "fluctuating_model" + model_cost = 3000.0 + load_cost = 2000.0 + + def loader(): + self.mock_monitor.allocate(load_cost) + return model_name + + model = self.manager.acquire_model(model_name, loader) + self.manager.release_model(model_name, model) + initial_est = self.manager._estimator.get_estimate(model_name) + self.assertEqual(initial_est, load_cost) + + def run_inference(): + model = self.manager.acquire_model(model_name, loader) + noise = model_cost - load_cost + random.uniform(-300.0, 300.0) + self.mock_monitor.allocate(noise) + time.sleep(0.1) + self.mock_monitor.free(noise) + self.manager.release_model(model_name, model) + return + + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [executor.submit(run_inference) for _ in range(100)] + + for f in futures: + f.result() + + est_cost = self.manager._estimator.get_estimate(model_name) + self.assertAlmostEqual(est_cost, model_cost, delta=100.0) + + +class TestModelManagerEviction(unittest.TestCase): + def setUp(self): + self.mock_monitor = MockGPUMonitor(total_memory=12000.0) + ModelManager._instance = None + self.manager = ModelManager( + monitor=self.mock_monitor, + slack_percentage=0.0, + min_data_points=1, + eviction_cooldown_seconds=10.0, + min_model_copies=1) + + def tearDown(self): + self.manager.shutdown() + + def create_loader(self, name, size): + return lambda: MockModel(name, size, self.mock_monitor) + + def test_basic_lru_eviction(self): + self.manager._estimator.set_initial_estimate("A", 4000) + self.manager._estimator.set_initial_estimate("B", 4000) + self.manager._estimator.set_initial_estimate("C", 5000) + + model_a = self.manager.acquire_model("A", self.create_loader("A", 4000)) + self.manager.release_model("A", model_a) + + model_b = self.manager.acquire_model("B", self.create_loader("B", 4000)) + self.manager.release_model("B", model_b) + + key_a = list(self.manager._idle_lru.keys())[0] + self.manager._idle_lru[key_a] = ("A", model_a, time.time() - 20.0) + + key_b = list(self.manager._idle_lru.keys())[1] + self.manager._idle_lru[key_b] = ("B", model_b, time.time() - 20.0) + + model_a_again = self.manager.acquire_model( + "A", self.create_loader("A", 4000)) + self.manager.release_model("A", model_a_again) + + self.manager.acquire_model("C", self.create_loader("C", 5000)) + + self.assertEqual(len(self.manager.all_models("B")), 0) + self.assertEqual(len(self.manager.all_models("A")), 1) + + def test_chained_eviction(self): + self.manager._estimator.set_initial_estimate("big_guy", 8000) + models = [] + for i in range(4): + name = f"small_{i}" + m = self.manager.acquire_model(name, self.create_loader(name, 3000)) + self.manager.release_model(name, m) + models.append(m) + + self.manager.acquire_model("big_guy", self.create_loader("big_guy", 8000)) + + self.assertTrue(models[0].deleted) + self.assertTrue(models[1].deleted) + self.assertTrue(models[2].deleted) + self.assertFalse(models[3].deleted) + + def test_active_models_are_protected(self): + self.manager._estimator.set_initial_estimate("A", 6000) + self.manager._estimator.set_initial_estimate("B", 4000) + self.manager._estimator.set_initial_estimate("C", 4000) + + model_a = self.manager.acquire_model("A", self.create_loader("A", 6000)) + model_b = self.manager.acquire_model("B", self.create_loader("B", 4000)) + self.manager.release_model("B", model_b) + + key_b = list(self.manager._idle_lru.keys())[0] + self.manager._idle_lru[key_b] = ("B", model_b, time.time() - 20.0) + + def acquire_c(): + return self.manager.acquire_model("C", self.create_loader("C", 4000)) + + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(acquire_c) + model_c = future.result(timeout=2.0) + + self.assertTrue(model_b.deleted) + self.assertFalse(model_a.deleted) + + self.manager.release_model("A", model_a) + self.manager.release_model("C", model_c) + + def test_unknown_model_clears_memory(self): + self.manager._estimator.set_initial_estimate("A", 2000) + model_a = self.manager.acquire_model("A", self.create_loader("A", 2000)) + self.manager.release_model("A", model_a) + self.assertFalse(model_a.deleted) + + self.assertTrue(self.manager._estimator.is_unknown("X")) + model_x = self.manager.acquire_model("X", self.create_loader("X", 10000)) + + self.assertTrue(model_a.deleted, "Model A should be deleted for isolation") + self.assertEqual(len(self.manager.all_models("A")), 0) + self.assertTrue(self.manager._isolation_mode) + self.manager.release_model("X", model_x) + + def test_concurrent_eviction_pressure(self): + def worker(idx): + name = f"model_{idx % 5}" + try: + m = self.manager.acquire_model(name, self.create_loader(name, 4000)) + time.sleep(0.001) + self.manager.release_model(name, m) + except Exception: + pass + + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [executor.submit(worker, i) for i in range(50)] + for f in futures: + f.result() + + curr, _, _ = self.mock_monitor.get_stats() + expected_usage = 0 + for _, instances in self.manager._models.items(): + expected_usage += len(instances) * 4000 + + self.assertAlmostEqual(curr, expected_usage) + + def test_starvation_prevention_overrides_demand(self): + self.manager._estimator.set_initial_estimate("A", 12000) + m_a = self.manager.acquire_model("A", self.create_loader("A", 12000)) + self.manager.release_model("A", m_a) + + def cycle_a(): + try: + m = self.manager.acquire_model("A", self.create_loader("A", 12000)) + time.sleep(0.3) + self.manager.release_model("A", m) + except Exception: + pass + + executor = ThreadPoolExecutor(max_workers=5) + for _ in range(5): + executor.submit(cycle_a) + + def acquire_b(): + return self.manager.acquire_model("B", self.create_loader("B", 4000)) + + b_future = executor.submit(acquire_b) + model_b = b_future.result() + + self.assertTrue(m_a.deleted) + self.manager.release_model("B", model_b) + executor.shutdown(wait=True) + + +class TestGPUMonitor(unittest.TestCase): + def setUp(self): + self.subprocess_patcher = patch('subprocess.check_output') + self.mock_subprocess = self.subprocess_patcher.start() + + def tearDown(self): + self.subprocess_patcher.stop() + + def test_init_hardware_detected(self): + """Test that init correctly reads total memory when nvidia-smi exists.""" + self.mock_subprocess.return_value = "24576" + monitor = GPUMonitor() + self.assertTrue(monitor._gpu_available) + self.assertEqual(monitor._total_memory, 24576.0) + + def test_init_hardware_missing(self): + """Test fallback behavior when nvidia-smi is missing.""" + self.mock_subprocess.side_effect = FileNotFoundError() + monitor = GPUMonitor(fallback_memory_mb=12000.0) + self.assertFalse(monitor._gpu_available) + self.assertEqual(monitor._total_memory, 12000.0) + + @patch('time.sleep') + def test_polling_updates_stats(self, mock_sleep): + """Test that the polling loop updates current and peak usage.""" + def subprocess_side_effect(*args, **kwargs): + if isinstance(args[0], list) and "memory.total" in args[0][1]: + return "16000" + + if isinstance(args[0], str) and "memory.free" in args[0]: + return b"12000" + + raise Exception("Unexpected command") + + self.mock_subprocess.side_effect = subprocess_side_effect + self.mock_subprocess.return_value = None + + monitor = GPUMonitor() + monitor.start() + time.sleep(0.1) + curr, peak, total = monitor.get_stats() + monitor.stop() + + self.assertEqual(curr, 4000.0) + self.assertEqual(peak, 4000.0) + self.assertEqual(total, 16000.0) + + def test_reset_peak(self): + """Test that resetting peak usage works.""" + monitor = GPUMonitor() + monitor._gpu_available = True + + with monitor._lock: + monitor._current_usage = 2000.0 + monitor._peak_usage = 8000.0 + monitor._memory_history.append((time.time(), 8000.0)) + monitor._memory_history.append((time.time(), 2000.0)) + + monitor.reset_peak() + + _, peak, _ = monitor.get_stats() + self.assertEqual(peak, 2000.0) + + +class TestResourceEstimatorSolver(unittest.TestCase): + def setUp(self): + self.estimator = ResourceEstimator() + + @patch('apache_beam.ml.inference.model_manager.nnls') + def test_solver_respects_min_data_points(self, mock_nnls): + mock_nnls.return_value = ([100.0, 50.0], 0.0) + + self.estimator.add_observation({'model_A': 1}, 500) + self.estimator.add_observation({'model_B': 1}, 500) + self.assertFalse(mock_nnls.called) + + self.estimator.add_observation({'model_A': 1, 'model_B': 1}, 1000) + self.assertFalse(mock_nnls.called) + + self.estimator.add_observation({'model_A': 1}, 500) + self.assertFalse(mock_nnls.called) + + self.estimator.add_observation({'model_B': 1}, 500) + self.assertTrue(mock_nnls.called) + + @patch('apache_beam.ml.inference.model_manager.nnls') + def test_solver_respects_unique_model_constraint(self, mock_nnls): + mock_nnls.return_value = ([100.0, 100.0, 50.0], 0.0) + + for _ in range(5): + self.estimator.add_observation({'model_A': 1, 'model_B': 1}, 800) + + for _ in range(5): + self.estimator.add_observation({'model_C': 1}, 400) + + self.assertFalse(mock_nnls.called) + + self.estimator.add_observation({'model_A': 1}, 300) + self.estimator.add_observation({'model_B': 1}, 300) + + self.assertTrue(mock_nnls.called) + + +if __name__ == "__main__": + unittest.main() From b123a95e921f4ad01834d043e53b100b88eba526 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 23 Jan 2026 22:49:02 +0000 Subject: [PATCH 02/22] Add pydoc and move gpu detection to start --- .../apache_beam/ml/inference/model_manager.py | 49 ++++++++++++++++--- .../ml/inference/model_manager_test.py | 4 +- 2 files changed, 45 insertions(+), 8 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index e56b8ee8d03f..8980466aba36 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -42,6 +42,17 @@ class GPUMonitor: + """Monitors GPU memory usage in a separate thread using nvidia-smi. + + This class continuously polls GPU memory statistics to track current usage + and peak usage over a sliding time window. It serves as the source of truth + for the ModelManager's resource decisions. + + Attributes: + fallback_memory_mb: Default total memory if hardware detection fails. + poll_interval: Seconds between memory checks. + peak_window_seconds: Duration to track peak memory usage. + """ def __init__( self, fallback_memory_mb: float = 16000.0, @@ -56,7 +67,6 @@ def __init__( self._running = False self._thread = None self._lock = threading.Lock() - self._gpu_available = self._detect_hardware() def _detect_hardware(self): try: @@ -82,6 +92,7 @@ def _detect_hardware(self): return False def start(self): + self._gpu_available = self._detect_hardware() if self._running or not self._gpu_available: return self._running = True @@ -145,6 +156,12 @@ def _poll_loop(self): class ResourceEstimator: + """Estimates individual model memory usage using statistical observation. + + Uses Non-Negative Least Squares (NNLS) to deduce the memory footprint of + individual models based on aggregate system memory readings and the + configuration of active models at that time. + """ def __init__(self, smoothing_factor: float = 0.2, min_data_points: int = 5): self.smoothing_factor = smoothing_factor self.min_data_points = min_data_points @@ -251,6 +268,11 @@ def _solve(self): class TrackedModelProxy: + """A transparent proxy for model objects that adds tracking metadata. + + Wraps the underlying model object to attach a unique ID and intercept + calls, allowing the manager to track individual instances across processes. + """ def __init__(self, obj): object.__setattr__(self, "_wrapped_obj", obj) object.__setattr__(self, "_beam_tracking_id", str(uuid.uuid4())) @@ -279,12 +301,24 @@ def __repr__(self): def __dir__(self): return dir(self._wrapped_obj) - def unsafe_hard_delete(self): - if hasattr(self._wrapped_obj, "unsafe_hard_delete"): - self._wrapped_obj.unsafe_hard_delete() + def trackedModelProxy_unsafe_hard_delete(self): + if hasattr(self._wrapped_obj, "singletonProxy_unsafe_hard_delete"): + try: + self._wrapped_obj.singletonProxy_unsafe_hard_delete() + except Exception: + pass class ModelManager: + """Manages model lifecycles, caching, and resource arbitration. + + This class acts as the central controller for acquiring model instances. + It handles: + 1. LRU Caching of idle models. + 2. Resource estimation and admission control (preventing OOM). + 3. Dynamic eviction of low-priority models when space is needed. + 4. 'Isolation Mode' for safely profiling unknown models. + """ _lock = threading.Lock() def __init__( @@ -587,7 +621,8 @@ def _perform_eviction(self, key, tag, instance, score): del self._models[tag][i] break - instance.unsafe_hard_delete() + if hasattr(instance, "trackedModelProxy_unsafe_hard_delete"): + instance.trackedModelProxy_unsafe_hard_delete() del instance gc.collect() torch.cuda.empty_cache() @@ -633,8 +668,8 @@ def _delete_all_models(self): self._idle_lru.clear() for _, instances in self._models.items(): for instance in instances: - if hasattr(instance, "unsafe_hard_delete"): - instance.unsafe_hard_delete() + if hasattr(instance, "trackedModelProxy_unsafe_hard_delete"): + instance.trackedModelProxy_unsafe_hard_delete() del instance self._models.clear() self._active_counts.clear() diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index 7412ea6a6c64..57d3816165ec 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -94,7 +94,7 @@ def __init__(self, name, size, monitor): self.deleted = False self.monitor.allocate(size) - def unsafe_hard_delete(self): + def singletonProxy_unsafe_hard_delete(self): if not self.deleted: self.monitor.free(self.size) self.deleted = True @@ -453,6 +453,7 @@ def test_init_hardware_detected(self): """Test that init correctly reads total memory when nvidia-smi exists.""" self.mock_subprocess.return_value = "24576" monitor = GPUMonitor() + monitor.start() self.assertTrue(monitor._gpu_available) self.assertEqual(monitor._total_memory, 24576.0) @@ -460,6 +461,7 @@ def test_init_hardware_missing(self): """Test fallback behavior when nvidia-smi is missing.""" self.mock_subprocess.side_effect = FileNotFoundError() monitor = GPUMonitor(fallback_memory_mb=12000.0) + monitor.start() self.assertFalse(monitor._gpu_available) self.assertEqual(monitor._total_memory, 12000.0) From 4f82caa4b63276878a1a9a6f529f6765f8fc99b8 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Tue, 27 Jan 2026 03:13:29 +0000 Subject: [PATCH 03/22] Add comments and helper function to make it easier to understand the code and cleanup some code logics --- .../apache_beam/ml/inference/model_manager.py | 218 +++++++++++------- 1 file changed, 133 insertions(+), 85 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 8980466aba36..cf6ffb74759f 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -186,10 +186,17 @@ def set_initial_estimate(self, model_tag: str, cost: float): def add_observation( self, active_snapshot: Dict[str, int], peak_memory: float): + if active_snapshot: + model_list = "\n".join( + f"\t- {model}: {count}" + for model, count in sorted(active_snapshot.items())) + else: + model_list = "\t- None" + logger.info( - "Adding Observation: Snapshot=%s, PeakMemory=%.1f MB", - active_snapshot, - peak_memory) + "Adding Observation:\n PeakMemory: %.1f MB\n Instances:\n%s", + peak_memory, + model_list) if not active_snapshot: return with self._lock: @@ -316,7 +323,8 @@ class ModelManager: It handles: 1. LRU Caching of idle models. 2. Resource estimation and admission control (preventing OOM). - 3. Dynamic eviction of low-priority models when space is needed. + 3. Dynamic eviction of low-priority models, determined by count of + pending requests, when space is needed. 4. 'Isolation Mode' for safely profiling unknown models. """ _lock = threading.Lock() @@ -343,38 +351,107 @@ def __init__( # Resource State self._models = defaultdict(list) + # Idle LRU used to track released models that + # can be freed or reused upon request. self._idle_lru = OrderedDict() self._active_counts = Counter() self._total_active_jobs = 0 self._pending_reservations = 0.0 + # Isolation state used to profile unknown models, + # ensuring they run alone to get accurate readings. + # isolation_baseline represents the GPU usage before + # loading the unknown model. self._isolation_mode = False - self._pending_isolation_count = 0 self._isolation_baseline = 0.0 + # Waiting Queue and Ticketing to make sure we have fair ordering + # and also priority for unknown models. self._wait_queue = [] self._ticket_counter = itertools.count() self._cv = threading.Condition() - self._load_lock = threading.Lock() self._monitor.start() def all_models(self, tag) -> list[Any]: return self._models[tag] + def enter_isolation_mode(self, tag: str, ticket_num: int) -> bool: + if self._total_active_jobs > 0: + logger.info( + "Waiting to enter isolation: tag=%s ticket num=%s", tag, ticket_num) + self._cv.wait() + # return False since we have waited and need to re-evaluate + # in caller to make sure our priority is still valid. + return False + + logger.info("Unknown model %s detected. Flushing GPU.", tag) + self._delete_all_models() + + self._isolation_mode = True + self._total_active_jobs += 1 + self._isolation_baseline, _, _ = self._monitor.get_stats() + self._monitor.reset_peak() + return True + + def should_spawn_model(self, tag: str, ticket_num: int) -> bool: + curr, _, total = self._monitor.get_stats() + est_cost = self._estimator.get_estimate(tag) + limit = total * (1 - self._slack_percentage) + + # Use current usage for capacity check (ignore old spikes) + if (curr + self._pending_reservations + est_cost) <= limit: + self._pending_reservations += est_cost + self._total_active_jobs += 1 + self._active_counts[tag] += 1 + return True + + # Evict to make space (passing tag to check demand/existence) + if self._evict_to_make_space(limit, est_cost, requesting_tag=tag): + return True + + # Manually log status for debugging if we are going to wait + idle_count = 0 + other_idle_count = 0 + for item in self._idle_lru.items(): + if item[1][0] == tag: + idle_count += 1 + else: + other_idle_count += 1 + total_model_count = 0 + for _, instances in self._models.items(): + total_model_count += len(instances) + curr, _, _ = self._monitor.get_stats() + logger.info( + "Waiting for resources to free up: " + "tag=%s ticket num%s model count=%s " + "idle count=%s resource usage=%.1f MB " + "total models count=%s other idle=%s", + tag, + ticket_num, + len(self._models[tag]), + idle_count, + curr, + total_model_count, + other_idle_count) + self._cv.wait(timeout=10.0) + return False + def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: current_priority = 0 if self._estimator.is_unknown(tag) else 1 ticket_num = next(self._ticket_counter) my_id = object() with self._cv: - # FAST PATH - if self._pending_isolation_count == 0 and not self._isolation_mode: + # FAST PATH: Grab from idle LRU if available + if not self._isolation_mode: cached_instance = self._try_grab_from_lru(tag) if cached_instance: return cached_instance - # SLOW PATH + # SLOW PATH: Enqueue and wait for turn to acquire model, + # with unknown models having priority and order enforced + # by ticket number as FIFO. logger.info( "Acquire Queued: tag=%s, priority=%d " "total models count=%s ticket num=%s", @@ -397,9 +474,11 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: self._cv.wait() continue + # Re-evaluate priority in case model became known during wait real_is_unknown = self._estimator.is_unknown(tag) real_priority = 0 if real_is_unknown else 1 + # If priority changed, reinsert into queue and wait if current_priority != real_priority: heapq.heappop(self._wait_queue) current_priority = real_priority @@ -408,6 +487,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: self._cv.notify_all() continue + # Try grab from LRU again in case model was released during wait cached_instance = self._try_grab_from_lru(tag) if cached_instance: return cached_instance @@ -416,27 +496,17 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: # Path A: Isolation if is_unknown: - if self._total_active_jobs > 0: - logger.info( - "Waiting to enter isolation: tag=%s ticket num=%s", - tag, - ticket_num) - self._cv.wait() + if self.enter_isolation_mode(tag, ticket_num): + should_spawn = True + break + else: + # We waited, need to re-evaluate our turn + # because priority may have changed during the wait continue - logger.info("Unknown model %s detected. Flushing GPU.", tag) - self._delete_all_models() - - self._isolation_mode = True - self._total_active_jobs += 1 - self._isolation_baseline, _, _ = self._monitor.get_stats() - self._monitor.reset_peak() - should_spawn = True - break - # Path B: Concurrent else: - if self._pending_isolation_count > 0 or self._isolation_mode: + if self._isolation_mode: logger.info( "Waiting due to isolation in progress: tag=%s ticket num%s", tag, @@ -444,48 +514,17 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: self._cv.wait() continue - curr, _, total = self._monitor.get_stats() - est_cost = self._estimator.get_estimate(tag) - limit = total * (1 - self._slack_percentage) - - # Use current usage for capacity check (ignore old spikes) - if (curr + self._pending_reservations + est_cost) <= limit: - self._pending_reservations += est_cost - self._total_active_jobs += 1 - self._active_counts[tag] += 1 + if self.should_spawn_model(tag, ticket_num): should_spawn = True + est_cost = self._estimator.get_estimate(tag) break - - # Evict to make space (passing tag to check demand/existence) - if self._evict_to_make_space(limit, est_cost, requesting_tag=tag): + else: + # We waited, need to re-evaluate our turn + # because priority may have changed during the wait continue - idle_count = 0 - other_idle_count = 0 - for item in self._idle_lru.items(): - if item[1][0] == tag: - idle_count += 1 - else: - other_idle_count += 1 - total_model_count = 0 - for _, instances in self._models.items(): - total_model_count += len(instances) - curr, _, _ = self._monitor.get_stats() - logger.info( - "Waiting for resources to free up: " - "tag=%s ticket num%s model count=%s " - "idle count=%s resource usage=%.1f MB " - "total models count=%s other idle=%s", - tag, - ticket_num, - len(self._models[tag]), - idle_count, - curr, - total_model_count, - other_idle_count) - self._cv.wait(timeout=10.0) - finally: + # Remove self from wait queue once done if self._wait_queue and self._wait_queue[0][2] is my_id: heapq.heappop(self._wait_queue) else: @@ -495,8 +534,8 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: heapq.heapify(self._wait_queue) self._cv.notify_all() - if should_spawn: - return self._spawn_new_model(tag, loader_func, is_unknown, est_cost) + if should_spawn: + return self._spawn_new_model(tag, loader_func, is_unknown, est_cost) def release_model(self, tag: str, instance: Any): with self._cv: @@ -507,14 +546,18 @@ def release_model(self, tag: str, instance: Any): self._idle_lru[id(instance)] = (tag, instance, time.time()) + # Update estimator with latest stats _, peak_during_job, _ = self._monitor.get_stats() if self._isolation_mode and self._active_counts[tag] == 0: + # For isolation mode, we directly set the initial estimate + # so that we can quickly learn the model cost. cost = max(0, peak_during_job - self._isolation_baseline) self._estimator.set_initial_estimate(tag, cost) self._isolation_mode = False self._isolation_baseline = 0.0 else: + # Regular update for known models snapshot = { t: len(instances) for t, instances in self._models.items() if len(instances) > 0 @@ -536,6 +579,7 @@ def _try_grab_from_lru(self, tag: str) -> Any: break if target_instance: + # Found an idle model, remove from LRU and return del self._idle_lru[target_key] self._active_counts[tag] += 1 self._total_active_jobs += 1 @@ -550,20 +594,21 @@ def _evict_to_make_space( Evicts models based on Demand Magnitude + Tiers. Crucially: If we have 0 active copies of 'requesting_tag', we FORCE eviction of the lowest-demand candidate to avoid starvation. + Returns True if space was made, False otherwise. """ - evicted_something = False curr, _, _ = self._monitor.get_stats() projected_usage = curr + self._pending_reservations + est_cost if projected_usage <= limit: - return False + # Memory usage changed and we are already under limit + return True now = time.time() + # Calculate the demand from the wait queue demand_map = Counter() for item in self._wait_queue: - if len(item) >= 4: - demand_map[item[3]] += 1 + demand_map[item[3]] += 1 my_demand = demand_map[requesting_tag] am_i_starving = len(self._models[requesting_tag]) == 0 @@ -572,10 +617,12 @@ def _evict_to_make_space( for key, (tag, instance, release_time) in self._idle_lru.items(): candidate_demand = demand_map[tag] - if not am_i_starving: - if candidate_demand >= my_demand: - continue + if not am_i_starving and candidate_demand >= my_demand: + continue + # Attempts to score candidates based on hotness and manually + # specified minimum copies. Demand is weighted heavily to + # ensure we evict low-demand models first. age = now - release_time is_cold = age >= self._eviction_cooldown @@ -593,6 +640,7 @@ def _evict_to_make_space( candidates.sort(key=lambda x: (x[0], x[1])) + # Evict candidates until we are under limit for score, _, key, tag, instance in candidates: if projected_usage <= limit: break @@ -600,14 +648,13 @@ def _evict_to_make_space( if key not in self._idle_lru: continue self._perform_eviction(key, tag, instance, score) - evicted_something = True curr, _, _ = self._monitor.get_stats() projected_usage = curr + self._pending_reservations + est_cost - return evicted_something + return projected_usage <= limit - def _perform_eviction(self, key, tag, instance, score): + def _perform_eviction(self, key: str, tag: str, instance: Any, score: int): logger.info("Evicting Model: %s (Score %d)", tag, score) curr, _, _ = self._monitor.get_stats() logger.info("Resource Usage Before Eviction: %.1f MB", curr) @@ -621,8 +668,7 @@ def _perform_eviction(self, key, tag, instance, score): del self._models[tag][i] break - if hasattr(instance, "trackedModelProxy_unsafe_hard_delete"): - instance.trackedModelProxy_unsafe_hard_delete() + instance.trackedModelProxy_unsafe_hard_delete() del instance gc.collect() torch.cuda.empty_cache() @@ -631,18 +677,22 @@ def _perform_eviction(self, key, tag, instance, score): curr, _, _ = self._monitor.get_stats() logger.info("Resource Usage After Eviction: %.1f MB", curr) - def _spawn_new_model(self, tag, loader_func, is_unknown, est_cost): + def _spawn_new_model( + self, + tag: str, + loader_func: Callable[[], Any], + is_unknown: bool, + est_cost: float) -> Any: try: - with self._load_lock: + with self._cv: logger.info("Loading Model: %s (Unknown: %s)", tag, is_unknown) - isolation_baseline_snap, _, _ = self._monitor.get_stats() + baseline_snap, _, _ = self._monitor.get_stats() instance = TrackedModelProxy(loader_func()) _, peak_during_load, _ = self._monitor.get_stats() - with self._cv: snapshot = {tag: 1} self._estimator.add_observation( - snapshot, peak_during_load - isolation_baseline_snap) + snapshot, peak_during_load - baseline_snap) if not is_unknown: self._pending_reservations = max( @@ -668,8 +718,7 @@ def _delete_all_models(self): self._idle_lru.clear() for _, instances in self._models.items(): for instance in instances: - if hasattr(instance, "trackedModelProxy_unsafe_hard_delete"): - instance.trackedModelProxy_unsafe_hard_delete() + instance.trackedModelProxy_unsafe_hard_delete() del instance self._models.clear() self._active_counts.clear() @@ -688,7 +737,6 @@ def _force_reset(self): self._total_active_jobs = 0 self._pending_reservations = 0.0 self._isolation_mode = False - self._pending_isolation_count = 0 self._isolation_baseline = 0.0 def shutdown(self): From 8a856b0f57e4907b238d1081e0ead09c7ef464ca Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Tue, 27 Jan 2026 03:22:17 +0000 Subject: [PATCH 04/22] Add TODO for threading --- sdks/python/apache_beam/ml/inference/model_manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index cf6ffb74759f..08a6f4fa04b0 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -369,6 +369,8 @@ def __init__( # and also priority for unknown models. self._wait_queue = [] self._ticket_counter = itertools.count() + # TODO: Consider making the wait to be smarter, i.e. + # splitting read/write etc. to avoid potential contention. self._cv = threading.Condition() self._monitor.start() From 3a200ba5226992a3a9a7ab624e824b34bd1831da Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Thu, 29 Jan 2026 21:13:14 +0000 Subject: [PATCH 05/22] Remove tracked model proxy and have model manager store tags instead of model instance --- .../apache_beam/ml/inference/model_manager.py | 67 +++++-------------- .../ml/inference/model_manager_test.py | 45 ++++++++++++- 2 files changed, 61 insertions(+), 51 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 08a6f4fa04b0..744f1aa9f45f 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -24,7 +24,6 @@ usage and performance. """ -import uuid import time import threading import subprocess @@ -37,6 +36,7 @@ import itertools from collections import defaultdict, deque, Counter, OrderedDict from typing import Dict, Any, Tuple, Optional, Callable +from apache_beam.utils import multi_process_shared logger = logging.getLogger(__name__) @@ -274,48 +274,6 @@ def _solve(self): logger.error("Solver failed: %s", e) -class TrackedModelProxy: - """A transparent proxy for model objects that adds tracking metadata. - - Wraps the underlying model object to attach a unique ID and intercept - calls, allowing the manager to track individual instances across processes. - """ - def __init__(self, obj): - object.__setattr__(self, "_wrapped_obj", obj) - object.__setattr__(self, "_beam_tracking_id", str(uuid.uuid4())) - - def __getattr__(self, name): - return getattr(self._wrapped_obj, name) - - def __setattr__(self, name, value): - setattr(self._wrapped_obj, name, value) - - def __call__(self, *args, **kwargs): - return self._wrapped_obj(*args, **kwargs) - - def __setstate__(self, state): - self.__dict__.update(state) - - def __getstate__(self): - return self.__dict__ - - def __str__(self): - return str(self._wrapped_obj) - - def __repr__(self): - return repr(self._wrapped_obj) - - def __dir__(self): - return dir(self._wrapped_obj) - - def trackedModelProxy_unsafe_hard_delete(self): - if hasattr(self._wrapped_obj, "singletonProxy_unsafe_hard_delete"): - try: - self._wrapped_obj.singletonProxy_unsafe_hard_delete() - except Exception: - pass - - class ModelManager: """Manages model lifecycles, caching, and resource arbitration. @@ -619,6 +577,7 @@ def _evict_to_make_space( for key, (tag, instance, release_time) in self._idle_lru.items(): candidate_demand = demand_map[tag] + # TODO: Try to avoid churn if demand is similar if not am_i_starving and candidate_demand >= my_demand: continue @@ -656,6 +615,17 @@ def _evict_to_make_space( return projected_usage <= limit + def _delete_instance(self, instance: Any): + if isinstance(instance, str): + # If the instance is a string, it's a uuid used + # to retrieve the model from MultiProcessShared + multi_process_shared.MultiProcessShared( + lambda: "N/A", tag=instance).unsafe_hard_delete() + if hasattr(instance, 'mock_model_unsafe_hard_delete'): + # Call the mock unsafe hard delete method for testing + instance.mock_model_unsafe_hard_delete() + del instance + def _perform_eviction(self, key: str, tag: str, instance: Any, score: int): logger.info("Evicting Model: %s (Score %d)", tag, score) curr, _, _ = self._monitor.get_stats() @@ -664,14 +634,12 @@ def _perform_eviction(self, key: str, tag: str, instance: Any, score: int): if key in self._idle_lru: del self._idle_lru[key] - target_id = instance._beam_tracking_id for i, inst in enumerate(self._models[tag]): - if inst._beam_tracking_id == target_id: + if instance == inst: del self._models[tag][i] break - instance.trackedModelProxy_unsafe_hard_delete() - del instance + self._delete_instance(instance) gc.collect() torch.cuda.empty_cache() self._monitor.refresh() @@ -689,7 +657,7 @@ def _spawn_new_model( with self._cv: logger.info("Loading Model: %s (Unknown: %s)", tag, is_unknown) baseline_snap, _, _ = self._monitor.get_stats() - instance = TrackedModelProxy(loader_func()) + instance = loader_func() _, peak_during_load, _ = self._monitor.get_stats() snapshot = {tag: 1} @@ -720,8 +688,7 @@ def _delete_all_models(self): self._idle_lru.clear() for _, instances in self._models.items(): for instance in instances: - instance.trackedModelProxy_unsafe_hard_delete() - del instance + self._delete_instance(instance) self._models.clear() self._active_counts.clear() gc.collect() diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index 57d3816165ec..ce9cff0064e9 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -21,6 +21,7 @@ import random from concurrent.futures import ThreadPoolExecutor from unittest.mock import patch +from apache_beam.utils import multi_process_shared try: from apache_beam.ml.inference.model_manager import ModelManager, GPUMonitor, ResourceEstimator @@ -94,12 +95,26 @@ def __init__(self, name, size, monitor): self.deleted = False self.monitor.allocate(size) - def singletonProxy_unsafe_hard_delete(self): + def mock_model_unsafe_hard_delete(self): if not self.deleted: self.monitor.free(self.size) self.deleted = True +class Counter(object): + def __init__(self, start=0): + self.running = start + self.lock = threading.Lock() + + def get(self): + return self.running + + def increment(self, value=1): + with self.lock: + self.running += value + return self.running + + class TestModelManager(unittest.TestCase): def setUp(self): """Force reset the Singleton ModelManager before every test.""" @@ -110,6 +125,34 @@ def setUp(self): def tearDown(self): self.manager.shutdown() + def test_model_manager_deletes_multiprocessshared_instances(self): + """Test that MultiProcessShared instances are deleted properly.""" + model_name = "test_model_shared" + tag = f"model_manager_test_{model_name}" + + def loader(): + multi_process_shared.MultiProcessShared( + lambda: Counter, tag=tag, always_proxy=True) + return tag + + instance = self.manager.acquire_model(model_name, loader) + instance_before = multi_process_shared.MultiProcessShared( + Counter, tag=tag, always_proxy=True).acquire() + instance_before.increment() + self.assertEqual(instance_before.get(), 1) + self.manager.release_model(model_name, instance) + + # Force delete all models + self.manager._force_reset() + + # Verify that the MultiProcessShared instance is deleted + # and the counter is reseted + with self.assertRaises(Exception): + instance_before.get() + instance_after = multi_process_shared.MultiProcessShared( + Counter, tag=tag, always_proxy=True).acquire() + self.assertEqual(instance_after.get(), 0) + def test_model_manager_capacity_check(self): """ Test that the manager blocks when spawning models exceeds the limit, From 31e87a9cc4f62bb7f68b6b7dece929121165d619 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 30 Jan 2026 17:52:29 +0000 Subject: [PATCH 06/22] Fix import order --- .../apache_beam/ml/inference/model_manager.py | 27 ++++++++++++------- .../ml/inference/model_manager_test.py | 11 +++++--- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 744f1aa9f45f..71ebd34779c7 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -24,18 +24,27 @@ usage and performance. """ -import time -import threading -import subprocess -import logging import gc -import numpy as np -from scipy.optimize import nnls -import torch import heapq import itertools -from collections import defaultdict, deque, Counter, OrderedDict -from typing import Dict, Any, Tuple, Optional, Callable +import logging +import subprocess +import threading +import time +from collections import Counter +from collections import OrderedDict +from collections import defaultdict +from collections import deque +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional +from typing import Tuple + +import numpy as np +import torch +from scipy.optimize import nnls + from apache_beam.utils import multi_process_shared logger = logging.getLogger(__name__) diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index ce9cff0064e9..86442b0c11a5 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -15,16 +15,19 @@ # limitations under the License. # -import unittest -import time -import threading import random +import threading +import time +import unittest from concurrent.futures import ThreadPoolExecutor from unittest.mock import patch + from apache_beam.utils import multi_process_shared try: - from apache_beam.ml.inference.model_manager import ModelManager, GPUMonitor, ResourceEstimator + from apache_beam.ml.inference.model_manager import GPUMonitor + from apache_beam.ml.inference.model_manager import ModelManager + from apache_beam.ml.inference.model_manager import ResourceEstimator except ImportError as e: raise unittest.SkipTest("Model Manager dependencies are not installed") From 5c89ca796c1b625a5871e47d680a80bf666781b3 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 30 Jan 2026 19:01:53 +0000 Subject: [PATCH 07/22] Clean up and logs --- .../apache_beam/ml/inference/model_manager.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 71ebd34779c7..6c46376347bd 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -345,7 +345,8 @@ def __init__( def all_models(self, tag) -> list[Any]: return self._models[tag] - def enter_isolation_mode(self, tag: str, ticket_num: int) -> bool: + # Should hold _cv lock when calling + def try_enter_isolation_mode(self, tag: str, ticket_num: int) -> bool: if self._total_active_jobs > 0: logger.info( "Waiting to enter isolation: tag=%s ticket num=%s", tag, ticket_num) @@ -363,6 +364,7 @@ def enter_isolation_mode(self, tag: str, ticket_num: int) -> bool: self._monitor.reset_peak() return True + # Should hold _cv lock when calling def should_spawn_model(self, tag: str, ticket_num: int) -> bool: curr, _, total = self._monitor.get_stats() est_cost = self._estimator.get_estimate(tag) @@ -403,6 +405,8 @@ def should_spawn_model(self, tag: str, ticket_num: int) -> bool: curr, total_model_count, other_idle_count) + # Wait since we couldn't make space and + # added timeout to avoid missed notify call. self._cv.wait(timeout=10.0) return False @@ -431,7 +435,6 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: heapq.heappush( self._wait_queue, (current_priority, ticket_num, my_id, tag)) - should_spawn = False est_cost = 0.0 is_unknown = False @@ -465,8 +468,8 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: # Path A: Isolation if is_unknown: - if self.enter_isolation_mode(tag, ticket_num): - should_spawn = True + if self.try_enter_isolation_mode(tag, ticket_num): + # We got isolation, can proceed to spawn break else: # We waited, need to re-evaluate our turn @@ -484,8 +487,8 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: continue if self.should_spawn_model(tag, ticket_num): - should_spawn = True est_cost = self._estimator.get_estimate(tag) + # We can proceed to spawn since we have resources break else: # We waited, need to re-evaluate our turn @@ -497,14 +500,18 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: if self._wait_queue and self._wait_queue[0][2] is my_id: heapq.heappop(self._wait_queue) else: + logger.warning( + "Item not at head of wait queue during cleanup" + ", this is not expected: tag=%s ticket num=%s", + tag, + ticket_num) for i, item in enumerate(self._wait_queue): if item[2] is my_id: self._wait_queue.pop(i) heapq.heapify(self._wait_queue) self._cv.notify_all() - if should_spawn: - return self._spawn_new_model(tag, loader_func, is_unknown, est_cost) + return self._spawn_new_model(tag, loader_func, is_unknown, est_cost) def release_model(self, tag: str, instance: Any): with self._cv: @@ -575,6 +582,7 @@ def _evict_to_make_space( now = time.time() # Calculate the demand from the wait queue + # TODO: Also factor in the active counts to avoid thrashing demand_map = Counter() for item in self._wait_queue: demand_map[item[3]] += 1 From a5e8ff3d6bebb0cfc11c3e95fa2859227de0c1ac Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 30 Jan 2026 19:04:41 +0000 Subject: [PATCH 08/22] Added timeout for waiting too long on model acquire --- .../apache_beam/ml/inference/model_manager.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 6c46376347bd..e46f54b42ee4 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -305,7 +305,8 @@ def __init__( min_data_points: int = 5, smoothing_factor: float = 0.2, eviction_cooldown_seconds: float = 10.0, - min_model_copies: int = 1): + min_model_copies: int = 1, + wait_timeout_seconds: float = 300.0): self._estimator = ResourceEstimator( min_data_points=min_data_points, smoothing_factor=smoothing_factor) @@ -315,6 +316,7 @@ def __init__( self._eviction_cooldown = eviction_cooldown_seconds self._min_model_copies = min_model_copies + self._wait_timeout_seconds = wait_timeout_seconds # Resource State self._models = defaultdict(list) @@ -437,9 +439,18 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: est_cost = 0.0 is_unknown = False + wait_time_start = time.time() try: while True: + wait_time_elapsed = time.time() - wait_time_start + if wait_time_elapsed > self._wait_timeout_seconds: + logger.warning( + "Long wait detected for model acquisition: " + "tag=%s ticket num=%s elapsed=%.1f seconds", + tag, + ticket_num, + wait_time_elapsed) if not self._wait_queue or self._wait_queue[0][2] is not my_id: logger.info( "Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num) From c340399c5d420fad2bc5705eb81d52b23f15418e Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 30 Jan 2026 19:05:31 +0000 Subject: [PATCH 09/22] Throw error if timeout --- sdks/python/apache_beam/ml/inference/model_manager.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index e46f54b42ee4..e36b20967a26 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -445,12 +445,9 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: while True: wait_time_elapsed = time.time() - wait_time_start if wait_time_elapsed > self._wait_timeout_seconds: - logger.warning( - "Long wait detected for model acquisition: " - "tag=%s ticket num=%s elapsed=%.1f seconds", - tag, - ticket_num, - wait_time_elapsed) + raise RuntimeError( + f"Timeout waiting to acquire model: {tag} " + f"after {wait_time_elapsed:.1f} seconds.") if not self._wait_queue or self._wait_queue[0][2] is not my_id: logger.info( "Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num) From 82d357795f837e0f833491283ca0460a48cc26ef Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 30 Jan 2026 19:14:29 +0000 Subject: [PATCH 10/22] Add test for timeout and adjust --- .../apache_beam/ml/inference/model_manager.py | 12 ++++++---- .../ml/inference/model_manager_test.py | 24 +++++++++++++++++++ 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index e36b20967a26..b2bd7c50c342 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -306,7 +306,8 @@ def __init__( smoothing_factor: float = 0.2, eviction_cooldown_seconds: float = 10.0, min_model_copies: int = 1, - wait_timeout_seconds: float = 300.0): + wait_timeout_seconds: float = 300.0, + lock_timeout_seconds: float = 60.0): self._estimator = ResourceEstimator( min_data_points=min_data_points, smoothing_factor=smoothing_factor) @@ -317,6 +318,7 @@ def __init__( self._eviction_cooldown = eviction_cooldown_seconds self._min_model_copies = min_model_copies self._wait_timeout_seconds = wait_timeout_seconds + self._lock_timeout_seconds = lock_timeout_seconds # Resource State self._models = defaultdict(list) @@ -352,7 +354,7 @@ def try_enter_isolation_mode(self, tag: str, ticket_num: int) -> bool: if self._total_active_jobs > 0: logger.info( "Waiting to enter isolation: tag=%s ticket num=%s", tag, ticket_num) - self._cv.wait() + self._cv.wait(timeout=self._lock_timeout_seconds) # return False since we have waited and need to re-evaluate # in caller to make sure our priority is still valid. return False @@ -409,7 +411,7 @@ def should_spawn_model(self, tag: str, ticket_num: int) -> bool: other_idle_count) # Wait since we couldn't make space and # added timeout to avoid missed notify call. - self._cv.wait(timeout=10.0) + self._cv.wait(timeout=self._lock_timeout_seconds) return False def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: @@ -451,7 +453,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: if not self._wait_queue or self._wait_queue[0][2] is not my_id: logger.info( "Waiting for its turn: tag=%s ticket num=%s", tag, ticket_num) - self._cv.wait() + self._cv.wait(timeout=self._lock_timeout_seconds) continue # Re-evaluate priority in case model became known during wait @@ -491,7 +493,7 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: "Waiting due to isolation in progress: tag=%s ticket num%s", tag, ticket_num) - self._cv.wait() + self._cv.wait(timeout=self._lock_timeout_seconds) continue if self.should_spawn_model(tag, ticket_num): diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index 86442b0c11a5..8b9f987d4a16 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -156,6 +156,30 @@ def loader(): Counter, tag=tag, always_proxy=True).acquire() self.assertEqual(instance_after.get(), 0) + def test_model_manager_timeout_on_acquire(self): + """Test that acquiring a model times out properly.""" + model_name = "timeout_model" + self.manager = ModelManager( + monitor=self.mock_monitor, + wait_timeout_seconds=1.0, + lock_timeout_seconds=1.0) + + def loader(): + self.mock_monitor.allocate(self.mock_monitor._total) + return model_name + + # Acquire the model in one thread to block others + _ = self.manager.acquire_model(model_name, loader) + + def acquire_model_with_timeout(): + return self.manager.acquire_model(model_name, loader) + + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(acquire_model_with_timeout) + with self.assertRaises(RuntimeError) as context: + future.result(timeout=5.0) + self.assertIn("Timeout waiting to acquire model", str(context.exception)) + def test_model_manager_capacity_check(self): """ Test that the manager blocks when spawning models exceeds the limit, From df2f92125dfaa05aed57f63e7da06248d322ac7e Mon Sep 17 00:00:00 2001 From: "RuiLong J." Date: Fri, 30 Jan 2026 11:15:17 -0800 Subject: [PATCH 11/22] Update sdks/python/apache_beam/ml/inference/model_manager.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- sdks/python/apache_beam/ml/inference/model_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index b2bd7c50c342..cd6dded73dd7 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -141,8 +141,8 @@ def refresh(self): def _get_nvidia_smi_used(self) -> float: try: - cmd = "nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits" - output = subprocess.check_output(cmd, shell=True).decode("utf-8").strip() + cmd = ["nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits"] + output = subprocess.check_output(cmd, text=True).strip() free_memory = float(output) return self._total_memory - free_memory except Exception: From c7920e4373b09dbe1f1099fe6fa360806631e903 Mon Sep 17 00:00:00 2001 From: "RuiLong J." Date: Fri, 30 Jan 2026 11:15:27 -0800 Subject: [PATCH 12/22] Update sdks/python/apache_beam/ml/inference/model_manager.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- sdks/python/apache_beam/ml/inference/model_manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index cd6dded73dd7..4b504406f992 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -145,7 +145,8 @@ def _get_nvidia_smi_used(self) -> float: output = subprocess.check_output(cmd, text=True).strip() free_memory = float(output) return self._total_memory - free_memory - except Exception: + except Exception as e: + logger.warning('Failed to get GPU memory usage: %s', e) return 0.0 def _poll_loop(self): From b2ebc0aad1ab332f8a5f2a4570767991c78329a1 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 30 Jan 2026 19:19:48 +0000 Subject: [PATCH 13/22] Gemini clean up --- sdks/python/apache_beam/ml/inference/model_manager.py | 4 ---- sdks/python/apache_beam/ml/inference/model_manager_test.py | 2 -- 2 files changed, 6 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 4b504406f992..d79762ce7d60 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -295,8 +295,6 @@ class ModelManager: pending requests, when space is needed. 4. 'Isolation Mode' for safely profiling unknown models. """ - _lock = threading.Lock() - def __init__( self, monitor: Optional['GPUMonitor'] = None, @@ -738,8 +736,6 @@ def _force_reset(self): def shutdown(self): self._delete_all_models() - gc.collect() - torch.cuda.empty_cache() self._monitor.stop() def __del__(self): diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index 8b9f987d4a16..657a9e6ec547 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -121,7 +121,6 @@ def increment(self, value=1): class TestModelManager(unittest.TestCase): def setUp(self): """Force reset the Singleton ModelManager before every test.""" - ModelManager._instance = None self.mock_monitor = MockGPUMonitor() self.manager = ModelManager(monitor=self.mock_monitor) @@ -366,7 +365,6 @@ def run_inference(): class TestModelManagerEviction(unittest.TestCase): def setUp(self): self.mock_monitor = MockGPUMonitor(total_memory=12000.0) - ModelManager._instance = None self.manager = ModelManager( monitor=self.mock_monitor, slack_percentage=0.0, From ea21b74f0c9994ec0337a17812667019ff989334 Mon Sep 17 00:00:00 2001 From: "RuiLong J." Date: Fri, 30 Jan 2026 11:20:13 -0800 Subject: [PATCH 14/22] Update sdks/python/apache_beam/ml/inference/model_manager_test.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- sdks/python/apache_beam/ml/inference/model_manager_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index 657a9e6ec547..0c83b99263b2 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -302,7 +302,7 @@ def crashing_loader(): self.assertEqual(self.manager._pending_reservations, 0.0) self.assertFalse(self.manager._cv._is_owned()) - def test_model_managaer_force_reset_on_exception(self): + def test_model_manager_force_reset_on_exception(self): """Test that force_reset clears all models from the manager.""" model_name = "test_model" From ce433e37d3312dd4d848a2225bea4e5cd983ec7d Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 30 Jan 2026 19:26:51 +0000 Subject: [PATCH 15/22] Update GPU monitor test --- sdks/python/apache_beam/ml/inference/model_manager_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index 0c83b99263b2..0712035c645a 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -540,8 +540,9 @@ def subprocess_side_effect(*args, **kwargs): if isinstance(args[0], list) and "memory.total" in args[0][1]: return "16000" - if isinstance(args[0], str) and "memory.free" in args[0]: - return b"12000" + if isinstance(args[0], list) and any("memory.free" in part + for part in args[0]): + return "12000" raise Exception("Unexpected command") From 6846eebccf8c9a9d5abb7ac15b0f6378f660a723 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Fri, 30 Jan 2026 19:54:55 +0000 Subject: [PATCH 16/22] Format --- sdks/python/apache_beam/ml/inference/model_manager.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index d79762ce7d60..0ddf53428700 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -141,7 +141,11 @@ def refresh(self): def _get_nvidia_smi_used(self) -> float: try: - cmd = ["nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits"] + cmd = [ + "nvidia-smi", + "--query-gpu=memory.free", + "--format=csv,noheader,nounits" + ] output = subprocess.check_output(cmd, text=True).strip() free_memory = float(output) return self._total_memory - free_memory From e7c7165768415708a261d16ceb35274dff3efbcc Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Tue, 3 Feb 2026 23:29:29 +0000 Subject: [PATCH 17/22] Cleanup upating is_unkown logic --- sdks/python/apache_beam/ml/inference/model_manager.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 0ddf53428700..395165e3dad8 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -460,8 +460,8 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: continue # Re-evaluate priority in case model became known during wait - real_is_unknown = self._estimator.is_unknown(tag) - real_priority = 0 if real_is_unknown else 1 + is_unknown = self._estimator.is_unknown(tag) + real_priority = 0 if is_unknown else 1 # If priority changed, reinsert into queue and wait if current_priority != real_priority: @@ -477,8 +477,6 @@ def acquire_model(self, tag: str, loader_func: Callable[[], Any]) -> Any: if cached_instance: return cached_instance - is_unknown = real_is_unknown - # Path A: Isolation if is_unknown: if self.try_enter_isolation_mode(tag, ticket_num): From 6ebb19cbd8b843b52a96dd6444f1e418c01add58 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Wed, 4 Feb 2026 01:32:51 +0000 Subject: [PATCH 18/22] Try to fix flake --- sdks/python/apache_beam/ml/inference/model_manager_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index 0712035c645a..d40acc287a02 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -19,6 +19,7 @@ import threading import time import unittest +from concurrent.futures import TimeoutError from concurrent.futures import ThreadPoolExecutor from unittest.mock import patch @@ -353,7 +354,7 @@ def run_inference(): return with ThreadPoolExecutor(max_workers=8) as executor: - futures = [executor.submit(run_inference) for _ in range(100)] + futures = [executor.submit(run_inference) for _ in range(200)] for f in futures: f.result() From 642d5d780f06fb1e0d0a9116dbfe6c517d7f003c Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Wed, 4 Feb 2026 02:04:11 +0000 Subject: [PATCH 19/22] Fix import order --- sdks/python/apache_beam/ml/inference/model_manager_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index d40acc287a02..4481fea551ea 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -19,8 +19,8 @@ import threading import time import unittest -from concurrent.futures import TimeoutError from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import TimeoutError from unittest.mock import patch from apache_beam.utils import multi_process_shared From ed3052f2302593b3c4435835d6026cbe9460d5a7 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Wed, 4 Feb 2026 02:52:20 +0000 Subject: [PATCH 20/22] Fix random seed to avoid flake --- sdks/python/apache_beam/ml/inference/model_manager_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager_test.py b/sdks/python/apache_beam/ml/inference/model_manager_test.py index 4481fea551ea..1bd8edd34d18 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager_test.py +++ b/sdks/python/apache_beam/ml/inference/model_manager_test.py @@ -333,7 +333,9 @@ def test_single_model_convergence_with_fluctuations(self): """ model_name = "fluctuating_model" model_cost = 3000.0 - load_cost = 2000.0 + load_cost = 2500.0 + # Fix random seed for reproducibility + random.seed(42) def loader(): self.mock_monitor.allocate(load_cost) @@ -354,7 +356,7 @@ def run_inference(): return with ThreadPoolExecutor(max_workers=8) as executor: - futures = [executor.submit(run_inference) for _ in range(200)] + futures = [executor.submit(run_inference) for _ in range(100)] for f in futures: f.result() From f77ff23b7428f4d779b5f74ead827af52c9f9d3a Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Wed, 4 Feb 2026 17:23:49 +0000 Subject: [PATCH 21/22] Fix identation --- sdks/python/apache_beam/ml/inference/model_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 395165e3dad8..2de279f4290e 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -296,7 +296,7 @@ class ModelManager: 1. LRU Caching of idle models. 2. Resource estimation and admission control (preventing OOM). 3. Dynamic eviction of low-priority models, determined by count of - pending requests, when space is needed. + pending requests, when space is needed. 4. 'Isolation Mode' for safely profiling unknown models. """ def __init__( From 4b0d8a836379293db4e0e1859ee1ac4ad2782446 Mon Sep 17 00:00:00 2001 From: AMOOOMA Date: Wed, 4 Feb 2026 18:10:20 +0000 Subject: [PATCH 22/22] Try fixing doc again --- sdks/python/apache_beam/ml/inference/model_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/model_manager.py b/sdks/python/apache_beam/ml/inference/model_manager.py index 2de279f4290e..cc9f833c2682 100644 --- a/sdks/python/apache_beam/ml/inference/model_manager.py +++ b/sdks/python/apache_beam/ml/inference/model_manager.py @@ -292,7 +292,7 @@ class ModelManager: """Manages model lifecycles, caching, and resource arbitration. This class acts as the central controller for acquiring model instances. - It handles: + 1. LRU Caching of idle models. 2. Resource estimation and admission control (preventing OOM). 3. Dynamic eviction of low-priority models, determined by count of