diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index 607acbb3a0c..973ef4be146 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -42,6 +42,7 @@ import org.apache.sysds.runtime.instructions.ooc.MapMMChainOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.ReorgOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.AppendOOCInstruction; public class OOCInstructionParser extends InstructionParser { protected static final Log LOG = LogFactory.getLog(OOCInstructionParser.class.getName()); @@ -106,6 +107,8 @@ else if(parts.length == 4) return IndexingOOCInstruction.parseInstruction(str); case Rand: return DataGenOOCInstruction.parseInstruction(str); + case Append: + return AppendOOCInstruction.parseInstruction(str); default: throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AppendOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AppendOOCInstruction.java new file mode 100644 index 00000000000..2c0e24523cd --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AppendOOCInstruction.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.functionobjects.OffsetColumnIndex; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +public class AppendOOCInstruction extends BinaryOOCInstruction { + + public enum AppendType { + CBIND + } + + protected final AppendType _type; + + protected AppendOOCInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, AppendType type, + String opcode, String istr) { + super(OOCType.Append, op, in1, in2, out, opcode, istr); + _type = type; + } + + public static AppendOOCInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + InstructionUtils.checkNumFields(parts, 5, 4); + + String opcode = parts[0]; + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand out = new CPOperand(parts[parts.length-2]); + boolean cbind = Boolean.parseBoolean(parts[parts.length-1]); + + if(in1.getDataType() != Types.DataType.MATRIX || in2.getDataType() != Types.DataType.MATRIX || !cbind){ + throw new DMLRuntimeException("Only matrix-matrix cbind is supported"); + } + AppendType type = AppendType.CBIND; + + Operator op = new ReorgOperator(OffsetColumnIndex.getOffsetColumnIndexFnObject(-1)); + return new AppendOOCInstruction(op, in1, in2, out, type, opcode, str); + } + + @Override + public void processInstruction(ExecutionContext ec) { + MatrixObject in1 = ec.getMatrixObject(input1); + MatrixObject in2 = ec.getMatrixObject(input2); + validateInput(in1, in2); + if(handleZeroDims(in1, in2, ec)) + return; + + OOCStream qIn1 = in1.getStreamHandle(); + OOCStream qIn2 = in2.getStreamHandle(); + + int blksize = in1.getBlocksize(); + int rem1 = (int) in1.getNumColumns()%blksize; + int rem2 = (int) in2.getNumColumns()%blksize; + int cblk1 = (int) in1.getDataCharacteristics().getNumColBlocks(); + int cblk2 = (int) in2.getDataCharacteristics().getNumColBlocks(); + int cblkRes = (int) Math.ceil((double)(in1.getNumColumns()+in2.getNumColumns())/blksize); + + if(rem1==0){ + // no shifting needed + OOCStream out = new SubscribableTaskQueue<>(); + mapOOC(qIn2, out, imv -> new IndexedMatrixValue( + new MatrixIndexes(imv.getIndexes().getRowIndex(), cblk1+imv.getIndexes().getColumnIndex()), imv.getValue())); + + ec.getMatrixObject(output).setStreamHandle(mergeOOCStreams(List.of(qIn1, out))); + return; + } + + List> split1 = splitOOCStream(qIn1, imv -> imv.getIndexes().getColumnIndex()==cblk1? 1 : 0, 2); + List> split2 = splitOOCStream(qIn2, imv -> (int) imv.getIndexes().getColumnIndex()-1, cblk2); + + OOCStream head = split1.get(0); + OOCStream lastCol = split1.get(1); + OOCStream firstCol = split2.get(0); + + CachingStream firstColCache = new CachingStream(firstCol); + OOCStream firstColForCritical = firstColCache.getReadStream(); + OOCStream firstColForTail = firstColCache.getReadStream(); + + SubscribableTaskQueue out = new SubscribableTaskQueue<>(); + Function rowKey = imv -> new MatrixIndexes(imv.getIndexes().getRowIndex(), 1); + + int fullRem2 = rem2==0? blksize : rem2; + // combine cols both matrices + joinOOC(lastCol, firstColForCritical, out, (left, right) -> { + MatrixBlock lb = (MatrixBlock) left.getValue(); + MatrixBlock rb = (MatrixBlock) right.getValue(); + int stop = cblk2==1 && blksize-rem1>fullRem2? fullRem2 : blksize-rem1; + MatrixBlock combined = cbindBlocks(lb, sliceCols(rb, 0, stop)); + return new IndexedMatrixValue( + new MatrixIndexes(left.getIndexes().getRowIndex(), left.getIndexes().getColumnIndex()), combined); + }, rowKey); + + List> outStreams = new ArrayList<>(); + outStreams.add(head); + outStreams.add(out); + + // shift cols second matrix + OOCStream fst = firstColForTail; + OOCStream sec = null; + for(int i=0; i(); + CachingStream secCachingStream = new CachingStream(split2.get(i+1)); + sec = secCachingStream.getReadStream(); + + int finalI = i; + joinOOC(fst, sec, out, (left, right) -> { + MatrixBlock lb = (MatrixBlock) left.getValue(); + MatrixBlock rb = (MatrixBlock) right.getValue(); + int stop = finalI+2==cblk2 && blksize-rem1>fullRem2? fullRem2 : blksize-rem1; + MatrixBlock combined = cbindBlocks(sliceCols(lb, blksize-rem1, blksize), sliceCols(rb, 0, stop)); + return new IndexedMatrixValue( + new MatrixIndexes(left.getIndexes().getRowIndex(), cblk1 + left.getIndexes().getColumnIndex()), + combined); + }, rowKey); + + fst = secCachingStream.getReadStream(); + outStreams.add(out); + } + + if(cblk1+cblk2==cblkRes){ + // overflow + int remSize = (rem1+rem2)%blksize; + out = new SubscribableTaskQueue<>(); + mapOOC(fst, out, imv -> new IndexedMatrixValue( + new MatrixIndexes(imv.getIndexes().getRowIndex(), cblk1+imv.getIndexes().getColumnIndex()), + sliceCols((MatrixBlock) imv.getValue(), fullRem2-remSize, fullRem2))); + + outStreams.add(out); + } + ec.getMatrixObject(output).setStreamHandle(mergeOOCStreams(outStreams)); + } + + public AppendType getAppendType() { + return _type; + } + + private void validateInput(MatrixObject m1, MatrixObject m2) { + if(_type == AppendType.CBIND && m1.getNumRows() != m2.getNumRows()) { + throw new DMLRuntimeException( + "Append-cbind is not possible for input matrices " + input1.getName() + " and " + input2.getName() + + " with different number of rows: " + m1.getNumRows() + " vs " + m2.getNumRows()); + } + } + + private boolean handleZeroDims(MatrixObject m1, MatrixObject m2, ExecutionContext ec) { + long rows = m1.getNumRows(); + long cols1 = m1.getNumColumns(); + long cols2 = m2.getNumColumns(); + if(rows == 0 || (cols1 == 0 && cols2 == 0)) { + OOCStream empty = createWritableStream(); + empty.closeInput(); + ec.getMatrixObject(output).setStreamHandle(empty); + } + else if(cols1 == 0) { + ec.getMatrixObject(output).setStreamHandle(m2.getStreamHandle()); + } + else if(cols2 == 0) { + ec.getMatrixObject(output).setStreamHandle(m1.getStreamHandle()); + } + else return false; + + return true; + } + + private static MatrixBlock sliceCols(MatrixBlock in, int colStart, int colEndExclusive) { + // slice is inclusive + return in.slice(0, in.getNumRows()-1, colStart, colEndExclusive-1); + } + + private static MatrixBlock cbindBlocks(MatrixBlock left, MatrixBlock right) { + return left.append(right, new MatrixBlock()); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java index 7e1bdac73d2..7e3f7e5133a 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java @@ -35,6 +35,7 @@ import org.apache.sysds.runtime.ooc.stream.SourceOOCStream; import org.apache.sysds.runtime.ooc.stream.message.OOCGetStreamTypeMessage; import org.apache.sysds.runtime.ooc.stream.message.OOCStreamMessage; +import org.apache.sysds.runtime.ooc.util.OOCUtils; import org.apache.sysds.runtime.util.IndexRange; import shaded.parquet.it.unimi.dsi.fastutil.ints.IntArrayList; @@ -453,7 +454,7 @@ public void findCachedAsync(MatrixIndexes idx, Consumer 0) { - long expected = dc.getNumBlocks(); + long expected = OOCUtils.getNumBlocks(dc); if (expected >= 0 && _numBlocks != expected) { throw new DMLRuntimeException("CachingStream block count mismatch: expected " + expected + " but saw " + _numBlocks + " (" + dc.getRows() + "x" + dc.getCols() + ")"); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index f7cefe635df..931d45e0f45 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -75,7 +75,7 @@ public abstract class OOCInstruction extends Instruction { public enum OOCType { Reblock, Tee, Binary, Ternary, Unary, AggregateUnary, AggregateBinary, AggregateTernary, MAPMM, MMTSJ, - MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand + MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand, Append } protected final OOCInstruction.OOCType _ooctype; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java index e5c48decdd1..ce796728319 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java @@ -25,6 +25,7 @@ import org.apache.sysds.runtime.meta.DataCharacteristics; import org.apache.sysds.runtime.ooc.stream.message.OOCGetStreamTypeMessage; import org.apache.sysds.runtime.ooc.stream.message.OOCStreamMessage; +import org.apache.sysds.runtime.ooc.util.OOCUtils; import org.apache.sysds.runtime.util.IndexRange; import java.util.LinkedList; @@ -166,7 +167,7 @@ public synchronized void closeInput() { private void validateBlockCountOnClose() { DataCharacteristics dc = getDataCharacteristics(); if (dc != null && dc.dimsKnown() && dc.getBlocksize() > 0) { - long expected = dc.getNumBlocks(); + long expected = OOCUtils.getNumBlocks(dc); if (expected >= 0 && _blockCount.get() != expected) { throw new DMLRuntimeException("OOCStream block count mismatch: expected " + expected + " but saw " + _blockCount.get() + " (" + dc.getRows() + "x" + dc.getCols() + ")"); @@ -180,6 +181,7 @@ public void setSubscriber(Consumer> subscriber) { throw new IllegalArgumentException("Cannot set subscriber to null"); LinkedList data; + boolean needsEos; synchronized(this) { if(_subscriber != null) @@ -189,12 +191,20 @@ public void setSubscriber(Consumer> subscriber) { throw _failure; data = _data; _data = new LinkedList<>(); + // If this stream was already closed with no buffered data, no further + // onDeliveryFinished() call will happen, so emit EOS immediately. + needsEos = _closed.get() && data.isEmpty() && _availableCtr.get() == 0; + if(needsEos) + _availableCtr.incrementAndGet(); // route terminal emission via onDeliveryFinished } for (T t : data) { subscriber.accept(new SimpleQueueCallback<>(t, _failure)); onDeliveryFinished(); } + + if(needsEos) + onDeliveryFinished(); } @SuppressWarnings("unchecked") @@ -214,6 +224,9 @@ private void onDeliveryFinished() { @Override public synchronized void propagateFailure(DMLRuntimeException re) { + // Ignore late failures + if(_closed.get() && _availableCtr.get() == 0) + return; super.propagateFailure(re); Consumer> s = _subscriber; if(s != null) diff --git a/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlock.java b/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlock.java index 69fd386c5ef..a9912373290 100644 --- a/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlock.java +++ b/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlock.java @@ -97,10 +97,13 @@ public final void writeEmptyMatrixToHDFS(String fname, long rlen, long clen, int FileSystem fs = IOUtilFunctions.getFileSystem(path, job); final Writer writer = IOUtilFunctions.getSeqWriter(path, job, _replication); try { - MatrixIndexes index = new MatrixIndexes(1, 1); - MatrixBlock block = new MatrixBlock((int) Math.max(Math.min(rlen, blen), 1), - (int) Math.max(Math.min(clen, blen), 1), true); - writer.append(index, block); + // For 0xN or Nx0, emit a valid sequence file header only (no blocks). + if(rlen > 0 && clen > 0) { + MatrixIndexes index = new MatrixIndexes(1, 1); + MatrixBlock block = new MatrixBlock((int) Math.max(Math.min(rlen, blen), 1), + (int) Math.max(Math.min(clen, blen), 1), true); + writer.append(index, block); + } } finally { IOUtilFunctions.closeSilently(writer); diff --git a/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlockParallel.java b/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlockParallel.java index c00e58b7fac..88f7c0a690e 100644 --- a/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlockParallel.java +++ b/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlockParallel.java @@ -95,6 +95,16 @@ protected void writeBinaryBlockMatrixToHDFS( Path path, JobConf job, MatrixBlock public long writeMatrixFromStream(String fname, OOCStream stream, long rlen, long clen, int blen) throws IOException { Path path = new Path(fname); + + // For empty dimensions, no stream tiles are expected but the output must still exist. + if(rlen <= 0 || clen <= 0) { + while(stream.dequeue() != LocalTaskQueue.NO_MORE_TASKS) { + // Drain any unexpected records to keep stream producers unblocked. + } + writeEmptyMatrixToHDFS(fname, rlen, clen, blen); + return 0; + } + long nnz = -1; DataCharacteristics dc = stream.getDataCharacteristics(); if(dc != null) diff --git a/src/main/java/org/apache/sysds/runtime/ooc/stream/MergedOOCStream.java b/src/main/java/org/apache/sysds/runtime/ooc/stream/MergedOOCStream.java index 7d0a27932f1..0c036d16c20 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/stream/MergedOOCStream.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/stream/MergedOOCStream.java @@ -85,6 +85,17 @@ public MergedOOCStream(List> sources) { if(_failed.get()) return; + if(cb instanceof OOCStream.GroupQueueCallback) { + OOCStream.GroupQueueCallback group = (OOCStream.GroupQueueCallback) cb; + for(int i = 0; i < group.size(); i++) { + OOCStream.QueueCallback sub = group.getCallback(i); + try(sub) { + _taskQueue.enqueue(sub.keepOpen()); + } + } + return; + } + _taskQueue.enqueue(cb.keepOpen()); } } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/util/OOCUtils.java b/src/main/java/org/apache/sysds/runtime/ooc/util/OOCUtils.java index c564748e1da..f33d17ea132 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/util/OOCUtils.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/util/OOCUtils.java @@ -20,6 +20,7 @@ package org.apache.sysds.runtime.ooc.util; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.meta.DataCharacteristics; import org.apache.sysds.runtime.util.IndexRange; import java.util.ArrayList; @@ -60,4 +61,13 @@ public static Collection getTilesOfRange(IndexRange range, long b list.add(new MatrixIndexes(r, c)); return list; } + + public static long getNumBlocks(DataCharacteristics dc) { + if (dc != null && dc.dimsKnown() && dc.getBlocksize() > 0) { + if(dc.getCols() == 0 || dc.getRows() == 0) + return 0; + return dc.getNumBlocks(); + } + return -1; + } } diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/CBindTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/CBindTest.java new file mode 100644 index 00000000000..3172585c6ee --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/CBindTest.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.parser.LanguageException; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; + +@RunWith(Parameterized.class) +@net.jcip.annotations.NotThreadSafe +public class CBindTest extends AutomatedTestBase { + + private static final String TEST_NAME = "CBindTest"; + private static final String TEST_DIR = "functions/ooc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + CBindTest.class.getSimpleName() + "/"; + + private final static double eps = 1e-8; + private static final String INPUT_NAME_1 = "A"; + private static final String INPUT_NAME_2 = "B"; + private static final String OUTPUT_NAME = "res"; + + private final int r1; + private final int c1; + private final int r2; + private final int c2; + private final int bsize; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); + } + + public CBindTest(int r1, int c1, int r2, int c2, int bsize) { + this.r1 = r1; + this.c1 = c1; + this.r2 = r2; + this.c2 = c2; + this.bsize = bsize; + } + + @Parameterized.Parameters(name = "{0}x{1} {2}x{3} bsize {4}") + public static Iterable getParams() { + int[] rows = new int[]{1000, 2000}; + int[] cols = new int[]{300, 700, 2300, 2700, 3000, 3300}; + int[] bsizes = new int[]{1000}; + + ArrayList params = new ArrayList<>(); + + for(int row : rows) { + for(int col : cols) { + for(int col2 : cols) { + for(int bsize : bsizes) { + params.add(new Object[] {row, col, row, col2, bsize}); + } + } + } + } + + params.add(new Object[] {10, 1000, 20, 1000, 1000}); + params.add(new Object[] {0, 1000, 0, 1000, 1000}); + params.add(new Object[] {1000, 0, 1000, 1000, 1000}); + params.add(new Object[] {1000, 1000, 1000, 0, 1000}); + params.add(new Object[] {1000, 0, 1000, 0, 1000}); + + return params; + } + + @Test + public void runCBindTest() { + Types.ExecMode platformOld = rtplatform; + rtplatform = Types.ExecMode.SINGLE_NODE; + + try { + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + + double[][] A = TestUtils.floor(getRandomMatrix(r1, c1, -1, 1, 1.0, 7)); + double[][] B = TestUtils.floor(getRandomMatrix(r2, c2, -1, 1, 1.0, 13)); + + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(A), input(INPUT_NAME_1), r1, c1, bsize, r1*c1); + writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(B), input(INPUT_NAME_2), r2, c2, bsize, r2*c2); + + HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(r1, c1, bsize, r1*c1), Types.FileFormat.BINARY); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_2 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(r2, c2, bsize, r2*c2), Types.FileFormat.BINARY); + + + programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", + input(INPUT_NAME_1), input(INPUT_NAME_2), output(OUTPUT_NAME)}; + + if(r1 != r2){ + runTest(true,true, LanguageException.class,-1); + return; + } + + runTest(true, false, null, -1); + Assert.assertTrue("OOC wasn't used for cbind", + heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.APPEND)); + + // rerun without ooc flag + programArgs = new String[] {"-explain", "-stats", "-args", + input(INPUT_NAME_1), input(INPUT_NAME_2), output(OUTPUT_NAME + "_target")}; + runTest(true, false, null, -1); + + // compare results + MatrixBlock ret1 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME), + Types.FileFormat.BINARY, r1, c1+c2, bsize); + MatrixBlock ret2 = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"), + Types.FileFormat.BINARY, r1, c1+c2, bsize); + TestUtils.compareMatrices(ret1, ret2, eps); + } + catch(Exception ex) { + Assert.fail(ex.getMessage()); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/scripts/functions/ooc/CBindTest.dml b/src/test/scripts/functions/ooc/CBindTest.dml new file mode 100644 index 00000000000..edfbddafc0f --- /dev/null +++ b/src/test/scripts/functions/ooc/CBindTest.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = read($1) +B = read($2) +res = cbind(A, B) + +write(res, $3, format="binary");