diff --git a/flixopt/comparison.py b/flixopt/comparison.py index a8c2076c8..6521e0597 100644 --- a/flixopt/comparison.py +++ b/flixopt/comparison.py @@ -2,7 +2,9 @@ from __future__ import annotations +import threading import warnings +from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Any, Literal, overload import xarray as xr @@ -19,6 +21,7 @@ ) if TYPE_CHECKING: + import pathlib from collections.abc import ItemsView, Iterator, KeysView, ValuesView from .flow_system import FlowSystem @@ -28,6 +31,12 @@ # Extract all unique slot names from xarray_plotly _CASE_SLOTS = frozenset(slot for slots in SLOT_ORDERS.values() for slot in slots) +# The netCDF4 C library is not thread-safe: concurrent `xr.load_dataset` calls +# with engine='netcdf4' can segfault because the HDF5 error stack and library +# state are global. We serialize only the file-read step (the CPU-heavy +# deserialization that follows runs in parallel). +_NETCDF_READ_LOCK = threading.Lock() + def _extract_nonindex_coords(datasets: list[xr.Dataset]) -> tuple[list[xr.Dataset], dict[str, tuple[str, dict]]]: """Extract and merge non-index coords, returning cleaned datasets and merged mappings. @@ -186,6 +195,66 @@ def __init__(self, flow_systems: list[FlowSystem], names: list[str] | None = Non self._statistics: ComparisonStatistics | None = None self._inputs: xr.Dataset | None = None + @classmethod + def from_netcdf( + cls, + paths: list[str | pathlib.Path] | dict[str | pathlib.Path, str], + max_workers: int | None = None, + ) -> Comparison: + """Load multiple FlowSystems from NetCDF files and combine them into a Comparison. + + The file read itself is serialized (the netCDF4 C library is not + thread-safe — concurrent reads can segfault), but the CPU-heavy + deserialization — JSON attrs and rebuilding the FlowSystem from the + dataset — runs in parallel across a thread pool. This typically still + gives a solid speedup because deserialization dominates the total load + time for non-trivial systems. + + Args: + paths: Either a list of file paths (names are derived from the + filename stems), or a dict mapping file paths to explicit case + names. + max_workers: Maximum number of threads used to deserialize loaded + datasets. ``None`` uses the default of + :class:`concurrent.futures.ThreadPoolExecutor`. Set to ``1`` to + run sequentially. + + Returns: + A new :class:`Comparison` containing the loaded FlowSystems. + + Examples: + ```python + # From a list (names come from filenames) + comp = fx.Comparison.from_netcdf(['results/base.nc', 'results/modified.nc']) + + # With explicit names + comp = fx.Comparison.from_netcdf({'results/base.nc': 'baseline', 'results/modified.nc': 'variant'}) + ``` + """ + import pathlib as _pl + + from .flow_system import FlowSystem + from .io import load_dataset_from_netcdf + + if isinstance(paths, dict): + path_list = list(paths.keys()) + names: list[str] | None = list(paths.values()) + else: + path_list = list(paths) + names = None + + def _load_one(path: str | _pl.Path) -> FlowSystem: + with _NETCDF_READ_LOCK: + ds = load_dataset_from_netcdf(path) + fs = FlowSystem.from_dataset(ds) + fs.name = _pl.Path(path).stem + return fs + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + flow_systems = list(executor.map(_load_one, path_list)) + + return cls(flow_systems, names=names) + def __repr__(self) -> str: """Return a detailed string representation.""" lines = ['Comparison', '=' * 10] @@ -417,6 +486,43 @@ def inputs(self) -> xr.Dataset: self._inputs = _apply_merged_coords(result, merged_coords) return self._inputs + def expand(self, max_workers: int | None = None) -> Comparison: + """Expand clustered FlowSystems back to full timesteps in parallel. + + Calls :meth:`FlowSystem.transform.expand` on every contained FlowSystem + that has a ``clustering`` attribute. FlowSystems without clustering are + passed through unchanged, so mixed comparisons are safe. + + Expansion is CPU-bound but vectorized through xarray/numpy, which + release the GIL for most operations — a thread pool is typically + enough to get a speedup. + + Args: + max_workers: Maximum number of threads used to expand systems. + ``None`` uses the default of + :class:`concurrent.futures.ThreadPoolExecutor`. Set to ``1`` to + expand sequentially. + + Returns: + A new :class:`Comparison` with expanded FlowSystems, preserving + the original case names. + + Examples: + ```python + comp_reduced = fx.Comparison([fs_clustered_a, fs_clustered_b]) + comp_full = comp_reduced.expand() + comp_full.stats.plot.balance('Heat') # Full-resolution plots + ``` + """ + + def _expand_one(fs: FlowSystem) -> FlowSystem: + return fs.transform.expand() if fs.clustering is not None else fs + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + expanded = list(executor.map(_expand_one, self._systems)) + + return type(self)(expanded, names=list(self._names)) + class ComparisonStatistics: """Combined statistics accessor for comparing FlowSystems. diff --git a/tests/test_comparison.py b/tests/test_comparison.py index 94328da97..d238b0cf6 100644 --- a/tests/test_comparison.py +++ b/tests/test_comparison.py @@ -533,3 +533,131 @@ def test_diff_invalid_reference_raises(self, optimized_base, optimized_with_chp) with pytest.raises(ValueError, match='not found'): comp.diff(reference='NonexistentCase') + + +# ============================================================================ +# PARALLEL LOAD / EXPAND TESTS +# ============================================================================ + + +class TestComparisonFromNetcdf: + """Tests for Comparison.from_netcdf classmethod.""" + + def test_from_netcdf_list_of_paths(self, tmp_path, optimized_base, optimized_with_chp): + """List of paths loads systems with names derived from filenames.""" + p1 = tmp_path / 'base.nc' + p2 = tmp_path / 'with_chp.nc' + optimized_base.to_netcdf(p1) + optimized_with_chp.to_netcdf(p2) + + comp = fx.Comparison.from_netcdf([p1, p2]) + + assert comp.names == ['base', 'with_chp'] + assert len(comp) == 2 + assert comp.is_optimized + + def test_from_netcdf_dict_paths_to_names(self, tmp_path, optimized_base, optimized_with_chp): + """Dict input uses explicit names instead of filenames.""" + p1 = tmp_path / 'base.nc' + p2 = tmp_path / 'chp.nc' + optimized_base.to_netcdf(p1) + optimized_with_chp.to_netcdf(p2) + + comp = fx.Comparison.from_netcdf({p1: 'baseline', p2: 'variant'}) + + assert comp.names == ['baseline', 'variant'] + + def test_from_netcdf_serial_matches_parallel(self, tmp_path, optimized_base, optimized_with_chp): + """max_workers=1 produces the same result as the default parallel load.""" + p1 = tmp_path / 'base.nc' + p2 = tmp_path / 'chp.nc' + optimized_base.to_netcdf(p1) + optimized_with_chp.to_netcdf(p2) + + comp_parallel = fx.Comparison.from_netcdf([p1, p2]) + comp_serial = fx.Comparison.from_netcdf([p1, p2], max_workers=1) + + assert comp_parallel.names == comp_serial.names + xr.testing.assert_identical(comp_parallel.solution, comp_serial.solution) + + +class TestComparisonExpand: + """Tests for Comparison.expand method.""" + + @pytest.fixture(scope='class') + def clustered_systems(self): + """Build two clustered, optimized FlowSystems (module/class-scoped: solve once).""" + pytest.importorskip('tsam') + n_hours = 168 # 7 days + ts = pd.date_range('2024-01-01', periods=n_hours, freq='h', name='time') + demand = np.sin(np.linspace(0, 14 * np.pi, n_hours)) + 2 + + def _build(name: str, cost: float) -> fx.FlowSystem: + fs = fx.FlowSystem(ts, name=name) + fs.add_elements( + fx.Effect('costs', '€', 'Costs', is_standard=True, is_objective=True), + fx.Bus('Electricity'), + fx.Source( + 'Grid', + outputs=[fx.Flow('P_el', bus='Electricity', size=100, effects_per_flow_hour={'costs': cost})], + ), + fx.Sink( + 'Demand', + inputs=[ + fx.Flow( + 'P_demand', + bus='Electricity', + size=100, + fixed_relative_profile=fx.TimeSeriesData(demand / 100), + ) + ], + ), + ) + return fs + + solver = fx.solvers.HighsSolver(mip_gap=0, time_limit_seconds=60, log_to_console=False) + systems = [] + for name, cost in [('A', 0.3), ('B', 0.25)]: + fs = _build(name, cost).transform.cluster(n_clusters=2, cluster_duration='1D') + fs.optimize(solver) + systems.append(fs) + return systems + + def test_expand_returns_new_comparison(self, clustered_systems): + """expand() returns a new Comparison instance, preserving names.""" + comp = fx.Comparison(clustered_systems, names=['a', 'b']) + expanded = comp.expand() + + assert isinstance(expanded, fx.Comparison) + assert expanded is not comp + assert expanded.names == ['a', 'b'] + + def test_expand_restores_full_timesteps(self, clustered_systems): + """Each expanded FlowSystem has the full (original) timestep count.""" + comp = fx.Comparison(clustered_systems, names=['a', 'b']) + expanded = comp.expand() + + for fs in expanded.values(): + # Original was 168 hours; clustering exposes 2D shape but expand + # restores a single time axis with 168 steps (+1 boundary). + assert 'time' in fs.solution.dims + assert fs.solution.sizes['time'] == 168 + 1 + + def test_expand_serial_matches_parallel(self, clustered_systems): + """max_workers=1 gives identical results to the default parallel path.""" + comp = fx.Comparison(clustered_systems, names=['a', 'b']) + + expanded_parallel = comp.expand() + expanded_serial = comp.expand(max_workers=1) + + xr.testing.assert_identical(expanded_parallel.solution, expanded_serial.solution) + + def test_expand_passes_through_non_clustered(self, clustered_systems, optimized_base): + """Systems without clustering are passed through unchanged (mixed comparison).""" + comp = fx.Comparison([clustered_systems[0], optimized_base], names=['clustered', 'plain']) + expanded = comp.expand() + + # The non-clustered system is the same object, untouched. + assert expanded['plain'] is optimized_base + # The clustered system was actually expanded (new object). + assert expanded['clustered'] is not clustered_systems[0]