Skip to content

Commit

Permalink
[SPARK-18352][SQL] Support parsing multiline json files
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

If a new option `wholeFile` is set to `true` the JSON reader will parse each file (instead of a single line) as a value. This is done with Jackson streaming and it should be capable of parsing very large documents, assuming the row will fit in memory.

Because the file is not buffered in memory the corrupt record handling is also slightly different when `wholeFile` is enabled: the corrupt column will contain the filename instead of the literal JSON if there is a parsing failure. It would be easy to extend this to add the parser location (line, column and byte offsets) to the output if desired.

These changes have allowed types other than `String` to be parsed. Support for `UTF8String` and `Text` have been added (alongside `String` and `InputFormat`) and no longer require a conversion to `String` just for parsing.

I've also included a few other changes that generate slightly better bytecode and (imo) make it more obvious when and where boxing is occurring in the parser. These are included as separate commits, let me know if they should be flattened into this PR or moved to a new one.

## How was this patch tested?

New and existing unit tests. No performance or load tests have been run.

Author: Nathan Howell <nhowell@godaddy.com>

Closes apache#16386 from NathanHowell/SPARK-18352.
  • Loading branch information
Nathan Howell authored and cloud-fan committed Feb 17, 2017
1 parent dcc2d54 commit 21fde57
Show file tree
Hide file tree
Showing 17 changed files with 740 additions and 231 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,13 @@ public void writeTo(ByteBuffer buffer) {
buffer.position(pos + numBytes);
}

public void writeTo(OutputStream out) throws IOException {
/**
* Returns a {@link ByteBuffer} wrapping the base object if it is a byte array
* or a copy of the data if the base object is not a byte array.
*
* Unlike getBytes this will not create a copy the array if this is a slice.
*/
public @Nonnull ByteBuffer getByteBuffer() {
if (base instanceof byte[] && offset >= BYTE_ARRAY_OFFSET) {
final byte[] bytes = (byte[]) base;

Expand All @@ -160,12 +166,20 @@ public void writeTo(OutputStream out) throws IOException {
throw new ArrayIndexOutOfBoundsException();
}

out.write(bytes, (int) arrayOffset, numBytes);
return ByteBuffer.wrap(bytes, (int) arrayOffset, numBytes);
} else {
out.write(getBytes());
return ByteBuffer.wrap(getBytes());
}
}

public void writeTo(OutputStream out) throws IOException {
final ByteBuffer bb = this.getByteBuffer();
assert(bb.hasArray());

// similar to Utils.writeByteBuffer but without the spark-core dependency
out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining());
}

/**
* Returns the number of bytes for a code point with the first byte as `b`
* @param b The first byte of a code point
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFil

import org.apache.spark.internal.config
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since

/**
* A general format for reading whole files in as streams, byte arrays,
Expand Down Expand Up @@ -175,6 +176,7 @@ class PortableDataStream(
* Create a new DataInputStream from the split and context. The user of this method is responsible
* for closing the stream after usage.
*/
@Since("1.2.0")
def open(): DataInputStream = {
val pathp = split.getPath(index)
val fs = pathp.getFileSystem(conf)
Expand All @@ -184,6 +186,7 @@ class PortableDataStream(
/**
* Read the file as a byte array
*/
@Since("1.2.0")
def toArray(): Array[Byte] = {
val stream = open()
try {
Expand All @@ -193,6 +196,10 @@ class PortableDataStream(
}
}

@Since("1.2.0")
def getPath(): String = path

@Since("2.2.0")
def getConfiguration: Configuration = conf
}

13 changes: 8 additions & 5 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,12 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
timeZone=None):
timeZone=None, wholeFile=None):
"""
Loads a JSON file (`JSON Lines text format or newline-delimited JSON
<http://jsonlines.org/>`_) or an RDD of Strings storing JSON objects (one object per
record) and returns the result as a :class`DataFrame`.
Loads a JSON file and returns the results as a :class:`DataFrame`.
Both JSON (one record per file) and `JSON Lines <http://jsonlines.org/>`_
(newline-delimited JSON) are supported and can be selected with the `wholeFile` parameter.
If the ``schema`` parameter is not specified, this function goes
through the input once to determine the input schema.
Expand Down Expand Up @@ -212,6 +213,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
:param timeZone: sets the string that indicates a timezone to be used to parse timestamps.
If None is set, it uses the default value, session local timezone.
:param wholeFile: parse one record, which may span multiple lines, per file. If None is
set, it uses the default value, ``false``.
>>> df1 = spark.read.json('python/test_support/sql/people.json')
>>> df1.dtypes
Expand All @@ -228,7 +231,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero,
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
timestampFormat=timestampFormat, timeZone=timeZone)
timestampFormat=timestampFormat, timeZone=timeZone, wholeFile=wholeFile)
if isinstance(path, basestring):
path = [path]
if type(path) == list:
Expand Down
14 changes: 9 additions & 5 deletions python/pyspark/sql/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,11 +428,13 @@ def load(self, path=None, format=None, schema=None, **options):
def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
mode=None, columnNameOfCorruptRecord=None, dateFormat=None,
timestampFormat=None, timeZone=None):
mode=None, columnNameOfCorruptRecord=None, dateFormat=None, timestampFormat=None,
timeZone=None, wholeFile=None):
"""
Loads a JSON file stream (`JSON Lines text format or newline-delimited JSON
<http://jsonlines.org/>`_) and returns a :class`DataFrame`.
Loads a JSON file stream and returns the results as a :class:`DataFrame`.
Both JSON (one record per file) and `JSON Lines <http://jsonlines.org/>`_
(newline-delimited JSON) are supported and can be selected with the `wholeFile` parameter.
If the ``schema`` parameter is not specified, this function goes
through the input once to determine the input schema.
Expand Down Expand Up @@ -483,6 +485,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
default value, ``yyyy-MM-dd'T'HH:mm:ss.SSSZZ``.
:param timeZone: sets the string that indicates a timezone to be used to parse timestamps.
If None is set, it uses the default value, session local timezone.
:param wholeFile: parse one record, which may span multiple lines, per file. If None is
set, it uses the default value, ``false``.
>>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema)
>>> json_sdf.isStreaming
Expand All @@ -496,7 +500,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero,
allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord, dateFormat=dateFormat,
timestampFormat=timestampFormat, timeZone=timeZone)
timestampFormat=timestampFormat, timeZone=timeZone, wholeFile=wholeFile)
if isinstance(path, basestring):
return self._df(self._jreader.json(path))
else:
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,13 @@ def test_udf_with_order_by_and_limit(self):
res.explain(True)
self.assertEqual(res.collect(), [Row(id=0, copy=0)])

