Skip to content

Commit a4ced80

Browse files
authored
Merge pull request #205 from InfiniTensor/demo131
Demo-131 Cuda graph with optimized paged attention
2 parents 96ecf49 + 04c37f3 commit a4ced80

54 files changed

Lines changed: 2287 additions & 260 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,5 @@ __pycache__/
2929
*.txt
3030

3131
*.http
32+
33+
*.nsys-rep

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
[submodule "third_party/spdlog"]
22
path = third_party/spdlog
33
url = https://github.com/gabime/spdlog.git
4+
[submodule "third_party/json"]
5+
path = third_party/json
6+
url = https://github.com/nlohmann/json.git

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA
7171
- 单次推理测试
7272
- llama示例
7373
```bash
74-
python examples/llama.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_path=<path/to/model_dir>
74+
python examples/llama.py [--cpu | --nvidia | --qy | --metax | --moore | --iluvatar | --ali] --model_path=<path/to/model_dir>
7575
```
7676
- 例如:
7777
```bash

csrc/cache/kv_cache.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -85,26 +85,36 @@ StaticKVCache::update(size_t layer_idx,
8585

8686
auto batch_size = k->size(0);
8787
auto update_len = k->size(2);
88-
size_t cache_pos = reinterpret_cast<int64_t *>(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0];
89-
auto result_len = cache_pos + update_len;
90-
91-
ASSERT(result_len <= cache_len_);
9288

9389
ASSERT_EQ(batch_size, rank_batch_size_);
9490

9591
auto k_cache_layer = k_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
9692
auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
9793

98-
auto k_cache_update = k_cache_layer->narrow({{2, cache_pos, update_len}});
99-
auto v_cache_update = v_cache_layer->narrow({{2, cache_pos, update_len}});
100-
101-
k_cache_update->copy_from(k);
102-
v_cache_update->copy_from(v);
103-
104-
auto k_total = k_cache_layer->narrow({{2, 0, result_len}});
105-
auto v_total = v_cache_layer->narrow({{2, 0, result_len}});
94+
auto device = k_cache_layer->device();
95+
96+
if (device.getType() == infinicore::Device::Type::NVIDIA
97+
|| device.getType() == infinicore::Device::Type::ILUVATAR
98+
|| device.getType() == infinicore::Device::Type::METAX) {
99+
infinicore::op::kv_caching_(
100+
k_cache_layer,
101+
v_cache_layer,
102+
k,
103+
v,
104+
past_sequence_lengths);
105+
} else {
106+
size_t cache_pos = reinterpret_cast<int64_t *>(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0];
107+
auto result_len = cache_pos + update_len;
108+
ASSERT(result_len <= cache_len_);
109+
110+
auto k_cache_update = k_cache_layer->narrow({{2, cache_pos, update_len}});
111+
auto v_cache_update = v_cache_layer->narrow({{2, cache_pos, update_len}});
112+
113+
k_cache_update->copy_from(k);
114+
v_cache_update->copy_from(v);
115+
}
106116

107-
return {k_total, v_total};
117+
return {k_cache_layer, v_cache_layer};
108118
}
109119

110120
// ==========================

csrc/config/model_config.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#include "model_config.hpp"
2+
3+
namespace infinilm::config {
4+
ModelConfig::ModelConfig(const std::string &path) {
5+
std::ifstream file(path);
6+
if (file.is_open()) {
7+
file >> config_json;
8+
file.close();
9+
} else {
10+
throw std::runtime_error("Could not open config file: " + path);
11+
}
12+
this->quant_config = QuantConfig(config_json["quantization_config"]);
13+
}
14+
15+
infinicore::quantization::QuantScheme
16+
ModelConfig::get_quant_scheme() const {
17+
if (quant_config.get_quant_scheme() != infinicore::quantization::QuantScheme::NONE) {
18+
return quant_config.get_quant_scheme();
19+
} else {
20+
return infinicore::quantization::QuantScheme::NONE;
21+
}
22+
}
23+
24+
std::shared_ptr<infinicore::nn::RoPE::ScalingConfig>
25+
ModelConfig::get_rope_scaling() const {
26+
if (!config_json.contains("rope_scaling") || config_json["rope_scaling"].is_null()) {
27+
return nullptr;
28+
}
29+
30+
const auto &rope_scaling = config_json["rope_scaling"];
31+
if (!rope_scaling.is_object()) {
32+
throw std::runtime_error("rope_scaling must be an object");
33+
}
34+
35+
if (!rope_scaling.contains("type")) {
36+
throw std::runtime_error("rope_scaling must contain 'type' field");
37+
}
38+
39+
std::string type_str = rope_scaling["type"].get<std::string>();
40+
if (type_str == "longrope") {
41+
// Required fields for LongRopeConfig
42+
if (!rope_scaling.contains("short_factor") || !rope_scaling.contains("long_factor") || !rope_scaling.contains("original_max_position_embeddings")) {
43+
throw std::runtime_error(
44+
"LongRopeConfig requires 'short_factor', 'long_factor', and 'original_max_position_embeddings'");
45+
}
46+
47+
auto short_factor = rope_scaling["short_factor"].get<std::vector<float>>();
48+
auto long_factor = rope_scaling["long_factor"].get<std::vector<float>>();
49+
size_t original_max_position_embeddings = rope_scaling["original_max_position_embeddings"].get<size_t>();
50+
51+
float factor = 1.0f;
52+
if (rope_scaling.contains("factor")) {
53+
factor = rope_scaling["factor"].get<float>();
54+
}
55+
56+
return std::make_shared<infinicore::nn::RoPE::LongRopeConfig>(
57+
std::move(short_factor),
58+
std::move(long_factor),
59+
original_max_position_embeddings,
60+
factor);
61+
} else if (type_str == "default" || type_str == "none") {
62+
// Default scaling, no scaling applied
63+
return nullptr;
64+
} else {
65+
throw std::runtime_error("Unsupported rope_scaling type: " + type_str);
66+
}
67+
}
68+
69+
infinicore::DataType
70+
ModelConfig::get_dtype() const {
71+
try {
72+
std::string dtype_str = this->get<std::string>("torch_dtype");
73+
if (dtype_str == "float32") {
74+
return infinicore::DataType::F32;
75+
} else if (dtype_str == "float16") {
76+
return infinicore::DataType::F16;
77+
} else if (dtype_str == "bfloat16") {
78+
return infinicore::DataType::BF16;
79+
} else if (dtype_str == "int8") {
80+
return infinicore::DataType::I8;
81+
} else {
82+
throw std::runtime_error("Unsupported dtype string: " + dtype_str);
83+
}
84+
} catch (const std::exception &e) {
85+
throw std::runtime_error("Error getting dtype from config: " + std::string(e.what()));
86+
}
87+
}
88+
} // namespace infinilm::config

csrc/config/model_config.hpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#pragma once
2+
3+
#include "infinicore/nn/rope.hpp"
4+
#include "infinicore/ops.hpp"
5+
#include "quant_config.hpp"
6+
#include <fstream>
7+
#include <string>
8+
9+
namespace infinilm::config {
10+
class ModelConfig {
11+
// Model config is implemented using nlohmann/json and is primarily used for advanced configuration
12+
// beyond the standard model config. It is initialized via ModelConfig(const std::string& path)
13+
// and passed through the InferEngine during inference.
14+
public:
15+
ModelConfig() = default;
16+
// Not Implemented
17+
// ModelConfig(const nlohmann::json &json) : config_json(json) {};
18+
ModelConfig(const std::string &path);
19+
20+
// Template Function to get a value by key with type safety
21+
template <typename T>
22+
T get(const std::string &key) const {
23+
if (!config_json.contains(key)) {
24+
throw std::out_of_range("Key '" + key + "' not found in config.");
25+
}
26+
try {
27+
return config_json.at(key).get<T>();
28+
} catch (const nlohmann::json::type_error &e) {
29+
throw std::runtime_error("Type conversion failed for key '" + key + "': " + std::string(e.what()));
30+
}
31+
}
32+
33+
template <typename T>
34+
T get_or(const std::string &key, const T &default_value) const {
35+
if (!config_json.contains(key) || config_json.at(key).is_null()) {
36+
return default_value;
37+
}
38+
try {
39+
return config_json.at(key).get<T>();
40+
} catch (const nlohmann::json::type_error &) {
41+
// If type conversion fails, return default value
42+
return default_value;
43+
}
44+
}
45+
size_t get_kv_dim() const {
46+
return get<size_t>("hidden_size") * get<size_t>("num_key_value_heads") / get<size_t>("num_attention_heads");
47+
}
48+
size_t get_head_dim() const {
49+
if (config_json.contains("head_dim")) {
50+
return get<size_t>("head_dim");
51+
}
52+
return get<size_t>("hidden_size") / get<size_t>("num_attention_heads");
53+
}
54+
55+
QuantConfig get_quant_config() const {
56+
return quant_config;
57+
}
58+
59+
std::shared_ptr<infinicore::quantization::BaseQuantization> get_quantization_method() const {
60+
return quant_config.get_quantization_method();
61+
}
62+
63+
infinicore::DataType get_dtype() const;
64+
infinicore::quantization::QuantScheme get_quant_scheme() const;
65+
std::shared_ptr<infinicore::nn::RoPE::ScalingConfig> get_rope_scaling() const;
66+
67+
private:
68+
nlohmann::json config_json;
69+
QuantConfig quant_config;
70+
};
71+
} // namespace infinilm::config

csrc/config/quant_config.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#include "quant_config.hpp"
2+
3+
namespace infinilm::config {
4+
QuantConfig::QuantConfig(const nlohmann::json &json) : quantization_config(json) {
5+
this->quantization_method = get_quantization_method();
6+
}
7+
8+
std::shared_ptr<infinicore::quantization::BaseQuantization>
9+
QuantConfig::get_quantization_method() const {
10+
if (quantization_config.is_null()) {
11+
// return nullptr;
12+
return std::make_shared<infinicore::quantization::NoneQuantization>(quantization_config); // Default case if no matching scheme
13+
}
14+
15+
// Determine the quantization scheme from the JSON config
16+
if (quantization_config["quant_method"] == "compressed-tensors") {
17+
return std::make_shared<infinicore::quantization::CompressedTensors>(quantization_config);
18+
} else if (quantization_config["quant_method"] == "awq") {
19+
return std::make_shared<infinicore::quantization::AWQ>(quantization_config);
20+
} else {
21+
return std::make_shared<infinicore::quantization::NoneQuantization>(quantization_config);
22+
}
23+
// Add other schemes as needed
24+
25+
return std::make_shared<infinicore::quantization::NoneQuantization>(quantization_config); // Default case if no matching scheme
26+
}
27+
} // namespace infinilm::config

csrc/config/quant_config.hpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#pragma once
2+
// #include "../quantization/quantization.hpp"
3+
#include "infinicore/quantization.hpp"
4+
#include "nlohmann/json.hpp"
5+
6+
namespace infinilm::config {
7+
8+
class QuantConfig {
9+
// QuantConfig is used to store and parse the "quantization" field from config.json.
10+
// This is currently a basic version and will be extended in the future.
11+
public:
12+
QuantConfig() = default;
13+
QuantConfig(const nlohmann::json &json);
14+
15+
std::shared_ptr<infinicore::quantization::BaseQuantization> get_quantization_method() const;
16+
17+
infinicore::quantization::QuantScheme get_quant_scheme() const {
18+
if (quantization_method != nullptr) {
19+
return quantization_method->get_quant_scheme();
20+
} else {
21+
return infinicore::quantization::QuantScheme::NONE;
22+
}
23+
}
24+
25+
private:
26+
nlohmann::json quantization_config;
27+
std::shared_ptr<infinicore::quantization::BaseQuantization> quantization_method;
28+
};
29+
30+
} // namespace infinilm::config
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#include "general_compiler.hpp"
2+
3+
namespace infinilm::engine {
4+
GeneralCompiler::GeneralCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier) : GraphCompiler(model, barrier) {
5+
static_batching_compiler_ = std::make_unique<StaticBatchingCompiler>(model_, barrier);
6+
paged_compiler_ = std::make_unique<PagedCompiler>(model_, barrier);
7+
}
8+
9+
void GeneralCompiler::compile() {
10+
static_batching_compiler_->compile();
11+
paged_compiler_->compile();
12+
}
13+
14+
GeneralCompiler::Compiled GeneralCompiler::get_compiled(const InfinilmModel::Input &input) {
15+
GeneralCompiler::Compiled result = {nullptr, nullptr};
16+
17+
// try each compiler, return the first valid result
18+
result = static_batching_compiler_.get()->get_compiled(input);
19+
if (std::get<0>(result) != nullptr && std::get<1>(result) != nullptr) {
20+
return result;
21+
}
22+
result = paged_compiler_.get()->get_compiled(input);
23+
return result;
24+
}
25+
26+
} // namespace infinilm::engine
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#pragma once
2+
3+
#include "paged_compiler.hpp"
4+
#include "static_batching_compiler.hpp"
5+
6+
namespace infinilm::engine {
7+
class GeneralCompiler : public GraphCompiler {
8+
public:
9+
GeneralCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier);
10+
11+
void compile() override;
12+
13+
Compiled get_compiled(const InfinilmModel::Input &input) override;
14+
15+
private:
16+
std::unique_ptr<StaticBatchingCompiler> static_batching_compiler_;
17+
std::unique_ptr<PagedCompiler> paged_compiler_;
18+
};
19+
} // namespace infinilm::engine

0 commit comments

Comments
 (0)