读取自定义文件和记录格式

前提:

我们将支持自定义文件格式分为两个任务:

  • 文件格式:使用 tf.data.Dataset 阅读器来从文件中读取原始记录(通常以零阶字符串张量(scalar string tensors)表示,也可能有其他结构)。
  • 记录格式:使用解码器或者解析操作将一个字符串记录转换成 TensorFlow 可用的张量(tensor)。

tf.decode_csv

[TOC]

为文件格式编写一个数据集

tf.data.Dataset

每个实现包含了三个相关的类:

  • 一个 tensorflow::DatasetOpKernel 的子类 (如 TextLineDatasetOp),这个类的 MakeDataset() 方法告诉 TensorFlow 怎样根据一个操作的输入和属性生成一个数据集的对象。

  • 一个 tensorflow::GraphDatasetBase 的子类(如 TextLineDatasetOp::Dataset),表示数据集的不可变性定义,这个类的 MakeIterator() 方法告诉 TensorFlow 怎样在数据集上生成迭代器对象。

  • 一个 tensorflow::DatasetIterator<Dataset> 的子类(如 TextLineDatasetOp::Dataset::Iterator),表示特定数据集上的迭代器的可变性,这个类的 GetNextInternal() 方法告诉 TensorFlow 怎样获取迭代器的下一个元素。

其中最重要的方法是 GetNextInternal(),因为它定义了怎样从文件中实际读取记录,并用一个或多个 Tensor 对象来表示它们。

创建一个新的阅读器数据集叫做(比方说)MyReaderDataset,你需要:

  1. 在 C++ 中定义 tensorflow::DatasetOpKerneltensorflow::GraphDatasetBasetensorflow::DatasetIterator<Dataset> 的子类来实现读取逻辑。
  2. 在 C++ 中注册一个新的名叫 "MyReaderDataset" 的阅读器操作和内核。
  3. tf.data.Dataset

添加一个新操作(Op)

#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

namespace tensorflow {
namespace {

class MyReaderDatasetOp : public DatasetOpKernel {
 public:

  MyReaderDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
    // 用 `ctx->GetAttr()` 解析并验证定义数据集的属性,并把它们存在成员变量中。
  }

  void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
    // 用 `ctx->input()` 或者通用函数 `ParseScalarArgument<T>(ctx, &arg)` 解析并验证定义数据集的输入张量。

    // 创建数据集对象,并根据属性或输入张量传入(已经验证的)参数。
    *output = new Dataset(ctx);
  }

 private:
  class Dataset : public GraphDatasetBase {
   public:
    Dataset(OpKernelContext* ctx) : GraphDatasetBase(ctx) {}

    std::unique_ptr<IteratorBase> MakeIterator(
        const string& prefix) const override {
      return std::unique_ptr<IteratorBase>(
          new Iterator({this, strings::StrCat(prefix, "::MyReader")}));
    }

    // 记录结构:每个记录用一个零阶字符串张量表示。
    //
    // 数据集的元素有固定数量的组件,每个组件有不同的类型和形状;重写以下两个方法来自定义数据集的这方面的设置。
    const DataTypeVector& output_dtypes() const override {
      static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
      return *dtypes;
    }
    const std::vector<PartialTensorShape>& output_shapes() const override {
      static std::vector<PartialTensorShape>* shapes =
          new std::vector<PartialTensorShape>({{}});
      return *shapes;
    }

    string DebugString() override { return "MyReaderDatasetOp::Dataset"; }

