Pregunta

Tengo varias listas, como [1,2,3,4], [2,3,4], [1,2], [2,3,4,6,8,10], cuyas longitudes son Obviamente inquebrantable.

¿Cómo puedo usar esto como entrada del marcador de posición en TensorFlow?

Como he intentado, la siguiente configuración aumentará el error.

tf.constant ([[1,2], [1,2,3] ...], dtype = tf.int32)

Entonces supongo marcador de posición no se puede establecer mediante la entrada superior de las listas.

¿Hay alguna solución?

Editar:

El siguiente es mi ejemplo. ¿Cómo hacer que se ejecute sin errores?

enter image description here

¿Fue útil?

Solución

Cuando creas una matriz numpy como esta:

x_data = np.array( [[1,2],[4,5,6],[1,2,3,4,5,6]])

El dtype numpy interno es "objeto":

array([[1, 2], [4, 5, 6], [1, 2, 3, 4, 5, 6]], dtype=object)

y esto no puede usarse como tensor en Tensorflow. En cualquier caso, los tensores deben tener el mismo tamaño en cada dimensión, no pueden ser "irregulares" y deben tener una forma definida por un solo número en cada dimensión. TensorFlow básicamente asume esto sobre todos sus tipos de datos. Aunque los diseñadores de TensorFlow podrían escribirlo en teoría, hacer que acepte matrices irregulares e incluya una función de conversión, ese tipo de fundición automática no siempre es una buena idea, porque podría ocultar un problema en el código de entrada.

Por lo tanto, debe rellenar los datos de entrada para que sea una forma utilizable. En una búsqueda rápida, encontré Este enfoque en el desbordamiento de la pila, replicado como un cambio en su código:

import tensorflow as tf
import numpy as np

x = tf.placeholder( tf.int32, [3,None] )
y = x * 2

with tf.Session() as session:
    x_data = np.array( [[1,2],[4,5,6],[1,2,3,4,5,6]] )

    # Get lengths of each row of data
    lens = np.array([len(x_data[i]) for i in range(len(x_data))])

    # Mask of valid places in each row
    mask = np.arange(lens.max()) < lens[:,None]

    # Setup output array and put elements from data into masked positions
    padded = np.zeros(mask.shape)
    padded[mask] = np.hstack((x_data[:]))

    # Call TensorFlow
    result = session.run(y, feed_dict={x:padded})

    # Remove the padding - the list function ensures we 
    # create same datatype as input. It is not necessary in the case
    # where you are happy with a list of Numpy arrays instead
    result_without_padding = np.array(
       [list(result[i,0:lens[i]]) for i in range(lens.size)]
    )
    print( result_without_padding )

La salida es:

[[2, 4] [8, 10, 12] [2, 4, 6, 8, 10, 12]]

No tiene que eliminar el relleno al final, solo haga esto si necesita mostrar su salida en el mismo formato de matriz irregular. También tenga en cuenta que cuando alimenta el resultado padded Los datos a rutinas más complejas, los ceros, u otros datos de relleno, si lo cambia, pueden usarse por cualquier algoritmo que haya implementado.

Si tiene muchos matrices cortas y solo una o dos muy largas, entonces puede considerar usar un representación del tensor escaso Para guardar la memoria y acelerar los cálculos.

Otros consejos

Como alternativa al uso de matrices acolchadas, puede alimentar todos sus datos como una gran cadena de espagueti y luego hacer origami dentro del gráfico TensorFlow

Ejemplo:

import tensorflow as tf
import numpy as np

sess = tf.InteractiveSession()

noodle = tf.placeholder(tf.float32, [None])
chop_indices = tf.placeholder(tf.int32, [None,2])

do_origami = lambda list_idx: tf.gather(noodle, tf.range(chop_indices[list_idx,0], chop_indices[list_idx,1]))

print( [do_origami(list_idx=i).eval({noodle:[1,2,3,2,3,6], chop_indices:[[0,2],[2,3],[3,6]]}).tolist() for i in range(3)] )

Resultado:

[[1.0, 2.0], [3.0], [2.0, 3.0, 6.0]]

Sin embargo, si tiene un número variable de listas internas, entonces buena suerte. No puede devolver una lista de TF.WHILLHHIL_LOOP y no puede usar una comprensión de la lista como se indica anteriormente para que tenga que hacer los cálculos por separado para cada lista interior.

import tensorflow as tf

sess = tf.InteractiveSession()

my_list = tf.Variable(initial_value=[1,2,3,4,5])

init = tf.global_variables_initializer()

sess.run(init)

sess.run(my_list)

Resultado: Array ([1, 2, 3, 4, 5])

Licenciado bajo: CC-BY-SA con atribución
scroll top