Pytorch实现WGAN用于动漫头像生成
WGAN与GAN的不同
- 去除sigmoid
- 使用具有动量的优化方法,比如使用RMSProp
- 要对Discriminator的权重做修整限制以确保lipschitz连续约
WGAN实战卷积生成动漫头像
importtorch
importtorch.nnasnn
importtorchvision.transformsastransforms
fromtorch.utils.dataimportDataLoader
fromtorchvision.utilsimportsave_image
importos
fromanime_face_generator.datasetimportImageDataset
batch_size=32
num_epoch=100
z_dimension=100
dir_path='./wgan_img'
#创建文件夹
ifnotos.path.exists(dir_path):
os.mkdir(dir_path)
defto_img(x):
"""因为我们在生成器里面用了tanh"""
out=0.5*(x+1)
returnout
dataset=ImageDataset()
dataloader=DataLoader(dataset,batch_size=32,shuffle=False)
classGenerator(nn.Module):
def__init__(self):
super().__init__()
self.gen=nn.Sequential(
#输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的featuremap
nn.ConvTranspose2d(100,512,4,1,0,bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
#上一步的输出形状:(512)x4x4
nn.ConvTranspose2d(512,256,4,2,1,bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
#上一步的输出形状:(256)x8x8
nn.ConvTranspose2d(256,128,4,2,1,bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
#上一步的输出形状:(256)x16x16
nn.ConvTranspose2d(128,64,4,2,1,bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
#上一步的输出形状:(256)x32x32
nn.ConvTranspose2d(64,3,5,3,1,bias=False),
nn.Tanh()#输出范围-1~1故而采用Tanh
#nn.Sigmoid()
#输出形状:3x96x96
)
defforward(self,x):
x=self.gen(x)
returnx
defweight_init(m):
#weight_initialization:importantforwgan
class_name=m.__class__.__name__
ifclass_name.find('Conv')!=-1:
m.weight.data.normal_(0,0.02)
elifclass_name.find('Norm')!=-1:
m.weight.data.normal_(1.0,0.02)
classDiscriminator(nn.Module):
def__init__(self):
super().__init__()
self.dis=nn.Sequential(
nn.Conv2d(3,64,5,3,1,bias=False),
nn.LeakyReLU(0.2,inplace=True),
#输出(64)x32x32
nn.Conv2d(64,128,4,2,1,bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2,inplace=True),
#输出(128)x16x16
nn.Conv2d(128,256,4,2,1,bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2,inplace=True),
#输出(256)x8x8
nn.Conv2d(256,512,4,2,1,bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2,inplace=True),
#输出(512)x4x4
nn.Conv2d(512,1,4,1,0,bias=False),
nn.Flatten(),
#nn.Sigmoid()#输出一个数(概率)
)
defforward(self,x):
x=self.dis(x)
returnx
defweight_init(m):
#weight_initialization:importantforwgan
class_name=m.__class__.__name__
ifclass_name.find('Conv')!=-1:
m.weight.data.normal_(0,0.02)
elifclass_name.find('Norm')!=-1:
m.weight.data.normal_(1.0,0.02)
defsave(model,filename="model.pt",out_dir="out/"):
ifmodelisnotNone:
ifnotos.path.exists(out_dir):
os.mkdir(out_dir)
torch.save({'model':model.state_dict()},out_dir+filename)
else:
print("[ERROR]:Pleasebuildamodel!!!")
importQuickModelBuilderasbuilder
if__name__=='__main__':
one=torch.FloatTensor([1]).cuda()
mone=-1*one
is_print=True
#创建对象
D=Discriminator()
G=Generator()
D.weight_init()
G.weight_init()
iftorch.cuda.is_available():
D=D.cuda()
G=G.cuda()
lr=2e-4
d_optimizer=torch.optim.RMSprop(D.parameters(),lr=lr,)
g_optimizer=torch.optim.RMSprop(G.parameters(),lr=lr,)
d_scheduler=torch.optim.lr_scheduler.ExponentialLR(d_optimizer,gamma=0.99)
g_scheduler=torch.optim.lr_scheduler.ExponentialLR(g_optimizer,gamma=0.99)
fake_img=None
###########################进入训练##判别器的判断过程#####################
forepochinrange(num_epoch):#进行多个epoch的训练
pbar=builder.MyTqdm(epoch=epoch,maxval=len(dataloader))
fori,imginenumerate(dataloader):
num_img=img.size(0)
real_img=img.cuda()#将tensor变成Variable放入计算图中
#这里的优化器是D的优化器
forparaminD.parameters():
param.requires_grad=True
#########判别器训练train#####################
#分为两部分:1、真的图像判别为真;2、假的图像判别为假
#计算真实图片的损失
d_optimizer.zero_grad()#在反向传播之前,先将梯度归0
real_out=D(real_img)#将真实图片放入判别器中
d_loss_real=real_out.mean(0).view(1)
d_loss_real.backward(one)
#计算生成图片的损失
z=torch.randn(num_img,z_dimension).cuda()#随机生成一些噪声
z=z.reshape(num_img,z_dimension,1,1)
fake_img=G(z).detach()#随机噪声放入生成网络中,生成一张假的图片。#避免梯度传到G,因为G不用更新,detach分离
fake_out=D(fake_img)#判别器判断假的图片,
d_loss_fake=fake_out.mean(0).view(1)
d_loss_fake.backward(mone)
d_loss=d_loss_fake-d_loss_real
d_optimizer.step()#更新参数
#每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c=0.01
forparminD.parameters():
parm.data.clamp_(-0.01,0.01)
#==================训练生成器============================
################################生成网络的训练###############################
forparaminD.parameters():
param.requires_grad=False
#这里的优化器是G的优化器,所以不需要冻结D的梯度,因为不是D的优化器,不会更新D
g_optimizer.zero_grad()#梯度归0
z=torch.randn(num_img,z_dimension).cuda()
z=z.reshape(num_img,z_dimension,1,1)
fake_img=G(z)#随机噪声输入到生成器中,得到一副假的图片
output=D(fake_img)#经过判别器得到的结果
#g_loss=criterion(output,real_label)#得到的假的图片与真实的图片的label的loss
g_loss=torch.mean(output).view(1)
#bpandoptimize
g_loss.backward(one)#进行反向传播
g_optimizer.step()#.step()一般用在反向传播后面,用于更新生成网络的参数
#打印中间的损失
pbar.set_right_info(d_loss=d_loss.data.item(),
g_loss=g_loss.data.item(),
real_scores=real_out.data.mean().item(),
fake_scores=fake_out.data.mean().item(),
)
pbar.update()
try:
fake_images=to_img(fake_img.cpu())
save_image(fake_images,dir_path+'/fake_images-{}.png'.format(epoch+1))
except:
pass
ifis_print:
is_print=False
real_images=to_img(real_img.cpu())
save_image(real_images,dir_path+'/real_images.png')
pbar.finish()
d_scheduler.step()
g_scheduler.step()
save(D,"wgan_D.pt")
save(G,"wgan_G.pt")
到此这篇关于Pytorch实现WGAN用于动漫头像生成的文章就介绍到这了,更多相关Pytorch实现WGAN用于动漫头像生成内容请搜索毛票票以前的文章或继续浏览下面的相关文章希望大家以后多多支持毛票票!
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。