Python产生batch数据的操作
产生batch数据
输入data中每个样本可以有多个特征,和一个标签,最好都是numpy.array格式。
datas=[data1,data2,…,dataN],labels=[label1,label2,…,labelN],
其中data[i]=[feature1,feature2,…featureM],表示每个样本数据有M个特征。
输入我们方法的数据,all_data=[datas,labels]。
代码实现
通过索引值来产生batch大小的数据,同时提供是否打乱顺序的选择,根据随机产生数据量范围类的索引值来打乱顺序。
importnumpyasnp defbatch_generator(all_data,batch_size,shuffle=True): """ :paramall_data:all_data整个数据集,包含输入和输出标签 :parambatch_size:batch_size表示每个batch的大小 :paramshuffle:是否打乱顺序 :return: """ #输入all_datas的每一项必须是numpy数组,保证后面能按p所示取值 all_data=[np.array(d)fordinall_data] #获取样本大小 data_size=all_data[0].shape[0] print("data_size:",data_size) ifshuffle: #随机生成打乱的索引 p=np.random.permutation(data_size) #重新组织数据 all_data=[d[p]fordinall_data] batch_count=0 whileTrue: #数据一轮循环(epoch)完成,打乱一次顺序 ifbatch_count*batch_size+batch_size>data_size: batch_count=0 ifshuffle: p=np.random.permutation(data_size) all_data=[d[p]fordinall_data] start=batch_count*batch_size end=start+batch_size batch_count+=1 yield[d[start:end]fordinall_data]
测试数据
样本数据x和标签y可以分开输入,也可以同时输入。
#输入x表示有23个样本,每个样本有两个特征 #输出y表示有23个标签,每个标签取值为0或1 x=np.random.random(size=[23,2]) y=np.random.randint(2,size=[23,1]) count=x.shape[0] batch_size=5 epochs=20 batch_num=count//batch_size batch_gen=batch_generator([x,y],batch_size) foriinrange(epochs): print("#####epoch%s#####"%i) forjinrange(batch_num): batch_x,batch_y=next(batch_gen) print("-----epoch=%s,batch=%s-----"%(i,j)) print(batch_x,batch_y)
补充:使用tensorflow.data.Dataset构造batch数据集
importtensorflowastf importnumpyasnp def_parse_function(x): num_list=np.arange(10) returnnum_list def_from_tensor_slice(x): returntf.data.Dataset.from_tensor_slices(x) softmax_data=tf.data.Dataset.range(1000)#构造一个队列 softmax_data=softmax_data.map(lambdax:tf.py_func(_parse_function,[x],[tf.int32]))#将数据进行传入 softmax_data=softmax_data.flat_map(_from_tensor_slice)#将数据进行平铺,将其变为一维的数据,from_tensor_slice将数据可以输出 softmax_data=softmax_data.batch(1)#构造一个batch的数量 softmax_iter=softmax_data.make_initializable_iterator()#构造数据迭代器 softmax_element=softmax_iter.get_next()#获得一个batch的数据 sess=tf.Session() sess.run(softmax_iter.initializer)#数据迭代器的初始化操作 print(sess.run(softmax_element))#实际获得一个数据 print(sess.run(softmax_data))
以上为个人经验,希望能给大家一个参考,也希望大家多多支持毛票票。如有错误或未考虑完全的地方,望不吝赐教。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。