pytorch:实现简单的GAN示例(MNIST数据集)
我就废话不多说了,直接上代码吧!
#-*-coding:utf-8-*-
"""
CreatedonSatOct1310:22:452018
@author:www
"""
importtorch
fromtorchimportnn
fromtorch.autogradimportVariable
importtorchvision.transformsastfs
fromtorch.utils.dataimportDataLoader,sampler
fromtorchvision.datasetsimportMNIST
importnumpyasnp
importmatplotlib.pyplotasplt
importmatplotlib.gridspecasgridspec
plt.rcParams['figure.figsize']=(10.0,8.0)#设置画图的尺寸
plt.rcParams['image.interpolation']='nearest'
plt.rcParams['image.cmap']='gray'
defshow_images(images):#定义画图工具
images=np.reshape(images,[images.shape[0],-1])
sqrtn=int(np.ceil(np.sqrt(images.shape[0])))
sqrtimg=int(np.ceil(np.sqrt(images.shape[1])))
fig=plt.figure(figsize=(sqrtn,sqrtn))
gs=gridspec.GridSpec(sqrtn,sqrtn)
gs.update(wspace=0.05,hspace=0.05)
fori,imginenumerate(images):
ax=plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(img.reshape([sqrtimg,sqrtimg]))
return
defpreprocess_img(x):
x=tfs.ToTensor()(x)
return(x-0.5)/0.5
defdeprocess_img(x):
return(x+1.0)/2.0
classChunkSampler(sampler.Sampler):#定义一个取样的函数
"""Sampleselementssequentiallyfromsomeoffset.
Arguments:
num_samples:#ofdesireddatapoints
start:offsetwhereweshouldstartselectingfrom
"""
def__init__(self,num_samples,start=0):
self.num_samples=num_samples
self.start=start
def__iter__(self):
returniter(range(self.start,self.start+self.num_samples))
def__len__(self):
returnself.num_samples
NUM_TRAIN=50000
NUM_VAL=5000
NOISE_DIM=96
batch_size=128
train_set=MNIST('E:/data',train=True,transform=preprocess_img)
train_data=DataLoader(train_set,batch_size=batch_size,sampler=ChunkSampler(NUM_TRAIN,0))
val_set=MNIST('E:/data',train=True,transform=preprocess_img)
val_data=DataLoader(val_set,batch_size=batch_size,sampler=ChunkSampler(NUM_VAL,NUM_TRAIN))
imgs=deprocess_img(train_data.__iter__().next()[0].view(batch_size,784)).numpy().squeeze()#可视化图片效果
show_images(imgs)
#判别网络
defdiscriminator():
net=nn.Sequential(
nn.Linear(784,256),
nn.LeakyReLU(0.2),
nn.Linear(256,256),
nn.LeakyReLU(0.2),
nn.Linear(256,1)
)
returnnet
#生成网络
defgenerator(noise_dim=NOISE_DIM):
net=nn.Sequential(
nn.Linear(noise_dim,1024),
nn.ReLU(True),
nn.Linear(1024,1024),
nn.ReLU(True),
nn.Linear(1024,784),
nn.Tanh()
)
returnnet
#判别器的loss就是将真实数据的得分判断为1,假的数据的得分判断为0,而生成器的loss就是将假的数据判断为1
bce_loss=nn.BCEWithLogitsLoss()#交叉熵损失函数
defdiscriminator_loss(logits_real,logits_fake):#判别器的loss
size=logits_real.shape[0]
true_labels=Variable(torch.ones(size,1)).float()
false_labels=Variable(torch.zeros(size,1)).float()
loss=bce_loss(logits_real,true_labels)+bce_loss(logits_fake,false_labels)
returnloss
defgenerator_loss(logits_fake):#生成器的loss
size=logits_fake.shape[0]
true_labels=Variable(torch.ones(size,1)).float()
loss=bce_loss(logits_fake,true_labels)
returnloss
#使用adam来进行训练,学习率是3e-4,beta1是0.5,beta2是0.999
defget_optimizer(net):
optimizer=torch.optim.Adam(net.parameters(),lr=3e-4,betas=(0.5,0.999))
returnoptimizer
deftrain_a_gan(D_net,G_net,D_optimizer,G_optimizer,discriminator_loss,generator_loss,show_every=250,
noise_size=96,num_epochs=10):
iter_count=0
forepochinrange(num_epochs):
forx,_intrain_data:
bs=x.shape[0]
#判别网络
real_data=Variable(x).view(bs,-1)#真实数据
logits_real=D_net(real_data)#判别网络得分
sample_noise=(torch.rand(bs,noise_size)-0.5)/0.5#-1~1的均匀分布
g_fake_seed=Variable(sample_noise)
fake_images=G_net(g_fake_seed)#生成的假的数据
logits_fake=D_net(fake_images)#判别网络得分
d_total_error=discriminator_loss(logits_real,logits_fake)#判别器的loss
D_optimizer.zero_grad()
d_total_error.backward()
D_optimizer.step()#优化判别网络
#生成网络
g_fake_seed=Variable(sample_noise)
fake_images=G_net(g_fake_seed)#生成的假的数据
gen_logits_fake=D_net(fake_images)
g_error=generator_loss(gen_logits_fake)#生成网络的loss
G_optimizer.zero_grad()
g_error.backward()
G_optimizer.step()#优化生成网络
if(iter_count%show_every==0):
print('Iter:{},D:{:.4},G:{:.4}'.format(iter_count,d_total_error.item(),g_error.item()))
imgs_numpy=deprocess_img(fake_images.data.cpu().numpy())
show_images(imgs_numpy[0:16])
plt.show()
print()
iter_count+=1
D=discriminator()
G=generator()
D_optim=get_optimizer(D)
G_optim=get_optimizer(G)
train_a_gan(D,G,D_optim,G_optim,discriminator_loss,generator_loss)
以上这篇pytorch:实现简单的GAN示例(MNIST数据集)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。