Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
337 changes: 148 additions & 189 deletions docs/precision_checker_guide.md

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
#include "infini_train/include/profiler.h"
#endif
#include "infini_train/include/nn/parallel/utils.h"
#include "infini_train/include/utils/global_module_hook_registry.h"
#include "infini_train/include/utils/precision_check_config.h"
#include "infini_train/include/utils/precision_checker.h"

#include "example/common/tiny_shakespeare_dataset.h"
#include "example/common/tokenizer.h"
Expand Down Expand Up @@ -257,6 +259,9 @@ void Train(const nn::parallel::Rank &rank) {
LOG(INFO) << "start training";

for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
// Reset precision check counters at start of each iteration for file overwrite
utils::PrecisionChecker::ResetCounters();

const bool last_step = step == FLAGS_num_iteration;

const auto iter_start = std::chrono::high_resolution_clock::now();
Expand Down
5 changes: 5 additions & 0 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
#include "infini_train/include/nn/parallel/global.h"
#include "infini_train/include/nn/parallel/process_group.h"
#include "infini_train/include/nn/parallel/utils.h"
#include "infini_train/include/utils/global_module_hook_registry.h"
#include "infini_train/include/utils/precision_check_config.h"
#include "infini_train/include/utils/precision_checker.h"

#include "example/common/tiny_shakespeare_dataset.h"
#include "example/common/tokenizer.h"
Expand Down Expand Up @@ -232,6 +234,9 @@ void Train(const nn::parallel::Rank &rank) {
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start training";

for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
// Reset precision check counters at start of each iteration for file overwrite
utils::PrecisionChecker::ResetCounters();

const bool last_step = step == FLAGS_num_iteration;

const auto iter_start = std::chrono::high_resolution_clock::now();
Expand Down
1 change: 0 additions & 1 deletion infini_train/include/nn/modules/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ class Module : public std::enable_shared_from_this<Module> {
std::vector<ModulePostHook> forward_post_hooks_;
std::vector<ModulePreHook> backward_pre_hooks_;
std::vector<ModulePostHook> backward_post_hooks_;
bool precision_check_registered_ = false;

private:
std::unordered_map<std::string, std::shared_ptr<Module>>
Expand Down
55 changes: 55 additions & 0 deletions infini_train/include/utils/global_module_hook_registry.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#pragma once

#include "infini_train/include/common/hook.h"
#include "infini_train/include/tensor.h"
#include <functional>
#include <memory>
#include <mutex>
#include <vector>

namespace infini_train {
namespace nn {
class Module;
}

namespace utils {

// Global Module Hook Registry
// Global hooks that are executed on every forward/backward pass
class GlobalModuleHookRegistry {
public:
using ModuleForwardPreHook = std::function<void(nn::Module *, const std::vector<std::shared_ptr<Tensor>> &inputs)>;

using ModuleForwardHook = std::function<void(nn::Module *, const std::vector<std::shared_ptr<Tensor>> &inputs,
const std::vector<std::shared_ptr<Tensor>> &outputs)>;

using ModuleFullBackwardHook
= std::function<void(nn::Module *, const std::vector<std::shared_ptr<Tensor>> &grad_outputs,
const std::vector<std::shared_ptr<Tensor>> &grad_inputs)>;

static GlobalModuleHookRegistry &Instance();

// PyTorch-style registration: RegisterModule* prefix
std::unique_ptr<HookHandle> RegisterModuleForwardPreHook(ModuleForwardPreHook hook);
std::unique_ptr<HookHandle> RegisterModuleForwardHook(ModuleForwardHook hook);
std::unique_ptr<HookHandle> RegisterModuleFullBackwardHook(ModuleFullBackwardHook hook);

// Call hooks (called by Module::operator())
void CallModuleForwardPreHooks(nn::Module *module, const std::vector<std::shared_ptr<Tensor>> &inputs);
void CallModuleForwardHooks(nn::Module *module, const std::vector<std::shared_ptr<Tensor>> &inputs,
const std::vector<std::shared_ptr<Tensor>> &outputs);
void CallModuleFullBackwardHooks(nn::Module *module, const std::vector<std::shared_ptr<Tensor>> &grad_outputs,
const std::vector<std::shared_ptr<Tensor>> &grad_inputs);
bool HasModuleBackwardHooks() const;

private:
GlobalModuleHookRegistry() = default;

std::vector<ModuleForwardPreHook> module_forward_pre_hooks_;
std::vector<ModuleForwardHook> module_forward_hooks_;
std::vector<ModuleFullBackwardHook> module_full_backward_hooks_;
mutable std::mutex mutex_;
};

} // namespace utils
} // namespace infini_train
16 changes: 12 additions & 4 deletions infini_train/include/utils/precision_check_config.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <string>
#include <unordered_map>

namespace infini_train {
namespace utils {
Expand All @@ -9,10 +10,11 @@ enum class PrecisionCheckLevel { OFF = 0, MODULE = 1, FUNCTION = 2 };

struct PrecisionCheckConfig {
PrecisionCheckLevel level = PrecisionCheckLevel::OFF;
std::string output_path = ""; // empty=console(rank0), non-empty=file(all ranks)
bool output_md5 = false; // output MD5 hash or tensor values
std::string format = "simple"; // "simple" or "table"
std::string baseline_path = ""; // baseline file path for comparison
std::string output_path = "./log_precision_check"; // Output path (default)
std::string format = "simple"; // "simple" or "md5"
bool save_tensors = false; // Whether to output .npy file
double md5_tolerance = 0.0; // MD5 tolerance for quantization (e.g., 1e-3)
// 0 means no quantization (original precision)

// Parse from "key=value,key=value" string
static PrecisionCheckConfig Parse(const std::string &config_str);
Expand All @@ -23,10 +25,16 @@ class PrecisionCheckEnv {
static PrecisionCheckEnv &Instance();
void Init(const PrecisionCheckConfig &config);
const PrecisionCheckConfig &GetConfig() const;
const std::string &GetOutputPath() const;

// Tensor counter management for file overwrite across iterations (thread-local)
static int GetAndIncrementCounter(const std::string &key);
static void ResetCounters();

private:
PrecisionCheckEnv() = default;
PrecisionCheckConfig config_;
std::string timestamped_path_ = ""; // Actual output path (with timestamp)
};

} // namespace utils
Expand Down
9 changes: 9 additions & 0 deletions infini_train/include/utils/precision_checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <string>
#include <vector>

#include "infini_train/include/utils/precision_check_config.h"

namespace infini_train {
class Tensor;
class HookHandle;
Expand Down Expand Up @@ -32,13 +34,20 @@ class PrecisionChecker {
return default_config;
}

// Initialize global module-level precision checking
// Called automatically by PrecisionCheckEnv::Init when level >= MODULE
static void Init(const PrecisionCheckConfig &global_config, const Config &config = DefaultConfig());

static void RegisterForFunction(autograd::Function *func, const std::string &name = "",
const Config &config = DefaultConfig());

// Register hooks for a Module (checks forward inputs/outputs)
static void RegisterForModule(nn::Module *module, const std::string &name = "",
const Config &config = DefaultConfig());

// Reset tensor counters (call at start of each iteration for file overwrite)
static void ResetCounters();

private:
static void CheckTensors(const std::string &stage, const std::string &name,
const std::vector<std::shared_ptr<Tensor>> &tensors, const Config &config);
Expand Down
44 changes: 27 additions & 17 deletions infini_train/src/nn/modules/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
#include "infini_train/include/device.h"
#include "infini_train/include/nn/parallel/global.h"
#include "infini_train/include/tensor.h"
#include "infini_train/include/utils/precision_check_config.h"
#include "infini_train/include/utils/precision_checker.h"
#include "infini_train/include/utils/global_module_hook_registry.h"

#ifndef UNLIKELY
#define UNLIKELY(x) __builtin_expect(!!(x), 0)
Expand Down Expand Up @@ -135,37 +134,37 @@ std::vector<std::shared_ptr<Tensor>> Module::Forward(const std::vector<std::shar
}

std::vector<std::shared_ptr<Tensor>> Module::operator()(const std::vector<std::shared_ptr<Tensor>> &input_tensors) {
// Register precision check hooks if enabled and not already registered
// TODO(cx): move RegisterForModule to PrecisionChecker and avoid duplicate registration
if (!precision_check_registered_) {
auto precision_level = utils::PrecisionCheckEnv::Instance().GetConfig().level;
if (precision_level == utils::PrecisionCheckLevel::MODULE) {
utils::PrecisionChecker::RegisterForModule(this);
precision_check_registered_ = true;
}
}
// 1. Call global module forward pre-hooks
utils::GlobalModuleHookRegistry::Instance().CallModuleForwardPreHooks(this, input_tensors);

// Call forward pre-hooks
// 2. Call local forward pre-hooks
for (const auto &hook : forward_pre_hooks_) {
if (hook) {
hook(this, input_tensors);
}
}

// Call actual Forward implementation
// 3. Call actual Forward implementation
auto output_tensors = Forward(input_tensors);

// Call forward post-hooks
// 4. Call local forward post-hooks
for (const auto &hook : forward_post_hooks_) {
if (hook) {
hook(this, input_tensors, output_tensors);
}
}

// Register backward hooks on output tensors' grad_fn
if (UNLIKELY(!backward_pre_hooks_.empty() || !backward_post_hooks_.empty())) {
// 5. Call global module forward hooks
utils::GlobalModuleHookRegistry::Instance().CallModuleForwardHooks(this, input_tensors, output_tensors);

// 6. Register backward hooks on output tensors' grad_fn
const bool has_local_backward_hooks = !backward_pre_hooks_.empty() || !backward_post_hooks_.empty();
const bool has_global_backward_hooks = utils::GlobalModuleHookRegistry::Instance().HasModuleBackwardHooks();

if (UNLIKELY(has_local_backward_hooks || has_global_backward_hooks)) {
for (const auto &output : output_tensors) {
if (output && output->grad_fn()) {
if (output && output->output_idx() == 0 && output->grad_fn()) {
// Local backward prehooks
if (!backward_pre_hooks_.empty()) {
output->grad_fn()->RegisterBackwardPreHook(
[this](autograd::Function *, const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
Expand All @@ -176,6 +175,7 @@ std::vector<std::shared_ptr<Tensor>> Module::operator()(const std::vector<std::s
}
});
}
// Local backward post-hooks
if (!backward_post_hooks_.empty()) {
output->grad_fn()->RegisterBackwardPostHook(
[this](autograd::Function *, const std::vector<std::shared_ptr<Tensor>> &grad_inputs,
Expand All @@ -187,6 +187,16 @@ std::vector<std::shared_ptr<Tensor>> Module::operator()(const std::vector<std::s
}
});
}
// Global backward hooks
if (has_global_backward_hooks) {
output->grad_fn()->RegisterBackwardPostHook(
[this](autograd::Function *, const std::vector<std::shared_ptr<Tensor>> &grad_inputs,
const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
// Registry convention: (grad_outputs, grad_inputs) - PyTorch style
utils::GlobalModuleHookRegistry::Instance().CallModuleFullBackwardHooks(this, grad_outputs,
grad_inputs);
});
}
}
}
}
Expand Down
80 changes: 80 additions & 0 deletions infini_train/src/utils/global_module_hook_registry.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#include "infini_train/include/utils/global_module_hook_registry.h"

namespace infini_train::utils {

GlobalModuleHookRegistry &GlobalModuleHookRegistry::Instance() {
static GlobalModuleHookRegistry instance;
return instance;
}

std::unique_ptr<HookHandle> GlobalModuleHookRegistry::RegisterModuleForwardPreHook(ModuleForwardPreHook hook) {
std::lock_guard<std::mutex> lock(mutex_);
module_forward_pre_hooks_.push_back(std::move(hook));
return std::make_unique<HookHandleImpl<ModuleForwardPreHook>>(&module_forward_pre_hooks_,
module_forward_pre_hooks_.size() - 1);
}

std::unique_ptr<HookHandle> GlobalModuleHookRegistry::RegisterModuleForwardHook(ModuleForwardHook hook) {
std::lock_guard<std::mutex> lock(mutex_);
module_forward_hooks_.push_back(std::move(hook));
return std::make_unique<HookHandleImpl<ModuleForwardHook>>(&module_forward_hooks_,
module_forward_hooks_.size() - 1);
}

std::unique_ptr<HookHandle> GlobalModuleHookRegistry::RegisterModuleFullBackwardHook(ModuleFullBackwardHook hook) {
std::lock_guard<std::mutex> lock(mutex_);
module_full_backward_hooks_.push_back(std::move(hook));
return std::make_unique<HookHandleImpl<ModuleFullBackwardHook>>(&module_full_backward_hooks_,
module_full_backward_hooks_.size() - 1);
}

void GlobalModuleHookRegistry::CallModuleForwardPreHooks(nn::Module *module,
const std::vector<std::shared_ptr<Tensor>> &inputs) {
std::vector<ModuleForwardPreHook> snapshot;
{
std::lock_guard<std::mutex> lock(mutex_);
snapshot = module_forward_pre_hooks_;
}
for (const auto &hook : snapshot) {
if (hook) {
hook(module, inputs);
}
}
}

void GlobalModuleHookRegistry::CallModuleForwardHooks(nn::Module *module,
const std::vector<std::shared_ptr<Tensor>> &inputs,
const std::vector<std::shared_ptr<Tensor>> &outputs) {
std::vector<ModuleForwardHook> snapshot;
{
std::lock_guard<std::mutex> lock(mutex_);
snapshot = module_forward_hooks_;
}
for (const auto &hook : snapshot) {
if (hook) {
hook(module, inputs, outputs);
}
}
}

void GlobalModuleHookRegistry::CallModuleFullBackwardHooks(nn::Module *module,
const std::vector<std::shared_ptr<Tensor>> &grad_outputs,
const std::vector<std::shared_ptr<Tensor>> &grad_inputs) {
std::vector<ModuleFullBackwardHook> snapshot;
{
std::lock_guard<std::mutex> lock(mutex_);
snapshot = module_full_backward_hooks_;
}
for (const auto &hook : snapshot) {
if (hook) {
hook(module, grad_outputs, grad_inputs);
}
}
}

bool GlobalModuleHookRegistry::HasModuleBackwardHooks() const {
std::lock_guard<std::mutex> lock(mutex_);
return !module_full_backward_hooks_.empty();
}

} // namespace infini_train::utils
Loading