读取自定义文件和记录格式
前提:
- 对 C++ 有一定程度的了解。
- 通过源码安装 TensorFlow
我们将支持自定义文件格式分为两个任务:
- 文件格式:使用
tf.data.Dataset
阅读器来从文件中读取原始记录(通常以零阶字符串张量(scalar string tensors)表示,也可能有其他结构)。 - 记录格式:使用解码器或者解析操作将一个字符串记录转换成 TensorFlow 可用的张量(tensor)。
[TOC]
为文件格式编写一个数据集
每个实现包含了三个相关的类:
一个
tensorflow::DatasetOpKernel
的子类 (如TextLineDatasetOp
),这个类的MakeDataset()
方法告诉 TensorFlow 怎样根据一个操作的输入和属性生成一个数据集的对象。一个
tensorflow::GraphDatasetBase
的子类(如TextLineDatasetOp::Dataset
),表示数据集的不可变性定义,这个类的MakeIterator()
方法告诉 TensorFlow 怎样在数据集上生成迭代器对象。一个
tensorflow::DatasetIterator<Dataset>
的子类(如TextLineDatasetOp::Dataset::Iterator
),表示特定数据集上的迭代器的可变性,这个类的GetNextInternal()
方法告诉 TensorFlow 怎样获取迭代器的下一个元素。
其中最重要的方法是 GetNextInternal()
,因为它定义了怎样从文件中实际读取记录,并用一个或多个 Tensor
对象来表示它们。
创建一个新的阅读器数据集叫做(比方说)MyReaderDataset
,你需要:
- 在 C++ 中定义
tensorflow::DatasetOpKernel
、tensorflow::GraphDatasetBase
和tensorflow::DatasetIterator<Dataset>
的子类来实现读取逻辑。 - 在 C++ 中注册一个新的名叫
"MyReaderDataset"
的阅读器操作和内核。 tf.data.Dataset
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace {
class MyReaderDatasetOp : public DatasetOpKernel {
public:
MyReaderDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
// 用 `ctx->GetAttr()` 解析并验证定义数据集的属性,并把它们存在成员变量中。
}
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
// 用 `ctx->input()` 或者通用函数 `ParseScalarArgument<T>(ctx, &arg)` 解析并验证定义数据集的输入张量。
// 创建数据集对象,并根据属性或输入张量传入(已经验证的)参数。
*output = new Dataset(ctx);
}
private:
class Dataset : public GraphDatasetBase {
public:
Dataset(OpKernelContext* ctx) : GraphDatasetBase(ctx) {}
std::unique_ptr<IteratorBase> MakeIterator(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::MyReader")}));
}
// 记录结构:每个记录用一个零阶字符串张量表示。
//
// 数据集的元素有固定数量的组件,每个组件有不同的类型和形状;重写以下两个方法来自定义数据集的这方面的设置。
const DataTypeVector& output_dtypes() const override {
static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
return *dtypes;
}
const std::vector<PartialTensorShape>& output_shapes() const override {
static std::vector<PartialTensorShape>* shapes =
new std::vector<PartialTensorShape>({{}});
return *shapes;
}
string DebugString() override { return "MyReaderDatasetOp::Dataset"; }
protected:
// 可选:数据集的 `GraphDef` 序列化。
//
// 如果你想保存这个数据集(和它上面的所有迭代器)的实例,实现以下这个方法。
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
Node** output) const override {
// 使用 `b->AddScalar()` 和 `b->AddVector()` 来从这个对象的成员变量构建代表输入张量的节点。
std::vector<Node*> input_tensors;
TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output));
return Status::OK();
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params), i_(0) {}
// 读取逻辑的实现。
//
// 这个文件中的示例实现产生十次 『MyReader!』 字符串。总的来讲有以下三种情况:
// 1. 如果成功读取一个元素,在 `*out_tensors` 中将它储存为一个或多个张量,设置 `*end_of_sequence = false` 并返回 `Status::OK()`。
// 2. 如果到达输入的结尾,设置 `*end_of_sequence = true` 并返回 `Status::OK()`。
// 3. 如果发生了一个错误,通过 "tensorflow/core/lib/core/errors.h" 中的帮助函数返回一个错误状态。
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
// 注意:`GetNextInternal()` 可能会被并发调用,所以推荐用一个互斥量来保护迭代器的状态。
mutex_lock l(mu_);
if (i_ < 10) {
// 创建一个零阶字符串张量并把它添加到输出中。
Tensor record_tensor(ctx->allocator({}), DT_STRING, {});
record_tensor.scalar<string>()() = "MyReader!";
out_tensors->emplace_back(std::move(record_tensor));
++i_;
*end_of_sequence = false;
} else {
*end_of_sequence = true;
}
return Status::OK();
}
protected:
// 可选:迭代器的状态序列化。
//
// 如果你想保存和恢复这个迭代器的实例,实现以下两个方法。
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
return Status::OK();
}
private:
mutex mu_;
int64 i_ GUARDED_BY(mu_);
};
};
};
// 为 MyReaderDataset 注册操作定义。
//
// 数据集操作通常只有一个类型为 `variant` 的输出,代表结构化的 `Dataset` 对象。
//
// 在这里添加定义数据集的属性和输入张量。
REGISTER_OP("MyReaderDataset")
.Output("handle: variant")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);
// 为 MyReaderDataset 注册核心实现。
REGISTER_KERNEL_BUILDER(Name("MyReaderDataset").Device(DEVICE_CPU),
MyReaderDatasetOp);
} // 命名空间
} // 命名空间 tensorflow
import tensorflow as tf
# 假设文件在当前工作目录下。
my_reader_dataset_module = tf.load_op_library("./my_reader_dataset_op.so")
class MyReaderDataset(tf.data.Dataset):
def __init__(self):
super(MyReaderDataset, self).__init__()
# 将输入属性或张量作为类的成员变量创建。
def _as_variant_tensor(self):
# 为数据集操作构建图形节点
#
# 当你在这个数据集或者由它衍生出来的数据集上创建一个迭代器时,
# 这个方法会被调用。
return my_reader_dataset_module.my_reader_dataset()
# 以下属性定义了元素的结构:一个零阶 `tf.string` 张量。
# 如果你修改了元素的结构,也需要修改这些属性
# 来与 `MyReaderDataset::Dataset` 中
# 的 `output_dtypes()` 和 `output_shapes()` 匹配。
@property
def output_types(self):
return tf.string
@property
def output_shapes(self):
return tf.TensorShape([])
@property
def output_classes(self):
return tf.Tensor
if __name__ == "__main__":
# 创建一个 MyReaderDataset 并打印它的元素。
with tf.Session() as sess:
iterator = MyReaderDataset().make_one_shot_iterator()
next_element = iterator.get_next()
try:
while True:
print(sess.run(next_element)) # 打印十次 『MyReader!』。
except tf.errors.OutOfRangeError:
pass
你可以在 tensorflow/python/data/ops/dataset_ops.py
中找到一些 Dataset
封装类的例子。
为记录格式写一个操作
列几个对解码记录有帮助的操作: