PyTorch中Tensor的拼接与拆分的实现
拼接张量:torch.cat()、torch.stack()
- torch.cat(inputs,dimension=0)→Tensor
在给定维度上对输入的张量序列seq进行连接操作
举个例子:
>>>importtorch >>>x=torch.randn(2,3) >>>x tensor([[-0.1997,-0.6900,0.7039], [0.0268,-1.0140,-2.9764]]) >>>torch.cat((x,x,x),0)#在0维(纵向)进行拼接 tensor([[-0.1997,-0.6900,0.7039], [0.0268,-1.0140,-2.9764], [-0.1997,-0.6900,0.7039], [0.0268,-1.0140,-2.9764], [-0.1997,-0.6900,0.7039], [0.0268,-1.0140,-2.9764]]) >>>torch.cat((x,x,x),1)#在1维(横向)进行拼接 tensor([[-0.1997,-0.6900,0.7039,-0.1997,-0.6900,0.7039,-0.1997,-0.6900, 0.7039], [0.0268,-1.0140,-2.9764,0.0268,-1.0140,-2.9764,0.0268,-1.0140, -2.9764]]) >>>y1=torch.randn(5,3,6) >>>y2=torch.randn(5,3,6) >>>torch.cat([y1,y2],2).size() torch.Size([5,3,12]) >>>torch.cat([y1,y2],1).size() torch.Size([5,6,6])
对于需要拼接的张量,维度数量必须相同,进行拼接的维度的尺寸可以不同,但是其它维度的尺寸必须相同。
- torch.stack(sequence,dim=0)
沿着一个新维度对输入张量序列进行连接。序列中所有的张量都应该为相同形状
举个例子:
>>>x1=torch.randn(2,3) >>>x2=torch.randn(2,3) >>>torch.stack((x1,x2),0).size()#在0维插入一个维度,进行区分拼接 torch.Size([2,2,3]) >>>torch.stack((x1,x2),1).size()#在1维插入一个维度,进行组合拼接 torch.Size([2,2,3]) >>>torch.stack((x1,x2),2).size() torch.Size([2,3,2]) >>>torch.stack((x1,x2),0) tensor([[[-0.3499,-0.6124,1.4332], [0.1516,-1.5439,-0.1758]], [[-0.4678,-1.1430,-0.5279], [-0.4917,-0.6504,2.2512]]]) >>>torch.stack((x1,x2),1) tensor([[[-0.3499,-0.6124,1.4332], [-0.4678,-1.1430,-0.5279]], [[0.1516,-1.5439,-0.1758], [-0.4917,-0.6504,2.2512]]]) >>>torch.stack((x1,x2),2) tensor([[[-0.3499,-0.4678], [-0.6124,-1.1430], [1.4332,-0.5279]], [[0.1516,-0.4917], [-1.5439,-0.6504], [-0.1758,2.2512]]])
把相同形状的张量合并,并根据提供的维度序列在相应位置插入维度,方法会根据位置来排列数据。代码中,根据第0维和第1维来进行合并时,虽然合并后的张量维度和尺寸相等,但是数据的位置并不是相同的。
拆分张量:torch.split()、torch.chunk()
- torch.split(tensor,split_size,dim=0)
将输入张量分割成相等形状的chunks(如果可分)。如果沿指定维的张量形状大小不能被split_size整分,则最后一个分块会小于其它分块。
举个例子:
>>>x=torch.randn(3,10,6) >>>a,b,c=x.split(1,0)#在0维进行间隔维1的拆分 >>>a.size(),b.size(),c.size() (torch.Size([1,10,6]),torch.Size([1,10,6]),torch.Size([1,10,6])) >>>d,e=x.split(2,0)#在0维进行间隔维2的拆分 >>>d.size(),e.size() (torch.Size([2,10,6]),torch.Size([1,10,6]))
把张量在0维度上以间隔1来拆分时,其中x在0维度上的尺寸为3,就可以分成3份。
把张量在0维度上以间隔2来拆分时,只能分成2份,且只能把前面部分先以间隔2来拆分,后面不足2的部分就直接作为一个分块。
- torch.chunk(tensor,chunks,dim=0)
在给定维度(轴)上将输入张量进行分块儿
直接用上面的数据来举个例子:
>>>l,m,n=x.chunk(3,0)#在0维上拆分成3份 >>>l.size(),m.size(),n.size() (torch.Size([1,10,6]),torch.Size([1,10,6]),torch.Size([1,10,6])) >>>u,v=x.chunk(2,0)#在0维上拆分成2份 >>>u.size(),v.size() (torch.Size([2,10,6]),torch.Size([1,10,6]))
把张量在0维度上拆分成3部分时,因为尺寸正好为3,所以每个分块的间隔相等,都为1。
把张量在0维度上拆分成2部分时,无法平均分配,以上面的结果来看,可以看成是,用0维度的尺寸除以需要拆分的份数,把余数作为最后一个分块的间隔大小,再把前面的分块以相同的间隔拆分。
在某一维度上拆分的份数不能比这一维度的尺寸大
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。