Pytorch中index_select() 函数的实现理解
函数形式:
index_select( dim, index )
参数:
- dim:表示从第几维挑选数据,类型为int值;
- index:表示从第一个参数维度中的哪个位置挑选数据,类型为torch.Tensor类的实例;
刚开始学习pytorch,遇到了index_select(),一开始不太明白几个参数的意思,后来查了一下资料,算是明白了一点。
a=torch.linspace(1,12,steps=12).view(3,4) print(a) b=torch.index_select(a,0,torch.tensor([0,2])) print(b) print(a.index_select(0,torch.tensor([0,2]))) c=torch.index_select(a,1,torch.tensor([1,3])) print(c)
先定义了一个tensor,这里用到了linspace和view方法。
第一个参数是索引的对象,第二个参数0表示按行索引,1表示按列进行索引,第三个参数是一个tensor,就是索引的序号,比如b里面tensor[0,2]表示第0行和第2行,c里面tensor[1,3]表示第1列和第3列。
输出结果如下:
tensor([[1., 2., 3., 4.],
[5., 6., 7., 8.],
[9.,10.,11.,12.]])
tensor([[1., 2., 3., 4.],
[9.,10.,11.,12.]])
tensor([[1., 2., 3., 4.],
[9.,10.,11.,12.]])
tensor([[2., 4.],
[6., 8.],
[10.,12.]])
功能:从张量的某个维度的指定位置选取数据。
代码实例:
t=torch.arange(24).reshape(2,3,4)#初始化一个tensor,从0到23,形状为(2,3,4) print("t--->",t) index=torch.tensor([1,2])#要选取数据的位置 print("index--->",index) data1=t.index_select(1,index)#第一个参数:从第1维挑选,第二个参数:从该维中挑选的位置 print("data1--->",data1) data2=t.index_select(2,index)#第一个参数:从第2维挑选,第二个参数:从该维中挑选的位置 print("data2--->",data2)
运行结果:
t--->tensor([[[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9,10,11]],
[[12,13,14,15],
[16,17,18,19],
[20,21,22,23]]])
index--->tensor([1,2])
data1--->tensor([[[4, 5, 6, 7],
[8, 9,10,11]],
[[16,17,18,19],
[20,21,22,23]]])
data2--->tensor([[[1, 2],
[5, 6],
[9,10]],
[[13,14],
[17,18],
[21,22]]])
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。