浅析PyTorch中nn.Module的使用
torch.nn.Modules相当于是对网络某种层的封装,包括网络结构以及网络参数和一些操作
torch.nn.Module是所有神经网络单元的基类
查看源码
初始化部分:
def__init__(self): self._backend=thnn_backend self._parameters=OrderedDict() self._buffers=OrderedDict() self._backward_hooks=OrderedDict() self._forward_hooks=OrderedDict() self._forward_pre_hooks=OrderedDict() self._state_dict_hooks=OrderedDict() self._load_state_dict_pre_hooks=OrderedDict() self._modules=OrderedDict() self.training=True
属性解释:
- _parameters:字典,保存用户直接设置的Parameter
- _modules:子module,即子类构造函数中的内容
- _buffers:缓存
- _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
- training:判断值来决定前向传播策略
方法定义:
defforward(self,*input): raiseNotImplementedError
没有实际内容,用于被子类的forward()方法覆盖
且forward方法在__call__方法中被调用:
def__call__(self,*input,**kwargs): forhookinself._forward_pre_hooks.values(): hook(self,input) iftorch._C._get_tracing_state(): result=self._slow_forward(*input,**kwargs) else: result=self.forward(*input,**kwargs) ... ...
实例展示
简单搭建:
importtorch.nnasnn importtorch.nn.functionalasF classNet(nn.Module): def__init__(self,n_feature,n_hidden,n_output): super(Net,self).__init__() self.hidden=nn.Linear(n_feature,n_hidden) self.out=nn.Linear(n_hidden,n_output) defforward(self,x): x=F.relu(self.hidden(x)) x=self.out(x) returnx
Net类继承了torch的Module和__init__功能
hidden是隐藏层线性输出
out是输出层线性输出
打印出网络的结构:
>>>net=Net(n_feature=10,n_hidden=30,n_output=15) >>>print(net) Net( (hidden):Linear(in_features=10,out_features=30,bias=True) (out):Linear(in_features=30,out_features=15,bias=True) )
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。