tensorflow源码分析之Tensor

680 阅读2分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

 Tensor

tensorflow,从名字上看由tensor+flow组成。本文来看看Tensor是什么,是怎么实现的。

tensorflow里的tensor可以抽象的认为由<n维数组,数组元素类型(dtype),数组各维大小(shape)>三元组成,同时在这三元组上有一些操作:创建,删除,复制,改变shape, 切片等等。如果用C++来简单定义:

struct Tensor {
    std::vector<int> shape; //表示多维数组各维大小,如三维数组:shape={2,3,4}
    int dtype;              //表示数据类型,根据类型能报data转成对应的数组
    void *data;             //连续内存空间,保存了数组中所有元素
};

shape是可以修改的,比如一个2x3的数组,也可以变成3x2,只要元素个数不变就行。

data是一段连续的内存空间,正如c++中的数组 T[2][3]. 如果dtype是整数,那么就是int data[2][3],

data是个指针,如果强制转成int *data. 那么data, data+1, data+2, ..., data+5就是各个元素。

还能切片slice:如如把data的第一维拿出来,就是data[1]. 因为是2*3数组, data[0], data[1]都3个元素。slice之后tensor还引用着原tensor的内存。而且通过引用计数保存原tensor内存释放了,slice也是可用的

Tensor实现

然而,在工程实现中,还要考虑data的对齐,如8字节对齐。也要考虑 data的内存分配方式,tensorflow里定义了allocator接口,来实现各种不同的分配方式。考虑到模型参数保存,checkpoint保存等,tensor还得支持序列化,tensorflow使用protobuf来序列化tensor.

Tensor的实现在:

ls tensorflow/core/framework/tensor.*
tensor.cc     tensor.h      tensor.proto  

基本操作

Tensor成员

  TensorShape shape_; //形状
  TensorBuffer* buf_; //数据

构造

  • 空构造:不是scalar, shape {0}, NumElements() ==0。
  • type+shape构造,会分配内存:Tensor(DataType type, const TensorShape& shape); 默认用CPUAllocator
  • allocator+type+shape构造:Tensor(Allocator* a, DataType type, const TensorShape& shape);
  • 带buffer构建:Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf);
  • 基于常量(scalar)的构建函数,重载了很多 explicit Tensor(float scalar_value)

切片

按第一维切片,但是不复制数据,不能保证对齐IsAligned

Tensor Slice(int64_t dim0_start, int64_t dim0_limit) const;
Tensor SubSlice(int64_t index) const;

序列化

  bool FromProto(const TensorProto& other) TF_MUST_USE_RESULT;
  bool FromProto(Allocator* a, const TensorProto& other) TF_MUST_USE_RESULT;

  /// \brief Fills in `proto` with `*this` tensor's content.
  ///
  /// `AsProtoField()` fills in the repeated field for `proto.dtype()`, while
  /// `AsProtoTensorContent()` encodes the content in `proto.tensor_content()`
  /// in a compact form.
  void AsProtoField(TensorProto* proto) const;
  void AsProtoTensorContent(TensorProto* proto) const;

拷贝

  1. 复制构造和移动构造都支持
  2. operator=支持复制和移动

访问

/// Returns the data type.
DataType dtype() const { return shape_.data_type(); }

/// Returns the shape of the tensor.
const TensorShape& shape() const { return shape_; }

/// \brief Convenience accessor for the tensor shape.
///
/// For all shape accessors, see comments for relevant methods of
/// `TensorShape` in `tensor_shape.h`.
int dims() const { return shape().dims(); }

/// Convenience accessor for the tensor shape.
int64_t dim_size(int d) const { return shape().dim_size(d); }

/// Convenience accessor for the tensor shape.
int64_t NumElements() const { return shape().num_elements(); }

size_t AllocatedBytes() const
bool IsAligned() const
bool CopyFrom(const Tensor& other,
              const TensorShape& shape)
Tensor t;
d = t.scalar<float>(); //访问scalar
d = t.vec<float>();    //以一维数组方式访问: d[0]

d = t.matrix<float>(); //以矩阵方式访问: d(2,3)

//单个元素访问
flat = t.flat<float>()
d = flat.data()
for(auto i = 0; i < t.NumElements(); i++) d[i]
  template <typename T>
  typename TTypes<T>::Flat flat() {
    return shaped<T, 1>({NumElements()});
  }

  template <typename T>
  typename TTypes<T>::UnalignedFlat unaligned_flat() {
    return unaligned_shaped<T, 1>({NumElements()});
  }

//用于memcpy
/// REQUIRES: `DataTypeCanUseMemcpy(dtype())`.
StringPiece tensor_data() const;
void* data() const;

Debug信息


  std::string SummarizeValue(int64_t max_entries, bool print_v2 = false) const;
  std::string DebugString(int num_values) const;
  std::string DebugString() const { return DebugString(3); }
  std::string DeviceSafeDebugString() const;
  void FillDescription(TensorDescription* description) const;

Tensor shape type的实现在如下文件中

tensor.h: TensorBuffer来执行级data内存。

$ ls tensorflow/core/framework/tensor* 
tensorflow/core/framework/tensor.cc                 tensorflow/core/framework/tensor_shape.proto    tensorflow/core/framework/tensor_testutil.h
tensorflow/core/framework/tensor.h                  tensorflow/core/framework/tensor_shape_test.cc  tensorflow/core/framework/tensor_testutil_test.cc
tensorflow/core/framework/tensor.proto              tensorflow/core/framework/tensor_slice.cc       tensorflow/core/framework/tensor_types.h
tensorflow/core/framework/tensor_description.proto  tensorflow/core/framework/tensor_slice.h        tensorflow/core/framework/tensor_util.cc
tensorflow/core/framework/tensor_key.h              tensorflow/core/framework/tensor_slice.proto    tensorflow/core/framework/tensor_util.h
tensorflow/core/framework/tensor_reference.h        tensorflow/core/framework/tensor_slice_test.cc  tensorflow/core/framework/tensor_util_test.cc
tensorflow/core/framework/tensor_shape.cc           tensorflow/core/framework/tensor_test.cc
tensorflow/core/framework/tensor_shape.h            tensorflow/core/framework/tensor_testutil.cc

$ ls tensorflow/core/framework/shape* 
tensorflow/core/framework/shape_inference.cc  tensorflow/core/framework/shape_inference_test.cc      tensorflow/core/framework/shape_inference_testutil.h
tensorflow/core/framework/shape_inference.h   tensorflow/core/framework/shape_inference_testutil.cc  tensorflow/core/framework/shape_inference_testutil_test.cc

$ ls tensorflow/core/framework/type* 
tensorflow/core/framework/type_index.h   tensorflow/core/framework/typed_allocator.cc  tensorflow/core/framework/types.cc  tensorflow/core/framework/types.proto
tensorflow/core/framework/type_traits.h  tensorflow/core/framework/typed_allocator.h   tensorflow/core/framework/types.h   tensorflow/core/framework/types_test.cc

Tensor支持的数据类型

定义在tensorflow/core/framework/types.proto中

enum DataType {
  // Not a legal value for DataType.  Used to indicate a DataType field
  // has not been set.
  DT_INVALID = 0;

  // Data types that all computation devices are expected to be
  // capable to support.
  DT_FLOAT = 1;
  DT_DOUBLE = 2;
  DT_INT32 = 3;
  DT_UINT8 = 4;
  DT_INT16 = 5;
  DT_INT8 = 6;
  DT_STRING = 7;
  DT_COMPLEX64 = 8;  // Single-precision complex
  DT_INT64 = 9;
  DT_BOOL = 10;
  DT_QINT8 = 11;     // Quantized int8
  DT_QUINT8 = 12;    // Quantized uint8
  DT_QINT32 = 13;    // Quantized int32
  DT_BFLOAT16 = 14;  // Float32 truncated to 16 bits.  Only for cast ops.
  DT_QINT16 = 15;    // Quantized int16
  DT_QUINT16 = 16;   // Quantized uint16
  DT_UINT16 = 17;
  DT_COMPLEX128 = 18;  // Double-precision complex
  DT_HALF = 19;
  DT_RESOURCE = 20;
  DT_VARIANT = 21;  // Arbitrary C++ data types
  DT_UINT32 = 22;
  DT_UINT64 = 23;
}

序列化tensor.proto

tensor序列化后可以成为PB.


