Keras自动下载的数据集/模型存放位置介绍
Mac
#数据集
~/.keras/datasets/#模型
~/.keras/models/
Linux
#数据集
~/.keras/datasets/
Windows
#win10
C:\Users\user_name\.keras\datasets
补充知识:Keras_gan生成自己的数据,并保存模型
我就废话不多说了,大家还是直接看代码吧~
from__future__importprint_function,division
fromkeras.datasetsimportmnist
fromkeras.layersimportInput,Dense,Reshape,Flatten,Dropout
fromkeras.layersimportBatchNormalization,Activation,ZeroPadding2D
fromkeras.layers.advanced_activationsimportLeakyReLU
fromkeras.layers.convolutionalimportUpSampling2D,Conv2D
fromkeras.modelsimportSequential,Model
fromkeras.optimizersimportAdam
importos
importmatplotlib.pyplotasplt
importsys
importnumpyasnp
classGAN():
def__init__(self):
self.img_rows=3
self.img_cols=60
self.channels=1
self.img_shape=(self.img_rows,self.img_cols,self.channels)
self.latent_dim=100
optimizer=Adam(0.0002,0.5)
#构建和编译判别器
self.discriminator=self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
#构建生成器
self.generator=self.build_generator()
#生成器输入噪音,生成假的图片
z=Input(shape=(self.latent_dim,))
img=self.generator(z)
#为了组合模型,只训练生成器
self.discriminator.trainable=False
#判别器将生成的图像作为输入并确定有效性
validity=self.discriminator(img)
#Thecombinedmodel(stackedgeneratoranddiscriminator)
#训练生成器骗过判别器
self.combined=Model(z,validity)
self.combined.compile(loss='binary_crossentropy',optimizer=optimizer)
defbuild_generator(self):
model=Sequential()
model.add(Dense(64,input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(128))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
#np.prod(self.img_shape)=3x60x1
model.add(Dense(np.prod(self.img_shape),activation='tanh'))
model.add(Reshape(self.img_shape))
model.summary()
noise=Input(shape=(self.latent_dim,))
img=model(noise)
#输入噪音,输出图片
returnModel(noise,img)
defbuild_discriminator(self):
model=Sequential()
model.add(Flatten(input_shape=self.img_shape))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(128))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(64))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1,activation='sigmoid'))
model.summary()
img=Input(shape=self.img_shape)
validity=model(img)
returnModel(img,validity)
deftrain(self,epochs,batch_size=128,sample_interval=50):
############################################################
#自己数据集此部分需要更改
#加载数据集
data=np.load('data/相对大小分叉.npy')
data=data[:,:,0:60]
#归一化到-1到1
data=data*2-1
data=np.expand_dims(data,axis=3)
############################################################
#Adversarialgroundtruths
valid=np.ones((batch_size,1))
fake=np.zeros((batch_size,1))
forepochinrange(epochs):
#---------------------
#训练判别器
#---------------------
#data.shape[0]为数据集的数量,随机生成batch_size个数量的随机数,作为数据的索引
idx=np.random.randint(0,data.shape[0],batch_size)
#从数据集随机挑选batch_size个数据,作为一个批次训练
imgs=data[idx]
#噪音维度(batch_size,100)
noise=np.random.normal(0,1,(batch_size,self.latent_dim))
#由生成器根据噪音生成假的图片
gen_imgs=self.generator.predict(noise)
#训练判别器,判别器希望真实图片,打上标签1,假的图片打上标签0
d_loss_real=self.discriminator.train_on_batch(imgs,valid)
d_loss_fake=self.discriminator.train_on_batch(gen_imgs,fake)
d_loss=0.5*np.add(d_loss_real,d_loss_fake)
#---------------------
#训练生成器
#---------------------
noise=np.random.normal(0,1,(batch_size,self.latent_dim))
#Trainthegenerator(tohavethediscriminatorlabelsamplesasvalid)
g_loss=self.combined.train_on_batch(noise,valid)
#打印loss值
print("%d[Dloss:%f,acc.:%.2f%%][Gloss:%f]"%(epoch,d_loss[0],100*d_loss[1],g_loss))
#没sample_interval个epoch保存一次生成图片
ifepoch%sample_interval==0:
self.sample_images(epoch)
ifnotos.path.exists("keras_model"):
os.makedirs("keras_model")
self.generator.save_weights("keras_model/G_model%d.hdf5"%epoch,True)
self.discriminator.save_weights("keras_model/D_model%d.hdf5"%epoch,True)
defsample_images(self,epoch):
r,c=10,10
#重新生成一批噪音,维度为(100,100)
noise=np.random.normal(0,1,(r*c,self.latent_dim))
gen_imgs=self.generator.predict(noise)
#将生成的图片重新归整到0-1之间
gen=0.5*gen_imgs+0.5
gen=gen.reshape(-1,3,60)
fig,axs=plt.subplots(r,c)
cnt=0
foriinrange(r):
forjinrange(c):
xy=gen[cnt]
forkinrange(len(xy)):
x=xy[k][0:30]
y=xy[k][30:60]
ifk==0:
axs[i,j].plot(x,y,color='blue')
ifk==1:
axs[i,j].plot(x,y,color='red')
ifk==2:
axs[i,j].plot(x,y,color='green')
plt.xlim(0.,1.)
plt.ylim(0.,1.)
plt.xticks(np.arange(0,1,0.1))
plt.xticks(np.arange(0,1,0.1))
axs[i,j].axis('off')
cnt+=1
ifnotos.path.exists("keras_imgs"):
os.makedirs("keras_imgs")
fig.savefig("keras_imgs/%d.png"%epoch)
plt.close()
deftest(self,gen_nums=100,save=False):
self.generator.load_weights("keras_model/G_model4000.hdf5",by_name=True)
self.discriminator.load_weights("keras_model/D_model4000.hdf5",by_name=True)
noise=np.random.normal(0,1,(gen_nums,self.latent_dim))
gen=self.generator.predict(noise)
gen=0.5*gen+0.5
gen=gen.reshape(-1,3,60)
print(gen.shape)
###############################################################
#直接可视化生成图片
ifsave:
foriinrange(0,len(gen)):
plt.figure(figsize=(128,128),dpi=1)
plt.plot(gen[i][0][0:30],gen[i][0][30:60],color='blue',linewidth=300)
plt.plot(gen[i][1][0:30],gen[i][1][30:60],color='red',linewidth=300)
plt.plot(gen[i][2][0:30],gen[i][2][30:60],color='green',linewidth=300)
plt.axis('off')
plt.xlim(0.,1.)
plt.ylim(0.,1.)
plt.xticks(np.arange(0,1,0.1))
plt.yticks(np.arange(0,1,0.1))
ifnotos.path.exists("keras_gen"):
os.makedirs("keras_gen")
plt.savefig("keras_gen"+os.sep+str(i)+'.jpg',dpi=1)
plt.close()
##################################################################
#重整图片到0-1
else:
foriinrange(len(gen)):
plt.plot(gen[i][0][0:30],gen[i][0][30:60],color='blue')
plt.plot(gen[i][1][0:30],gen[i][1][30:60],color='red')
plt.plot(gen[i][2][0:30],gen[i][2][30:60],color='green')
plt.xlim(0.,1.)
plt.ylim(0.,1.)
plt.xticks(np.arange(0,1,0.1))
plt.xticks(np.arange(0,1,0.1))
plt.show()
if__name__=='__main__':
gan=GAN()
gan.train(epochs=300000,batch_size=32,sample_interval=2000)
#gan.test(save=True)
以上这篇Keras自动下载的数据集/模型存放位置介绍就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。