def test_wholefile_json(self):
from pyspark.sql.types import StringType
people1 = self.spark.read.json("python/test_support/sql/people.json")
people_array = self.spark.read.json("python/test_support/sql/people_array.json",
wholeFile=True)
self.assertEqual(people1.collect(), people_array.collect())

def test_udf_with_input_file_name(self):
from pyspark.sql.functions import udf, input_file_name
from pyspark.sql.types import StringType
Expand Down
13 changes: 13 additions & 0 deletions python/test_support/sql/people_array.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[
{
"name": "Michael"
},
{
"name": "Andy",
"age": 30
},
{
"name": "Justin",
"age": 19
}
]
Original file line number Diff line number Diff line change
Expand Up @@ -497,16 +497,20 @@ case class JsonToStruct(
lazy val parser =
new JacksonParser(
schema,
"invalid", // Not used since we force fail fast. Invalid rows will be set to `null`.
new JSONOptions(options ++ Map("mode" -> ParseModes.FAIL_FAST_MODE), timeZoneId.get))
new JSONOptions(options + ("mode" -> ParseModes.FAIL_FAST_MODE), timeZoneId.get))

override def dataType: DataType = schema

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

override def nullSafeEval(json: Any): Any = {
try parser.parse(json.toString).headOption.orNull catch {
try {
parser.parse(
json.asInstanceOf[UTF8String],
CreateJacksonParser.utf8String,
identity[UTF8String]).headOption.orNull
} catch {
case _: SparkSQLJsonProcessingException => null
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.json

import java.io.InputStream

import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
import org.apache.hadoop.io.Text

import org.apache.spark.unsafe.types.UTF8String

private[sql] object CreateJacksonParser extends Serializable {
def string(jsonFactory: JsonFactory, record: String): JsonParser = {
jsonFactory.createParser(record)
}

def utf8String(jsonFactory: JsonFactory, record: UTF8String): JsonParser = {
val bb = record.getByteBuffer
assert(bb.hasArray)

jsonFactory.createParser(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
}

def text(jsonFactory: JsonFactory, record: Text): JsonParser = {
jsonFactory.createParser(record.getBytes, 0, record.getLength)
}

def inputStream(jsonFactory: JsonFactory, record: InputStream): JsonParser = {
jsonFactory.createParser(record)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,20 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs
* Most of these map directly to Jackson's internal options, specified in [[JsonParser.Feature]].
*/
private[sql] class JSONOptions(
@transient private val parameters: CaseInsensitiveMap[String], defaultTimeZoneId: String)
@transient private val parameters: CaseInsensitiveMap[String],
defaultTimeZoneId: String,
defaultColumnNameOfCorruptRecord: String)
extends Logging with Serializable {

def this(parameters: Map[String, String], defaultTimeZoneId: String) =
this(CaseInsensitiveMap(parameters), defaultTimeZoneId)
def this(
parameters: Map[String, String],
defaultTimeZoneId: String,
defaultColumnNameOfCorruptRecord: String = "") = {
this(
CaseInsensitiveMap(parameters),
defaultTimeZoneId,
defaultColumnNameOfCorruptRecord)
}

val samplingRatio =
parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
Expand All @@ -57,7 +66,8 @@ private[sql] class JSONOptions(
parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false)
val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName)
private val parseMode = parameters.getOrElse("mode", "PERMISSIVE")
val columnNameOfCorruptRecord = parameters.get("columnNameOfCorruptRecord")
val columnNameOfCorruptRecord =
parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord)

val timeZone: TimeZone = TimeZone.getTimeZone(parameters.getOrElse("timeZone", defaultTimeZoneId))

Expand All @@ -69,6 +79,8 @@ private[sql] class JSONOptions(
FastDateFormat.getInstance(
parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSZZ"), timeZone, Locale.US)

val wholeFile = parameters.get("wholeFile").map(_.toBoolean).getOrElse(false)

// Parse mode flags
if (!ParseModes.isValidMode(parseMode)) {
logWarning(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.")
Expand Down
Loading

0 comments on commit 21fde57

Please # to comment.