Skip to content

Commit 36d173e

Browse files
committed
Refactor: add new API alongside legacy interfaces with deprecation warnings
1 parent ea019e0 commit 36d173e

23 files changed

Lines changed: 663 additions & 57 deletions

=0.34.0,

Whitespace-only changes.

csrc/engine/infer_engine.cpp

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,65 @@
11
#include "infer_engine.hpp"
22
#include "spdlog/spdlog.h"
3-
#include <iostream>
43

54
namespace infinilm::engine {
65

76
//------------------------------------------------------
87
// Constructor
98
//------------------------------------------------------
9+
/**
10+
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
11+
*
12+
* ⚠️ DEVELOPMENT POLICY:
13+
* - NO new development or feature additions permitted on this interface
14+
* - Only critical bug fixes (security/stability) allowed until removal
15+
* - All new code MUST migrate to the polymorphic overload below
16+
*
17+
* Replacement: Use the polymorphic overload of this same function name with updated signature
18+
* Reason: Legacy signature lacks support for dynamic quantization modes.
19+
* Removal target: v0.2.0 (Q2 2026)
20+
*/
1021
InferEngine::InferEngine(
22+
const InfinilmModel::Config &config,
1123
const distributed::DistConfig &distributed_config,
1224
infinicore::Device::Type device_type,
1325
const cache::CacheConfig *cache_config,
26+
bool enable_graph_compiling) // Changed parameter
27+
: communication_group_(distributed_config, device_type),
28+
legacy_model_config_(config) {
29+
30+
if (cache_config != nullptr) {
31+
cache_config_ = cache_config->unique_copy();
32+
}
33+
// Create one RankWorker per rank
34+
int world_size = communication_group_.get_world_size();
35+
barrier_ = std::make_unique<RankBarrier>((size_t)world_size);
36+
workers_.reserve(world_size);
37+
for (int r = 0; r < world_size; ++r) {
38+
workers_.emplace_back(std::make_unique<RankWorker>(
39+
legacy_model_config_,
40+
communication_group_.get_rank_info(r),
41+
cache_config_ != nullptr ? cache_config_.get() : nullptr,
42+
barrier_.get(),
43+
enable_graph_compiling));
44+
}
45+
46+
// Compile the model on all workers
47+
this->compile();
48+
}
49+
50+
InferEngine::InferEngine(
1451
const std::string &model_path,
52+
const distributed::DistConfig &distributed_config,
53+
infinicore::Device::Type device_type,
54+
const cache::CacheConfig *cache_config,
1555
bool enable_graph_compiling) // Changed parameter
1656
: communication_group_(distributed_config, device_type) {
17-
1857
if (cache_config != nullptr) {
1958
cache_config_ = cache_config->unique_copy();
2059
}
2160

2261
// Load model config if model_path is provided, model_path must be valid, and config.json exists
2362
this->model_config_ = std::make_shared<infinilm::config::ModelConfig>(model_path + "/config.json");
24-
2563
// Create one RankWorker per rank
2664
int world_size = communication_group_.get_world_size();
2765
barrier_ = std::make_unique<RankBarrier>((size_t)world_size);
@@ -34,7 +72,6 @@ InferEngine::InferEngine(
3472
barrier_.get(),
3573
enable_graph_compiling));
3674
}
37-
3875
// Compile the model on all workers
3976
this->compile();
4077
}

csrc/engine/infer_engine.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,30 @@ class InferEngine {
2020
using Output = RankWorker::Output;
2121

2222
// Updated constructor: accept CacheConfig instead of CacheType
23+
/**
24+
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
25+
*
26+
* ⚠️ DEVELOPMENT POLICY:
27+
* - NO new development or feature additions permitted on this interface
28+
* - Only critical bug fixes (security/stability) allowed until removal
29+
* - All new code MUST migrate to the polymorphic overload below
30+
*
31+
* Replacement: Use the polymorphic overload of this same function name with updated signature
32+
* Reason: Legacy signature lacks support for dynamic quantization modes.
33+
* Removal target: v0.2.0 (Q2 2026)
34+
*/
2335
InferEngine(
36+
const InfinilmModel::Config &config,
2437
const distributed::DistConfig &distributed_config = distributed::DistConfig(),
2538
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
2639
const cache::CacheConfig *cache_config = nullptr,
40+
bool enable_graph_compiling = false);
41+
42+
InferEngine(
2743
const std::string &model_path = "",
44+
const distributed::DistConfig &distributed_config = distributed::DistConfig(),
45+
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
46+
const cache::CacheConfig *cache_config = nullptr,
2847
bool enable_graph_compiling = false);
2948

3049
// Load a parameter to all workers (each can extract its shard inside RankWorker)
@@ -52,6 +71,7 @@ class InferEngine {
5271
std::unique_ptr<RankBarrier> barrier_;
5372
distributed::CommunicationGroup communication_group_;
5473
std::unique_ptr<cache::CacheConfig> cache_config_;
74+
const InfinilmModel::Config &legacy_model_config_ = InfinilmModel::Config();
5575
std::shared_ptr<infinilm::config::ModelConfig> model_config_;
5676
};
5777

csrc/engine/rank_worker.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,48 @@
44

55
#include "infinicore/ops.hpp"
66

7-
#include <iostream>
87
#include <spdlog/spdlog.h>
98
#include <stdexcept>
109

1110
namespace infinilm::engine {
1211

12+
/**
13+
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
14+
*
15+
* ⚠️ DEVELOPMENT POLICY:
16+
* - NO new development or feature additions permitted on this interface
17+
* - Only critical bug fixes (security/stability) allowed until removal
18+
* - All new code MUST migrate to the polymorphic overload below
19+
*
20+
* Replacement: Use the polymorphic overload of this same function name with updated signature
21+
* Reason: Legacy signature lacks support for dynamic quantization modes.
22+
* Removal target: v0.2.0 (Q2 2026)
23+
*/
24+
RankWorker::RankWorker(const InfinilmModel::Config &model_config,
25+
const distributed::RankInfo &rank_info,
26+
const cache::CacheConfig *cache_config,
27+
RankBarrier *barrier,
28+
bool enable_graph_compiling)
29+
: legacy_model_config_(model_config),
30+
rank_info_(rank_info),
31+
enable_graph_compiling_(enable_graph_compiling),
32+
job_cmd_(Command::INIT),
33+
has_job_(false),
34+
job_done_(false),
35+
should_exit_(false),
36+
init_done_(false),
37+
barrier_(barrier) {
38+
if (cache_config != nullptr) {
39+
pending_cache_config_ = cache_config->unique_copy();
40+
}
41+
// start the thread
42+
thread_ = std::thread(&RankWorker::thread_loop, this);
43+
44+
// Wait until the worker thread finishes initialization (model created)
45+
std::unique_lock<std::mutex> lk(mutex_);
46+
cv_.wait(lk, [&] { return init_done_; });
47+
}
48+
1349
RankWorker::RankWorker(
1450
std::shared_ptr<infinilm::config::ModelConfig> model_config,
1551
const distributed::RankInfo &rank_info,

csrc/engine/rank_worker.hpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ class RankWorker {
5757
infinicore::Tensor output_ids;
5858
};
5959

60+
RankWorker(const InfinilmModel::Config &model_config,
61+
const distributed::RankInfo &rank_info,
62+
const cache::CacheConfig *cache_config,
63+
RankBarrier *barrier,
64+
bool enable_graph_compiling);
65+
6066
RankWorker(std::shared_ptr<infinilm::config::ModelConfig> model_config,
6167
const distributed::RankInfo &rank_info,
6268
const cache::CacheConfig *cache_config,
@@ -95,11 +101,11 @@ class RankWorker {
95101

96102
private:
97103
// Worker properties
98-
// const InfinilmModel::Config &model_config_;
104+
const InfinilmModel::Config &legacy_model_config_ = InfinilmModel::Config();
105+
std::shared_ptr<infinilm::config::ModelConfig> model_config_;
99106
distributed::RankInfo rank_info_;
100107
std::shared_ptr<InfinilmModel> model_;
101108
std::shared_ptr<cache::Cache> cache_;
102-
std::shared_ptr<infinilm::config::ModelConfig> model_config_;
103109

104110
// Graph Compiling
105111
bool enable_graph_compiling_;

csrc/layers/fused_linear.cpp

Lines changed: 103 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,39 +6,102 @@ namespace infinilm::layers {
66
// ---------------------------------------------------------
77
// QKV Parallel Linear
88
// ---------------------------------------------------------
9+
/**
10+
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
11+
*
12+
* ⚠️ DEVELOPMENT POLICY:
13+
* - NO new development or feature additions permitted on this interface
14+
* - Only critical bug fixes (security/stability) allowed until removal
15+
* - All new code MUST migrate to the polymorphic overload below
16+
*
17+
* Replacement: Use the polymorphic overload of this same function name with updated signature
18+
* Reason: Legacy signature lacks support for dynamic quantization modes.
19+
* Removal target: v0.2.0 (Q2 2026)
20+
*/
921
QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
1022
size_t head_dim,
1123
size_t num_q_head,
1224
size_t num_kv_head,
1325
bool bias,
1426
const infinicore::DataType &dtype,
1527
const infinicore::Device &device,
16-
engine::distributed::RankInfo rank_info,
17-
std::optional<infinicore::nn::QuantScheme> quant_scheme)
28+
engine::distributed::RankInfo rank_info)
1829
: QKVParallelLinear(hidden_size,
1930
head_dim, head_dim, head_dim,
2031
num_q_head, num_kv_head, num_kv_head,
2132
bias, bias, bias,
22-
dtype, device, rank_info,
23-
quant_scheme) {}
33+
dtype, device, rank_info) {}
2434

2535
QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
2636
size_t q_dim, size_t k_dim, size_t v_dim,
2737
size_t num_q_head, size_t num_k_head, size_t num_v_head,
2838
bool q_bias, bool k_bias, bool v_bias,
2939
const infinicore::DataType &dtype,
3040
const infinicore::Device &device,
31-
engine::distributed::RankInfo rank_info,
32-
std::optional<infinicore::nn::QuantScheme> quant_scheme)
41+
engine::distributed::RankInfo rank_info)
3342
: infinicore::nn::ColumnParallelLinear(
3443
hidden_size,
3544
num_q_head * q_dim + num_k_head * k_dim + num_v_head * v_dim,
3645
(q_bias || k_bias || v_bias),
3746
dtype,
3847
device,
3948
rank_info.tp_rank,
40-
rank_info.tp_size,
41-
quant_scheme),
49+
rank_info.tp_size),
50+
q_dim_(q_dim),
51+
k_dim_(k_dim),
52+
v_dim_(v_dim),
53+
num_q_head_(num_q_head),
54+
num_k_head_(num_k_head),
55+
num_v_head_(num_v_head),
56+
q_bias_(q_bias),
57+
k_bias_(k_bias),
58+
v_bias_(v_bias) {
59+
if (num_q_head % tp_size_ != 0 || num_k_head % tp_size_ != 0 || num_v_head % tp_size_ != 0) {
60+
throw std::runtime_error("QKVParallelLinear: num_[q|k|v]_head must be divisible by tp_size");
61+
}
62+
63+
if ((q_bias_ != k_bias_) || (k_bias_ != v_bias_)) {
64+
throw std::runtime_error("q_bias, k_bias, v_bias must all match");
65+
}
66+
67+
q_out_size_ = num_q_head_ * q_dim_ / tp_size_;
68+
k_out_size_ = num_k_head_ * k_dim_ / tp_size_;
69+
v_out_size_ = num_v_head_ * v_dim_ / tp_size_;
70+
}
71+
72+
QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
73+
size_t head_dim,
74+
size_t num_q_head,
75+
size_t num_kv_head,
76+
infinicore::nn::QuantScheme quant_scheme,
77+
bool bias,
78+
const infinicore::DataType &dtype,
79+
const infinicore::Device &device,
80+
engine::distributed::RankInfo rank_info)
81+
: QKVParallelLinear(hidden_size,
82+
head_dim, head_dim, head_dim,
83+
num_q_head, num_kv_head, num_kv_head,
84+
bias, bias, bias,
85+
quant_scheme,
86+
dtype, device, rank_info) {}
87+
88+
QKVParallelLinear::QKVParallelLinear(size_t hidden_size,
89+
size_t q_dim, size_t k_dim, size_t v_dim,
90+
size_t num_q_head, size_t num_k_head, size_t num_v_head,
91+
bool q_bias, bool k_bias, bool v_bias,
92+
infinicore::nn::QuantScheme quant_scheme,
93+
const infinicore::DataType &dtype,
94+
const infinicore::Device &device,
95+
engine::distributed::RankInfo rank_info)
96+
: infinicore::nn::ColumnParallelLinear(
97+
hidden_size,
98+
num_q_head * q_dim + num_k_head * k_dim + num_v_head * v_dim,
99+
quant_scheme,
100+
(q_bias || k_bias || v_bias),
101+
dtype,
102+
device,
103+
rank_info.tp_rank,
104+
rank_info.tp_size),
42105
q_dim_(q_dim),
43106
k_dim_(k_dim),
44107
v_dim_(v_dim),
@@ -141,18 +204,44 @@ bool QKVParallelLinear::has_v_bias() const { return v_bias_; }
141204
// ---------------------------------------------------------
142205
// Gate-Up Parallel Linear
143206
// ---------------------------------------------------------
207+
/**
208+
* @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
209+
*
210+
* ⚠️ DEVELOPMENT POLICY:
211+
* - NO new development or feature additions permitted on this interface
212+
* - Only critical bug fixes (security/stability) allowed until removal
213+
* - All new code MUST migrate to the polymorphic overload below
214+
*
215+
* Replacement: Use the polymorphic overload of this same function name with updated signature
216+
* Reason: Legacy signature lacks support for dynamic quantization modes.
217+
* Removal target: v0.2.0 (Q2 2026)
218+
*/
144219
GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool bias,
145220
const infinicore::DataType &dtype, const infinicore::Device &device,
146-
engine::distributed::RankInfo rank_info,
147-
std::optional<infinicore::nn::QuantScheme> quant_scheme)
148-
: GateUpParallelLinear(hidden_size, intermediate_size, bias, bias, dtype, device, rank_info, quant_scheme) {
221+
engine::distributed::RankInfo rank_info)
222+
: GateUpParallelLinear(hidden_size, intermediate_size, bias, bias, dtype, device, rank_info) {
223+
}
224+
225+
GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool gate_bias, bool up_bias,
226+
const infinicore::DataType &dtype, const infinicore::Device &device,
227+
engine::distributed::RankInfo rank_info)
228+
: infinicore::nn::ColumnParallelLinear(hidden_size, intermediate_size * 2, gate_bias || up_bias, dtype, device, rank_info.tp_rank, rank_info.tp_size), gate_bias_(gate_bias), up_bias_(up_bias) {
229+
if (gate_bias_ != up_bias_) {
230+
throw std::runtime_error("Not supported yet: gate_bias and up_bias should be given at the same time");
231+
}
232+
}
233+
234+
GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, infinicore::nn::QuantScheme quant_scheme, bool bias,
235+
const infinicore::DataType &dtype, const infinicore::Device &device,
236+
engine::distributed::RankInfo rank_info)
237+
: GateUpParallelLinear(hidden_size, intermediate_size, bias, bias, quant_scheme, dtype, device, rank_info) {
149238
}
150239

151240
GateUpParallelLinear::GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool gate_bias, bool up_bias,
241+
infinicore::nn::QuantScheme quant_scheme,
152242
const infinicore::DataType &dtype, const infinicore::Device &device,
153-
engine::distributed::RankInfo rank_info,
154-
std::optional<infinicore::nn::QuantScheme> quant_scheme)
155-
: infinicore::nn::ColumnParallelLinear(hidden_size, intermediate_size * 2, gate_bias || up_bias, dtype, device, rank_info.tp_rank, rank_info.tp_size, quant_scheme), gate_bias_(gate_bias), up_bias_(up_bias) {
243+
engine::distributed::RankInfo rank_info)
244+
: infinicore::nn::ColumnParallelLinear(hidden_size, intermediate_size * 2, quant_scheme, gate_bias || up_bias, dtype, device, rank_info.tp_rank, rank_info.tp_size), gate_bias_(gate_bias), up_bias_(up_bias) {
156245
if (gate_bias_ != up_bias_) {
157246
throw std::runtime_error("Not supported yet: gate_bias and up_bias should be given at the same time");
158247
}

0 commit comments

Comments
 (0)