基于Pytorch SSD模型分析
本文参考github上SSD实现,对模型进行分析,主要分析模型组成及输入输出大小.SSD网络结构如下图:
每输入的图像有8732个框输出;
importtorch importtorch.nnasnn importtorch.nn.functionalasF fromtorch.autogradimportVariable #fromlayersimport* fromdataimportvoc,coco importos
base={ '300':[64,64,'M',128,128,'M',256,256,256,'C',512,512,512,'M', 512,512,512], '512':[], } extras={ '300':[256,'S',512,128,'S',256,128,256,128,256], '512':[], } mbox={ '300':[4,6,6,6,4,4],#numberofboxesperfeaturemaplocation '512':[], }
VGG基础网络结构:
defvgg(cfg,i,batch_norm=False): layers=[] in_channels=i forvincfg: ifv=='M': layers+=[nn.MaxPool2d(kernel_size=2,stride=2)] elifv=='C': layers+=[nn.MaxPool2d(kernel_size=2,stride=2,ceil_mode=True)] else: conv2d=nn.Conv2d(in_channels,v,kernel_size=3,padding=1) ifbatch_norm: layers+=[conv2d,nn.BatchNorm2d(v),nn.ReLU(inplace=True)] else: layers+=[conv2d,nn.ReLU(inplace=True)] in_channels=v pool5=nn.MaxPool2d(kernel_size=3,stride=1,padding=1) conv6=nn.Conv2d(512,1024,kernel_size=3,padding=6,dilation=6) conv7=nn.Conv2d(1024,1024,kernel_size=1) layers+=[pool5,conv6, nn.ReLU(inplace=True),conv7,nn.ReLU(inplace=True)] returnlayers
size=300 vgg=vgg(base[str(size)],3) print(vgg)
输出为:
Conv2d(3,64,kernel_size=(3,3),stride=(1,1),padding=(1,1)) ReLU(inplace) Conv2d(64,64,kernel_size=(3,3),stride=(1,1),padding=(1,1)) ReLU(inplace) MaxPool2d(kernel_size=2,stride=2,padding=0,dilation=1,ceil_mode=False) Conv2d(64,128,kernel_size=(3,3),stride=(1,1),padding=(1,1)) ReLU(inplace) Conv2d(128,128,kernel_size=(3,3),stride=(1,1),padding=(1,1)) ReLU(inplace) MaxPool2d(kernel_size=2,stride=2,padding=0,dilation=1,ceil_mode=False) Conv2d(128,256,kernel_size=(3,3),stride=(1,1),padding=(1,1)) ReLU(inplace) Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1)) ReLU(inplace) Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1)) ReLU(inplace) MaxPool2d(kernel_size=2,stride=2,padding=0,dilation=1,ceil_mode=True) Conv2d(256,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) ReLU(inplace) Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) ReLU(inplace) Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) ReLU(inplace) MaxPool2d(kernel_size=2,stride=2,padding=0,dilation=1,ceil_mode=False) Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) ReLU(inplace) Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) ReLU(inplace) Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1)) ReLU(inplace) MaxPool2d(kernel_size=3,stride=1,padding=1,dilation=1,ceil_mode=False) Conv2d(512,1024,kernel_size=(3,3),stride=(1,1),padding=(6,6),dilation=(6,6)) ReLU(inplace) Conv2d(1024,1024,kernel_size=(1,1),stride=(1,1)) ReLU(inplace)
SSD中添加的网络
add_extras函数构建基本的卷积层
defadd_extras(cfg,i,batch_norm=False): #ExtralayersaddedtoVGGforfeaturescaling layers=[] in_channels=i flag=False fork,vinenumerate(cfg): ifin_channels!='S': ifv=='S': layers+=[nn.Conv2d(in_channels,cfg[k+1], kernel_size=(1,3)[flag],stride=2,padding=1)] else: layers+=[nn.Conv2d(in_channels,v,kernel_size=(1,3)[flag])] flag=notflag in_channels=v returnlayers
extra_layers=add_extras(extras[str(size)],1024) forlayerinextra_layers: print(layer)
输出为:
Conv2d(1024,256,kernel_size=(1,1),stride=(1,1)) Conv2d(256,512,kernel_size=(3,3),stride=(2,2),padding=(1,1)) Conv2d(512,128,kernel_size=(1,1),stride=(1,1)) Conv2d(128,256,kernel_size=(3,3),stride=(2,2),padding=(1,1)) Conv2d(256,128,kernel_size=(1,1),stride=(1,1)) Conv2d(128,256,kernel_size=(3,3),stride=(1,1)) Conv2d(256,128,kernel_size=(1,1),stride=(1,1)) Conv2d(128,256,kernel_size=(3,3),stride=(1,1))
multibox函数得到每个特征图的默认box的位置计算网络和分类得分网络
defmultibox(vgg,extra_layers,cfg,num_classes): loc_layers=[] conf_layers=[] vgg_source=[21,-2] fork,vinenumerate(vgg_source): loc_layers+=[nn.Conv2d(vgg[v].out_channels, cfg[k]*4,kernel_size=3,padding=1)] conf_layers+=[nn.Conv2d(vgg[v].out_channels, cfg[k]*num_classes,kernel_size=3,padding=1)] fork,vinenumerate(extra_layers[1::2],2): loc_layers+=[nn.Conv2d(v.out_channels,cfg[k] *4,kernel_size=3,padding=1)] conf_layers+=[nn.Conv2d(v.out_channels,cfg[k] *num_classes,kernel_size=3,padding=1)] returnvgg,extra_layers,(loc_layers,conf_layers)
base_,extras_,head_=multibox(vgg(base[str(size)],3),##产生vgg19基本模型 add_extras(extras[str(size)],1024), mbox[str(size)],num_classes) #mbox[str(size)]为:[4,6,6,6,4,4]
得到的输出为:
base_为上述描述的vgg网络,extras_为extra_layers网络,head_为:
([Conv2d(512,16,kernel_size=(3,3),stride=(1,1),padding=(1,1)), Conv2d(1024,24,kernel_size=(3,3),stride=(1,1),padding=(1,1)), Conv2d(512,24,kernel_size=(3,3),stride=(1,1),padding=(1,1)), Conv2d(256,24,kernel_size=(3,3),stride=(1,1),padding=(1,1)), Conv2d(256,16,kernel_size=(3,3),stride=(1,1),padding=(1,1)), Conv2d(256,16,kernel_size=(3,3),stride=(1,1),padding=(1,1))], [Conv2d(512,84,kernel_size=(3,3),stride=(1,1),padding=(1,1)), Conv2d(1024,126,kernel_size=(3,3),stride=(1,1),padding=(1,1)), Conv2d(512,126,kernel_size=(3,3),stride=(1,1),padding=(1,1)), Conv2d(256,126,kernel_size=(3,3),stride=(1,1),padding=(1,1)), Conv2d(256,84,kernel_size=(3,3),stride=(1,1),padding=(1,1)), Conv2d(256,84,kernel_size=(3,3),stride=(1,1),padding=(1,1))])
SSD网络及forward函数为:
classSSD(nn.Module): """SingleShotMultiboxArchitecture ThenetworkiscomposedofabaseVGGnetworkfollowedbythe addedmultiboxconvlayers.Eachmultiboxlayerbranchesinto 1)conv2dforclassconfscores 2)conv2dforlocalizationpredictions 3)associatedpriorboxlayertoproducedefaultbounding boxesspecifictothelayer'sfeaturemapsize. See:https://arxiv.org/pdf/1512.02325.pdfformoredetails. Args: phase:(string)Canbe"test"or"train" size:inputimagesize base:VGG16layersforinput,sizeofeither300or500 extras:extralayersthatfeedtomultiboxlocandconflayers head:"multiboxhead"consistsoflocandconfconvlayers """ def__init__(self,phase,size,base,extras,head,num_classes): super(SSD,self).__init__() self.phase=phase self.num_classes=num_classes self.cfg=(coco,voc)[num_classes==21] self.priorbox=PriorBox(self.cfg) self.priors=Variable(self.priorbox.forward(),volatile=True) self.size=size #SSDnetwork self.vgg=nn.ModuleList(base) #Layerlearnstoscalethel2normalizedfeaturesfromconv4_3 self.L2Norm=L2Norm(512,20) self.extras=nn.ModuleList(extras) self.loc=nn.ModuleList(head[0]) self.conf=nn.ModuleList(head[1]) ifphase=='test': self.softmax=nn.Softmax(dim=-1) self.detect=Detect(num_classes,0,200,0.01,0.45) defforward(self,x): """Appliesnetworklayersandopsoninputimage(s)x. Args: x:inputimageorbatchofimages.Shape:[batch,3,300,300]. Return: Dependingonphase: test: Variable(tensor)ofoutputclasslabelpredictions, confidencescore,andcorrespondinglocationpredictionsfor eachobjectdetected.Shape:[batch,topk,7] train: listofconcatoutputsfrom: 1:confidencelayers,Shape:[batch*num_priors,num_classes] 2:localizationlayers,Shape:[batch,num_priors*4] 3:priorboxlayers,Shape:[2,num_priors*4] """ sources=list() loc=list() conf=list() #applyvgguptoconv4_3relu forkinrange(23): x=self.vgg[k](x)##得到的x尺度为[1,512,38,38] s=self.L2Norm(x) sources.append(s) #applyvgguptofc7 forkinrange(23,len(self.vgg)): x=self.vgg[k](x)##得到的x尺寸为[1,1024,19,19] sources.append(x) #applyextralayersandcachesourcelayeroutputs fork,vinenumerate(self.extras): x=F.relu(v(x),inplace=True) ifk%2==1: sources.append(x) ''' 上述得到的x输出分别为: torch.Size([1,512,10,10]) torch.Size([1,256,5,5]) torch.Size([1,256,3,3]) torch.Size([1,256,1,1]) ''' #applymultiboxheadtosourcelayers for(x,l,c)inzip(sources,self.loc,self.conf): loc.append(l(x).permute(0,2,3,1).contiguous()) conf.append(c(x).permute(0,2,3,1).contiguous()) loc=torch.cat([o.view(o.size(0),-1)foroinloc],1) conf=torch.cat([o.view(o.size(0),-1)foroinconf],1) ifself.phase=="test": output=self.detect( loc.view(loc.size(0),-1,4),#locpreds self.softmax(conf.view(conf.size(0),-1, self.num_classes)),#confpreds self.priors.type(type(x.data))#defaultboxes ) else: output=( loc.view(loc.size(0),-1,4),#[1,8732,4] conf.view(conf.size(0),-1,self.num_classes),#[1,8732,21] self.priors ) returnoutput
上述代码中sources中保存的数据输出如下,即用于边框提取的特征图:
torch.Size([1,512,38,38]) torch.Size([1,1024,19,19]) torch.Size([1,512,10,10]) torch.Size([1,256,5,5]) torch.Size([1,256,3,3]) torch.Size([1,256,1,1])
模型输入为
x=Variable(torch.randn(1,3,300,300))
以上这篇基于PytorchSSD模型分析就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。