pytorch GAN生成对抗网络实例
我就废话不多说了,直接上代码吧!
importtorch
importtorch.nnasnn
fromtorch.autogradimportVariable
importnumpyasnp
importmatplotlib.pyplotasplt
torch.manual_seed(1)
np.random.seed(1)
BATCH_SIZE=64
LR_G=0.0001
LR_D=0.0001
N_IDEAS=5
ART_COMPONENTS=15
PAINT_POINTS=np.vstack([np.linspace(-1,1,ART_COMPONENTS)for_inrange(BATCH_SIZE)])
defartist_works():
a=np.random.uniform(1,2,size=BATCH_SIZE)[:,np.newaxis]
paintings=a*np.power(PAINT_POINTS,2)+(a-1)
paintings=torch.from_numpy(paintings).float()
returnVariable(paintings)
G=nn.Sequential(
nn.Linear(N_IDEAS,128),
nn.ReLU(),
nn.Linear(128,ART_COMPONENTS),
)
D=nn.Sequential(
nn.Linear(ART_COMPONENTS,128),
nn.ReLU(),
nn.Linear(128,1),
nn.Sigmoid(),
)
opt_D=torch.optim.Adam(D.parameters(),lr=LR_D)
opt_G=torch.optim.Adam(G.parameters(),lr=LR_G)
plt.ion()
forstepinrange(10000):
artist_paintings=artist_works()
G_ideas=Variable(torch.randn(BATCH_SIZE,N_IDEAS))
G_paintings=G(G_ideas)
prob_artist0=D(artist_paintings)
prob_artist1=D(G_paintings)
D_loss=-torch.mean(torch.log(prob_artist0)+torch.log(1-prob_artist1))
G_loss=torch.mean(torch.log(1-prob_artist1))
opt_D.zero_grad()
D_loss.backward(retain_variables=True)
opt_D.step()
opt_G.zero_grad()
G_loss.backward()
opt_G.step()
ifstep%50==0:
plt.cla()
plt.plot(PAINT_POINTS[0],G_paintings.data.numpy()[0],c='#4ad631',lw=3,label='Generatedpainting',)
plt.plot(PAINT_POINTS[0],2*np.power(PAINT_POINTS[0],2)+1,c='#74BCFF',lw=3,label='upperbound',)
plt.plot(PAINT_POINTS[0],1*np.power(PAINT_POINTS[0],2)+0,c='#FF9359',lw=3,label='lowerbound',)
plt.text(-.5,2.3,'Daccuracy=%.2f(0.5forDtoconverge)'%prob_artist0.data.numpy().mean(),fontdict={'size':15})
plt.text(-.5,2,'Dscore=%.2f(-1.38forGtoconverge)'%-D_loss.data.numpy(),fontdict={'size':15})
plt.ylim((0,3))
plt.legend(loc='upperright',fontsize=12)
plt.draw()
plt.pause(0.01)
plt.ioff()
plt.show()
以上这篇pytorchGAN生成对抗网络实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。