pytorch GAN伪造手写体mnist数据集方式
一,mnist数据集
形如上图的数字手写体就是mnist数据集。
二,GAN原理(生成对抗网络)
GAN网络一共由两部分组成:一个是伪造器(Generator,简称G),一个是判别器(Discrimniator,简称D)
一开始,G由服从某几个分布(如高斯分布)的噪音组成,生成的图片不断送给D判断是否正确,直到G生成的图片连D都判断以为是真的。D每一轮除了看过G生成的假图片以外,还要见数据集中的真图片,以前者和后者得到的损失函数值为依据更新D网络中的权值。因此G和D都在不停地更新权值。以下图为例:
在v1时的G只不过是一堆噪声,见过数据集(realimages)的D肯定能判断出G所生成的是假的。当然G也能知道D判断它是假的这个结果,因此G就会更新权值,到v2的时候,G就能生成更逼真的图片来让D判断,当然在v2时D也是会先看一次真图片,再去判断G所生成的图片。以此类推,不断循环就是GAN的思想。
三,训练代码
importargparse importos importnumpyasnp importmath importtorchvision.transformsastransforms fromtorchvision.utilsimportsave_image fromtorch.utils.dataimportDataLoader fromtorchvisionimportdatasets fromtorch.autogradimportVariable importtorch.nnasnn importtorch.nn.functionalasF importtorch os.makedirs("images",exist_ok=True) parser=argparse.ArgumentParser() parser.add_argument("--n_epochs",type=int,default=200,help="numberofepochsoftraining") parser.add_argument("--batch_size",type=int,default=64,help="sizeofthebatches") parser.add_argument("--lr",type=float,default=0.0002,help="adam:learningrate") parser.add_argument("--b1",type=float,default=0.5,help="adam:decayoffirstordermomentumofgradient") parser.add_argument("--b2",type=float,default=0.999,help="adam:decayoffirstordermomentumofgradient") parser.add_argument("--n_cpu",type=int,default=8,help="numberofcputhreadstouseduringbatchgeneration") parser.add_argument("--latent_dim",type=int,default=100,help="dimensionalityofthelatentspace") parser.add_argument("--img_size",type=int,default=28,help="sizeofeachimagedimension") parser.add_argument("--channels",type=int,default=1,help="numberofimagechannels") parser.add_argument("--sample_interval",type=int,default=400,help="intervalbetwenimagesamples") opt=parser.parse_args() print(opt) img_shape=(opt.channels,opt.img_size,opt.img_size)#确定图片输入的格式为(1,28,28),由于mnist数据集是灰度图所以通道为1 cuda=Trueiftorch.cuda.is_available()elseFalse classGenerator(nn.Module): def__init__(self): super(Generator,self).__init__() defblock(in_feat,out_feat,normalize=True): layers=[nn.Linear(in_feat,out_feat)] ifnormalize: layers.append(nn.BatchNorm1d(out_feat,0.8)) layers.append(nn.LeakyReLU(0.2,inplace=True)) returnlayers self.model=nn.Sequential( *block(opt.latent_dim,128,normalize=False), *block(128,256), *block(256,512), *block(512,1024), nn.Linear(1024,int(np.prod(img_shape))), nn.Tanh() ) defforward(self,z): img=self.model(z) img=img.view(img.size(0),*img_shape) returnimg classDiscriminator(nn.Module): def__init__(self): super(Discriminator,self).__init__() self.model=nn.Sequential( nn.Linear(int(np.prod(img_shape)),512), nn.LeakyReLU(0.2,inplace=True), nn.Linear(512,256), nn.LeakyReLU(0.2,inplace=True), nn.Linear(256,1), nn.Sigmoid(), ) defforward(self,img): img_flat=img.view(img.size(0),-1) validity=self.model(img_flat) returnvalidity #Lossfunction adversarial_loss=torch.nn.BCELoss() #Initializegeneratoranddiscriminator generator=Generator() discriminator=Discriminator() ifcuda: generator.cuda() discriminator.cuda() adversarial_loss.cuda() #Configuredataloader os.makedirs("../../data/mnist",exist_ok=True) dataloader=torch.utils.data.DataLoader( datasets.MNIST( "../../data/mnist", train=True, download=True, transform=transforms.Compose( [transforms.Resize(opt.img_size),transforms.ToTensor(),transforms.Normalize([0.5],[0.5])] ), ), batch_size=opt.batch_size, shuffle=True, ) #Optimizers optimizer_G=torch.optim.Adam(generator.parameters(),lr=opt.lr,betas=(opt.b1,opt.b2)) optimizer_D=torch.optim.Adam(discriminator.parameters(),lr=opt.lr,betas=(opt.b1,opt.b2)) Tensor=torch.cuda.FloatTensorifcudaelsetorch.FloatTensor #---------- #Training #---------- if__name__=='__main__': forepochinrange(opt.n_epochs): fori,(imgs,_)inenumerate(dataloader): #print(imgs.shape) #Adversarialgroundtruths valid=Variable(Tensor(imgs.size(0),1).fill_(1.0),requires_grad=False)#全1 fake=Variable(Tensor(imgs.size(0),1).fill_(0.0),requires_grad=False)#全0 #Configureinput real_imgs=Variable(imgs.type(Tensor)) #----------------- #TrainGenerator #----------------- optimizer_G.zero_grad()#清空G网络上一个batch的梯度 #Samplenoiseasgeneratorinput z=Variable(Tensor(np.random.normal(0,1,(imgs.shape[0],opt.latent_dim))))#生成的噪音,均值为0方差为1维度为(64,100)的噪音 #Generateabatchofimages gen_imgs=generator(z) #Lossmeasuresgenerator'sabilitytofoolthediscriminator g_loss=adversarial_loss(discriminator(gen_imgs),valid) g_loss.backward()#g_loss用于更新G网络的权值,g_loss于D网络的判断结果有关 optimizer_G.step() #--------------------- #TrainDiscriminator #--------------------- optimizer_D.zero_grad()#清空D网络上一个batch的梯度 #Measurediscriminator'sabilitytoclassifyrealfromgeneratedsamples real_loss=adversarial_loss(discriminator(real_imgs),valid) fake_loss=adversarial_loss(discriminator(gen_imgs.detach()),fake) d_loss=(real_loss+fake_loss)/2 d_loss.backward()#d_loss用于更新D网络的权值 optimizer_D.step() print( "[Epoch%d/%d][Batch%d/%d][Dloss:%f][Gloss:%f]" %(epoch,opt.n_epochs,i,len(dataloader),d_loss.item(),g_loss.item()) ) batches_done=epoch*len(dataloader)+i ifbatches_done%opt.sample_interval==0: save_image(gen_imgs.data[:25],"images/%d.png"%batches_done,nrow=5,normalize=True)#保存一个batchsize中的25张 if(epoch+1)%2==0: print('save..') torch.save(generator,'g%d.pth'%epoch) torch.save(discriminator,'d%d.pth'%epoch)
运行结果:
一开始时,G生成的全是杂音:
然后逐渐呈现数字的雏形:
最后一次生成的结果:
四,测试代码:
导入最后保存生成器的模型:
fromganimportGenerator,Discriminator importtorch importmatplotlib.pyplotasplt fromtorch.autogradimportVariable importnumpyasnp fromtorchvision.utilsimportsave_image device=torch.device('cuda'iftorch.cuda.is_available()else'cpu') Tensor=torch.cuda.FloatTensor g=torch.load('g199.pth')#导入生成器Generator模型 #d=torch.load('d.pth') g=g.to(device) #d=d.to(device) z=Variable(Tensor(np.random.normal(0,1,(64,100))))#输入的噪音 gen_imgs=g(z)#生产图片 save_image(gen_imgs.data[:25],"images.png",nrow=5,normalize=True)
生成结果:
以上这篇pytorchGAN伪造手写体mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。