libtorch系列教程2:torch::Tensor的使用
系列教程列表:
这篇文章中,我们暂时忽略网络训练和推理,详细展开Libtorch中Tensor对象的使用,看看将Libtorch当作一个纯粹的Tensor库来使用时,有哪些注意事项。如有未涉及的内容,请访问Libtorch官方文档,通过搜索框获取更多的信息。Libtorch的环境搭建参考上一篇文章。
1. torch::Tensor基本操作
Libtorch中的Tensor是与Pytorch中的Tensor对应的,使用方式上很类似,只在一些Python语法C++不支持的时候有些不同,例如slice操作。
使用Libtorch前需要包含 Libtorch 的头文件torch/torch.h
:
1 |
这篇文章用到的所有函数都在此头文件中声明,而且所有的函数namespace都是torch
,因此都可以以torch::xxx
的形式来调用。
1.1 Tensor创建
Tensor 创建的方式比较多,包括从字面量创建,从C++ 原生的数组创建,从vector创建,从Libtorch自带的函数创建等。
从字面量创建:
1 | torch::Tensor foo = torch::tensor({1.0, 2.0, 3.0, 4.0}); |
从C++ 原生的float数组创建,使用from_blob
函数:
1 | float arr[] = {1.0, 2.0, 3.0, 4.0}; |
其中第二个参数表示创建的Tensor shape,会自动对原生数组进行reshape。
从vector 创建,使用from_blob
函数:
1 | std::vector<float> v = {1.0, 2.0, 3.0, 4.0}; |
还可以用Libtorch的函数创建,跟Numpy和Pytorch类似:
1 | foo = torch::arange(4); |
创建好以后,Tensor对应可以直接用std::cout
来输出:
1 | torch::Tensor foo = torch::tensor({1.0, 2.0, 3.0, 4.0}); |
输出如下:
1 | ==> foo is: |
可以看到最后打印了Tensor的类型。
1.2 Tensor对象的属性函数
创建Tensor后,我们还需要看到它的一些属性,判断是否跟预期相符。注意Libtorch的Tensor是没有公开可访问的属性attribute的,Tensor信息需要属性函数来获取。常见的属性函数包括:
- dim(): Tensor的维度
- sizes(): 跟Pytorch中的shape属性一样
- size(n): 第N个维度的shape
- numel(): 总的元素数目,sizes中的每个元素相乘
- dtype(): 数据类型
- device(): Tensor所在的设备类型,CPU, CUDA, MPS等。
使用方式如下:
1 | // Tensor 属性函数 |
1.3 Tensor对象的索引
Tensor 默认是支持[]
操作符的,因此可以使用这样的方式来获取元素:
1 | auto foo = torch::randn({1, 2, 3, 4}); |
另一种方式是用Tensor对象的index
函数,它的优势是支持slice。
对于单个元素,可以类似Pytorch中,直接用index({i, j, k})
的方式来索引:
1 | auto foo = torch::randn({1, 2, 3, 4}); |
那么python中很常用的slice呢?例如foo[..., :2, 1:, :-1]
,该怎么在Libtorch中表示?
这里需要用到torch::indexing::Slice
对象,来实现Python中的Slice,看看下面的例子你就明白了:
1 | using namespace torch::indexing; |
应该是能满足Python中slice同样的使用场景。
1.4 更新Tensor中元素的值
有了索引之后,我们就可以更新Tensor的值了:
1 | torch::Tensor foo = torch::tensor({1.0, 2.0, 3.0, 4.0}); |
但还没找到用给部分Tensor元素赋值的方法,类似Python中的foo[:2] = bar
,欢迎补充。
1.5 获取Tensor中的数据
Tensor是一个Libtorch的对象,那怎么把它中的数据拿出来保存到文件中或传给别的函数呢?
使用data_ptr
函数就可以:
1 | torch::Tensor foo = torch::randn({3, 3}); |
对于单个元素的Tensor,还可以用item
函数得到具体的数值:
1 | torch::Tensor one_element_tensor = foo.index({Slice(), Slice(0, 1), Slice(0, 1), Slice(0, 1)}); |
1.6 数据类型
Libtorch中支持float16, float32, float64, int8, int16, int32, uint8这几类的Tensor数据类型,可以用to
函数来进行类型转换:
1 | // 数据类型, 参见 https://pytorch.org/cppdocs/api/file_torch_csrc_api_include_torch_types.h.html#variables |
全部数据类型,参见官方文档的数据类型页面。
1.7 设备类型
设备类型是Tensor保存的设备的种类。由于Libtorch不仅仅支持CPU,还支持各种类型的GPU,因此有很多设备类型。
所有的设备类型参见这里。
需要注意的是,设备是跟编译时的配置,机器是否支持强相关的,而且某些设备支持并不好,例如我想用下面的代码将CPU上的Tensor转移到MPS上:
1 | auto foo = torch::randn({3, 3}); |
编译是没有问题的,但运行时会报下面的错:
libc++abi: terminating with uncaught exception of type c10::TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn’t support float64. Please use float32 instead.
提示说MPS不支持float64,但我打印foo
的类型,它其实是float32,本身报错比较奇怪,搜了一圈也没找到怎么解决。
1.8 Tensor 变形函数
很多时候我们需要将Tensor进行形状的修改,这方面Libtorch支持的比较好,这些操作都支持:
- reshape
- flatten
- squeeze
- unsqueeze
- transpose
- cat/concat/concatenate
而且支持torch::reshape
这种静态函数和tensor.reshape
这种对象函数。下面是一些例子:
1 | // 变形操作 |
一个比较特殊的地方是transpose只支持两个轴的交换,多个轴的交换需要调用多次来实现。
1.9 Tensor之间的操作函数
Tensor库中,Tensor和Tensor之间的操作是很常见的,比如求矩阵相乘,内积外积等,有内置的函数支持能避免很多额外的开发工作。这里是一些例子:
1 | foo = torch::randn({3, 3}); |
1.10 线性代数相关函数
torch::linalg
namespace中包含常见的线性代数操作,几个简单的使用例子:
1 | bar = torch::linalg::inv(foo); |
所有支持的函数详见官方文档
1.11 神经网络相关函数
神经网络是torch的核心模块,常见的一些激活函数,卷积层都可以以函数的形式作用在Tensor上,这里写几个简单的例子:
1 | bar = torch::softmax(foo, -1); |