Skip to content

yiyixuxu/n-grammer-flax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 

Repository files navigation

n-grammer-flax

Implementation of N-Grammer: Augmenting Transformers with latent n-grams in Flax

Usage

from n_grammer_flax.n_grammer_flax import PQNgrammer
import jax

key0, key1, key2 = jax.random.split(random.PRNGKey(0), 3)

init_rngs = {'params': key1, 
             'batch_stats': key2}

x = jax.random.normal(key0, shape=(1, 1024, 32 * 16)) 

pq_ngram = PQNgrammer(
    num_clusters = 1024, # number of clusters
    num_heads = 16, # number of attention heads
    dim_per_head = 32, # dimensions of each attention head
    ngram_vocab_size = 768 * 256, #ngram vocab size 
    ngram_emb_dim= 16, # ngram embedding 
    decay = 0.99)

init_variables  = pq_ngram.init(init_rngs, x)
out,mutated_variables  =pq_ngram.apply(init_variables,x, mutable=['batch_stats'])

print('mutated variables.shape:\n', jax.tree_map(lambda x: x.shape, mutated_variables))
print('output.shape:\n', out.shape)
mutated variables.shape:
 FrozenDict({
    batch_stats: {
        ProductQuantization_0: {
            means: (32, 1024, 16),
        },
    },
})
output.shape:
 (1, 1024, 512)

Acknowledgement

This Project is enabled by TRC program. Thank you google!

Reference

Thanks for lucidrains's concise implementation of N-Grammer in pytorch. n-grammer-flax is inspired by and tested against (see n_grammer_flax_test.py for more details ) his project.

also inspired by the official jax implementation: https://github.com/tensorflow/lingvo/tree/master/lingvo/jax

Citations

@inproceedings{thai2020using,
    title   = {N-grammer: Augmenting Transformers with latent n-grams},
    author  = {Aurko Roy and Rohan Anil and Guangda Lai and Benjamin Lee and Jeffrey Zhao and Shuyuan Zhang and Shibo Wang and Ye Zhang and Shen Wu and Rigel Swavely and Tao (Alex)Yu and Phuong Dao and Christopher Fifty and Zhifeng Chen and Yonghui Wu},
    year    = {2021},
    url     = {https://arxiv.org/abs/2207.06366}
}

About

Implementation of N-Grammer in Flax

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages