如何在TensorFlow中进行批处理内部产品?
-
16-10-2019 - |
题
我有两个张量 a:[batch_size, dim]
b:[batch_size, dim]
。我想为批处理中的每对制作内部产品,生成 c:[batch_size, 1]
, , 在哪里 c[i,0]=a[i,:].T*b[i,:]
. 。如何?
解决方案
没有本地人 .dot_product
方法。但是,两个向量之间的点乘积只是元素乘以求和,因此以下示例有效:
import tensorflow as tf
# Arbitrarity, we'll use placeholders and allow batch size to vary,
# but fix vector dimensions.
# You can change this as you see fit
a = tf.placeholder(tf.float32, shape=(None, 3))
b = tf.placeholder(tf.float32, shape=(None, 3))
c = tf.reduce_sum( tf.multiply( a, b ), 1, keep_dims=True )
with tf.Session() as session:
print( c.eval(
feed_dict={ a: [[1,2,3],[4,5,6]], b: [[2,3,4],[5,6,7]] }
) )
输出是:
[[ 20.]
[ 92.]]
其他提示
值得一看的另一个选项是 [tf.einsum][1]
- 这本质上是一个简化的版本 爱因斯坦符号.
跟随尼尔和杜姆卡的例子:
import tensorflow as tf
a = tf.placeholder(tf.float32, shape=(None, 3))
b = tf.placeholder(tf.float32, shape=(None, 3))
c = tf.einsum('ij,ij->i', a, b)
with tf.Session() as session:
print( c.eval(
feed_dict={ a: [[1,2,3],[4,5,6]], b: [[2,3,4],[5,6,7]] }
) )
第一个论点 einsum
是代表要乘以和求和的轴的方程式。方程式的基本规则是:
- 输入量通过逗号分隔的维度标签来描述
- 重复标签表明相应的维度将乘以
- 输出调整器由代表相应输入(或产品)的另一个维度标签的字符串描述
- 输出字符串中缺少的标签总结
在我们的情况下 ij,ij->i
意味着我们的输入将为2个相等形状的矩阵 (i,j)
, ,我们的输出将是形状的向量 (i,)
.
一旦掌握了它,您会发现 einsum
概括了许多其他操作:
X = [[1, 2]]
Y = [[3, 4], [5, 6]]
einsum('ab->ba', X) == [[1],[2]] # transpose
einsum('ab->a', X) == [3] # sum over last dimension
einsum('ab->', X) == 3 # sum over both dimensions
einsum('ab,bc->ac', X, Y) == [[13,16]] # matrix multiply
einsum('ab,bc->abc', X, Y) == [[[3,4],[10,12]]] # multiply and broadcast
很遗憾, einsum
与手动倍增+降低相比,相比之下,要进行巨大的性能。在表现至关重要的地方,我绝对建议您坚持使用尼尔的解决方案。
对角线 tf.tensordot 如果将轴设置为
[[1], [1]]
我已经改编了尼尔·斯莱特(Neil Slater)的例子:
import tensorflow as tf
# Arbitrarity, we'll use placeholders and allow batch size to vary,
# but fix vector dimensions.
# You can change this as you see fit
a = tf.placeholder(tf.float32, shape=(None, 3))
b = tf.placeholder(tf.float32, shape=(None, 3))
c = tf.diag_part(tf.tensordot( a, b, axes=[[1],[1]]))
with tf.Session() as session:
print( c.eval(
feed_dict={ a: [[1,2,3],[4,5,6]], b: [[2,3,4],[5,6,7]] }
) )
现在也给出:
[ 20. 92.]
不过,这可能是大型矩阵的最佳选择(请参阅讨论 这里)