@@ -542,7 +542,7 @@ std::string GenerateConstantTensorCode(const std::pair<std::string, InitializedT
542542 else {
543543 strs << ConvertValuesToString (length, data) << " ;\n " ;
544544 }
545- strs << " const " << type << " * tensor_" + t.first + " = fTensor_" + t.first + " .data();\n " ;
545+ strs << type << " * tensor_" + t.first + " = fTensor_" + t.first + " .data();\n " ;
546546 }
547547 return strs.str ();
548548}
@@ -739,13 +739,16 @@ void RModel::GenerateOutput()
739739 if (!doInferArgs.empty ())
740740 doInferArgs += " ," ;
741741 for (std::string const &name : fOutputTensorNames ) {
742- fGC += SP + " std::vector<" + typeForOutput (GetTensorType (name)) + " > output_tensor_" + name + " ;\n " ;
743- doInferArgs += " output_tensor_" + name + " ," ;
742+ bool isIntermediate = fIntermediateTensorInfos .count (name) > 0 ;
743+ std::string n = isIntermediate ? std::to_string (ConvertShapeToLength (GetTensorShape (name)))
744+ : ConvertDynamicShapeToLength (GetDynamicTensorShape (name));
745+ fGC += SP + " std::vector<" + typeForOutput (GetTensorType (name)) + " > output_tensor_" + name + " (" + n + " );\n " ;
746+ doInferArgs += " output_tensor_" + name + " .data()," ;
744747 }
745748 if (!doInferArgs.empty ())
746749 doInferArgs.back () = ' ' ;
747750
748- fGC += SP + " doInfer(" + doInferArgs + " );\n " ;
751+ fGC += SP + " doInfer(this, " + doInferArgs + " );\n " ;
749752
750753 fGC += SP + " return {" ;
751754 for (size_t i = 0 ; i < fOutputTensorNames .size (); i++) {
@@ -759,23 +762,35 @@ void RModel::GenerateOutput()
759762
760763void RModel::GenerateSessionCode ()
761764{
765+ std::string sessionName;
766+ if (fUseSession && !fIsGNNComponent ) {
767+ sessionName = !fIsSubGraph ? " Session" : " Session_" + fName ;
768+
769+ // forward declare session struct
770+ fGC += " struct " + sessionName + " ;\n " ;
771+ }
772+
762773 // Determine the signature of the actual inference function
763774 std::string doInferSignature = GenerateInferSignature ();
764775 if (!doInferSignature.empty ())
765776 doInferSignature += " , " ;
766777 for (auto const &name : fOutputTensorNames ) {
767- doInferSignature += " std::vector< " + typeForOutput (GetTensorType (name)) + " > &output_tensor_ " + name + " ," ;
778+ doInferSignature += typeForOutput (GetTensorType (name)) + " *tensor_ " + name + " ," ;
768779 }
769780 doInferSignature.back () = ' ' ;
770781
782+ if (fUseSession && !fIsGNNComponent ) {
783+ doInferSignature = sessionName + " *session, " + doInferSignature;
784+ }
785+
771786 doInferSignature = " void doInfer(" + doInferSignature + " )" ;
772787
773788 // define the Session struct (for GNN this is generated in RModel_GNN)
774789 if (fUseSession && !fIsGNNComponent ) {
775- if (! fIsSubGraph )
776- fGC += " struct Session { \n " ;
777- else
778- fGC += " struct Session_ " + fName + " {\n " ;
790+ // forward declare inference implementation to be used in Session
791+ fGC += doInferSignature + " ; \n " ;
792+
793+ fGC += " struct " + sessionName + " {\n " ;
779794 }
780795
781796 // generate code for declaring the initialized tensors
@@ -815,9 +830,6 @@ void RModel::GenerateSessionCode()
815830
816831 // Generate code for Session constructor
817832 if (fUseSession ) {
818- std::string sessionName = " Session" ;
819- if (fIsSubGraph )
820- sessionName += " _" + fName ;
821833 // add here specific operator code that needs to define session data members
822834 fGC += " \n " ;
823835 for (size_t id = 0 ; id < fOperators .size (); id++) {
@@ -868,9 +880,39 @@ void RModel::GenerateSessionCode()
868880 fGC += " }\n\n " ;
869881 }
870882
883+ // generate the inference overload that returns an output struct
884+ GenerateOutput ();
885+
886+ // end of session
887+ if (fUseSession && !fIsGNNComponent ) {
888+ fGC += " }; // end of Session\n\n " ;
889+ }
890+
871891 fGC += doInferSignature + " {\n " ;
872892 fGC += " \n " ;
873893
894+ if (fUseSession && !fIsGNNComponent ) {
895+ fGC += " " + sessionName + " &sess = session[0];\n " ;
896+ std::vector<std::string> names;
897+ for (auto const & it: fInitializedTensors ) {
898+ names.push_back (it.first );
899+ }
900+ for (auto const & it: fIntermediateTensorInfos ) {
901+ names.push_back (it.first );
902+ }
903+ std::vector<std::string> added;
904+ for (auto const & name : names) {
905+ auto found = std::find (fOutputTensorNames .begin (), fOutputTensorNames .end (), name);
906+ auto found2 = std::find (added.begin (), added.end (), name);
907+ // Output tensors are passed directly via the function call
908+ if (found == fOutputTensorNames .end () && found2 == added.end ()) {
909+ fGC += " auto & tensor_" + name + " = sess.tensor_" + name + " ;\n " ;
910+ added.push_back (name);
911+ }
912+ }
913+ fGC += " \n " ;
914+ }
915+
874916 // generate the inference code
875917 if (fVerbose )
876918 std::cout << " Generating main inference code for " << fName << std::endl;
@@ -879,31 +921,11 @@ void RModel::GenerateSessionCode()
879921 throw std::runtime_error (" TMVA-SOFIE: output size=0 are not supported" );
880922
881923 for (size_t op_idx = 0 ; op_idx < fOperators .size (); ++op_idx) {
882- if (fVerbose )
883- std::cout << " Generating code for operator .... " << op_idx << std::endl;
924+ if (fVerbose ) std::cout << " Generating code for operator .... " << op_idx << std::endl;
884925 fGC += (fOperators [op_idx]->Generate (std::to_string (op_idx)));
885926 }
886927
887- fGC += SP + " using TMVA::Experimental::SOFIE::UTILITY::FillOutput;\n\n " ;
888-
889- for (std::string const &name : fOutputTensorNames ) {
890- // need to check is size is the same (don't want to return a vector with
891- // larger size) in that case better to copy
892- bool isIntermediate = fIntermediateTensorInfos .count (name) > 0 ;
893- std::string n = isIntermediate ? std::to_string (ConvertShapeToLength (GetTensorShape (name)))
894- : ConvertDynamicShapeToLength (GetDynamicTensorShape (name));
895- fGC += SP + " FillOutput(tensor_" + name + " , output_tensor_" + name + " , " + n + " );\n " ;
896- }
897-
898- fGC += " }\n\n " ;
899-
900- // generate the inference overload that returns an output struct
901- GenerateOutput ();
902-
903- // end of session
904- if (fUseSession && !fIsGNNComponent ) {
905- fGC += " }; // end of Session\n\n " ;
906- }
928+ fGC += " }\n " ;
907929}
908930
909931void RModel::Generate (std::underlying_type_t <Options> options, int batchSize, long pos, bool verbose)
0 commit comments