From fff69264736f6b478b5a8cf37ea297b2aa9f4214 Mon Sep 17 00:00:00 2001 From: "yuanman.ym" <yuanman.ym@alibaba-inc.com> Date: Mon, 26 Feb 2024 11:38:49 +0800 Subject: [PATCH 1/3] Support default values for filename dataset --- hybridbackend/tensorflow/common/arrow.cc | 150 +++++++++++++----- hybridbackend/tensorflow/common/arrow.h | 10 +- hybridbackend/tensorflow/data/__init__.py | 1 - hybridbackend/tensorflow/data/dataframe.py | 149 +++++------------ .../tensorflow/data/tabular/dataset.cc | 71 ++++++--- .../tensorflow/data/tabular/dataset_v1.py | 12 +- .../tensorflow/data/tabular/dataset_v2.py | 12 +- hybridbackend/tensorflow/data/tabular/orc.cc | 107 +++++++------ hybridbackend/tensorflow/data/tabular/orc.h | 1 + .../tensorflow/data/tabular/parquet.cc | 94 ++++++----- .../tensorflow/data/tabular/parquet.h | 1 + .../tensorflow/data/tabular/table.cc | 17 +- hybridbackend/tensorflow/data/tabular/table.h | 13 +- .../tensorflow/data/tabular/table.py | 25 ++- .../tests/parquet_dataset_deduplicate_test.py | 1 - .../data/tests/parquet_dataset_test.py | 31 +++- 16 files changed, 402 insertions(+), 293 deletions(-) diff --git a/hybridbackend/tensorflow/common/arrow.cc b/hybridbackend/tensorflow/common/arrow.cc index 0fcdb6e1..829fd363 100644 --- a/hybridbackend/tensorflow/common/arrow.cc +++ b/hybridbackend/tensorflow/common/arrow.cc @@ -381,42 +381,104 @@ Status MakeTensorsFromArrowArray( return Status::OK(); } +#define CASE_TENSOR_FILL(ENUM, OUT, IN) \ + case ENUM: { \ + OUT.flat<EnumToDataType<ENUM>::Type>().setConstant( \ + IN.scalar<EnumToDataType<ENUM>::Type>()()); \ + break; \ + } + +Status MakeTensorsFromRecordDefaultValue(const DataType type, + const int32 ragged_rank, + const PartialTensorShape& shape, + const int64 actual_batch_size, + const Tensor& record_default, + std::vector<Tensor>* output_tensors) { + TensorShape actual_shape; + if (!TF_PREDICT_TRUE(PartialTensorShape({actual_batch_size}) + .Concatenate(shape) + .AsTensorShape(&actual_shape))) { + return errors::InvalidArgument( + "Calculated shape of input batch is not fully defined"); + } + Tensor values_tensor(type, actual_shape); + switch (type) { + CASE_TENSOR_FILL(DT_INT8, values_tensor, record_default); + CASE_TENSOR_FILL(DT_UINT8, values_tensor, record_default); + CASE_TENSOR_FILL(DT_INT32, values_tensor, record_default); + CASE_TENSOR_FILL(DT_UINT32, values_tensor, record_default); + CASE_TENSOR_FILL(DT_INT64, values_tensor, record_default); + CASE_TENSOR_FILL(DT_UINT64, values_tensor, record_default); + CASE_TENSOR_FILL(DT_HALF, values_tensor, record_default); + CASE_TENSOR_FILL(DT_FLOAT, values_tensor, record_default); + CASE_TENSOR_FILL(DT_DOUBLE, values_tensor, record_default); + CASE_TENSOR_FILL(DT_STRING, values_tensor, record_default); + default: + return errors::Unimplemented("Data type ", DataTypeString(type), + " not supported."); + } + output_tensors->emplace_back(std::move(values_tensor)); + int32 remained_ragged_rank = ragged_rank; + if (remained_ragged_rank > 0) { + Tensor split_tensor(DT_INT32, {2}); + auto split_tensor_flat = split_tensor.tensor<int32, 1>(); + split_tensor_flat(0) = 0; + split_tensor_flat(1) = shape.num_elements(); + output_tensors->emplace_back(std::move(split_tensor)); + remained_ragged_rank--; + } + while (remained_ragged_rank > 0) { + Tensor split_tensor(DT_INT32, {2}); + auto split_tensor_flat = split_tensor.tensor<int32, 1>(); + split_tensor_flat(0) = 0; + split_tensor_flat(1) = 1; + output_tensors->emplace_back(std::move(split_tensor)); + remained_ragged_rank--; + } + + return Status::OK(); +} + Status ValidateSchema(const string& filename, const std::vector<string>& field_names, const DataTypeVector& field_dtypes, const std::vector<int32>& field_ragged_ranks, std::shared_ptr<::arrow::Schema>& schema, - std::vector<int>* out_column_indices) { + std::vector<int>* columns, + std::vector<int>* field_column_indices) { + if (TF_PREDICT_FALSE(columns == nullptr || field_column_indices == nullptr)) { + return errors::Internal("columns and field_column_indices must be valid"); + } if (TF_PREDICT_FALSE(!schema->HasDistinctFieldNames())) { return errors::InvalidArgument(filename, " must has distinct column names"); } + int column_idx = 0; for (size_t i = 0; i < field_names.size(); ++i) { auto& cname = field_names[i]; - int column_index = schema->GetFieldIndex(cname); - if (TF_PREDICT_FALSE(column_index < 0)) { - return errors::NotFound("No column called `", cname, "` found in ", - filename); - } - if (out_column_indices != nullptr) { - out_column_indices->push_back(column_index); - } - const auto& expected_dtype = field_dtypes[i]; - const auto& expected_ragged_rank = field_ragged_ranks[i]; - DataType actual_dtype; - int32 actual_ragged_rank = 0; - TF_RETURN_IF_ERROR(MakeDataTypeAndRaggedRankFromArrowDataType( - schema->field(column_index)->type(), &actual_dtype, - &actual_ragged_rank)); - if (TF_PREDICT_FALSE(actual_dtype != expected_dtype)) { - return errors::InvalidArgument( - "Field ", cname, " in ", filename, " has unexpected data type ", - DataTypeString(actual_dtype), ", which should be ", - DataTypeString(expected_dtype)); - } - if (TF_PREDICT_FALSE(actual_ragged_rank != expected_ragged_rank)) { - return errors::InvalidArgument( - "Field ", cname, " in ", filename, " has unexpected ragged rank ", - actual_ragged_rank, ", which should be ", expected_ragged_rank); + int column = schema->GetFieldIndex(cname); + if (TF_PREDICT_FALSE(column < 0)) { + field_column_indices->push_back(-1); + } else { + columns->push_back(column); + field_column_indices->push_back(column_idx); + column_idx++; + const auto& expected_dtype = field_dtypes[i]; + const auto& expected_ragged_rank = field_ragged_ranks[i]; + DataType actual_dtype; + int32 actual_ragged_rank = 0; + TF_RETURN_IF_ERROR(MakeDataTypeAndRaggedRankFromArrowDataType( + schema->field(column)->type(), &actual_dtype, &actual_ragged_rank)); + if (TF_PREDICT_FALSE(actual_dtype != expected_dtype)) { + return errors::InvalidArgument( + "Field ", cname, " in ", filename, " has unexpected data type ", + DataTypeString(actual_dtype), ", which should be ", + DataTypeString(expected_dtype)); + } + if (TF_PREDICT_FALSE(actual_ragged_rank != expected_ragged_rank)) { + return errors::InvalidArgument( + "Field ", cname, " in ", filename, " has unexpected ragged rank ", + actual_ragged_rank, ", which should be ", expected_ragged_rank); + } } } return Status::OK(); @@ -424,10 +486,12 @@ Status ValidateSchema(const string& filename, Status ReadRecordBatch(::arrow::RecordBatchReader* batch_reader, const string& filename, const int64 batch_size, + const std::vector<Tensor>& record_defaults, const std::vector<string>& field_names, const DataTypeVector& field_dtypes, const std::vector<int32>& field_ragged_ranks, const std::vector<PartialTensorShape>& field_shapes, + const std::vector<int>& field_column_indices, const bool drop_remainder, const int64 row_limit, std::vector<Tensor>* output_tensors, int64* row_counter) { @@ -450,19 +514,33 @@ Status ReadRecordBatch(::arrow::RecordBatchReader* batch_reader, // Populate tensors from record batch. auto arrays = batch->columns(); - for (size_t i = 0; i < arrays.size(); ++i) { - auto s = - MakeTensorsFromArrowArray(field_dtypes[i], field_ragged_ranks[i], - field_shapes[i], arrays[i], output_tensors); - if (!s.ok()) { - return errors::DataLoss("Failed to parse row #", *row_counter, " - #", - (*row_counter) + batch->num_rows(), " at column ", - field_names[i], " (#", i, ") in ", filename, ": ", - s.error_message()); + const int64 actual_batch_size = batch->num_rows(); + for (size_t i = 0; i < field_names.size(); ++i) { + int column_idx = field_column_indices[i]; + if (column_idx == -1) { + auto s = MakeTensorsFromRecordDefaultValue( + field_dtypes[i], field_ragged_ranks[i], field_shapes[i], + actual_batch_size, record_defaults[i], output_tensors); + if (!s.ok()) { + return errors::DataLoss( + "Failed to populate default value for row #", *row_counter, " - #", + (*row_counter) + actual_batch_size, " at column ", field_names[i], + " (#", i, ") in ", filename, ": ", s.error_message()); + } + } else { + auto s = MakeTensorsFromArrowArray(field_dtypes[i], field_ragged_ranks[i], + field_shapes[i], arrays[column_idx], + output_tensors); + if (!s.ok()) { + return errors::DataLoss("Failed to parse row #", *row_counter, " - #", + (*row_counter) + actual_batch_size, + " at column ", field_names[i], " (#", i, + ") in ", filename, ": ", s.error_message()); + } } } - (*row_counter) += batch->num_rows(); + (*row_counter) += actual_batch_size; return Status::OK(); #else return errors::Unimplemented("HYBRIDBACKEND_WITH_ARROW must be ON"); diff --git a/hybridbackend/tensorflow/common/arrow.h b/hybridbackend/tensorflow/common/arrow.h index 55765449..3f668fed 100644 --- a/hybridbackend/tensorflow/common/arrow.h +++ b/hybridbackend/tensorflow/common/arrow.h @@ -126,19 +126,27 @@ Status MakeTensorsFromArrowArray( const std::shared_ptr<::arrow::Array>& arrow_array, std::vector<Tensor>* output_tensors); +Status MakeTensorsFromDefaultValue(const DataType type, const int32 ragged_rank, + const PartialTensorShape& shape, + const Tensor& default_value, + std::vector<Tensor>* output_tensors); + Status ValidateSchema(const string& filename, const std::vector<string>& field_names, const DataTypeVector& field_dtypes, const std::vector<int32>& field_ragged_ranks, std::shared_ptr<::arrow::Schema>& schema, - std::vector<int>* out_column_indices); + std::vector<int>* columns, + std::vector<int>* field_column_indices); Status ReadRecordBatch(::arrow::RecordBatchReader* batch_reader, const string& filename, const int64 batch_size, + const std::vector<Tensor>& record_defaults, const std::vector<string>& field_names, const DataTypeVector& field_dtypes, const std::vector<int32>& field_ragged_ranks, const std::vector<PartialTensorShape>& field_shapes, + const std::vector<int>& field_column_indices, const bool drop_remainder, const int64 row_limit, std::vector<Tensor>* output_tensors, int64* row_counter); diff --git a/hybridbackend/tensorflow/data/__init__.py b/hybridbackend/tensorflow/data/__init__.py index 74610a26..ad49f79d 100644 --- a/hybridbackend/tensorflow/data/__init__.py +++ b/hybridbackend/tensorflow/data/__init__.py @@ -23,7 +23,6 @@ from hybridbackend.tensorflow.data.dataframe import DataFrame from hybridbackend.tensorflow.data.dataframe import parse -from hybridbackend.tensorflow.data.dataframe import populate_defaults from hybridbackend.tensorflow.data.dataframe import unbatch_and_to_sparse from hybridbackend.tensorflow.data.deduplicate.dataset import deduplicate from hybridbackend.tensorflow.data.prefetch.iterator import Iterator diff --git a/hybridbackend/tensorflow/data/dataframe.py b/hybridbackend/tensorflow/data/dataframe.py index e30aca60..231d6c8e 100644 --- a/hybridbackend/tensorflow/data/dataframe.py +++ b/hybridbackend/tensorflow/data/dataframe.py @@ -169,18 +169,16 @@ def __init__( else: raise self._ragged_rank = ragged_rank - if shape: + if shape is not None: shape = tensor_shape.TensorShape(shape) for d in shape: if d.value is None: raise ValueError( f'Field {name} has incomplete shape: {shape}') - if ragged_rank is not None and ragged_rank > 1: + if shape.ndims > 0 and ragged_rank is not None and ragged_rank > 1: raise ValueError( - f'Field {name} is a nested list ({ragged_rank}) ' - f'with shape {shape}') - else: - shape = tensor_shape.TensorShape([]) + f'Field {name} with shape {shape} should be a fixed-length list ' + f'not a nested list ({ragged_rank})') self._shape = shape self._default_value = default_value self._restore_idx_field = None @@ -209,6 +207,17 @@ def shape(self): def default_value(self): return self._default_value + @property + def default_value_tensor(self): + default_value = 0 + if self._default_value is not None: + default_value = self._default_value + elif self._dtype == dtypes.string: + default_value = '' + elif self._dtype == dtypes.bool: + default_value = False + return ops.convert_to_tensor(default_value, dtype=self._dtype) + @property def restore_idx_field(self): return self._restore_idx_field @@ -455,51 +464,6 @@ def parse(cls, features, pad=False, restore_idx=None): return features raise ValueError(f'{features} not supported') - @classmethod - def populate_defaults(cls, features, all_fields, batch_size): - r'''Populate default values. - ''' - if not isinstance(features, dict): - raise ValueError('Inputs should be a dict') - populated = dict(features) - for f in all_fields: - if f.name not in features and f.default_value is not None: - if batch_size is None: - populated[f.name] = array_ops.expand_dims( - ops.convert_to_tensor(f.default_value, dtype=f.dtype), - axis=0) - else: - if isinstance(f.default_value, sparse_tensor.SparseTensor): - indices_dtype = f.default_value.indices.dtype - indices_car = math_ops.cast( - array_ops.reshape( - array_ops.tile( - math_ops.range(batch_size), - [f.default_value.indices.shape[0]]), - [-1, 1]), - indices_dtype) - indices_cdr = array_ops.tile( - f.default_value.indices, [batch_size, 1]) - batched_indices = array_ops.concat( - [indices_car, indices_cdr], axis=1) - batched_values = array_ops.tile( - f.default_value.values, [batch_size]) - batched_dense_shape = array_ops.concat( - [ops.convert_to_tensor([batch_size], indices_dtype), - f.default_value.dense_shape], - axis=0) - populated[f.name] = sparse_tensor.SparseTensor( - indices=batched_indices, - values=batched_values, - dense_shape=batched_dense_shape) - else: - value = ops.convert_to_tensor(f.default_value, dtype=f.dtype) - populated[f.name] = array_ops.tile( - array_ops.expand_dims(value, axis=0), - tensor_shape.TensorShape([batch_size]).concatenate( - value.shape)) - return populated - @classmethod def to_sparse(cls, features): r'''Convert DataFrame values to tensors or sparse tensors. @@ -565,16 +529,6 @@ def _apply_fn(dataset): return _apply_fn -def populate_defaults(all_fields, batch_size, num_parallel_calls=None): - r'''Populate default values. - ''' - def _apply_fn(dataset): - return dataset.map( - lambda t: DataFrame.populate_defaults(t, all_fields, batch_size), - num_parallel_calls=num_parallel_calls) - return _apply_fn - - def unbatch_and_to_sparse(num_parallel_calls=None): r'''Unbatch and convert a row to tensors or sparse tensors from input dataset. ''' @@ -607,7 +561,7 @@ def input_fields(input_dataset, fields=None): return fields -def build_fields(filename, fn, fields=None, lower=False): +def build_fields(filename=None, fn=None, fields=None, lower=False): r'''Get fields from a file. Args: @@ -618,8 +572,12 @@ def build_fields(filename, fn, fields=None, lower=False): Returns: Field definitions. ''' - logging.info(f'Reading fields from {filename} ...') - all_field_tuples = fn(filename) # pylint: disable=c-extension-no-member + all_field_tuples = [] + if filename is not None and fn is not None: + logging.info(f'Reading fields from {filename} ...') + all_field_tuples = fn(filename) # pylint: disable=c-extension-no-member + if not all_field_tuples: + raise ValueError(f'No field found in file {filename}') all_fields = { f[0]: {'dtype': f[1], 'ragged_rank': f[2]} for f in all_field_tuples} @@ -635,20 +593,17 @@ def build_fields(filename, fn, fields=None, lower=False): dtype=f.dtype, shape=f.shape, ragged_rank=f.ragged_rank) - if f.name not in all_fields: - if f.default_value is None: - raise ValueError( - f'Field {f.name} not found in file {filename}') + if (filename is not None and + f.name not in all_fields and + f.default_value is None): + raise ValueError( + f'Field {f.name} without default value not found in file {filename}') dtype = f.dtype - + ragged_rank = f.ragged_rank if f.name in all_fields: actual_dtype = np.dtype(all_fields[f.name]['dtype']) if dtype is None: dtype = actual_dtype - elif dtype != actual_dtype: - raise ValueError( - f'Field {f.name} dtype should be {actual_dtype} not {dtype}') - ragged_rank = f.ragged_rank actual_ragged_rank = all_fields[f.name]['ragged_rank'] if ragged_rank is None: ragged_rank = actual_ragged_rank @@ -656,46 +611,24 @@ def build_fields(filename, fn, fields=None, lower=False): raise ValueError( f'Field {f.name} ragged_rank should be {actual_ragged_rank} ' f'not {ragged_rank}') - else: + elif filename is not None: if f.default_value is None: raise ValueError( f'Field {f.name} not found in file {filename}') - if isinstance(f.default_value, sparse_tensor.SparseTensor): - actual_dtype = np.dtype(f.default_value.dtype) - if dtype is None: - dtype = actual_dtype - elif dtype != actual_dtype: - raise ValueError( - f'Field {f.name} dtype should be {actual_dtype} not {dtype}') - elif isinstance(f.default_value, ops.Tensor): - actual_dtype = np.dtype(f.default_value.dtype) - if dtype is None: - dtype = actual_dtype - elif dtype != actual_dtype: - raise ValueError( - f'Field {f.name} dtype should be {actual_dtype} not {dtype}') - actual_ragged_rank = 0 - if ragged_rank is None: - ragged_rank = actual_ragged_rank - elif ragged_rank != actual_ragged_rank: - raise ValueError( - f'Field {f.name} ragged_rank should be {actual_ragged_rank} ' - f'not {ragged_rank}') - else: - try: - with ops.name_scope('default_values/'): - _ = ops.convert_to_tensor( - f.default_value, dtype=dtype) - except (TypeError, ValueError) as ex: - raise ValueError( - f'Field {f.name} default_value {f.default_value} ' - f'should be a SparseTensor or Tensor: {ex}') from ex + if f.dtype is None: + raise ValueError( + f'Data type of field {f.name} with default value ' + 'must be specified') + if f.ragged_rank is None: + raise ValueError( + f'Ragged rank of field {f.name} with default value ' + 'must be specified') f = DataFrame.Field( f.name, dtype=dtype, ragged_rank=ragged_rank, - shape=f.shape, - default_value=None if f.name in all_fields else f.default_value) + shape=tensor_shape.TensorShape([]) if f.shape is None else f.shape, + default_value=f.default_value) new_fields.append(f) continue if not isinstance(f, string): @@ -710,7 +643,7 @@ def build_fields(filename, fn, fields=None, lower=False): f, dtype=np.dtype(all_fields[f]['dtype']), ragged_rank=all_fields[f]['ragged_rank'], - shape=None)) + shape=tensor_shape.TensorShape([]))) return tuple(new_fields) @@ -752,6 +685,7 @@ def build_filenames_and_fields(filenames, fn, fields, lower=False): if f.incomplete: raise ValueError( f'Field {f} is incomplete, please specify dtype and ragged_rank') + fields = build_fields(fields=fields) elif isinstance(filenames, ops.Tensor): if filenames.dtype != dtypes.string: raise TypeError( @@ -766,6 +700,7 @@ def build_filenames_and_fields(filenames, fn, fields, lower=False): if f.incomplete: raise ValueError( f'Field {f} is incomplete, please specify dtype and ragged_rank') + fields = build_fields(fields=fields) else: raise ValueError( f'`filenames` {filenames} must be a `tf.data.Dataset` of scalar ' diff --git a/hybridbackend/tensorflow/data/tabular/dataset.cc b/hybridbackend/tensorflow/data/tabular/dataset.cc index 51643c63..83b43ebb 100644 --- a/hybridbackend/tensorflow/data/tabular/dataset.cc +++ b/hybridbackend/tensorflow/data/tabular/dataset.cc @@ -41,6 +41,7 @@ REGISTER_OP("HbTabularDataset") .Output("handle: variant") .Input("filename: string") .Input("batch_size: int64") + .Input("record_defaults: field_dtypes") .Attr("format: int") .Attr("field_names: list(string) >= 1") .Attr("field_dtypes: list(type) >= 1") @@ -53,8 +54,18 @@ REGISTER_OP("HbTabularDataset") .SetIsStateful() .SetShapeFn([](shape_inference::InferenceContext* c) { shape_inference::ShapeHandle unused; - // batch_size should be a scalar. + // `batch_size` must be a scalar. TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + // `record_defaults` must be a list of scalars + for (size_t i = 2; i < c->num_inputs(); ++i) { + shape_inference::ShapeHandle v; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v)); + if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) { + return errors::InvalidArgument( + "Shape of a default must be a length-0 or length-1 vector, or a " + "scalar."); + } + } return shape_inference::ScalarShape(c); }) .Doc(R"doc( @@ -63,6 +74,7 @@ A dataset that outputs batches from a file. handle: The handle to reference the dataset. filename: Path of file to read. batch_size: Maxium number of samples in an output batch. +record_defaults: Default values of a sample in an output batch. format: File format to use. field_names: List of field names to read. field_dtypes: List of data types for each field. @@ -157,6 +169,13 @@ class TabularDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(b->AddScalar(access_->filename(), &filename)); Node* batch_size; TF_RETURN_IF_ERROR(b->AddScalar(access_->batch_size(), &batch_size)); + std::vector<Node*> record_defaults; + record_defaults.reserve(access_->record_defaults().size()); + for (const Tensor& t : access_->record_defaults()) { + Node* node; + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + record_defaults.emplace_back(node); + } AttrValue format; b->BuildAttrValue(access_->format(), &format); AttrValue field_names; @@ -175,18 +194,18 @@ class TabularDatasetOp::Dataset : public DatasetBase { b->BuildAttrValue(partition_count_, &partition_count); AttrValue partition_index; b->BuildAttrValue(partition_index_, &partition_index); - TF_RETURN_IF_ERROR( - b->AddDataset(this, {{0, filename}, {1, batch_size}}, {}, - {{"format", format}, - {"field_names", field_names}, - {"field_dtypes", field_dtypes}, - {"field_ragged_ranks", field_ragged_ranks}, - {"field_shapes", field_shapes}, - {"drop_remainder", drop_remainder}, - {"skip_corrupted_data", skip_corrupted_data}, - {"partition_count", partition_count}, - {"partition_index", partition_index}}, - output)); + TF_RETURN_IF_ERROR(b->AddDataset( + this, {{0, filename}, {1, batch_size}}, {{2, record_defaults}}, + {{"format", format}, + {"field_names", field_names}, + {"field_dtypes", field_dtypes}, + {"field_ragged_ranks", field_ragged_ranks}, + {"field_shapes", field_shapes}, + {"drop_remainder", drop_remainder}, + {"skip_corrupted_data", skip_corrupted_data}, + {"partition_count", partition_count}, + {"partition_index", partition_index}}, + output)); return Status::OK(); } @@ -197,7 +216,7 @@ class TabularDatasetOp::Dataset : public DatasetBase { std::unique_ptr<TableAccess> access_; int64 partition_count_; int64 partition_index_; -}; +}; // namespace hybridbackend void TabularDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) { string filename; @@ -214,10 +233,26 @@ void TabularDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) { OP_REQUIRES(ctx, partition_index_ >= 0, errors::InvalidArgument("Partition index ", partition_index_, "must be greater than 0")); - TableAccess* access = - TableAccess::Create(ctx, format_, filename, batch_size, field_names_, - field_dtypes_, field_ragged_ranks_, field_shapes_, - drop_remainder_, skip_corrupted_data_); + + OpInputList record_defaults_list; + std::vector<Tensor> record_defaults; + OP_REQUIRES_OK(ctx, + ctx->input_list("record_defaults", &record_defaults_list)); + for (int i = 0; i < record_defaults_list.size(); ++i) { + OP_REQUIRES(ctx, record_defaults_list[i].dims() <= 1, + errors::InvalidArgument( + "Each record default should be at most rank 1")); + OP_REQUIRES(ctx, record_defaults_list[i].NumElements() < 2, + errors::InvalidArgument( + "There should only be 1 default per field but field ", i, + " has ", record_defaults_list[i].NumElements())); + record_defaults.push_back(record_defaults_list[i]); + } + + TableAccess* access = TableAccess::Create( + ctx, format_, filename, batch_size, std::move(record_defaults), + field_names_, field_dtypes_, field_ragged_ranks_, field_shapes_, + drop_remainder_, skip_corrupted_data_); int64 count = access->Count(); if (TF_PREDICT_FALSE(partition_count_ > 1)) { int64 partition_size = count / partition_count_; diff --git a/hybridbackend/tensorflow/data/tabular/dataset_v1.py b/hybridbackend/tensorflow/data/tabular/dataset_v1.py index 8f4483a0..e8c7ea2b 100644 --- a/hybridbackend/tensorflow/data/tabular/dataset_v1.py +++ b/hybridbackend/tensorflow/data/tabular/dataset_v1.py @@ -71,8 +71,9 @@ def __init__( self._output_classes = {f.name: f.output_classes for f in self._fields} self._output_types = {f.name: f.output_types for f in self._fields} self._output_shapes = {f.name: f.output_shapes for f in self._fields} + self._record_defaults = nest.flatten( + {f.name: f.default_value_tensor for f in self._fields}) self._field_names = nest.flatten({f.name: f.name for f in self._fields}) - self._field_dtypes = nest.flatten({f.name: f.dtype for f in self._fields}) self._field_ragged_ranks = nest.flatten( {f.name: f.ragged_rank for f in self._fields}) self._field_shapes = nest.flatten({f.name: f.shape for f in self._fields}) @@ -86,9 +87,9 @@ def _as_variant_tensor(self): return _ops.hb_tabular_dataset( self._filename, self._batch_size, + self._record_defaults, format=self._format, field_names=self._field_names, - field_dtypes=self._field_dtypes, field_ragged_ranks=self._field_ragged_ranks, field_shapes=self._field_shapes, partition_count=self._partition_count, @@ -452,9 +453,8 @@ def __init__( can only be set once. if `tf.data.experimental.AUTOTUNE` is used, the number of parsers would be set for best performance. ''' - filenames, self._all_fields = build_filenames_and_fields( + filenames, self._fields = build_filenames_and_fields( filenames, parquet_file_get_fields, fields) - self._fields = [f for f in self._all_fields if f.default_value is None] self._partition_count = partition_count self._partition_index = partition_index self._skip_corrupted_data = skip_corrupted_data @@ -494,10 +494,6 @@ def _create_dataset(f): _create_dataset, filenames, num_parallel_reads, num_sequential_reads) super().__init__() - @property - def all_fields(self): - return self._all_fields - @property def fields(self): return self._fields diff --git a/hybridbackend/tensorflow/data/tabular/dataset_v2.py b/hybridbackend/tensorflow/data/tabular/dataset_v2.py index 054eacc6..01957137 100644 --- a/hybridbackend/tensorflow/data/tabular/dataset_v2.py +++ b/hybridbackend/tensorflow/data/tabular/dataset_v2.py @@ -74,8 +74,9 @@ def __init__( raise ValueError( f'Field {f} is incomplete, please specify dtype and ragged_rank') self._output_specs = {f.name: f.build_spec() for f in self._fields} + self._record_defaults = nest.flatten( + {f.name: f.default_value_tensor for f in self._fields}) self._field_names = nest.flatten({f.name: f.name for f in self._fields}) - self._field_dtypes = nest.flatten({f.name: f.dtype for f in self._fields}) self._field_ragged_ranks = nest.flatten( {f.name: f.ragged_rank for f in self._fields}) self._field_shapes = nest.flatten({f.name: f.shape for f in self._fields}) @@ -87,9 +88,9 @@ def __init__( variant_tensor = _ops.hb_tabular_dataset( self._filename, self._batch_size, + self._record_defaults, format=self._format, field_names=self._field_names, - field_dtypes=self._field_dtypes, field_ragged_ranks=self._field_ragged_ranks, field_shapes=self._field_shapes, partition_count=self._partition_count, @@ -432,9 +433,8 @@ def __init__( can only be set once. if `tf.data.experimental.AUTOTUNE` is used, the number of parsers would be set for best performance. ''' - filenames, self._all_fields = build_filenames_and_fields( + filenames, self._fields = build_filenames_and_fields( filenames, parquet_file_get_fields, fields) - self._fields = [f for f in self._all_fields if f.default_value is None] self._partition_count = partition_count self._partition_index = partition_index self._skip_corrupted_data = skip_corrupted_data @@ -476,10 +476,6 @@ def _create_dataset(f): num_sequential_reads=num_sequential_reads) super().__init__(self._impl._variant_tensor) # pylint: disable=protected-access - @property - def all_fields(self): - return self._all_fields - @property def fields(self): return self._fields diff --git a/hybridbackend/tensorflow/data/tabular/orc.cc b/hybridbackend/tensorflow/data/tabular/orc.cc index 1626a204..239a1267 100644 --- a/hybridbackend/tensorflow/data/tabular/orc.cc +++ b/hybridbackend/tensorflow/data/tabular/orc.cc @@ -17,6 +17,7 @@ limitations under the License. #include <absl/strings/match.h> +#include <algorithm> #include <memory> #include <numeric> #include <string> @@ -30,33 +31,39 @@ namespace hybridbackend { class OrcAccess::Impl { public: Impl(OpKernelContext* ctx, const string& filename, const int64 batch_size, + const std::vector<Tensor>& record_defaults, const std::vector<string>& field_names, const DataTypeVector& field_dtypes, - const std::vector<int32>& field_ragged_ranks) - : field_names_(field_names), + const std::vector<int32>& field_ragged_ranks, + const std::vector<PartialTensorShape>& field_shapes, + const bool drop_remainder) + : filename_(filename), batch_size_(batch_size), + record_defaults_(record_defaults), + field_names_(field_names), + field_dtypes_(field_dtypes), + field_ragged_ranks_(field_ragged_ranks), + field_shapes_(field_shapes), + drop_remainder_(drop_remainder), + start_row_(-1), + end_row_(-1), row_(0), end_(0), - start_row_(-1), - end_row_(-1) { - OP_REQUIRES_OK(ctx, Initialize(filename, field_names, field_dtypes, - field_ragged_ranks)); + row_counter_(0) { + OP_REQUIRES_OK(ctx, Initialize()); } - Status Initialize(const string& filename, - const std::vector<string>& field_names, - const DataTypeVector& field_dtypes, - const std::vector<int32>& field_ragged_ranks) { + Status Initialize() { #if HYBRIDBACKEND_ARROW ::arrow::internal::Uri uri; - if (::arrow::fs::internal::IsLikelyUri(filename)) { - TF_RETURN_IF_ARROW_ERROR(uri.Parse(filename)); + if (::arrow::fs::internal::IsLikelyUri(filename_)) { + TF_RETURN_IF_ARROW_ERROR(uri.Parse(filename_)); } else { if (TF_PREDICT_FALSE( - !::arrow::fs::internal::IsLikelyUri("file://" + filename))) { - return errors::InvalidArgument("File name ", filename, " is illegal"); + !::arrow::fs::internal::IsLikelyUri("file://" + filename_))) { + return errors::InvalidArgument("File name ", filename_, " is illegal"); } - TF_RETURN_IF_ARROW_ERROR(uri.Parse("file://" + filename)); + TF_RETURN_IF_ARROW_ERROR(uri.Parse("file://" + filename_)); } std::vector<std::pair<std::string, std::string>> uri_options; TF_RETURN_IF_ARROW_ERROR(uri.query_items().Value(&uri_options)); @@ -69,14 +76,15 @@ class OrcAccess::Impl { } TF_RETURN_IF_ARROW_ERROR( - ::hybridbackend::OpenArrowFile(&fs_, &file_, filename)); + ::hybridbackend::OpenArrowFile(&fs_, &file_, filename_)); TF_RETURN_IF_ARROW_ERROR( ::hybridbackend::OpenOrcReader(&reader_, file_, true)); std::shared_ptr<::arrow::Schema> schema; TF_RETURN_IF_ARROW_ERROR(reader_->ReadSchema().Value(&schema)); - TF_RETURN_IF_ERROR(ValidateSchema(filename, field_names, field_dtypes, - field_ragged_ranks, schema, &columns_)); + TF_RETURN_IF_ERROR(ValidateSchema(filename_, field_names_, field_dtypes_, + field_ragged_ranks_, schema, &columns_, + &field_column_indices_)); #endif return Status::OK(); } @@ -130,28 +138,25 @@ class OrcAccess::Impl { #endif } - Status Read(const string& filename, const int64 batch_size, - const std::vector<string>& field_names, - const DataTypeVector& field_dtypes, - const std::vector<int32>& field_ragged_ranks, - const std::vector<PartialTensorShape>& field_shapes, - const bool drop_remainder, std::vector<Tensor>* output_tensors) { + Status Read(std::vector<Tensor>* output_tensors) { #if HYBRIDBACKEND_ARROW - auto s = - ReadRecordBatch(batch_reader_.get(), filename, batch_size, field_names, - field_dtypes, field_ragged_ranks, field_shapes, - drop_remainder, end_, output_tensors, &row_counter_); + Status s = ReadRecordBatch(batch_reader_.get(), filename_, batch_size_, + record_defaults_, field_names_, field_dtypes_, + field_ragged_ranks_, field_shapes_, + field_column_indices_, drop_remainder_, end_, + output_tensors, &row_counter_); while (TF_PREDICT_FALSE(errors::IsOutOfRange(s)) && row_counter_ < end_) { TF_RETURN_IF_ARROW_ERROR(reader_->NextStripeReader(batch_size_, columns_) .Value(&batch_reader_)); if (!batch_reader_) { - Close(filename); + Close(filename_); return s; } - s = ReadRecordBatch(batch_reader_.get(), filename, batch_size, - field_names, field_dtypes, field_ragged_ranks, - field_shapes, drop_remainder, end_, output_tensors, - &row_counter_); + s = ReadRecordBatch(batch_reader_.get(), filename_, batch_size_, + record_defaults_, field_names_, field_dtypes_, + field_ragged_ranks_, field_shapes_, + field_column_indices_, drop_remainder_, end_, + output_tensors, &row_counter_); } return s; #else @@ -161,33 +166,43 @@ class OrcAccess::Impl { private: #if HYBRIDBACKEND_ARROW - int64 row_counter_; std::shared_ptr<::arrow::fs::FileSystem> fs_; std::shared_ptr<::arrow::io::RandomAccessFile> file_; std::unique_ptr<arrow::adapters::orc::ORCFileReader> reader_; std::shared_ptr<::arrow::RecordBatchReader> batch_reader_; - std::vector<string> field_names_; +#endif + const string filename_; + const int64 batch_size_; + const std::vector<Tensor> record_defaults_; + const std::vector<string>& field_names_; + const DataTypeVector& field_dtypes_; + const std::vector<int32>& field_ragged_ranks_; + const std::vector<PartialTensorShape>& field_shapes_; + const bool drop_remainder_; + std::vector<int> columns_; - int64 batch_size_; - int64 row_; - int64 end_; + std::vector<int> field_column_indices_; int64 start_row_; int64 end_row_; -#endif + int64 row_; + int64 end_; + int64 row_counter_; }; OrcAccess::OrcAccess(OpKernelContext* ctx, const TableFormat& format, const string& filename, const int64 batch_size, + const std::vector<Tensor>& record_defaults, const std::vector<string>& field_names, const DataTypeVector& field_dtypes, const std::vector<int32>& field_ragged_ranks, const std::vector<PartialTensorShape>& field_shapes, const bool drop_remainder, const bool skip_corrupted_data) - : TableAccess(format, filename, batch_size, field_names, field_dtypes, - field_ragged_ranks, field_shapes, drop_remainder, - skip_corrupted_data), - pimpl_(new OrcAccess::Impl(ctx, filename, batch_size, field_names, - field_dtypes, field_ragged_ranks)) {} + : TableAccess(format, filename, batch_size, record_defaults, field_names, + field_dtypes, field_ragged_ranks, field_shapes, + drop_remainder, skip_corrupted_data), + pimpl_(new OrcAccess::Impl(ctx, filename, batch_size, record_defaults, + field_names, field_dtypes, field_ragged_ranks, + field_shapes, drop_remainder)) {} int64 OrcAccess::Count() const { return pimpl_->Count(); } @@ -198,9 +213,7 @@ Status OrcAccess::Open(const int64 start, const int64 end) { } Status OrcAccess::Read(std::vector<Tensor>* output_tensors) { - return pimpl_->Read(filename(), batch_size(), field_names(), field_dtypes(), - field_ragged_ranks(), field_shapes(), drop_remainder(), - output_tensors); + return pimpl_->Read(output_tensors); } OrcAccess::~OrcAccess() { pimpl_->Close(filename()); } diff --git a/hybridbackend/tensorflow/data/tabular/orc.h b/hybridbackend/tensorflow/data/tabular/orc.h index bef171b3..133a0d8f 100644 --- a/hybridbackend/tensorflow/data/tabular/orc.h +++ b/hybridbackend/tensorflow/data/tabular/orc.h @@ -31,6 +31,7 @@ class OrcAccess : public TableAccess { public: OrcAccess(OpKernelContext* ctx, const TableFormat& format, const string& filename, const int64 batch_size, + const std::vector<Tensor>& record_defaults, const std::vector<string>& field_names, const DataTypeVector& field_dtypes, const std::vector<int32>& field_ragged_ranks, diff --git a/hybridbackend/tensorflow/data/tabular/parquet.cc b/hybridbackend/tensorflow/data/tabular/parquet.cc index 655ad9d0..cb08e854 100644 --- a/hybridbackend/tensorflow/data/tabular/parquet.cc +++ b/hybridbackend/tensorflow/data/tabular/parquet.cc @@ -17,6 +17,7 @@ limitations under the License. #include <absl/strings/match.h> +#include <algorithm> #include <memory> #include <numeric> #include <string> @@ -30,28 +31,37 @@ namespace hybridbackend { class ParquetAccess::Impl { public: Impl(OpKernelContext* ctx, const string& filename, const int64 batch_size, + const std::vector<Tensor>& record_defaults, const std::vector<string>& field_names, const DataTypeVector& field_dtypes, - const std::vector<int32>& field_ragged_ranks) - : start_row_(-1), end_row_(-1) { - OP_REQUIRES_OK(ctx, Initialize(filename, batch_size, field_names, - field_dtypes, field_ragged_ranks)); + const std::vector<int32>& field_ragged_ranks, + const std::vector<PartialTensorShape>& field_shapes, + const bool drop_remainder) + : filename_(filename), + batch_size_(batch_size), + record_defaults_(record_defaults), + field_names_(field_names), + field_dtypes_(field_dtypes), + field_ragged_ranks_(field_ragged_ranks), + field_shapes_(field_shapes), + drop_remainder_(drop_remainder), + start_row_(-1), + end_row_(-1), + row_counter_(0) { + OP_REQUIRES_OK(ctx, Initialize()); } - Status Initialize(const string& filename, const int64 batch_size, - const std::vector<string>& field_names, - const DataTypeVector& field_dtypes, - const std::vector<int32>& field_ragged_ranks) { + Status Initialize() { #if HYBRIDBACKEND_ARROW ::arrow::internal::Uri uri; - if (::arrow::fs::internal::IsLikelyUri(filename)) { - TF_RETURN_IF_ARROW_ERROR(uri.Parse(filename)); + if (::arrow::fs::internal::IsLikelyUri(filename_)) { + TF_RETURN_IF_ARROW_ERROR(uri.Parse(filename_)); } else { if (TF_PREDICT_FALSE( - !::arrow::fs::internal::IsLikelyUri("file://" + filename))) { - return errors::InvalidArgument("File name ", filename, " is illegal"); + !::arrow::fs::internal::IsLikelyUri("file://" + filename_))) { + return errors::InvalidArgument("File name ", filename_, " is illegal"); } - TF_RETURN_IF_ARROW_ERROR(uri.Parse("file://" + filename)); + TF_RETURN_IF_ARROW_ERROR(uri.Parse("file://" + filename_)); } std::vector<std::pair<std::string, std::string>> uri_options; TF_RETURN_IF_ARROW_ERROR(uri.query_items().Value(&uri_options)); @@ -64,14 +74,15 @@ class ParquetAccess::Impl { } TF_RETURN_IF_ARROW_ERROR( - ::hybridbackend::OpenArrowFile(&fs_, &file_, filename)); + ::hybridbackend::OpenArrowFile(&fs_, &file_, filename_)); TF_RETURN_IF_ARROW_ERROR( ::hybridbackend::OpenParquetReader(&reader_, file_, true)); std::shared_ptr<::arrow::Schema> schema; TF_RETURN_IF_ARROW_ERROR(reader_->GetSchema(&schema)); - TF_RETURN_IF_ERROR(ValidateSchema(filename, field_names, field_dtypes, - field_ragged_ranks, schema, &columns_)); - reader_->set_batch_size(batch_size); + TF_RETURN_IF_ERROR(ValidateSchema(filename_, field_names_, field_dtypes_, + field_ragged_ranks_, schema, &columns_, + &field_column_indices_)); + reader_->set_batch_size(batch_size_); #endif return Status::OK(); } @@ -124,19 +135,15 @@ class ParquetAccess::Impl { #endif } - Status Read(const string& filename, const int64 batch_size, - const std::vector<string>& field_names, - const DataTypeVector& field_dtypes, - const std::vector<int32>& field_ragged_ranks, - const std::vector<PartialTensorShape>& field_shapes, - const bool drop_remainder, std::vector<Tensor>* output_tensors) { + Status Read(std::vector<Tensor>* output_tensors) { #if HYBRIDBACKEND_ARROW - auto s = - ReadRecordBatch(batch_reader_.get(), filename, batch_size, field_names, - field_dtypes, field_ragged_ranks, field_shapes, - drop_remainder, -1, output_tensors, &row_counter_); + Status s = ReadRecordBatch(batch_reader_.get(), filename_, batch_size_, + record_defaults_, field_names_, field_dtypes_, + field_ragged_ranks_, field_shapes_, + field_column_indices_, drop_remainder_, -1, + output_tensors, &row_counter_); if (TF_PREDICT_FALSE(errors::IsOutOfRange(s))) { - Close(filename); + Close(filename_); } return s; #else @@ -146,29 +153,40 @@ class ParquetAccess::Impl { private: #if HYBRIDBACKEND_ARROW - int64 row_counter_; std::shared_ptr<::arrow::fs::FileSystem> fs_; std::shared_ptr<::arrow::io::RandomAccessFile> file_; std::unique_ptr<::parquet::arrow::FileReader> reader_; std::unique_ptr<::arrow::RecordBatchReader> batch_reader_; #endif + const string filename_; + const int64 batch_size_; + const std::vector<Tensor> record_defaults_; + const std::vector<string>& field_names_; + const DataTypeVector& field_dtypes_; + const std::vector<int32>& field_ragged_ranks_; + const std::vector<PartialTensorShape>& field_shapes_; + const bool drop_remainder_; + std::vector<int> columns_; + std::vector<int> field_column_indices_; int64 start_row_; int64 end_row_; + int64 row_counter_; }; ParquetAccess::ParquetAccess( OpKernelContext* ctx, const TableFormat& format, const string& filename, - const int64 batch_size, const std::vector<string>& field_names, - const DataTypeVector& field_dtypes, + const int64 batch_size, const std::vector<Tensor>& record_defaults, + const std::vector<string>& field_names, const DataTypeVector& field_dtypes, const std::vector<int32>& field_ragged_ranks, const std::vector<PartialTensorShape>& field_shapes, const bool drop_remainder, const bool skip_corrupted_data) - : TableAccess(format, filename, batch_size, field_names, field_dtypes, - field_ragged_ranks, field_shapes, drop_remainder, - skip_corrupted_data), - pimpl_(new ParquetAccess::Impl(ctx, filename, batch_size, field_names, - field_dtypes, field_ragged_ranks)) {} + : TableAccess(format, filename, batch_size, record_defaults, field_names, + field_dtypes, field_ragged_ranks, field_shapes, + drop_remainder, skip_corrupted_data), + pimpl_(new ParquetAccess::Impl( + ctx, filename, batch_size, record_defaults, field_names, field_dtypes, + field_ragged_ranks, field_shapes, drop_remainder)) {} int64 ParquetAccess::Count() const { return pimpl_->Count(); } @@ -179,9 +197,7 @@ Status ParquetAccess::Open(const int64 start, const int64 end) { } Status ParquetAccess::Read(std::vector<Tensor>* output_tensors) { - return pimpl_->Read(filename(), batch_size(), field_names(), field_dtypes(), - field_ragged_ranks(), field_shapes(), drop_remainder(), - output_tensors); + return pimpl_->Read(output_tensors); } ParquetAccess::~ParquetAccess() { pimpl_->Close(filename()); } diff --git a/hybridbackend/tensorflow/data/tabular/parquet.h b/hybridbackend/tensorflow/data/tabular/parquet.h index 31a24eba..f192f162 100644 --- a/hybridbackend/tensorflow/data/tabular/parquet.h +++ b/hybridbackend/tensorflow/data/tabular/parquet.h @@ -31,6 +31,7 @@ class ParquetAccess : public TableAccess { public: ParquetAccess(OpKernelContext* ctx, const TableFormat& format, const string& filename, const int64 batch_size, + const std::vector<Tensor>& record_defaults, const std::vector<string>& field_names, const DataTypeVector& field_dtypes, const std::vector<int32>& field_ragged_ranks, diff --git a/hybridbackend/tensorflow/data/tabular/table.cc b/hybridbackend/tensorflow/data/tabular/table.cc index 7006096d..2dfaed9e 100644 --- a/hybridbackend/tensorflow/data/tabular/table.cc +++ b/hybridbackend/tensorflow/data/tabular/table.cc @@ -33,21 +33,22 @@ namespace hybridbackend { TableAccess* TableAccess::Create( OpKernelContext* ctx, const TableFormat& format, const string& filename, - const int64 batch_size, const std::vector<string>& field_names, - const DataTypeVector& field_dtypes, + const int64 batch_size, const std::vector<Tensor>& record_defaults, + const std::vector<string>& field_names, const DataTypeVector& field_dtypes, const std::vector<int32>& field_ragged_ranks, const std::vector<PartialTensorShape>& field_shapes, const bool drop_remainder, const bool skip_corrupted_data) { switch (format) { case kParquetFormat: - return new ParquetAccess(ctx, format, filename, batch_size, field_names, - field_dtypes, field_ragged_ranks, field_shapes, - drop_remainder, skip_corrupted_data); + return new ParquetAccess(ctx, format, filename, batch_size, + record_defaults, field_names, field_dtypes, + field_ragged_ranks, field_shapes, drop_remainder, + skip_corrupted_data); break; case kOrcFormat: - return new OrcAccess(ctx, format, filename, batch_size, field_names, - field_dtypes, field_ragged_ranks, field_shapes, - drop_remainder, skip_corrupted_data); + return new OrcAccess(ctx, format, filename, batch_size, record_defaults, + field_names, field_dtypes, field_ragged_ranks, + field_shapes, drop_remainder, skip_corrupted_data); break; default: LOG(ERROR) << "File format " << format << " is not supported"; diff --git a/hybridbackend/tensorflow/data/tabular/table.h b/hybridbackend/tensorflow/data/tabular/table.h index 3499d8ef..7013f78c 100644 --- a/hybridbackend/tensorflow/data/tabular/table.h +++ b/hybridbackend/tensorflow/data/tabular/table.h @@ -35,14 +35,17 @@ class TableAccess { public: static TableAccess* Create( OpKernelContext* ctx, const TableFormat& format, const string& filename, - const int64 batch_size, const std::vector<string>& field_names, + const int64 batch_size, const std::vector<Tensor>& record_defaults, + const std::vector<string>& field_names, const DataTypeVector& field_dtypes, const std::vector<int32>& field_ragged_ranks, const std::vector<PartialTensorShape>& field_shapes, const bool drop_remainder, const bool skip_corrupted_data); TableAccess(const TableFormat& format, const string& filename, - const int64 batch_size, const std::vector<string>& field_names, + const int64 batch_size, + const std::vector<Tensor>& record_defaults, + const std::vector<string>& field_names, const DataTypeVector& field_dtypes, const std::vector<int32>& field_ragged_ranks, const std::vector<PartialTensorShape>& field_shapes, @@ -50,6 +53,7 @@ class TableAccess { : format_(format), filename_(std::move(filename)), batch_size_(batch_size), + record_defaults_(std::move(record_defaults)), field_names_(std::move(field_names)), field_dtypes_(std::move(field_dtypes)), field_ragged_ranks_(std::move(field_ragged_ranks)), @@ -63,6 +67,10 @@ class TableAccess { int64 batch_size() const { return batch_size_; } + const std::vector<Tensor>& record_defaults() const { + return record_defaults_; + } + const std::vector<string>& field_names() const { return field_names_; } const DataTypeVector& field_dtypes() const { return field_dtypes_; } @@ -93,6 +101,7 @@ class TableAccess { const TableFormat format_; const string filename_; const int64 batch_size_; + const std::vector<Tensor> record_defaults_; const std::vector<string> field_names_; const DataTypeVector field_dtypes_; const std::vector<int32> field_ragged_ranks_; diff --git a/hybridbackend/tensorflow/data/tabular/table.py b/hybridbackend/tensorflow/data/tabular/table.py index dc1fc340..52e390f0 100644 --- a/hybridbackend/tensorflow/data/tabular/table.py +++ b/hybridbackend/tensorflow/data/tabular/table.py @@ -29,7 +29,6 @@ from tensorflow.python.framework import ops from hybridbackend.tensorflow.data.dataframe import parse -from hybridbackend.tensorflow.data.dataframe import populate_defaults from hybridbackend.tensorflow.data.deduplicate.dataset import deduplicate from hybridbackend.tensorflow.data.rebatch.dataset import RebatchDataset @@ -81,7 +80,6 @@ def __init__( self._fn = fn self._filenames = filenames self._fields = fields if field_map_fn is None else field_map_fn(fields) - self._valid_fields = [f for f in self._fields if f.default_value is None] self._partition_count = partition_count self._partition_index = partition_index self._skip_corrupted_data = skip_corrupted_data @@ -97,9 +95,9 @@ def __init__( else: num_parallel_reads = 1 if num_parallel_parser_calls is None: - num_parallel_parser_calls = len(self._valid_fields) + num_parallel_parser_calls = len(self._fields) if num_parallel_parser_calls == dataset_ops.AUTOTUNE: - num_parallel_parser_calls = len(self._valid_fields) + num_parallel_parser_calls = len(self._fields) if num_parallel_reads is not None and num_parallel_parser_calls is not None: arrow_num_threads = os.getenv('ARROW_NUM_THREADS', None) if arrow_num_threads is None: @@ -161,7 +159,7 @@ def _create_dataset(self, batch_size): ''' def _creator(filename): filename = ops.convert_to_tensor(filename, dtypes.string, name='filename') - return self._fn(filename, self._valid_fields, batch_size) + return self._fn(filename, self._fields, batch_size) if self._num_parallel_reads == 1: return self._filenames.flat_map(_creator) @@ -186,9 +184,8 @@ def read(self): ds = ds.apply( deduplicate( self._key_idx_field_names, - self._value_field_names, fields=self._valid_fields)) + self._value_field_names, fields=self._fields)) ds = ds.apply(parse(pad=self._to_dense)) - ds = ds.apply(populate_defaults(self._fields, None)) return ds.unbatch() def batch(self, batch_size, drop_remainder=False): @@ -220,12 +217,11 @@ def batch(self, batch_size, drop_remainder=False): ds = ds.apply( deduplicate( self._key_idx_field_names, - self._value_field_names, fields=self._valid_fields)) + self._value_field_names, fields=self._fields)) ds = RebatchDataset( - ds, self._valid_fields, batch_size, + ds, self._fields, batch_size, drop_remainder=drop_remainder) - ds = ds.apply(parse(pad=self._to_dense)) - return ds.apply(populate_defaults(self._fields, batch_size)) + return ds.apply(parse(pad=self._to_dense)) def shuffle_batch( self, batch_size, @@ -264,12 +260,11 @@ def shuffle_batch( ds = ds.apply( deduplicate( self._key_idx_field_names, - self._value_field_names, fields=self._valid_fields)) + self._value_field_names, fields=self._fields)) ds = RebatchDataset( - ds, self._valid_fields, batch_size, + ds, self._fields, batch_size, drop_remainder=drop_remainder, shuffle_buffer_size=buffer_size, shuffle_seed=seed, reshuffle_each_iteration=reshuffle_each_iteration) - ds = ds.apply(parse(pad=self._to_dense)) - return ds.apply(populate_defaults(self._fields, batch_size)) + return ds.apply(parse(pad=self._to_dense)) diff --git a/hybridbackend/tensorflow/data/tests/parquet_dataset_deduplicate_test.py b/hybridbackend/tensorflow/data/tests/parquet_dataset_deduplicate_test.py index a5de4cba..ec35bc69 100644 --- a/hybridbackend/tensorflow/data/tests/parquet_dataset_deduplicate_test.py +++ b/hybridbackend/tensorflow/data/tests/parquet_dataset_deduplicate_test.py @@ -81,7 +81,6 @@ def test_apply_to_tensor(self): batch_size=2) ds = srcds.apply(hb.data.deduplicate(['user_feat_idx'], [['user_feat']])) ds = ds.apply(hb.data.parse(pad=True)) - ds = ds.apply(hb.data.populate_defaults(srcds.all_fields, 2)) batch = tf.data.make_one_shot_iterator(ds).get_next() baseline = tf.ragged.constant( self._data_user_feat_duplicated.to_pylist()).to_tensor() diff --git a/hybridbackend/tensorflow/data/tests/parquet_dataset_test.py b/hybridbackend/tensorflow/data/tests/parquet_dataset_test.py index b4ce6855..eedd6114 100644 --- a/hybridbackend/tensorflow/data/tests/parquet_dataset_test.py +++ b/hybridbackend/tensorflow/data/tests/parquet_dataset_test.py @@ -80,9 +80,11 @@ def test_read_with_defaults(self): with tf.Graph().as_default() as graph: ds = hb.data.Dataset.from_parquet( self._filename, - fields=[hb.data.DataFrame.Field('A', tf.int64), + fields=[hb.data.DataFrame.Field('A', tf.int64, + default_value=123), hb.data.DataFrame.Field('X', tf.int64, - default_value=default_value), + default_value=default_value, + ragged_rank=0), hb.data.DataFrame.Field('C', tf.int64)]) ds = ds.batch(batch_size) ds = ds.prefetch(4) @@ -175,6 +177,31 @@ def gen_filenames(): with self.assertRaises(tf.errors.OutOfRangeError): sess.run(batch) + def test_read_from_generator_with_defaults(self): + num_epochs = 2 + batch_size = 100 + with tf.Graph().as_default() as graph: + def gen_filenames(): + for i in xrange(num_epochs + 1): + if i == num_epochs: + return # raise StopIteration + yield self._filename + filenames = tf.data.Dataset.from_generator( + gen_filenames, tf.string, tf.TensorShape([])) + fields = [ + hb.data.DataFrame.Field('A', tf.int64, 0), + hb.data.DataFrame.Field('C', tf.int64, 0), + hb.data.DataFrame.Field('PI', tf.float32, 0, default_value=3.14)] + ds = filenames.apply(hb.data.read_parquet(batch_size, fields=fields)) + ds = ds.prefetch(4) + batch = tf.data.make_one_shot_iterator(ds).get_next() + + with tf.Session(graph=graph) as sess: + for _ in xrange(len(self._df) * num_epochs // batch_size): + sess.run(batch) + with self.assertRaises(tf.errors.OutOfRangeError): + sess.run(batch) + def test_read_from_generator_parallel(self): num_epochs = 2 batch_size = 100 From c74f1a309b67df51525086a6b760b72d8adeb938 Mon Sep 17 00:00:00 2001 From: "yuanman.ym" <yuanman.ym@alibaba-inc.com> Date: Sat, 20 Apr 2024 01:34:02 +0800 Subject: [PATCH 2/3] fix --- hybridbackend/tensorflow/common/arrow.cc | 2 +- .../data/tests/parquet_dataset_ragged_test.py | 36 ++++++++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/hybridbackend/tensorflow/common/arrow.cc b/hybridbackend/tensorflow/common/arrow.cc index 829fd363..1af6c7ee 100644 --- a/hybridbackend/tensorflow/common/arrow.cc +++ b/hybridbackend/tensorflow/common/arrow.cc @@ -423,7 +423,7 @@ Status MakeTensorsFromRecordDefaultValue(const DataType type, Tensor split_tensor(DT_INT32, {2}); auto split_tensor_flat = split_tensor.tensor<int32, 1>(); split_tensor_flat(0) = 0; - split_tensor_flat(1) = shape.num_elements(); + split_tensor_flat(1) = actual_shape.num_elements(); output_tensors->emplace_back(std::move(split_tensor)); remained_ragged_rank--; } diff --git a/hybridbackend/tensorflow/data/tests/parquet_dataset_ragged_test.py b/hybridbackend/tensorflow/data/tests/parquet_dataset_ragged_test.py index 4ba88e30..4629bbe3 100644 --- a/hybridbackend/tensorflow/data/tests/parquet_dataset_ragged_test.py +++ b/hybridbackend/tensorflow/data/tests/parquet_dataset_ragged_test.py @@ -118,7 +118,41 @@ def test_to_sparse(self): len(set(list(zip(*actual.indices))[0])) + 1, len(expected.nested_row_splits[0])) - def test_map_to_sparse(self): + def test_to_sparse_with_defaults(self): + batch_size = 32 + with tf.Graph().as_default() as graph: + ds = hb.data.Dataset.from_parquet( + [self._filename], + fields = [ + hb.data.DataFrame.Field('col2', tf.int64, 1), + hb.data.DataFrame.Field('col0', tf.int64, 1), + hb.data.DataFrame.Field('nocol', tf.int64, 1, default_value=8)]) + ds = ds.batch(batch_size) + ds = ds.prefetch(4) + batch = tf.data.make_one_shot_iterator(ds).get_next() + + c = self._df['col0'] + with tf.Session(graph=graph) as sess: + for i in xrange(3): + result = sess.run(batch) + start_row = i * batch_size + end_row = (i + 1) * batch_size + expected_items = c[start_row:end_row].to_numpy().tolist() + expected_values = [] + expected_splits = [0] + for item in expected_items: + expected_values.extend(item) + expected_splits.append(expected_splits[-1] + len(item)) + expected = hb.data.DataFrame.Value( + np.array(expected_values), + [np.array(expected_splits, dtype=np.int32)]) + actual = result['col0'] + np.testing.assert_allclose(actual.values, expected.values) + np.testing.assert_equal( + len(set(list(zip(*actual.indices))[0])) + 1, + len(expected.nested_row_splits[0])) + + def xtest_map_to_sparse(self): batch_size = 32 with tf.Graph().as_default() as graph: ds = hb.data.Dataset.from_parquet( From f2c700ecd6e48fcb6016c704443a5ca46a726d45 Mon Sep 17 00:00:00 2001 From: "yuanman.ym" <yuanman.ym@alibaba-inc.com> Date: Sat, 20 Apr 2024 13:31:15 +0800 Subject: [PATCH 3/3] Fix ragged tensor support --- hybridbackend/tensorflow/common/arrow.cc | 16 +++++++--------- hybridbackend/tensorflow/data/dataframe.py | 5 +++++ 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/hybridbackend/tensorflow/common/arrow.cc b/hybridbackend/tensorflow/common/arrow.cc index 1af6c7ee..223d1a20 100644 --- a/hybridbackend/tensorflow/common/arrow.cc +++ b/hybridbackend/tensorflow/common/arrow.cc @@ -419,19 +419,17 @@ Status MakeTensorsFromRecordDefaultValue(const DataType type, } output_tensors->emplace_back(std::move(values_tensor)); int32 remained_ragged_rank = ragged_rank; - if (remained_ragged_rank > 0) { - Tensor split_tensor(DT_INT32, {2}); - auto split_tensor_flat = split_tensor.tensor<int32, 1>(); - split_tensor_flat(0) = 0; - split_tensor_flat(1) = actual_shape.num_elements(); - output_tensors->emplace_back(std::move(split_tensor)); - remained_ragged_rank--; + int64 stride = shape.num_elements(); + if (stride < 1) { + stride = 1; } while (remained_ragged_rank > 0) { - Tensor split_tensor(DT_INT32, {2}); + Tensor split_tensor(DT_INT32, {actual_batch_size + 1}); auto split_tensor_flat = split_tensor.tensor<int32, 1>(); split_tensor_flat(0) = 0; - split_tensor_flat(1) = 1; + for (size_t i = 0; i < actual_batch_size; i++) { + split_tensor_flat(i + 1) = (i + 1) * stride; + } output_tensors->emplace_back(std::move(split_tensor)); remained_ragged_rank--; } diff --git a/hybridbackend/tensorflow/data/dataframe.py b/hybridbackend/tensorflow/data/dataframe.py index 231d6c8e..ad7d7491 100644 --- a/hybridbackend/tensorflow/data/dataframe.py +++ b/hybridbackend/tensorflow/data/dataframe.py @@ -180,6 +180,11 @@ def __init__( f'Field {name} with shape {shape} should be a fixed-length list ' f'not a nested list ({ragged_rank})') self._shape = shape + if (default_value is not None and + not isinstance(default_value, (int, float, bool, str))): + raise ValueError( + f'Default value of field {name} must be a scalar ' + f'instead of {type(default_value)}') self._default_value = default_value self._restore_idx_field = None