keras K.function获取某层的输出操作
如下所示:
fromkerasimportbackendasK
fromkeras.modelsimportload_model
models=load_model('models.hdf5')
image=r'image.png'
images=cv2.imread(r'image.png')
image_arr=process_image(image,(224,224,3))
image_arr=np.expand_dims(image_arr,axis=0)
layer_1=K.function([base_model.get_input_at(0)],[base_model.get_layer('layer_name').output])
f1=layer_1([image_arr])[0]
加载训练好并保存的网络模型
加载数据(图像),并将数据处理成array形式
指定输出层
将处理后的数据输入,然后获取输出
其中,K.function有两种不同的写法:
1.获取名为layer_name的层的输出
layer_1=K.function([base_model.get_input_at(0)],[base_model.get_layer('layer_name').output]) #指定输出层的名称
2.获取第n层的输出
layer_1=K.function([model.get_input_at(0)],[model.layers[5].output]) #指定输出层的序号(层号从0开始)
另外,需要注意的是,书写不规范会导致报错:
报错:
TypeError:inputstoaTensorFlowbackendfunctionshouldbealistortuple
将该句:
f1=layer_1(image_arr)[0]
修改为:
f1=layer_1([image_arr])[0]
补充知识:keras.backend.function()
如下所示:
deffunction(inputs,outputs,updates=None,**kwargs):
"""InstantiatesaKerasfunction.
Arguments:
inputs:Listofplaceholdertensors.
outputs:Listofoutputtensors.
updates:Listofupdateops.
**kwargs:Passedto`tf.Session.run`.
Returns:
OutputvaluesasNumpyarrays.
Raises:
ValueError:ifinvalidkwargsarepassedin.
"""
ifkwargs:
forkeyinkwargs:
if(keynotintf_inspect.getargspec(session_module.Session.run)[0]and
keynotintf_inspect.getargspec(Function.__init__)[0]):
msg=('Invalidargument"%s"passedtoK.functionwithTensorflow'
'backend')%key
raiseValueError(msg)
returnFunction(inputs,outputs,updates=updates,**kwargs)
这是keras.backend.function()的源码。其中函数定义开头的注释就是官方文档对该函数的解释。
我们可以发现function()函数返回的是一个Function对象。下面是Function类的定义。
classFunction(object):
"""Runsacomputationgraph.
Arguments:
inputs:Feedplaceholderstothecomputationgraph.
outputs:Outputtensorstofetch.
updates:Additionalupdateopstoberunatfunctioncall.
name:anametohelpusersidentifywhatthisfunctiondoes.
"""
def__init__(self,inputs,outputs,updates=None,name=None,
**session_kwargs):
updates=updatesor[]
ifnotisinstance(inputs,(list,tuple)):
raiseTypeError('`inputs`toaTensorFlowbackendfunction'
'shouldbealistortuple.')
ifnotisinstance(outputs,(list,tuple)):
raiseTypeError('`outputs`ofaTensorFlowbackendfunction'
'shouldbealistortuple.')
ifnotisinstance(updates,(list,tuple)):
raiseTypeError('`updates`inaTensorFlowbackendfunction'
'shouldbealistortuple.')
self.inputs=list(inputs)
self.outputs=list(outputs)
withops.control_dependencies(self.outputs):
updates_ops=[]
forupdateinupdates:
ifisinstance(update,tuple):
p,new_p=update
updates_ops.append(state_ops.assign(p,new_p))
else:
#assumedalreadyanop
updates_ops.append(update)
self.updates_op=control_flow_ops.group(*updates_ops)
self.name=name
self.session_kwargs=session_kwargs
def__call__(self,inputs):
ifnotisinstance(inputs,(list,tuple)):
raiseTypeError('`inputs`shouldbealistortuple.')
feed_dict={}
fortensor,valueinzip(self.inputs,inputs):
ifis_sparse(tensor):
sparse_coo=value.tocoo()
indices=np.concatenate((np.expand_dims(sparse_coo.row,1),
np.expand_dims(sparse_coo.col,1)),1)
value=(indices,sparse_coo.data,sparse_coo.shape)
feed_dict[tensor]=value
session=get_session()
updated=session.run(
self.outputs+[self.updates_op],
feed_dict=feed_dict,
**self.session_kwargs)
returnupdated[:len(self.outputs)]
所以,function函数利用我们之前已经创建好的comuptationgraph。遵循计算图,从输入到定义的输出。这也是为什么该函数经常用于提取中间层结果。
以上这篇kerasK.function获取某层的输出操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。