|
| 1 | +from keras import initializations |
| 2 | +from keras.layers.recurrent import time_distributed_dense |
| 3 | +from keras.activations import tanh, softmax |
| 4 | +from keras.layers import LSTM |
| 5 | +from keras.engine import InputSpec |
| 6 | +import keras.backend as K |
| 7 | + |
| 8 | + |
| 9 | +class PointerLSTM(LSTM): |
| 10 | + def __init__(self, hidden_shape, *args, **kwargs): |
| 11 | + self.hidden_shape = hidden_shape |
| 12 | + self.input_length = [] |
| 13 | + super(PointerLSTM, self).__init__(*args, **kwargs) |
| 14 | + |
| 15 | + def build(self, input_shape): |
| 16 | + super(PointerLSTM, self).build(input_shape) |
| 17 | + self.input_spec = [InputSpec(shape=input_shape)] |
| 18 | + init = initializations.get('orthogonal') |
| 19 | + self.W1 = init((self.hidden_shape, 1)) |
| 20 | + self.W2 = init((self.hidden_shape, 1)) |
| 21 | + self.vt = init((input_shape[1], 1)) |
| 22 | + self.trainable_weights += [self.W1, self.W2, self.vt] |
| 23 | + |
| 24 | + def call(self, x, mask=None): |
| 25 | + input_shape = self.input_spec[0].shape |
| 26 | + en_seq = x |
| 27 | + x_input = x[:, input_shape[1]-1, :] |
| 28 | + x_input = K.repeat(x_input, input_shape[1]) |
| 29 | + initial_states = self.get_initial_states(x_input) |
| 30 | + |
| 31 | + constants = super(PointerLSTM, self).get_constants(x_input) |
| 32 | + constants.append(en_seq) |
| 33 | + preprocessed_input = self.preprocess_input(x_input) |
| 34 | + |
| 35 | + last_output, outputs, states = K.rnn(self.step, preprocessed_input, |
| 36 | + initial_states, |
| 37 | + go_backwards=self.go_backwards, |
| 38 | + constants=constants, |
| 39 | + input_length=input_shape[1]) |
| 40 | + |
| 41 | + return outputs |
| 42 | + |
| 43 | + def step(self, x_input, states): |
| 44 | + #print "x_input:", x_input, x_input.shape |
| 45 | + # <TensorType(float32, matrix)> |
| 46 | + |
| 47 | + input_shape = self.input_spec[0].shape |
| 48 | + en_seq = states[-1] |
| 49 | + _, [h, c] = super(PointerLSTM, self).step(x_input, states[:-1]) |
| 50 | + |
| 51 | + # vt*tanh(W1*e+W2*d) |
| 52 | + dec_seq = K.repeat(h, input_shape[1]) |
| 53 | + Eij = time_distributed_dense(en_seq, self.W1, output_dim=1) |
| 54 | + Dij = time_distributed_dense(dec_seq, self.W2, output_dim=1) |
| 55 | + U = self.vt * tanh(Eij + Dij) |
| 56 | + U = K.squeeze(U, 2) |
| 57 | + |
| 58 | + # make probability tensor |
| 59 | + pointer = softmax(U) |
| 60 | + return pointer, [h, c] |
| 61 | + |
| 62 | + def get_output_shape_for(self, input_shape): |
| 63 | + # output shape is not affected by the attention component |
| 64 | + return (input_shape[0], input_shape[1], input_shape[1]) |
0 commit comments