Pytorch 多维数组运算过程的索引处理方式
背景:对python不熟悉,能看懂代码,也能实现一些简单的功能,但是对python的核心思想和编程技巧不熟,所以使Pytorch写loss的时候遇到很多麻烦,尤其是在batch_size>1的时候,做矩阵乘法之类的运算会觉得特别不顺手。
所幸,在边查边写的过程中,理解了python中多维运算的实现规则。
1、python的基本索引规则
从0开始
对于给定的范围,如b=a[m:n],那么b为由(n-m)个数据组成的新数组,由a[m],a[m+1],...,a[n-1]构成。(若n 2、单个tensor运算,使用dim参数 torch中对tensor的操作方法,若不加dim参数表示对整体的tensr进行操作,若增加dim参则表示按维操作。 例: 注:torch.mean()是一个降维的操作,所以不会出现在取均值后保持跟原Tensor同维的情况。dim参数存在时降一维,不存在时得到的是整个Tensor的均值。 3、两个tensor运算,构造对应形状 以乘法为例: 3.1矩阵乘向量 计算乘法c=a@b 若a拓展为(N,3,2)N为batch_size,计算c2=a@b 若a,b同时拓展,变成(N,2),那么需要做一个变换b=b.view(N,2,1),计算c3=a@b 3.2矩阵乘矩阵 计算乘法c=a@b 若a拓展为(N,3,2)N为batch_size,计算c2=a@b 若a,b同时拓展,变成(N,2,2),计算c3=a@b 以上这篇Pytorch多维数组运算过程的索引处理方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。 声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。
a=[[1,2],[3,4],[5,6]](torch.tensor)
torch.mean(a)=>3.5
torch.mean(a,dim=0)=>[1.5,3.5,5.5]
torch.mean(a,dim=1)=>[[3],[4]]
torch.mean(a,dim=0)=>[3,4]
torch.mean(a,dim=1)=>[1.5,3.5,5.5]
a=[[1,2],[3,4],[5,6]]
b=[1,1]
a=[[1,2],[3,4],[5,6]]
b=[[1,1],[1,1]]