tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用
1.创建tfrecord
tfrecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList、tf.train.Int64List、tf.train.FloatList写入tf.train.Feature,如下所示:
tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))#feature一般是多维数组,要先转为list tf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.shape)))#tostring函数后feature的形状信息会丢失,把shape也写入 tf.train.Feature(float_list=tf.train.FloatList(value=[label]))
通过上述操作,以dict的形式把要写入的数据汇总,并构建tf.train.Features,然后构建tf.train.Example,如下:
defget_tfrecords_example(feature,label):
tfrecords_features={}
feat_shape=feature.shape
tfrecords_features['feature']=tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
tfrecords_features['shape']=tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
tfrecords_features['label']=tf.train.Feature(float_list=tf.train.FloatList(value=label))
returntf.train.Example(features=tf.train.Features(feature=tfrecords_features))
把创建的tf.train.Example序列化下,便可通过tf.python_io.TFRecordWriter写入tfrecord文件,如下:
tfrecord_wrt=tf.python_io.TFRecordWriter('xxx.tfrecord')#创建tfrecord的writer,文件名为xxx
exmp=get_tfrecords_example(feats[inx],labels[inx])#把数据写入Example
exmp_serial=exmp.SerializeToString()#Example序列化
tfrecord_wrt.write(exmp_serial)#写入tfrecord文件
tfrecord_wrt.close()#写完后关闭tfrecord的writer
代码汇总:
importtensorflowastf
fromtensorflow.contrib.learn.python.learn.datasets.mnistimportread_data_sets
mnist=read_data_sets("MNIST_data/",one_hot=True)
#把数据写入Example
defget_tfrecords_example(feature,label):
tfrecords_features={}
feat_shape=feature.shape
tfrecords_features['feature']=tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
tfrecords_features['shape']=tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
tfrecords_features['label']=tf.train.Feature(float_list=tf.train.FloatList(value=label))
returntf.train.Example(features=tf.train.Features(feature=tfrecords_features))
#把所有数据写入tfrecord文件
defmake_tfrecord(data,outf_nm='mnist-train'):
feats,labels=data
outf_nm+='.tfrecord'
tfrecord_wrt=tf.python_io.TFRecordWriter(outf_nm)
ndatas=len(labels)
forinxinrange(ndatas):
exmp=get_tfrecords_example(feats[inx],labels[inx])
exmp_serial=exmp.SerializeToString()
tfrecord_wrt.write(exmp_serial)
tfrecord_wrt.close()
importrandom
nDatas=len(mnist.train.labels)
inx_lst=range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains=int(0.85*nDatas)
#maketrainingset
data=([mnist.train.images[i]foriininx_lst[:ntrains]],\
[mnist.train.labels[i]foriininx_lst[:ntrains]])
make_tfrecord(data,outf_nm='mnist-train')
#makevalidationset
data=([mnist.train.images[i]foriininx_lst[ntrains:]],\
[mnist.train.labels[i]foriininx_lst[ntrains:]])
make_tfrecord(data,outf_nm='mnist-val')
#maketestset
data=(mnist.test.images,mnist.test.labels)
make_tfrecord(data,outf_nm='mnist-test')
2.tfrecord文件的使用:tf.data.TFRecordDataset
从tfrecord文件创建TFRecordDataset:
dataset=tf.data.TFRecordDataset('xxx.tfrecord')
解析tfrecord文件的每条记录,即序列化后的tf.train.Example;使用tf.parse_single_example来解析:
feats=tf.parse_single_example(serial_exmp,features=data_dict)
其中,data_dict是一个dict,包含的key是写入tfrecord文件时用的key,相应的value则是tf.FixedLenFeature([],tf.string)、tf.FixedLenFeature([],tf.int64)、tf.FixedLenFeature([],tf.float32),分别对应不同的数据类型,汇总即有:
defparse_exmp(serial_exmp):#label中[10]是因为一个label是一个有10个元素的列表,shape中的[x]为shape的长度
feats=tf.parse_single_example(serial_exmp,features={'feature':tf.FixedLenFeature([],tf.string),\
'label':tf.FixedLenFeature([10],tf.float32),'shape':tf.FixedLenFeature([x],tf.int64)})
image=tf.decode_raw(feats['feature'],tf.float32)
label=feats['label']
shape=tf.cast(feats['shape'],tf.int32)
returnimage,label,shape
解析tfrecord文件中的所有记录,使用dataset的map方法,如下:
dataset=dataset.map(parse_exmp)
map方法可以接受任意函数以对dataset中的数据进行处理;另外,可使用repeat、shuffle、batch方法对dataset进行重复、混洗、分批;用repeat复制dataset以进行多个epoch;如下:
dataset=dataset.repeat(epochs).shuffle(buffer_size).batch(batch_size)
解析完数据后,便可以取出数据进行使用,通过创建iterator来进行,如下:
iterator=dataset.make_one_shot_iterator() batch_image,batch_label,batch_shape=iterator.get_next()
要把不同dataset的数据feed进行模型,则需要先创建iteratorhandle,即iteratorplaceholder,如下:
handle=tf.placeholder(tf.string,shape=[]) iterator=tf.data.Iterator.from_string_handle(handle,\ dataset_train.output_types,dataset_train.output_shapes) image,label,shape=iterator.get_next()
然后为各个dataset创建handle,以feed_dict传入placeholder,如下:
withtf.Session()assess:
handle_train,handle_val,handle_test=sess.run(\
[x.string_handle()forxin[iter_train,iter_val,iter_test]])
sess.run([loss,train_op],feed_dict={handle:handle_train}
汇总:
importtensorflowastf
train_f,val_f,test_f=['mnist-%s.tfrecord'%iforiin['train','val','test']]
defparse_exmp(serial_exmp):
feats=tf.parse_single_example(serial_exmp,features={'feature':tf.FixedLenFeature([],tf.string),\
'label':tf.FixedLenFeature([10],tf.float32),'shape':tf.FixedLenFeature([],tf.int64)})
image=tf.decode_raw(feats['feature'],tf.float32)
label=feats['label']
shape=tf.cast(feats['shape'],tf.int32)
returnimage,label,shape
defget_dataset(fname):
dataset=tf.data.TFRecordDataset(fname)
returndataset.map(parse_exmp)#usepadded_batchmethodifpaddingneeded
epochs=16
batch_size=50#whenbatch_sizecan'tbedividedbynDatas,like56,
#therewillbeabatchdatawithnumslessthanbatch_size
#trainingdataset
nDatasTrain=46750
dataset_train=get_dataset(train_f)
dataset_train=dataset_train.repeat(epochs).shuffle(1000).batch(batch_size)#makesurerepeatisaheadbatch
#thisisdifferentfromdataset.shuffle(1000).batch(batch_size).repeat(epochs)
#thelattermeansthattherewillbeabatchdatawithnumslessthanbatch_sizeforeachepoch
#ifwhenbatch_sizecan'tbedividedbynDatas.
nBatchs=nDatasTrain*epochs//batch_size
#evalationdataset
nDatasVal=8250
dataset_val=get_dataset(val_f)
dataset_val=dataset_val.batch(nDatasVal).repeat(nBatchs//100*2)
#testdataset
nDatasTest=10000
dataset_test=get_dataset(test_f)
dataset_test=dataset_test.batch(nDatasTest)
#makedatasetiterator
iter_train=dataset_train.make_one_shot_iterator()
iter_val=dataset_val.make_one_shot_iterator()
iter_test=dataset_test.make_one_shot_iterator()
#makefeedableiterator
handle=tf.placeholder(tf.string,shape=[])
iterator=tf.data.Iterator.from_string_handle(handle,\
dataset_train.output_types,dataset_train.output_shapes)
x,y_,_=iterator.get_next()
train_op,loss,eval_op=model(x,y_)
init=tf.initialize_all_variables()
#summary
logdir='./logs/m4d2a'
defsummary_op(datapart='train'):
tf.summary.scalar(datapart+'-loss',loss)
tf.summary.scalar(datapart+'-eval',eval_op)
returntf.summary.merge_all()
summary_op_train=summary_op()
summary_op_test=summary_op('val')
withtf.Session()assess:
sess.run(init)
handle_train,handle_val,handle_test=sess.run(\
[x.string_handle()forxin[iter_train,iter_val,iter_test]])
_,cur_loss,cur_train_eval,summary=sess.run([train_op,loss,eval_op,summary_op_train],\
feed_dict={handle:handle_train,keep_prob:0.5})
cur_val_loss,cur_val_eval,summary=sess.run([loss,eval_op,summary_op_test],\
feed_dict={handle:handle_val,keep_prob:1.0})
3.mnist实验
importtensorflowastf
train_f,val_f,test_f=['mnist-%s.tfrecord'%iforiin['train','val','test']]
defparse_exmp(serial_exmp):
feats=tf.parse_single_example(serial_exmp,features={'feature':tf.FixedLenFeature([],tf.string),\
'label':tf.FixedLenFeature([10],tf.float32),'shape':tf.FixedLenFeature([],tf.int64)})
image=tf.decode_raw(feats['feature'],tf.float32)
label=feats['label']
shape=tf.cast(feats['shape'],tf.int32)
returnimage,label,shape
defget_dataset(fname):
dataset=tf.data.TFRecordDataset(fname)
returndataset.map(parse_exmp)#usepadded_batchmethodifpaddingneeded
epochs=16
batch_size=50#whenbatch_sizecan'tbedividedbynDatas,like56,
#therewillbeabatchdatawithnumslessthanbatch_size
#trainingdataset
nDatasTrain=46750
dataset_train=get_dataset(train_f)
dataset_train=dataset_train.repeat(epochs).shuffle(1000).batch(batch_size)#makesurerepeatisaheadbatch
#thisisdifferentfromdataset.shuffle(1000).batch(batch_size).repeat(epochs)
#thelattermeansthattherewillbeabatchdatawithnumslessthanbatch_sizeforeachepoch
#ifwhenbatch_sizecan'tbedividedbynDatas.
nBatchs=nDatasTrain*epochs//batch_size
#evalationdataset
nDatasVal=8250
dataset_val=get_dataset(val_f)
dataset_val=dataset_val.batch(nDatasVal).repeat(nBatchs//100*2)
#testdataset
nDatasTest=10000
dataset_test=get_dataset(test_f)
dataset_test=dataset_test.batch(nDatasTest)
#makedatasetiterator
iter_train=dataset_train.make_one_shot_iterator()
iter_val=dataset_val.make_one_shot_iterator()
iter_test=dataset_test.make_one_shot_iterator()
#makefeedableiterator,i.e.iteratorplaceholder
handle=tf.placeholder(tf.string,shape=[])
iterator=tf.data.Iterator.from_string_handle(handle,\
dataset_train.output_types,dataset_train.output_shapes)
x,y_,_=iterator.get_next()
#cnn
x_image=tf.reshape(x,[-1,28,28,1])
w_init=tf.truncated_normal_initializer(stddev=0.1,seed=9)
b_init=tf.constant_initializer(0.1)
cnn1=tf.layers.conv2d(x_image,32,(5,5),padding='same',activation=tf.nn.relu,\
kernel_initializer=w_init,bias_initializer=b_init)
mxpl1=tf.layers.max_pooling2d(cnn1,2,strides=2,padding='same')
cnn2=tf.layers.conv2d(mxpl1,64,(5,5),padding='same',activation=tf.nn.relu,\
kernel_initializer=w_init,bias_initializer=b_init)
mxpl2=tf.layers.max_pooling2d(cnn2,2,strides=2,padding='same')
mxpl2_flat=tf.reshape(mxpl2,[-1,7*7*64])
fc1=tf.layers.dense(mxpl2_flat,1024,activation=tf.nn.relu,\
kernel_initializer=w_init,bias_initializer=b_init)
keep_prob=tf.placeholder('float')
fc1_drop=tf.nn.dropout(fc1,keep_prob)
logits=tf.layers.dense(fc1_drop,10,kernel_initializer=w_init,bias_initializer=b_init)
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits,labels=y_))
optmz=tf.train.AdamOptimizer(1e-4)
train_op=optmz.minimize(loss)
defget_eval_op(logits,labels):
corr_prd=tf.equal(tf.argmax(logits,1),tf.argmax(labels,1))
returntf.reduce_mean(tf.cast(corr_prd,'float'))
eval_op=get_eval_op(logits,y_)
init=tf.initialize_all_variables()
#summary
logdir='./logs/m4d2a'
defsummary_op(datapart='train'):
tf.summary.scalar(datapart+'-loss',loss)
tf.summary.scalar(datapart+'-eval',eval_op)
returntf.summary.merge_all()
summary_op_train=summary_op()
summary_op_val=summary_op('val')
#whethertorestoreornot
ckpts_dir='ckpts/'
ckpt_nm='cnn-ckpt'
saver=tf.train.Saver(max_to_keep=50)#defaultstosaveallvariables,usingdict{'x':x,...}tosavespecifiedones.
restore_step=''
start_step=0
train_steps=nBatchs
best_loss=1e6
best_step=0
#importos
#os.environ["CUDA_VISIBLE_DEVICES"]="0"
#config=tf.ConfigProto()
#config.gpu_options.per_process_gpu_memory_fraction=0.9
#config.gpu_options.allow_growth=True#allocatewhenneeded
#withtf.Session(config=config)assess:
withtf.Session()assess:
sess.run(init)
handle_train,handle_val,handle_test=sess.run(\
[x.string_handle()forxin[iter_train,iter_val,iter_test]])
ifrestore_step:
ckpt=tf.train.get_checkpoint_state(ckpts_dir)
ifckptandckpt.model_checkpoint_path:#ckpt.model_checkpoint_pathmeansthelatestckpt
ifrestore_step=='latest':
ckpt_f=tf.train.latest_checkpoint(ckpts_dir)
start_step=int(ckpt_f.split('-')[-1])+1
else:
ckpt_f=ckpts_dir+ckpt_nm+'-'+restore_step
print('loadingwgtfile:'+ckpt_f)
saver.restore(sess,ckpt_f)
summary_wrt=tf.summary.FileWriter(logdir,sess.graph)
ifrestore_stepin['','latest']:
foriinrange(start_step,train_steps):
_,cur_loss,cur_train_eval,summary=sess.run([train_op,loss,eval_op,summary_op_train],\
feed_dict={handle:handle_train,keep_prob:0.5})
#logtostdoutandevalvalidationset
ifi%100==0ori==train_steps-1:
saver.save(sess,ckpts_dir+ckpt_nm,global_step=i)#savevariables
summary_wrt.add_summary(summary,global_step=i)
cur_val_loss,cur_val_eval,summary=sess.run([loss,eval_op,summary_op_val],\
feed_dict={handle:handle_val,keep_prob:1.0})
ifcur_val_loss
实验结果:
以上这篇tensorflow入门:tfrecord和tf.data.TFRecordDataset的使用就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。