Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.sql.connect.client
package org.apache.spark.sql.util

private[sql] trait CloseableIterator[E] extends Iterator[E] with AutoCloseable { self =>
def asJava: java.util.Iterator[E] = new java.util.Iterator[E] with AutoCloseable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connect.client.arrow
package org.apache.spark.sql.util

import java.io.{InputStream, IOException}
import java.nio.channels.Channels
Expand All @@ -34,7 +34,7 @@ import org.apache.arrow.vector.types.pojo.Schema
* closes its messages when it consumes them. In order to prevent that from happening in
* non-destructive mode we clone the messages before passing them to the reading logic.
*/
class ConcatenatingArrowStreamReader(
private[sql] class ConcatenatingArrowStreamReader(
allocator: BufferAllocator,
input: Iterator[AbstractMessageIterator],
destructive: Boolean)
Expand Down Expand Up @@ -128,7 +128,7 @@ class ConcatenatingArrowStreamReader(
override def closeReadSource(): Unit = ()
}

trait AbstractMessageIterator extends Iterator[ArrowMessage] {
private[sql] trait AbstractMessageIterator extends Iterator[ArrowMessage] {
def schema: Schema
def bytesRead: Long
}
Expand All @@ -137,7 +137,7 @@ trait AbstractMessageIterator extends Iterator[ArrowMessage] {
* Decode an Arrow IPC stream into individual messages. Please note that this iterator MUST have a
* valid IPC stream as its input, otherwise construction will fail.
*/
class MessageIterator(input: InputStream, allocator: BufferAllocator)
private[sql] class MessageIterator(input: InputStream, allocator: BufferAllocator)
extends AbstractMessageIterator {
private[this] val in = new ReadChannel(Channels.newChannel(input))
private[this] val reader = new MessageChannelReader(in, allocator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND
import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE
import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils._
import org.apache.spark.sql.catalyst.util.SparkIntervalUtils._
import org.apache.spark.sql.connect.client.CloseableIterator
import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum
import org.apache.spark.sql.connect.test.ConnectFunSuite
import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, Geography, Geometry, IntegerType, Metadata, SQLUserDefinedType, StringType, StructType, UserDefinedType, YearMonthIntervalType}
import org.apache.spark.sql.util.CloseableIterator
import org.apache.spark.unsafe.types.VariantVal
import org.apache.spark.util.{MaybeNull, SparkStringUtils}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, BoxedLongEncoder, UnboundRowEncoder}
import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.{ClassFinder, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf, SubqueryExpression}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.ExecutionListenerManager
import org.apache.spark.sql.util.{CloseableIterator, ExecutionListenerManager}
import org.apache.spark.util.ArrayImplicits._

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import scala.jdk.CollectionConverters._

import org.apache.spark.connect.proto.{Command, ExecutePlanResponse, Plan, StreamingQueryEventType}
import org.apache.spark.internal.{Logging, LogKeys}
import org.apache.spark.sql.connect.client.CloseableIterator
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.sql.streaming.StreamingQueryListener.{Event, QueryIdleEvent, QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent}
import org.apache.spark.sql.util.CloseableIterator

class StreamingQueryListenerBus(sparkSession: SparkSession) extends Logging {
private val listeners = new CopyOnWriteArrayList[StreamingQueryListener]()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.jdk.CollectionConverters._
import io.grpc.ManagedChannel

import org.apache.spark.connect.proto._
import org.apache.spark.sql.util.CloseableIterator

private[connect] class CustomSparkConnectBlockingStub(
channel: ManagedChannel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import io.grpc.stub.StreamObserver
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.connect.client.GrpcRetryHandler.RetryException
import org.apache.spark.sql.util.WrappedCloseableIterator

/**
* Retryable iterator of ExecutePlanResponses to an ExecutePlan call.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException,
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.streaming.StreamingQueryException
import org.apache.spark.sql.util.{CloseableIterator, WrappedCloseableIterator}
import org.apache.spark.util.ArrayImplicits._

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import io.grpc.stub.StreamObserver

import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{ERROR, NUM_RETRY, POLICY, RETRY_WAIT_TIME}
import org.apache.spark.sql.util.{CloseableIterator, WrappedCloseableIterator}

private[sql] class GrpcRetryHandler(
private val policies: Seq[RetryPolicy],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import io.grpc.{Status, StatusRuntimeException}
import io.grpc.stub.StreamObserver

import org.apache.spark.internal.Logging
import org.apache.spark.sql.util.{CloseableIterator, WrappedCloseableIterator}

// This is common logic to be shared between different stub instances to keep the server-side
// session id and to validate responses as seen by the client.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import org.apache.spark.internal.LogKeys.{ERROR, RATIO, SIZE, TIME}
import org.apache.spark.sql.connect.RuntimeConfig
import org.apache.spark.sql.connect.common.ProtoUtils
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.util.CloseableIterator
import org.apache.spark.util.SparkSystemUtils

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, UnboundRowEncoder}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.connect.client.arrow.{AbstractMessageIterator, ArrowDeserializingIterator, ConcatenatingArrowStreamReader, MessageIterator}
import org.apache.spark.sql.connect.client.arrow.ArrowDeserializingIterator
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, LiteralValueProtoConverter}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.util.{AbstractMessageIterator, ArrowUtils, CloseableIterator, ConcatenatingArrowStreamReader, MessageIterator}

private[sql] class SparkResult[T](
responses: CloseableIterator[proto.ExecutePlanResponse],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.connect.client.CloseableIterator
import org.apache.spark.sql.errors.{CompilationErrors, ExecutionErrors}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.sql.util.{CloseableIterator, ConcatenatingArrowStreamReader, MessageIterator}
import org.apache.spark.unsafe.types.VariantVal

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@ import org.apache.spark.sql.catalyst.DefinedByConstructorParams
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils}
import org.apache.spark.sql.connect.client.CloseableIterator
import org.apache.spark.sql.errors.ExecutionErrors
import org.apache.spark.sql.types.Decimal
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.util.{ArrowUtils, CloseableIterator}
import org.apache.spark.unsafe.types.VariantVal

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import com.google.protobuf.{Any => ProtoAny, ByteString}
import io.grpc.{Context, Status, StatusRuntimeException}
import io.grpc.stub.StreamObserver

import org.apache.spark.{SparkClassNotFoundException, SparkEnv, SparkException, TaskContext}
import org.apache.spark.{SparkClassNotFoundException, SparkEnv, SparkException}
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
import org.apache.spark.connect.proto
Expand Down Expand Up @@ -1491,9 +1491,12 @@ class SparkConnectPlanner(
}

if (rel.hasData) {
val (rows, structType) =
ArrowConverters.fromIPCStream(rel.getData.toByteArray, TaskContext.get())
buildLocalRelationFromRows(rows, structType, Option(schema))
val (rows, structType) = ArrowConverters.fromIPCStream(rel.getData.toByteArray)
try {
buildLocalRelationFromRows(rows, structType, Option(schema))
} finally {
rows.close()
}
} else {
if (schema == null) {
throw InvalidInputErrors.schemaRequiredForLocalRelation()
Expand Down Expand Up @@ -1564,28 +1567,13 @@ class SparkConnectPlanner(
}

// Load and combine all batches
var combinedRows: Iterator[InternalRow] = Iterator.empty
var structType: StructType = null

for ((dataHash, batchIndex) <- dataHashes.zipWithIndex) {
val dataBytes = readChunkedCachedLocalRelationBlock(dataHash)
val (batchRows, batchStructType) =
ArrowConverters.fromIPCStream(dataBytes, TaskContext.get())

// For the first batch, set the schema; for subsequent batches, verify compatibility
if (batchIndex == 0) {
structType = batchStructType
combinedRows = batchRows

} else {
if (batchStructType != structType) {
throw InvalidInputErrors.chunkedCachedLocalRelationChunksWithDifferentSchema()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An error like this is thrown in the iterator. We may want to make this nicer though...

}
combinedRows = combinedRows ++ batchRows
}
val (rows, structType) =
ArrowConverters.fromIPCStream(dataHashes.iterator.map(readChunkedCachedLocalRelationBlock))
try {
buildLocalRelationFromRows(rows, structType, Option(schema))
} finally {
rows.close()
}

buildLocalRelationFromRows(combinedRows, structType, Option(schema))
}

private def toStructTypeOrWrap(dt: DataType): StructType = dt match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,15 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.classic
import org.apache.spark.sql.connect
import org.apache.spark.sql.connect.client.{CloseableIterator, CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator, RetryPolicy, SparkConnectClient, SparkConnectStubState}
import org.apache.spark.sql.connect.client.{CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator, RetryPolicy, SparkConnectClient, SparkConnectStubState}
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.dsl.MockRemoteSession
import org.apache.spark.sql.connect.dsl.plans._
import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionKey, SparkConnectService}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.util.CloseableIterator

/**
* Base class and utilities for a test suite that starts and tests the real SparkConnectService
Expand Down
Loading