Skip to content

Commit a684e5a

Browse files
authored
DPL: properly reserve memory when writing (#15510)
1 parent bb290cf commit a684e5a

2 files changed

Lines changed: 65 additions & 19 deletions

File tree

Detectors/AOD/src/AODProducerWorkflowSpec.cxx

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -602,20 +602,20 @@ void AODProducerWorkflowDPL::fillTrackTablesPerCollision(int collisionID,
602602
int end = start + trackRef.getEntriesOfSource(src);
603603
int nToReserve = end - start; // + last index for a given table
604604
if (src == GIndex::Source::MFT) {
605-
mftTracksCursor.reserve(nToReserve + mftTracksCursor.lastIndex());
605+
mftTracksCursor.reserve(nToReserve + mftTracksCursor.lastIndex() + 1);
606606
if (mStoreAllMFTCov) {
607-
mftTracksCovCursor.reserve(nToReserve + mftTracksCovCursor.lastIndex());
607+
mftTracksCovCursor.reserve(nToReserve + mftTracksCovCursor.lastIndex() + 1);
608608
}
609609
} else if (src == GIndex::Source::MCH || src == GIndex::Source::MFTMCH || src == GIndex::Source::MCHMID) {
610-
fwdTracksCursor.reserve(nToReserve + fwdTracksCursor.lastIndex());
611-
fwdTracksCovCursor.reserve(nToReserve + fwdTracksCovCursor.lastIndex());
610+
fwdTracksCursor.reserve(nToReserve + fwdTracksCursor.lastIndex() + 1);
611+
fwdTracksCovCursor.reserve(nToReserve + fwdTracksCovCursor.lastIndex() + 1);
612612
if (!mStoreAllMFTCov && src == GIndex::Source::MFTMCH) {
613-
mftTracksCovCursor.reserve(nToReserve + mftTracksCovCursor.lastIndex());
613+
mftTracksCovCursor.reserve(nToReserve + mftTracksCovCursor.lastIndex() + 1);
614614
}
615615
} else {
616-
tracksCursor.reserve(nToReserve + tracksCursor.lastIndex());
617-
tracksCovCursor.reserve(nToReserve + tracksCovCursor.lastIndex());
618-
tracksExtraCursor.reserve(nToReserve + tracksExtraCursor.lastIndex());
616+
tracksCursor.reserve(nToReserve + tracksCursor.lastIndex() + 1);
617+
tracksCovCursor.reserve(nToReserve + tracksCovCursor.lastIndex() + 1);
618+
tracksExtraCursor.reserve(nToReserve + tracksExtraCursor.lastIndex() + 1);
619619
}
620620
for (int ti = start; ti < end; ti++) {
621621
const auto& trackIndex = GIndices[ti];
@@ -715,9 +715,9 @@ void AODProducerWorkflowDPL::fillTrackTablesPerCollision(int collisionID,
715715
}
716716
/// Add strangeness tracks to the table
717717
auto sTracks = data.getStrangeTracks();
718-
tracksCursor.reserve(mVertexStrLUT[collisionID + 1] + tracksCursor.lastIndex());
719-
tracksCovCursor.reserve(mVertexStrLUT[collisionID + 1] + tracksCovCursor.lastIndex());
720-
tracksExtraCursor.reserve(mVertexStrLUT[collisionID + 1] + tracksExtraCursor.lastIndex());
718+
tracksCursor.reserve(mVertexStrLUT[collisionID + 1] + tracksCursor.lastIndex() + 1);
719+
tracksCovCursor.reserve(mVertexStrLUT[collisionID + 1] + tracksCovCursor.lastIndex() + 1);
720+
tracksExtraCursor.reserve(mVertexStrLUT[collisionID + 1] + tracksExtraCursor.lastIndex() + 1);
721721
for (int iS{mVertexStrLUT[collisionID]}; iS < mVertexStrLUT[collisionID + 1]; ++iS) {
722722
auto& collStrTrk = mCollisionStrTrk[iS];
723723
auto& sTrk = sTracks[collStrTrk.second];
@@ -1236,9 +1236,9 @@ void AODProducerWorkflowDPL::fillMCTrackLabelsTable(MCTrackLabelCursorType& mcTr
12361236
for (int src = GIndex::NSources; src--;) {
12371237
int start = trackRef.getFirstEntryOfSource(src);
12381238
int end = start + trackRef.getEntriesOfSource(src);
1239-
mcMFTTrackLabelCursor.reserve(end - start + mcMFTTrackLabelCursor.lastIndex());
1240-
mcFwdTrackLabelCursor.reserve(end - start + mcFwdTrackLabelCursor.lastIndex());
1241-
mcTrackLabelCursor.reserve(end - start + mcTrackLabelCursor.lastIndex());
1239+
mcMFTTrackLabelCursor.reserve(end - start + mcMFTTrackLabelCursor.lastIndex() + 1);
1240+
mcFwdTrackLabelCursor.reserve(end - start + mcFwdTrackLabelCursor.lastIndex() + 1);
1241+
mcTrackLabelCursor.reserve(end - start + mcTrackLabelCursor.lastIndex() + 1);
12421242
for (int ti = start; ti < end; ti++) {
12431243
const auto trackIndex = primVerGIs[ti];
12441244

@@ -1320,7 +1320,7 @@ void AODProducerWorkflowDPL::fillMCTrackLabelsTable(MCTrackLabelCursorType& mcTr
13201320
auto sTrackLabels = data.getStrangeTracksMCLabels();
13211321
// check if vertexId and vertexId + 1 maps into mVertexStrLUT
13221322
if (!(vertexId < 0 || vertexId >= mVertexStrLUT.size() - 1)) {
1323-
mcTrackLabelCursor.reserve(mVertexStrLUT[vertexId + 1] + mcTrackLabelCursor.lastIndex());
1323+
mcTrackLabelCursor.reserve(mVertexStrLUT[vertexId + 1] + mcTrackLabelCursor.lastIndex() + 1);
13241324
for (int iS{mVertexStrLUT[vertexId]}; iS < mVertexStrLUT[vertexId + 1]; ++iS) {
13251325
auto& collStrTrk = mCollisionStrTrk[iS];
13261326
auto& label = sTrackLabels[collStrTrk.second];
@@ -1448,9 +1448,9 @@ void AODProducerWorkflowDPL::addClustersToFwdTrkClsTable(const o2::globaltrackin
14481448

14491449
if (mchTrackID > -1 && mchTrackID < mchTracks.size()) {
14501450
const auto& mchTrack = mchTracks[mchTrackID];
1451-
fwdTrkClsCursor.reserve(mchTrack.getNClusters() + fwdTrkClsCursor.lastIndex());
14521451
int first = mchTrack.getFirstClusterIdx();
14531452
int last = mchTrack.getLastClusterIdx();
1453+
fwdTrkClsCursor.reserve(last - first + 1 + fwdTrkClsCursor.lastIndex() + 1);
14541454
for (int i = first; i <= last; i++) {
14551455
const auto& cluster = mchClusters[i];
14561456
fwdTrkClsCursor(fwdTrackId,
@@ -1678,10 +1678,10 @@ void AODProducerWorkflowDPL::addToCaloTable(TCaloHandler& caloHandler, TCaloCurs
16781678
auto inputEvent = caloHandler.buildEvent(eventID);
16791679
auto cellsInEvent = inputEvent.mCells; // get cells belonging to current event
16801680
auto cellMClabels = inputEvent.mMCCellLabels; // get MC labels belonging to current event (only implemented for EMCal currently!)
1681-
caloCellCursor.reserve(cellsInEvent.size() + caloCellCursor.lastIndex());
1682-
caloTRGCursor.reserve(cellsInEvent.size() + caloTRGCursor.lastIndex());
1681+
caloCellCursor.reserve(cellsInEvent.size() + caloCellCursor.lastIndex() + 1);
1682+
caloTRGCursor.reserve(cellsInEvent.size() + caloTRGCursor.lastIndex() + 1);
16831683
if (mUseMC) {
1684-
mcCaloCellLabelCursor.reserve(cellsInEvent.size() + mcCaloCellLabelCursor.lastIndex());
1684+
mcCaloCellLabelCursor.reserve(cellsInEvent.size() + mcCaloCellLabelCursor.lastIndex() + 1);
16851685
}
16861686
for (auto iCell = 0U; iCell < cellsInEvent.size(); iCell++) {
16871687
caloCellCursor(bcID,

Framework/Core/test/test_ASoA.cxx

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,3 +1376,49 @@ TEST_CASE("TestCombinedGetter")
13761376
++count;
13771377
}
13781378
}
1379+
1380+
TEST_CASE("TestWritingCursorLastIndexAndReserve")
1381+
{
1382+
// Nails down the WritingCursor semantics the AOD-producer reserves depend on:
1383+
// lastIndex() returns the *last index* (rows - 1), not the row count, and
1384+
// reserve(newRows + lastIndex() + 1) reserves exactly the post-batch total so a
1385+
// fully-filled, no-skip batch neither overruns (the fwdTrkCls crash) nor trips
1386+
// the release() / per-row UnsafeAppend guard.
1387+
Produces<o2::aod::Points> cursor; // Points has two persistent columns: X, Y
1388+
auto* builder = new TableBuilder();
1389+
cursor.resetCursor(LifetimeHolder<TableBuilder>(builder));
1390+
1391+
// Empty cursor: no row written, so the last index is -1 and rows == lastIndex()+1 == 0.
1392+
REQUIRE(cursor.lastIndex() == -1);
1393+
1394+
// operator() increments before the append, but only to the index of the row it
1395+
// writes: after N writes lastIndex() == N - 1, NOT N.
1396+
cursor(10, 20);
1397+
REQUIRE(cursor.lastIndex() == 0);
1398+
cursor(11, 21);
1399+
REQUIRE(cursor.lastIndex() == 1);
1400+
cursor(12, 22);
1401+
REQUIRE(cursor.lastIndex() == 2);
1402+
REQUIRE(cursor.lastIndex() + 1 == 3); // rows-so-far == last index + 1
1403+
1404+
// Reserve a second batch the correct way: total = newRows + rowsSoFar
1405+
// = newRows + (lastIndex() + 1).
1406+
// The (buggy) newRows + lastIndex() would reserve 4 here and under-reserve the
1407+
// 5th row; the + 1 makes it exactly 5.
1408+
int64_t const newRows = 2;
1409+
int64_t const reserved = newRows + cursor.lastIndex() + 1; // correct total -> reserve(5)
1410+
cursor.reserve(reserved);
1411+
cursor(13, 23); // row index 3
1412+
cursor(14, 24); // row index 4 — fills the batch exactly (5 rows total)
1413+
REQUIRE(cursor.lastIndex() == 4);
1414+
1415+
// The contract release() enforces: rows filled (lastIndex()+1) must not exceed
1416+
// what was reserved. Correct (+1) gives reserved == 5 -> 5 <= 5 (green); the buggy
1417+
// newRows + lastIndex() reserves only 4 -> 5 <= 4 fails (red).
1418+
REQUIRE(cursor.lastIndex() + 1 <= reserved);
1419+
1420+
auto table = builder->finalize();
1421+
REQUIRE(table->num_rows() == 5);
1422+
REQUIRE(table->num_columns() == 2);
1423+
delete builder;
1424+
}

0 commit comments

Comments
 (0)