pytorch GAN生成对抗网络实例
我就废话不多说了,直接上代码吧!
importtorch importtorch.nnasnn fromtorch.autogradimportVariable importnumpyasnp importmatplotlib.pyplotasplt torch.manual_seed(1) np.random.seed(1) BATCH_SIZE=64 LR_G=0.0001 LR_D=0.0001 N_IDEAS=5 ART_COMPONENTS=15 PAINT_POINTS=np.vstack([np.linspace(-1,1,ART_COMPONENTS)for_inrange(BATCH_SIZE)]) defartist_works(): a=np.random.uniform(1,2,size=BATCH_SIZE)[:,np.newaxis] paintings=a*np.power(PAINT_POINTS,2)+(a-1) paintings=torch.from_numpy(paintings).float() returnVariable(paintings) G=nn.Sequential( nn.Linear(N_IDEAS,128), nn.ReLU(), nn.Linear(128,ART_COMPONENTS), ) D=nn.Sequential( nn.Linear(ART_COMPONENTS,128), nn.ReLU(), nn.Linear(128,1), nn.Sigmoid(), ) opt_D=torch.optim.Adam(D.parameters(),lr=LR_D) opt_G=torch.optim.Adam(G.parameters(),lr=LR_G) plt.ion() forstepinrange(10000): artist_paintings=artist_works() G_ideas=Variable(torch.randn(BATCH_SIZE,N_IDEAS)) G_paintings=G(G_ideas) prob_artist0=D(artist_paintings) prob_artist1=D(G_paintings) D_loss=-torch.mean(torch.log(prob_artist0)+torch.log(1-prob_artist1)) G_loss=torch.mean(torch.log(1-prob_artist1)) opt_D.zero_grad() D_loss.backward(retain_variables=True) opt_D.step() opt_G.zero_grad() G_loss.backward() opt_G.step() ifstep%50==0: plt.cla() plt.plot(PAINT_POINTS[0],G_paintings.data.numpy()[0],c='#4ad631',lw=3,label='Generatedpainting',) plt.plot(PAINT_POINTS[0],2*np.power(PAINT_POINTS[0],2)+1,c='#74BCFF',lw=3,label='upperbound',) plt.plot(PAINT_POINTS[0],1*np.power(PAINT_POINTS[0],2)+0,c='#FF9359',lw=3,label='lowerbound',) plt.text(-.5,2.3,'Daccuracy=%.2f(0.5forDtoconverge)'%prob_artist0.data.numpy().mean(),fontdict={'size':15}) plt.text(-.5,2,'Dscore=%.2f(-1.38forGtoconverge)'%-D_loss.data.numpy(),fontdict={'size':15}) plt.ylim((0,3)) plt.legend(loc='upperright',fontsize=12) plt.draw() plt.pause(0.01) plt.ioff() plt.show()
以上这篇pytorchGAN生成对抗网络实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。