Skip to content

Commit 2b5ad33

Browse files
committed
refactor: replace old Device runtime calls with DeviceGuard/impl operations
1 parent 0b9afd5 commit 2b5ad33

59 files changed

Lines changed: 806 additions & 738 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CMakeLists.txt

Lines changed: 152 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
cmake_minimum_required(VERSION 3.28)
2+
13
option(USE_CUDA "Support NVIDIA CUDA" OFF)
24
option(PROFILE_MODE "ENABLE PROFILE MODE" OFF)
35
option(USE_OMP "Use OpenMP as backend for Eigen" ON)
46
option(USE_NCCL "Build project for distributed running" ON)
5-
cmake_minimum_required(VERSION 3.28)
67

7-
project(infini_train VERSION 0.3.0 LANGUAGES CXX)
8+
project(infini_train VERSION 0.5.0 LANGUAGES CXX)
89

910
set(CMAKE_CXX_STANDARD 20)
1011
set(CMAKE_CXX_STANDARD_REQUIRED ON)
@@ -13,90 +14,186 @@ set(CMAKE_CXX_EXTENSIONS OFF)
1314
# Generate compile_commands.json
1415
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
1516

16-
# Add gflags
17+
# ------------------------------------------------------------------------------
18+
# Third-party deps
19+
# ------------------------------------------------------------------------------
20+
21+
# gflags
1722
add_subdirectory(third_party/gflags)
1823
include_directories(${gflags_SOURCE_DIR}/include)
1924

25+
# glog
2026
set(WITH_GFLAGS OFF CACHE BOOL "Disable glog finding system gflags" FORCE)
2127
set(WITH_GTEST OFF CACHE BOOL "Disable glog finding system gtest" FORCE)
22-
23-
# Add glog
2428
add_subdirectory(third_party/glog)
2529
include_directories(${glog_SOURCE_DIR}/src)
2630

27-
# Add eigen
31+
# eigen
2832
if(USE_OMP)
29-
find_package(OpenMP REQUIRED)
33+
find_package(OpenMP REQUIRED)
3034
endif()
31-
# find_package(OpenBLAS REQUIRED)
32-
# include_directories(${OpenBLAS_INCLUDE_DIR})
3335
add_subdirectory(third_party/eigen)
3436
include_directories(${PROJECT_SOURCE_DIR}/third_party/eigen)
35-
# add_definitions(-DEIGEN_USE_BLAS)
3637

