Skip to content

Commit 677cd19

Browse files
committed
Let's go vesin-0.5.1
1 parent ec04879 commit 677cd19

3 files changed

Lines changed: 38 additions & 25 deletions

File tree

python/metatomic_torch/metatomic/torch/heat_flux.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
from metatensor.torch import Labels, TensorBlock, TensorMap
5-
from vesin.metatomic import compute_requested_neighbors
5+
from vesin.metatomic import compute_requested_neighbors_from_options
66

77
from metatomic.torch import (
88
AtomisticModel,
@@ -12,7 +12,7 @@
1212
)
1313

1414

15-
def wrap_positions(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor:
15+
def _wrap_positions(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor:
1616
"""
1717
Wrap positions into the periodic cell.
1818
"""
@@ -23,7 +23,7 @@ def wrap_positions(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor:
2323
return wrapped_positions
2424

2525

26-
def check_collisions(
26+
def _check_collisions(
2727
cell: torch.Tensor, positions: torch.Tensor, cutoff: float, skin: float
2828
) -> tuple[torch.Tensor, torch.Tensor]:
2929
"""
@@ -58,7 +58,7 @@ def check_collisions(
5858
)
5959

6060

61-
def collisions_to_replicas(collisions: torch.Tensor) -> torch.Tensor:
61+
def _collisions_to_replicas(collisions: torch.Tensor) -> torch.Tensor:
6262
"""
6363
Convert boundary-collision flags into a boolean mask over all periodic image
6464
displacements in {0, +1, -1}^3. e.g. for an atom colliding with the x_lo and y_hi
@@ -87,7 +87,7 @@ def collisions_to_replicas(collisions: torch.Tensor) -> torch.Tensor:
8787
return outs.to(device=collisions.device)
8888

8989

90-
def generate_replica_atoms(
90+
def _generate_replica_atoms(
9191
types: torch.Tensor,
9292
positions: torch.Tensor,
9393
cell: torch.Tensor,
@@ -104,26 +104,30 @@ def generate_replica_atoms(
104104
replica_offsets = torch.tensor(
105105
[0, 1, -1], device=positions.device, dtype=positions.dtype
106106
)[replicas[:, 1:]]
107-
replica_positions = positions[replica_idx] + torch.einsum("iA,Aa->ia", replica_offsets, cell)
107+
replica_positions = positions[replica_idx] + torch.einsum(
108+
"iA,Aa->ia", replica_offsets, cell
109+
)
108110

109111
return replica_idx, types[replica_idx], replica_positions
110112

111113

112-
def unfold_system(metatomic_system: System, cutoff: float, skin: float = 0.5) -> System:
114+
def _unfold_system(
115+
metatomic_system: System, cutoff: float, skin: float = 0.5
116+
) -> System:
113117
"""
114118
Unfold a periodic system by generating replica atoms for those near the cell
115119
boundaries within the specified cutoff distance.
116120
The unfolded system has no periodic boundary conditions.
117121
"""
118122

119-
wrapped_positions = wrap_positions(
123+
wrapped_positions = _wrap_positions(
120124
metatomic_system.positions, metatomic_system.cell
121125
)
122-
collisions, _ = check_collisions(
126+
collisions, _ = _check_collisions(
123127
metatomic_system.cell, wrapped_positions, cutoff, skin
124128
)
125-
replicas = collisions_to_replicas(collisions)
126-
replica_idx, replica_types, replica_positions = generate_replica_atoms(
129+
replicas = _collisions_to_replicas(collisions)
130+
replica_idx, replica_types, replica_positions = _generate_replica_atoms(
127131
metatomic_system.types, wrapped_positions, metatomic_system.cell, replicas
128132
)
129133
unfolded_types = torch.cat(
@@ -258,7 +262,7 @@ def __init__(self, model: AtomisticModel, skin: float = 0.5):
258262
def requested_inputs(self) -> Dict[str, ModelOutput]:
259263
return self._requested_inputs
260264

261-
def barycenter_and_atomic_energies(self, system: System, n_atoms: int):
265+
def _barycenter_and_atomic_energies(self, system: System, n_atoms: int):
262266
energy_block = self._model([system], self._unfolded_run_options, False)[
263267
"energy"
264268
].block(0)
@@ -272,22 +276,24 @@ def barycenter_and_atomic_energies(self, system: System, n_atoms: int):
272276

273277
return barycenter, atomic_e, total_e
274278

275-
def calc_unfolded_heat_flux(self, system: System) -> torch.Tensor:
279+
def _calc_unfolded_heat_flux(self, system: System) -> torch.Tensor:
276280
n_atoms = len(system.positions)
277-
unfolded_system = unfold_system(system, self._interaction_range, self.skin).to(
278-
"cpu"
281+
unfolded_system = _unfold_system(system, self._interaction_range, self.skin).to(
282+
system.device
279283
)
280-
compute_requested_neighbors(
281-
unfolded_system, self._unfolded_run_options.length_unit, model=self._model
284+
compute_requested_neighbors_from_options(
285+
[unfolded_system],
286+
self._model.requested_neighbor_lists(),
287+
self._unfolded_run_options.length_unit,
288+
False,
282289
)
283-
unfolded_system = unfolded_system.to(system.device)
284290
velocities: torch.Tensor = (
285291
unfolded_system.get_data("velocities").block().values.reshape(-1, 3)
286292
)
287293
masses: torch.Tensor = (
288294
unfolded_system.get_data("masses").block().values.reshape(-1)
289295
)
290-
barycenter, atomic_e, total_e = self.barycenter_and_atomic_energies(
296+
barycenter, atomic_e, total_e = self._barycenter_and_atomic_energies(
291297
unfolded_system, n_atoms
292298
)
293299

@@ -352,7 +358,7 @@ def forward(
352358
heat_fluxes: List[torch.Tensor] = []
353359
for system in systems:
354360
system.positions.requires_grad_(True)
355-
heat_fluxes.append(self.calc_unfolded_heat_flux(system))
361+
heat_fluxes.append(self._calc_unfolded_heat_flux(system))
356362

357363
samples = Labels(
358364
["system"], torch.arange(len(systems), device=device).reshape(-1, 1)

python/metatomic_torch/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ python_files = ["*.py"]
5858
testpaths = ["tests"]
5959
filterwarnings = [
6060
"error",
61+
"ignore:Found metatomic.torch.*but vesin.metatomic was only tested with:UserWarning",
6162
"ignore:`torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`:DeprecationWarning",
6263
"ignore:`torch.jit.save` is deprecated. Please switch to `torch.export`:DeprecationWarning",
6364
"ignore:`torch.jit.load` is deprecated. Please switch to `torch.export`:DeprecationWarning",

python/metatomic_torch/tests/test_heat_flux.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def model():
4141
def atoms():
4242
cell = np.array([[6.0, 0.0, 0.0], [0.0, 6.0, 0.0], [0.0, 0.0, 6.0]])
4343
positions = np.array([[3.0, 3.0, 3.0]])
44-
atoms = Atoms(f"Ar", scaled_positions=positions, cell=cell, pbc=True).repeat(
44+
atoms = Atoms("Ar", scaled_positions=positions, cell=cell, pbc=True).repeat(
4545
(2, 2, 2)
4646
)
4747
MaxwellBoltzmannDistribution(
@@ -123,9 +123,7 @@ def __call__(self, systems, options, check_consistency):
123123
torch.arange(len(systems), device=values.device).reshape(-1, 1),
124124
),
125125
components=[],
126-
properties=Labels(
127-
["energy"], torch.tensor([[0]], device=values.device)
128-
),
126+
properties=Labels(["energy"], torch.tensor([[0]], device=values.device)),
129127
)
130128
return {
131129
"energy": TensorMap(
@@ -287,7 +285,15 @@ def test_generate_replica_atoms_triclinic_offsets():
287285
assert replica_idx.tolist() == [0, 0, 0, 0, 0, 0, 0]
288286
assert replica_types.tolist() == [1, 1, 1, 1, 1, 1, 1]
289287

290-
expected_offsets = [cell[0], cell[1], cell[2], cell[0] + cell[1], cell[0] + cell[2], cell[1] + cell[2], cell[0] + cell[1] + cell[2]]
288+
expected_offsets = [
289+
cell[0],
290+
cell[1],
291+
cell[2],
292+
cell[0] + cell[1],
293+
cell[0] + cell[2],
294+
cell[1] + cell[2],
295+
cell[0] + cell[1] + cell[2],
296+
]
291297
expected_positions = [positions[0] + offset for offset in expected_offsets]
292298

293299
for expected in expected_positions:

0 commit comments

Comments
 (0)