import tensorflow as tf from tensorflow.contrib import rnn, seq2seq class Model(): def __init__(self, dtype=tf.float32, **kwargs): """ Args: The following kwargs are recognized: input_size: dimension of a single input in an input sequence output_size: dimension of a single output in an output sequence output_sos_id: index of output start-of-sequence id (fed into the decoder at start; a reserved index that is never actually output; default: 0) output_eos_id: index of output end-of-sequence id (default: 1) enc_size: number of units in the LSTM cell (default: 42) dec_size: number of units in the LSTM cell (default: 96) """ self._input_size = kwargs['input_size'] self._output_size = kwargs['output_size'] self._output_sos_id = kwargs.get('output_sos_id', 0) self._output_eos_id = kwargs.get('output_eos_id', 1) self._enc_size = kwargs.get('enc_size', 42) self._dec_size = kwargs.get('dec_size', 96) self._dtype = dtype def _build_model(self, batch_size, helper_build_fn, decoder_maxiters=None, alignment_history=False): # embed input_data into a one-hot representation inputs = tf.one_hot(self.input_data, self._input_size, dtype=self._dtype) inputs_len = self.input_lengths with tf.name_scope('conv-encoder'): W = tf.Variable(tf.truncated_normal([3, self._input_size, self._enc_size], stddev=0.1), name="conv-weights") b = tf.Variable(tf.truncated_normal([self._enc_size], stddev=0.1), name="conv-bias") enc_out = tf.nn.elu(tf.nn.conv1d(inputs, W, stride=1, padding='SAME') + b) with tf.name_scope('attn-decoder'): dec_cell_in1 = rnn.GRUCell(self._dec_size) dec_cell_in2 = rnn.GRUCell(self._dec_size) memory = enc_out attn_mech = seq2seq.LuongMonotonicAttention(self._enc_size, memory, memory_sequence_length=inputs_len, sigmoid_noise=0.5, score_bias_init=-4., mode='recursive', scale=True) dec_cell_attn = rnn.MultiRNNCell([rnn.GRUCell(self._dec_size), rnn.GRUCell(self._enc_size)], state_is_tuple=True) dec_cell_attn = seq2seq.AttentionWrapper(dec_cell_attn, attn_mech, attention_layer_size=self._enc_size, alignment_history=alignment_history) dec_cell_out = rnn.GRUCell(self._output_size) dec_cell = rnn.MultiRNNCell([dec_cell_in1, dec_cell_in2, dec_cell_attn, dec_cell_out], state_is_tuple=True) dec = seq2seq.BasicDecoder(dec_cell, helper_build_fn(), dec_cell.zero_state(batch_size, self._dtype)) dec_out, dec_state, _ = seq2seq.dynamic_decode(dec, output_time_major=False, maximum_iterations=decoder_maxiters, impute_finished=True) self.outputs = dec_out.rnn_output self.output_ids = dec_out.sample_id self.final_state = dec_state def _output_onehot(self, ids): return tf.one_hot(ids, self._output_size, dtype=self._dtype) def train(self, batch_size, learning_rate=1e-4, out_help=False, time_discount=0.4, sampling_probability=0.2): """Build model for training. Args: batch_size: size of training batch """ self.input_data = tf.placeholder(tf.int32, [batch_size, None], name='input_data') self.input_lengths = tf.placeholder(tf.int32, [batch_size], name='input_lengths') self.output_data = tf.placeholder(tf.int32, [batch_size, None], name='output_data') self.output_lengths = tf.placeholder(tf.int32, [batch_size], name='output_lengths') output_data_maxlen = tf.shape(self.output_data)[1] def infer_helper(): return seq2seq.GreedyEmbeddingHelper( self._output_onehot, start_tokens=tf.fill([batch_size], self._output_sos_id), end_token=self._output_eos_id) def train_helper(): start_ids = tf.fill([batch_size, 1], self._output_sos_id) decoder_input_ids = tf.concat([start_ids, self.output_data], 1) decoder_inputs = self._output_onehot(decoder_input_ids) return seq2seq.ScheduledEmbeddingTrainingHelper(decoder_inputs, self.output_lengths, self._output_onehot, sampling_probability) helper = train_helper if out_help else infer_helper self._build_model(batch_size, helper, decoder_maxiters=output_data_maxlen) output_maxlen = tf.minimum(tf.shape(self.outputs)[1], output_data_maxlen) out_data_slice = tf.slice(self.output_data, [0, 0], [-1, output_maxlen]) out_logits_slice = tf.slice(self.outputs, [0, 0, 0], [-1, output_maxlen, -1]) out_pred_slice = tf.slice(self.output_ids, [0, 0], [-1, output_maxlen]) with tf.name_scope("costs"): losses = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=out_logits_slice, labels=out_data_slice) length_mask = tf.sequence_mask( self.output_lengths, maxlen=output_maxlen, dtype=self._dtype) losses = losses * length_mask # out_id = 2,3,4,5,6 : AA,AE,AH,AO,AW : reduce the cost by 20% for a-confusion data_is_a = tf.logical_and(tf.greater_equal(out_data_slice, 2), tf.less_equal(out_data_slice, 6)) pred_is_a = tf.logical_and(tf.greater_equal(out_pred_slice, 2), tf.less_equal(out_pred_slice, 6)) a_mask = tf.cast(tf.logical_and(data_is_a, pred_is_a), dtype=tf.float32) losses = losses * (1.0 - 0.2*a_mask) if time_discount > 0: # time discounts (only when using infer helper?) factors = tf.pow(tf.range(1, tf.to_float(output_maxlen + 1), dtype=tf.float32), -time_discount) losses = losses * tf.expand_dims(factors, 0) losses = tf.reduce_sum(losses, 1) self.losses = tf.reduce_sum(losses) tf.summary.scalar('losses', self.losses) inequality = tf.cast(tf.not_equal(self.output_ids, out_data_slice), dtype=tf.float32) # reduce inequality inaccuracy by 20% for a-confusion inequality = inequality * (1.0 - 0.1*a_mask) self.accuracy = tf.reduce_mean(1.0 - inequality) tf.summary.scalar('accuracy', tf.reduce_sum(self.accuracy)) self.global_step = tf.Variable(0, trainable=False, name="global_step") decay_rate = tf.constant(0.8, dtype=tf.float64) self.learning_rate = learning_rate * tf.pow(decay_rate, tf.floor(self.global_step/4000)) opt = tf.train.AdamOptimizer(self.learning_rate) self.train_step = opt.minimize(losses, global_step=self.global_step) def infer(self, output_maxlen=128): """Build model for inference. """ self.input_data = tf.placeholder(tf.int32, [1, None], name='input_data') self.input_lengths = None def infer_helper(): return seq2seq.GreedyEmbeddingHelper( self._output_onehot, start_tokens=tf.fill([1], self._output_sos_id), end_token=self._output_eos_id) self._build_model(1, infer_helper, decoder_maxiters=output_maxlen, alignment_history=True) # Also See # https://groups.google.com/a/tensorflow.org/forum/#!topic/discuss/dw3Y2lnMAJc