Как использовать списки в TensorFlow?
-
16-10-2019 - |
Вопрос
У меня есть ряд списков, таких как [1,2,3,4], [2,3,4], [1,2], [2,3,4,6,8,10], длина которых равен Очевидно, ненадежный.
Как я могу использовать это в качестве ввода заполнителя в Tensorflow?
Как я пытался, следующая настройка вызовет ошибку.
tf.constant ([[1,2], [1,2,3] ...], dtype = tf.int32)
Так что я думаю заполнитель не может быть установлен верхним входом списков.
Есть какое -либо решение?
Редактировать:
Ниже приведен мой пример. Как заставить его работать без ошибок?
Решение
Когда вы создаете массив Numpy, как это:
x_data = np.array( [[1,2],[4,5,6],[1,2,3,4,5,6]])
Внутренний Numpy dtype - это «объект»:
array([[1, 2], [4, 5, 6], [1, 2, 3, 4, 5, 6]], dtype=object)
И это не может быть использовано в качестве тензора в тензоре. В любом случае, тензоры должны иметь одинаковый размер в каждом измерении, они не могут быть «рваными» и должны иметь форму, определенную одним числом в каждом измерении. Tensorflow в основном предполагает это обо всех своих типах данных. Хотя дизайнеры Tensorflow могут написать его теоретически, заставляя его принимать рваные массивы и включать функцию преобразования, такого рода автоматическое кассообразование не всегда является хорошей идеей, потому что она может скрыть проблему в входном коде.
Таким образом, вам нужно подумать о входных данных, чтобы сделать их полезной формой. При быстром поиске я нашел Этот подход в переполнении стека, воспроизводится как изменение в вашем коде:
import tensorflow as tf
import numpy as np
x = tf.placeholder( tf.int32, [3,None] )
y = x * 2
with tf.Session() as session:
x_data = np.array( [[1,2],[4,5,6],[1,2,3,4,5,6]] )
# Get lengths of each row of data
lens = np.array([len(x_data[i]) for i in range(len(x_data))])
# Mask of valid places in each row
mask = np.arange(lens.max()) < lens[:,None]
# Setup output array and put elements from data into masked positions
padded = np.zeros(mask.shape)
padded[mask] = np.hstack((x_data[:]))
# Call TensorFlow
result = session.run(y, feed_dict={x:padded})
# Remove the padding - the list function ensures we
# create same datatype as input. It is not necessary in the case
# where you are happy with a list of Numpy arrays instead
result_without_padding = np.array(
[list(result[i,0:lens[i]]) for i in range(lens.size)]
)
print( result_without_padding )
Вывод:
[[2, 4] [8, 10, 12] [2, 4, 6, 8, 10, 12]]
Вам не нужно удалять прокладку в конце - сделайте это только в том случае, если вам нужно показать свой выход в одном и том же формате массива. Также обратите внимание, что при подавлении полученного padded
Данные к более сложным подпрограммам, Zeros - или другие данные заполнения, если вы его измените - могут использоваться любым алгоритмом, который вы реализовали.
Если у вас много коротких массивов и всего один или два очень длинных, то вы можете рассмотреть возможность использования Разреженное тензорное представление Чтобы сохранить память и ускорить расчеты.
Другие советы
В качестве альтернативы использованию мягких массивов, вы можете просто подавать все данные в качестве одной большой строки спагетти, а затем делать оригами внутри графика TensorFlow
Пример:
import tensorflow as tf
import numpy as np
sess = tf.InteractiveSession()
noodle = tf.placeholder(tf.float32, [None])
chop_indices = tf.placeholder(tf.int32, [None,2])
do_origami = lambda list_idx: tf.gather(noodle, tf.range(chop_indices[list_idx,0], chop_indices[list_idx,1]))
print( [do_origami(list_idx=i).eval({noodle:[1,2,3,2,3,6], chop_indices:[[0,2],[2,3],[3,6]]}).tolist() for i in range(3)] )
Результат:
[[1.0, 2.0], [3.0], [2.0, 3.0, 6.0]]
Если у вас есть переменное количество внутренних списков, тогда удачи. Вы не можете вернуть список из tf.while_loop, и вы не можете просто использовать понимание списка, как указано выше, поэтому вам придется делать вычисления отдельно для каждого внутреннего списка.
import tensorflow as tf
sess = tf.InteractiveSession()
my_list = tf.Variable(initial_value=[1,2,3,4,5])
init = tf.global_variables_initializer()
sess.run(init)
sess.run(my_list)
Результат: массив ([1, 2, 3, 4, 5])