diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyCharTypePaddingHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyCharTypePaddingHelper.scala index 26c85469e3f32..418349e465fc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyCharTypePaddingHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyCharTypePaddingHelper.scala @@ -51,8 +51,21 @@ object ApplyCharTypePaddingHelper { private[sql] def readSidePadding( relation: LogicalPlan, cleanedRelation: () => LogicalPlan): (LogicalPlan, Seq[(Attribute, Attribute)]) = { + readSideCharTypeHandling(relation, cleanedRelation, CharVarcharUtils.addPaddingForScan) + } + + private[sql] def readSideTrim( + relation: LogicalPlan, + cleanedRelation: () => LogicalPlan): (LogicalPlan, Seq[(Attribute, Attribute)]) = { + readSideCharTypeHandling(relation, cleanedRelation, CharVarcharUtils.trimTrailingSpacesForScan) + } + + private def readSideCharTypeHandling( + relation: LogicalPlan, + cleanedRelation: () => LogicalPlan, + charTypeHandler: Attribute => Expression): (LogicalPlan, Seq[(Attribute, Attribute)]) = { val projectList = relation.output.map { attr => - CharVarcharUtils.addPaddingForScan(attr) match { + charTypeHandler(attr) match { case ne: NamedExpression => ne case other => Alias(other, attr.name)(explicitMetadata = Some(attr.metadata)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala index 9501986bb0c5d..dfc0a13568f7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala @@ -262,6 +262,68 @@ object CharVarcharUtils extends Logging with SparkCharVarcharUtils { }.getOrElse(attr) } + def trimTrailingSpacesForScan(attr: Attribute): Expression = { + getRawType(attr.metadata).map { rawType => + trimTrailingSpacesForChar(attr, rawType) + }.getOrElse(attr) + } + + private def trimTrailingSpacesForChar(expr: Expression, dt: DataType): Expression = dt match { + case _: CharType => + StringTrimRight(expr) + + case StructType(fields) => + val struct = CreateNamedStruct(fields.zipWithIndex.flatMap { case (f, i) => + Seq(Literal(f.name), trimTrailingSpacesForChar( + GetStructField(expr, i, Some(f.name)), f.dataType)) + }.toImmutableArraySeq) + if (struct.valExprs.forall(_.isInstanceOf[GetStructField])) { + expr + } else if (expr.nullable) { + If(IsNull(expr), Literal(null, struct.dataType), struct) + } else { + struct + } + + case ArrayType(et, containsNull) => + val param = NamedLambdaVariable("x", replaceCharVarcharWithString(et), containsNull) + val funcBody = trimTrailingSpacesForChar(param, et) + if (funcBody.fastEquals(param)) { + expr + } else { + ArrayTransform(expr, LambdaFunction(funcBody, Seq(param))) + } + + case MapType(kt, vt, valueContainsNull) => + val keys = MapKeys(expr) + val newKeys = { + val param = NamedLambdaVariable("x", replaceCharVarcharWithString(kt), nullable = false) + val funcBody = trimTrailingSpacesForChar(param, kt) + if (funcBody.fastEquals(param)) { + keys + } else { + ArrayTransform(keys, LambdaFunction(funcBody, Seq(param))) + } + } + val values = MapValues(expr) + val newValues = { + val param = NamedLambdaVariable("x", replaceCharVarcharWithString(vt), valueContainsNull) + val funcBody = trimTrailingSpacesForChar(param, vt) + if (funcBody.fastEquals(param)) { + values + } else { + ArrayTransform(values, LambdaFunction(funcBody, Seq(param))) + } + } + if (newKeys.fastEquals(keys) && newValues.fastEquals(values)) { + expr + } else { + MapFromArrays(newKeys, newValues) + } + + case _ => expr + } + /** * Return expressions to apply char type padding for the string comparison between the given * attributes. When comparing two char type columns/fields, we need to pad the shorter one to diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 58840ebf09eb5..63a2e2eb5d1ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -6486,6 +6486,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val CHAR_TRIM_TRAILING_SPACES_ON_READ = + buildConf("spark.sql.charTrimTrailingSpacesOnRead") + .doc("When true, Spark trims trailing spaces from CHAR type columns/fields when reading " + + s"table data. This config takes precedence over ${READ_SIDE_CHAR_PADDING.key}.") + .version("4.2.0") + .booleanConf + .createWithDefault(false) + val LEGACY_NO_CHAR_PADDING_IN_PREDICATE = buildConf("spark.sql.legacy.noCharPaddingInPredicate") .internal() .doc("When true, Spark will not apply char type padding for CHAR type columns in string " + @@ -8465,6 +8473,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def readSideCharPadding: Boolean = getConf(SQLConf.READ_SIDE_CHAR_PADDING) + def charTrimTrailingSpacesOnRead: Boolean = getConf(SQLConf.CHAR_TRIM_TRAILING_SPACES_ON_READ) + def cliPrintHeader: Boolean = getConf(SQLConf.CLI_PRINT_HEADER) def legacyIntervalEnabled: Boolean = getConf(LEGACY_INTERVAL_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala index d952927f9d30a..3989179658cbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala @@ -28,6 +28,9 @@ import org.apache.spark.sql.internal.SQLConf /** * This rule performs string padding for char type. * + * When reading values from column/field of type CHAR(N), trim trailing spaces if the read-side + * trim config is turned on. + * * When reading values from column/field of type CHAR(N), right-pad the values to length N, if the * read-side padding config is turned on. * @@ -41,7 +44,22 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { return plan } - if (conf.readSideCharPadding) { + if (conf.charTrimTrailingSpacesOnRead) { + plan.resolveOperatorsUpWithNewOutput { + case r: LogicalRelation => + ApplyCharTypePaddingHelper.readSideTrim(r, () => + r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata))) + case r: DataSourceV2Relation => + ApplyCharTypePaddingHelper.readSideTrim(r, () => + r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata))) + case r: HiveTableRelation => + ApplyCharTypePaddingHelper.readSideTrim(r, () => { + val cleanedDataCols = r.dataCols.map(CharVarcharUtils.cleanAttrMetadata) + val cleanedPartCols = r.partitionCols.map(CharVarcharUtils.cleanAttrMetadata) + r.copy(dataCols = cleanedDataCols, partitionCols = cleanedPartCols) + }) + } + } else if (conf.readSideCharPadding) { val newPlan = plan.resolveOperatorsUpWithNewOutput { case r: LogicalRelation => ApplyCharTypePaddingHelper.readSidePadding(r, () => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index 8d22575f5d09b..1034133a4cde3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -87,6 +87,39 @@ trait CharVarcharTestSuite extends QueryTest { } } + test("SPARK-56819: trim trailing spaces from char columns on read") { + withSQLConf(SQLConf.CHAR_TRIM_TRAILING_SPACES_ON_READ.key -> "true") { + withTable("t") { + sql( + s""" + |CREATE TABLE t ( + | c CHAR(4), + | v VARCHAR(4), + | s STRING, + | st STRUCT + |) USING $format + |""".stripMargin) + sql( + """ + |INSERT INTO t VALUES ( + | '12', + | '12 ', + | '12 ', + | named_struct('c', '12', 'v', '12 ') + |) + |""".stripMargin) + + checkAnswer( + sql("SELECT c, v, s, st FROM t"), + Row("12", "12 ", "12 ", Row("12", "12 "))) + checkAnswer( + sql("SELECT length(c), length(v), length(s), length(st.c), length(st.v) FROM t"), + Row(2, 3, 3, 2, 3)) + checkAnswer(sql("SELECT c = '12', c = '12 ' FROM t"), Row(true, false)) + } + } + } + test("preserve char/varchar type info") { Seq(CharType(5), VarcharType(5)).foreach { typ => for {