From 21fde57f15db974b710e7b00e72c744da7c1ac3c Mon Sep 17 00:00:00 2001 From: Nathan Howell Date: Thu, 16 Feb 2017 20:51:19 -0800 Subject: [PATCH] [SPARK-18352][SQL] Support parsing multiline json files ## What changes were proposed in this pull request? If a new option `wholeFile` is set to `true` the JSON reader will parse each file (instead of a single line) as a value. This is done with Jackson streaming and it should be capable of parsing very large documents, assuming the row will fit in memory. Because the file is not buffered in memory the corrupt record handling is also slightly different when `wholeFile` is enabled: the corrupt column will contain the filename instead of the literal JSON if there is a parsing failure. It would be easy to extend this to add the parser location (line, column and byte offsets) to the output if desired. These changes have allowed types other than `String` to be parsed. Support for `UTF8String` and `Text` have been added (alongside `String` and `InputFormat`) and no longer require a conversion to `String` just for parsing. I've also included a few other changes that generate slightly better bytecode and (imo) make it more obvious when and where boxing is occurring in the parser. These are included as separate commits, let me know if they should be flattened into this PR or moved to a new one. ## How was this patch tested? New and existing unit tests. No performance or load tests have been run. Author: Nathan Howell Closes #16386 from NathanHowell/SPARK-18352. --- .../apache/spark/unsafe/types/UTF8String.java | 20 +- .../spark/input/PortableDataStream.scala | 7 + python/pyspark/sql/readwriter.py | 13 +- python/pyspark/sql/streaming.py | 14 +- python/pyspark/sql/tests.py | 7 + python/test_support/sql/people_array.json | 13 + .../expressions/jsonExpressions.scala | 10 +- .../catalyst/json/CreateJacksonParser.scala | 46 +++ .../spark/sql/catalyst/json/JSONOptions.scala | 20 +- .../sql/catalyst/json/JacksonParser.scala | 287 ++++++++++-------- .../apache/spark/sql/DataFrameReader.scala | 32 +- .../execution/datasources/CodecStreams.scala | 17 +- .../datasources/json/JsonDataSource.scala | 216 +++++++++++++ .../datasources/json/JsonFileFormat.scala | 96 ++---- .../datasources/json/JsonInferSchema.scala | 13 +- .../sql/streaming/DataStreamReader.scala | 8 +- .../datasources/json/JsonSuite.scala | 152 +++++++++- 17 files changed, 740 insertions(+), 231 deletions(-) create mode 100644 python/test_support/sql/people_array.json create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 3800d53c02f4c..87b9e8eb445aa 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -147,7 +147,13 @@ public void writeTo(ByteBuffer buffer) { buffer.position(pos + numBytes); } - public void writeTo(OutputStream out) throws IOException { + /** + * Returns a {@link ByteBuffer} wrapping the base object if it is a byte array + * or a copy of the data if the base object is not a byte array. + * + * Unlike getBytes this will not create a copy the array if this is a slice. + */ + public @Nonnull ByteBuffer getByteBuffer() { if (base instanceof byte[] && offset >= BYTE_ARRAY_OFFSET) { final byte[] bytes = (byte[]) base; @@ -160,12 +166,20 @@ public void writeTo(OutputStream out) throws IOException { throw new ArrayIndexOutOfBoundsException(); } - out.write(bytes, (int) arrayOffset, numBytes); + return ByteBuffer.wrap(bytes, (int) arrayOffset, numBytes); } else { - out.write(getBytes()); + return ByteBuffer.wrap(getBytes()); } } + public void writeTo(OutputStream out) throws IOException { + final ByteBuffer bb = this.getByteBuffer(); + assert(bb.hasArray()); + + // similar to Utils.writeByteBuffer but without the spark-core dependency + out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()); + } + /** * Returns the number of bytes for a code point with the first byte as `b` * @param b The first byte of a code point diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index 59404e08895a3..9606c4754314f 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFil import org.apache.spark.internal.config import org.apache.spark.SparkContext +import org.apache.spark.annotation.Since /** * A general format for reading whole files in as streams, byte arrays, @@ -175,6 +176,7 @@ class PortableDataStream( * Create a new DataInputStream from the split and context. The user of this method is responsible * for closing the stream after usage. */ + @Since("1.2.0") def open(): DataInputStream = { val pathp = split.getPath(index) val fs = pathp.getFileSystem(conf) @@ -184,6 +186,7 @@ class PortableDataStream( /** * Read the file as a byte array */ + @Since("1.2.0") def toArray(): Array[Byte] = { val stream = open() try { @@ -193,6 +196,10 @@ class PortableDataStream( } } + @Since("1.2.0") def getPath(): String = path + + @Since("2.2.0") + def getConfiguration: Configuration = conf } diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 167833488980a..6bed390e60c96 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -159,11 +159,12 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, - timeZone=None): + timeZone=None, wholeFile=None): """ - Loads a JSON file (`JSON Lines text format or newline-delimited JSON - `_) or an RDD of Strings storing JSON objects (one object per - record) and returns the result as a :class`DataFrame`. + Loads a JSON file and returns the results as a :class:`DataFrame`. + + Both JSON (one record per file) and `JSON Lines `_ + (newline-delimited JSON) are supported and can be selected with the `wholeFile` parameter. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. @@ -212,6 +213,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. If None is set, it uses the default value, session local timezone. + :param wholeFile: parse one record, which may span multiple lines, per file. If None is + set, it uses the default value, ``false``. >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes @@ -228,7 +231,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, - timestampFormat=timestampFormat, timeZone=timeZone) + timestampFormat=timestampFormat, timeZone=timeZone, wholeFile=wholeFile) if isinstance(path, basestring): path = [path] if type(path) == list: diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index d988e596a86d9..965c8f6b269e9 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -428,11 +428,13 @@ def load(self, path=None, format=None, schema=None, **options): def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None, allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None, - mode=None, columnNameOfCorruptRecord=None, dateFormat=None, - timestampFormat=None, timeZone=None): + mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None, + timeZone=None, wholeFile=None): """ - Loads a JSON file stream (`JSON Lines text format or newline-delimited JSON - `_) and returns a :class`DataFrame`. + Loads a JSON file stream and returns the results as a :class:`DataFrame`. + + Both JSON (one record per file) and `JSON Lines `_ + (newline-delimited JSON) are supported and can be selected with the `wholeFile` parameter. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. @@ -483,6 +485,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``. :param timeZone: sets the string that indicates a timezone to be used to parse timestamps. If None is set, it uses the default value, session local timezone. + :param wholeFile: parse one record, which may span multiple lines, per file. If None is + set, it uses the default value, ``false``. >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema) >>> json_sdf.isStreaming @@ -496,7 +500,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero, allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter, mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat, - timestampFormat=timestampFormat, timeZone=timeZone) + timestampFormat=timestampFormat, timeZone=timeZone, wholeFile=wholeFile) if isinstance(path, basestring): return self._df(self._jreader.json(path)) else: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d8b7b3137c1c9..9058443285aca 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -439,6 +439,13 @@ def test_udf_with_order_by_and_limit(self): res.explain(True) self.assertEqual(res.collect(), [Row(id=0, copy=0)]) + def test_wholefile_json(self): + from pyspark.sql.types import StringType + people1 = self.spark.read.json("python/test_support/sql/people.json") + people_array = self.spark.read.json("python/test_support/sql/people_array.json", + wholeFile=True) + self.assertEqual(people1.collect(), people_array.collect()) + def test_udf_with_input_file_name(self): from pyspark.sql.functions import udf, input_file_name from pyspark.sql.types import StringType diff --git a/python/test_support/sql/people_array.json b/python/test_support/sql/people_array.json new file mode 100644 index 0000000000000..c27c48fe343e4 --- /dev/null +++ b/python/test_support/sql/people_array.json @@ -0,0 +1,13 @@ +[ + { + "name": "Michael" + }, + { + "name": "Andy", + "age": 30 + }, + { + "name": "Justin", + "age": 19 + } +] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index bd852a50fe71e..1e690a446951e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -497,8 +497,7 @@ case class JsonToStruct( lazy val parser = new JacksonParser( schema, - "invalid", // Not used since we force fail fast. Invalid rows will be set to `null`. - new JSONOptions(options ++ Map("mode" -> ParseModes.FAIL_FAST_MODE), timeZoneId.get)) + new JSONOptions(options + ("mode" -> ParseModes.FAIL_FAST_MODE), timeZoneId.get)) override def dataType: DataType = schema @@ -506,7 +505,12 @@ case class JsonToStruct( copy(timeZoneId = Option(timeZoneId)) override def nullSafeEval(json: Any): Any = { - try parser.parse(json.toString).headOption.orNull catch { + try { + parser.parse( + json.asInstanceOf[UTF8String], + CreateJacksonParser.utf8String, + identity[UTF8String]).headOption.orNull + } catch { case _: SparkSQLJsonProcessingException => null } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala new file mode 100644 index 0000000000000..e0ed03a68981a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.json + +import java.io.InputStream + +import com.fasterxml.jackson.core.{JsonFactory, JsonParser} +import org.apache.hadoop.io.Text + +import org.apache.spark.unsafe.types.UTF8String + +private[sql] object CreateJacksonParser extends Serializable { + def string(jsonFactory: JsonFactory, record: String): JsonParser = { + jsonFactory.createParser(record) + } + + def utf8String(jsonFactory: JsonFactory, record: UTF8String): JsonParser = { + val bb = record.getByteBuffer + assert(bb.hasArray) + + jsonFactory.createParser(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + } + + def text(jsonFactory: JsonFactory, record: Text): JsonParser = { + jsonFactory.createParser(record.getBytes, 0, record.getLength) + } + + def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = { + jsonFactory.createParser(record) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index 5307ce1cb711d..5a91f9c1939aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -31,11 +31,20 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs * Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]]. */ private[sql] class JSONOptions( - @transient private val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String) + @transient private val parameters: CaseInsensitiveMap[String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String) extends Logging with Serializable { - def this(parameters: Map[String, String], defaultTimeZoneId: String) = - this(CaseInsensitiveMap(parameters), defaultTimeZoneId) + def this( + parameters: Map[String, String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String = "") = { + this( + CaseInsensitiveMap(parameters), + defaultTimeZoneId, + defaultColumnNameOfCorruptRecord) + } val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) @@ -57,7 +66,8 @@ private[sql] class JSONOptions( parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName) private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") - val columnNameOfCorruptRecord = parameters.get("columnNameOfCorruptRecord") + val columnNameOfCorruptRecord = + parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord) val timeZone: TimeZone = TimeZone.getTimeZone(parameters.getOrElse("timeZone", defaultTimeZoneId)) @@ -69,6 +79,8 @@ private[sql] class JSONOptions( FastDateFormat.getInstance( parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), timeZone, Locale.US) + val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false) + // Parse mode flags if (!ParseModes.isValidMode(parseMode)) { logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 03e27ba934fb0..995095969d7af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -39,7 +39,6 @@ private[sql] class SparkSQLJsonProcessingException(msg: String) extends RuntimeE */ class JacksonParser( schema: StructType, - columnNameOfCorruptRecord: String, options: JSONOptions) extends Logging { import JacksonUtils._ @@ -48,69 +47,110 @@ class JacksonParser( // A `ValueConverter` is responsible for converting a value from `JsonParser` // to a value in a field for `InternalRow`. - private type ValueConverter = (JsonParser) => Any + private type ValueConverter = JsonParser => AnyRef // `ValueConverter`s for the root schema for all fields in the schema - private val rootConverter: ValueConverter = makeRootConverter(schema) + private val rootConverter = makeRootConverter(schema) private val factory = new JsonFactory() options.setJacksonOptions(factory) private val emptyRow: Seq[InternalRow] = Seq(new GenericInternalRow(schema.length)) + private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord) + corruptFieldIndex.foreach(idx => require(schema(idx).dataType == StringType)) + + @transient + private[this] var isWarningPrinted: Boolean = false + @transient - private[this] var isWarningPrintedForMalformedRecord: Boolean = false + private def printWarningForMalformedRecord(record: () => UTF8String): Unit = { + def sampleRecord: String = { + if (options.wholeFile) { + "" + } else { + s"Sample record: ${record()}\n" + } + } + + def footer: String = { + s"""Code example to print all malformed records (scala): + |=================================================== + |// The corrupted record exists in column ${options.columnNameOfCorruptRecord}. + |val parsedJson = spark.read.json("/path/to/json/file/test.json") + | + """.stripMargin + } + + if (options.permissive) { + logWarning( + s"""Found at least one malformed record. The JSON reader will replace + |all malformed records with placeholder null in current $PERMISSIVE_MODE parser mode. + |To find out which corrupted records have been replaced with null, please use the + |default inferred schema instead of providing a custom schema. + | + |${sampleRecord ++ footer} + | + """.stripMargin) + } else if (options.dropMalformed) { + logWarning( + s"""Found at least one malformed record. The JSON reader will drop + |all malformed records in current $DROP_MALFORMED_MODE parser mode. To find out which + |corrupted records have been dropped, please switch the parser mode to $PERMISSIVE_MODE + |mode and use the default inferred schema. + | + |${sampleRecord ++ footer} + | + """.stripMargin) + } + } + + @transient + private def printWarningIfWholeFile(): Unit = { + if (options.wholeFile && corruptFieldIndex.isDefined) { + logWarning( + s"""Enabling wholeFile mode and defining columnNameOfCorruptRecord may result + |in very large allocations or OutOfMemoryExceptions being raised. + | + """.stripMargin) + } + } /** * This function deals with the cases it fails to parse. This function will be called * when exceptions are caught during converting. This functions also deals with `mode` option. */ - private def failedRecord(record: String): Seq[InternalRow] = { - // create a row even if no corrupt record column is present - if (options.failFast) { - throw new SparkSQLJsonProcessingException(s"Malformed line in FAILFAST mode: $record") - } - if (options.dropMalformed) { - if (!isWarningPrintedForMalformedRecord) { - logWarning( - s"""Found at least one malformed records (sample: $record). The JSON reader will drop - |all malformed records in current $DROP_MALFORMED_MODE parser mode. To find out which - |corrupted records have been dropped, please switch the parser mode to $PERMISSIVE_MODE - |mode and use the default inferred schema. - | - |Code example to print all malformed records (scala): - |=================================================== - |// The corrupted record exists in column ${columnNameOfCorruptRecord} - |val parsedJson = spark.read.json("/path/to/json/file/test.json") - | - """.stripMargin) - isWarningPrintedForMalformedRecord = true - } - Nil - } else if (schema.getFieldIndex(columnNameOfCorruptRecord).isEmpty) { - if (!isWarningPrintedForMalformedRecord) { - logWarning( - s"""Found at least one malformed records (sample: $record). The JSON reader will replace - |all malformed records with placeholder null in current $PERMISSIVE_MODE parser mode. - |To find out which corrupted records have been replaced with null, please use the - |default inferred schema instead of providing a custom schema. - | - |Code example to print all malformed records (scala): - |=================================================== - |// The corrupted record exists in column ${columnNameOfCorruptRecord}. - |val parsedJson = spark.read.json("/path/to/json/file/test.json") - | - """.stripMargin) - isWarningPrintedForMalformedRecord = true - } - emptyRow - } else { - val row = new GenericInternalRow(schema.length) - for (corruptIndex <- schema.getFieldIndex(columnNameOfCorruptRecord)) { - require(schema(corruptIndex).dataType == StringType) - row.update(corruptIndex, UTF8String.fromString(record)) - } - Seq(row) + private def failedRecord(record: () => UTF8String): Seq[InternalRow] = { + corruptFieldIndex match { + case _ if options.failFast => + if (options.wholeFile) { + throw new SparkSQLJsonProcessingException("Malformed line in FAILFAST mode") + } else { + throw new SparkSQLJsonProcessingException(s"Malformed line in FAILFAST mode: ${record()}") + } + + case _ if options.dropMalformed => + if (!isWarningPrinted) { + printWarningForMalformedRecord(record) + isWarningPrinted = true + } + Nil + + case None => + if (!isWarningPrinted) { + printWarningForMalformedRecord(record) + isWarningPrinted = true + } + emptyRow + + case Some(corruptIndex) => + if (!isWarningPrinted) { + printWarningIfWholeFile() + isWarningPrinted = true + } + val row = new GenericInternalRow(schema.length) + row.update(corruptIndex, record()) + Seq(row) } } @@ -119,11 +159,11 @@ class JacksonParser( * to a value according to a desired schema. This is a wrapper for the method * `makeConverter()` to handle a row wrapped with an array. */ - private def makeRootConverter(st: StructType): ValueConverter = { + private def makeRootConverter(st: StructType): JsonParser => Seq[InternalRow] = { val elementConverter = makeConverter(st) - val fieldConverters = st.map(_.dataType).map(makeConverter) - (parser: JsonParser) => parseJsonToken(parser, st) { - case START_OBJECT => convertObject(parser, st, fieldConverters) + val fieldConverters = st.map(_.dataType).map(makeConverter).toArray + (parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, st) { + case START_OBJECT => convertObject(parser, st, fieldConverters) :: Nil // SPARK-3308: support reading top level JSON arrays and take every element // in such an array as a row // @@ -137,7 +177,15 @@ class JacksonParser( // List([str_a_1,null]) // List([str_a_2,null], [null,str_b_3]) // - case START_ARRAY => convertArray(parser, elementConverter) + case START_ARRAY => + val array = convertArray(parser, elementConverter) + // Here, as we support reading top level JSON arrays and take every element + // in such an array as a row, this case is possible. + if (array.numElements() == 0) { + Nil + } else { + array.toArray[InternalRow](schema).toSeq + } } } @@ -145,35 +193,35 @@ class JacksonParser( * Create a converter which converts the JSON documents held by the `JsonParser` * to a value according to a desired schema. */ - private[sql] def makeConverter(dataType: DataType): ValueConverter = dataType match { + def makeConverter(dataType: DataType): ValueConverter = dataType match { case BooleanType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[java.lang.Boolean](parser, dataType) { case VALUE_TRUE => true case VALUE_FALSE => false } case ByteType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[java.lang.Byte](parser, dataType) { case VALUE_NUMBER_INT => parser.getByteValue } case ShortType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[java.lang.Short](parser, dataType) { case VALUE_NUMBER_INT => parser.getShortValue } case IntegerType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[java.lang.Integer](parser, dataType) { case VALUE_NUMBER_INT => parser.getIntValue } case LongType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[java.lang.Long](parser, dataType) { case VALUE_NUMBER_INT => parser.getLongValue } case FloatType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[java.lang.Float](parser, dataType) { case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT => parser.getFloatValue @@ -193,7 +241,7 @@ class JacksonParser( } case DoubleType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[java.lang.Double](parser, dataType) { case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT => parser.getDoubleValue @@ -213,7 +261,7 @@ class JacksonParser( } case StringType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[UTF8String](parser, dataType) { case VALUE_STRING => UTF8String.fromString(parser.getText) @@ -227,66 +275,71 @@ class JacksonParser( } case TimestampType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[java.lang.Long](parser, dataType) { case VALUE_STRING => + val stringValue = parser.getText // This one will lose microseconds parts. // See https://issues.apache.org/jira/browse/SPARK-10681. - Try(options.timestampFormat.parse(parser.getText).getTime * 1000L) - .getOrElse { - // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards - // compatibility. - DateTimeUtils.stringToTime(parser.getText).getTime * 1000L - } + Long.box { + Try(options.timestampFormat.parse(stringValue).getTime * 1000L) + .getOrElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + DateTimeUtils.stringToTime(stringValue).getTime * 1000L + } + } case VALUE_NUMBER_INT => parser.getLongValue * 1000000L } case DateType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[java.lang.Integer](parser, dataType) { case VALUE_STRING => val stringValue = parser.getText // This one will lose microseconds parts. // See https://issues.apache.org/jira/browse/SPARK-10681.x - Try(DateTimeUtils.millisToDays(options.dateFormat.parse(parser.getText).getTime)) - .getOrElse { - // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards - // compatibility. - Try(DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(parser.getText).getTime)) + Int.box { + Try(DateTimeUtils.millisToDays(options.dateFormat.parse(stringValue).getTime)) + .orElse { + // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards + // compatibility. + Try(DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(stringValue).getTime)) + } .getOrElse { - // In Spark 1.5.0, we store the data as number of days since epoch in string. - // So, we just convert it to Int. - stringValue.toInt - } + // In Spark 1.5.0, we store the data as number of days since epoch in string. + // So, we just convert it to Int. + stringValue.toInt + } } } case BinaryType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[Array[Byte]](parser, dataType) { case VALUE_STRING => parser.getBinaryValue } case dt: DecimalType => - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[Decimal](parser, dataType) { case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) => Decimal(parser.getDecimalValue, dt.precision, dt.scale) } case st: StructType => - val fieldConverters = st.map(_.dataType).map(makeConverter) - (parser: JsonParser) => parseJsonToken(parser, dataType) { + val fieldConverters = st.map(_.dataType).map(makeConverter).toArray + (parser: JsonParser) => parseJsonToken[InternalRow](parser, dataType) { case START_OBJECT => convertObject(parser, st, fieldConverters) } case at: ArrayType => val elementConverter = makeConverter(at.elementType) - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[ArrayData](parser, dataType) { case START_ARRAY => convertArray(parser, elementConverter) } case mt: MapType => val valueConverter = makeConverter(mt.valueType) - (parser: JsonParser) => parseJsonToken(parser, dataType) { + (parser: JsonParser) => parseJsonToken[MapData](parser, dataType) { case START_OBJECT => convertMap(parser, valueConverter) } @@ -298,7 +351,7 @@ class JacksonParser( // Here, we pass empty `PartialFunction` so that this case can be // handled as a failed conversion. It will throw an exception as // long as the value is not null. - parseJsonToken(parser, dataType)(PartialFunction.empty[JsonToken, Any]) + parseJsonToken[AnyRef](parser, dataType)(PartialFunction.empty[JsonToken, AnyRef]) } /** @@ -306,14 +359,14 @@ class JacksonParser( * to parse the JSON token using given function `f`. If the `f` failed to parse and convert the * token, call `failedConversion` to handle the token. */ - private def parseJsonToken( + private def parseJsonToken[R >: Null]( parser: JsonParser, - dataType: DataType)(f: PartialFunction[JsonToken, Any]): Any = { + dataType: DataType)(f: PartialFunction[JsonToken, R]): R = { parser.getCurrentToken match { case FIELD_NAME => // There are useless FIELD_NAMEs between START_OBJECT and END_OBJECT tokens parser.nextToken() - parseJsonToken(parser, dataType)(f) + parseJsonToken[R](parser, dataType)(f) case null | VALUE_NULL => null @@ -325,9 +378,9 @@ class JacksonParser( * This function throws an exception for failed conversion, but returns null for empty string, * to guard the non string types. */ - private def failedConversion( + private def failedConversion[R >: Null]( parser: JsonParser, - dataType: DataType): PartialFunction[JsonToken, Any] = { + dataType: DataType): PartialFunction[JsonToken, R] = { case VALUE_STRING if parser.getTextLength < 1 => // If conversion is failed, this produces `null` rather than throwing exception. // This will protect the mismatch of types. @@ -348,7 +401,7 @@ class JacksonParser( private def convertObject( parser: JsonParser, schema: StructType, - fieldConverters: Seq[ValueConverter]): InternalRow = { + fieldConverters: Array[ValueConverter]): InternalRow = { val row = new GenericInternalRow(schema.length) while (nextUntil(parser, JsonToken.END_OBJECT)) { schema.getFieldIndex(parser.getCurrentName) match { @@ -394,36 +447,30 @@ class JacksonParser( } /** - * Parse the string JSON input to the set of [[InternalRow]]s. + * Parse the JSON input to the set of [[InternalRow]]s. + * + * @param recordLiteral an optional function that will be used to generate + * the corrupt record text instead of record.toString */ - def parse(input: String): Seq[InternalRow] = { - if (input.trim.isEmpty) { - Nil - } else { - try { - Utils.tryWithResource(factory.createParser(input)) { parser => - parser.nextToken() - rootConverter.apply(parser) match { - case null => failedRecord(input) - case row: InternalRow => row :: Nil - case array: ArrayData => - // Here, as we support reading top level JSON arrays and take every element - // in such an array as a row, this case is possible. - if (array.numElements() == 0) { - Nil - } else { - array.toArray[InternalRow](schema) - } - case _ => - failedRecord(input) + def parse[T]( + record: T, + createParser: (JsonFactory, T) => JsonParser, + recordLiteral: T => UTF8String): Seq[InternalRow] = { + try { + Utils.tryWithResource(createParser(factory, record)) { parser => + // a null first token is equivalent to testing for input.trim.isEmpty + // but it works on any token stream and not just strings + parser.nextToken() match { + case null => Nil + case _ => rootConverter.apply(parser) match { + case null => throw new SparkSQLJsonProcessingException("Root converter returned null") + case rows => rows } } - } catch { - case _: JsonProcessingException => - failedRecord(input) - case _: SparkSQLJsonProcessingException => - failedRecord(input) } + } catch { + case _: JsonProcessingException | _: SparkSQLJsonProcessingException => + failedRecord(() => recordLiteral(record)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 780fe51ac699d..cb9493a575643 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -26,14 +26,14 @@ import org.apache.spark.internal.Logging import org.apache.spark.Partition import org.apache.spark.annotation.InterfaceStability import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions} -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.JsonInferSchema import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String /** * Interface used to load a [[Dataset]] from external storage systems (e.g. file systems, @@ -261,8 +261,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } /** - * Loads a JSON file (JSON Lines text format or - * newline-delimited JSON) and returns the result as a `DataFrame`. + * Loads a JSON file and returns the results as a `DataFrame`. + * + * Both JSON (one record per file) and JSON Lines + * (newline-delimited JSON) are supported and can be selected with the `wholeFile` option. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. @@ -301,6 +303,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * `java.text.SimpleDateFormat`. This applies to timestamp type. *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone * to be used to parse timestamps.
  • + *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines, + * per file
  • * * * @since 2.0.0 @@ -332,20 +336,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 1.4.0 */ def json(jsonRDD: RDD[String]): DataFrame = { - val parsedOptions: JSONOptions = - new JSONOptions(extraOptions.toMap, sparkSession.sessionState.conf.sessionLocalTimeZone) - val columnNameOfCorruptRecord = - parsedOptions.columnNameOfCorruptRecord - .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val parsedOptions = new JSONOptions( + extraOptions.toMap, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val createParser = CreateJacksonParser.string _ + val schema = userSpecifiedSchema.getOrElse { JsonInferSchema.infer( jsonRDD, - columnNameOfCorruptRecord, - parsedOptions) + parsedOptions, + createParser) } + val parsed = jsonRDD.mapPartitions { iter => - val parser = new JacksonParser(schema, columnNameOfCorruptRecord, parsedOptions) - iter.flatMap(parser.parse) + val parser = new JacksonParser(schema, parsedOptions) + iter.flatMap(parser.parse(_, createParser, UTF8String.fromString)) } Dataset.ofRows( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala index 900263aeb21d6..0762d1b7daaea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CodecStreams.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.execution.datasources -import java.io.{OutputStream, OutputStreamWriter} +import java.io.{InputStream, OutputStream, OutputStreamWriter} import java.nio.charset.{Charset, StandardCharsets} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.compress._ import org.apache.hadoop.mapreduce.JobContext @@ -27,6 +28,20 @@ import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.util.ReflectionUtils object CodecStreams { + private def getDecompressionCodec(config: Configuration, file: Path): Option[CompressionCodec] = { + val compressionCodecs = new CompressionCodecFactory(config) + Option(compressionCodecs.getCodec(file)) + } + + def createInputStream(config: Configuration, file: Path): InputStream = { + val fs = file.getFileSystem(config) + val inputStream: InputStream = fs.open(file) + + getDecompressionCodec(config, file) + .map(codec => codec.createInputStream(inputStream)) + .getOrElse(inputStream) + } + private def getCompressionCodec( context: JobContext, file: Option[Path] = None): Option[CompressionCodec] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala new file mode 100644 index 0000000000000..3e984effcb8d8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.json + +import java.io.InputStream + +import scala.reflect.ClassTag + +import com.fasterxml.jackson.core.{JsonFactory, JsonParser} +import com.google.common.io.ByteStreams +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, TextInputFormat} + +import org.apache.spark.TaskContext +import org.apache.spark.input.{PortableDataStream, StreamInputFormat} +import org.apache.spark.rdd.{BinaryFileRDD, RDD} +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.execution.datasources.{CodecStreams, HadoopFileLinesReader, PartitionedFile} +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +/** + * Common functions for parsing JSON files + * @tparam T A datatype containing the unparsed JSON, such as [[Text]] or [[String]] + */ +abstract class JsonDataSource[T] extends Serializable { + def isSplitable: Boolean + + /** + * Parse a [[PartitionedFile]] into 0 or more [[InternalRow]] instances + */ + def readFile( + conf: Configuration, + file: PartitionedFile, + parser: JacksonParser): Iterator[InternalRow] + + /** + * Create an [[RDD]] that handles the preliminary parsing of [[T]] records + */ + protected def createBaseRdd( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus]): RDD[T] + + /** + * A generic wrapper to invoke the correct [[JsonFactory]] method to allocate a [[JsonParser]] + * for an instance of [[T]] + */ + def createParser(jsonFactory: JsonFactory, value: T): JsonParser + + final def infer( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus], + parsedOptions: JSONOptions): Option[StructType] = { + if (inputPaths.nonEmpty) { + val jsonSchema = JsonInferSchema.infer( + createBaseRdd(sparkSession, inputPaths), + parsedOptions, + createParser) + checkConstraints(jsonSchema) + Some(jsonSchema) + } else { + None + } + } + + /** Constraints to be imposed on schema to be stored. */ + private def checkConstraints(schema: StructType): Unit = { + if (schema.fieldNames.length != schema.fieldNames.distinct.length) { + val duplicateColumns = schema.fieldNames.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => "\"" + x + "\"" + }.mkString(", ") + throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + + s"cannot save to JSON format") + } + } +} + +object JsonDataSource { + def apply(options: JSONOptions): JsonDataSource[_] = { + if (options.wholeFile) { + WholeFileJsonDataSource + } else { + TextInputJsonDataSource + } + } + + /** + * Create a new [[RDD]] via the supplied callback if there is at least one file to process, + * otherwise an [[org.apache.spark.rdd.EmptyRDD]] will be returned. + */ + def createBaseRdd[T : ClassTag]( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus])( + fn: (Configuration, String) => RDD[T]): RDD[T] = { + val paths = inputPaths.map(_.getPath) + + if (paths.nonEmpty) { + val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) + FileInputFormat.setInputPaths(job, paths: _*) + fn(job.getConfiguration, paths.mkString(",")) + } else { + sparkSession.sparkContext.emptyRDD[T] + } + } +} + +object TextInputJsonDataSource extends JsonDataSource[Text] { + override val isSplitable: Boolean = { + // splittable if the underlying source is + true + } + + override protected def createBaseRdd( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus]): RDD[Text] = { + JsonDataSource.createBaseRdd(sparkSession, inputPaths) { + case (conf, name) => + sparkSession.sparkContext.newAPIHadoopRDD( + conf, + classOf[TextInputFormat], + classOf[LongWritable], + classOf[Text]) + .setName(s"JsonLines: $name") + .values // get the text column + } + } + + override def readFile( + conf: Configuration, + file: PartitionedFile, + parser: JacksonParser): Iterator[InternalRow] = { + val linesReader = new HadoopFileLinesReader(file, conf) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) + linesReader.flatMap(parser.parse(_, createParser, textToUTF8String)) + } + + private def textToUTF8String(value: Text): UTF8String = { + UTF8String.fromBytes(value.getBytes, 0, value.getLength) + } + + override def createParser(jsonFactory: JsonFactory, value: Text): JsonParser = { + CreateJacksonParser.text(jsonFactory, value) + } +} + +object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] { + override val isSplitable: Boolean = { + false + } + + override protected def createBaseRdd( + sparkSession: SparkSession, + inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = { + JsonDataSource.createBaseRdd(sparkSession, inputPaths) { + case (conf, name) => + new BinaryFileRDD( + sparkSession.sparkContext, + classOf[StreamInputFormat], + classOf[String], + classOf[PortableDataStream], + conf, + sparkSession.sparkContext.defaultMinPartitions) + .setName(s"JsonFile: $name") + .values + } + } + + private def createInputStream(config: Configuration, path: String): InputStream = { + val inputStream = CodecStreams.createInputStream(config, new Path(path)) + Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => inputStream.close())) + inputStream + } + + override def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = { + CreateJacksonParser.inputStream( + jsonFactory, + createInputStream(record.getConfiguration, record.getPath())) + } + + override def readFile( + conf: Configuration, + file: PartitionedFile, + parser: JacksonParser): Iterator[InternalRow] = { + def partitionedFileString(ignored: Any): UTF8String = { + Utils.tryWithResource(createInputStream(conf, file.filePath)) { inputStream => + UTF8String.fromBytes(ByteStreams.toByteArray(inputStream)) + } + } + + parser.parse( + createInputStream(conf, file.filePath), + CreateJacksonParser.inputStream, + partitionedFileString).toIterator + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index b4a8ff2cf01ad..2cbf4ea7beaca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -19,15 +19,10 @@ package org.apache.spark.sql.execution.datasources.json import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.{LongWritable, Text} -import org.apache.hadoop.mapred.{JobConf, TextInputFormat} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import org.apache.spark.TaskContext import org.apache.spark.internal.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.CompressionCodecs @@ -37,29 +32,30 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { + override val shortName: String = "json" - override def shortName(): String = "json" + override def isSplitable( + sparkSession: SparkSession, + options: Map[String, String], + path: Path): Boolean = { + val parsedOptions = new JSONOptions( + options, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val jsonDataSource = JsonDataSource(parsedOptions) + jsonDataSource.isSplitable && super.isSplitable(sparkSession, options, path) + } override def inferSchema( sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - if (files.isEmpty) { - None - } else { - val parsedOptions: JSONOptions = - new JSONOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) - val columnNameOfCorruptRecord = - parsedOptions.columnNameOfCorruptRecord - .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord) - val jsonSchema = JsonInferSchema.infer( - createBaseRdd(sparkSession, files), - columnNameOfCorruptRecord, - parsedOptions) - checkConstraints(jsonSchema) - - Some(jsonSchema) - } + val parsedOptions = new JSONOptions( + options, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) + JsonDataSource(parsedOptions).infer( + sparkSession, files, parsedOptions) } override def prepareWrite( @@ -68,8 +64,10 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { val conf = job.getConfiguration - val parsedOptions: JSONOptions = - new JSONOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) + val parsedOptions = new JSONOptions( + options, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) parsedOptions.compressionCodec.foreach { codec => CompressionCodecs.setCodecConfiguration(conf, codec) } @@ -99,47 +97,17 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - val parsedOptions: JSONOptions = - new JSONOptions(options, sparkSession.sessionState.conf.sessionLocalTimeZone) - val columnNameOfCorruptRecord = parsedOptions.columnNameOfCorruptRecord - .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord) + val parsedOptions = new JSONOptions( + options, + sparkSession.sessionState.conf.sessionLocalTimeZone, + sparkSession.sessionState.conf.columnNameOfCorruptRecord) (file: PartitionedFile) => { - val linesReader = new HadoopFileLinesReader(file, broadcastedHadoopConf.value.value) - Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) - val lines = linesReader.map(_.toString) - val parser = new JacksonParser(requiredSchema, columnNameOfCorruptRecord, parsedOptions) - lines.flatMap(parser.parse) - } - } - - private def createBaseRdd( - sparkSession: SparkSession, - inputPaths: Seq[FileStatus]): RDD[String] = { - val job = Job.getInstance(sparkSession.sessionState.newHadoopConf()) - val conf = job.getConfiguration - - val paths = inputPaths.map(_.getPath) - - if (paths.nonEmpty) { - FileInputFormat.setInputPaths(job, paths: _*) - } - - sparkSession.sparkContext.hadoopRDD( - conf.asInstanceOf[JobConf], - classOf[TextInputFormat], - classOf[LongWritable], - classOf[Text]).map(_._2.toString) // get the text line - } - - /** Constraints to be imposed on schema to be stored. */ - private def checkConstraints(schema: StructType): Unit = { - if (schema.fieldNames.length != schema.fieldNames.distinct.length) { - val duplicateColumns = schema.fieldNames.groupBy(identity).collect { - case (x, ys) if ys.length > 1 => "\"" + x + "\"" - }.mkString(", ") - throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + - s"cannot save to JSON format") + val parser = new JacksonParser(requiredSchema, parsedOptions) + JsonDataSource(parsedOptions).readFile( + broadcastedHadoopConf.value.value, + file, + parser) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index f51c18d46f45d..ab09358115c0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -36,13 +36,14 @@ private[sql] object JsonInferSchema { * 2. Merge types by choosing the lowest type necessary to cover equal keys * 3. Replace any remaining null fields with string, the top type */ - def infer( - json: RDD[String], - columnNameOfCorruptRecord: String, - configOptions: JSONOptions): StructType = { + def infer[T]( + json: RDD[T], + configOptions: JSONOptions, + createParser: (JsonFactory, T) => JsonParser): StructType = { require(configOptions.samplingRatio > 0, s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0") val shouldHandleCorruptRecord = configOptions.permissive + val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord val schemaData = if (configOptions.samplingRatio > 0.99) { json } else { @@ -55,7 +56,7 @@ private[sql] object JsonInferSchema { configOptions.setJacksonOptions(factory) iter.flatMap { row => try { - Utils.tryWithResource(factory.createParser(row)) { parser => + Utils.tryWithResource(createParser(factory, row)) { parser => parser.nextToken() Some(inferField(parser, configOptions)) } @@ -79,7 +80,7 @@ private[sql] object JsonInferSchema { private[this] val structFieldComparator = new Comparator[StructField] { override def compare(o1: StructField, o2: StructField): Int = { - o1.name.compare(o2.name) + o1.name.compareTo(o2.name) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 4e706da184c0b..99943944f3c6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -141,8 +141,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } /** - * Loads a JSON file stream (JSON Lines text format or - * newline-delimited JSON) and returns the result as a `DataFrame`. + * Loads a JSON file stream and returns the results as a `DataFrame`. + * + * Both JSON (one record per file) and JSON Lines + * (newline-delimited JSON) are supported and can be selected with the `wholeFile` option. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. @@ -183,6 +185,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * `java.text.SimpleDateFormat`. This applies to timestamp type. *
  • `timeZone` (default session local timezone): sets the string that indicates a timezone * to be used to parse timestamps.
  • + *
  • `wholeFile` (default `false`): parse one record, which may span multiple lines, + * per file
  • * * * @since 2.0.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 9344aeda00175..05aa2ab2ce2d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -28,8 +28,8 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.rdd.RDD import org.apache.spark.SparkException -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions} +import org.apache.spark.sql.{functions => F, _} +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.json.JsonInferSchema.compatibleType @@ -64,7 +64,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val dummyOption = new JSONOptions(Map.empty[String, String], "GMT") val dummySchema = StructType(Seq.empty) - val parser = new JacksonParser(dummySchema, "", dummyOption) + val parser = new JacksonParser(dummySchema, dummyOption) Utils.tryWithResource(factory.createParser(writer.toString)) { jsonParser => jsonParser.nextToken() @@ -1367,7 +1367,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-6245 JsonRDD.inferSchema on empty RDD") { // This is really a test that it doesn't throw an exception val emptySchema = JsonInferSchema.infer( - empty, "", new JSONOptions(Map.empty[String, String], "GMT")) + empty, + new JSONOptions(Map.empty[String, String], "GMT"), + CreateJacksonParser.string) assert(StructType(Seq()) === emptySchema) } @@ -1392,7 +1394,9 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-8093 Erase empty structs") { val emptySchema = JsonInferSchema.infer( - emptyRecords, "", new JSONOptions(Map.empty[String, String], "GMT")) + emptyRecords, + new JSONOptions(Map.empty[String, String], "GMT"), + CreateJacksonParser.string) assert(StructType(Seq()) === emptySchema) } @@ -1802,4 +1806,142 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { val df2 = spark.read.option("PREfersdecimaL", "true").json(records) assert(df2.schema == schema) } + + test("SPARK-18352: Parse normal multi-line JSON files (compressed)") { + withTempPath { dir => + val path = dir.getCanonicalPath + primitiveFieldAndType + .toDF("value") + .write + .option("compression", "GzIp") + .text(path) + + assert(new File(path).listFiles().exists(_.getName.endsWith(".gz"))) + + val jsonDF = spark.read.option("wholeFile", true).json(path) + val jsonDir = new File(dir, "json").getCanonicalPath + jsonDF.coalesce(1).write + .option("compression", "gZiP") + .json(jsonDir) + + assert(new File(jsonDir).listFiles().exists(_.getName.endsWith(".json.gz"))) + + val originalData = spark.read.json(primitiveFieldAndType) + checkAnswer(jsonDF, originalData) + checkAnswer(spark.read.schema(originalData.schema).json(jsonDir), originalData) + } + } + + test("SPARK-18352: Parse normal multi-line JSON files (uncompressed)") { + withTempPath { dir => + val path = dir.getCanonicalPath + primitiveFieldAndType + .toDF("value") + .write + .text(path) + + val jsonDF = spark.read.option("wholeFile", true).json(path) + val jsonDir = new File(dir, "json").getCanonicalPath + jsonDF.coalesce(1).write.json(jsonDir) + + val compressedFiles = new File(jsonDir).listFiles() + assert(compressedFiles.exists(_.getName.endsWith(".json"))) + + val originalData = spark.read.json(primitiveFieldAndType) + checkAnswer(jsonDF, originalData) + checkAnswer(spark.read.schema(originalData.schema).json(jsonDir), originalData) + } + } + + test("SPARK-18352: Expect one JSON document per file") { + // the json parser terminates as soon as it sees a matching END_OBJECT or END_ARRAY token. + // this might not be the optimal behavior but this test verifies that only the first value + // is parsed and the rest are discarded. + + // alternatively the parser could continue parsing following objects, which may further reduce + // allocations by skipping the line reader entirely + + withTempPath { dir => + val path = dir.getCanonicalPath + spark + .createDataFrame(Seq(Tuple1("{}{invalid}"))) + .coalesce(1) + .write + .text(path) + + val jsonDF = spark.read.option("wholeFile", true).json(path) + // no corrupt record column should be created + assert(jsonDF.schema === StructType(Seq())) + // only the first object should be read + assert(jsonDF.count() === 1) + } + } + + test("SPARK-18352: Handle multi-line corrupt documents (PERMISSIVE)") { + withTempPath { dir => + val path = dir.getCanonicalPath + val corruptRecordCount = additionalCorruptRecords.count().toInt + assert(corruptRecordCount === 5) + + additionalCorruptRecords + .toDF("value") + // this is the minimum partition count that avoids hash collisions + .repartition(corruptRecordCount * 4, F.hash($"value")) + .write + .text(path) + + val jsonDF = spark.read.option("wholeFile", true).option("mode", "PERMISSIVE").json(path) + assert(jsonDF.count() === corruptRecordCount) + assert(jsonDF.schema === new StructType() + .add("_corrupt_record", StringType) + .add("dummy", StringType)) + val counts = jsonDF + .join( + additionalCorruptRecords.toDF("value"), + F.regexp_replace($"_corrupt_record", "(^\\s+|\\s+$)", "") === F.trim($"value"), + "outer") + .agg( + F.count($"dummy").as("valid"), + F.count($"_corrupt_record").as("corrupt"), + F.count("*").as("count")) + checkAnswer(counts, Row(1, 4, 6)) + } + } + + test("SPARK-18352: Handle multi-line corrupt documents (FAILFAST)") { + withTempPath { dir => + val path = dir.getCanonicalPath + val corruptRecordCount = additionalCorruptRecords.count().toInt + assert(corruptRecordCount === 5) + + additionalCorruptRecords + .toDF("value") + // this is the minimum partition count that avoids hash collisions + .repartition(corruptRecordCount * 4, F.hash($"value")) + .write + .text(path) + + val schema = new StructType().add("dummy", StringType) + + // `FAILFAST` mode should throw an exception for corrupt records. + val exceptionOne = intercept[SparkException] { + spark.read + .option("wholeFile", true) + .option("mode", "FAILFAST") + .json(path) + .collect() + } + assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode")) + + val exceptionTwo = intercept[SparkException] { + spark.read + .option("wholeFile", true) + .option("mode", "FAILFAST") + .schema(schema) + .json(path) + .collect() + } + assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode")) + } + } }