diff --git a/octis/models/VONT.py b/octis/models/VONT.py new file mode 100644 index 00000000..08d5fa57 --- /dev/null +++ b/octis/models/VONT.py @@ -0,0 +1,584 @@ +# Organizing the imports +# Standard libraries +import string +import pickle +from collections import defaultdict +import math + +# Libraries for Deep Learning and ML +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.nn import init +from sklearn.cluster import KMeans +from sklearn import metrics +from scipy import sparse +from octis.models.vONTSS_model.hyperspherical_vae.distributions.von_mises_fisher import VonMisesFisher, HypersphericalUniform +from octis.models.model import AbstractModel +from torch.distributions.kl import register_kl + + +# Libraries for NLP +import nltk +from nltk.corpus import stopwords, wordnet +from nltk.stem import WordNetLemmatizer +from nltk.tokenize import word_tokenize +from nltk import pos_tag +import gensim.downloader +import gensim + +# Other utilities +import pandas as pd +import numpy as np +# import ot +import matplotlib.pyplot as plt +# import seaborn as sns +# from datasets import Dataset +from octis.models.vONTSS_model.utils import kld_normal +from octis.models.vONTSS_model.preprocess import TextProcessor + +@register_kl(VonMisesFisher, HypersphericalUniform) +def _kl_vmf_uniform(vmf, hyu): + #print(vmf.entropy() , hyu.entropy()) + return -vmf.entropy() + hyu.entropy() + + +# Libraries for NLP +import nltk +from nltk.corpus import stopwords, wordnet +from nltk.stem import WordNetLemmatizer +from nltk.tokenize import word_tokenize +from nltk import pos_tag +import gensim.downloader +import gensim + +# Other utilities +import pandas as pd +import numpy as np +# import ot +import matplotlib.pyplot as plt +# import seaborn as sns +# from datasets import Dataset +from octis.models.vONTSS_model.utils import kld_normal +from octis.models.vONTSS_model.preprocess import TextProcessor + +@register_kl(VonMisesFisher, HypersphericalUniform) +def _kl_vmf_uniform(vmf, hyu): + #print(vmf.entropy() , hyu.entropy()) + return -vmf.entropy() + hyu.entropy() + + +class EmbTopic(nn.Module): + """ + A class used to represent decoder for Embedded Topic Modeling + reimplement of: https://github.com/lffloyd/embedded-topic-model + + Attributes + ---------- + topic_emb: nn.Parameters + represent topic embedding + + + Methods: + -------- + forward(logit) + Output the result from decoder + get_topics + result before log + + + """ + def __init__(self, embedding, k, normalize = False): + super(EmbTopic, self).__init__() + self.embedding = embedding + n_vocab, topic_dim = embedding.weight.size() + self.k = k + self.topic_emb = nn.Parameter(torch.Tensor(k, topic_dim)) + self.reset_parameters() + self.normalize = normalize + + def forward(self, logit): + # return the log_prob of vocab distribution +# if normalize: +# self.topic_emb = torch.nn.Parameter(normalize(self.topic_emb)) + if self.normalize: + val = normalize(self.topic_emb) @ self.embedding.weight.transpose(0, 1) + else: + val = self.topic_emb @ self.embedding.weight.transpose(0, 1) + # print(val.shape) + beta = F.softmax(val, dim=1) + # print(beta.shape) + # return beta + return torch.log(torch.matmul(logit, beta) + 1e-10) + + def get_topics(self): + return F.softmax(self.topic_emb @ self.embedding.weight.transpose(0, 1), dim=1) + + + def get_rank(self): + #self.topic_emb = torch.nn.Parameter(normalize(self.topic_emb)) + return normalize(self.topic_emb) @ self.embedding.weight.transpose(0, 1) + + def reset_parameters(self): + init.normal_(self.topic_emb) + # init.kaiming_uniform_(self.topic_emb, a=math.sqrt(5)) + # init.normal_(self.embedding.weight, std=0.01) + + def extra_repr(self): + k, d = self.topic_emb.size() + return 'topic_emb: Parameter({}, {})'.format(k, d) + + + +def topic_covariance_penalty(topic_emb, EPS=1e-12): + """topic_emb: T x topic_dim.""" + #normalized the topic + normalized_topic = topic_emb / (torch.norm(topic_emb, dim=-1, keepdim=True) + EPS) + #get topic similarity absolute value + cosine = (normalized_topic @ normalized_topic.transpose(0, 1)).abs() + #average similarity + mean = cosine.mean() + #variance + var = ((cosine - mean) ** 2).mean() + return mean - var, var, mean + +class NormalParameter(nn.Module): + def __init__(self, in_features, out_features): + super(NormalParameter, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.mu = nn.Linear(in_features, out_features) + self.log_sigma = nn.Linear(in_features, out_features) + self.reset_parameters() + + def forward(self, h): + return self.mu(h), self.log_sigma(h) + + def reset_parameters(self): + init.zeros_(self.log_sigma.weight) + init.zeros_(self.log_sigma.bias) + +class NTM(nn.Module): + """NTM that keeps track of output + """ + def __init__(self, hidden, normal, h_to_z, topics): + super(NTM, self).__init__() + self.hidden = hidden + self.normal = normal + self.h_to_z = h_to_z + self.topics = topics + self.output = None + self.drop = nn.Dropout(p=0.5) + def forward(self, x, n_sample=1): + h = self.hidden(x) + h = self.drop(h) + mu, log_sigma = self.normal(h) + #identify how far it is away from normal distribution + kld = kld_normal(mu, log_sigma) + #print(kld.shape) + rec_loss = 0 + for i in range(n_sample): + #reparametrician trick + z = torch.zeros_like(mu).normal_() * torch.exp(0.5*log_sigma) + mu + #decode + + z = self.h_to_z(z) + self.output = z + #print(z) + #z = self.drop(z) + #get log probability for reconstruction loss + log_prob = self.topics(z) + rec_loss = rec_loss - (log_prob * x).sum(dim=-1) + #average reconstruction loss + rec_loss = rec_loss / n_sample + #print(rec_loss.shape) + minus_elbo = rec_loss + kld + + return { + 'loss': minus_elbo, + 'minus_elbo': minus_elbo, + 'rec_loss': rec_loss, + 'kld': kld + } + + def get_topics(self): + return self.topics.get_topics() + +def optimal_transport_prior(softmax_top, index, + lambda_sh = 1): + """ add prior as a semi-supervised loss + + parameters + ---------- + softmax_top: softmax results from decoder + index: list: a list of list with number as index + embedding: numpy array, word embedding trained by spherical word embeddings + beta: float, weights for prior loss + gamma: float, weights for negative sampling + iter2: int, how many epochs to train for third phase + sample: int, sample number + lambda_sh: low means high entrophy + + Returns: + -------- + int + loss functions + + """ + + m = - torch.log(softmax_top + 1e-12) + loss = torch.cat([m[:, i].mean(axis = 1).reshape(1, -1) for i in index]).to(m.device) + #print(loss.shape) + b = torch.ones(loss.shape[1]).to(m.device) + a = torch.ones(loss.shape[0]).to(m.device) + + return ot.sinkhorn(a, b, loss, lambda_sh).sum() + +class VNTM(nn.Module): + """NTM that keeps track of output + """ + def __init__(self, hidden, normal, h_to_z, topics, layer, top_number, penalty, beta = 1, index = None, temp=10): + super(VNTM, self).__init__() + self.hidden = hidden + #self.normal = normal + self.h_to_z = h_to_z + self.topics = topics + self.output = None + self.index = index + self.drop = nn.Dropout(p=0.3) + self.fc_mean = nn.Linear(layer, top_number) + self.fc_var = nn.Linear(layer, 1) + self.num = top_number + self.penalty = penalty + self.temp = temp + self.beta = beta + + #self.dirichlet = torch.distributions.dirichlet.Dirichlet((torch.ones(self.topics.k)/self.topics.k).cuda()) + def forward(self, x, device, n_sample=1, epoch = 0): + h = self.hidden(x) + h = self.drop(h) + z_mean = self.fc_mean(h) + z_mean = z_mean / z_mean.norm(dim=-1, keepdim=True) + # the `+ 1` prevent collapsing behaviors + z_var = F.softplus(self.fc_var(h)) + 1 + + q_z = VonMisesFisher(z_mean, z_var) + p_z = HypersphericalUniform(self.num - 1, device=device) + kld = torch.distributions.kl.kl_divergence(q_z, p_z).mean().to(device) + #print(q_z) + #mu, log_sigma = self.normal(h) + #identify how far it is away from normal distribution + + #print(kld.shape) + rec_loss = 0 + for i in range(n_sample): + #reparametrician trick + z = q_z.rsample() + #z = nn.Softmax()(z) + #decode + #print(z) + + z = self.h_to_z(self.temp * z) + self.output = z + #print(z) + + #get log probability for reconstruction loss + log_prob = self.topics(z) + rec_loss = rec_loss - (log_prob * x).sum(dim=-1) + #average reconstruction loss + rec_loss = rec_loss / n_sample + #print(rec_loss.shape) + minus_elbo = rec_loss + kld + penalty, var, mean = topic_covariance_penalty(self.topics.topic_emb) + if self.index is not None: + sinkhorn = optimal_transport_prior(self.topics.get_topics(), self.index) + else: + sinkhorn = 0 + + return { + 'loss': minus_elbo + penalty * self.penalty + sinkhorn * self.beta, + 'minus_elbo': minus_elbo, + 'rec_loss': rec_loss, + 'kld': kld + } + + def get_topics(self): + return self.topics.get_topics() + +def get_mlp(features, activate): + """features: mlp size of each layer, append activation in each layer except for the first layer.""" + if isinstance(activate, str): + activate = getattr(nn, activate) + layers = [] + for in_f, out_f in zip(features[:-1], features[1:]): + layers.append(nn.Linear(in_f, out_f)) + layers.append(activate()) + return nn.Sequential(*layers) + +class GSM(NTM): + def __init__(self, hidden, normal, h_to_z, topics, penalty): + # h_to_z will output probabilities over topics + super(GSM, self).__init__(hidden, normal, h_to_z, topics) + self.penalty = penalty + + def forward(self, x, device, n_sample=1): + stat = super(GSM, self).forward(x, n_sample) + loss = stat['loss'].to(device) + penalty, var, mean = topic_covariance_penalty(self.topics.topic_emb) + + stat.update({ + 'loss': loss #+ penalty.to(device) * self.penalty, + # 'penalty_mean': mean, + # 'penalty_var': var, + # 'penalty': penalty.to(device) * self.penalty, + }) + + return stat + +class Topics(nn.Module): + def __init__(self, k, vocab_size, bias=True): + super(Topics, self).__init__() + self.k = k + self.vocab_size = vocab_size + self.topic = nn.Linear(k, vocab_size, bias=bias) + + def forward(self, logit): + # return the log_prob of vocab distribution + return torch.log_softmax(self.topic(logit), dim=-1) + + def get_topics(self): + return torch.softmax(self.topic.weight.data.transpose(0, 1), dim=-1) + + def get_topic_word_logit(self): + """topic x V. + Return the logits instead of probability distribution + """ + return self.topic.weight.transpose(0, 1) + + +class VONT(AbstractModel): + def __init__(self, epochs=20, batch_size=256, gpu_num=1, numb_embeddings=20, + learning_rate=0.002, weight_decay=1.2e-6, penalty=1, beta = 1, temp = 10, + top_n_words=20, num_representative_docs=5, top_n_topics=100, embedding_dim=100): + + self.dataset = None + self.epochs = epochs + self.batch_size = batch_size + self.gpu_num = gpu_num + self.numb_embeddings = numb_embeddings + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.penalty = penalty + self.top_n_words = top_n_words + self.num_representative_docs = num_representative_docs + self.top_n_topics = top_n_topics + self.embedding_dim = embedding_dim + self.device = torch.device("cpu") + self.beta = beta + self.temp = temp + self.z = None + self.model = None + + def train(self, X, batch_size): + self.model.train() + total_nll = 0.0 + total_kld = 0.0 + + indices = torch.randperm(X.shape[0]) + indices = torch.split(indices, batch_size) + length = len(indices) + for idx, ind in enumerate(indices): + data_batch = X[ind].to(self.device).float() + + d = self.model(x = data_batch, device = self.device) + + total_nll += d['rec_loss'].sum().item() / batch_size + total_kld += d['kld'].sum().item() / batch_size + loss = d['loss'] + + self.optimizer.zero_grad() + loss.sum().backward() + self.optimizer.step() + self.scheduler.step() + + print(total_nll/length, total_kld/length) + + def fit_transform(self, dataset, index = []): + self.dataset = dataset + self.tp = TextProcessor(self.dataset) + self.tp.process() + bag_of_words = torch.tensor(self.tp.bow) + if index != []: + index_words = [[self.tp.word_to_index[word] for word in ind if word in self.tp.word_to_index] for ind in index] + else: + index_words = None + print(index_words) + #print(bag_of_words.shape) + # rest of your initialization code here + layer = bag_of_words.shape[1]//16 + hidden = get_mlp([bag_of_words.shape[1], bag_of_words.shape[1]//4, layer], nn.GELU) + normal = NormalParameter(layer, self.numb_embeddings) + h_to_z = nn.Softmax() + embedding = nn.Embedding(bag_of_words.shape[1], 100) + # p1d = (0, 0, 0, 10000 - company1.embeddings.shape[0]) # pad last dim by 1 on each side + # out = F.pad(company1.embeddings, p1d, "constant", 0) # effectively zero padding + + glove_vectors = gensim.downloader.load('glove-wiki-gigaword-100') + embed = np.asarray([glove_vectors[self.tp.index_to_word[i]] if self.tp.index_to_word[i] in glove_vectors else np.asarray([1]*100) for i in self.tp.index_to_word ]) + print(embed.shape) + embedding.weight = torch.nn.Parameter(torch.from_numpy(embed).float()) + embedding.weight.requires_grad=True + + + + topics = EmbTopic(embedding = embedding, + k = self.numb_embeddings, normalize = False) + + + + + self.model = VNTM(hidden = hidden, + normal = normal, + h_to_z = h_to_z, + topics = topics, + layer = layer, + top_number = self.numb_embeddings, + index = index_words, + penalty = self.penalty, + beta = self.beta, + temp = self.temp, + ).to(self.device).float() + + #batch_size = 256 + self.optimizer = optim.Adam(self.model.parameters(), + lr=self.learning_rate, + weight_decay=self.weight_decay) + + + + + self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer, max_lr=0.002, steps_per_epoch=int(bag_of_words.shape[0]/self.batch_size) + 1, epochs=self.epochs) + + # Initialize and train your model + for epoch in range(self.epochs): + self.train(bag_of_words, self.batch_size) + + # Store the topics + emb = self.model.topics.get_topics().cpu().detach().numpy() + self.topics = [[self.tp.index_to_word[ind] for ind in np.argsort(emb[i])[::-1][:self.top_n_topics]] for i in range(self.numb_embeddings)] #100 can be specified + self.topics_score = [[score for score in np.sort(emb[i])[::-1]] for i in range(self.numb_embeddings)] + # Compute and store the documents-topics distributions + data_batch = bag_of_words.float() + self.model.cpu() + + z = self.model.hidden(data_batch) + z_mean = self.model.fc_mean(z) + z_mean = z_mean / z_mean.norm(dim=-1, keepdim=True) + self.z = self.model.h_to_z(z_mean).detach().numpy() + self.topic_doc = [[ind for ind in np.argsort(self.z[:, i])[::-1][:100] ] for i in range(self.numb_embeddings)] #100 can be specified + self.topic_doc_score = [[ind for ind in np.sort(self.z[:, i])[::-1][:100] ] for i in range(self.numb_embeddings)] #100 can be specified + + + return self.topics, self.z + + def get_topics(self, index): + return [(i, j) for i, j in zip(self.topics[index], self.topics_score[index])][:self.top_n_words] + + def get_representative_docs(self, index): + return [(self.dataset[i], j) for i, j in zip(self.topic_doc[index], self.topic_doc_score[index])][:self.num_representative_docs] + + def topic_word_matrix(self): + return self.model.topics.get_topics().cpu().detach().numpy() + + def topic_keywords(self): + return self.topics + +# def visualize_topic_similarity(self): +# # Compute m similarity matrix +# topic_word_matrix = self.model.topics.topic_emb.detach().numpy() +# similarity_matrix = np.matmul(topic_word_matrix, topic_word_matrix.T) + +# # Plot the similarity matrix as a heatmap +# plt.figure(figsize=(10, 10)) +# sns.heatmap(similarity_matrix, cmap="YlGnBu", square=True) +# plt.title('Topic Similarity Heatmap') +# plt.xlabel('Topic IDs') +# plt.ylabel('Topic IDs') +# plt.show() + + def visualize_topic_keywords(self, topic_id, num_keywords=10): + # Get top keywords for the given topic + topic_keywords = self.get_topics(topic_id)[:num_keywords] + words, scores = zip(*topic_keywords) + + # Generate the bar plot + plt.figure(figsize=(10, 5)) + plt.barh(words, scores, color='skyblue') + plt.xlabel("Keyword Importance") + plt.title(f"Top {num_keywords} Keywords for Topic {topic_id}") + plt.gca().invert_yaxis() + plt.show() + + def get_document_info(self, top_n_words=10): + data = [] + for topic_id in range(self.numb_embeddings): + topic_keywords = self.get_topics(topic_id)[:top_n_words] + topic_keywords_str = "_".join([word for word, _ in topic_keywords[:3]]) + + # Get the document that has the highest probability for this topic + doc_indices = np.argsort(self.z[:, topic_id])[::-1] + representative_doc_index = doc_indices[0] + representative_doc = self.dataset[representative_doc_index] + + # Count the number of documents that have this topic as their dominant topic + dominant_topics = np.argmax(self.z, axis=1) + num_docs = np.sum(dominant_topics == topic_id) + + data.append([topic_id, f"{topic_id}_{topic_keywords_str}", topic_keywords_str, representative_doc, num_docs]) + + df = pd.DataFrame(data, columns=["Topic", "Name", "Top_n_words", "Representative_Doc", "Num_Docs"]) + return df + + def train_model(self, dataset, hyperparameters={}, top_words=10): + + self.top_n_words = top_words + # Extract hyperparameters and set them as attributes + if 'epochs' in hyperparameters: + self.epochs = hyperparameters['epochs'] + if 'batch_size' in hyperparameters: + self.batch_size = hyperparameters['batch_size'] + if 'gpu_num' in hyperparameters: + self.gpu_num = hyperparameters['gpu_num'] + if 'numb_embeddings' in hyperparameters: + self.numb_embeddings = hyperparameters['numb_embeddings'] + if 'learning_rate' in hyperparameters: + self.learning_rate = hyperparameters['learning_rate'] + if 'weight_decay' in hyperparameters: + self.weight_decay = hyperparameters['weight_decay'] + if 'penalty' in hyperparameters: + self.penalty = hyperparameters['penalty'] + if 'beta' in hyperparameters: + self.beta = hyperparameters['beta'] + if 'temp' in hyperparameters: + self.temp = hyperparameters['temp'] + + if 'num_representative_docs' in hyperparameters: + self.num_representative_docs = hyperparameters['num_representative_docs'] + if 'top_n_topics' in hyperparameters: + self.top_n_topics = hyperparameters['top_n_topics'] + if 'embedding_dim' in hyperparameters: + self.embedding_dim = hyperparameters['embedding_dim'] + + # Check if the model has been trained + if self.z is None: + self.fit_transform(dataset) + + # Create the model output + model_output = {} + model_output['topics'] = [i[:top_words] for i in self.topics] + model_output['topic-word-matrix'] = self.model.topics.get_topics().cpu().detach().numpy() + model_output['topic-document-matrix'] = self.z.T + + return model_output + diff --git a/octis/models/vONTSS_model/hyperspherical_vae/__init__.py b/octis/models/vONTSS_model/hyperspherical_vae/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/octis/models/vONTSS_model/hyperspherical_vae/distributions/__init__.py b/octis/models/vONTSS_model/hyperspherical_vae/distributions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/octis/models/vONTSS_model/hyperspherical_vae/distributions/hyperspherical_uniform.py b/octis/models/vONTSS_model/hyperspherical_vae/distributions/hyperspherical_uniform.py new file mode 100644 index 00000000..520eeb28 --- /dev/null +++ b/octis/models/vONTSS_model/hyperspherical_vae/distributions/hyperspherical_uniform.py @@ -0,0 +1,56 @@ +import math +import torch + + +class HypersphericalUniform(torch.distributions.Distribution): + + support = torch.distributions.constraints.real + has_rsample = False + _mean_carrier_measure = 0 + + @property + def dim(self): + return self._dim + + @property + def device(self): + return self._device + + @device.setter + def device(self, val): + self._device = val if isinstance(val, torch.device) else torch.device(val) + + def __init__(self, dim, validate_args=None, device="cpu"): + super(HypersphericalUniform, self).__init__( + torch.Size([dim]), validate_args=validate_args + ) + self._dim = dim + self.device = device + + def sample(self, shape=torch.Size()): + output = ( + torch.distributions.Normal(0, 1) + .sample( + (shape if isinstance(shape, torch.Size) else torch.Size([shape])) + + torch.Size([self._dim + 1]) + ) + .to(self.device) + ) + + return output / output.norm(dim=-1, keepdim=True) + + def entropy(self): + return self.__log_surface_area() + + def log_prob(self, x): + return -torch.ones(x.shape[:-1], device=self.device) * self.__log_surface_area() + + def __log_surface_area(self): + if torch.__version__ >= "1.0.0": + lgamma = torch.lgamma(torch.tensor([(self._dim + 1) / 2]).to(self.device)) + else: + lgamma = torch.lgamma( + torch.Tensor([(self._dim + 1) / 2], device=self.device) + ) + return math.log(2) + ((self._dim + 1) / 2) * math.log(math.pi) - lgamma + \ No newline at end of file diff --git a/octis/models/vONTSS_model/hyperspherical_vae/distributions/von_mises_fisher.py b/octis/models/vONTSS_model/hyperspherical_vae/distributions/von_mises_fisher.py new file mode 100644 index 00000000..7c3d8d84 --- /dev/null +++ b/octis/models/vONTSS_model/hyperspherical_vae/distributions/von_mises_fisher.py @@ -0,0 +1,343 @@ +import math +import torch +from torch.distributions.kl import register_kl + +import math +import torch + + +class HypersphericalUniform(torch.distributions.Distribution): + + support = torch.distributions.constraints.real + has_rsample = False + _mean_carrier_measure = 0 + + @property + def dim(self): + return self._dim + + @property + def device(self): + return self._device + + @device.setter + def device(self, val): + self._device = val if isinstance(val, torch.device) else torch.device(val) + + def __init__(self, dim, validate_args=None, device="cpu"): + super(HypersphericalUniform, self).__init__( + torch.Size([dim]), validate_args=validate_args + ) + self._dim = dim + self.device = device + + def sample(self, shape=torch.Size()): + output = ( + torch.distributions.Normal(0, 1) + .sample( + (shape if isinstance(shape, torch.Size) else torch.Size([shape])) + + torch.Size([self._dim + 1]) + ) + .to(self.device) + ) + + return output / output.norm(dim=-1, keepdim=True) + + def entropy(self): + return self.__log_surface_area() + + def log_prob(self, x): + return -torch.ones(x.shape[:-1], device=self.device) * self.__log_surface_area() + + def __log_surface_area(self): + if torch.__version__ >= "1.0.0": + lgamma = torch.lgamma(torch.tensor([(self._dim + 1) / 2]).to(self.device)) + else: + lgamma = torch.lgamma( + torch.Tensor([(self._dim + 1) / 2], device=self.device) + ) + return math.log(2) + ((self._dim + 1) / 2) * math.log(math.pi) - lgamma + +class VonMisesFisher(torch.distributions.Distribution): + + arg_constraints = { + "loc": torch.distributions.constraints.real, + "scale": torch.distributions.constraints.positive, + } + support = torch.distributions.constraints.real + has_rsample = True + _mean_carrier_measure = 0 + + @property + def mean(self): + # option 1: + # return self.loc * ( + # ive(self.__m / 2, self.scale) / ive(self.__m / 2 - 1, self.scale) + # ) + # option 2: + return self.loc * ive_fraction_approx(torch.tensor(self.__m / 2), self.scale) + # options 3: + # return self.loc * ive_fraction_approx2(torch.tensor(self.__m / 2), self.scale) + + @property + def stddev(self): + return self.scale + + def __init__(self, loc, scale, validate_args=None, k=1): + self.dtype = loc.dtype + self.loc = loc + self.scale = scale + self.device = loc.device + self.__m = loc.shape[-1] + self.__e1 = (torch.Tensor([1.0] + [0] * (loc.shape[-1] - 1))).to(self.device) + self.k = k + + super().__init__(self.loc.size(), validate_args=validate_args) + + def sample(self, shape=torch.Size()): + with torch.no_grad(): + return self.rsample(shape) + + def rsample(self, shape=torch.Size()): + shape = shape if isinstance(shape, torch.Size) else torch.Size([shape]) + + w = ( + self.__sample_w3(shape=shape) + if self.__m == 3 + else self.__sample_w_rej(shape=shape) + ) + + v = ( + torch.distributions.Normal(0, 1) + .sample(shape + torch.Size(self.loc.shape)) + .to(self.device) + .transpose(0, -1)[1:] + ).transpose(0, -1) + v = v / v.norm(dim=-1, keepdim=True) + + w_ = torch.sqrt(torch.clamp(1 - (w ** 2), 1e-10)) + x = torch.cat((w, w_ * v), -1) + z = self.__householder_rotation(x) + + return z.type(self.dtype) + + def __sample_w3(self, shape): + shape = shape + torch.Size(self.scale.shape) + u = torch.distributions.Uniform(0, 1).sample(shape).to(self.device) + self.__w = ( + 1 + + torch.stack( + [torch.log(u), torch.log(1 - u) - 2 * self.scale], dim=0 + ).logsumexp(0) + / self.scale + ) + return self.__w + + def __sample_w_rej(self, shape): + c = torch.sqrt((4 * (self.scale ** 2)) + (self.__m - 1) ** 2) + b_true = (-2 * self.scale + c) / (self.__m - 1) + + # using Taylor approximation with a smooth swift from 10 < scale < 11 + # to avoid numerical errors for large scale + b_app = (self.__m - 1) / (4 * self.scale) + s = torch.min( + torch.max( + torch.tensor([0.0], dtype=self.dtype, device=self.device), + self.scale - 10, + ), + torch.tensor([1.0], dtype=self.dtype, device=self.device), + ) + b = b_app * s + b_true * (1 - s) + + a = (self.__m - 1 + 2 * self.scale + c) / 4 + d = (4 * a * b) / (1 + b) - (self.__m - 1) * math.log(self.__m - 1) + + self.__b, (self.__e, self.__w) = b, self.__while_loop(b, a, d, shape, k=self.k) + return self.__w + + @staticmethod + def first_nonzero(x, dim, invalid_val=-1): + mask = x > 0 + idx = torch.where( + mask.any(dim=dim), + mask.float().argmax(dim=1).squeeze(), + torch.tensor(invalid_val, device=x.device), + ) + return idx + + def __while_loop(self, b, a, d, shape, k=20, eps=1e-20): + # matrix while loop: samples a matrix of [A, k] samples, to avoid looping all together + b, a, d = [ + e.repeat(*shape, *([1] * len(self.scale.shape))).reshape(-1, 1) + for e in (b, a, d) + ] + w, e, bool_mask = ( + torch.zeros_like(b).to(self.device), + torch.zeros_like(b).to(self.device), + (torch.ones_like(b) == 1).to(self.device), + ) + + sample_shape = torch.Size([b.shape[0], k]) + shape = shape + torch.Size(self.scale.shape) + + while bool_mask.sum() != 0: + con1 = torch.tensor((self.__m - 1) / 2, dtype=torch.float64) + con2 = torch.tensor((self.__m - 1) / 2, dtype=torch.float64) + e_ = ( + torch.distributions.Beta(con1, con2) + .sample(sample_shape) + .to(self.device) + .type(self.dtype) + ) + + u = ( + torch.distributions.Uniform(0 + eps, 1 - eps) + .sample(sample_shape) + .to(self.device) + .type(self.dtype) + ) + + w_ = (1 - (1 + b) * e_) / (1 - (1 - b) * e_) + t = (2 * a * b) / (1 - (1 - b) * e_) + + accept = ((self.__m - 1.0) * t.log() - t + d) > torch.log(u) + accept_idx = self.first_nonzero(accept, dim=-1, invalid_val=-1).unsqueeze(1) + accept_idx_clamped = accept_idx.clamp(0) + # we use .abs(), in order to not get -1 index issues, the -1 is still used afterwards + w_ = w_.gather(1, accept_idx_clamped.view(-1, 1)) + e_ = e_.gather(1, accept_idx_clamped.view(-1, 1)) + + reject = accept_idx < 0 + accept = ~reject if torch.__version__ >= "1.2.0" else 1 - reject + + w[bool_mask * accept] = w_[bool_mask * accept] + e[bool_mask * accept] = e_[bool_mask * accept] + + bool_mask[bool_mask * accept] = reject[bool_mask * accept] + + return e.reshape(shape), w.reshape(shape) + + def __householder_rotation(self, x): + u = self.__e1 - self.loc + u = u / (u.norm(dim=-1, keepdim=True) + 1e-5) + z = x - 2 * (x * u).sum(-1, keepdim=True) * u + return z + + def entropy(self): + # option 1: + # output = ( + # -self.scale + # * ive(self.__m / 2, self.scale) + # / ive((self.__m / 2) - 1, self.scale) + # ) + # option 2: + output = - self.scale * ive_fraction_approx(torch.tensor(self.__m / 2), self.scale) + # option 3: + # output = - self.scale * ive_fraction_approx2(torch.tensor(self.__m / 2), self.scale) + + return output.view(*(output.shape[:-1])) #+ self._log_normalization() + + def log_prob(self, x): + return self._log_unnormalized_prob(x) - self._log_normalization() + + def _log_unnormalized_prob(self, x): + output = self.scale * (self.loc * x).sum(-1, keepdim=True) + + return output.view(*(output.shape[:-1])) + + def _log_normalization(self): + output = -( + (self.__m / 2 - 1) * torch.log(self.scale) + - (self.__m / 2) * math.log(2 * math.pi) + - (self.scale + torch.log(ive(self.__m / 2 - 1, self.scale))) + ) + + return output.view(*(output.shape[:-1])) + + +@register_kl(VonMisesFisher, HypersphericalUniform) +def _kl_vmf_uniform(vmf, hyu): + #print(vmf.entropy() , hyu.entropy()) + return -vmf.entropy() + hyu.entropy() + + + +import torch +import numpy as np +import scipy.special +from numbers import Number + + +class IveFunction(torch.autograd.Function): + @staticmethod + def forward(self, v, z): + + assert isinstance(v, Number), "v must be a scalar" + + self.save_for_backward(z) + self.v = v + z_cpu = z.data.cpu().numpy() + + if np.isclose(v, 0): + output = scipy.special.i0e(z_cpu, dtype=z_cpu.dtype) + elif np.isclose(v, 1): + output = scipy.special.i1e(z_cpu, dtype=z_cpu.dtype) + else: # v > 0 + output = scipy.special.ive(v, z_cpu, dtype=z_cpu.dtype) + # else: + # print(v, type(v), np.isclose(v, 0)) + # raise RuntimeError('v must be >= 0, it is {}'.format(v)) + + return torch.Tensor(output).to(z.device) + + @staticmethod + def backward(self, grad_output): + z = self.saved_tensors[-1] + return ( + None, + grad_output * (ive(self.v - 1, z) - ive(self.v, z) * (self.v + z) / z), + ) + + +class Ive(torch.nn.Module): + def __init__(self, v): + super(Ive, self).__init__() + self.v = v + + def forward(self, z): + return ive(self.v, z) + + +ive = IveFunction.apply + + +# ######### +# The below provided approximations were provided in the +# respective source papers, to improve the stability of +# the Bessel fractions. +# I_(v/2)(k) / I_(v/2 - 1)(k) + +# source: https://arxiv.org/pdf/1606.02008.pdf +def ive_fraction_approx(v, z): + # I_(v/2)(k) / I_(v/2 - 1)(k) >= z / (v-1 + ((v+1)^2 + z^2)^0.5 + return z / (v - 1 + torch.pow(torch.pow(v + 1, 2) + torch.pow(z, 2), 0.5)) + + +# source: https://arxiv.org/pdf/1902.02603.pdf +def ive_fraction_approx2(v, z, eps=1e-20): + def delta_a(a): + lamb = v + (a - 1.0) / 2.0 + return (v - 0.5) + lamb / ( + 2 * torch.sqrt((torch.pow(lamb, 2) + torch.pow(z, 2)).clamp(eps)) + ) + + delta_0 = delta_a(0.0) + delta_2 = delta_a(2.0) + B_0 = z / ( + delta_0 + torch.sqrt((torch.pow(delta_0, 2) + torch.pow(z, 2))).clamp(eps) + ) + B_2 = z / ( + delta_2 + torch.sqrt((torch.pow(delta_2, 2) + torch.pow(z, 2))).clamp(eps) + ) + + return (B_0 + B_2) / 2.0 diff --git a/octis/models/vONTSS_model/hyperspherical_vae/ops/__init__.py b/octis/models/vONTSS_model/hyperspherical_vae/ops/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/octis/models/vONTSS_model/hyperspherical_vae/ops/ive.py b/octis/models/vONTSS_model/hyperspherical_vae/ops/ive.py new file mode 100644 index 00000000..63b1d05f --- /dev/null +++ b/octis/models/vONTSS_model/hyperspherical_vae/ops/ive.py @@ -0,0 +1,79 @@ +import torch +import numpy as np +import scipy.special +from numbers import Number + + +class IveFunction(torch.autograd.Function): + @staticmethod + def forward(self, v, z): + + assert isinstance(v, Number), "v must be a scalar" + + self.save_for_backward(z) + self.v = v + z_cpu = z.data.cpu().numpy() + + if np.isclose(v, 0): + output = scipy.special.i0e(z_cpu, dtype=z_cpu.dtype) + elif np.isclose(v, 1): + output = scipy.special.i1e(z_cpu, dtype=z_cpu.dtype) + else: # v > 0 + output = scipy.special.ive(v, z_cpu, dtype=z_cpu.dtype) + # else: + # print(v, type(v), np.isclose(v, 0)) + # raise RuntimeError('v must be >= 0, it is {}'.format(v)) + + return torch.Tensor(output).to(z.device) + + @staticmethod + def backward(self, grad_output): + z = self.saved_tensors[-1] + return ( + None, + grad_output * (ive(self.v - 1, z) - ive(self.v, z) * (self.v + z) / z), + ) + + +class Ive(torch.nn.Module): + def __init__(self, v): + super(Ive, self).__init__() + self.v = v + + def forward(self, z): + return ive(self.v, z) + + +ive = IveFunction.apply + + +########## +# The below provided approximations were provided in the +# respective source papers, to improve the stability of +# the Bessel fractions. +# I_(v/2)(k) / I_(v/2 - 1)(k) + +# source: https://arxiv.org/pdf/1606.02008.pdf +def ive_fraction_approx(v, z): + # I_(v/2)(k) / I_(v/2 - 1)(k) >= z / (v-1 + ((v+1)^2 + z^2)^0.5 + return z / (v - 1 + torch.pow(torch.pow(v + 1, 2) + torch.pow(z, 2), 0.5)) + + +# source: https://arxiv.org/pdf/1902.02603.pdf +def ive_fraction_approx2(v, z, eps=1e-20): + def delta_a(a): + lamb = v + (a - 1.0) / 2.0 + return (v - 0.5) + lamb / ( + 2 * torch.sqrt((torch.pow(lamb, 2) + torch.pow(z, 2)).clamp(eps)) + ) + + delta_0 = delta_a(0.0) + delta_2 = delta_a(2.0) + B_0 = z / ( + delta_0 + torch.sqrt((torch.pow(delta_0, 2) + torch.pow(z, 2))).clamp(eps) + ) + B_2 = z / ( + delta_2 + torch.sqrt((torch.pow(delta_2, 2) + torch.pow(z, 2))).clamp(eps) + ) + + return (B_0 + B_2) / 2.0 diff --git a/octis/models/vONTSS_model/preprocess.py b/octis/models/vONTSS_model/preprocess.py new file mode 100644 index 00000000..81477742 --- /dev/null +++ b/octis/models/vONTSS_model/preprocess.py @@ -0,0 +1,142 @@ +import argparse +from collections import defaultdict +from datasets import load_dataset, Dataset +from itertools import chain +import multiprocessing +from multiprocessing import Pool +import nltk +from nltk.corpus import stopwords, wordnet +from nltk.stem import WordNetLemmatizer +from nltk.tokenize import word_tokenize +from nltk import pos_tag +import numpy as np +import pandas as pd +from sklearn.feature_extraction.text import TfidfVectorizer +import string +import pickle + +class TextProcessor: + + def __init__(self, data): + self.data = data + self.bow = None + self.word_to_index = None + self.index_to_word = None + self.lemmatized_sentences = None + nltk.download('punkt') + nltk.download('wordnet') + nltk.download('averaged_perceptron_tagger') + nltk.download('stopwords') + + def __str__(self): + """String representation of TextProcessor""" + return f'TextProcessor(len(data)={len(self.data)})' + + + def get_wordnet_pos(self, word): + tag = pos_tag([word])[0][1][0].upper() + tag_dict = {"J": wordnet.ADJ, + "N": wordnet.NOUN, + "V": wordnet.VERB, + "R": wordnet.ADV} + + return tag_dict.get(tag, wordnet.NOUN) + + def convert_to_bag_of_words(self, list_of_lists, min_freq=10, max_freq_ratio=0.05): + # Your existing implementation here + word_freq = defaultdict(int) + for lst in list_of_lists: + for word in lst: + word_freq[word] += 1 + max_freq = len(list_of_lists) * max_freq_ratio + vocabulary = {word for word, count in word_freq.items() if min_freq <= count < max_freq} + word_to_index = {word: i for i, word in enumerate(vocabulary)} + index_to_word = {i: word for word, i in word_to_index.items()} + num_lists = len(list_of_lists) + vocab_size = len(vocabulary) + bag_of_words = [[0] * vocab_size for _ in range(num_lists)] + self.lemmas = [] + for i, lst in enumerate(list_of_lists): + lemma = [] + for word in lst: + lemma = [word for word in lst if word in word_to_index] + if word in word_to_index: + index = word_to_index[word] + bag_of_words[i][index] += 1 + + self.lemmas.append(lemma) + + self.bow, self.word_to_index, self.index_to_word = bag_of_words, word_to_index, index_to_word + + def extract_important_words(self, tfidf_vector, feature_names): + # Your existing implementation here + coo_matrix = tfidf_vector.tocoo() + sorted_items = sorted(zip(coo_matrix.col, coo_matrix.data), key=lambda x: (x[1], x[0]), reverse=True) + + return self.extract_topn_from_vector(feature_names, sorted_items) + + def extract_topn_from_vector(self, feature_names, sorted_items, topn=20): + sorted_items = sorted_items[:topn] + + score_vals = [] + feature_vals = [] + + for idx, score in sorted_items: + score_vals.append(round(score, 3)) + feature_vals.append(feature_names[idx]) + + results = {} + for idx in range(len(feature_vals)): + results[feature_vals[idx]] = score_vals[idx] + + return results + + # def lemmatize_sentences(self): + # lemmatizer = WordNetLemmatizer() + # lemmatized_sentences = [] + # stop_words = set(stopwords.words('english')) + # table = str.maketrans(string.punctuation, ' ' * len(string.punctuation)) + + # for index, sentence in enumerate(self.data): + # if index % 100 == 0: + # print(index) + # sentence = sentence.translate(table).lower().replace(" ", " ") + + # words = word_tokenize(sentence) + # lemmatized_words = [lemmatizer.lemmatize(word, self.get_wordnet_pos(word)) for word in words if word not in stop_words and lemmatizer.lemmatize(word, self.get_wordnet_pos(word)) != '' ] + # lemmatized_words = [i for i in lemmatized_words if len(i) >= 3 and (not i.isdigit()) and ' ' not in i] + # lemmatized_sentences.append(lemmatized_words) + + # self.lemmatized_sentences = lemmatized_sentences + + + def worker(self, data_chunk): + lemmatizer = WordNetLemmatizer() + stop_words = set(stopwords.words('english')) + table = str.maketrans(string.punctuation, ' ' * len(string.punctuation)) + lemmatized_chunk = [] + + for sentence in data_chunk: + sentence = sentence.translate(table).lower().replace(" ", " ") + words = word_tokenize(sentence) + lemmatized_words = [lemmatizer.lemmatize(word, self.get_wordnet_pos(word)) for word in words if word not in stop_words and lemmatizer.lemmatize(word, self.get_wordnet_pos(word)) != '' ] + lemmatized_words = [i for i in lemmatized_words if len(i) >= 3 and (not i.isdigit()) and ' ' not in i] + lemmatized_chunk.append(lemmatized_words) + + return lemmatized_chunk + + def lemmatize_sentences(self): + num_processes = multiprocessing.cpu_count() + pool = Pool(num_processes) + data_chunks = np.array_split(self.data, num_processes) + results = pool.map(self.worker, data_chunks) + pool.close() + pool.join() + #print(results) + self.lemmatized_sentences = [j for i in results for j in i] + + print(len(results), len(results[0])) + + def process(self): + self.lemmatize_sentences() + self.convert_to_bag_of_words(self.lemmatized_sentences) \ No newline at end of file diff --git a/octis/models/vONTSS_model/utils.py b/octis/models/vONTSS_model/utils.py new file mode 100644 index 00000000..815a78c7 --- /dev/null +++ b/octis/models/vONTSS_model/utils.py @@ -0,0 +1,12 @@ +import torch + +def kld_normal(mu, log_sigma): + """KL divergence to standard normal distribution. + mu: batch_size x dim + log_sigma: batch_size x dim + """ + #normal distribution KL divergence of two gaussian + #https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians + return -0.5 * (1 - mu ** 2 + 2 * log_sigma - torch.exp(2 * log_sigma)).sum(dim=-1) + + diff --git a/requirements.txt b/requirements.txt index 9f672164..5f476fa9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ flask sentence_transformers requests tomotopy +datasets==2.11.0