Skip to content

Commit 6b90cb6

Browse files
committed
[TMVA][SOFIE] Restructure emitted code to be differentiable with Clad
1 parent 6c8a4ac commit 6b90cb6

2 files changed

Lines changed: 73 additions & 34 deletions

File tree

tmva/sofie/inc/TMVA/SOFIE_common.hxx

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,23 @@ inline void FillOutput(std::vector<bool> const &vec, std::vector<std::uint8_t> &
637637
}
638638
}
639639

640+
// Used at the end of infer() to fill the return object.
641+
template <class T>
642+
void FillOutput(T const *arr, T *out, std::size_t n)
643+
{
644+
for (std::size_t i = 0; i < n; ++i) {
645+
out[i] = arr[i];
646+
}
647+
}
648+
649+
// Special case for std::vector<bool>.
650+
inline void FillOutput(std::vector<bool> const &vec, std::uint8_t *out, std::size_t n)
651+
{
652+
for (std::size_t i = 0; i < n; ++i) {
653+
out[i] = vec[i];
654+
}
655+
}
656+
640657
} // end namespace UTILITY
641658

642659
namespace BLAS{

tmva/sofie/src/RModel.cxx

Lines changed: 56 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ std::string GenerateConstantTensorCode(const std::pair<std::string, InitializedT
542542
else {
543543
strs << ConvertValuesToString(length, data) << ";\n";
544544
}
545-
strs << "const " << type << " * tensor_" + t.first + " = fTensor_" + t.first + ".data();\n";
545+
strs << type << " * tensor_" + t.first + " = fTensor_" + t.first + ".data();\n";
546546
}
547547
return strs.str();
548548
}
@@ -739,13 +739,16 @@ void RModel::GenerateOutput()
739739
if (!doInferArgs.empty())
740740
doInferArgs += ",";
741741
for (std::string const &name : fOutputTensorNames) {
742-
fGC += SP + "std::vector<" + typeForOutput(GetTensorType(name)) + " > output_tensor_" + name + ";\n";
743-
doInferArgs += " output_tensor_" + name + ",";
742+
bool isIntermediate = fIntermediateTensorInfos.count(name) > 0;
743+
std::string n = isIntermediate ? std::to_string(ConvertShapeToLength(GetTensorShape(name)))
744+
: ConvertDynamicShapeToLength(GetDynamicTensorShape(name));
745+
fGC += SP + "std::vector<" + typeForOutput(GetTensorType(name)) + " > output_tensor_" + name + "(" + n + ");\n";
746+
doInferArgs += " output_tensor_" + name + ".data(),";
744747
}
745748
if (!doInferArgs.empty())
746749
doInferArgs.back() = ' ';
747750

748-
fGC += SP + "doInfer(" + doInferArgs + ");\n";
751+
fGC += SP + "doInfer(this, " + doInferArgs + ");\n";
749752

750753
fGC += SP + "return {";
751754
for (size_t i = 0; i < fOutputTensorNames.size(); i++) {
@@ -759,23 +762,35 @@ void RModel::GenerateOutput()
759762

760763
void RModel::GenerateSessionCode()
761764
{
765+
std::string sessionName;
766+
if (fUseSession && !fIsGNNComponent) {
767+
sessionName = !fIsSubGraph ? "Session" : "Session_" + fName;
768+
769+
// forward declare session struct
770+
fGC += "struct " + sessionName + ";\n";
771+
}
772+
762773
// Determine the signature of the actual inference function
763774
std::string doInferSignature = GenerateInferSignature();
764775
if (!doInferSignature.empty())
765776
doInferSignature += ", ";
766777
for (auto const &name : fOutputTensorNames) {
767-
doInferSignature += " std::vector<" + typeForOutput(GetTensorType(name)) + "> &output_tensor_" + name + ",";
778+
doInferSignature += typeForOutput(GetTensorType(name)) + " *tensor_" + name + ",";
768779
}
769780
doInferSignature.back() = ' ';
770781

782+
if (fUseSession && !fIsGNNComponent) {
783+
doInferSignature = sessionName + " *session, " + doInferSignature;
784+
}
785+
771786
doInferSignature = "void doInfer(" + doInferSignature + ")";
772787

773788
// define the Session struct (for GNN this is generated in RModel_GNN)
774789
if (fUseSession && !fIsGNNComponent) {
775-
if (!fIsSubGraph)
776-
fGC += "struct Session {\n";
777-
else
778-
fGC += "struct Session_" + fName + " {\n";
790+
// forward declare inference implementation to be used in Session
791+
fGC += doInferSignature + ";\n";
792+
793+
fGC += "struct " + sessionName + " {\n";
779794
}
780795

781796
// generate code for declaring the initialized tensors
@@ -815,9 +830,6 @@ void RModel::GenerateSessionCode()
815830

816831
// Generate code for Session constructor
817832
if (fUseSession) {
818-
std::string sessionName = "Session";
819-
if (fIsSubGraph)
820-
sessionName += "_" + fName;
821833
// add here specific operator code that needs to define session data members
822834
fGC += "\n";
823835
for (size_t id = 0; id < fOperators.size(); id++) {
@@ -868,9 +880,39 @@ void RModel::GenerateSessionCode()
868880
fGC += "}\n\n";
869881
}
870882

883+
// generate the inference overload that returns an output struct
884+
GenerateOutput();
885+
886+
// end of session
887+
if (fUseSession && !fIsGNNComponent) {
888+
fGC += "}; // end of Session\n\n";
889+
}
890+
871891
fGC += doInferSignature + "{\n";
872892
fGC += "\n";
873893

894+
if (fUseSession && !fIsGNNComponent) {
895+
fGC += " " + sessionName + " &sess = session[0];\n";
896+
std::vector<std::string> names;
897+
for (auto const& it: fInitializedTensors) {
898+
names.push_back(it.first);
899+
}
900+
for (auto const& it: fIntermediateTensorInfos) {
901+
names.push_back(it.first);
902+
}
903+
std::vector<std::string> added;
904+
for (auto const& name : names) {
905+
auto found = std::find(fOutputTensorNames.begin(), fOutputTensorNames.end(), name);
906+
auto found2 = std::find(added.begin(), added.end(), name);
907+
// Output tensors are passed directly via the function call
908+
if(found == fOutputTensorNames.end() && found2 == added.end()) {
909+
fGC += " auto & tensor_" + name + " = sess.tensor_" + name + ";\n";
910+
added.push_back(name);
911+
}
912+
}
913+
fGC += "\n";
914+
}
915+
874916
// generate the inference code
875917
if (fVerbose)
876918
std::cout << "Generating main inference code for " << fName << std::endl;
@@ -879,31 +921,11 @@ void RModel::GenerateSessionCode()
879921
throw std::runtime_error("TMVA-SOFIE: output size=0 are not supported");
880922

881923
for (size_t op_idx = 0; op_idx < fOperators.size(); ++op_idx) {
882-
if (fVerbose)
883-
std::cout << "Generating code for operator .... " << op_idx << std::endl;
924+
if (fVerbose) std::cout << "Generating code for operator .... " << op_idx << std::endl;
884925
fGC += (fOperators[op_idx]->Generate(std::to_string(op_idx)));
885926
}
886927

887-
fGC += SP + "using TMVA::Experimental::SOFIE::UTILITY::FillOutput;\n\n";
888-
889-
for (std::string const &name : fOutputTensorNames) {
890-
// need to check is size is the same (don't want to return a vector with
891-
// larger size) in that case better to copy
892-
bool isIntermediate = fIntermediateTensorInfos.count(name) > 0;
893-
std::string n = isIntermediate ? std::to_string(ConvertShapeToLength(GetTensorShape(name)))
894-
: ConvertDynamicShapeToLength(GetDynamicTensorShape(name));
895-
fGC += SP + "FillOutput(tensor_" + name + ", output_tensor_" + name + ", " + n + ");\n";
896-
}
897-
898-
fGC += "}\n\n";
899-
900-
// generate the inference overload that returns an output struct
901-
GenerateOutput();
902-
903-
// end of session
904-
if (fUseSession && !fIsGNNComponent) {
905-
fGC += "}; // end of Session\n\n";
906-
}
928+
fGC += "}\n";
907929
}
908930

909931
void RModel::Generate(std::underlying_type_t<Options> options, int batchSize, long pos, bool verbose)

0 commit comments

Comments
 (0)