This is an implementation in Pytorch (and HuggingFace) of the GAN-BERT method from https://github.com/crux82/ganbert which is available in Tensorflow. While the original GAN-BERT was an extension of BERT, this implementation can be adapted to several architectures, ranging from Roberta to Albert!
IMPORTANT: Since this implementation is slightly different from the original Tensorflow one, some results may vary. Any feedback or suggestions for improving this first version would be appreciated.
This is the code for the paper "GAN-BERT: Generative Adversarial Learning for Robust Text Classification with a Bunch of Labeled Examples" published in the ACL 2020 - short paper by Danilo Croce (Tor Vergata, University of Rome), Giuseppe Castellucci (Amazon) and Roberto Basili (Tor Vergata, University of Rome).
GAN-BERT is an extension of BERT which uses a Generative Adversarial setting to implement an effective semi-supervised learning schema. It allows training BERT with datasets composed of a limited amount of labeled examples and larger subsets of unlabeled material. GAN-BERT can be used in sequence classification tasks (also involving text pairs).
As in the original implementation in Tensorflow, this code runs the GAN-BERT experiment over the TREC dataset for the fine-grained Question Classification task. We provide in this package the code as well as the data for running an experiment by using 2% of the labeled material (109 examples) and 5343 unlabeled examples. The test set is composed of 500 annotated examples.
GAN-BERT is an extension of the BERT model within the Generative Adversarial Network (GAN) framework (Goodfellow et al, 2014). In particular, the Semi-Supervised GAN (Salimans et al, 2016) is used to make the BERT fine-tuning robust in such training scenarios where obtaining annotated material is problematic. When fine-tuned with very few labeled examples the BERT model is not able to provide sufficient performances. With GAN-BERT we extend the fine-tuning stage by introducing a Discriminator-Generator setting, where:
- the Generator G is devoted to producing "fake" vector representations of sentences;
- the Discriminator D is a BERT-based classifier over k+1 categories.
D has the role of classifying an example concerning the k categories of the task of interest, and it should recognize the examples that are generated by G (the k+1 category). G, instead, must produce representations as much similar as possible to the ones produced by the model for the "real" examples. G is penalized when D correctly classifies an example as fake.
In this context, the model is trained on both labeled and unlabeled examples. The labeled examples contribute to the computation of the loss function concerning the task k categories. The unlabeled examples contribute to the computation of the loss functions as they should not be incorrectly classified as belonging to the k+1 category (i.e., the fake category).
The resulting model is demonstrated to learn text classification tasks starting from very few labeled examples (50-60 examples) and to outperform the classical BERT fine-tuned models by a large margin in this setting.
More details are available at https://github.com/crux82/ganbert
If this software is usefull for your research, please cite the following paper:
@inproceedings{croce-etal-2020-gan,
title = "{GAN}-{BERT}: Generative Adversarial Learning for Robust Text Classification with a Bunch of Labeled Examples",
author = "Croce, Danilo and
Castellucci, Giuseppe and
Basili, Roberto",
booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
month = jul,
year = "2020",
address = "Online",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/2020.acl-main.191",
pages = "2114--2119"
}
We would like to thank Osman Mutlu and Ali Hürriyetoğlu for their implementation of GAN-BERT in Pytorch that inspired our porting. You can find their initial repository at this link. We would like to thank Claudia Breazzano (Tor Vergata, University of Rome) that supported this porting.