pytorch GAN伪造手写体mnist数据集方式
一,mnist数据集
形如上图的数字手写体就是mnist数据集。
二,GAN原理(生成对抗网络)
GAN网络一共由两部分组成:一个是伪造器(Generator,简称G),一个是判别器(Discrimniator,简称D)
一开始,G由服从某几个分布(如高斯分布)的噪音组成,生成的图片不断送给D判断是否正确,直到G生成的图片连D都判断以为是真的。D每一轮除了看过G生成的假图片以外,还要见数据集中的真图片,以前者和后者得到的损失函数值为依据更新D网络中的权值。因此G和D都在不停地更新权值。以下图为例:
在v1时的G只不过是一堆噪声,见过数据集(realimages)的D肯定能判断出G所生成的是假的。当然G也能知道D判断它是假的这个结果,因此G就会更新权值,到v2的时候,G就能生成更逼真的图片来让D判断,当然在v2时D也是会先看一次真图片,再去判断G所生成的图片。以此类推,不断循环就是GAN的思想。
三,训练代码
importargparse
importos
importnumpyasnp
importmath
importtorchvision.transformsastransforms
fromtorchvision.utilsimportsave_image
fromtorch.utils.dataimportDataLoader
fromtorchvisionimportdatasets
fromtorch.autogradimportVariable
importtorch.nnasnn
importtorch.nn.functionalasF
importtorch
os.makedirs("images",exist_ok=True)
parser=argparse.ArgumentParser()
parser.add_argument("--n_epochs",type=int,default=200,help="numberofepochsoftraining")
parser.add_argument("--batch_size",type=int,default=64,help="sizeofthebatches")
parser.add_argument("--lr",type=float,default=0.0002,help="adam:learningrate")
parser.add_argument("--b1",type=float,default=0.5,help="adam:decayoffirstordermomentumofgradient")
parser.add_argument("--b2",type=float,default=0.999,help="adam:decayoffirstordermomentumofgradient")
parser.add_argument("--n_cpu",type=int,default=8,help="numberofcputhreadstouseduringbatchgeneration")
parser.add_argument("--latent_dim",type=int,default=100,help="dimensionalityofthelatentspace")
parser.add_argument("--img_size",type=int,default=28,help="sizeofeachimagedimension")
parser.add_argument("--channels",type=int,default=1,help="numberofimagechannels")
parser.add_argument("--sample_interval",type=int,default=400,help="intervalbetwenimagesamples")
opt=parser.parse_args()
print(opt)
img_shape=(opt.channels,opt.img_size,opt.img_size)#确定图片输入的格式为(1,28,28),由于mnist数据集是灰度图所以通道为1
cuda=Trueiftorch.cuda.is_available()elseFalse
classGenerator(nn.Module):
def__init__(self):
super(Generator,self).__init__()
defblock(in_feat,out_feat,normalize=True):
layers=[nn.Linear(in_feat,out_feat)]
ifnormalize:
layers.append(nn.BatchNorm1d(out_feat,0.8))
layers.append(nn.LeakyReLU(0.2,inplace=True))
returnlayers
self.model=nn.Sequential(
*block(opt.latent_dim,128,normalize=False),
*block(128,256),
*block(256,512),
*block(512,1024),
nn.Linear(1024,int(np.prod(img_shape))),
nn.Tanh()
)
defforward(self,z):
img=self.model(z)
img=img.view(img.size(0),*img_shape)
returnimg
classDiscriminator(nn.Module):
def__init__(self):
super(Discriminator,self).__init__()
self.model=nn.Sequential(
nn.Linear(int(np.prod(img_shape)),512),
nn.LeakyReLU(0.2,inplace=True),
nn.Linear(512,256),
nn.LeakyReLU(0.2,inplace=True),
nn.Linear(256,1),
nn.Sigmoid(),
)
defforward(self,img):
img_flat=img.view(img.size(0),-1)
validity=self.model(img_flat)
returnvalidity
#Lossfunction
adversarial_loss=torch.nn.BCELoss()
#Initializegeneratoranddiscriminator
generator=Generator()
discriminator=Discriminator()
ifcuda:
generator.cuda()
discriminator.cuda()
adversarial_loss.cuda()
#Configuredataloader
os.makedirs("../../data/mnist",exist_ok=True)
dataloader=torch.utils.data.DataLoader(
datasets.MNIST(
"../../data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size),transforms.ToTensor(),transforms.Normalize([0.5],[0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
)
#Optimizers
optimizer_G=torch.optim.Adam(generator.parameters(),lr=opt.lr,betas=(opt.b1,opt.b2))
optimizer_D=torch.optim.Adam(discriminator.parameters(),lr=opt.lr,betas=(opt.b1,opt.b2))
Tensor=torch.cuda.FloatTensorifcudaelsetorch.FloatTensor
#----------
#Training
#----------
if__name__=='__main__':
forepochinrange(opt.n_epochs):
fori,(imgs,_)inenumerate(dataloader):
#print(imgs.shape)
#Adversarialgroundtruths
valid=Variable(Tensor(imgs.size(0),1).fill_(1.0),requires_grad=False)#全1
fake=Variable(Tensor(imgs.size(0),1).fill_(0.0),requires_grad=False)#全0
#Configureinput
real_imgs=Variable(imgs.type(Tensor))
#-----------------
#TrainGenerator
#-----------------
optimizer_G.zero_grad()#清空G网络上一个batch的梯度
#Samplenoiseasgeneratorinput
z=Variable(Tensor(np.random.normal(0,1,(imgs.shape[0],opt.latent_dim))))#生成的噪音,均值为0方差为1维度为(64,100)的噪音
#Generateabatchofimages
gen_imgs=generator(z)
#Lossmeasuresgenerator'sabilitytofoolthediscriminator
g_loss=adversarial_loss(discriminator(gen_imgs),valid)
g_loss.backward()#g_loss用于更新G网络的权值,g_loss于D网络的判断结果有关
optimizer_G.step()
#---------------------
#TrainDiscriminator
#---------------------
optimizer_D.zero_grad()#清空D网络上一个batch的梯度
#Measurediscriminator'sabilitytoclassifyrealfromgeneratedsamples
real_loss=adversarial_loss(discriminator(real_imgs),valid)
fake_loss=adversarial_loss(discriminator(gen_imgs.detach()),fake)
d_loss=(real_loss+fake_loss)/2
d_loss.backward()#d_loss用于更新D网络的权值
optimizer_D.step()
print(
"[Epoch%d/%d][Batch%d/%d][Dloss:%f][Gloss:%f]"
%(epoch,opt.n_epochs,i,len(dataloader),d_loss.item(),g_loss.item())
)
batches_done=epoch*len(dataloader)+i
ifbatches_done%opt.sample_interval==0:
save_image(gen_imgs.data[:25],"images/%d.png"%batches_done,nrow=5,normalize=True)#保存一个batchsize中的25张
if(epoch+1)%2==0:
print('save..')
torch.save(generator,'g%d.pth'%epoch)
torch.save(discriminator,'d%d.pth'%epoch)
运行结果:
一开始时,G生成的全是杂音:
然后逐渐呈现数字的雏形:
最后一次生成的结果:
四,测试代码:
导入最后保存生成器的模型:
fromganimportGenerator,Discriminator
importtorch
importmatplotlib.pyplotasplt
fromtorch.autogradimportVariable
importnumpyasnp
fromtorchvision.utilsimportsave_image
device=torch.device('cuda'iftorch.cuda.is_available()else'cpu')
Tensor=torch.cuda.FloatTensor
g=torch.load('g199.pth')#导入生成器Generator模型
#d=torch.load('d.pth')
g=g.to(device)
#d=d.to(device)
z=Variable(Tensor(np.random.normal(0,1,(64,100))))#输入的噪音
gen_imgs=g(z)#生产图片
save_image(gen_imgs.data[:25],"images.png",nrow=5,normalize=True)
生成结果:
以上这篇pytorchGAN伪造手写体mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。