解决Pytorch自定义层出现多Variable共享内存错误问题
错误信息:
RuntimeError:in-placeoperationscanbeonlyusedonvariablesthatdon'tsharestoragewithanyothervariables,butdetectedthatthereare4objectssharingit
自动求导是很方便,但是想想,如果两个Variable共享内存,再对这个共享的内存的数据进行修改,就会引起错误!
一般是由于inplace操作或是indexing或是转置.这些都是共享内存的.
@staticmethod defbackward(ctx,grad_output): ind_lst=ctx.ind_lst flag=ctx.flag c=grad_output.size(1) grad_former_all=grad_output[:,0:c//3,:,:] grad_latter_all=grad_output[:,c//3:c*2//3,:,:] grad_swapped_all=grad_output[:,c*2//3:c,:,:] spatial_size=ctx.h*ctx.w W_mat_all=Variable(ctx.Tensor(ctx.bz,spatial_size,spatial_size).zero_()) foridxinrange(ctx.bz): W_mat=W_mat_all.select(0,idx) forcntinrange(spatial_size): indS=ind_lst[idx][cnt] ifflag[cnt]==1: #这里W_mat是W_mat_all通过select出来的,他们共享内存. W_mat[cnt,indS]=1 W_mat_t=W_mat.t() grad_swapped_weighted=torch.mm(W_mat_t,grad_swapped_all[idx].view(c//3,-1).t()) grad_swapped_weighted=grad_swapped_weighted.t().contiguous().view(1,c//3,ctx.h,ctx.w) grad_latter_all[idx]=torch.add(grad_latter_all[idx],grad_swapped_weighted.mul(ctx.triple_w))
由于这里W_mat是W_mat_all通过select出来的,他们共享内存.所以当对这个共享的内存进行修改W_mat[cnt,indS]=1,就会出错.此时我们可以通过clone()将W_mat和W_mat_all独立出来.这样的话,梯度也会通过clone()操作将W_mat的梯度正确反传到W_mat_all中.
@staticmethod defbackward(ctx,grad_output): ind_lst=ctx.ind_lst flag=ctx.flag c=grad_output.size(1) grad_former_all=grad_output[:,0:c//3,:,:] grad_latter_all=grad_output[:,c//3:c*2//3,:,:] grad_swapped_all=grad_output[:,c*2//3:c,:,:] spatial_size=ctx.h*ctx.w W_mat_all=Variable(ctx.Tensor(ctx.bz,spatial_size,spatial_size).zero_()) foridxinrange(ctx.bz): #这里使用clone了 W_mat=W_mat_all.select(0,idx).clone() forcntinrange(spatial_size): indS=ind_lst[idx][cnt] ifflag[cnt]==1: W_mat[cnt,indS]=1 W_mat_t=W_mat.t() grad_swapped_weighted=torch.mm(W_mat_t,grad_swapped_all[idx].view(c//3,-1).t()) grad_swapped_weighted=grad_swapped_weighted.t().contiguous().view(1,c//3,ctx.h,ctx.w) #这句话删了不会出错,加上就吹出错 grad_latter_all[idx]=torch.add(grad_latter_all[idx],grad_swapped_weighted.mul(ctx.triple_w))
但是现在却出现4个objects共享内存.如果将最后一句话删掉,那么则不会出错.
如果没有最后一句话,我们看到
grad_swapped_weighted=torch.mm(W_mat_t,grad_swapped_all[idx].view(c//3,-1).t())
grad_swapped_weighted=grad_swapped_weighted.t().contiguous().view(1,c//3,ctx.h,ctx.w)
grad_swapped_weighted一个新的Variable,因此并没有和其他Variable共享内存,所以不会出错.但是最后一句话,
grad_latter_all[idx]=torch.add(grad_latter_all[idx],grad_swapped_weighted.mul(ctx.triple_w))
你可能会说,不对啊,修改grad_latter_all[idx]又没有创建新的Variable,怎么会出错.这是因为grad_latter_all和grad_output是共享内存的.因为grad_latter_all=grad_output[:,c//3:c*2//3,:,:],所以这里的解决方案是:
@staticmethod defbackward(ctx,grad_output): ind_lst=ctx.ind_lst flag=ctx.flag c=grad_output.size(1) grad_former_all=grad_output[:,0:c//3,:,:] #这两个后面修改值了,所以也要加clone,防止它们与grad_output共享内存 grad_latter_all=grad_output[:,c//3:c*2//3,:,:].clone() grad_swapped_all=grad_output[:,c*2//3:c,:,:].clone() spatial_size=ctx.h*ctx.w W_mat_all=Variable(ctx.Tensor(ctx.bz,spatial_size,spatial_size).zero_()) foridxinrange(ctx.bz): W_mat=W_mat_all.select(0,idx).clone() forcntinrange(spatial_size): indS=ind_lst[idx][cnt] ifflag[cnt]==1: W_mat[cnt,indS]=1 W_mat_t=W_mat.t() grad_swapped_weighted=torch.mm(W_mat_t,grad_swapped_all[idx].view(c//3,-1).t()) grad_swapped_weighted=grad_swapped_weighted.t().contiguous().view(1,c//3,ctx.h,ctx.w) grad_latter_all[idx]=torch.add(grad_latter_all[idx],grad_swapped_weighted.mul(ctx.triple_w)) grad_input=torch.cat([grad_former_all,grad_latter_all],1) returngrad_input,None,None,None,None,None,None,None,None,None,None
补充知识:Pytorch中expand,expand_as是共享内存的,只是原始数据的一个视图view
如下所示:
mask=mask_miss.expand_as(sxing).clone()#type:torch.Tensor
mask[:,:,-2,:,:]=1#exceptforpersonmaskchannel
为了避免对expand后对某个channel操作会影响原始tensor的全部元素,需要使用clone()
如果没有clone(),对mask_miss的某个通道赋值后,所有通道上的tensor都会变成1!
#Notice!expanddoesnotallocatemorememorybutjustmakethetensorlookasifyouexpandedit.
#Youshouldcall.clone()ontheresultingtensorifyouplanonmodifyingit
#https://discuss.pytorch.org/t/very-strange-behavior-change-one-element-of-a-tensor-will-influence-all-elements/41190
以上这篇解决Pytorch自定义层出现多Variable共享内存错误问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。