Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/KOKKOS/pair_metatomic_kokkos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,12 @@ void PairMetatomicKokkos<DeviceType>::init_style() {
this->type_mapping_kk = Kokkos::View<int32_t*, Kokkos::LayoutRight, DeviceType>("type_mapping_kk", atom->ntypes + 1);
Kokkos::deep_copy(this->type_mapping_kk, type_mapping_kk_host);

using NCMode = PairMetatomicData::NonConservativeMode;
auto options = MetatomicSystemOptions{
this->type_mapping_kk.data(),
mta_data->max_cutoff,
mta_data->check_consistency,
!(mta_data->non_conservative),
mta_data->non_conservative != NCMode::ON, // autograd needed for OFF/FORCES/STRESS
};

// override the system adaptor with the kokkos version
Expand Down Expand Up @@ -112,6 +113,7 @@ void PairMetatomicKokkos<DeviceType>::pick_device(torch::Device& device, const c

template<class DeviceType>
void PairMetatomicKokkos<DeviceType>::store_forces(const at::Tensor& forces_tensor) {
using NCMode = PairMetatomicData::NonConservativeMode;
assert(forces_tensor.scalar_type() == torch::kFloat64);
auto forces = forces_tensor.contiguous();

Expand All @@ -131,8 +133,8 @@ void PairMetatomicKokkos<DeviceType>::store_forces(const at::Tensor& forces_tens
}
);

// in non-conservative mode we do not need to update forces on ghost atoms
if (!mta_data->non_conservative) {
// ghost atom forces only exist when forces come from autograd
if (mta_data->non_conservative == NCMode::OFF || mta_data->non_conservative == NCMode::STRESS) {
auto system_adaptor_kk = dynamic_cast<MetatomicSystemAdaptorKokkos<DeviceType>*>(this->system_adaptor.get());
assert(system_adaptor_kk != nullptr);
auto mta_to_lmp_kk = UnmanagedView<int32_t*, DeviceType>(
Expand Down
5 changes: 3 additions & 2 deletions src/ML-METATOMIC/metatomic_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ struct PairMetatomicData: public CommonMetatomicData {
metatomic_torch::ModelOutput nc_forces_output;
metatomic_torch::ModelOutput nc_stress_output;

// whether non-conservative forces and stresses should be used
bool non_conservative = false;
// which non-conservative outputs to use
enum class NonConservativeMode { OFF, ON, FORCES, STRESS };
NonConservativeMode non_conservative = NonConservativeMode::OFF;

// energy key for the model
std::string energy_key;
Expand Down
167 changes: 113 additions & 54 deletions src/ML-METATOMIC/pair_metatomic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,17 @@ void PairMetatomic::settings(int argc, char ** argv) {
i += 1;
} else if (strcmp(argv[i], "non_conservative") == 0) {
if (i == argc - 1) {
error->one(FLERR, "expected <on/off> after 'non_conservative' in pair_style metatomic, got nothing");
error->one(FLERR, "expected <on/off/forces/stress> after 'non_conservative' in pair_style metatomic, got nothing");
} else if (strcmp(argv[i + 1], "on") == 0) {
mta_data->non_conservative = true;
mta_data->non_conservative = PairMetatomicData::NonConservativeMode::ON;
} else if (strcmp(argv[i + 1], "off") == 0) {
mta_data->non_conservative = false;
mta_data->non_conservative = PairMetatomicData::NonConservativeMode::OFF;
} else if (strcmp(argv[i + 1], "forces") == 0) {
mta_data->non_conservative = PairMetatomicData::NonConservativeMode::FORCES;
} else if (strcmp(argv[i + 1], "stress") == 0) {
mta_data->non_conservative = PairMetatomicData::NonConservativeMode::STRESS;
} else {
error->one(FLERR, "expected <on/off> after 'non_conservative' in pair_style metatomic, got '{}'", argv[i + 1]);
error->one(FLERR, "expected <on/off/forces/stress> after 'non_conservative' in pair_style metatomic, got '{}'", argv[i + 1]);
}

i += 1;
Expand Down Expand Up @@ -267,29 +271,68 @@ void PairMetatomic::settings(int argc, char ** argv) {
}

// Handle non-conservative variants
if (mta_data->non_conservative) {
// Error if *both* nc-force and nc-stress were provided by user AND one is Null
bool user_set_forces = (variant_nc_forces != nullptr);
bool user_set_stress = (variant_nc_stress != nullptr);

if (user_set_forces && user_set_stress) {

bool forces_none = !normalize_variant(variant_nc_forces).has_value();
bool stress_none = !normalize_variant(variant_nc_stress).has_value();
using NCMode = PairMetatomicData::NonConservativeMode;
const auto nc_mode = mta_data->non_conservative;

bool user_set_forces = (variant_nc_forces != nullptr);
bool user_set_stress = (variant_nc_stress != nullptr);

// Warn if the user set an explicit variant for an output that the chosen
// mode does not use.
if (user_set_forces && nc_mode != NCMode::ON && nc_mode != NCMode::FORCES) {
error->warning(FLERR,
"'variant/non_conservative_forces' was set but the current 'non_conservative' mode "
"does not use non-conservative forces; the variant will be ignored."
);
}
if (user_set_stress && nc_mode != NCMode::ON && nc_mode != NCMode::STRESS) {
error->warning(FLERR,
"'variant/non_conservative_stress' was set but the current 'non_conservative' mode "
"does not use non-conservative stress; the variant will be ignored."
);
}

if (forces_none != stress_none) {
error->one(FLERR,
"if both 'variant/non_conservative_stress' and "
"'variant/non_conservative_forces' are present, they "
"must either both be 'off' or both not 'off'");
}
// Error if *both* nc-force and nc-stress were provided by user AND one is Null
if (nc_mode == NCMode::ON && user_set_forces && user_set_stress) {
bool forces_none = !normalize_variant(variant_nc_forces).has_value();
bool stress_none = !normalize_variant(variant_nc_stress).has_value();
if (forces_none != stress_none) {
error->one(FLERR,
"if both 'variant/non_conservative_stress' and "
"'variant/non_conservative_forces' are present with 'non_conservative on', "
"they must either both be 'off' or both not 'off'");
}
}

bool do_nc_forces = (nc_mode == NCMode::ON || nc_mode == NCMode::FORCES);
if (do_nc_forces) {
try {
mta_data->nc_forces_key = pick_output("non_conservative_forces", outputs, v_nc_forces);
} catch (std::exception& e) {
error->one(FLERR,
"{}Failed to select 'non_conservative_forces' output. "
"If the model does not support non-conservative forces, use "
"'non_conservative stress' or 'non_conservative off'. "
"If the model provides multiple variants, select one with "
"'variant/non_conservative_forces <name>'.",
e.what()
);
}
}

bool do_nc_stress = (nc_mode == NCMode::ON || nc_mode == NCMode::STRESS);
if (do_nc_stress) {
try {
mta_data->nc_stress_key = pick_output("non_conservative_stress", outputs, v_nc_stress);
} catch (std::exception& e) {
error->one(FLERR, e.what());
error->one(FLERR,
"{}Failed to select 'non_conservative_stress' output. "
"If the model does not support non-conservative stress, use "
"'non_conservative forces' or 'non_conservative off'. "
"If the model provides multiple variants, select one with "
"'variant/non_conservative_stress <name>'.",
e.what()
);
Comment on lines +328 to +335
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this looks like if the user made a typo in the variant name?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arghh yes this might look. Let me try if we can make this better.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the error message it now looks like:

ERROR on proc 0: output 'non_conservative_stress' not found in outputs
Exception raised from pick_output at /Users/runner/work/metatomic/metatomic/metatomic-torch/src/misc.cpp:188 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>) + 56 (0x105f3fd7c in libc10.dylib)
frame #1: metatomic_torch::pick_output(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, c10::Dict<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, c10::intrusive_ptr<metatomic_torch::ModelOutputHolder, c10::detail::intrusive_target_default_null_type<metatomic_torch::ModelOutputHolder>>>, std::__1::optional<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>>) + 1548 (0x1064189e0 in libmetatomic_torch.dylib)
frame #2: LAMMPS_NS::PairMetatomic::settings(int, char**) + 4152 (0x104e32df8 in lmp)
frame #3: LAMMPS_NS::Input::pair_style() + 1460 (0x104b2ef78 in lmp)
frame #4: LAMMPS_NS::Input::execute_command() + 2380 (0x104b26b08 in lmp)
frame #5: LAMMPS_NS::Input::file() + 768 (0x104b259c0 in lmp)
frame #6: main + 80 (0x10490af98 in lmp)
frame #7: start + 6992 (0x18925fda4 in dyld)
Failed to select 'non_conservative_stress' output. If the model does not support non-conservative stress, use 'non_conservative forces' or 'non_conservative off'. If the model provides multiple variants, select one with 'variant/non_conservative_stress <name>'. (src/ML-METATOMIC/pair_metatomic.cpp:328)

}
}

Expand Down Expand Up @@ -331,7 +374,7 @@ void PairMetatomic::settings(int argc, char ** argv) {
}
}

if (mta_data->non_conservative) {
if (do_nc_forces) {
auto nc_forces = outputs.find(mta_data->nc_forces_key);
if (nc_forces == outputs.end()) {
error->one(FLERR,
Expand All @@ -348,12 +391,13 @@ void PairMetatomic::settings(int argc, char ** argv) {
mta_data->nc_forces_key, model_path
);
}

mta_data->nc_forces_output = torch::make_intrusive<metatomic_torch::ModelOutputHolder>();
mta_data->nc_forces_output->set_quantity("force");
mta_data->nc_forces_output->set_unit(this->energy_unit + "/" + this->length_unit);
mta_data->nc_forces_output->per_atom = true;
}

if (do_nc_stress) {
auto nc_stress = outputs.find(mta_data->nc_stress_key);
if (nc_stress != outputs.end()) {
mta_data->nc_stress_output = torch::make_intrusive<metatomic_torch::ModelOutputHolder>();
Expand Down Expand Up @@ -505,6 +549,9 @@ void PairMetatomic::coeff(int argc, char ** argv) {

// called when the run starts
void PairMetatomic::init_style() {
using NCMode = PairMetatomicData::NonConservativeMode;
const auto nc_mode = mta_data->non_conservative;

// Require newton pair on since we need to communicate forces accumulated on
// ghost atoms to neighboring domains. These forces contributions come from
// gradient of a local descriptor w.r.t. domain ghosts (periodic images
Expand Down Expand Up @@ -549,7 +596,7 @@ void PairMetatomic::init_style() {
this->type_mapping,
mta_data->max_cutoff,
mta_data->check_consistency,
!(mta_data->non_conservative),
nc_mode != NCMode::ON, // autograd needed for OFF/FORCES/STRESS
};
this->system_adaptor = std::make_unique<MetatomicSystemAdaptor>(lmp, options);

Expand All @@ -576,6 +623,9 @@ void PairMetatomic::init_list(int id, NeighList *ptr) {
}

void PairMetatomic::compute(int eflag, int vflag) {
using NCMode = PairMetatomicData::NonConservativeMode;
const auto nc_mode = mta_data->non_conservative;

if (std::getenv("LAMMPS_METATOMIC_PROFILE") != nullptr) {
MetatomicTimer::enable(true);
} else {
Expand All @@ -589,8 +639,15 @@ void PairMetatomic::compute(int eflag, int vflag) {
mta_data->evaluation_options->outputs.clear();
// we need an energy output if the energy was explicitly requested (through
// `eflag_either`), or when running in standard/conservative mode, because
// we'll get the forces as the gradient of the energy through autodiff.
if (eflag_either || !mta_data->non_conservative) {
// we'll get the forces and stress as the gradient of the energy through autodiff.
auto need_energy_for_autograd = (nc_mode == NCMode::OFF
|| nc_mode == NCMode::STRESS
|| (nc_mode == NCMode::FORCES && vflag_global));

auto do_nc_forces = nc_mode == NCMode::ON || nc_mode == NCMode::FORCES;
auto do_nc_stress = nc_mode == NCMode::ON || nc_mode == NCMode::STRESS;

if (eflag_either || need_energy_for_autograd) {
if (eflag_atom) {
if (!mta_data->is_energy_output_per_atom) {
error->one(FLERR,
Expand All @@ -609,18 +666,11 @@ void PairMetatomic::compute(int eflag, int vflag) {
mta_data->evaluation_options->outputs.insert(mta_data->energy_uq_key, mta_data->uncertainty_output);
}

if (mta_data->non_conservative) {
if (do_nc_forces) {
mta_data->evaluation_options->outputs.insert(mta_data->nc_forces_key, mta_data->nc_forces_output);
if (vflag_global) {
if (mta_data->nc_stress_output == nullptr) {
error->one(FLERR,
"the model at '{}' does not have a '{}' output, "
"we can not run non_conservative simulations that require computing the stress/virial",
mta_data->model_path, mta_data->nc_stress_key
);
}
mta_data->evaluation_options->outputs.insert(mta_data->nc_stress_key, mta_data->nc_stress_output);
}
}
if (vflag_global && do_nc_stress) {
mta_data->evaluation_options->outputs.insert(mta_data->nc_stress_key, mta_data->nc_stress_output);
}

auto dtype = torch::kFloat64;
Expand All @@ -635,7 +685,7 @@ void PairMetatomic::compute(int eflag, int vflag) {
// transform from LAMMPS to metatomic System
auto system = this->system_adaptor->system_from_lmp(
mta_list,
static_cast<bool>(vflag_global),
vflag_global && !do_nc_stress,
dtype,
mta_data->device
);
Expand Down Expand Up @@ -712,7 +762,7 @@ void PairMetatomic::compute(int eflag, int vflag) {

// get the energy if we need to compute the energy, or if we are using it to
// get the forces/virial with autograd
if (eflag_either || !mta_data->non_conservative) {
if (eflag_either || need_energy_for_autograd) {
auto energy = results.at(mta_data->energy_key).toCustomClass<metatensor_torch::TensorMapHolder>();
auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(energy, 0);
energy_tensor = energy_block->values();
Expand All @@ -722,30 +772,37 @@ void PairMetatomic::compute(int eflag, int vflag) {
torch::Tensor forces_tensor;
torch::Tensor virial_tensor;

if (mta_data->non_conservative) {
// get non-conservative forces
if (do_nc_forces) {
auto forces = results.at(mta_data->nc_forces_key).toCustomClass<metatensor_torch::TensorMapHolder>();
auto forces_block = metatensor_torch::TensorMapHolder::block_by_id(forces, 0);
forces_tensor = forces_block->values().squeeze(-1);
forces_tensor = forces_tensor.to(torch::kCPU).to(torch::kFloat64);
}

if (vflag_global) {
auto stress = results.at(mta_data->nc_stress_key).toCustomClass<metatensor_torch::TensorMapHolder>();
auto stress_block = metatensor_torch::TensorMapHolder::block_by_id(stress, 0);
auto stress_tensor = stress_block->values().squeeze(0).squeeze(-1);
virial_tensor = - stress_tensor * compute_volume(domain);
virial_tensor = virial_tensor.to(torch::kCPU).to(torch::kFloat64);
}
} else {
// compute forces/virial on device with backward propagation
// reset gradients to zero before calling backward
// get non-conservative stress
if (vflag_global && do_nc_stress) {
auto stress = results.at(mta_data->nc_stress_key).toCustomClass<metatensor_torch::TensorMapHolder>();
auto stress_block = metatensor_torch::TensorMapHolder::block_by_id(stress, 0);
auto stress_tensor = stress_block->values().squeeze(0).squeeze(-1);
virial_tensor = - stress_tensor * compute_volume(domain);
virial_tensor = virial_tensor.to(torch::kCPU).to(torch::kFloat64);
}

// compute conservative quantities through autograd
if (need_energy_for_autograd) {
this->system_adaptor->positions.mutable_grad() = torch::Tensor();
this->system_adaptor->strain.mutable_grad() = torch::Tensor();

auto _ = MetatomicTimer("running Model::backward");
energy_tensor.backward(-torch::ones_like(energy_tensor));

forces_tensor = this->system_adaptor->positions.grad();
virial_tensor = this->system_adaptor->strain.grad();
if (!do_nc_forces) {
forces_tensor = this->system_adaptor->positions.grad();
}
if (vflag_global && !do_nc_stress) {
virial_tensor = this->system_adaptor->strain.grad();
}
}

{
Expand Down Expand Up @@ -802,7 +859,7 @@ void PairMetatomic::compute(int eflag, int vflag) {

assert(!vflag_fdotr);

if (vflag_global) {
if (vflag_global && virial_tensor.defined()) {
auto virial_cpu = virial_tensor.to(torch::kCPU);
assert(virial_cpu.is_cpu() && virial_cpu.scalar_type() == torch::kFloat64);

Expand Down Expand Up @@ -833,8 +890,10 @@ void PairMetatomic::store_forces(const at::Tensor& forces_tensor) {
atom->f[i][2] += this->scale * forces[i][2];
}

// in non-conservative mode we do not need to update forces on ghost atoms
if (!mta_data->non_conservative) {
// ghost atom forces only exist when forces come from autograd
using NCMode = PairMetatomicData::NonConservativeMode;
const auto nc_mode = mta_data->non_conservative;
if (nc_mode == NCMode::OFF || nc_mode == NCMode::STRESS) {
const auto& mta_to_lmp = this->system_adaptor->mta_to_lmp;
for (int i=atom->nlocal; i<forces.size(0); i++) {
atom->f[mta_to_lmp[i]][0] += this->scale * forces[i][0];
Expand Down