Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 20 additions & 46 deletions src/predictor/cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -572,31 +572,6 @@ void PredictBatchByBlockKernel(DataView const &batch, HostModel const &model,
});
}

float FillNodeMeanValues(tree::ScalarTreeView const &tree, bst_node_t nidx,
std::vector<float> *mean_values) {
float result;
auto &node_mean_values = *mean_values;
if (tree.IsLeaf(nidx)) {
result = tree.LeafValue(nidx);
} else {
result = FillNodeMeanValues(tree, tree.LeftChild(nidx), mean_values) *
tree.Stat(tree.LeftChild(nidx)).sum_hess;
result += FillNodeMeanValues(tree, tree.RightChild(nidx), mean_values) *
tree.Stat(tree.RightChild(nidx)).sum_hess;
result /= tree.Stat(nidx).sum_hess;
}
node_mean_values[nidx] = result;
return result;
}

void FillNodeMeanValues(tree::ScalarTreeView const &tree, std::vector<float> *mean_values) {
auto n_nodes = tree.Size();
if (static_cast<decltype(n_nodes)>(mean_values->size()) == n_nodes) {
return;
}
mean_values->resize(n_nodes);
FillNodeMeanValues(tree, 0, mean_values);
}
} // anonymous namespace

/**
Expand Down Expand Up @@ -920,7 +895,6 @@ class CPUPredictor : public Predictor {
void PredictContributionKernel(DataView batch, const MetaInfo &info, HostModel const &h_model,
linalg::VectorView<float const> base_score,
std::vector<bst_float> const *tree_weights,
std::vector<std::vector<float>> *mean_values,
ThreadTmp<1> *feat_vecs, std::vector<bst_float> *contribs,
bool approximate, int condition,
unsigned condition_feature) const {
Expand All @@ -932,6 +906,15 @@ class CPUPredictor : public Predictor {
auto device = ctx_->Device().IsSycl() ? DeviceOrd::CPU() : ctx_->Device();
auto base_margin = info.base_margin_.View(device);

// Preprocess every tree in the ensemble
std::vector<PreprocessedLeaf> preprocessed_leaves;
for (bst_tree_t j = 0; j < h_model.tree_end; ++j) {
auto sc_tree = std::get<tree::ScalarTreeView>(h_model.Trees()[j]);
auto new_leaves = PreprocessTree(j, sc_tree);
preprocessed_leaves.insert(preprocessed_leaves.end(), new_leaves.begin(),
new_leaves.end());
}

// parallel over local batch
common::ParallelFor(batch.Size(), this->ctx_->Threads(), [&](auto i) {
auto row_idx = batch.base_rowid + i;
Expand All @@ -944,26 +927,22 @@ class CPUPredictor : public Predictor {
for (bst_target_t gid = 0; gid < n_groups; ++gid) {
float *p_contribs = &(*contribs)[(row_idx * n_groups + gid) * ncolumns];
batch.Fill(i, &feats);
// calculate contributions
for (bst_tree_t j = 0; j < h_model.tree_end; ++j) {
auto *tree_mean_values = &mean_values->at(j);
std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0);
if (h_model.tree_groups[j] != gid) {
for(auto& leaf : preprocessed_leaves) {
if (h_model.tree_groups[leaf.tree_idx] != gid) {
continue;
}
auto sc_tree = std::get<tree::ScalarTreeView>(h_model.Trees()[j]);
if (!approximate) {
CalculateContributions(sc_tree, feats, tree_mean_values, &this_tree_contribs[0],
condition, condition_feature);
} else {
CalculateContributionsApprox(sc_tree, feats, tree_mean_values, &this_tree_contribs[0]);
}
const auto& tree= std::get<tree::ScalarTreeView>(h_model.Trees()[leaf.tree_idx]);
auto path = ExtractBinaryPath(tree, feats, leaf.leaf_path);
auto tree_weight = (tree_weights == nullptr ? 1 : (*tree_weights)[leaf.tree_idx]);
for (size_t ci = 0; ci < ncolumns; ++ci) {
p_contribs[ci] +=
this_tree_contribs[ci] * (tree_weights == nullptr ? 1 : (*tree_weights)[j]);
if(leaf.S.count(ci) > 0){
p_contribs[ci] += leaf.S[ci][path] * tree_weight;
}
}
p_contribs[ncolumns - 1] += leaf.null_coalition_weight * tree_weight;
}
feats.Drop();

// add base margin to BIAS
if (base_margin.Size() != 0) {
CHECK_EQ(base_margin.Shape(1), n_groups);
Expand Down Expand Up @@ -1102,19 +1081,14 @@ class CPUPredictor : public Predictor {
// make sure contributions is zeroed, we could be reusing a previously
// allocated one
std::fill(contribs.begin(), contribs.end(), 0);
// initialize tree node mean values
std::vector<std::vector<float>> mean_values(ntree_limit);
common::ParallelFor(ntree_limit, n_threads, [&](bst_omp_uint i) {
FillNodeMeanValues(model.trees[i]->HostScView(), &(mean_values[i]));
});

auto const h_model =
HostModel{DeviceOrd::CPU(), model, 0, ntree_limit, &this->mu_, CopyViews{}};
LaunchPredict(this->ctx_, p_fmat, model, [&](auto &&policy) {
policy.ForEachBatch([&](auto &&batch) {
PredictContributionKernel(batch, info, h_model,
model.learner_model_param->BaseScore(DeviceOrd::CPU()),
tree_weights, &mean_values, &feat_vecs, &contribs, approximate,
tree_weights, &feat_vecs, &contribs, approximate,
condition, condition_feature);
});
});
Expand Down
Loading
Loading