22
33import torch
44from metatensor .torch import Labels , TensorBlock , TensorMap
5- from vesin .metatomic import compute_requested_neighbors
5+ from vesin .metatomic import compute_requested_neighbors_from_options
66
77from metatomic .torch import (
88 AtomisticModel ,
1212)
1313
1414
15- def wrap_positions (positions : torch .Tensor , cell : torch .Tensor ) -> torch .Tensor :
15+ def _wrap_positions (positions : torch .Tensor , cell : torch .Tensor ) -> torch .Tensor :
1616 """
1717 Wrap positions into the periodic cell.
1818 """
@@ -23,7 +23,7 @@ def wrap_positions(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor:
2323 return wrapped_positions
2424
2525
26- def check_collisions (
26+ def _check_collisions (
2727 cell : torch .Tensor , positions : torch .Tensor , cutoff : float , skin : float
2828) -> tuple [torch .Tensor , torch .Tensor ]:
2929 """
@@ -58,7 +58,7 @@ def check_collisions(
5858 )
5959
6060
61- def collisions_to_replicas (collisions : torch .Tensor ) -> torch .Tensor :
61+ def _collisions_to_replicas (collisions : torch .Tensor ) -> torch .Tensor :
6262 """
6363 Convert boundary-collision flags into a boolean mask over all periodic image
6464 displacements in {0, +1, -1}^3. e.g. for an atom colliding with the x_lo and y_hi
@@ -87,7 +87,7 @@ def collisions_to_replicas(collisions: torch.Tensor) -> torch.Tensor:
8787 return outs .to (device = collisions .device )
8888
8989
90- def generate_replica_atoms (
90+ def _generate_replica_atoms (
9191 types : torch .Tensor ,
9292 positions : torch .Tensor ,
9393 cell : torch .Tensor ,
@@ -104,26 +104,30 @@ def generate_replica_atoms(
104104 replica_offsets = torch .tensor (
105105 [0 , 1 , - 1 ], device = positions .device , dtype = positions .dtype
106106 )[replicas [:, 1 :]]
107- replica_positions = positions [replica_idx ] + torch .einsum ("iA,Aa->ia" , replica_offsets , cell )
107+ replica_positions = positions [replica_idx ] + torch .einsum (
108+ "iA,Aa->ia" , replica_offsets , cell
109+ )
108110
109111 return replica_idx , types [replica_idx ], replica_positions
110112
111113
112- def unfold_system (metatomic_system : System , cutoff : float , skin : float = 0.5 ) -> System :
114+ def _unfold_system (
115+ metatomic_system : System , cutoff : float , skin : float = 0.5
116+ ) -> System :
113117 """
114118 Unfold a periodic system by generating replica atoms for those near the cell
115119 boundaries within the specified cutoff distance.
116120 The unfolded system has no periodic boundary conditions.
117121 """
118122
119- wrapped_positions = wrap_positions (
123+ wrapped_positions = _wrap_positions (
120124 metatomic_system .positions , metatomic_system .cell
121125 )
122- collisions , _ = check_collisions (
126+ collisions , _ = _check_collisions (
123127 metatomic_system .cell , wrapped_positions , cutoff , skin
124128 )
125- replicas = collisions_to_replicas (collisions )
126- replica_idx , replica_types , replica_positions = generate_replica_atoms (
129+ replicas = _collisions_to_replicas (collisions )
130+ replica_idx , replica_types , replica_positions = _generate_replica_atoms (
127131 metatomic_system .types , wrapped_positions , metatomic_system .cell , replicas
128132 )
129133 unfolded_types = torch .cat (
@@ -258,7 +262,7 @@ def __init__(self, model: AtomisticModel, skin: float = 0.5):
258262 def requested_inputs (self ) -> Dict [str , ModelOutput ]:
259263 return self ._requested_inputs
260264
261- def barycenter_and_atomic_energies (self , system : System , n_atoms : int ):
265+ def _barycenter_and_atomic_energies (self , system : System , n_atoms : int ):
262266 energy_block = self ._model ([system ], self ._unfolded_run_options , False )[
263267 "energy"
264268 ].block (0 )
@@ -272,22 +276,24 @@ def barycenter_and_atomic_energies(self, system: System, n_atoms: int):
272276
273277 return barycenter , atomic_e , total_e
274278
275- def calc_unfolded_heat_flux (self , system : System ) -> torch .Tensor :
279+ def _calc_unfolded_heat_flux (self , system : System ) -> torch .Tensor :
276280 n_atoms = len (system .positions )
277- unfolded_system = unfold_system (system , self ._interaction_range , self .skin ).to (
278- "cpu"
281+ unfolded_system = _unfold_system (system , self ._interaction_range , self .skin ).to (
282+ system . device
279283 )
280- compute_requested_neighbors (
281- unfolded_system , self ._unfolded_run_options .length_unit , model = self ._model
284+ compute_requested_neighbors_from_options (
285+ [unfolded_system ],
286+ self ._model .requested_neighbor_lists (),
287+ self ._unfolded_run_options .length_unit ,
288+ False ,
282289 )
283- unfolded_system = unfolded_system .to (system .device )
284290 velocities : torch .Tensor = (
285291 unfolded_system .get_data ("velocities" ).block ().values .reshape (- 1 , 3 )
286292 )
287293 masses : torch .Tensor = (
288294 unfolded_system .get_data ("masses" ).block ().values .reshape (- 1 )
289295 )
290- barycenter , atomic_e , total_e = self .barycenter_and_atomic_energies (
296+ barycenter , atomic_e , total_e = self ._barycenter_and_atomic_energies (
291297 unfolded_system , n_atoms
292298 )
293299
@@ -352,7 +358,7 @@ def forward(
352358 heat_fluxes : List [torch .Tensor ] = []
353359 for system in systems :
354360 system .positions .requires_grad_ (True )
355- heat_fluxes .append (self .calc_unfolded_heat_flux (system ))
361+ heat_fluxes .append (self ._calc_unfolded_heat_flux (system ))
356362
357363 samples = Labels (
358364 ["system" ], torch .arange (len (systems ), device = device ).reshape (- 1 , 1 )
0 commit comments