我有两个张量 a:[batch_size, dim] b:[batch_size, dim]。我想为批处理中的每对制作内部产品,生成 c:[batch_size, 1], , 在哪里 c[i,0]=a[i,:].T*b[i,:]. 。如何?

有帮助吗?

解决方案

没有本地人 .dot_product 方法。但是,两个向量之间的点乘积只是元素乘以求和,因此以下示例有效:

import tensorflow as tf

# Arbitrarity, we'll use placeholders and allow batch size to vary,
# but fix vector dimensions.
# You can change this as you see fit
a = tf.placeholder(tf.float32, shape=(None, 3))
b = tf.placeholder(tf.float32, shape=(None, 3))

c = tf.reduce_sum( tf.multiply( a, b ), 1, keep_dims=True )

with tf.Session() as session:
    print( c.eval(
        feed_dict={ a: [[1,2,3],[4,5,6]], b: [[2,3,4],[5,6,7]] }
    ) )

输出是:

[[ 20.]
 [ 92.]]

其他提示

值得一看的另一个选项是 [tf.einsum][1] - 这本质上是一个简化的版本 爱因斯坦符号.

跟随尼尔和杜姆卡的例子:

import tensorflow as tf

a = tf.placeholder(tf.float32, shape=(None, 3))
b = tf.placeholder(tf.float32, shape=(None, 3))

c = tf.einsum('ij,ij->i', a, b)

with tf.Session() as session:
    print( c.eval(
        feed_dict={ a: [[1,2,3],[4,5,6]], b: [[2,3,4],[5,6,7]] }
    ) )

第一个论点 einsum 是代表要乘以和求和的轴的方程式。方程式的基本规则是:

  1. 输入量通过逗号分隔的维度标签来描述
  2. 重复标签表明相应的维度将乘以
  3. 输出调整器由代表相应输入(或产品)的另一个维度标签的字符串描述
  4. 输出字符串中缺少的标签总结

在我们的情况下 ij,ij->i 意味着我们的输入将为2个相等形状的矩阵 (i,j), ,我们的输出将是形状的向量 (i,).

一旦掌握了它,您会发现 einsum 概括了许多其他操作:

X = [[1, 2]]
Y = [[3, 4], [5, 6]]

einsum('ab->ba', X) == [[1],[2]]   # transpose
einsum('ab->a',  X) ==  [3]        # sum over last dimension
einsum('ab->',   X) ==   3         # sum over both dimensions

einsum('ab,bc->ac',  X, Y) == [[13,16]]          # matrix multiply
einsum('ab,bc->abc', X, Y) == [[[3,4],[10,12]]]  # multiply and broadcast

很遗憾, einsum 与手动倍增+降低相比,相比之下,要进行巨大的性能。在表现至关重要的地方,我绝对建议您坚持使用尼尔的解决方案。

对角线 tf.tensordot 如果将轴设置为

[[1], [1]]

我已经改编了尼尔·斯莱特(Neil Slater)的例子:

import tensorflow as tf

# Arbitrarity, we'll use placeholders and allow batch size to vary,
# but fix vector dimensions.
# You can change this as you see fit
a = tf.placeholder(tf.float32, shape=(None, 3))
b = tf.placeholder(tf.float32, shape=(None, 3))

c = tf.diag_part(tf.tensordot( a, b, axes=[[1],[1]]))

with tf.Session() as session:
    print( c.eval(
        feed_dict={ a: [[1,2,3],[4,5,6]], b: [[2,3,4],[5,6,7]] }
    ) )

现在也给出:

[ 20.  92.]

不过,这可能是大型矩阵的最佳选择(请参阅讨论 这里)

许可以下: CC-BY-SA归因
scroll top