libtorch Tensor张量的常用操作总结(1)

 基于libtorch的深度学习框架,其处理数据的主要基本单位是Tensor张量,我们可以把Tensor张量理解成矩阵,该矩阵的维度可以是1维、2维、3维,或更高维。

本文我们来总结一下Tensor张量的常用操作。

01

打印张量的信息

  • 打印张量的维度信息

要查看张量的维度信息,通常有两种方式:打印张量的sizes;或者直接调用张量类的print函数:

torch::Tensor b = torch::zeros({ 3, 5 });
cout << b.sizes() << endl;   //方式一,只打印维度信息
b.print();    //方式二,除了打印维度信息,数据类型也打印出来

运行结果:

  • 打印张量的内容

torch::Tensor b = torch::zeros({ 3, 5 });
cout << b << endl;

运行结果:

02

定义并初始化张量的值

  • 定义一定维度的张量并初始化全部值为0

torch::Tensor b = torch::zeros({ 5, 7 });  //定义5行7列的0值张量
cout << b << endl;

运行结果如下,得到5行7列的张量:

  • 定义一定维度的张量并初始化全部值为1

auto b = torch::ones({ 3,4 });  //定义3行4列的1值张量
cout << b << endl;

运行结果如下,得到3行4列的张量:

  • 定义一定维度的单位张量

单位张量与单位向量是一个概念,即对角线值为1,其余值全部为0:

auto b = torch::eye(5);  //定义5*5单位张量
cout << b << endl;

运行结果如下,得到5行5列的单位张量:

  • 定义一定维度的张量并设置初始值

auto b = torch::full({ 3,4 }, 10);  //定义3行4列张量,并初始化全部值为0
cout << b << endl;

运行结果如下,得到3行4列的张量:

此外,还可以使用另一个张量的形状作为模板,定义相同形状维度的张量,并填充初始值:

auto b = torch::full({ 3,4 }, 10);  //定义3行4列的张量b,并填充全部值为10
auto a = torch::full_like(b, 2);   //定义与b相同形状的张量a,并填充初始值2
auto a1 = torch::full_like(b, 2.5);  //定义与b相同形状的张量a,并填充初始值2.5
auto a2 = torch::full_like(b.toType(kFloat), 2.5);  //定义与b相同形状的张量a,并填充初始值2.5
cout << b << endl;
cout << a << endl;
cout << a1 << endl;
cout << a2 << endl;

运行结果如下,我们注意到张量b的数据默认为long int型,那么使用full_like定义张量a、a1时,它们的数据类型也默认为long int型,即使a1填充值为2.5,也被自动截断为整型数2了。而a2则强行把b转换为float型,使a2也是float型数据,因此a2可以使用2.5来填充而不被自动截断为整型数了。

  • 定义n行1列的张量,并指定初始值

auto b = torch::tensor({ 1,2,3,4,5,6,7,8 });  //定义8行1列的张量,并指定初始值
cout  << b << endl;

运行结果如下,得到8行1列的张量:

  • 定义一定维度的张量,并使用随机数初始化

//定义3行4列张量,并使用区间[0, 1)的符合均匀分布的随机数初始化
auto r = torch::rand({ 3,4 });
cout  << r << endl;
//定义5行6列张量,并使用符合标准正态分布(均值为0,方差为1,即高斯白噪声)的随机数初始化
r = torch::randn({ 5, 6 });
cout  << r << endl;
//定义5行5列张量,并使用区间[0, 10)的整型数初始化
r = torch::randint(0, 10, { 5,5 });
cout  << r << endl;

运行结果如下:

  • 使用数组或某一段内存初始化张量

使用数组或某一段内存来初始化张量时,通常调用torch::from_blob函数来实现:

//使用数组来初始化张量内容
int aa[4] = { 3,4,6,7 };
auto aaaaa = torch::from_blob(aa, { 2, 2 }, torch::kInt);
cout  << aaaaa << endl;
//使用vector迭代容器来初始化张量内容
vector<float> aaaa = { 3,4,6 };
auto aaa = torch::from_blob(aaaa.data(), { 1, 1, 1, 3 }, torch::kFloat);
cout  << aaa << endl;
//使用Opencv的Mat来初始化张量内容,相当于把Mat转换为Tensor
Mat x = Mat::zeros(5, 5, CV_32FC1);
auto xx = torch::from_blob(x.data, { 1, 1, 5, 5 }, torch::kFloat);
cout  << xx << endl;

运行结果如下:

神经网络的输入通常为一张单通道灰度图或一张三通道的彩色图,如果输入为Opencv Mat格式的三通道彩色图,我们需要格外注意数据维度的顺序,因为Mat格式的三通道图像与libtorch Tensor张量的数据维度是不一样的,前者是[Height, Width, channels],后者是[channels, Height, Width],如果展开成一维向量来看,Opencv Mat存储RGB图像的顺序为(每个R、G、B像素点交替存储)

