pytorch:实现简单的GAN示例(MNIST数据集)
我就废话不多说了,直接上代码吧!
#-*-coding:utf-8-*- """ CreatedonSatOct1310:22:452018 @author:www """ importtorch fromtorchimportnn fromtorch.autogradimportVariable importtorchvision.transformsastfs fromtorch.utils.dataimportDataLoader,sampler fromtorchvision.datasetsimportMNIST importnumpyasnp importmatplotlib.pyplotasplt importmatplotlib.gridspecasgridspec plt.rcParams['figure.figsize']=(10.0,8.0)#设置画图的尺寸 plt.rcParams['image.interpolation']='nearest' plt.rcParams['image.cmap']='gray' defshow_images(images):#定义画图工具 images=np.reshape(images,[images.shape[0],-1]) sqrtn=int(np.ceil(np.sqrt(images.shape[0]))) sqrtimg=int(np.ceil(np.sqrt(images.shape[1]))) fig=plt.figure(figsize=(sqrtn,sqrtn)) gs=gridspec.GridSpec(sqrtn,sqrtn) gs.update(wspace=0.05,hspace=0.05) fori,imginenumerate(images): ax=plt.subplot(gs[i]) plt.axis('off') ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_aspect('equal') plt.imshow(img.reshape([sqrtimg,sqrtimg])) return defpreprocess_img(x): x=tfs.ToTensor()(x) return(x-0.5)/0.5 defdeprocess_img(x): return(x+1.0)/2.0 classChunkSampler(sampler.Sampler):#定义一个取样的函数 """Sampleselementssequentiallyfromsomeoffset. Arguments: num_samples:#ofdesireddatapoints start:offsetwhereweshouldstartselectingfrom """ def__init__(self,num_samples,start=0): self.num_samples=num_samples self.start=start def__iter__(self): returniter(range(self.start,self.start+self.num_samples)) def__len__(self): returnself.num_samples NUM_TRAIN=50000 NUM_VAL=5000 NOISE_DIM=96 batch_size=128 train_set=MNIST('E:/data',train=True,transform=preprocess_img) train_data=DataLoader(train_set,batch_size=batch_size,sampler=ChunkSampler(NUM_TRAIN,0)) val_set=MNIST('E:/data',train=True,transform=preprocess_img) val_data=DataLoader(val_set,batch_size=batch_size,sampler=ChunkSampler(NUM_VAL,NUM_TRAIN)) imgs=deprocess_img(train_data.__iter__().next()[0].view(batch_size,784)).numpy().squeeze()#可视化图片效果 show_images(imgs) #判别网络 defdiscriminator(): net=nn.Sequential( nn.Linear(784,256), nn.LeakyReLU(0.2), nn.Linear(256,256), nn.LeakyReLU(0.2), nn.Linear(256,1) ) returnnet #生成网络 defgenerator(noise_dim=NOISE_DIM): net=nn.Sequential( nn.Linear(noise_dim,1024), nn.ReLU(True), nn.Linear(1024,1024), nn.ReLU(True), nn.Linear(1024,784), nn.Tanh() ) returnnet #判别器的loss就是将真实数据的得分判断为1,假的数据的得分判断为0,而生成器的loss就是将假的数据判断为1 bce_loss=nn.BCEWithLogitsLoss()#交叉熵损失函数 defdiscriminator_loss(logits_real,logits_fake):#判别器的loss size=logits_real.shape[0] true_labels=Variable(torch.ones(size,1)).float() false_labels=Variable(torch.zeros(size,1)).float() loss=bce_loss(logits_real,true_labels)+bce_loss(logits_fake,false_labels) returnloss defgenerator_loss(logits_fake):#生成器的loss size=logits_fake.shape[0] true_labels=Variable(torch.ones(size,1)).float() loss=bce_loss(logits_fake,true_labels) returnloss #使用adam来进行训练,学习率是3e-4,beta1是0.5,beta2是0.999 defget_optimizer(net): optimizer=torch.optim.Adam(net.parameters(),lr=3e-4,betas=(0.5,0.999)) returnoptimizer deftrain_a_gan(D_net,G_net,D_optimizer,G_optimizer,discriminator_loss,generator_loss,show_every=250, noise_size=96,num_epochs=10): iter_count=0 forepochinrange(num_epochs): forx,_intrain_data: bs=x.shape[0] #判别网络 real_data=Variable(x).view(bs,-1)#真实数据 logits_real=D_net(real_data)#判别网络得分 sample_noise=(torch.rand(bs,noise_size)-0.5)/0.5#-1~1的均匀分布 g_fake_seed=Variable(sample_noise) fake_images=G_net(g_fake_seed)#生成的假的数据 logits_fake=D_net(fake_images)#判别网络得分 d_total_error=discriminator_loss(logits_real,logits_fake)#判别器的loss D_optimizer.zero_grad() d_total_error.backward() D_optimizer.step()#优化判别网络 #生成网络 g_fake_seed=Variable(sample_noise) fake_images=G_net(g_fake_seed)#生成的假的数据 gen_logits_fake=D_net(fake_images) g_error=generator_loss(gen_logits_fake)#生成网络的loss G_optimizer.zero_grad() g_error.backward() G_optimizer.step()#优化生成网络 if(iter_count%show_every==0): print('Iter:{},D:{:.4},G:{:.4}'.format(iter_count,d_total_error.item(),g_error.item())) imgs_numpy=deprocess_img(fake_images.data.cpu().numpy()) show_images(imgs_numpy[0:16]) plt.show() print() iter_count+=1 D=discriminator() G=generator() D_optim=get_optimizer(D) G_optim=get_optimizer(G) train_a_gan(D,G,D_optim,G_optim,discriminator_loss,generator_loss)
以上这篇pytorch:实现简单的GAN示例(MNIST数据集)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。