pytorch实现线性拟合方式
一维线性拟合
数据为y=4x+5加上噪音
结果:
importnumpyasnp
frommpl_toolkits.mplot3dimportAxes3D
frommatplotlibimportpyplotasplt
fromtorch.autogradimportVariable
importtorch
fromtorchimportnn
X=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
Y=4*X+5+torch.rand(X.size())
classLinearRegression(nn.Module):
def__init__(self):
super(LinearRegression,self).__init__()
self.linear=nn.Linear(1,1)#输入和输出的维度都是1
defforward(self,X):
out=self.linear(X)
returnout
model=LinearRegression()
criterion=nn.MSELoss()
optimizer=torch.optim.SGD(model.parameters(),lr=1e-2)
num_epochs=1000
forepochinrange(num_epochs):
inputs=Variable(X)
target=Variable(Y)
#向前传播
out=model(inputs)
loss=criterion(out,target)
#向后传播
optimizer.zero_grad()#注意每次迭代都需要清零
loss.backward()
optimizer.step()
if(epoch+1)%20==0:
print('Epoch[{}/{}],loss:{:.6f}'.format(epoch+1,num_epochs,loss.item()))
model.eval()
predict=model(Variable(X))
predict=predict.data.numpy()
plt.plot(X.numpy(),Y.numpy(),'ro',label='OriginalData')
plt.plot(X.numpy(),predict,label='FittingLine')
plt.show()
多维:
fromitertoolsimportcount
importtorch
importtorch.autograd
importtorch.nn.functionalasF
POLY_DEGREE=3
defmake_features(x):
"""Buildsfeaturesi.e.amatrixwithcolumns[x,x^2,x^3]."""
x=x.unsqueeze(1)
returntorch.cat([x**iforiinrange(1,POLY_DEGREE+1)],1)
W_target=torch.randn(POLY_DEGREE,1)
b_target=torch.randn(1)
deff(x):
returnx.mm(W_target)+b_target.item()
defget_batch(batch_size=32):
random=torch.randn(batch_size)
x=make_features(random)
y=f(x)
returnx,y
#Definemodel
fc=torch.nn.Linear(W_target.size(0),1)
batch_x,batch_y=get_batch()
print(batch_x,batch_y)
forbatch_idxincount(1):
#Getdata
#Resetgradients
fc.zero_grad()
#Forwardpass
output=F.smooth_l1_loss(fc(batch_x),batch_y)
loss=output.item()
#Backwardpass
output.backward()
#Applygradients
forparaminfc.parameters():
param.data.add_(-0.1*param.grad.data)
#Stopcriterion
ifloss<1e-3:
break
defpoly_desc(W,b):
"""Createsastringdescriptionofapolynomial."""
result='y='
fori,winenumerate(W):
result+='{:+.2f}x^{}'.format(w,len(W)-i)
result+='{:+.2f}'.format(b[0])
returnresult
print('Loss:{:.6f}after{}batches'.format(loss,batch_idx))
print('==>Learnedfunction:\t'+poly_desc(fc.weight.view(-1),fc.bias))
print('==>Actualfunction:\t'+poly_desc(W_target.view(-1),b_target))
以上这篇pytorch实现线性拟合方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。
