Pergunta

As a way to improve my model, I want to average GloVe vectors over a sentence. However, I can't get np.mean to work. The following code works when not averaging over words. (copied from other code)

embeddings_dict = {}
with open("glove.6B.50d.txt", 'r') as f:
    for line in f:
        values = line.split()
        word = values[0]
        vector = np.asarray(values[1:], "float32")
        embeddings_dict[word] = vector

g = open("input.txt", 'r')

vector_lines = []

for line in g:
    clean_line = line.translate(str.maketrans('', '', string.punctuation))
    array_line = clean_line.lower().split()
    vec_line = []
    i = 0
    for word in array_line:
        i += 1
        try:
            vec_line.append(embeddings_dict[word])
        except:
            vec_line.append(embeddings_dict["unk"])
    while i < 30: //pad up to thirthy words with zero vectors
        vec_line.append(np.zeros(50))
        i += 1
    vector_lines.append(np.asarray(vec_line))

X = np.asarray(vector_lines)

To average over words, I'm modifying a small part of the code

        try:
            vec_line.append(embeddings_dict[word])
        except:
            vec_line.append(embeddings_dict["unk"])

    #padding is not necessary anymore
    #while i < 30: 
    #    vec_line.append(np.zeros(50))
    #    i += 1

    vec_mean = np.mean(vec_line, axis=0, keepdims=True)[:,None]
    vector_lines.append(vec_mean)

X = np.asarray(vector_lines)

This gives me the error "ValueError: could not broadcast input array from shape (1,50) into shape (1,1)". It feels as if I have tried every possible modification to this code as I could, but I keep on getting shape issues. What is causing all of these issues?

Foi útil?

Solução

Change this:

import bumpy as np
a = np.array([1,2,3])
b = np.array([4,5,6])
vec_line = [a,b]

print(np.mean(vec_line, axis=0, keepdims=True)[:,None])
>>[[[2.5 3.5 4.5]]]
np.mean(vec_line, axis=0, keepdims=True)[:,None].shape
>>(1,1,3)

To this:

print(np.mean(vec_line, axis=0))
>>[2.5 3.5 4.5]
np.mean(vec_line, axis=0).shape
>>(3,)
Licenciado em: CC-BY-SA com atribuição
scroll top