Torch.cat funziona con backpropagation?
Domanda
Mi chiedevo se andasse bene usare Torch.cat nella mia funzione in avanti. Lo sto facendo perché voglio che le prime due colonne del mio input saltino i livelli nascosti centrali e vadano direttamente al livello finale.
Ecco il mio codice: puoi vedere che uso Torch.cat all'ultimo momento per fare XCAT.
Il gradiente si propaga indietro? O la torcia.cat copre cosa è successo alle mie variabili nascoste?
class LinearRegressionForce(nn.Module):
def __init__(self, focus_input_size, rest_input_size, hidden_size_1, hidden_size_2, output_size):
super(LinearRegressionForce, self).__init__()
self.in1 = nn.Linear(rest_input_size, hidden_size_1)
self.middle1 = nn.Linear(hidden_size_1,hidden_size_2)
self.out4 = nn.Linear(focus_input_size + hidden_size_2,output_size)
def forward(self, inputs):
focus_inputs = inputs[:,0:focus_input_size]
rest_inputs = inputs[:,focus_input_size:(rest_input_size+focus_input_size)]
x = self.in1(rest_inputs).clamp(min=0)
x = self.middle1(x).clamp(min=0)
xcat = torch.cat((focus_inputs,x),1)
out = self.out4(xcat).clamp(min=0)
return out
Lo chiamo così:
rest_inputs = Variable(torch.from_numpy(rest_x_train))
focus_x_train_ones = np.concatenate((focus_x_train, np.ones((n,1))), axis=1)
focus_inputs = Variable(torch.from_numpy(focus_x_train_ones)).float()
inputs = torch.cat((focus_inputs,rest_inputs),1)
predicted = model(inputs).data.numpy()
Nessuna soluzione corretta
Autorizzato sotto: CC-BY-SA insieme a attribuzione
Non affiliato a datascience.stackexchange