// Protocol buffer representing a tensor.
message TensorProto {
  DataType dtype = 1;
  TensorShapeProto tensor_shape = 2;

  int32 version_number = 3;

  bytes tensor_content = 4;
  
  repeated int32 half_val = 13 [packed = true];

  // DT_FLOAT.
  repeated float float_val = 5 [packed = true];

  // DT_DOUBLE.
  repeated double double_val = 6 [packed = true];

  // DT_INT32, DT_INT16, DT_UINT16, DT_INT8, DT_UINT8.
  repeated int32 int_val = 7 [packed = true];

  // DT_STRING
  repeated bytes string_val = 8;

  // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real
  // and imaginary parts of i-th single precision complex.
  repeated float scomplex_val = 9 [packed = true];

  // DT_INT64
  repeated int64 int64_val = 10 [packed = true];

  // DT_BOOL
  repeated bool bool_val = 11 [packed = true];

  // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real
  // and imaginary parts of i-th double precision complex.
  repeated double dcomplex_val = 12 [packed = true];

  // DT_RESOURCE
  repeated ResourceHandleProto resource_handle_val = 14;

  // DT_VARIANT
  repeated VariantTensorDataProto variant_val = 15;

  // DT_UINT32
  repeated uint32 uint32_val = 16 [packed = true];

  // DT_UINT64
  repeated uint64 uint64_val = 17 [packed = true];
}

// Protocol buffer representing the serialization format of DT_VARIANT tensors.
message VariantTensorDataProto {
  // Name of the type of objects being serialized.
  string type_name = 1;
  // Portions of the object that are not Tensors.
  bytes metadata = 2;
  // Tensors contained within objects being serialized.
  repeated TensorProto tensors = 3;
}

tensor_util.h和tensor_util_test.cc中有使用tensor的样例

提供了如下功能:

  • tensor深拷贝
  • slice深拷贝
  • Concat 连接
  • Split 分割
  • ConcatSplitStrings 字符串连接
  • CreatesStringTensorProto: 从文件的protobuf中反序列化出dtype=DT_STRING的tensor
  • CreatesInt32TensorProto
  • CreatesInt64TensorProto
  • CreatesUInt32TensorProto
  • CreatesUInt64TensorProto
  • ...各种类型都有从文件反序列化
  • CompressTensorProtoInPlaceTooSmall  各种tensor proto压缩
  • CompressTensorProtoInPlaceAllEqual
  • CompressTensorProtoConstantTail
  • CompressTensorProtoNegatizeZero

Tensor内存分配:Allocator

接口

tensorflow/core/framework/allocator.h

class Allocator {
 public:
  static constexpr size_t kAllocatorAlignment = 64;
  virtual ~Allocator();
  virtual std::string Name() = 0;
  virtual void* AllocateRaw(size_t alignment, size_t num_bytes) = 0;
  virtual void* AllocateRaw(size_t alignment, size_t num_bytes,
                            const AllocationAttributes& allocation_attr) {
    return AllocateRaw(alignment, num_bytes);
  }
  virtual void DeallocateRaw(void* ptr) = 0;

  virtual size_t AllocatedSize(const void* ptr) const {
    return RequestedSize(ptr);
  }
};

可以继承并实现自己的allocator

CPUAllocator

tensorflow/core/framework/cpu_allocator_impl.h


class CPUAllocator : public Allocator {
 public:
  CPUAllocator()
      : single_allocation_warning_count_(0),
        total_allocation_warning_count_(0) {}

  ~CPUAllocator() override {}

  string Name() override { return "cpu"; }

  void* AllocateRaw(size_t alignment, size_t num_bytes) override {
    void* p = port::AlignedMalloc(num_bytes, alignment);
    
    return p;
  }

  void DeallocateRaw(void* ptr) override {
    port::AlignedFree(ptr);
  }
};

//注册cpu allocator
REGISTER_MEM_ALLOCATOR("DefaultCPUAllocator", 100, CPUAllocatorFactory);

Allocator注册

tensorflow/core/framework/allocator_registry.h

class AllocatorFactoryRegistry {
 public:
  AllocatorFactoryRegistry() {}
  ~AllocatorFactoryRegistry() {}

  void Register(const char* source_file, int source_line, const string& name,
                int priority, AllocatorFactory* factory);
  Allocator* GetAllocator();

 private:
  std::vector<FactoryEntry> factories_ TF_GUARDED_BY(mu_);
};