diff --git a/gensim/models/wrappers/fasttext.py b/gensim/models/wrappers/fasttext.py index 76fdb0e34b..bfbb4986ee 100644 --- a/gensim/models/wrappers/fasttext.py +++ b/gensim/models/wrappers/fasttext.py @@ -90,7 +90,7 @@ def word_vec(self, word, use_norm=False): if word in self.vocab: return super(FastTextKeyedVectors, self).word_vec(word, use_norm) else: - word_vec = np.zeros(self.syn0_ngrams.shape[1]) + word_vec = np.zeros(self.syn0_ngrams.shape[1], dtype=np.float32) ngrams = compute_ngrams(word, self.min_n, self.max_n) ngrams = [ng for ng in ngrams if ng in self.ngrams] if use_norm: diff --git a/gensim/test/test_fasttext_wrapper.py b/gensim/test/test_fasttext_wrapper.py index 7bd73e5fd2..a673eca490 100644 --- a/gensim/test/test_fasttext_wrapper.py +++ b/gensim/test/test_fasttext_wrapper.py @@ -337,6 +337,17 @@ def testHash(self): ft_hash = fasttext.ft_hash('word') self.assertEqual(ft_hash, 1788406269) + def testConsistentDtype(self): + """Test that the same dtype is returned for OOV words as for words in the vocabulary""" + vocab_word = 'night' + oov_word = 'wordnotpresentinvocabulary' + self.assertIn(vocab_word, self.test_model.wv.vocab) + self.assertNotIn(oov_word, self.test_model.wv.vocab) + + vocab_embedding = self.test_model[vocab_word] + oov_embedding = self.test_model[oov_word] + self.assertEqual(vocab_embedding.dtype, oov_embedding.dtype) + if __name__ == '__main__': logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)