Skip to content
73 changes: 62 additions & 11 deletions src/histalign/backend/ccf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pathlib import Path
import shutil
import ssl
from typing import Literal
from typing import Callable, Literal
import urllib.error
from urllib.request import urlopen

Expand Down Expand Up @@ -37,40 +37,72 @@
def download_atlas(
resolution: Resolution,
atlas_type: Literal["average_template", "ara_nissl"] = "average_template",
) -> None:
callback: Callable | None = None,
) -> bool:
"""Downloads the atlas file for the given type and resolution.

Args:
resolution (Resolution): Resolution of the atlas.
atlas_type (Literal["average_template", "ara_nissl"], optional):
Type of the atlas.
callback (Callable | None, optional):
Callback to run after every downloaded chunk. This is expected to return
a boolean indicating whether to cancel (`True`) or continue the download
(`False`).

Returns:
A boolean indicating whether the download finished successfully. `True` means
the file was downloaded, `False` means the download was cancelled or the URL
could not be found.
"""
atlas_file_name = f"{atlas_type}_{resolution.value}.nrrd"
url = "/".join([BASE_ATLAS_URL, atlas_type, atlas_file_name])
atlas_path = ATLAS_ROOT_DIRECTORY / atlas_file_name

download(url, atlas_path)
return download(url, atlas_path, callback)


def download_annotation_volume(resolution: Resolution) -> None:
def download_annotation_volume(
resolution: Resolution, callback: Callable | None = None
) -> bool:
"""Downloads the annotation volume file for the given resolution.

Args:
resolution (Resolution): Resolution of the atlas.
callback (Callable | None, optional):
Callback to run after every downloaded chunk. This is expected to return
a boolean indicating whether to cancel (`True`) or continue the download
(`False`).

Returns:
A boolean indicating whether the download finished successfully. `True` means
the file was downloaded, `False` means the download was cancelled or the URL
could not be found.
"""
volume_file_name = f"annotation_{resolution}.nrrd"
url = "/".join([BASE_ANNOTATION_URL, volume_file_name])
volume_path = ANNOTATION_ROOT_DIRECTORY / volume_file_name

download(url, volume_path)
return download(url, volume_path, callback)


def download_structure_mask(structure_name: str, resolution: Resolution) -> None:
def download_structure_mask(
structure_name: str, resolution: Resolution, callback: Callable | None = None
) -> bool:
"""Downloads the structure mask file for the given name and resolution.

Args:
structure_name (str): Name of the structure.
resolution (Resolution): Resolution of the atlas.
callback (Callable | None, optional):
Callback to run after every downloaded chunk. This is expected to return
a boolean indicating whether to cancel (`True`) or continue the download
(`False`).

Returns:
A boolean indicating whether the download finished successfully. `True` means
the file was downloaded, `False` means the download was cancelled or the URL
could not be found.
"""
structure_id = get_structure_id(structure_name, resolution)
structure_file_name = f"structure_{structure_id}.nrrd"
Expand All @@ -85,15 +117,24 @@ def download_structure_mask(structure_name: str, resolution: Resolution) -> None

os.makedirs(structure_path.parent, exist_ok=True)

download(url, structure_path)
return download(url, structure_path, callback)


def download(url: str, file_path: str | Path) -> None:
def download(url: str, file_path: str | Path, callback: Callable | None = None) -> bool:
"""Downloads a file from the given URL and saves it to the given path.

Args:
url (str): URL to fetch.
file_path (str | Path): Path to save the result to.
callback (Callable | None, optional):
Callback to run after every downloaded chunk. This is expected to return
a boolean indicating whether to cancel (`True`) or continue the download
(`False`).

Returns:
A boolean indicating whether the download finished successfully. `True` means
the file was downloaded, `False` means the download was cancelled or the URL
could not be found.
"""
# Thin guard to not just download anything...
if not url.startswith(BASE_ATLAS_URL) and not url.startswith(BASE_MASK_URL):
Expand All @@ -106,15 +147,25 @@ def download(url: str, file_path: str | Path) -> None:
# Allen SSL certificate is apparently not valid...
context = get_ssl_context(check_hostname=False, check_certificate=False)
try:
buffer_size = 1024 * 1024 # 1 MiB
with (
urlopen(url, context=context) as response,
open(tmp_file_path, "wb") as handle,
):
shutil.copyfileobj(response, handle)
while True:
buffer = response.read(buffer_size)
if not buffer:
break
handle.write(buffer)

if callback is not None and callback():
return False
except urllib.error.HTTPError:
_module_logger.error(f"URL not found ('{url}').")
return False

shutil.move(tmp_file_path, file_path)
return True


def get_ssl_context(
Expand Down Expand Up @@ -269,7 +320,7 @@ def get_structure_tree(resolution: Resolution) -> StructureTree:
return ReferenceSpaceCache(
resolution=resolution.value,
reference_space_key=os.path.join("annotation", "ccf_2017"),
manifest=str(DATA_ROOT / f"manifest.json"),
manifest=str(DATA_ROOT / "manifest.json"),
).get_structure_tree()


Expand All @@ -281,7 +332,7 @@ def get_structures_hierarchy_path() -> str:
Returns:
The path to the `structures.json` hierarchy file.
"""
path = DATA_ROOT / f"structures.json"
path = DATA_ROOT / "structures.json"

# Easiest option to have the Allen SDK do the work for us
if not path.exists():
Expand Down
23 changes: 18 additions & 5 deletions src/histalign/backend/maths/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,24 +392,37 @@ def get_sk_transform_from_parameters(
return AffineTransform(matrix=matrix)


def normalise_array(array: np.ndarray, dtype: Optional[np.dtype] = None) -> np.ndarray:
def normalise_array(
array: np.ndarray, dtype: Optional[np.dtype] = None, fast: bool = False
) -> np.ndarray:
"""Normalise an array to the range between 0 and the dtype's maximum value.

Args:
array (np.ndarray): Array to normalise.
dtype (np.dtype, optional):
Target dtype. If `None`, the dtype will be inferred as the dtype of `array`.
fast (bool, optional):
Whether to normalise without using intermediary float arrays. This will lead
to reduced accuracy but no extra memory usage.

Returns:
The normalised array.
"""
dtype = dtype or array.dtype
maximum = get_dtype_maximum(dtype)

array = array.astype(np.float64)
array -= array.min()
array /= max(array.max(), 1)
array *= maximum
if fast:
array -= array.min()
ratio = array.max() // maximum
if ratio > 1:
array[:] //= int(ratio)
else:
array[:] *= np.max(1, int(ratio))
else:
array = array.astype(np.float64)
array -= array.min()
array /= max(array.max(), 1)
array *= maximum

return array.astype(dtype)

Expand Down
8 changes: 6 additions & 2 deletions src/histalign/backend/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,19 @@ class VolumeSettings(BaseModel, validate_assignment=True):

@property
def shape(self) -> tuple[int, int, int]:
match self.resolution:
return self.get_shape_from_resolution(self.resolution)

@staticmethod
def get_shape_from_resolution(resolution: Resolution) -> tuple[int, int, int]:
match resolution:
case Resolution.MICRONS_100:
return 132, 80, 114
case Resolution.MICRONS_50:
return 264, 160, 228
case Resolution.MICRONS_25:
return 528, 320, 456
case Resolution.MICRONS_10:
return 1320, 800, 114
return 1320, 800, 1140
case _:
raise Exception("ASSERT NOT REACHED")

Expand Down
Loading