libtorch张量存储RGB图像的顺序为(依次存储所有的R、G、B像素点)

因此将Mat格式的三通道图像转换为Tensor张量时,我们应该首先把[Height, Width, Channels]的Mat格式数据转换为[Height, Width, Channels]的Tensor张量,然后再调用Tensor张量的permute函数把数据的维度顺序调整为[Channels, Height, Width]即可

Mat x1 = (Mat_<uchar>(5, 5) << 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25);
Mat x2 = x1.clone();
Mat x3 = x1.clone();
vector<Mat> channels;
channels.push_back(x1.clone());
channels.push_back(x2.clone());
channels.push_back(x3.clone());
Mat x123;
merge(channels, x123);  //合并成一张三通道图像
cout << "[Height, Width, channels]格式的Mat:" << endl;
cout << x123 << endl;
//错误维度顺序示范
auto x_t = torch::from_blob(x123.data, { 3, 5, 5 }, torch::kByte);
cout << "直接将[Height, Width, channels]格式Mat转换为的[channels, Height, Width]格式的Tensor,维度不对应:" << endl;
cout << x_t << endl;
//建议的做法,确保Tensor张量的维度顺序为[channels, Height, Width]
x_t = torch::from_blob(x123.data, { 5, 5, 3 }, torch::kByte);
cout << "先将[Height, Width, channels]格式Mat转换为的[Height, Width, channels]格式的Tensor:" << endl;
cout << x_t << endl;
x_t = x_t.permute({ 2, 0, 1 });
cout << "再调整Tensor的维度顺序:[Height, Width, channels]-->[channels, Height, Width]:" << endl;
cout << x_t << endl;

运行结果:

此外,调用torch::from_blob函数建立的Tensor张量,与传入的指针是共用内存的,该张量并没有重新开辟一段内存,这一点需要注意。如果需要开辟内存,则通过调用clone函数来执行深拷贝:

int aa[4] = { 3,4,6,7 };
auto aaaaa = torch::from_blob(aa, { 2, 2 }, torch::kInt).clone();

03

张量的拼接

  • 按列拼接

两个张量可以按列拼接的前提条件是它们的行数一样,否则拼接会出错:

torch::Tensor a1 = torch::rand({ 2,3 });  //2行3列
torch::Tensor a2 = torch::rand({ 2,1 });  //2行1列
torch::Tensor cat_1 = torch::cat({ a1, a2 }, 1); //dim参数为1表示按列拼接
std::cout << a1 << std::endl;
std::cout << a2 << std::endl;
std::cout << cat_1 << std::endl;

运行结果:

  • 按行拼接

两个张量可以按行拼接的前提条件是它们的列数一样,否则拼接会出错:

torch::Tensor a1 = torch::rand({ 2,3 });  //2行3列
torch::Tensor a2 = torch::rand({ 1,3 });  //1行3列
torch::Tensor cat_1 = torch::cat({ a1, a2 }, 0);  //dim参数为0表示按行拼接
std::cout << a1 << std::endl;
std::cout << a2 << std::endl;
std::cout << cat_1 << std::endl;

运行结果:

04

张量的切片与索引

所谓切片,我们可以把张量理解成一个蛋糕,把蛋糕切成一块块的操作就相当于切片。索引也好理解,索引就是一个地址,通过该地址我们可以定位到张量内部的某一个数值或某一部分数值。

下面我们以三维张量[channels, Height, Width]为例说明张量的常见切片、索引操作。注意所有维度的序号均从0开始。

1. 索引操作

  • 对第1维度、第2维度的所有索引,取第3维度索引号范围i~j的数据

//linspace(1, 75, 75)为取范围再1~75之间、长度为75的数组,也即1、2、3、...、75
auto a = torch::linspace(1, 75, 75).reshape({ 3, 5, 5 });  //start -- end -- length
cout << a << endl;
//对于所有第1维度、第2维度,取第3维度索引号为2的数据
auto bx = a.index({ "...", 2 });  
cout << bx << endl;

运行结果:

  • 对第1维度的所有索引,取第2维度索引号为i、第3维度索引号为j的数据

auto a = torch::linspace(1, 75, 75).reshape({ 3, 5, 5 });  
cout << a << endl;
auto bx = a.index({ "...", 2, 3 });  //对所有第1维度,取第2维度索引号为2、第3维度索引号为3的数据
cout << bx << endl;

运行结果:

  • 对第1维度的索引号i,取第2维度、第3维度的所有索引号的数据

该操作相当于取channels个Height*Width矩阵中的第i个Height*Width矩阵。

