diff --git a/src/histalign/backend/ccf/__init__.py b/src/histalign/backend/ccf/__init__.py index e061d43..b04033f 100644 --- a/src/histalign/backend/ccf/__init__.py +++ b/src/histalign/backend/ccf/__init__.py @@ -9,7 +9,7 @@ from pathlib import Path import shutil import ssl -from typing import Literal +from typing import Callable, Literal import urllib.error from urllib.request import urlopen @@ -37,40 +37,72 @@ def download_atlas( resolution: Resolution, atlas_type: Literal["average_template", "ara_nissl"] = "average_template", -) -> None: + callback: Callable | None = None, +) -> bool: """Downloads the atlas file for the given type and resolution. Args: resolution (Resolution): Resolution of the atlas. atlas_type (Literal["average_template", "ara_nissl"], optional): Type of the atlas. + callback (Callable | None, optional): + Callback to run after every downloaded chunk. This is expected to return + a boolean indicating whether to cancel (`True`) or continue the download + (`False`). + + Returns: + A boolean indicating whether the download finished successfully. `True` means + the file was downloaded, `False` means the download was cancelled or the URL + could not be found. """ atlas_file_name = f"{atlas_type}_{resolution.value}.nrrd" url = "/".join([BASE_ATLAS_URL, atlas_type, atlas_file_name]) atlas_path = ATLAS_ROOT_DIRECTORY / atlas_file_name - download(url, atlas_path) + return download(url, atlas_path, callback) -def download_annotation_volume(resolution: Resolution) -> None: +def download_annotation_volume( + resolution: Resolution, callback: Callable | None = None +) -> bool: """Downloads the annotation volume file for the given resolution. Args: resolution (Resolution): Resolution of the atlas. + callback (Callable | None, optional): + Callback to run after every downloaded chunk. This is expected to return + a boolean indicating whether to cancel (`True`) or continue the download + (`False`). + + Returns: + A boolean indicating whether the download finished successfully. `True` means + the file was downloaded, `False` means the download was cancelled or the URL + could not be found. """ volume_file_name = f"annotation_{resolution}.nrrd" url = "/".join([BASE_ANNOTATION_URL, volume_file_name]) volume_path = ANNOTATION_ROOT_DIRECTORY / volume_file_name - download(url, volume_path) + return download(url, volume_path, callback) -def download_structure_mask(structure_name: str, resolution: Resolution) -> None: +def download_structure_mask( + structure_name: str, resolution: Resolution, callback: Callable | None = None +) -> bool: """Downloads the structure mask file for the given name and resolution. Args: structure_name (str): Name of the structure. resolution (Resolution): Resolution of the atlas. + callback (Callable | None, optional): + Callback to run after every downloaded chunk. This is expected to return + a boolean indicating whether to cancel (`True`) or continue the download + (`False`). + + Returns: + A boolean indicating whether the download finished successfully. `True` means + the file was downloaded, `False` means the download was cancelled or the URL + could not be found. """ structure_id = get_structure_id(structure_name, resolution) structure_file_name = f"structure_{structure_id}.nrrd" @@ -85,15 +117,24 @@ def download_structure_mask(structure_name: str, resolution: Resolution) -> None os.makedirs(structure_path.parent, exist_ok=True) - download(url, structure_path) + return download(url, structure_path, callback) -def download(url: str, file_path: str | Path) -> None: +def download(url: str, file_path: str | Path, callback: Callable | None = None) -> bool: """Downloads a file from the given URL and saves it to the given path. Args: url (str): URL to fetch. file_path (str | Path): Path to save the result to. + callback (Callable | None, optional): + Callback to run after every downloaded chunk. This is expected to return + a boolean indicating whether to cancel (`True`) or continue the download + (`False`). + + Returns: + A boolean indicating whether the download finished successfully. `True` means + the file was downloaded, `False` means the download was cancelled or the URL + could not be found. """ # Thin guard to not just download anything... if not url.startswith(BASE_ATLAS_URL) and not url.startswith(BASE_MASK_URL): @@ -106,15 +147,25 @@ def download(url: str, file_path: str | Path) -> None: # Allen SSL certificate is apparently not valid... context = get_ssl_context(check_hostname=False, check_certificate=False) try: + buffer_size = 1024 * 1024 # 1 MiB with ( urlopen(url, context=context) as response, open(tmp_file_path, "wb") as handle, ): - shutil.copyfileobj(response, handle) + while True: + buffer = response.read(buffer_size) + if not buffer: + break + handle.write(buffer) + + if callback is not None and callback(): + return False except urllib.error.HTTPError: _module_logger.error(f"URL not found ('{url}').") + return False shutil.move(tmp_file_path, file_path) + return True def get_ssl_context( @@ -269,7 +320,7 @@ def get_structure_tree(resolution: Resolution) -> StructureTree: return ReferenceSpaceCache( resolution=resolution.value, reference_space_key=os.path.join("annotation", "ccf_2017"), - manifest=str(DATA_ROOT / f"manifest.json"), + manifest=str(DATA_ROOT / "manifest.json"), ).get_structure_tree() @@ -281,7 +332,7 @@ def get_structures_hierarchy_path() -> str: Returns: The path to the `structures.json` hierarchy file. """ - path = DATA_ROOT / f"structures.json" + path = DATA_ROOT / "structures.json" # Easiest option to have the Allen SDK do the work for us if not path.exists(): diff --git a/src/histalign/backend/maths/__init__.py b/src/histalign/backend/maths/__init__.py index 96b05b5..ba90f39 100644 --- a/src/histalign/backend/maths/__init__.py +++ b/src/histalign/backend/maths/__init__.py @@ -392,13 +392,18 @@ def get_sk_transform_from_parameters( return AffineTransform(matrix=matrix) -def normalise_array(array: np.ndarray, dtype: Optional[np.dtype] = None) -> np.ndarray: +def normalise_array( + array: np.ndarray, dtype: Optional[np.dtype] = None, fast: bool = False +) -> np.ndarray: """Normalise an array to the range between 0 and the dtype's maximum value. Args: array (np.ndarray): Array to normalise. dtype (np.dtype, optional): Target dtype. If `None`, the dtype will be inferred as the dtype of `array`. + fast (bool, optional): + Whether to normalise without using intermediary float arrays. This will lead + to reduced accuracy but no extra memory usage. Returns: The normalised array. @@ -406,10 +411,18 @@ def normalise_array(array: np.ndarray, dtype: Optional[np.dtype] = None) -> np.n dtype = dtype or array.dtype maximum = get_dtype_maximum(dtype) - array = array.astype(np.float64) - array -= array.min() - array /= max(array.max(), 1) - array *= maximum + if fast: + array -= array.min() + ratio = array.max() // maximum + if ratio > 1: + array[:] //= int(ratio) + else: + array[:] *= np.max(1, int(ratio)) + else: + array = array.astype(np.float64) + array -= array.min() + array /= max(array.max(), 1) + array *= maximum return array.astype(dtype) diff --git a/src/histalign/backend/models/__init__.py b/src/histalign/backend/models/__init__.py index 867618e..1c0ac0b 100644 --- a/src/histalign/backend/models/__init__.py +++ b/src/histalign/backend/models/__init__.py @@ -125,7 +125,11 @@ class VolumeSettings(BaseModel, validate_assignment=True): @property def shape(self) -> tuple[int, int, int]: - match self.resolution: + return self.get_shape_from_resolution(self.resolution) + + @staticmethod + def get_shape_from_resolution(resolution: Resolution) -> tuple[int, int, int]: + match resolution: case Resolution.MICRONS_100: return 132, 80, 114 case Resolution.MICRONS_50: @@ -133,7 +137,7 @@ def shape(self) -> tuple[int, int, int]: case Resolution.MICRONS_25: return 528, 320, 456 case Resolution.MICRONS_10: - return 1320, 800, 114 + return 1320, 800, 1140 case _: raise Exception("ASSERT NOT REACHED") diff --git a/src/histalign/backend/workspace/__init__.py b/src/histalign/backend/workspace/__init__.py index 8cf119a..6eb188a 100644 --- a/src/histalign/backend/workspace/__init__.py +++ b/src/histalign/backend/workspace/__init__.py @@ -6,8 +6,6 @@ from collections.abc import Iterator, Sequence from concurrent.futures import ThreadPoolExecutor -import contextlib -from functools import partial import hashlib import json import logging @@ -15,14 +13,14 @@ from multiprocessing import Process, Queue import os from pathlib import Path -from queue import Empty import re import shutil from threading import Event import time -from typing import Any, get_type_hints, Literal, Optional +from typing import Any, Callable, get_type_hints, Literal, Optional from allensdk.core.structure_tree import StructureTree +import nrrd import numpy as np from PIL import Image from PySide6 import QtCore @@ -175,6 +173,10 @@ def __init__( def is_loaded(self) -> bool: return self._volume is not None + @property + def is_downloaded(self) -> bool: + return self.path.exists() + def ensure_loaded(self) -> None: """Ensures the volume is loaded (and downloads it if necessary).""" self._ensure_downloaded() @@ -182,6 +184,9 @@ def ensure_loaded(self) -> None: def update_from_array(self, array: np.ndarray) -> None: """Updates the wrapped volume with a `vedo.Volume` of `array`.""" + # Very ugly but override vedo's forced deep copy of the array to create a volume + # so we don't temporarily need twice the memory. + vedo.utils.numpy2vtk.__defaults__ = (None, False, "") self._volume = vedo.Volume(array) def load(self) -> np.ndarray: @@ -189,13 +194,13 @@ def load(self) -> np.ndarray: return io.load_volume(self.path, self.dtype, as_array=True) def _ensure_downloaded(self) -> None: - if not self.path.exists() and not self.is_loaded: + if not self.is_downloaded and not self.is_loaded: self._download() self.downloaded.emit() - def _download(self) -> None: - download_atlas(self.resolution) + def _download(self, callback: Callable | None = None) -> bool: + return download_atlas(self.resolution, callback=callback) def _ensure_loaded(self) -> None: if not self.is_loaded: @@ -212,7 +217,11 @@ def __getattr__(self, name: str) -> Any: return getattr(self._volume, name) def __setattr__(self, name: str, value: Any) -> None: - if name in get_type_hints(type(self)).keys() or name in dir(self): + if ( + name in get_type_hints(type(self)).keys() + or name in dir(self) + or name == "__METAOBJECT__" # Used by PySide6 when multithreading? + ): return super().__setattr__(name, value) if not self.is_loaded: @@ -221,21 +230,20 @@ def __setattr__(self, name: str, value: Any) -> None: class AnnotationVolume(Volume): - """A wrapper around the Allen Institute's annotated CCF volumes. - - Since the Allen Institute has reserved some ID ranges, there are huge gaps in the - values of the annotated volume. This wrapper maps the IDs present in the raw file - into sequential values to allow a volume of uint16 instead of uint32, freeing a lot - of memory and not really incurring any loading cost (around 2 seconds on my - machine for the 25um annotated volume). - - References: - Algorithm for efficient value replacement: https://stackoverflow.com/a/29408060 - """ + """A wrapper around the Allen Institute's annotated CCF volumes.""" - _id_translation_table: np.ndarray _structure_tree: StructureTree + def __init__( + self, + path: str | Path, + resolution: Resolution, + convert_dtype: Optional[type | np.dtype] = None, + lazy: bool = False, + ) -> None: + self._structure_tree = get_structure_tree(Resolution.MICRONS_100) + super().__init__(path, resolution, convert_dtype, lazy) + def get_name_from_voxel(self, coordinates: Sequence) -> str: """Returns the name of the brain structure at `coordinates`. @@ -264,9 +272,7 @@ def get_name_from_voxel(self, coordinates: Sequence) -> str: value = self._volume.tonumpy()[coordinates] - node_details = self._structure_tree.get_structures_by_id( - [self._id_translation_table[value]] - )[0] + node_details = self._structure_tree.get_structures_by_id([value])[0] name: str if node_details is not None: name = node_details["name"] @@ -275,18 +281,8 @@ def get_name_from_voxel(self, coordinates: Sequence) -> str: return name - def update_from_array(self, array: np.ndarray) -> None: - unique_values = np.unique(array) - replacement_array = np.empty(array.max() + 1, dtype=np.uint16) - replacement_array[unique_values] = np.arange(len(unique_values)) - - self._id_translation_table = unique_values - self._structure_tree = get_structure_tree(Resolution.MICRONS_100) - - super().update_from_array(replacement_array[array]) - - def _download(self) -> None: - download_annotation_volume(self.resolution) + def _download(self, callback: Callable | None = None) -> bool: + return download_annotation_volume(self.resolution, callback=callback) class VolumeLoaderThread(QtCore.QThread): @@ -315,7 +311,6 @@ def start( super().start(priority) def run(self): - # Shortcircuit to avoid pickling an already-loaded volume if self.volume.is_loaded: self.volume.downloaded.emit() self.volume.loaded.emit() @@ -326,33 +321,55 @@ def run(self): process.start() while process.is_alive(): if self.isInterruptionRequested(): - process.terminate() - process.join() + process.kill() return - time.sleep(0.25) + time.sleep(0.1) self.volume.downloaded.emit() # Load queue = Queue() process = Process( - target=partial(self._run, self.volume, queue), + target=_load_volume, args=(self.volume.path, queue), daemon=True ) + _module_logger.debug("Starting volume loader process.") + byte_array = bytearray() process.start() while process.is_alive(): if self.isInterruptionRequested(): - process.terminate() - process.join() + _module_logger.debug("VolumeLoaderThread interrupted.") + process.kill() return - with contextlib.suppress(Empty): - self.volume.update_from_array(queue.get(block=False)) - self.volume.loaded.emit() + while not queue.empty(): + byte_array += queue.get() + + if self.isInterruptionRequested(): + _module_logger.debug("VolumeLoaderThread interrupted.") + process.kill() + return time.sleep(0.1) + # Reconstruct the NumPy array + dtype = nrrd.reader._determine_datatype(nrrd.read_header(str(self.volume.path))) + array = np.ndarray( + # We can't get the shape directly from self.volume as that would force a call + # to `ensure_loaded` and load the volume to get the answer. + shape=VolumeSettings.get_shape_from_resolution(self.volume.resolution), + dtype=dtype, + buffer=byte_array, + order="F", + ) + + # Potentially normalise fast + array = io.normalise_array(array, self.volume.dtype, fast=True) + + self.volume.update_from_array(array) + self.volume.loaded.emit() + @staticmethod def _run(volume: Volume, queue: Queue) -> None: queue.put(volume.load()) @@ -1053,3 +1070,26 @@ def alignment_directory_has_volumes(directory: Path) -> bool: ) return has_aligned and has_interpolated + + +def _load_volume(path: Path, queue: Queue) -> None: + _module_logger.debug("Starting to load volume.") + + # Reduce chunk size to considerably reduce memory footprint of large files + backup_chunksize = nrrd.reader._READ_CHUNKSIZE + nrrd.reader._READ_CHUNKSIZE = 2**16 + + data, _ = nrrd.read(path) + + nrrd.reader._READ_CHUNKSIZE = backup_chunksize + + # Transfer as a 1D array through the queue + data = data.ravel(order="F") + + chunk_size = 2**25 # Seems like a good trade-off between speed and memory + while data.shape != (0,): + buffer, data = data[:chunk_size], data[chunk_size:] + + queue.put(buffer.tobytes()) + + _module_logger.debug("Finished loading volume.") diff --git a/src/histalign/frontend/__init__.py b/src/histalign/frontend/__init__.py index 3566e02..2348b0e 100644 --- a/src/histalign/frontend/__init__.py +++ b/src/histalign/frontend/__init__.py @@ -335,7 +335,7 @@ def prepare_gui_for_new_project(self) -> None: tab.alignment_widget.reset_volume() tab.alignment_widget.reset_histology() - def switch_workspace(self) -> None: + def switch_workspace(self) -> bool: """Handles a change in the current workspace. Note that when this function is called, the workspace should already have been @@ -353,7 +353,7 @@ def switch_workspace(self) -> None: # Update the registration tab tab = self.registration_tab - tab.load_atlas() + return tab.load_atlas() def propagate_workspace(self) -> None: """Ensures workspace models are properly shared with all that rely on it.""" @@ -412,7 +412,9 @@ def _create_project(self, settings: ProjectSettings) -> None: # Initialise a new workspace self.workspace = Workspace(settings) - self.switch_workspace() + if not self.switch_workspace(): + # The user cancelled the download/load of the atlas + return # Update workspace state self.workspace_is_dirty = True @@ -455,7 +457,9 @@ def _open_project(self, path: str | Path) -> None: except ValueError as e: _module_logger.error(f"Failed to load project from '{path}': {e}") return InvalidProjectFileDialog(self).open() - self.switch_workspace() + if not self.switch_workspace(): + # The user cancelled the download/load of the atlas + return # Restore registration saved state tab = self.registration_tab diff --git a/src/histalign/frontend/dialogs.py b/src/histalign/frontend/dialogs.py index 9f4f056..fced49e 100644 --- a/src/histalign/frontend/dialogs.py +++ b/src/histalign/frontend/dialogs.py @@ -36,19 +36,6 @@ def __init__(self, parent: Optional[QtWidgets.QWidget] = None) -> None: self.setMaximum(0) self.setCancelButton(None) # type: ignore[arg-type] - def closeEvent(self, event: QtGui.QCloseEvent) -> None: - # Ugly but not all platforms support having a frame and no close button - event.ignore() - return - - def keyPressEvent(self, event: QtGui.QKeyEvent) -> None: - if event.key() == QtCore.Qt.Key.Key_Escape: - # Disable closing dialog with Escape - event.accept() - return - - super().keyPressEvent(event) - class InvalidProjectFileDialog(QtWidgets.QMessageBox): def __init__(self, parent: Optional[QtWidgets.QWidget] = None) -> None: diff --git a/src/histalign/frontend/registration/__init__.py b/src/histalign/frontend/registration/__init__.py index c9a8421..8a3a0cb 100644 --- a/src/histalign/frontend/registration/__init__.py +++ b/src/histalign/frontend/registration/__init__.py @@ -341,7 +341,7 @@ def clear_volume_state(self) -> None: """Clears any atlas on the GUI.""" self.alignment_widget.reset_volume() - def load_atlas(self) -> int: + def load_atlas(self) -> bool: """Loads the atlas and annotations volumes.""" _module_logger.debug("Loading atlas and annotations.") @@ -355,7 +355,6 @@ def load_atlas(self) -> int: annotation_volume.loaded.connect( lambda: _module_logger.debug("Annotations loaded.") ) - self.annotation_volume = annotation_volume atlas_volume = unwrap(self.alignment_widget.volume_slicer).volume # Set up the dialog and loader threads @@ -363,6 +362,12 @@ def load_atlas(self) -> int: annotation_loader_thread = VolumeLoaderThread(annotation_volume) atlas_loader_thread = VolumeLoaderThread(atlas_volume) + dialog.canceled.connect(annotation_loader_thread.requestInterruption) + dialog.canceled.connect(atlas_loader_thread.requestInterruption) + dialog.canceled.connect( + lambda: _module_logger.debug("Volume (down)loading cancelled by user.") + ) + atlas_volume.downloaded.connect( lambda: dialog.setLabelText("Loading atlas"), type=QtCore.Qt.ConnectionType.QueuedConnection, @@ -379,13 +384,17 @@ def load_atlas(self) -> int: annotation_loader_thread.start() atlas_loader_thread.start() - result = dialog.exec() # Blocking + result = QtWidgets.QDialog.DialogCode(dialog.exec()) # Blocking # Ensure we wait for the threads to be destroyed annotation_loader_thread.wait() atlas_loader_thread.wait() - return result + successful = result == QtWidgets.QDialog.DialogCode.Accepted + if successful: + self.annotation_volume = annotation_volume + + return successful def locate_mouse(self) -> None: if self.annotation_volume is None: diff --git a/src/histalign/io/__init__.py b/src/histalign/io/__init__.py index e19a9f4..7e8fb39 100644 --- a/src/histalign/io/__init__.py +++ b/src/histalign/io/__init__.py @@ -148,7 +148,16 @@ def load_volume( except UnknownFileFormatError: suffix = Path(path).suffix if suffix == ".nrrd": + # `nrrd` normally loads files in chunks of 4GiB but the `zlib` decompressor + # memory usage blows up (uses something around 10x the size of the array) + # when loading the 10 microns annotation volume. This reduces the chunk size + # significantly to reduce the memory usage to about the size of the array + # we're trying to load. `nrrd` still ends up using twice that temporarily + # when it creates the `numpy` array but not much we can do about that. + backup_value = nrrd.reader._READ_CHUNKSIZE + nrrd.reader._READ_CHUNKSIZE = 2**16 array = nrrd.read(path)[0] + nrrd.reader._READ_CHUNKSIZE = backup_value else: # Continue raising raise