3738
include_directories(${PROJECT_SOURCE_DIR})
38-
file(GLOB_RECURSE SRC ${PROJECT_SOURCE_DIR}/infini_train/src/*.cc)
39-
list(FILTER SRC EXCLUDE REGEX ".*kernels/cpu/.*")
4039

4140
if(PROFILE_MODE)
42-
add_compile_definitions(PROFILE_MODE=1)
41+
add_compile_definitions(PROFILE_MODE=1)
4342
endif()
4443

45-
file (GLOB_RECURSE CPU_KERNELS ${PROJECT_SOURCE_DIR}/infini_train/src/kernels/cpu/*.cc)
44+
# ------------------------------------------------------------------------------
45+
# Sources
46+
# ------------------------------------------------------------------------------
47+
48+
# Framework core sources (*.cc), excluding cpu kernels (they are built separately)
49+
file(GLOB_RECURSE SRC ${PROJECT_SOURCE_DIR}/infini_train/src/*.cc)
50+
list(FILTER SRC EXCLUDE REGEX ".*kernels/cpu/.*")
51+
52+
# CPU kernels (*.cc)
53+
file(GLOB_RECURSE CPU_KERNELS ${PROJECT_SOURCE_DIR}/infini_train/src/kernels/cpu/*.cc)
54+
55+
# ------------------------------------------------------------------------------
56+
# CPU kernels library
57+
# ------------------------------------------------------------------------------
58+
4659
add_library(infini_train_cpu_kernels STATIC ${CPU_KERNELS})
47-
target_link_libraries(infini_train_cpu_kernels glog Eigen3::Eigen)
60+
target_link_libraries(infini_train_cpu_kernels PUBLIC glog Eigen3::Eigen)
61+
4862
if(USE_OMP)
49-
add_compile_definitions(USE_OMP=1)
50-
target_link_libraries(infini_train_cpu_kernels OpenMP::OpenMP_CXX)
63+
add_compile_definitions(USE_OMP=1)
64+
target_link_libraries(infini_train_cpu_kernels PUBLIC OpenMP::OpenMP_CXX)
65+
endif()
66+
67+
# ------------------------------------------------------------------------------
68+
# CUDA kernels library (optional)
69+
# ------------------------------------------------------------------------------
70+
71+
if(USE_CUDA)
72+
add_compile_definitions(USE_CUDA=1)
73+
enable_language(CUDA)
74+
find_package(CUDAToolkit REQUIRED)
75+
include_directories(${CUDAToolkit_INCLUDE_DIRS})
76+
77+
# CUDA compilation options
78+
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr")
79+
80+
# Only compile CUDA kernels / cuda sources here (your original used src/*.cu)
81+
file(GLOB_RECURSE CUDA_KERNELS ${PROJECT_SOURCE_DIR}/infini_train/src/*.cu)
82+
83+
add_library(infini_train_cuda_kernels STATIC ${CUDA_KERNELS})
84+
set_target_properties(infini_train_cuda_kernels PROPERTIES CUDA_ARCHITECTURES "75;80;90")
85+
86+
target_link_libraries(infini_train_cuda_kernels
87+
PUBLIC
88+
glog
89+
CUDA::cudart
90+
CUDA::cublas
91+
CUDA::cuda_driver
92+
)
93+
94+
if(USE_NCCL)
95+
message(STATUS "Add USE_NCCL, use NCCL with CUDA")
96+
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake)
97+
find_package(NCCL REQUIRED)
98+
add_compile_definitions(USE_NCCL=1)
99+
target_link_libraries(infini_train_cuda_kernels PUBLIC nccl)
100+
endif()
51101
endif()
52102

103+
# ------------------------------------------------------------------------------
104+
# Main framework library
105+
# ------------------------------------------------------------------------------
106+
107+
add_library(infini_train STATIC ${SRC})
108+
target_link_libraries(infini_train
109+
PUBLIC
110+
glog
111+
gflags
112+
infini_train_cpu_kernels
113+
)
114+
53115
if(USE_CUDA)
54-
add_compile_definitions(USE_CUDA=1)
55-
enable_language(CUDA)
56-
find_package(CUDAToolkit REQUIRED)
57-
include_directories(${CUDAToolkit_INCLUDE_DIRS})
58-
59-
# enable CUDA-related compilation options
60-
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda --expt-relaxed-constexpr")
61-
file(GLOB_RECURSE CUDA_KERNELS ${PROJECT_SOURCE_DIR}/infini_train/src/*.cu)
62-
add_library(infini_train_cuda_kernels STATIC ${CUDA_KERNELS})
63-
set_target_properties(infini_train_cuda_kernels PROPERTIES CUDA_ARCHITECTURES "75;80;90")
64-
target_link_libraries(infini_train_cuda_kernels glog CUDA::cudart CUDA::cublas CUDA::cuda_driver)
65-
66-
add_library(infini_train STATIC ${SRC})
67-
target_link_libraries(infini_train glog gflags "-Wl,--whole-archive" infini_train_cpu_kernels infini_train_cuda_kernels "-Wl,--no-whole-archive")
68-
69-
if (USE_NCCL)
70-
message(STATUS "Add USE_NCCL, use NCCL with CUDA")
71-
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake)
72-
find_package(NCCL REQUIRED)
73-
add_compile_definitions(USE_NCCL=1)
74-
target_link_libraries(infini_train nccl)
75-
endif()
76-
else()
77-
add_library(infini_train STATIC ${SRC})
78-
target_link_libraries(infini_train glog gflags "-Wl,--whole-archive" infini_train_cpu_kernels "-Wl,--no-whole-archive")
116+
# infini_train contains cuda runtime wrappers (*.cc) like cuda_blas_handle.cc/cuda_guard.cc
117+
# Those may need CUDA runtime/driver/cublas symbols at final link, so attach them here too.
118+
target_link_libraries(infini_train
119+
PUBLIC
120+
infini_train_cuda_kernels
121+
CUDA::cudart
122+
CUDA::cublas
123+
CUDA::cuda_driver
124+
)
125+
126+
if(USE_NCCL)
127+
# If your core library code also directly references NCCL symbols (not only kernels),
128+
# keep this. Otherwise it's harmless.
129+
target_link_libraries(infini_train PUBLIC nccl)
130+
endif()
79131
endif()
80132

133+
# ------------------------------------------------------------------------------
134+
# Helper: link libraries in a group to fix static lib one-pass resolution
135+
# (THIS is what fixes "undefined reference" from cuda_kernels -> core symbols)
136+
# ------------------------------------------------------------------------------
137+
function(link_infini_train_exe target_name)
138+
if(USE_CUDA)
139+
target_link_libraries(${target_name} PRIVATE
140+
"-Wl,--start-group"
141+
"-Wl,--whole-archive"
142+
infini_train
143+
infini_train_cpu_kernels
144+
infini_train_cuda_kernels
145+
"-Wl,--no-whole-archive"
146+
"-Wl,--end-group"
147+
)
148+
else()
149+
target_link_libraries(${target_name} PRIVATE
150+
"-Wl,--start-group"
151+
"-Wl,--whole-archive"
152+
infini_train
153+
infini_train_cpu_kernels
154+
"-Wl,--no-whole-archive"
155+
"-Wl,--end-group"
156+
)
157+
endif()
158+
endfunction()
159+
160+
161+
# ------------------------------------------------------------------------------
81162
# Examples
82-
add_executable(mnist example/mnist/main.cc example/mnist/dataset.cc example/mnist/net.cc)
83-
target_link_libraries(mnist infini_train)
163+
# ------------------------------------------------------------------------------
84164

85-
add_executable(gpt2 example/gpt2/main.cc example/common/tiny_shakespeare_dataset.cc example/common/utils.cc example/gpt2/net.cc example/common/tokenizer.cc)
86-
target_link_libraries(gpt2 infini_train)
165+
add_executable(mnist
166+
example/mnist/main.cc
167+
example/mnist/dataset.cc
168+
example/mnist/net.cc
169+
)
170+
link_infini_train_exe(mnist)
171+
172+
add_executable(gpt2
173+
example/gpt2/main.cc
174+
example/common/tiny_shakespeare_dataset.cc
175+
example/common/utils.cc
176+
example/gpt2/net.cc
177+
example/common/tokenizer.cc
178+
)
179+
link_infini_train_exe(gpt2)
180+
181+
add_executable(llama3
182+
example/llama3/main.cc
183+
example/common/tiny_shakespeare_dataset.cc
184+
example/common/utils.cc
185+
example/llama3/net.cc
186+
example/common/tokenizer.cc
187+
)
188+
link_infini_train_exe(llama3)
87189

88-
add_executable(llama3 example/llama3/main.cc example/common/tiny_shakespeare_dataset.cc example/common/utils.cc example/llama3/net.cc example/common/tokenizer.cc)
89-
target_link_libraries(llama3 infini_train)
190+
# Tools
191+
add_subdirectory(tools/infini_run)
192+
set_target_properties(infini_run PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
90193

194+
# Tests
91195
add_executable(test_hook test/hook/test_hook.cc)
92196
target_link_libraries(test_hook infini_train)
93197

94198
add_executable(test_precision_check test/hook/test_precision_check.cc)
95199
target_link_libraries(test_precision_check infini_train)
96-
97-
add_subdirectory(tools/infini_run)
98-
99-
set_target_properties(infini_run PROPERTIES
100-
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}
101-
)
102-

example/gpt2/main.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "glog/logging.h"
1111

1212
#include "infini_train/include/autocast.h"
13+
#include "infini_train/include/core/device_guard.h"
1314
#include "infini_train/include/dataloader.h"
1415
#include "infini_train/include/device.h"
1516
#include "infini_train/include/nn/modules/loss.h"
@@ -272,7 +273,7 @@ void Train(const nn::parallel::Rank &rank) {
272273
loss_fn->To(device);
273274
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start training";
274275

275-
auto cuda_device = device->IsCUDA() ? dynamic_cast<const CudaDevice *>(device) : nullptr;
276+
auto impl = core::GetDeviceGuardImpl(device.type());
276277

277278
LOG(INFO) << "start training";
278279

@@ -282,8 +283,8 @@ void Train(const nn::parallel::Rank &rank) {
282283

283284
const bool last_step = step == FLAGS_num_iteration;
284285

285-
if (cuda_device) {
286-
cuda_device->ResetMemPoolHighWatermarks();
286+
if (device.IsCUDA()) {
287+
impl->ResetMemPoolHighWatermarks(device);
287288
}
288289

289290
const auto iter_start = std::chrono::high_resolution_clock::now();
@@ -375,8 +376,8 @@ void Train(const nn::parallel::Rank &rank) {
375376

376377
if (rank.IsLastRank()) {
377378
size_t used_mb = 0, reserved_mb = 0;
378-
if (cuda_device) {
379-
std::tie(used_mb, reserved_mb) = cuda_device->GetMemPoolPeakMB();
379+
if (device.IsCUDA()) {
380+
std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device);
380381
}
381382

382383
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "

example/gpt2/net.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,8 @@ GPT2FirstStage::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>>
199199
int tp_rank = 0;
200200
if (tp_world_size > 1) {
201201
auto tp_group = nn::parallel::ProcessGroupFactory::Instance()->Get(
202-
nn::parallel::GetTensorParallelProcessGroupName(device->rank().GlobalRank()));
203-
tp_rank = tp_group->GetGroupRank(device->rank().GlobalRank());
202+
nn::parallel::GetTensorParallelProcessGroupName(device.Rank().GlobalRank()));
203+
tp_rank = tp_group->GetGroupRank(device.Rank().GlobalRank());
204204
}
205205
int64_t t_local = sequence_parallel_enabled ? x1->Dims()[1] / tp_world_size : x1->Dims()[1];
206206
int64_t start = sequence_parallel_enabled ? tp_rank * t_local : 0;

example/llama3/main.cc

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,28 @@
88
#include "glog/logging.h"
99

1010
#include "infini_train/include/autocast.h"
11+
#include "infini_train/include/core/device_guard.h"
1112
#include "infini_train/include/dataloader.h"
1213
#include "infini_train/include/device.h"
1314
#include "infini_train/include/nn/modules/loss.h"
1415
#include "infini_train/include/nn/modules/module.h"
1516
#include "infini_train/include/nn/parallel/ddp/distributed_data_parallel.h"
1617
#include "infini_train/include/nn/parallel/ddp/distributed_optimizer.h"
18+
#include "infini_train/include/nn/parallel/global.h"
1719
#include "infini_train/include/nn/parallel/parallel_functional.h"
1820
#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h"
21+
#include "infini_train/include/nn/parallel/process_group.h"
1922
#include "infini_train/include/nn/parallel/rank.h"
2023
#include "infini_train/include/nn/parallel/reduce_op_type.h"
2124
#include "infini_train/include/nn/parallel/tensor_parallel.h"
22-
#include "infini_train/include/optimizer.h"
23-
#ifdef PROFILE_MODE
24-
#include "infini_train/include/profiler.h"
25-
#endif
26-
#include "infini_train/include/nn/parallel/global.h"
27-
#include "infini_train/include/nn/parallel/process_group.h"
2825
#include "infini_train/include/nn/parallel/utils.h"
26+
#include "infini_train/include/optimizer.h"
2927
#include "infini_train/include/utils/global_module_hook_registry.h"
3028
#include "infini_train/include/utils/precision_check_config.h"
3129
#include "infini_train/include/utils/precision_checker.h"
30+
#ifdef PROFILE_MODE
31+
#include "infini_train/include/profiler.h"
32+
#endif
3233

3334
#include "example/common/tiny_shakespeare_dataset.h"
3435
#include "example/common/tokenizer.h"
@@ -250,16 +251,16 @@ void Train(const nn::parallel::Rank &rank) {
250251
loss_fn->To(device);
251252
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start training";
252253

253-
auto cuda_device = device->IsCUDA() ? dynamic_cast<const CudaDevice *>(device) : nullptr;
254+
auto impl = core::GetDeviceGuardImpl(device.type());
254255

255256
for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
256257
// Reset precision check counters at start of each iteration for file overwrite
257258
utils::PrecisionChecker::ResetCounters();
258259

259260
const bool last_step = step == FLAGS_num_iteration;
260261

261-
if (cuda_device) {
262-
cuda_device->ResetMemPoolHighWatermarks();
262+
if (device.IsCUDA()) {
263+
impl->ResetMemPoolHighWatermarks(device);
263264
}
264265

265266
const auto iter_start = std::chrono::high_resolution_clock::now();
@@ -351,8 +352,8 @@ void Train(const nn::parallel::Rank &rank) {
351352

352353
if (rank.IsLastRank()) {
353354
size_t used_mb = 0, reserved_mb = 0;
354-
if (cuda_device) {
355-
std::tie(used_mb, reserved_mb) = cuda_device->GetMemPoolPeakMB();
355+
if (device.IsCUDA()) {
356+
std::tie(used_mb, reserved_mb) = impl->GetMemPoolPeakMB(device);
356357
}
357358

358359
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "

infini_train/include/autocast.h

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,10 @@
33
#include <string_view>
44
#include <unordered_map>
55

6-
#include "common/common.h"
7-
#include "datatype.h"
8-
#include "device.h"
9-
#include "tensor.h"
10-
11-
#ifdef USE_CUDA
12-
#include <cuda_bf16.h>
13-
#include <cuda_fp16.h>
14-
#endif
6+
#include "infini_train/include/common/common.h"
7+
#include "infini_train/include/datatype.h"
8+
#include "infini_train/include/device.h"
9+
#include "infini_train/include/tensor.h"
1510

1611
namespace infini_train {
1712
namespace {

infini_train/include/common/common.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#pragma once
22

3+
#include <cstdint>
4+
#include <vector>
5+
36
#include "glog/logging.h"
47

58
#include "infini_train/include/datatype.h"

0 commit comments

Comments
 (0)