   protected:
    // 可选:数据集的 `GraphDef` 序列化。
    //
    // 如果你想保存这个数据集(和它上面的所有迭代器)的实例,实现以下这个方法。
    Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
                              Node** output) const override {
      // 使用 `b->AddScalar()` 和 `b->AddVector()` 来从这个对象的成员变量构建代表输入张量的节点。
      std::vector<Node*> input_tensors;
      TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output));
      return Status::OK();
    }

   private:
    class Iterator : public DatasetIterator<Dataset> {
     public:
      explicit Iterator(const Params& params)
          : DatasetIterator<Dataset>(params), i_(0) {}

      // 读取逻辑的实现。
      //
      // 这个文件中的示例实现产生十次 『MyReader!』 字符串。总的来讲有以下三种情况:
      // 1. 如果成功读取一个元素,在 `*out_tensors` 中将它储存为一个或多个张量,设置 `*end_of_sequence = false` 并返回 `Status::OK()`。
      // 2. 如果到达输入的结尾,设置 `*end_of_sequence = true` 并返回 `Status::OK()`。
      // 3. 如果发生了一个错误,通过 "tensorflow/core/lib/core/errors.h" 中的帮助函数返回一个错误状态。
      Status GetNextInternal(IteratorContext* ctx,
                             std::vector<Tensor>* out_tensors,
                             bool* end_of_sequence) override {
        // 注意:`GetNextInternal()` 可能会被并发调用,所以推荐用一个互斥量来保护迭代器的状态。
        mutex_lock l(mu_);
        if (i_ < 10) {
          // 创建一个零阶字符串张量并把它添加到输出中。
          Tensor record_tensor(ctx->allocator({}), DT_STRING, {});
          record_tensor.scalar<string>()() = "MyReader!";
          out_tensors->emplace_back(std::move(record_tensor));
          ++i_;
          *end_of_sequence = false;
        } else {
          *end_of_sequence = true;
        }
        return Status::OK();
      }

     protected:
      // 可选:迭代器的状态序列化。
      //
      // 如果你想保存和恢复这个迭代器的实例,实现以下两个方法。
      Status SaveInternal(IteratorStateWriter* writer) override {
        mutex_lock l(mu_);
        TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
        return Status::OK();
      }
      Status RestoreInternal(IteratorContext* ctx,
                             IteratorStateReader* reader) override {
        mutex_lock l(mu_);
        TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
        return Status::OK();
      }

     private:
      mutex mu_;
      int64 i_ GUARDED_BY(mu_);
    };
  };
};

// 为 MyReaderDataset 注册操作定义。
//
// 数据集操作通常只有一个类型为 `variant` 的输出,代表结构化的 `Dataset` 对象。
//
// 在这里添加定义数据集的属性和输入张量。
REGISTER_OP("MyReaderDataset")
    .Output("handle: variant")
    .SetIsStateful()
    .SetShapeFn(shape_inference::ScalarShape);

// 为 MyReaderDataset 注册核心实现。
REGISTER_KERNEL_BUILDER(Name("MyReaderDataset").Device(DEVICE_CPU),
                        MyReaderDatasetOp);

}  // 命名空间
}  // 命名空间 tensorflow

tf.data.Dataset

import tensorflow as tf

# 假设文件在当前工作目录下。
my_reader_dataset_module = tf.load_op_library("./my_reader_dataset_op.so")

class MyReaderDataset(tf.data.Dataset):

  def __init__(self):
    super(MyReaderDataset, self).__init__()
    # 将输入属性或张量作为类的成员变量创建。

  def _as_variant_tensor(self):
    # 为数据集操作构建图形节点
    #
    # 当你在这个数据集或者由它衍生出来的数据集上创建一个迭代器时,
    # 这个方法会被调用。
    return my_reader_dataset_module.my_reader_dataset()

  # 以下属性定义了元素的结构:一个零阶 `tf.string` 张量。
  # 如果你修改了元素的结构,也需要修改这些属性
  # 来与 `MyReaderDataset::Dataset` 中
  # 的 `output_dtypes()` 和 `output_shapes()` 匹配。
  @property
  def output_types(self):
    return tf.string

  @property
  def output_shapes(self):
    return tf.TensorShape([])

  @property
  def output_classes(self):
    return tf.Tensor

if __name__ == "__main__":
  # 创建一个 MyReaderDataset 并打印它的元素。
  with tf.Session() as sess:
    iterator = MyReaderDataset().make_one_shot_iterator()
    next_element = iterator.get_next()
    try:
      while True:
        print(sess.run(next_element))  # 打印十次 『MyReader!』。
    except tf.errors.OutOfRangeError:
      pass

你可以在 tensorflow/python/data/ops/dataset_ops.py 中找到一些 Dataset 封装类的例子。

为记录格式写一个操作

添加一个新操作(Op)

列几个对解码记录有帮助的操作:

tf.reshape