auto a = torch::linspace(1, 75, 75).reshape({ 3, 5, 5 });  
cout << a << endl;
//对索引号为2的第1维度,取所有第2维度、第3维度数据
auto bx = a.index({ 2, "..."});  
cout << bx << endl;

运行结果:

  • 对第1维度的索引号i、第3维度的索引号j,取第2维度所有索引的数据

该操作相当于取channels个Height*Width矩阵中的第i个Height*Width矩阵的第j列。

auto a = torch::linspace(1, 75, 75).reshape({ 3, 5, 5 });  
cout << a << endl;
//对索引号为2的第1维度、索引号为3的第3维度,取所有第2维度数据
auto bx = a.index({ 2, "...", 3 });  
cout << bx << endl;

运行结果:

  • 直接指定张量各维度的索引

cout << "**************************" << endl;


Tensor a = torch::linspace(1, 25, 25).reshape({ 5, 5 });
cout << a << endl;
//取第1维度索引号为1、第2维度索引号为2的所有数据
auto b = a.index({ 1, 2 });  
cout << b << endl;


cout << "**************************" << endl;


a = torch::linspace(1, 27, 27).reshape({ 3, 3, 3 });
cout << a << endl;
//取第1维度索引号为1、第2维度索引号为2的所有数据
b = a.index({ 1, 2 });  
cout << b << endl;


cout << "**************************" << endl;
  
a = torch::linspace(1, 75, 75).reshape({ 3, 5, 5 });
cout << a << endl;
//取第1维度索引号为1、第2维度索引号为2、第3维度索引号为3的所有数据
b = a.index({ 1, 2, 3 });  
cout << b << endl;

运行结果:

  • 通过索引赋值

auto a = torch::linspace(1, 4, 4).reshape({ 2, 2 });
cout << a << endl;
//将第1维度的索引号1、第2维度的索引号1处赋值为100
a.index_put_({ 1, 1 }, 100);
cout << a << endl;
//将第1维度的索引号0、第2维度的索引号0处赋值为101
a.index_put_({ 0, 0 }, 101);
cout << a << endl;
//将第1维度的索引号1、第2维度的索引号0处赋值为102
a.index_put_({ 1, 0 }, 102);
cout << a << endl;


a = torch::linspace(1, 9, 9).reshape({ 3, 3 });
cout << a << endl;
//将第1维度的所有索引号、第2维度的索引号1处赋值为100
a.index_put_({ "...", 1 }, 100);
cout << a << endl;

运行结果:

2. 切片操作

切片操作主要通过调用Slice函数实现。首先我们介绍一下该函数。

Slice函数主要用于对张量的某一维进行切片时,指定切片的开始索引、结束索引,以及切片步长(切片步长默认为1):

Slice(
    //开始索引,若设置为None,则从索引0开始
    c10::optional<int64_t> start_index = c10::nullopt,   
    //结束索引。若设置为None,则到最大索引号结束
    c10::optional<int64_t> stop_index = c10::nullopt,   
    //切片步长,默认为1
    c10::optional<int64_t> step_index = c10::nullopt    
    )
  • 对第1维度,第2维度的所有索引号,从[第3维度的索引号i~第3维度的最大索引号]开始切片,默认切片步长为1

auto a = torch::linspace(1, 75, 75).reshape({ 3, 5, 5 });
cout << a << endl;
//从第3维度的索引号1开始切片,一直切到第3维度的最大索引号,也即取第1~第4列
//这里只Slice只设置开始索引,结束索引和切片步长都默认
auto b = a.index({ "...", Slice(1) });   
cout << b << endl;

运行结果:

  • 对第1维度,第2维度的所有索引号,从[第3维度的索引号i~第3维度的最大索引号]开始切片,并设定切片步长为2

auto a = torch::linspace(1, 75, 75).reshape({ 3, 5, 5 });
cout << a << endl;
//设定切片步长为2,从第3维度的索引号1开始切片,一直切到第3维度的最大索引号,也即取第1~第4列
auto b = a.index({ "...", Slice(1, None, 2) });  
cout << b << endl;

运行结果:

  • 对第1维度,第2维度的所有索引号,从[第3维度的索引号0~第3维度的索引号i]开始切片,默认切片步长为1

注意:切片不包括第3维度的索引号i,比如i为3,则切片索引号0、1、2。

auto a = torch::linspace(1, 75, 75).reshape({ 3, 5, 5 });
cout << a << endl;
//对第1维度,第2维度的所有索引号,从[第3维度的索引号0~第3维度的索引号2]开始切片,默认切片步长为1
auto b = a.index({ "...", Slice({None, 2}) });   
cout << b << endl;

运行结果:

好了,本文就总结到这里,下文我们继续哈~

欢迎扫码关注本微信公众号,您的支持是我坚持下去的最大动力~

posted @ 2021-06-29 20:57  萌萌哒程序猴  阅读(1730)  评论(0编辑  收藏  举报