diff --git a/deepmatch/models/gru4rec.py b/deepmatch/models/gru4rec.py new file mode 100644 index 0000000..2ac9451 --- /dev/null +++ b/deepmatch/models/gru4rec.py @@ -0,0 +1,77 @@ +import tensorflow as tf +from tensorflow.python.keras.models import Model +from tensorflow.python.keras.layers import GRU, Dense +from tensorflow.python.keras.activations import sigmoid +import tensorflow.python.keras.backend as K +from deepctr.inputs import build_input_features, create_embedding_matrix +from deepctr.layers.core import PredictionLayer + + +def bpr(yTrue, yhat): + """ + Bayesian Personalized Ranking + + """ + + yhatT = K.transpose(yhat) + return K.mean(-K.log(sigmoid(tf.linalg.diag_part(yhat) - yhatT))) + + +def top1(yTrue, yhat): + """ + This is a customized loss function designed for solving the task in 'session-based recommendations + with recurrent neural networks' + + """ + yhatT = tf.transpose(yhat) + term1 = tf.reduce_mean(tf.nn.sigmoid(-tf.diag_part(yhat) + yhatT) + tf.nn.sigmoid(yhatT ** 2), axis=0) + term2 = tf.nn.sigmoid(tf.diag_part(yhat) ** 2) / len(yhat) + return tf.reduce_mean(term1 - term2) + + +def GRU4REC(item_feature_columns, n_classes, gru_units, batch_size, l2_reg_embedding=1e-6, init_std=0.0001, + seed=1024): + """ + Instantiates the GRU for Recommendation Model architecture. + + :param item_feature_columns: An iterable containing item's features used by the model. + :param n_classes: int, number of the label classes. + :param gru_units: tuple, the layer number and units in each GRU layer. + :param batch_size: int, number of samples in each batch. + :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector + :param init_std: float, to use as the initialize std of embedding vector + :param seed: integer, to use as random seed. + :return: A Keras model instance. + + """ + + item_feature_name = item_feature_columns[0].name + + embedding_matrix_dict = create_embedding_matrix(item_feature_columns, l2_reg_embedding, + init_std, seed, prefix="") + + item_features = build_input_features(item_feature_columns) + item_features['movie_id'].set_shape((batch_size, 1)) + item_inputs_list = list(item_features.values()) + item_embedding_matrix = embedding_matrix_dict[item_feature_name] + + item_emb = item_embedding_matrix(item_features[item_feature_name]) + for i, j in enumerate(gru_units): + if i == 0: + x, gru_states = GRU(j, stateful=True, return_state=True, name='gru_{}'.format(str(i)))(item_emb) + else: + x, gru_states = GRU(j, stateful=True, return_state=True, name='gru_{}'.format(str(i)))(x) + + x = tf.reshape(x, (batch_size, 1, -1)) + + x = tf.reshape(x, (batch_size, -1)) + x = Dense(n_classes, activation='linear')(x) + + output = PredictionLayer("multiclass", False)(x) + + model = Model(inputs=item_inputs_list, outputs=output) + + model.__setattr__("item_input", item_inputs_list) + model.__setattr__("item_embedding", item_emb) + + return model diff --git a/examples/preprocess.py b/examples/preprocess.py index 4803749..d5cbf3f 100644 --- a/examples/preprocess.py +++ b/examples/preprocess.py @@ -111,3 +111,52 @@ def gen_model_input_sdm(train_set, user_profile, seq_short_len, seq_prefer_len): train_model_input[key] = user_profile.loc[train_model_input['user_id']][key].values return train_model_input, train_label + + +def gen_model_input_gru4rec(data, batch_size, session_key, item_key, time_key): + """ + Implement session-parallel mini-batches in 'session-based recommendations with recurrent neural networks' + section 3.1.1. + + """ + data.sort_values([session_key, time_key], inplace=True) + + click_offsets = np.zeros(data[session_key].nunique() + 1, dtype=np.int32) + # group & sort the df by session_key and get the offset values + click_offsets[1:] = data.groupby(session_key).size().cumsum() + + session_idx_arr = np.arange(data[session_key].nunique()) + + iters = np.arange(batch_size) + maxiter = iters.max() + start = click_offsets[session_idx_arr[iters]] + end = click_offsets[session_idx_arr[iters] + 1] + mask = [] # indicator for the sessions to be terminated + finished = False + + while not finished: + minlen = (end - start).min() + # Item indices (for embedding) for clicks where the first sessions start + idx_target = data[item_key].values[start] + for i in range(minlen - 1): + # Build inputs & targets + idx_input = idx_target + idx_target = data[item_key].values[start + i + 1] + inp = idx_input + target = idx_target + yield inp, target, mask + + # click indices where a particular session meets second-to-last element + start = start + (minlen - 1) + # see if how many sessions should terminate + mask = np.arange(len(iters))[(end - start) <= 1] + done_sessions_counter = len(mask) + for idx in mask: + maxiter += 1 + if maxiter >= len(click_offsets) - 1: + finished = True + break + # update the next starting/ending point + iters[idx] = maxiter + start[idx] = click_offsets[session_idx_arr[maxiter]] + end[idx] = click_offsets[session_idx_arr[maxiter] + 1] diff --git a/examples/run_gru4rec.py b/examples/run_gru4rec.py new file mode 100644 index 0000000..19e31f6 --- /dev/null +++ b/examples/run_gru4rec.py @@ -0,0 +1,104 @@ +import numpy as np +import pandas as pd + +from deepctr.inputs import SparseFeat +from sklearn.preprocessing import LabelEncoder +from tensorflow.python.keras import backend as K + +from deepmatch.utils import recall_N +from deepmatch.models.gru4rec import GRU4REC, top1, bpr +from preprocess import gen_model_input_gru4rec, gen_data_set + +if __name__ == "__main__": + debug = True + if debug: + data = pd.read_csvdata = pd.read_csv("./movielens_sample.txt")[['user_id', 'movie_id', 'timestamp']] + batch_size = 3 + else: + data_path = "./" + unames = ['user_id', 'gender', 'age', 'occupation', 'zip'] + user = pd.read_csv(data_path + 'ml-1m/users.dat', sep='::', header=None, names=unames) + rnames = ['user_id', 'movie_id', 'rating', 'timestamp'] + ratings = pd.read_csv(data_path + 'ml-1m/ratings.dat', sep='::', header=None, names=rnames) + mnames = ['movie_id', 'title', 'genres'] + movies = pd.read_csv(data_path + 'ml-1m/movies.dat', sep='::', header=None, names=mnames) + data = pd.merge(pd.merge(ratings, movies), user)[['user_id', 'movie_id', 'timestamp']] + batch_size = 512 + + features = ['user_id', 'movie_id'] + feature_max_idx = {} + for feature in features: + lbe = LabelEncoder() + data[feature] = lbe.fit_transform(data[feature]) + 1 + feature_max_idx[feature] = data[feature].max() + 1 + + data["rank"] = data.groupby("user_id")["timestamp"].rank("first", ascending=False) + test_set = data.loc[data['rank'] <= 2,] + train_set = data.loc[data['rank'] >= 2] + + epochs = 3 + embedding_dim = 128 + gru_units = (128,) + n_classes = feature_max_idx['movie_id'] + loss_fn = 'CrossEntropy' + + test_loader = gen_model_input_gru4rec(test_set, batch_size, 'user_id', 'movie_id', 'timestamp') + item_feature_columns = [SparseFeat('movie_id', feature_max_idx['movie_id'], embedding_dim)] + + K.set_learning_phase(True) + import tensorflow as tf + + if tf.__version__ >= '2.0.0': + tf.compat.v1.disable_eager_execution() + + model = GRU4REC(item_feature_columns, n_classes, gru_units, batch_size) + + if loss_fn == 'CrossEntropy': + model.compile(optimizer="adam", loss='sparse_categorical_crossentropy') + elif loss_fn == 'TOP1': + model.compile(optimizer="adam", loss=top1) + elif loss_fn == 'BPR': + model.compile(optimizer="adam", loss=bpr) + + model.summary() + + for epoch in range(epochs): + step = 0 + train_loader = gen_model_input_gru4rec(train_set, batch_size, 'user_id', 'movie_id', 'timestamp') + for feat, target, mask in train_loader: + real_mask = np.ones((batch_size, 1)) + for elt in mask: + real_mask[elt, :] = 0 + + for i in range(len(gru_units)): + hidden_states = K.get_value(model.get_layer('gru_{}'.format(str(i))).states[0]) + hidden_states = np.multiply(real_mask, hidden_states) + hidden_states = np.array(hidden_states, dtype=np.float32) + model.get_layer('gru_{}'.format(str(i))).reset_states(hidden_states) + + feat = np.array(feat).reshape((-1, 1)) + target = np.array(target).reshape((-1, 1)) + + tr_loss = model.train_on_batch(feat, target) + if step % 500 == 0: + print(step) + print(tr_loss) + step += 1 + + # s = [] + # hit = 0 + # total_sample = 0 + # n = 50 + # for feat, target, mask in test_loader: + # feat = np.array(feat).reshape((-1, 1)) + # target = np.array(target).reshape((-1, 1)) + # pred = model.predict(feat, batch_size=batch_size) + # pred = np.array(pred).argsort()[:, ::-1][:, :n] + # + # for i in range(len(pred)): + # s.append(recall_N(target[i], pred[i], n)) + # if target[i] in pred[i]: + # hit += 1 + # total_sample += 1 + # print("recall", np.mean(s)) + # print("hit rate", hit / total_sample)