在keras 中获取张量 tensor 的维度大小实例
在进行keras网络计算时,有时候需要获取输入张量的维度来定义自己的层。但是由于keras是一个封闭的接口。因此在调用由于是张量不能直接用numpy里的A.shape()。这样的形式来获取。这里需要调用一下keras作为后端的方式来获取。当我们想要操作时第一时间就想到直接用shape()函数。其实keras中真的有shape()这个函数。
shape(x)返回一个张量的符号shape,符号shape的意思是返回值本身也是一个tensor,
示例:
>>>fromkerasimportbackendasK >>>tf_session=K.get_session() >>>val=np.array([[1,2],[3,4]]) >>>kvar=K.variable(value=val) >>>input=keras.backend.placeholder(shape=(2,4,5)) >>>K.shape(kvar)>>>K.shape(input) __Togetintegershape(Instead,youcanuseK.int_shape(x))__ >>>K.shape(kvar).eval(session=tf_session) array([2,2],dtype=int32) >>>K.shape(input).eval(session=tf_session) array([2,4,5],dtype=int32)
如果直接调用这个出的不是我们想要的。我们想要的是tensor各个维度的大小。因此可以直接调用int_shape(x)函数。这个函数才是我们想要的。
>>>fromkerasimportbackendasK >>>input=K.placeholder(shape=(2,4,5)) >>>K.int_shape(input) (2,4,5) >>>val=np.array([[1,2],[3,4]]) >>>kvar=K.variable(value=val) >>>K.int_shape(kvar) (2,2)
最后这样我们就可以直接调用里面的大小。然后定义我们自己的keras层了。
补充知识:获取Tensor的维度(x.shape和x.get_shape()的区别)
tf.shape(a)和a.get_shape()比较
相同点:都可以得到tensora的尺寸
不同点:tf.shape()中a数据的类型可以是tensor,list,array
a.get_shape()中a的数据类型只能是tensor,且返回的是一个元组(tuple)
importtensorflowastf importnumpyasnp x=tf.constant([[1,2,3],[4,5,6]]) y=[[1,2,3],[4,5,6]] z=np.arange(24).reshape([2,3,4]) sess=tf.Session() #tf.shape() x_shape=tf.shape(x)#x_shape是一个tensor y_shape=tf.shape(y)#z_shape=tf.shape(z)# print(sess.run(x_shape))#结果:[23] print(sess.run(y_shape))#结果:[23] print(sess.run(z_shape))#结果:[234] x_shape=x.get_shape() print(x_shape)#返回的是TensorShape([Dimension(2),Dimension(3)]),不能使用sess.run()因为返回的不是tensor或string,而是元组(2,3) x_shape=x.get_shape().as_list() print(x_shape)#可以使用as_list()得到具体的尺寸,x_shape=[23]这是重点返回列表方便参加其他代码的运算 #y_shape=y.get_shape() print(x_shape)#AttributeError:'list'objecthasnoattribute'get_shape' #z_shape=z.get_shape() print(x_shape)#AttributeError:'numpy.ndarray'objecthasnoattribute'get_shape'或者a.shape.as_list()
以上这篇在keras中获取张量tensor的维度大小实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。