From 8c8799fc5313ee54e532ed2f7b322f9579098757 Mon Sep 17 00:00:00 2001 From: cheneydon Date: Tue, 25 Oct 2022 09:54:46 +0800 Subject: [PATCH 1/5] update dataset for fednlp --- .../core/auxiliaries/data_builder.py | 2 + federatedscope/core/configs/cfg_aggregator.py | 22 + federatedscope/core/configs/cfg_data.py | 11 + federatedscope/core/configs/cfg_evaluation.py | 4 + federatedscope/core/configs/cfg_model.py | 11 + federatedscope/core/configs/cfg_training.py | 10 +- federatedscope/nlp/dataloader/__init__.py | 9 +- .../nlp/dataloader/data_collator.py | 237 +++++++++++ .../nlp/dataloader/hfl_dataloader.py | 389 ++++++++++++++++++ federatedscope/nlp/dataset/agnews.py | 95 +++++ federatedscope/nlp/dataset/cnndm.py | 186 +++++++++ federatedscope/nlp/dataset/imdb.py | 93 +++++ federatedscope/nlp/dataset/msqg.py | 189 +++++++++ federatedscope/nlp/dataset/newsqa.py | 362 ++++++++++++++++ .../nlp/dataset/preprocess/get_hfl_data.py | 141 +++++++ federatedscope/nlp/dataset/squad.py | 356 ++++++++++++++++ federatedscope/nlp/dataset/utils.py | 56 +++ federatedscope/nlp/trainer/utils.py | 75 ++++ .../fedavg/config_client_fedavg.yaml | 94 +++++ .../fedavg/config_fedavg.yaml | 39 ++ scripts/fednlp_exp_scripts/fedavg/run.sh | 10 + 21 files changed, 2388 insertions(+), 3 deletions(-) create mode 100644 federatedscope/core/configs/cfg_aggregator.py create mode 100644 federatedscope/nlp/dataloader/data_collator.py create mode 100644 federatedscope/nlp/dataloader/hfl_dataloader.py create mode 100644 federatedscope/nlp/dataset/agnews.py create mode 100644 federatedscope/nlp/dataset/cnndm.py create mode 100644 federatedscope/nlp/dataset/imdb.py create mode 100644 federatedscope/nlp/dataset/msqg.py create mode 100644 federatedscope/nlp/dataset/newsqa.py create mode 100644 federatedscope/nlp/dataset/preprocess/get_hfl_data.py create mode 100644 federatedscope/nlp/dataset/squad.py create mode 100644 federatedscope/nlp/trainer/utils.py create mode 100644 scripts/fednlp_exp_scripts/fedavg/config_client_fedavg.yaml create mode 100644 scripts/fednlp_exp_scripts/fedavg/config_fedavg.yaml create mode 100644 scripts/fednlp_exp_scripts/fedavg/run.sh diff --git a/federatedscope/core/auxiliaries/data_builder.py b/federatedscope/core/auxiliaries/data_builder.py index 617d9689e..798f0bd24 100644 --- a/federatedscope/core/auxiliaries/data_builder.py +++ b/federatedscope/core/auxiliaries/data_builder.py @@ -4,6 +4,7 @@ from federatedscope.core.data.utils import RegexInverseMap, load_dataset, \ convert_data_mode from federatedscope.core.auxiliaries.utils import setup_seed +from federatedscope.nlp.dataloader import * import federatedscope.register as register @@ -16,6 +17,7 @@ f'{error} in `federatedscope.contrib.data`, some modules are not ' f'available.') + # TODO: Add PyGNodeDataTranslator and PyGLinkDataTranslator # TODO: move splitter to PyGNodeDataTranslator and PyGLinkDataTranslator TRANS_DATA_MAP = { diff --git a/federatedscope/core/configs/cfg_aggregator.py b/federatedscope/core/configs/cfg_aggregator.py new file mode 100644 index 000000000..e70df3e9f --- /dev/null +++ b/federatedscope/core/configs/cfg_aggregator.py @@ -0,0 +1,22 @@ +from federatedscope.core.configs.config import CN +from federatedscope.register import register_config + + +def extend_aggregator_cfg(cfg): + cfg.aggregator = CN() + cfg.aggregator.num_agg_groups = None + cfg.aggregator.num_agg_topk = None + cfg.aggregator.inside_weight = None + cfg.aggregator.outside_weight = None + cfg.aggregator.proto_weight = None + cfg.aggregator.synth_ratio = None + + # --------------- register corresponding check function ---------- + cfg.register_cfg_check_fun(assert_aggregator_cfg) + + +def assert_aggregator_cfg(cfg): + pass + + +register_config('aggregator', extend_aggregator_cfg) diff --git a/federatedscope/core/configs/cfg_data.py b/federatedscope/core/configs/cfg_data.py index b9eb0d48c..f09e13b95 100644 --- a/federatedscope/core/configs/cfg_data.py +++ b/federatedscope/core/configs/cfg_data.py @@ -55,6 +55,17 @@ def extend_data_cfg(cfg): cfg.data.quadratic.min_curv = 0.02 cfg.data.quadratic.max_curv = 12.5 + # fednlp + cfg.data.datasets = [] + cfg.data.num_grouped_clients = [] + cfg.data.max_seq_len = 0 + cfg.data.max_tgt_len = 0 + cfg.data.max_query_len = 0 + cfg.data.trunc_stride = 0 + cfg.data.cache_dir = '' + cfg.data.num_contrast = 0 + cfg.data.debug = False + # --------------- outdated configs --------------- # TODO: delete this code block cfg.data.loader = '' diff --git a/federatedscope/core/configs/cfg_evaluation.py b/federatedscope/core/configs/cfg_evaluation.py index 09b9cdd48..e3a238e71 100644 --- a/federatedscope/core/configs/cfg_evaluation.py +++ b/federatedscope/core/configs/cfg_evaluation.py @@ -22,6 +22,10 @@ def extend_evaluation_cfg(cfg): cfg.eval.count_flops = True + # fednlp + cfg.eval.result_path = '' + cfg.eval.temp_dir = '' + # ---------------------------------------------------------------------- # # wandb related options # ---------------------------------------------------------------------- # diff --git a/federatedscope/core/configs/cfg_model.py b/federatedscope/core/configs/cfg_model.py index b11b082b9..56fe98412 100644 --- a/federatedscope/core/configs/cfg_model.py +++ b/federatedscope/core/configs/cfg_model.py @@ -24,6 +24,17 @@ def extend_model_cfg(cfg): cfg.model.num_user = 0 cfg.model.input_shape = () # A tuple, e.g., (in_channel, h, w) + # fednlp + cfg.model.model_type = '' + cfg.model.bos_token = '' + cfg.model.eos_token = '' + cfg.model.eoq_token = '' + cfg.model.pad_token = '' + cfg.model.bos_token_id = -1 + cfg.model.eos_token_id = -1 + cfg.model.eoq_token_id = -1 + cfg.model.pad_token_id = -1 + # ---------------------------------------------------------------------- # # Criterion related options # ---------------------------------------------------------------------- # diff --git a/federatedscope/core/configs/cfg_training.py b/federatedscope/core/configs/cfg_training.py index e4089a15a..57a1d2e44 100644 --- a/federatedscope/core/configs/cfg_training.py +++ b/federatedscope/core/configs/cfg_training.py @@ -7,9 +7,14 @@ def extend_training_cfg(cfg): # Trainer related options # ---------------------------------------------------------------------- # cfg.trainer = CN() - cfg.trainer.type = 'general' + # fednlp + cfg.trainer.disp_freq = 50 + cfg.trainer.val_freq = 100000000 # eval freq across batches + cfg.trainer.grad_accum_count = 1 + cfg.trainer.train_steps = 1 + # ---------------------------------------------------------------------- # # Training related options # ---------------------------------------------------------------------- # @@ -26,6 +31,9 @@ def extend_training_cfg(cfg): cfg.train.scheduler = CN(new_allowed=True) cfg.train.scheduler.type = '' + # fednlp + cfg.train.scheduler.warmup_ratio = 0. + # ---------------------------------------------------------------------- # # Finetune related options # ---------------------------------------------------------------------- # diff --git a/federatedscope/nlp/dataloader/__init__.py b/federatedscope/nlp/dataloader/__init__.py index 079fd07cb..c0b31382d 100644 --- a/federatedscope/nlp/dataloader/__init__.py +++ b/federatedscope/nlp/dataloader/__init__.py @@ -1,3 +1,8 @@ -from federatedscope.nlp.dataloader.dataloader import load_nlp_dataset +from os.path import dirname, basename, isfile, join +import glob -__all__ = ['load_nlp_dataset'] +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules + if isfile(f) and not f.endswith('__init__.py') +] diff --git a/federatedscope/nlp/dataloader/data_collator.py b/federatedscope/nlp/dataloader/data_collator.py new file mode 100644 index 000000000..34fc024ac --- /dev/null +++ b/federatedscope/nlp/dataloader/data_collator.py @@ -0,0 +1,237 @@ +import math +import numpy as np +import torch +from numpy.random import permutation, poisson + + +class DataCollatorForMLM(object): + def __init__(self, tokenizer, mlm_probability=0.15): + self.tokenizer = tokenizer + self.mlm_probability = mlm_probability + + def __call__(self, examples): + """ Prepare masked tokens inputs/labels for masked language + modeling: 80% MASK, 10% random, 10% original. """ + examples = {k: torch.stack([x[k] for x in examples]) + for k in examples[0].keys()} + token_ids = examples['token_ids'] + attention_mask = examples['attention_mask'] + labels = token_ids.clone() + + # We sample a few tokens in each sequence for masked-LM training + # (with probability self.mlm_probability defaults to 0.15 in + # Bert/RoBERTa) + probability_matrix = torch.full(labels.shape, self.mlm_probability) + special_tokens_mask = [ + self.tokenizer.get_special_tokens_mask( + val, already_has_special_tokens=True) for val in labels.tolist() + ] + probability_matrix.masked_fill_( + torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) + if self.tokenizer._pad_token is not None: + padding_mask = labels.eq(self.tokenizer.pad_token_id) + probability_matrix.masked_fill_(padding_mask, value=0.0) + masked_indices = torch.bernoulli(probability_matrix).bool() + labels[~masked_indices] = -100 # We only compute loss on masked tokens + + # 80% of the time, we replace masked input tokens with + # tokenizer.mask_token ([MASK]) + indices_replaced = torch.bernoulli( + torch.full(labels.shape, 0.8)).bool() & masked_indices + token_ids[indices_replaced] = self.tokenizer.convert_tokens_to_ids( + self.tokenizer.mask_token) + + # 10% of the time, we replace masked input tokens with random word + indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() \ + & masked_indices & ~indices_replaced + random_words = torch.randint(len(self.tokenizer), labels.shape, + dtype=torch.long) + token_ids[indices_random] = random_words[indices_random] + + # The rest of the time (10% of the time) we keep the masked input + # tokens unchanged + return {'token_ids': token_ids, + 'attention_mask': attention_mask, + 'labels': labels, + 'example_indices': examples['example_indices']} + + +class DataCollatorForDenoisingTasks(object): + """Data collator used denoising language modeling task in BART. + The implementation is based on + https://github.com/pytorch/fairseq/blob/ + 1bba712622b8ae4efb3eb793a8a40da386fe11d0/fairseq/data/denoising_dataset.py. + The default paramters is based on BART paper + https://arxiv.org/abs/1910.13461. + """ + def __init__(self, tokenizer, mask_ratio=0.3, poisson_lambda=3.0, + permutate_sentence_ratio=1.0): + self.tokenizer = tokenizer + self.mask_ratio = mask_ratio + self.poisson_lambda = poisson_lambda + self.permutate_sentence_ratio = permutate_sentence_ratio + + def __call__(self, examples): + examples = {k: torch.stack([x[k] for x in examples]) + for k in examples[0].keys()} + token_ids = examples['token_ids'].numpy() + attention_mask = examples['attention_mask'].numpy() + labels = token_ids.copy() + + do_permutate = False + if self.permutate_sentence_ratio > 0.0: + permute_sent = self.permutate_sentences(token_ids[:, 1:]) + for i, s in enumerate(permute_sent): + token_ids[i, 1:] = s + do_permutate = True + + if self.mask_ratio: + token_ids, _ = self.add_whole_word_mask(token_ids, do_permutate) + num_non_padding = np.sum( + token_ids != self.tokenizer.pad_token_id, axis=-1) + for i in range(len(attention_mask)): + attention_mask[i][num_non_padding[i]:] = 0 + + token_ids = torch.from_numpy(token_ids) + attention_mask = torch.from_numpy(attention_mask) + labels = torch.from_numpy(labels) + return {'token_ids': token_ids, + 'attention_mask': attention_mask, + 'labels': labels, + 'example_indices': examples['example_indices']} + + def permutate_sentences(self, inputs): + results = inputs.copy() + + for i in range(inputs.shape[0]): + full_stops = (inputs[i] == self.tokenizer.eoq_token_id) | ( + inputs[i] == self.tokenizer.eos_token_id) + full_stops = full_stops[None, :] + sentence_ends = np.argwhere(full_stops[:, 1:] * ~full_stops[:, :-1]) + if len(sentence_ends) == 0: + continue + + sentence_ends[:, 1] += 2 + num_sentences = np.unique( + sentence_ends[:, 0], return_counts=True)[1] + num_to_permute = np.ceil( + (num_sentences * 2 * self.permutate_sentence_ratio) / + 2.0).astype(int) + sentence_ends = np.split( + sentence_ends[:, 1], np.unique( + sentence_ends[:, 0], return_index=True)[1][1:]) + + substitutions = np.random.permutation(num_sentences[0])[ + :num_to_permute[0]] + ordering = np.arange(0, num_sentences[0]) + ordering[substitutions] = substitutions[np.random.permutation( + num_to_permute[0])] + + index = 0 + for j in ordering: + sentence = inputs[i, (sentence_ends[0][j - 1] if j > 0 else + 0) : sentence_ends[0][j]] + results[i, index : index + sentence.shape[0]] = sentence + index += sentence.shape[0] + + num_non_padding = np.sum(results != self.tokenizer.pad_token_id, + axis=-1) + eos_indices = np.where(results == self.tokenizer.eos_token_id)[1] + for i, (idx1, idx2) in enumerate(zip(eos_indices, num_non_padding)): + results[i][idx1] = self.tokenizer.eoq_token_id + results[i][idx2 - 1] = self.tokenizer.eos_token_id + + return results + + def add_whole_word_mask(self, inputs, do_permutate): + labels = inputs.copy() + inputs = inputs.copy() + + special_tokens_mask = [ + self.tokenizer.get_special_tokens_mask( + val,already_has_special_tokens=True) for val in labels.tolist() + ] + special_tokens_mask = np.array(special_tokens_mask, dtype=bool) + + # determine how many tokens we need to mask in total + is_token = ~(labels == self.tokenizer.pad_token_id) & \ + ~special_tokens_mask + num_to_mask = int(math.ceil(is_token.astype(float).sum() * + self.mask_ratio)) + if num_to_mask == 0: + return inputs, labels + + # generate a sufficient number of span lengths + lengths = poisson(lam=self.poisson_lambda, size=(num_to_mask,)) + while np.cumsum(lengths, 0)[-1] < num_to_mask: + lengths = np.concatenate([lengths, poisson( + lam=self.poisson_lambda, size=(num_to_mask,))]) + + # remove all spans of length 0 + # Note that BART inserts additional mask tokens where length == 0, + # which we do not implement for now as it adds additional complexity + lengths = lengths[lengths > 0] + + # trim to about num_to_mask tokens + idx = np.argmin(np.abs(np.cumsum(lengths, 0) - num_to_mask)) + 1 + lengths = lengths[: idx + 1] + + # select span start indices + token_indices = np.argwhere(is_token == 1) + span_starts = permutation(token_indices.shape[0])[: lengths.shape[0]] + + # prepare mask + masked_indices = np.array(token_indices[span_starts]) + mask = np.full_like(labels, fill_value=False) + + # mask span start indices + for mi in masked_indices: + mask[tuple(mi)] = True + lengths -= 1 + + # fill up spans + max_index = labels.shape[1] - 1 + remaining = (lengths > 0) & (masked_indices[:, 1] < max_index) + while np.any(remaining): + masked_indices[remaining, 1] += 1 + for mi in masked_indices: + mask[tuple(mi)] = True + lengths -= 1 + remaining = (lengths > 0) & (masked_indices[:, 1] < max_index) + + # place the mask tokens + mask[np.where(special_tokens_mask)] = False + inputs[np.where(mask)] = self.tokenizer.mask_token_id + + if not do_permutate: + labels[np.where(mask)] = -100 + else: + labels[np.where(special_tokens_mask)] = -100 + + # remove mask tokens that are not starts of spans + to_remove = (mask == 1) & np.roll((mask == 1), 1, 1) + new_inputs = np.full_like( + labels, fill_value=self.tokenizer.pad_token_id) + + # splits = list(map(lambda x: x.reshape(-1), np.split(inputs_copy, + # indices_or_sections=2, axis=0)) + for i, example in enumerate(np.split( + inputs, indices_or_sections=new_inputs.shape[0], axis=0)): + new_example = example[0][~to_remove[i]] + new_inputs[i, 0 : new_example.shape[0]] = new_example + + # batching now fixed + return new_inputs, labels + + +class DataCollatorForPFedNLP(object): + def __init__(self, tokenizer, mlm_probability=0.15, mask_ratio=0.3, + poisson_lambda=3.0, permutate_sentence_ratio=1.0): + self.mlm_collator = DataCollatorForMLM(tokenizer, mlm_probability) + self.denoise_collator = DataCollatorForDenoisingTasks( + tokenizer, mask_ratio, poisson_lambda, permutate_sentence_ratio) + + def __call__(self, examples): + mlm_results = self.mlm_collator(examples) + denoise_results = self.denoise_collator(examples) + return {'mlm': mlm_results, 'denoise': denoise_results} diff --git a/federatedscope/nlp/dataloader/hfl_dataloader.py b/federatedscope/nlp/dataloader/hfl_dataloader.py new file mode 100644 index 000000000..e9f90c0b2 --- /dev/null +++ b/federatedscope/nlp/dataloader/hfl_dataloader.py @@ -0,0 +1,389 @@ +import os +import logging +import copy +from tqdm import tqdm +from torch.utils.data import DataLoader +from federatedscope.register import register_data +from federatedscope.nlp.dataset.preprocess.get_hfl_data import HFLDataProcessor +from federatedscope.nlp.dataset.utils import setup_tokenizer +from federatedscope.nlp.dataset.imdb import create_imdb_dataset +from federatedscope.nlp.dataset.agnews import create_agnews_dataset +from federatedscope.nlp.dataset.squad import create_squad_dataset +from federatedscope.nlp.dataset.newsqa import create_newsqa_dataset +from federatedscope.nlp.dataset.cnndm import create_cnndm_dataset +from federatedscope.nlp.dataset.msqg import create_msqg_dataset +from federatedscope.nlp.dataloader.data_collator import DataCollatorForMLM, \ + DataCollatorForDenoisingTasks, DataCollatorForPFedNLP + +logger = logging.getLogger(__name__) + + +def extend_cfg(cfg, cfg_client): + # config + cfg.eval.result_path = cfg.outdir + cfg.eval.temp_dir = os.path.join(cfg.outdir, 'temp') + os.makedirs(cfg.eval.temp_dir, exist_ok=True) + if cfg.federate.save_to: + cfg.federate.save_to = os.path.join(cfg.outdir, cfg.federate.save_to) + save_dir = cfg.federate.save_to + os.makedirs(save_dir, exist_ok=True) + + if cfg.model.task == 'pretrain': + downstream_tasks = [] + for group_id, num_clients in enumerate(cfg.data.num_grouped_clients): + downstream_tasks += [cfg.model.downstream_tasks[group_id]] * \ + num_clients + cfg.model.downstream_tasks = downstream_tasks + elif cfg.aggregator.num_agg_topk is not None: + num_agg_topk = [] + for group_id, num_clients in enumerate(cfg.data.num_grouped_clients): + if isinstance(cfg.aggregator.num_agg_topk, list): + num_agg_topk += [cfg.aggregator.num_agg_topk[group_id]] * \ + num_clients + else: + num_agg_topk += [cfg.aggregator.num_agg_topk] * num_clients + cfg.aggregator.num_agg_topk = num_agg_topk + + tokenizer = setup_tokenizer(cfg) + cfg.model.bos_token_id = tokenizer.bos_token_id + cfg.model.eos_token_id = tokenizer.eos_token_id + cfg.model.eoq_token_id = tokenizer.eoq_token_id + cfg.model.pad_token_id = tokenizer.pad_token_id + + if cfg.data.debug: + if cfg.federate.total_round_num > 5: + cfg.federate.total_round_num = 5 + cfg.federate.save_to = '' + if cfg.data.num_contrast is not None and cfg.data.num_contrast > 20: + cfg.data.num_contrast = 20 + cfg.data.cache_dir = '' + cfg.trainer.train_steps = 5 + + # client_config + with open(os.path.join(cfg.outdir, 'config_client.yaml'), 'w') as outfile: + from contextlib import redirect_stdout + with redirect_stdout(outfile): + tmp_cfg = copy.deepcopy(cfg_client) + tmp_cfg.cfg_check_funcs = [] + print(tmp_cfg.dump()) + + num_grouped_clients = cfg.data.num_grouped_clients + client_start_id = 1 + for group_id, num_clients in enumerate(num_grouped_clients): + group_cfg = cfg_client['client_group_{}'.format(group_id + 1)] + if cfg.data.debug: + group_cfg.trainer.train_steps = 5 + for client_id in range(client_start_id, client_start_id + num_clients): + cfg_client['client_{}'.format(client_id)] = group_cfg + client_start_id += num_clients + + return cfg, cfg_client + + +def create_data(data, split, tokenizer, task, model_type, max_seq_len, + max_query_len, trunc_stride, max_tgt_len, cache_dir, + client_id, pretrain, debug): + if task == 'imdb': + create_dataset_func = create_imdb_dataset + elif task == 'agnews': + create_dataset_func = create_agnews_dataset + elif task == 'squad': + create_dataset_func = create_squad_dataset + elif task == 'newsqa': + create_dataset_func = create_newsqa_dataset + elif task == 'cnndm': + create_dataset_func = create_cnndm_dataset + elif task == 'msqg': + create_dataset_func = create_msqg_dataset + else: + raise ValueError(f'No HFL dataset named {task}') + + return create_dataset_func(data=data, + split=split, + tokenizer=tokenizer, + max_seq_len=max_seq_len, + max_query_len=max_query_len, + trunc_stride=trunc_stride, + max_src_len=max_seq_len, + max_tgt_len=max_tgt_len, + model_type=model_type, + cache_dir=cache_dir, + raw_cache_dir=cache_dir, + client_id=client_id, + pretrain=pretrain, + debug=debug) + + +def load_fednlp_data(config, client_config): + extend_cfg(config, client_config) + model_type = config.model.model_type + tokenizer = setup_tokenizer(config) + pretrain = config.model.task == 'pretrain' + cache_dir = config.data.cache_dir if config.data.cache_dir else '' + debug = config.data.debug + data_collator = None + if pretrain: + if config.model.pretrain_task == 'mlm': + data_collator = DataCollatorForMLM(tokenizer=tokenizer) + elif config.model.pretrain_task == 'denoise': + data_collator = DataCollatorForDenoisingTasks(tokenizer=tokenizer) + else: + raise ValueError(f'Pretrain task {config.model.pretrain_task} is ' + f'not supported') + + logger.info(f'Preprocessing dataset {config.data.type}') + data_processor = HFLDataProcessor(config) + all_data = data_processor.get_data() + all_data_dict = {'train': all_data[0], + 'val': all_data[1], + 'test': all_data[2]} + + data_dict = dict() + for client_id in tqdm(range(1, config.federate.client_num + 1)): + cfg_client = config if pretrain else \ + client_config['client_{}'.format(client_id)] + cur_task = cfg_client.model.downstream_tasks[client_id - 1] \ + if pretrain else cfg_client.model.task + train_data, val_data, test_data = [ + create_data( + data=all_data_dict[split][client_id - 1], + split=split, + tokenizer=tokenizer, + task=cur_task, + model_type=model_type, + max_seq_len=getattr(cfg_client.data, 'max_seq_len', None), + max_query_len=getattr(cfg_client.data, 'max_query_len', None), + trunc_stride=getattr(cfg_client.data, 'trunc_stride', None), + max_tgt_len=getattr(cfg_client.data, 'max_tgt_len', None), + cache_dir=cache_dir, + client_id=client_id, + pretrain=pretrain, + debug=debug) + for split in ['train', 'val', 'test'] + ] + + dataloader_dict = { + 'train': { + 'dataloader': DataLoader( + dataset=train_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=config.data.shuffle, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), + 'encoded': train_data[1], + 'examples': train_data[2]}, + 'val': { + 'dataloader': DataLoader( + dataset=val_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), + 'encoded': val_data[1], + 'examples': val_data[2]}, + 'test': { + 'dataloader': DataLoader( + dataset=test_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), + 'encoded': test_data[1], + 'examples': test_data[2]}, + } + data_dict[client_id] = dataloader_dict + + return data_dict, config + + +def load_pfednlp_data(config, client_config): + extend_cfg(config, client_config) + model_type = config.model.model_type + tokenizer = setup_tokenizer(config) + pretrain = config.model.task == 'pretrain' + cache_dir = config.data.cache_dir if config.data.cache_dir else '' + debug = config.data.debug + data_collator = DataCollatorForPFedNLP(tokenizer=tokenizer) \ + if pretrain else None + + logger.info(f'Preprocessing dataset {config.data.type}') + data_processor = HFLDataProcessor(config) + all_data = data_processor.get_data() + all_data_dict = {'train': all_data[0], + 'val': all_data[1], + 'test': all_data[2]} + + data_dict = dict() + for client_id in tqdm(range(1, config.federate.client_num + 1)): + cfg_client = config if pretrain else \ + client_config['client_{}'.format(client_id)] + cur_task = cfg_client.model.downstream_tasks[client_id - 1] \ + if pretrain else cfg_client.model.task + train_data, val_data, test_data = [ + create_data( + data=all_data_dict[split][client_id - 1], + split=split, + tokenizer=tokenizer, + task=cur_task, + model_type=model_type, + max_seq_len=getattr(cfg_client.data, 'max_seq_len', None), + max_query_len=getattr(cfg_client.data, 'max_query_len', None), + trunc_stride=getattr(cfg_client.data, 'trunc_stride', None), + max_tgt_len=getattr(cfg_client.data, 'max_tgt_len', None), + cache_dir=cache_dir, + client_id=client_id, + pretrain=pretrain, + debug=debug) + for split in ['train', 'val', 'test'] + ] + + dataloader_dict = { + 'train': { + 'dataloader': DataLoader( + dataset=train_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=config.data.shuffle, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), + 'encoded': train_data[1], + 'examples': train_data[2]}, + 'val': { + 'dataloader': DataLoader( + dataset=val_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), + 'encoded': val_data[1], + 'examples': val_data[2]}, + 'test': { + 'dataloader': DataLoader( + dataset=test_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), + 'encoded': test_data[1], + 'examples': test_data[2]}, + } + data_dict[client_id] = dataloader_dict + + return data_dict, config + + +def load_pfednlp_contrast_data(config, client_config): + extend_cfg(config, client_config) + model_type = config.model.model_type + tokenizer = setup_tokenizer(config) + pretrain = config.model.task == 'pretrain' + cache_dir = config.data.cache_dir if config.data.cache_dir else '' + debug = config.data.debug + data_collator = DataCollatorForPFedNLP(tokenizer=tokenizer) \ + if pretrain else None + + logger.info(f'Preprocessing dataset {config.data.type}') + data_processor = HFLDataProcessor(config) + all_data = data_processor.get_data() + all_data_dict = {'train': all_data[0], + 'val': all_data[1], + 'test': all_data[2]} + + data_dict = dict() + for client_id in tqdm(range(1, config.federate.client_num + 1)): + cfg_client = config if pretrain else \ + client_config['client_{}'.format(client_id)] + cur_task = cfg_client.model.downstream_tasks[client_id - 1] \ + if pretrain else cfg_client.model.task + train_data, val_data, test_data = [ + create_data( + data=all_data_dict[split][client_id - 1], + split=split, + tokenizer=tokenizer, + task=cur_task, + model_type=model_type, + max_seq_len=getattr(cfg_client.data, 'max_seq_len', None), + max_query_len=getattr(cfg_client.data, 'max_query_len', None), + trunc_stride=getattr(cfg_client.data, 'trunc_stride', None), + max_tgt_len=getattr(cfg_client.data, 'max_tgt_len', None), + cache_dir=cache_dir, + client_id=client_id, + pretrain=pretrain, + debug=debug) + for split in ['train', 'val', 'test'] + ] + + dataloader_dict = { + 'train_raw': { + 'dataloader': DataLoader( + dataset=train_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=config.data.shuffle, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), + 'encoded': train_data[1], + 'examples': train_data[2]}, + 'train_contrast': { + 'dataloader': DataLoader( + dataset=train_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), + 'encoded': train_data[1], + 'examples': train_data[2]}, + 'val': { + 'dataloader': DataLoader( + dataset=val_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), + 'encoded': val_data[1], + 'examples': val_data[2]}, + 'test': { + 'dataloader': DataLoader( + dataset=test_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), + 'encoded': test_data[1], + 'examples': test_data[2]}, + } + data_dict[client_id] = dataloader_dict + + return data_dict, config + + +def call_fednlp_data(config, client_config): + if config.data.type == 'fednlp_data': + data, modified_config = load_fednlp_data(config, client_config) + return data, modified_config + + +def call_pfednlp_data(config, client_config): + if config.data.type == 'pfednlp_data': + data, modified_config = load_pfednlp_data(config, client_config) + return data, modified_config + + +def call_pfednlp_contrast_data(config, client_config): + if config.data.type == 'pfednlp_contrast_data': + data, modified_config = load_pfednlp_contrast_data( + config, client_config) + return data, modified_config + + +register_data('fednlp_data', call_fednlp_data) +register_data('pfednlp_data', call_pfednlp_data) +register_data('pfednlp_contrast_data', call_pfednlp_contrast_data) diff --git a/federatedscope/nlp/dataset/agnews.py b/federatedscope/nlp/dataset/agnews.py new file mode 100644 index 000000000..c5ed9dcf4 --- /dev/null +++ b/federatedscope/nlp/dataset/agnews.py @@ -0,0 +1,95 @@ +import os +import os.path as osp +import logging +import torch +from federatedscope.nlp.dataset.utils import split_sent, DictDataset, NUM_DEBUG + +logger = logging.getLogger(__name__) + + +def create_agnews_examples(data, debug=False): + if debug: + data = data[:NUM_DEBUG] + examples = [] + for ex in data: + examples.append((ex['text'], ex['label'])) + return examples + + +def create_agnews_dataset(data, split, tokenizer, max_seq_len, cache_dir='', + client_id=None, pretrain=False, debug=False, + **kwargs): + if pretrain: + return create_agnews_pretrain_dataset( + data, split, tokenizer, max_seq_len, cache_dir, client_id, debug) + + save_dir = osp.join(cache_dir, 'finetune', str(client_id)) + cache_file = osp.join(save_dir, split + '.pt') + if osp.exists(cache_file): + logger.info('Loading cache file from \'{}\''.format(cache_file)) + cache_data = torch.load(cache_file) + examples = cache_data['examples'] + encoded_inputs = cache_data['encoded_inputs'] + else: + examples = create_agnews_examples(data, debug) + texts = [ex[0] for ex in examples] + encoded_inputs = tokenizer(texts, + padding='max_length', + truncation=True, + max_length=max_seq_len, + return_tensors='pt') + + if cache_dir: + logger.info('Saving cache file to \'{}\''.format(cache_file)) + os.makedirs(save_dir, exist_ok=True) + torch.save({'examples': examples, + 'encoded_inputs': encoded_inputs}, cache_file) + + labels = [ex[1] for ex in examples] + example_indices = torch.arange(encoded_inputs.input_ids.size(0), + dtype=torch.long) + dataset = DictDataset({'token_ids': encoded_inputs.input_ids, + 'token_type_ids': encoded_inputs.token_type_ids, + 'attention_mask': encoded_inputs.attention_mask, + 'labels': torch.LongTensor(labels), + 'example_indices': example_indices}) + return dataset, encoded_inputs, examples + + +def create_agnews_pretrain_dataset(data, split, tokenizer, max_seq_len, + cache_dir='', client_id=None, debug=False): + save_dir = osp.join(cache_dir, 'pretrain', str(client_id)) + cache_file = osp.join(save_dir, split + '.pt') + + if osp.exists(cache_file): + logger.info('Loading cache file from \'{}\''.format(cache_file)) + cache_data = torch.load(cache_file) + examples = cache_data['examples'] + encoded_inputs = cache_data['encoded_inputs'] + else: + examples = create_agnews_examples(data, debug) + texts = [ex[0] for ex in examples] + texts = split_sent(texts, eoq=tokenizer.eoq_token) + encoded_inputs = tokenizer(texts, + padding='max_length', + truncation=True, + max_length=max_seq_len, + return_tensors='pt') + num_non_padding = (encoded_inputs.input_ids != + tokenizer.pad_token_id).sum(dim=-1) + for i, pad_idx in enumerate(num_non_padding): + encoded_inputs.input_ids[i, 0] = tokenizer.bos_token_id + encoded_inputs.input_ids[i, pad_idx - 1] = tokenizer.eos_token_id + + if cache_dir: + logger.info('Saving cache file to \'{}\''.format(cache_file)) + os.makedirs(save_dir, exist_ok=True) + torch.save({'examples': examples, + 'encoded_inputs': encoded_inputs}, cache_file) + + example_indices = torch.arange(encoded_inputs.input_ids.size(0), + dtype=torch.long) + dataset = DictDataset({'token_ids': encoded_inputs.input_ids, + 'attention_mask': encoded_inputs.attention_mask, + 'example_indices': example_indices}) + return dataset, encoded_inputs, examples diff --git a/federatedscope/nlp/dataset/cnndm.py b/federatedscope/nlp/dataset/cnndm.py new file mode 100644 index 000000000..f03a9871d --- /dev/null +++ b/federatedscope/nlp/dataset/cnndm.py @@ -0,0 +1,186 @@ +import os +import os.path as osp +import logging +import torch +import numpy as np +from federatedscope.nlp.dataset.utils import split_sent, DictDataset, NUM_DEBUG + +logger = logging.getLogger(__name__) + + +def create_cnndm_examples(data, debug=False): + if debug: + data = data[:NUM_DEBUG] + src_examples, tgt_examples = [], [] + for ex in data: + src_examples.append(ex['src']) + tgt_examples.append(ex['tgt']) + return src_examples, tgt_examples + + +def create_cnndm_dataset(data, split, tokenizer, max_src_len, max_tgt_len, + raw_cache_dir='', client_id=None, pretrain=False, + debug=False, **kwargs): + if pretrain: + return create_cnndm_pretrain_dataset( + data, split, tokenizer, max_src_len, raw_cache_dir, client_id, + debug) + + cache_dir = osp.join(raw_cache_dir, 'finetune', str(client_id), split) + src_examples, tgt_examples = create_cnndm_examples(data, debug) + if osp.exists(cache_dir): + logger.info('Loading cache file from \'{}\''.format(cache_dir)) + token_ids = np.memmap( + filename=osp.join(cache_dir, 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + token_type_ids = np.memmap( + filename=osp.join(cache_dir, 'token_type_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + attention_mask = np.memmap( + filename=osp.join(cache_dir, 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + labels = np.memmap( + filename=osp.join(cache_dir, 'labels.memmap'), + shape=(len(src_examples), max_tgt_len), + mode='r', + dtype=np.int64) + + token_ids = torch.from_numpy(token_ids) + token_type_ids = torch.from_numpy(token_type_ids) + attention_mask = torch.from_numpy(attention_mask) + labels = torch.from_numpy(labels) + else: + src_encoded = tokenizer(src_examples, + padding='max_length', + truncation=True, + max_length=max_src_len, + return_tensors='pt') + tgt_examples = split_sent(tgt_examples, eoq=tokenizer.eoq_token) + tgt_encoded = tokenizer(tgt_examples, + padding='max_length', + truncation=True, + max_length=max_tgt_len, + return_tensors='pt') + num_non_padding = (tgt_encoded.input_ids != + tokenizer.pad_token_id).sum(dim=-1) + for i, pad_idx in enumerate(num_non_padding): + tgt_encoded.input_ids[i, 0] = tokenizer.bos_token_id + tgt_encoded.input_ids[i, pad_idx - 1] = tokenizer.eos_token_id + + if raw_cache_dir: + logger.info('Saving cache file to \'{}\''.format(cache_dir)) + os.makedirs(cache_dir, exist_ok=True) + token_ids = np.memmap( + filename=osp.join(cache_dir, 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + token_type_ids = np.memmap( + filename=osp.join(cache_dir, 'token_type_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + attention_mask = np.memmap( + filename=osp.join(cache_dir, 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + labels = np.memmap( + filename=osp.join(cache_dir, 'labels.memmap'), + shape=(len(src_examples), max_tgt_len), + mode='w+', + dtype=np.int64) + + for i in range(len(src_examples)): + token_ids[i] = src_encoded.input_ids[i] + token_type_ids[i] = src_encoded.token_type_ids[i] + attention_mask[i] = src_encoded.attention_mask[i] + labels[i] = tgt_encoded.input_ids[i] + + token_ids = torch.from_numpy(token_ids) + token_type_ids = torch.from_numpy(token_type_ids) + attention_mask = torch.from_numpy(attention_mask) + labels = torch.from_numpy(labels) + else: + token_ids = src_encoded.input_ids + token_type_ids = src_encoded.token_type_ids + attention_mask = src_encoded.attention_mask + labels = tgt_encoded.input_ids + + example_indices = torch.arange(token_ids.size(0), dtype=torch.long) + dataset = DictDataset({'token_ids': token_ids, + 'token_type_ids': token_type_ids, + 'attention_mask': attention_mask, + 'labels': labels, + 'example_indices': example_indices}) + return dataset, None, None + + +def create_cnndm_pretrain_dataset(data, split, tokenizer, max_src_len, + raw_cache_dir='', client_id=None, + debug=False): + cache_dir = osp.join(raw_cache_dir, 'pretrain', str(client_id), split) + src_examples, tgt_examples = create_cnndm_examples(data, debug) + if osp.exists(cache_dir): + logger.info('Loading cache file from \'{}\''.format(cache_dir)) + token_ids = np.memmap( + filename=osp.join(cache_dir, 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + attention_mask = np.memmap( + filename=osp.join(cache_dir, 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + + token_ids = torch.from_numpy(token_ids) + attention_mask = torch.from_numpy(attention_mask) + else: + src_examples = split_sent(src_examples, eoq=tokenizer.eoq_token) + src_encoded = tokenizer(src_examples, + padding='max_length', + truncation=True, + max_length=max_src_len, + return_tensors='pt') + num_non_padding = (src_encoded.input_ids != + tokenizer.pad_token_id).sum(dim=-1) + for i, pad_idx in enumerate(num_non_padding): + src_encoded.input_ids[i, 0] = tokenizer.bos_token_id + src_encoded.input_ids[i, pad_idx - 1] = tokenizer.eos_token_id + + if raw_cache_dir: + logger.info('Saving cache file to \'{}\''.format(cache_dir)) + os.makedirs(cache_dir, exist_ok=True) + token_ids = np.memmap( + filename=osp.join(cache_dir, 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + attention_mask = np.memmap( + filename=osp.join(cache_dir, 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + + for i in range(len(src_examples)): + token_ids[i] = src_encoded.input_ids[i] + attention_mask[i] = src_encoded.attention_mask[i] + + token_ids = torch.from_numpy(token_ids) + attention_mask = torch.from_numpy(attention_mask) + else: + token_ids = src_encoded.input_ids + attention_mask = src_encoded.attention_mask + + example_indices = torch.arange(token_ids.size(0), dtype=torch.long) + dataset = DictDataset({'token_ids': token_ids, + 'attention_mask': attention_mask, + 'example_indices': example_indices}) + return dataset, None, None diff --git a/federatedscope/nlp/dataset/imdb.py b/federatedscope/nlp/dataset/imdb.py new file mode 100644 index 000000000..b514037f8 --- /dev/null +++ b/federatedscope/nlp/dataset/imdb.py @@ -0,0 +1,93 @@ +import os +import os.path as osp +import logging +import torch +from federatedscope.nlp.dataset.utils import split_sent, DictDataset, NUM_DEBUG + +logger = logging.getLogger(__name__) + + +def create_imdb_examples(data, debug=False): + if debug: + data = data[:NUM_DEBUG] + examples = [] + for ex in data: + examples.append((ex['text'], ex['label'])) + return examples + + +def create_imdb_dataset(data, split, tokenizer, max_seq_len, cache_dir='', + client_id=None, pretrain=False, debug=False, **kwargs): + if pretrain: + return create_imdb_pretrain_dataset( + data, split, tokenizer, max_seq_len, cache_dir, client_id, debug) + + save_dir = osp.join(cache_dir, 'finetune', str(client_id)) + cache_file = osp.join(save_dir, split + '.pt') + if osp.exists(cache_file): + logger.info('Loading cache file from \'{}\''.format(cache_file)) + cache_data = torch.load(cache_file) + examples = cache_data['examples'] + encoded_inputs = cache_data['encoded_inputs'] + else: + examples = create_imdb_examples(data, debug) + texts = [ex[0] for ex in examples] + encoded_inputs = tokenizer(texts, + padding='max_length', + truncation=True, + max_length=max_seq_len, + return_tensors='pt') + + if cache_dir: + logger.info('Saving cache file to \'{}\''.format(cache_file)) + os.makedirs(save_dir, exist_ok=True) + torch.save({'examples': examples, + 'encoded_inputs': encoded_inputs}, cache_file) + + labels = [ex[1] for ex in examples] + example_indices = torch.arange(encoded_inputs.input_ids.size(0), + dtype=torch.long) + dataset = DictDataset({'token_ids': encoded_inputs.input_ids, + 'token_type_ids': encoded_inputs.token_type_ids, + 'attention_mask': encoded_inputs.attention_mask, + 'labels': torch.LongTensor(labels), + 'example_indices': example_indices}) + return dataset, encoded_inputs, examples + + +def create_imdb_pretrain_dataset(data, split, tokenizer, max_seq_len, + cache_dir='', client_id=None, debug=False): + save_dir = osp.join(cache_dir, 'pretrain', str(client_id)) + cache_file = osp.join(save_dir, split + '.pt') + if osp.exists(cache_file): + logger.info('Loading cache file from \'{}\''.format(cache_file)) + cache_data = torch.load(cache_file) + examples = cache_data['examples'] + encoded_inputs = cache_data['encoded_inputs'] + else: + examples = create_imdb_examples(data, debug) + texts = [ex[0] for ex in examples] + texts = split_sent(texts, eoq=tokenizer.eoq_token) + encoded_inputs = tokenizer(texts, + padding='max_length', + truncation=True, + max_length=max_seq_len, + return_tensors='pt') + num_non_padding = (encoded_inputs.input_ids != + tokenizer.pad_token_id).sum(dim=-1) + for i, pad_idx in enumerate(num_non_padding): + encoded_inputs.input_ids[i, 0] = tokenizer.bos_token_id + encoded_inputs.input_ids[i, pad_idx - 1] = tokenizer.eos_token_id + + if cache_dir: + logger.info('Saving cache file to \'{}\''.format(cache_file)) + os.makedirs(save_dir, exist_ok=True) + torch.save({'examples': examples, + 'encoded_inputs': encoded_inputs}, cache_file) + + example_indices = torch.arange(encoded_inputs.input_ids.size(0), + dtype=torch.long) + dataset = DictDataset({'token_ids': encoded_inputs.input_ids, + 'attention_mask': encoded_inputs.attention_mask, + 'example_indices': example_indices}) + return dataset, encoded_inputs, examples diff --git a/federatedscope/nlp/dataset/msqg.py b/federatedscope/nlp/dataset/msqg.py new file mode 100644 index 000000000..dab24eac7 --- /dev/null +++ b/federatedscope/nlp/dataset/msqg.py @@ -0,0 +1,189 @@ +import os +import os.path as osp +import logging +import torch +import numpy as np +from federatedscope.nlp.dataset.utils import split_sent, DictDataset, NUM_DEBUG + +logger = logging.getLogger(__name__) + + +def create_msqg_examples(data, debug=False): + if debug: + data = data[:NUM_DEBUG] + src_examples, tgt_examples = [], [] + for ex in data: + src_examples.append(ex['src']) + tgt_examples.append(ex['tgt']) + return src_examples, tgt_examples + + +def create_msqg_dataset(data, split, tokenizer, max_src_len, max_tgt_len, + raw_cache_dir='', client_id=None, pretrain=False, + debug=False, **kwargs): + if pretrain: + return create_msqg_pretrain_dataset( + data, split, tokenizer, max_src_len, raw_cache_dir, client_id, + debug) + + cache_dir = osp.join(raw_cache_dir, 'finetune', str(client_id), split) + src_examples, tgt_examples = create_msqg_examples(data, debug) + if osp.exists(cache_dir): + logger.info('Loading cache file from \'{}\''.format(cache_dir)) + token_ids = np.memmap( + filename=osp.join(cache_dir, 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + token_type_ids = np.memmap( + filename=osp.join(cache_dir, 'token_type_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + attention_mask = np.memmap( + filename=osp.join(cache_dir, 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + labels = np.memmap( + filename=osp.join(cache_dir, 'labels.memmap'), + shape=(len(src_examples), max_tgt_len), + mode='r', + dtype=np.int64) + + token_ids = torch.from_numpy(token_ids) + token_type_ids = torch.from_numpy(token_type_ids) + attention_mask = torch.from_numpy(attention_mask) + labels = torch.from_numpy(labels) + else: + src_encoded = tokenizer(src_examples, + padding='max_length', + truncation=True, + max_length=max_src_len, + return_tensors='pt') + tgt_examples = split_sent(tgt_examples, + eoq=tokenizer.eoq_token, + tokenize=False) + tgt_encoded = tokenizer(tgt_examples, + padding='max_length', + truncation=True, + max_length=max_tgt_len, + return_tensors='pt') + num_non_padding = (tgt_encoded.input_ids != + tokenizer.pad_token_id).sum(dim=-1) + for i, pad_idx in enumerate(num_non_padding): + tgt_encoded.input_ids[i, 0] = tokenizer.bos_token_id + tgt_encoded.input_ids[i, pad_idx - 1] = tokenizer.eos_token_id + + if raw_cache_dir: + logger.info('Saving cache file to \'{}\''.format(cache_dir)) + os.makedirs(cache_dir, exist_ok=True) + token_ids = np.memmap( + filename=osp.join(cache_dir, 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + token_type_ids = np.memmap( + filename=osp.join(cache_dir, 'token_type_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + attention_mask = np.memmap( + filename=osp.join(cache_dir, 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + labels = np.memmap( + filename=osp.join(cache_dir, 'labels.memmap'), + shape=(len(src_examples), max_tgt_len), + mode='w+', + dtype=np.int64) + + for i in range(len(src_examples)): + token_ids[i] = src_encoded.input_ids[i] + token_type_ids[i] = src_encoded.token_type_ids[i] + attention_mask[i] = src_encoded.attention_mask[i] + labels[i] = tgt_encoded.input_ids[i] + + token_ids = torch.from_numpy(token_ids) + token_type_ids = torch.from_numpy(token_type_ids) + attention_mask = torch.from_numpy(attention_mask) + labels = torch.from_numpy(labels) + + else: + token_ids = src_encoded.input_ids + token_type_ids = src_encoded.token_type_ids + attention_mask = src_encoded.attention_mask + labels = tgt_encoded.input_ids + + example_indices = torch.arange(token_ids.size(0), dtype=torch.long) + dataset = DictDataset({'token_ids': token_ids, + 'token_type_ids': token_type_ids, + 'attention_mask': attention_mask, + 'labels': labels, + 'example_indices': example_indices}) + return dataset, None, None + + +def create_msqg_pretrain_dataset(data, split, tokenizer, max_src_len, + raw_cache_dir='', client_id=None, debug=False): + cache_dir = osp.join(raw_cache_dir, 'pretrain', str(client_id), split) + src_examples, tgt_examples = create_msqg_examples(data, debug) + if osp.exists(cache_dir): + logger.info('Loading cache file from \'{}\''.format(cache_dir)) + token_ids = np.memmap( + filename=osp.join(cache_dir, 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + attention_mask = np.memmap( + filename=osp.join(cache_dir, 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + token_ids = torch.from_numpy(token_ids) + attention_mask = torch.from_numpy(attention_mask) + else: + src_examples = split_sent(src_examples, + eoq=tokenizer.eoq_token, + tokenize=False) + src_encoded = tokenizer(src_examples, + padding='max_length', + truncation=True, + max_length=max_src_len, + return_tensors='pt') + num_non_padding = (src_encoded.input_ids != + tokenizer.pad_token_id).sum(dim=-1) + for i, pad_idx in enumerate(num_non_padding): + src_encoded.input_ids[i, 0] = tokenizer.bos_token_id + src_encoded.input_ids[i, pad_idx - 1] = tokenizer.eos_token_id + + if raw_cache_dir: + logger.info('Saving cache file to \'{}\''.format(cache_dir)) + os.makedirs(cache_dir, exist_ok=True) + token_ids = np.memmap( + filename=osp.join(cache_dir, 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + attention_mask = np.memmap( + filename=osp.join(cache_dir, 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + + for i in range(len(src_examples)): + token_ids[i] = src_encoded.input_ids[i] + attention_mask[i] = src_encoded.attention_mask[i] + + token_ids = torch.from_numpy(token_ids) + attention_mask = torch.from_numpy(attention_mask) + else: + token_ids = src_encoded.input_ids + attention_mask = src_encoded.attention_mask + + example_indices = torch.arange(token_ids.size(0), dtype=torch.long) + dataset = DictDataset({'token_ids': token_ids, + 'attention_mask': attention_mask, + 'example_indices': example_indices}) + return dataset, None, None diff --git a/federatedscope/nlp/dataset/newsqa.py b/federatedscope/nlp/dataset/newsqa.py new file mode 100644 index 000000000..67523b682 --- /dev/null +++ b/federatedscope/nlp/dataset/newsqa.py @@ -0,0 +1,362 @@ +import os +import os.path as osp +import torch +import logging +from federatedscope.nlp.dataset.utils import split_sent, DictDataset, NUM_DEBUG + +logger = logging.getLogger(__name__) + + +class NewsQAExample(object): + def __init__(self, qa_id, question, context, train_answer, val_answer, + start_pos, end_pos, context_tokens, is_impossible): + self.qa_id = qa_id + self.question = question + self.context = context + self.train_answer = train_answer + self.val_answer = val_answer + self.start_position = start_pos + self.end_position = end_pos + self.context_tokens = context_tokens + self.is_impossible = is_impossible + + +class NewsQAEncodedInput(object): + def __init__(self, token_ids, token_type_ids, attention_mask, + overflow_token_ids): + self.token_ids = token_ids + self.token_type_ids = token_type_ids + self.attention_mask = attention_mask + self.overflow_token_ids = overflow_token_ids + + +class NewsQAResult(object): + def __init__(self, unique_id, start_logits, end_logits): + self.unique_id = unique_id + self.start_logits = start_logits + self.end_logits = end_logits + + +def refine_subtoken_position(context_subtokens, subtoken_start_pos, + subtoken_end_pos, tokenizer, annotated_answer): + subtoken_answer = ' '.join(tokenizer.tokenize(annotated_answer)) + for new_st in range(subtoken_start_pos, subtoken_end_pos + 1): + for new_ed in range(subtoken_end_pos, subtoken_start_pos - 1, -1): + text_span = ' '.join(context_subtokens[new_st:(new_ed + 1)]) + if text_span == subtoken_answer: + return new_st, new_ed + return subtoken_start_pos, subtoken_end_pos + + +def get_char_to_word_positions(context, answer, start_char_pos, is_impossible): + context_tokens = [] + char_to_word_offset = [] + is_prev_whitespace = True + for c in context: + is_whitespace = (c == ' ' or c == '\t' or c == '\r' or c == '\n' or + ord(c) == 0x202F) + if is_whitespace: + is_prev_whitespace = True + else: + if is_prev_whitespace: + context_tokens.append(c) + else: + context_tokens[-1] += c + is_prev_whitespace = False + char_to_word_offset.append(len(context_tokens) - 1) + + start_pos, end_pos = 0, 0 + if start_char_pos is not None and not is_impossible: + start_pos = char_to_word_offset[start_char_pos] + end_pos = char_to_word_offset[start_char_pos + len(answer) - 1] + return start_pos, end_pos, context_tokens + + +def check_max_context_token(all_spans, cur_span_idx, pos): + best_score, best_span_idx = None, None + for span_idx, span in enumerate(all_spans): + end = span.context_start_position + span.context_len - 1 + if pos < span.context_start_position or pos > end: + continue + num_left_context = pos - span.context_start_position + num_right_context = end - pos + score = min(num_left_context, num_right_context) + 0.01 * \ + span.context_len + if best_score is None or score > best_score: + best_score = score + best_span_idx = span_idx + return cur_span_idx == best_span_idx + + +def encode(tokenizer, text_a, text_b, max_seq_len, max_query_len, + added_trunc_size): + def _get_token_ids(text): + if isinstance(text, str): + return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) + elif isinstance(text, (list, tuple)) and len(text) > 0 and \ + isinstance(text[0], str): + return tokenizer.convert_tokens_to_ids(text) + elif isinstance(text, (list, tuple)) and len(text) > 0 and \ + isinstance(text[0], int): + return text + else: + raise ValueError('Input is not valid, should be a string, ' + 'a list/tuple of strings or a list/tuple of ' + 'integers.') + + token_ids_a = _get_token_ids(text_a) + token_ids_b = _get_token_ids(text_b) + + # Truncate + overflow_token_ids = None + len_a = len(token_ids_a) + 2 + total_len = len(token_ids_a) + len(token_ids_b) + 3 + if len_a > max_query_len: + num_remove = len_a - max_query_len + token_ids_a = token_ids_a[:-num_remove] + if total_len > max_seq_len: + num_remove = total_len - max_seq_len + trunc_size = min(len(token_ids_b), added_trunc_size + num_remove) + overflow_token_ids = token_ids_b[-trunc_size:] + token_ids_b = token_ids_b[:-num_remove] + + # Combine and pad + token_ids = [tokenizer.cls_token_id] + \ + token_ids_a + [tokenizer.sep_token_id] + token_type_ids = [0] * len(token_ids) + token_ids += token_ids_b + [tokenizer.sep_token_id] + token_type_ids += [1] * (len(token_ids_b) + 1) + attention_mask = [1] * len(token_ids) + if len(token_ids) < max_seq_len: + dif = max_seq_len - len(token_ids) + token_ids += [tokenizer.pad_token_id] * dif + token_type_ids += [0] * dif + attention_mask += [0] * dif + + return NewsQAEncodedInput(token_ids, token_type_ids, attention_mask, + overflow_token_ids) + + +def create_newsqa_examples(data, split, debug=False): + if debug: + data = data[:NUM_DEBUG] + examples = [] + for para in data: + context = para['context'] + qa = para['qa'] + qa_id = qa['qid'] + question = qa['question'] + start_char_pos = None + train_answer = None + val_answer = [] + + is_impossible = qa['is_impossible'] if 'is_impossible' in qa else False + if not is_impossible: + answers = qa['detected_answers'] + spans = sorted([span for spans in answers + for span in spans['char_spans']]) + if split == 'train': + train_answer = context[spans[0][0]: spans[0][1] + 1] + start_char_pos = spans[0][0] + else: + val_answer = [{'text': context[spans[i][0]: spans[i][1] + 1], + 'answer_start': spans[i][0]} + for i in range(len(spans))] + + start_pos, end_pos, context_tokens = get_char_to_word_positions( + context, train_answer, start_char_pos, is_impossible) + examples.append(NewsQAExample(qa_id, question, context, train_answer, + val_answer, start_pos, end_pos, + context_tokens, is_impossible)) + return examples + + +def create_newsqa_dataset(data, split, tokenizer, max_seq_len, max_query_len, + trunc_stride, cache_dir='', client_id=None, + pretrain=False, debug=False, **kwargs): + if pretrain: + return create_newsqa_pretrain_dataset( + data, split, tokenizer, max_seq_len, cache_dir, client_id, debug) + + save_dir = osp.join(cache_dir, 'finetune', str(client_id)) + cache_file = osp.join(save_dir, split + '.pt') + if osp.exists(cache_file): + logger.info('Loading cache file from \'{}\''.format(cache_file)) + cache_data = torch.load(cache_file) + examples = cache_data['examples'] + encoded_inputs = cache_data['encoded_inputs'] + else: + examples = create_newsqa_examples(data, split, debug) + unique_id = 1000000000 + encoded_inputs = [] + for example_idx, example in enumerate(examples): + if split == 'train' and not example.is_impossible: + start_pos = example.start_position + end_pos = example.end_position + actual_answer = ' '.join(example.context_tokens[ + start_pos:(end_pos + 1)]) + cleaned_answer = ' '.join(example.train_answer.strip().split()) + if actual_answer.find(cleaned_answer) == -1: + logger.info('Could not find answer: {} vs. {}'.format( + actual_answer, cleaned_answer)) + continue + + tok_to_subtok_idx = [] + subtok_to_tok_idx = [] + context_subtokens = [] + for i, token in enumerate(example.context_tokens): + tok_to_subtok_idx.append(len(context_subtokens)) + subtokens = tokenizer.tokenize(token) + for subtoken in subtokens: + subtok_to_tok_idx.append(i) + context_subtokens.append(subtoken) + + if split == 'train' and not example.is_impossible: + subtoken_start_pos = tok_to_subtok_idx[example.start_position] + if example.end_position < len(example.context_tokens) - 1: + subtoken_end_pos = tok_to_subtok_idx[ + example.end_position + 1] - 1 + else: + subtoken_end_pos = len(context_subtokens) - 1 + subtoken_start_pos, subtoken_end_pos = refine_subtoken_position( + context_subtokens, subtoken_start_pos, subtoken_end_pos, + tokenizer, example.train_answer) + + truncated_context = context_subtokens + len_question = min(len(tokenizer.tokenize(example.question)), + max_query_len - 2) + added_trunc_size = max_seq_len - trunc_stride - len_question - 3 + spans = [] + while len(spans) * trunc_stride < len(context_subtokens): + text_a = example.question + text_b = truncated_context + encoded_input = encode(tokenizer, text_a, text_b, max_seq_len, + max_query_len, added_trunc_size) + context_start_pos = len(spans) * trunc_stride + context_len = min(len(context_subtokens) - context_start_pos, + max_seq_len - len_question - 3) + context_end_pos = context_start_pos + context_len - 1 + + if tokenizer.pad_token_id in encoded_input.token_ids: + non_padded_ids = encoded_input.token_ids[ + :encoded_input.token_ids.index( + tokenizer.pad_token_id)] + else: + non_padded_ids = encoded_input.token_ids + tokens = tokenizer.convert_ids_to_tokens(non_padded_ids) + + context_subtok_to_tok_idx = {} + for i in range(context_len): + context_idx = len_question + i + 2 + context_subtok_to_tok_idx[context_idx] = \ + subtok_to_tok_idx[context_start_pos + i] + + start_pos, end_pos = 0, 0 + span_is_impossible = example.is_impossible + if split == 'train' and not span_is_impossible: + # For training, if our document chunk does not contain + # an annotation we throw it out, since there is nothing + # to predict. + if subtoken_start_pos >= context_start_pos and \ + subtoken_end_pos <= context_end_pos: + context_offset = len_question + 2 + start_pos = subtoken_start_pos - context_start_pos + \ + context_offset + end_pos = subtoken_end_pos - context_start_pos + \ + context_offset + else: + start_pos = 0 + end_pos = 0 + span_is_impossible = True + + encoded_input.start_position = start_pos + encoded_input.end_position = end_pos + encoded_input.is_impossible = span_is_impossible + + # For computing metrics + encoded_input.example_index = example_idx + encoded_input.context_start_position = context_start_pos + encoded_input.context_len = context_len + encoded_input.tokens = tokens + encoded_input.context_subtok_to_tok_idx = \ + context_subtok_to_tok_idx + encoded_input.is_max_context_token = {} + encoded_input.unique_id = unique_id + spans.append(encoded_input) + unique_id += 1 + + if encoded_input.overflow_token_ids is None: + break + truncated_context = encoded_input.overflow_token_ids + + for span_idx in range(len(spans)): + for context_idx in range(spans[span_idx].context_len): + is_max_context_token = check_max_context_token( + spans, span_idx, span_idx * trunc_stride + context_idx) + idx = len_question + context_idx + 2 + spans[span_idx].is_max_context_token[idx] = \ + is_max_context_token + encoded_inputs.extend(spans) + + if cache_dir: + logger.info('Saving cache file to \'{}\''.format(cache_file)) + os.makedirs(save_dir, exist_ok=True) + torch.save({'examples': examples, + 'encoded_inputs': encoded_inputs}, cache_file) + + token_ids = torch.LongTensor([inp.token_ids for inp in encoded_inputs]) + token_type_ids = torch.LongTensor([inp.token_type_ids + for inp in encoded_inputs]) + attention_mask = torch.LongTensor([inp.attention_mask + for inp in encoded_inputs]) + start_positions = torch.LongTensor([inp.start_position + for inp in encoded_inputs]) + end_positions = torch.LongTensor([inp.end_position for + inp in encoded_inputs]) + + example_indices = torch.arange(token_ids.size(0), dtype=torch.long) + dataset = DictDataset({'token_ids': token_ids, + 'token_type_ids': token_type_ids, + 'attention_mask': attention_mask, + 'start_positions': start_positions, + 'end_positions': end_positions, + 'example_indices': example_indices}) + return dataset, encoded_inputs, examples + + +def create_newsqa_pretrain_dataset(data, split, tokenizer, max_seq_len, + cache_dir='', client_id=None, debug=False): + save_dir = osp.join(cache_dir, 'pretrain', str(client_id)) + cache_file = osp.join(save_dir, split + '.pt') + if osp.exists(cache_file): + logger.info('Loading cache file from \'{}\''.format(cache_file)) + cache_data = torch.load(cache_file) + examples = cache_data['examples'] + encoded_inputs = cache_data['encoded_inputs'] + else: + examples = create_newsqa_examples(data, split, debug) + texts = split_sent([e.context for e in examples], + eoq=tokenizer.eoq_token) + encoded_inputs = tokenizer(texts, + padding='max_length', + truncation=True, + max_length=max_seq_len, + return_tensors='pt') + num_non_padding = (encoded_inputs.input_ids != + tokenizer.pad_token_id).sum(dim=-1) + for i, pad_idx in enumerate(num_non_padding): + encoded_inputs.input_ids[i, 0] = tokenizer.bos_token_id + encoded_inputs.input_ids[i, pad_idx - 1] = tokenizer.eos_token_id + + if cache_dir: + logger.info('Saving cache file to \'{}\''.format(cache_file)) + os.makedirs(save_dir, exist_ok=True) + torch.save({'examples': examples, + 'encoded_inputs': encoded_inputs}, cache_file) + + example_indices = torch.arange(encoded_inputs.input_ids.size(0), + dtype=torch.long) + dataset = DictDataset({'token_ids': encoded_inputs.input_ids, + 'attention_mask': encoded_inputs.attention_mask, + 'example_indices': example_indices}) + return dataset, encoded_inputs, examples diff --git a/federatedscope/nlp/dataset/preprocess/get_hfl_data.py b/federatedscope/nlp/dataset/preprocess/get_hfl_data.py new file mode 100644 index 000000000..9757011ca --- /dev/null +++ b/federatedscope/nlp/dataset/preprocess/get_hfl_data.py @@ -0,0 +1,141 @@ +import os +import logging +import random +import csv +import json +import gzip +import zipfile +import shutil +from federatedscope.core.auxiliaries.utils import download_url + +HFL_NAMES = ['imdb', 'agnews', 'squad', 'newsqa', 'cnndm', 'msqg'] +logger = logging.getLogger(__name__) + + +class HFLDataProcessor(object): + def __init__(self, config, train_frac=0.9): + self.data_dir = config.data.root + self.datasets = config.data.datasets + self.total_client_num = config.federate.client_num + self.num_grouped_clients = config.data.num_grouped_clients + self.train_frac = train_frac + + def get_data(self): + all_train_data = [] + all_val_data = [] + all_test_data = [] + for i, dataset in enumerate(self.datasets): + if dataset not in HFL_NAMES: + raise ValueError(f'No HFL dataset named {dataset}') + train_val_data = self._load_data( + dataset, 'train', self.num_grouped_clients[i]) + train_data = [data[:int(self.train_frac * len(data))] + for data in train_val_data] + val_data = [data[int(self.train_frac * len(data)):] + for data in train_val_data] + test_data = self._load_data( + dataset, 'test', self.num_grouped_clients[i]) + all_train_data.extend(train_data) + all_val_data.extend(val_data) + all_test_data.extend(test_data) + return all_train_data, all_val_data, all_test_data + + def _load_data(self, dataset, split, num_clients): + data_dir = os.path.join(self.data_dir, dataset) + if not os.path.exists(data_dir): + self._download(dataset) + self._extract(dataset) + + # read data + data = [] + if dataset == 'imdb': + pos_files = os.listdir(os.path.join(data_dir, split, 'pos')) + neg_files = os.listdir(os.path.join(data_dir, split, 'neg')) + for file in pos_files: + path = os.path.join(data_dir, split, 'pos', file) + with open(path) as f: + line = f.readline() + data.append({'text': line, 'label': 1}) + for file in neg_files: + path = os.path.join(data_dir, split, 'neg', file) + with open(path) as f: + line = f.readline() + data.append({'text': line, 'label': 0}) + random.shuffle(data) + + elif dataset == 'agnews': + with open(os.path.join(data_dir, split + '.csv'), + encoding="utf-8") as csv_file: + csv_reader = csv.reader(csv_file, + quotechar='"', + delimiter=",", + quoting=csv.QUOTE_ALL, + skipinitialspace=True) + for i, row in enumerate(csv_reader): + label, title, description = row + label = int(label) - 1 + text = ' [SEP] '.join((title, description)) + data.append({'text': text, 'label': label}) + + elif dataset == 'squad': + with open(os.path.join(data_dir, split + '.json'), 'r', + encoding='utf-8') as reader: + raw_data = json.load(reader)['data'] + for line in raw_data: + for para in line['paragraphs']: + context = para['context'] + for qa in para['qas']: + data.append({'context': context, 'qa': qa}) + + elif dataset == 'newsqa': + with gzip.GzipFile(os.path.join(data_dir, split + '.jsonl.gz'), + 'r') as reader: + content = reader.read().decode('utf-8').strip().split('\n')[1:] + raw_data = [json.loads(line) for line in content] + for line in raw_data: + context = line['context'] + for qa in line['qas']: + data.append({'context': context, 'qa': qa}) + + elif dataset in {'cnndm', 'msqg'}: + src_file = os.path.join(data_dir, split + '.src') + tgt_file = os.path.join(data_dir, split + '.tgt') + with open(src_file) as f: + src_data = [line.strip().replace('', '[SEP]') + for line in f] + with open(tgt_file) as f: + tgt_data = [line.strip().replace('', '[SEP]') + for line in f] + for src, tgt in zip(src_data, tgt_data): + data.append({'src': src, 'tgt': tgt}) + + # split data + logger.info(f'Spliting dataset {dataset} ({split})') + all_split_data = [] + n = len(data) // num_clients + data_idx = 0 + for i in range(num_clients): + num_split = n if i < num_clients - 1 else \ + len(data) - n * (num_clients - 1) + cur_data = data[data_idx: data_idx + num_split] + data_idx += num_split + all_split_data.append(cur_data) + logger.info(f'Client id: {i + 1}, num samples: {num_split}') + return all_split_data + + def _download(self, dataset): + url = 'https://federatedscope.oss-cn-beijing.aliyuncs.com' + os.makedirs(self.data_dir, exist_ok=True) + download_url(f'{url}/{dataset}.zip', self.data_dir) + + def _extract(self, dataset): + raw_dir = os.path.join(self.data_dir, dataset + '_raw') + extract_dir = os.path.join(self.data_dir, dataset) + with zipfile.ZipFile(os.path.join(self.data_dir, f'{dataset}.zip'), + 'r') as zip_ref: + zip_ref.extractall(raw_dir) + shutil.move(os.path.join(raw_dir, dataset), self.data_dir) + if os.path.exists(os.path.join(extract_dir, '.DS_Store')): + os.remove(os.path.join(extract_dir, '.DS_Store')) + os.remove(os.path.join(self.data_dir, f'{dataset}.zip')) + shutil.rmtree(raw_dir) diff --git a/federatedscope/nlp/dataset/squad.py b/federatedscope/nlp/dataset/squad.py new file mode 100644 index 000000000..b15ff79c5 --- /dev/null +++ b/federatedscope/nlp/dataset/squad.py @@ -0,0 +1,356 @@ +import os +import os.path as osp +import torch +import logging +from federatedscope.nlp.dataset.utils import split_sent, DictDataset, NUM_DEBUG + +logger = logging.getLogger(__name__) + + +class SquadExample(object): + def __init__(self, qa_id, question, context, train_answer, val_answer, + start_pos, end_pos, context_tokens, is_impossible): + self.qa_id = qa_id + self.question = question + self.context = context + self.train_answer = train_answer + self.val_answer = val_answer + self.start_position = start_pos + self.end_position = end_pos + self.context_tokens = context_tokens + self.is_impossible = is_impossible + + +class SquadEncodedInput(object): + def __init__(self, token_ids, token_type_ids, attention_mask, + overflow_token_ids): + self.token_ids = token_ids + self.token_type_ids = token_type_ids + self.attention_mask = attention_mask + self.overflow_token_ids = overflow_token_ids + + +class SquadResult(object): + def __init__(self, unique_id, start_logits, end_logits): + self.unique_id = unique_id + self.start_logits = start_logits + self.end_logits = end_logits + + +def refine_subtoken_position(context_subtokens, subtoken_start_pos, + subtoken_end_pos, tokenizer, annotated_answer): + subtoken_answer = ' '.join(tokenizer.tokenize(annotated_answer)) + for new_st in range(subtoken_start_pos, subtoken_end_pos + 1): + for new_ed in range(subtoken_end_pos, subtoken_start_pos - 1, -1): + text_span = ' '.join(context_subtokens[new_st:(new_ed + 1)]) + if text_span == subtoken_answer: + return new_st, new_ed + return subtoken_start_pos, subtoken_end_pos + + +def get_char_to_word_positions(context, answer, start_char_pos, is_impossible): + context_tokens = [] + char_to_word_offset = [] + is_prev_whitespace = True + for c in context: + is_whitespace = (c == ' ' or c == '\t' or c == '\r' or c == '\n' or + ord(c) == 0x202F) + if is_whitespace: + is_prev_whitespace = True + else: + if is_prev_whitespace: + context_tokens.append(c) + else: + context_tokens[-1] += c + is_prev_whitespace = False + char_to_word_offset.append(len(context_tokens) - 1) + + start_pos, end_pos = 0, 0 + if start_char_pos is not None and not is_impossible: + start_pos = char_to_word_offset[start_char_pos] + end_pos = char_to_word_offset[start_char_pos + len(answer) - 1] + return start_pos, end_pos, context_tokens + + +def check_max_context_token(all_spans, cur_span_idx, pos): + best_score, best_span_idx = None, None + for span_idx, span in enumerate(all_spans): + end = span.context_start_position + span.context_len - 1 + if pos < span.context_start_position or pos > end: + continue + num_left_context = pos - span.context_start_position + num_right_context = end - pos + score = min(num_left_context, num_right_context) + 0.01 * \ + span.context_len + if best_score is None or score > best_score: + best_score = score + best_span_idx = span_idx + return cur_span_idx == best_span_idx + + +def encode(tokenizer, text_a, text_b, max_seq_len, max_query_len, + added_trunc_size): + def _get_token_ids(text): + if isinstance(text, str): + return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) + elif isinstance(text, (list, tuple)) and len(text) > 0 and \ + isinstance(text[0], str): + return tokenizer.convert_tokens_to_ids(text) + elif isinstance(text, (list, tuple)) and len(text) > 0 and \ + isinstance(text[0], int): + return text + else: + raise ValueError('Input is not valid, should be a string, ' + 'a list/tuple of strings or a list/tuple of ' + 'integers.') + + token_ids_a = _get_token_ids(text_a) + token_ids_b = _get_token_ids(text_b) + + # Truncate + overflow_token_ids = None + len_a = len(token_ids_a) + 2 + total_len = len(token_ids_a) + len(token_ids_b) + 3 + if len_a > max_query_len: + num_remove = len_a - max_query_len + token_ids_a = token_ids_a[:-num_remove] + if total_len > max_seq_len: + num_remove = total_len - max_seq_len + trunc_size = min(len(token_ids_b), added_trunc_size + num_remove) + overflow_token_ids = token_ids_b[-trunc_size:] + token_ids_b = token_ids_b[:-num_remove] + + # Combine and pad + token_ids = [tokenizer.cls_token_id] + \ + token_ids_a + [tokenizer.sep_token_id] + token_type_ids = [0] * len(token_ids) + token_ids += token_ids_b + [tokenizer.sep_token_id] + token_type_ids += [1] * (len(token_ids_b) + 1) + attention_mask = [1] * len(token_ids) + if len(token_ids) < max_seq_len: + dif = max_seq_len - len(token_ids) + token_ids += [tokenizer.pad_token_id] * dif + token_type_ids += [0] * dif + attention_mask += [0] * dif + + return SquadEncodedInput(token_ids, token_type_ids, attention_mask, + overflow_token_ids) + + +def create_squad_examples(data, split, debug=False): + if debug: + data = data[:NUM_DEBUG] + examples = [] + for para in data: + context = para['context'] + qa = para['qa'] + qa_id = qa['id'] + question = qa['question'] + start_char_pos = None + train_answer = None + val_answer = [] + + is_impossible = qa['is_impossible'] if 'is_impossible' in qa else False + if not is_impossible: + if split == 'train': + train_answer = qa['answers'][0]['text'] + start_char_pos = qa['answers'][0]['answer_start'] + else: + val_answer = qa['answers'] + + start_pos, end_pos, context_tokens = get_char_to_word_positions( + context, train_answer, start_char_pos, is_impossible) + examples.append(SquadExample(qa_id, question, context, train_answer, + val_answer, start_pos, end_pos, + context_tokens, is_impossible)) + return examples + + +def create_squad_dataset(data, split, tokenizer, max_seq_len, max_query_len, + trunc_stride, cache_dir='', client_id=None, + pretrain=False, debug=False, **kwargs): + if pretrain: + return create_squad_pretrain_dataset( + data, split, tokenizer, max_seq_len, cache_dir, client_id, debug) + + save_dir = osp.join(cache_dir, 'finetune', str(client_id)) + cache_file = osp.join(save_dir, split + '.pt') + if osp.exists(cache_file): + logger.info('Loading cache file from \'{}\''.format(cache_file)) + cache_data = torch.load(cache_file) + examples = cache_data['examples'] + encoded_inputs = cache_data['encoded_inputs'] + else: + examples = create_squad_examples(data, split, debug) + unique_id = 1000000000 + encoded_inputs = [] + for example_idx, example in enumerate(examples): + if split == 'train' and not example.is_impossible: + start_pos = example.start_position + end_pos = example.end_position + actual_answer = ' '.join(example.context_tokens[ + start_pos:(end_pos + 1)]) + cleaned_answer = ' '.join(example.train_answer.strip().split()) + if actual_answer.find(cleaned_answer) == -1: + logger.info('Could not find answer: {} vs. {}'.format( + actual_answer, cleaned_answer)) + continue + + tok_to_subtok_idx = [] + subtok_to_tok_idx = [] + context_subtokens = [] + for i, token in enumerate(example.context_tokens): + tok_to_subtok_idx.append(len(context_subtokens)) + subtokens = tokenizer.tokenize(token) + for subtoken in subtokens: + subtok_to_tok_idx.append(i) + context_subtokens.append(subtoken) + + if split == 'train' and not example.is_impossible: + subtoken_start_pos = tok_to_subtok_idx[example.start_position] + if example.end_position < len(example.context_tokens) - 1: + subtoken_end_pos = tok_to_subtok_idx[ + example.end_position + 1] - 1 + else: + subtoken_end_pos = len(context_subtokens) - 1 + subtoken_start_pos, subtoken_end_pos = refine_subtoken_position( + context_subtokens, subtoken_start_pos, subtoken_end_pos, + tokenizer, example.train_answer) + + truncated_context = context_subtokens + len_question = min(len(tokenizer.tokenize(example.question)), + max_query_len - 2) + added_trunc_size = max_seq_len - trunc_stride - len_question - 3 + spans = [] + while len(spans) * trunc_stride < len(context_subtokens): + text_a = example.question + text_b = truncated_context + encoded_input = encode(tokenizer, text_a, text_b, max_seq_len, + max_query_len, added_trunc_size) + context_start_pos = len(spans) * trunc_stride + context_len = min(len(context_subtokens) - context_start_pos, + max_seq_len - len_question - 3) + context_end_pos = context_start_pos + context_len - 1 + + if tokenizer.pad_token_id in encoded_input.token_ids: + non_padded_ids = encoded_input.token_ids[ + :encoded_input.token_ids.index( + tokenizer.pad_token_id)] + else: + non_padded_ids = encoded_input.token_ids + tokens = tokenizer.convert_ids_to_tokens(non_padded_ids) + + context_subtok_to_tok_idx = {} + for i in range(context_len): + context_idx = len_question + i + 2 + context_subtok_to_tok_idx[context_idx] = \ + subtok_to_tok_idx[context_start_pos + i] + + start_pos, end_pos = 0, 0 + span_is_impossible = example.is_impossible + if split == 'train' and not span_is_impossible: + # For training, if our document chunk does not contain an annotation + # we throw it out, since there is nothing to predict. + if subtoken_start_pos >= context_start_pos and \ + subtoken_end_pos <= context_end_pos: + context_offset = len_question + 2 + start_pos = subtoken_start_pos - context_start_pos + \ + context_offset + end_pos = subtoken_end_pos - context_start_pos + \ + context_offset + else: + start_pos = 0 + end_pos = 0 + span_is_impossible = True + + encoded_input.start_position = start_pos + encoded_input.end_position = end_pos + encoded_input.is_impossible = span_is_impossible + + # For computing metrics + encoded_input.example_index = example_idx + encoded_input.context_start_position = context_start_pos + encoded_input.context_len = context_len + encoded_input.tokens = tokens + encoded_input.context_subtok_to_tok_idx = \ + context_subtok_to_tok_idx + encoded_input.is_max_context_token = {} + encoded_input.unique_id = unique_id + spans.append(encoded_input) + unique_id += 1 + + if encoded_input.overflow_token_ids is None: + break + truncated_context = encoded_input.overflow_token_ids + + for span_idx in range(len(spans)): + for context_idx in range(spans[span_idx].context_len): + is_max_context_token = check_max_context_token( + spans, span_idx, span_idx * trunc_stride + context_idx) + idx = len_question + context_idx + 2 + spans[span_idx].is_max_context_token[idx] = \ + is_max_context_token + encoded_inputs.extend(spans) + + if cache_dir: + logger.info('Saving cache file to \'{}\''.format(cache_file)) + os.makedirs(save_dir, exist_ok=True) + torch.save({'examples': examples, + 'encoded_inputs': encoded_inputs}, cache_file) + + token_ids = torch.LongTensor([inp.token_ids for inp in encoded_inputs]) + token_type_ids = torch.LongTensor([inp.token_type_ids + for inp in encoded_inputs]) + attention_mask = torch.LongTensor([inp.attention_mask + for inp in encoded_inputs]) + start_positions = torch.LongTensor([inp.start_position + for inp in encoded_inputs]) + end_positions = torch.LongTensor([inp.end_position + for inp in encoded_inputs]) + + example_indices = torch.arange(token_ids.size(0), dtype=torch.long) + dataset = DictDataset({'token_ids': token_ids, + 'token_type_ids': token_type_ids, + 'attention_mask': attention_mask, + 'start_positions': start_positions, + 'end_positions': end_positions, + 'example_indices': example_indices}) + return dataset, encoded_inputs, examples + + +def create_squad_pretrain_dataset(data, split, tokenizer, max_seq_len, + cache_dir='', client_id=None, debug=False): + save_dir = osp.join(cache_dir, 'pretrain', str(client_id)) + cache_file = osp.join(save_dir, split + '.pt') + if osp.exists(cache_file): + logger.info('Loading cache file from \'{}\''.format(cache_file)) + cache_data = torch.load(cache_file) + examples = cache_data['examples'] + encoded_inputs = cache_data['encoded_inputs'] + else: + examples = create_squad_examples(data, split, debug) + texts = split_sent([e.context for e in examples], + eoq=tokenizer.eoq_token) + encoded_inputs = tokenizer(texts, + padding='max_length', + truncation=True, + max_length=max_seq_len, + return_tensors='pt') + num_non_padding = (encoded_inputs.input_ids != + tokenizer.pad_token_id).sum(dim=-1) + for i, pad_idx in enumerate(num_non_padding): + encoded_inputs.input_ids[i, 0] = tokenizer.bos_token_id + encoded_inputs.input_ids[i, pad_idx - 1] = tokenizer.eos_token_id + + if cache_dir: + logger.info('Saving cache file to \'{}\''.format(cache_file)) + os.makedirs(save_dir, exist_ok=True) + torch.save({'examples': examples, + 'encoded_inputs': encoded_inputs}, cache_file) + + example_indices = torch.arange(encoded_inputs.input_ids.size(0), + dtype=torch.long) + dataset = DictDataset({'token_ids': encoded_inputs.input_ids, + 'attention_mask': encoded_inputs.attention_mask, + 'example_indices': example_indices}) + return dataset, encoded_inputs, examples diff --git a/federatedscope/nlp/dataset/utils.py b/federatedscope/nlp/dataset/utils.py index e31e306ed..94b88cf95 100644 --- a/federatedscope/nlp/dataset/utils.py +++ b/federatedscope/nlp/dataset/utils.py @@ -6,7 +6,12 @@ import re import numpy as np +import json from collections import Counter +from nltk.tokenize import sent_tokenize +from torch.utils.data.dataset import Dataset +from transformers.models.bert import BertTokenizerFast + # ------------------------ # utils for shakespeare dataset @@ -14,6 +19,7 @@ ALL_LETTERS = "\n !\"&'(),-.0123456789:;>?ABCDEFGHIJKLMNOPQRSTUVWXYZ[" \ "]abcdefghijklmnopqrstuvwxyz}" NUM_LETTERS = len(ALL_LETTERS) +NUM_DEBUG = 20 def _one_hot(index, size): @@ -88,3 +94,53 @@ def label_to_index(labels): sorted_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True) label_list = [x[0] for x in sorted_tuples] return [label_list.index(x) for x in labels] + + +def split_sent(examples, eoq='[unused2]', tokenize=True): + new_examples = [] + for e in examples: + if tokenize: + e = f' {eoq} '.join(sent_tokenize(e)) + else: + e = e.replace('[SEP]', eoq) + new_examples.append(e) + return new_examples + + +class DictDataset(Dataset): + def __init__(self, inputs): + super().__init__() + assert all(list(inputs.values())[0].size(0) == v.size(0) + for v in inputs.values()), "Size mismatch between tensors" + self.inputs = inputs + + def __getitem__(self, index): + return {k: v[index] for k, v in self.inputs.items()} + + def __len__(self): + return list(self.inputs.values())[0].size(0) + + +def setup_tokenizer(config): + bos_token, eos_token, eoq_token = \ + config.model.bos_token, config.model.eos_token, config.model.eoq_token + try: + tokenizer = BertTokenizerFast.from_pretrained( + config.model.model_type, + additional_special_tokens=[bos_token, eos_token, eoq_token], + skip_special_tokens=True, + local_files_only=True, + ) + except: + tokenizer = BertTokenizerFast.from_pretrained( + config.model.model_type, + additional_special_tokens=[bos_token, eos_token, eoq_token], + skip_special_tokens=True, + ) + tokenizer.bos_token = bos_token + tokenizer.eos_token = eos_token + tokenizer.eoq_token = eoq_token + tokenizer.bos_token_id = tokenizer.vocab[bos_token] + tokenizer.eos_token_id = tokenizer.vocab[eos_token] + tokenizer.eoq_token_id = tokenizer.vocab[eoq_token] + return tokenizer diff --git a/federatedscope/nlp/trainer/utils.py b/federatedscope/nlp/trainer/utils.py new file mode 100644 index 000000000..0b610a8e6 --- /dev/null +++ b/federatedscope/nlp/trainer/utils.py @@ -0,0 +1,75 @@ + +class AverageMeter(object): + def __init__(self): + self.reset() + + def reset(self): + self.avg = 0 + self.sum = 0 + self.cnt = 0 + self.val = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.cnt += n + self.avg = self.sum / self.cnt + + +class ContrastiveMonitor(object): + def __init__(self, stat=1, enc_hidden=None, synth_tokens=None, + dec_hidden=None, dec_out=None, all_group_ids=None, topk_group_ids=None): + self.stat = stat + self.enc_hidden = enc_hidden + self.synth_tokens = synth_tokens + self.dec_hidden = dec_hidden + self.dec_out = dec_out + self.all_group_ids = all_group_ids + self.topk_group_ids = topk_group_ids + + def update_stat(self, status): + self.stat = status + + def update_all_group_ids(self, group_ids): + self.all_group_ids = group_ids + + def update_topk_group_ids(self, group_ids): + self.topk_group_ids = group_ids + + def update_enc_hidden(self, enc_hidden, k=None): + if k is None: + self.enc_hidden = enc_hidden + else: + if self.enc_hidden is None: + self.enc_hidden = {} + self.enc_hidden[k] = enc_hidden + + def update_synth_tokens(self, synth_tokens, k=None): + if k is None: + self.synth_tokens = synth_tokens + else: + if self.synth_tokens is None: + self.synth_tokens = {} + self.synth_tokens[k] = synth_tokens + + def update_dec_hidden(self, dec_hidden, k=None): + if k is None: + self.dec_hidden = dec_hidden + else: + if self.dec_hidden is None: + self.dec_hidden = {} + self.dec_hidden[k] = dec_hidden + + def update_dec_out(self, dec_out, k=None): + if k is None: + self.dec_out = dec_out + else: + if self.dec_out is None: + self.dec_out = {} + self.dec_out[k] = dec_out + + def reset(self): + self.stat = 1 + self.dec_hidden = None + self.dec_out = None + self.group_ids = None diff --git a/scripts/fednlp_exp_scripts/fedavg/config_client_fedavg.yaml b/scripts/fednlp_exp_scripts/fedavg/config_client_fedavg.yaml new file mode 100644 index 000000000..b7632e0da --- /dev/null +++ b/scripts/fednlp_exp_scripts/fedavg/config_client_fedavg.yaml @@ -0,0 +1,94 @@ +client_group_1: + data: + batch_size: 32 + max_seq_len: 128 + model: + task: imdb + num_labels: 2 + trainer: + train_steps: 200 + grad_accum_count: 1 + eval: + metrics: ['acc'] +client_group_2: + data: + batch_size: 32 + max_seq_len: 128 + model: + task: agnews + num_labels: 4 + trainer: + train_steps: 200 + grad_accum_count: 1 + eval: + metrics: ['acc'] +client_group_3: + data: + batch_size: 32 + max_seq_len: 384 + max_query_len: 128 + trunc_stride: 128 + model: + task: squad + num_labels: 2 + n_best_size: 20 + max_answer_len: 30 + null_score_diff_threshold: 0.0 + trainer: + train_steps: 200 + grad_accum_count: 1 + eval: + metrics: ['squad'] +client_group_4: + data: + batch_size: 32 + max_seq_len: 384 + max_query_len: 128 + trunc_stride: 128 + model: + task: newsqa + num_labels: 2 + n_best_size: 20 + max_answer_len: 30 + null_score_diff_threshold: 0.0 + trainer: + train_steps: 200 + grad_accum_count: 1 + eval: + metrics: ['newsqa'] +client_group_5: + data: + batch_size: 32 + max_seq_len: 384 + max_tgt_len: 128 + model: + task: cnndm + max_length: 150 + min_length: 50 + no_repeat_ngram_size: 3 + length_penalty: 2.0 + num_beams: 5 + label_smoothing: 0.1 + trainer: + train_steps: 200 + grad_accum_count: 1 + eval: + metrics: ['cnndm'] +client_group_6: + data: + batch_size: 32 + max_seq_len: 384 + max_tgt_len: 64 + model: + task: msqg + max_length: 100 + min_length: 1 + no_repeat_ngram_size: 3 + length_penalty: 2.0 + num_beams: 5 + label_smoothing: 0.1 + trainer: + train_steps: 200 + grad_accum_count: 1 + eval: + metrics: ['msqg'] diff --git a/scripts/fednlp_exp_scripts/fedavg/config_fedavg.yaml b/scripts/fednlp_exp_scripts/fedavg/config_fedavg.yaml new file mode 100644 index 000000000..c7aee3685 --- /dev/null +++ b/scripts/fednlp_exp_scripts/fedavg/config_fedavg.yaml @@ -0,0 +1,39 @@ +use_gpu: True +device: 0 +seed: 12345 +outdir: exp/fedavg/ +federate: + mode: standalone + method: fedavg + total_round_num: 100 + client_num: 18 +data: + type: fednlp_data + root: datasets/ + datasets: ['imdb', 'agnews', 'squad', 'newsqa', 'cnndm', 'msqg'] + num_grouped_clients: [1, 3, 3, 2, 5, 4] + num_workers: 0 + cache_dir: cache/ +model: + type: fednlp_model + model_type: google/bert_uncased_L-2_H-128_A-2 + bos_token: '[unused0]' + eos_token: '[unused1]' + eoq_token: '[unused2]' +personalization: + local_param: ['classifier', 'encoder.pooler', 'decoder'] +trainer: + type: fednlp_trainer +train: + optimizer: + type: AdamW + lr: 5e-4 + weight_decay: 0.01 + grad_clip: 1.0 + scheduler: + type: step + warmup_ratio: 0.1 +eval: + split: ['test'] + report: ['group_avg'] + freq: 100000000 # eval freq across rounds diff --git a/scripts/fednlp_exp_scripts/fedavg/run.sh b/scripts/fednlp_exp_scripts/fedavg/run.sh new file mode 100644 index 000000000..40a2c5a5e --- /dev/null +++ b/scripts/fednlp_exp_scripts/fedavg/run.sh @@ -0,0 +1,10 @@ +DEVICE=$1 +CFG_DIR="$( dirname -- "$0"; )" +EXP_DIR="exp/fedavg/" + +python federatedscope/main.py \ + --cfg $CFG_DIR/config_fedavg.yaml \ + --client_cfg $CFG_DIR/config_client_fedavg.yaml \ + outdir $EXP_DIR/train/ \ + device $DEVICE \ + data.debug True \ From 3cc21f355da983d024f04b8061dadcb7fe386094 Mon Sep 17 00:00:00 2001 From: cheneydon Date: Tue, 25 Oct 2022 10:20:11 +0800 Subject: [PATCH 2/5] update dataset for hetero-fednlp --- .../core/auxiliaries/data_builder.py | 1 - federatedscope/core/configs/cfg_training.py | 1 + .../nlp/dataloader/data_collator.py | 11 ++++--- federatedscope/nlp/dataset/newsqa.py | 27 +++++++++------- federatedscope/nlp/dataset/squad.py | 32 +++++++++++-------- 5 files changed, 43 insertions(+), 29 deletions(-) diff --git a/federatedscope/core/auxiliaries/data_builder.py b/federatedscope/core/auxiliaries/data_builder.py index 798f0bd24..3bcaa0381 100644 --- a/federatedscope/core/auxiliaries/data_builder.py +++ b/federatedscope/core/auxiliaries/data_builder.py @@ -17,7 +17,6 @@ f'{error} in `federatedscope.contrib.data`, some modules are not ' f'available.') - # TODO: Add PyGNodeDataTranslator and PyGLinkDataTranslator # TODO: move splitter to PyGNodeDataTranslator and PyGLinkDataTranslator TRANS_DATA_MAP = { diff --git a/federatedscope/core/configs/cfg_training.py b/federatedscope/core/configs/cfg_training.py index 57a1d2e44..17eca3882 100644 --- a/federatedscope/core/configs/cfg_training.py +++ b/federatedscope/core/configs/cfg_training.py @@ -7,6 +7,7 @@ def extend_training_cfg(cfg): # Trainer related options # ---------------------------------------------------------------------- # cfg.trainer = CN() + cfg.trainer.type = 'general' # fednlp diff --git a/federatedscope/nlp/dataloader/data_collator.py b/federatedscope/nlp/dataloader/data_collator.py index 34fc024ac..289806763 100644 --- a/federatedscope/nlp/dataloader/data_collator.py +++ b/federatedscope/nlp/dataloader/data_collator.py @@ -24,7 +24,8 @@ def __call__(self, examples): probability_matrix = torch.full(labels.shape, self.mlm_probability) special_tokens_mask = [ self.tokenizer.get_special_tokens_mask( - val, already_has_special_tokens=True) for val in labels.tolist() + val, already_has_special_tokens=True) + for val in labels.tolist() ] probability_matrix.masked_fill_( torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) @@ -42,8 +43,9 @@ def __call__(self, examples): self.tokenizer.mask_token) # 10% of the time, we replace masked input tokens with random word - indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() \ - & masked_indices & ~indices_replaced + indices_random = \ + torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & \ + masked_indices & ~indices_replaced random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) token_ids[indices_random] = random_words[indices_random] @@ -107,7 +109,8 @@ def permutate_sentences(self, inputs): full_stops = (inputs[i] == self.tokenizer.eoq_token_id) | ( inputs[i] == self.tokenizer.eos_token_id) full_stops = full_stops[None, :] - sentence_ends = np.argwhere(full_stops[:, 1:] * ~full_stops[:, :-1]) + sentence_ends = np.argwhere( + full_stops[:, 1:] * ~full_stops[:, :-1]) if len(sentence_ends) == 0: continue diff --git a/federatedscope/nlp/dataset/newsqa.py b/federatedscope/nlp/dataset/newsqa.py index 67523b682..04eb44110 100644 --- a/federatedscope/nlp/dataset/newsqa.py +++ b/federatedscope/nlp/dataset/newsqa.py @@ -80,8 +80,8 @@ def check_max_context_token(all_spans, cur_span_idx, pos): continue num_left_context = pos - span.context_start_position num_right_context = end - pos - score = min(num_left_context, num_right_context) + 0.01 * \ - span.context_len + score = \ + min(num_left_context, num_right_context) + 0.01 * span.context_len if best_score is None or score > best_score: best_score = score best_span_idx = span_idx @@ -121,8 +121,8 @@ def _get_token_ids(text): token_ids_b = token_ids_b[:-num_remove] # Combine and pad - token_ids = [tokenizer.cls_token_id] + \ - token_ids_a + [tokenizer.sep_token_id] + token_ids = \ + [tokenizer.cls_token_id] + token_ids_a + [tokenizer.sep_token_id] token_type_ids = [0] * len(token_ids) token_ids += token_ids_b + [tokenizer.sep_token_id] token_type_ids += [1] * (len(token_ids_b) + 1) @@ -218,9 +218,12 @@ def create_newsqa_dataset(data, split, tokenizer, max_seq_len, max_query_len, example.end_position + 1] - 1 else: subtoken_end_pos = len(context_subtokens) - 1 - subtoken_start_pos, subtoken_end_pos = refine_subtoken_position( - context_subtokens, subtoken_start_pos, subtoken_end_pos, - tokenizer, example.train_answer) + subtoken_start_pos, subtoken_end_pos = \ + refine_subtoken_position(context_subtokens, + subtoken_start_pos, + subtoken_end_pos, + tokenizer, + example.train_answer) truncated_context = context_subtokens len_question = min(len(tokenizer.tokenize(example.question)), @@ -260,10 +263,12 @@ def create_newsqa_dataset(data, split, tokenizer, max_seq_len, max_query_len, if subtoken_start_pos >= context_start_pos and \ subtoken_end_pos <= context_end_pos: context_offset = len_question + 2 - start_pos = subtoken_start_pos - context_start_pos + \ - context_offset - end_pos = subtoken_end_pos - context_start_pos + \ - context_offset + start_pos = \ + subtoken_start_pos - context_start_pos + \ + context_offset + end_pos = \ + subtoken_end_pos - context_start_pos + \ + context_offset else: start_pos = 0 end_pos = 0 diff --git a/federatedscope/nlp/dataset/squad.py b/federatedscope/nlp/dataset/squad.py index b15ff79c5..3e30d7b71 100644 --- a/federatedscope/nlp/dataset/squad.py +++ b/federatedscope/nlp/dataset/squad.py @@ -80,8 +80,8 @@ def check_max_context_token(all_spans, cur_span_idx, pos): continue num_left_context = pos - span.context_start_position num_right_context = end - pos - score = min(num_left_context, num_right_context) + 0.01 * \ - span.context_len + score = \ + min(num_left_context, num_right_context) + 0.01 * span.context_len if best_score is None or score > best_score: best_score = score best_span_idx = span_idx @@ -121,8 +121,8 @@ def _get_token_ids(text): token_ids_b = token_ids_b[:-num_remove] # Combine and pad - token_ids = [tokenizer.cls_token_id] + \ - token_ids_a + [tokenizer.sep_token_id] + token_ids = \ + [tokenizer.cls_token_id] + token_ids_a + [tokenizer.sep_token_id] token_type_ids = [0] * len(token_ids) token_ids += token_ids_b + [tokenizer.sep_token_id] token_type_ids += [1] * (len(token_ids_b) + 1) @@ -213,9 +213,12 @@ def create_squad_dataset(data, split, tokenizer, max_seq_len, max_query_len, example.end_position + 1] - 1 else: subtoken_end_pos = len(context_subtokens) - 1 - subtoken_start_pos, subtoken_end_pos = refine_subtoken_position( - context_subtokens, subtoken_start_pos, subtoken_end_pos, - tokenizer, example.train_answer) + subtoken_start_pos, subtoken_end_pos = \ + refine_subtoken_position(context_subtokens, + subtoken_start_pos, + subtoken_end_pos, + tokenizer, + example.train_answer) truncated_context = context_subtokens len_question = min(len(tokenizer.tokenize(example.question)), @@ -249,15 +252,18 @@ def create_squad_dataset(data, split, tokenizer, max_seq_len, max_query_len, start_pos, end_pos = 0, 0 span_is_impossible = example.is_impossible if split == 'train' and not span_is_impossible: - # For training, if our document chunk does not contain an annotation - # we throw it out, since there is nothing to predict. + # For training, if our document chunk does not contain + # an annotation we throw it out, since there is nothing + # to predict. if subtoken_start_pos >= context_start_pos and \ subtoken_end_pos <= context_end_pos: context_offset = len_question + 2 - start_pos = subtoken_start_pos - context_start_pos + \ - context_offset - end_pos = subtoken_end_pos - context_start_pos + \ - context_offset + start_pos = \ + subtoken_start_pos - context_start_pos + \ + context_offset + end_pos = \ + subtoken_end_pos - context_start_pos + \ + context_offset else: start_pos = 0 end_pos = 0 From 6a1b1f901f4dee5a6ef134427bd8ac2bf2de238b Mon Sep 17 00:00:00 2001 From: cheneydon Date: Tue, 25 Oct 2022 12:10:05 +0800 Subject: [PATCH 3/5] update dataset for hetero-fednlp --- .../nlp/dataloader/data_collator.py | 117 +++++--- .../nlp/dataloader/hfl_dataloader.py | 284 +++++++++--------- federatedscope/nlp/dataset/agnews.py | 56 ++-- federatedscope/nlp/dataset/cnndm.py | 163 +++++----- federatedscope/nlp/dataset/imdb.py | 57 ++-- federatedscope/nlp/dataset/msqg.py | 164 +++++----- federatedscope/nlp/dataset/newsqa.py | 121 +++++--- .../nlp/dataset/preprocess/get_hfl_data.py | 35 ++- federatedscope/nlp/dataset/squad.py | 108 ++++--- federatedscope/nlp/dataset/utils.py | 6 +- federatedscope/nlp/trainer/utils.py | 11 +- 11 files changed, 641 insertions(+), 481 deletions(-) diff --git a/federatedscope/nlp/dataloader/data_collator.py b/federatedscope/nlp/dataloader/data_collator.py index 289806763..d7498d7c1 100644 --- a/federatedscope/nlp/dataloader/data_collator.py +++ b/federatedscope/nlp/dataloader/data_collator.py @@ -12,8 +12,10 @@ def __init__(self, tokenizer, mlm_probability=0.15): def __call__(self, examples): """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ - examples = {k: torch.stack([x[k] for x in examples]) - for k in examples[0].keys()} + examples = { + k: torch.stack([x[k] for x in examples]) + for k in examples[0].keys() + } token_ids = examples['token_ids'] attention_mask = examples['attention_mask'] labels = token_ids.clone() @@ -27,8 +29,9 @@ def __call__(self, examples): val, already_has_special_tokens=True) for val in labels.tolist() ] - probability_matrix.masked_fill_( - torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) + probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, + dtype=torch.bool), + value=0.0) if self.tokenizer._pad_token is not None: padding_mask = labels.eq(self.tokenizer.pad_token_id) probability_matrix.masked_fill_(padding_mask, value=0.0) @@ -37,8 +40,8 @@ def __call__(self, examples): # 80% of the time, we replace masked input tokens with # tokenizer.mask_token ([MASK]) - indices_replaced = torch.bernoulli( - torch.full(labels.shape, 0.8)).bool() & masked_indices + indices_replaced = torch.bernoulli(torch.full( + labels.shape, 0.8)).bool() & masked_indices token_ids[indices_replaced] = self.tokenizer.convert_tokens_to_ids( self.tokenizer.mask_token) @@ -46,16 +49,19 @@ def __call__(self, examples): indices_random = \ torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & \ masked_indices & ~indices_replaced - random_words = torch.randint(len(self.tokenizer), labels.shape, + random_words = torch.randint(len(self.tokenizer), + labels.shape, dtype=torch.long) token_ids[indices_random] = random_words[indices_random] # The rest of the time (10% of the time) we keep the masked input # tokens unchanged - return {'token_ids': token_ids, - 'attention_mask': attention_mask, - 'labels': labels, - 'example_indices': examples['example_indices']} + return { + 'token_ids': token_ids, + 'attention_mask': attention_mask, + 'labels': labels, + 'example_indices': examples['example_indices'] + } class DataCollatorForDenoisingTasks(object): @@ -66,7 +72,10 @@ class DataCollatorForDenoisingTasks(object): The default paramters is based on BART paper https://arxiv.org/abs/1910.13461. """ - def __init__(self, tokenizer, mask_ratio=0.3, poisson_lambda=3.0, + def __init__(self, + tokenizer, + mask_ratio=0.3, + poisson_lambda=3.0, permutate_sentence_ratio=1.0): self.tokenizer = tokenizer self.mask_ratio = mask_ratio @@ -74,8 +83,10 @@ def __init__(self, tokenizer, mask_ratio=0.3, poisson_lambda=3.0, self.permutate_sentence_ratio = permutate_sentence_ratio def __call__(self, examples): - examples = {k: torch.stack([x[k] for x in examples]) - for k in examples[0].keys()} + examples = { + k: torch.stack([x[k] for x in examples]) + for k in examples[0].keys() + } token_ids = examples['token_ids'].numpy() attention_mask = examples['attention_mask'].numpy() labels = token_ids.copy() @@ -89,52 +100,55 @@ def __call__(self, examples): if self.mask_ratio: token_ids, _ = self.add_whole_word_mask(token_ids, do_permutate) - num_non_padding = np.sum( - token_ids != self.tokenizer.pad_token_id, axis=-1) + num_non_padding = np.sum(token_ids != self.tokenizer.pad_token_id, + axis=-1) for i in range(len(attention_mask)): attention_mask[i][num_non_padding[i]:] = 0 token_ids = torch.from_numpy(token_ids) attention_mask = torch.from_numpy(attention_mask) labels = torch.from_numpy(labels) - return {'token_ids': token_ids, - 'attention_mask': attention_mask, - 'labels': labels, - 'example_indices': examples['example_indices']} + return { + 'token_ids': token_ids, + 'attention_mask': attention_mask, + 'labels': labels, + 'example_indices': examples['example_indices'] + } def permutate_sentences(self, inputs): results = inputs.copy() for i in range(inputs.shape[0]): full_stops = (inputs[i] == self.tokenizer.eoq_token_id) | ( - inputs[i] == self.tokenizer.eos_token_id) + inputs[i] == self.tokenizer.eos_token_id) full_stops = full_stops[None, :] - sentence_ends = np.argwhere( - full_stops[:, 1:] * ~full_stops[:, :-1]) + sentence_ends = np.argwhere(full_stops[:, 1:] * + ~full_stops[:, :-1]) if len(sentence_ends) == 0: continue sentence_ends[:, 1] += 2 - num_sentences = np.unique( - sentence_ends[:, 0], return_counts=True)[1] + num_sentences = np.unique(sentence_ends[:, 0], + return_counts=True)[1] num_to_permute = np.ceil( (num_sentences * 2 * self.permutate_sentence_ratio) / 2.0).astype(int) sentence_ends = np.split( - sentence_ends[:, 1], np.unique( - sentence_ends[:, 0], return_index=True)[1][1:]) + sentence_ends[:, 1], + np.unique(sentence_ends[:, 0], return_index=True)[1][1:]) - substitutions = np.random.permutation(num_sentences[0])[ - :num_to_permute[0]] + substitutions = np.random.permutation( + num_sentences[0])[:num_to_permute[0]] ordering = np.arange(0, num_sentences[0]) ordering[substitutions] = substitutions[np.random.permutation( num_to_permute[0])] index = 0 for j in ordering: - sentence = inputs[i, (sentence_ends[0][j - 1] if j > 0 else - 0) : sentence_ends[0][j]] - results[i, index : index + sentence.shape[0]] = sentence + sentence = inputs[i, ( + sentence_ends[0][j - + 1] if j > 0 else 0):sentence_ends[0][j]] + results[i, index:index + sentence.shape[0]] = sentence index += sentence.shape[0] num_non_padding = np.sum(results != self.tokenizer.pad_token_id, @@ -152,23 +166,26 @@ def add_whole_word_mask(self, inputs, do_permutate): special_tokens_mask = [ self.tokenizer.get_special_tokens_mask( - val,already_has_special_tokens=True) for val in labels.tolist() + val, already_has_special_tokens=True) + for val in labels.tolist() ] special_tokens_mask = np.array(special_tokens_mask, dtype=bool) # determine how many tokens we need to mask in total is_token = ~(labels == self.tokenizer.pad_token_id) & \ ~special_tokens_mask - num_to_mask = int(math.ceil(is_token.astype(float).sum() * - self.mask_ratio)) + num_to_mask = int( + math.ceil(is_token.astype(float).sum() * self.mask_ratio)) if num_to_mask == 0: return inputs, labels # generate a sufficient number of span lengths - lengths = poisson(lam=self.poisson_lambda, size=(num_to_mask,)) + lengths = poisson(lam=self.poisson_lambda, size=(num_to_mask, )) while np.cumsum(lengths, 0)[-1] < num_to_mask: - lengths = np.concatenate([lengths, poisson( - lam=self.poisson_lambda, size=(num_to_mask,))]) + lengths = np.concatenate([ + lengths, + poisson(lam=self.poisson_lambda, size=(num_to_mask, )) + ]) # remove all spans of length 0 # Note that BART inserts additional mask tokens where length == 0, @@ -177,11 +194,11 @@ def add_whole_word_mask(self, inputs, do_permutate): # trim to about num_to_mask tokens idx = np.argmin(np.abs(np.cumsum(lengths, 0) - num_to_mask)) + 1 - lengths = lengths[: idx + 1] + lengths = lengths[:idx + 1] # select span start indices token_indices = np.argwhere(is_token == 1) - span_starts = permutation(token_indices.shape[0])[: lengths.shape[0]] + span_starts = permutation(token_indices.shape[0])[:lengths.shape[0]] # prepare mask masked_indices = np.array(token_indices[span_starts]) @@ -213,23 +230,29 @@ def add_whole_word_mask(self, inputs, do_permutate): # remove mask tokens that are not starts of spans to_remove = (mask == 1) & np.roll((mask == 1), 1, 1) - new_inputs = np.full_like( - labels, fill_value=self.tokenizer.pad_token_id) + new_inputs = np.full_like(labels, + fill_value=self.tokenizer.pad_token_id) # splits = list(map(lambda x: x.reshape(-1), np.split(inputs_copy, # indices_or_sections=2, axis=0)) - for i, example in enumerate(np.split( - inputs, indices_or_sections=new_inputs.shape[0], axis=0)): + for i, example in enumerate( + np.split(inputs, + indices_or_sections=new_inputs.shape[0], + axis=0)): new_example = example[0][~to_remove[i]] - new_inputs[i, 0 : new_example.shape[0]] = new_example + new_inputs[i, 0:new_example.shape[0]] = new_example # batching now fixed return new_inputs, labels class DataCollatorForPFedNLP(object): - def __init__(self, tokenizer, mlm_probability=0.15, mask_ratio=0.3, - poisson_lambda=3.0, permutate_sentence_ratio=1.0): + def __init__(self, + tokenizer, + mlm_probability=0.15, + mask_ratio=0.3, + poisson_lambda=3.0, + permutate_sentence_ratio=1.0): self.mlm_collator = DataCollatorForMLM(tokenizer, mlm_probability) self.denoise_collator = DataCollatorForDenoisingTasks( tokenizer, mask_ratio, poisson_lambda, permutate_sentence_ratio) diff --git a/federatedscope/nlp/dataloader/hfl_dataloader.py b/federatedscope/nlp/dataloader/hfl_dataloader.py index e9f90c0b2..d387e985f 100644 --- a/federatedscope/nlp/dataloader/hfl_dataloader.py +++ b/federatedscope/nlp/dataloader/hfl_dataloader.py @@ -81,8 +81,8 @@ def extend_cfg(cfg, cfg_client): def create_data(data, split, tokenizer, task, model_type, max_seq_len, - max_query_len, trunc_stride, max_tgt_len, cache_dir, - client_id, pretrain, debug): + max_query_len, trunc_stride, max_tgt_len, cache_dir, client_id, + pretrain, debug): if task == 'imdb': create_dataset_func = create_imdb_dataset elif task == 'agnews': @@ -134,9 +134,11 @@ def load_fednlp_data(config, client_config): logger.info(f'Preprocessing dataset {config.data.type}') data_processor = HFLDataProcessor(config) all_data = data_processor.get_data() - all_data_dict = {'train': all_data[0], - 'val': all_data[1], - 'test': all_data[2]} + all_data_dict = { + 'train': all_data[0], + 'val': all_data[1], + 'test': all_data[2] + } data_dict = dict() for client_id in tqdm(range(1, config.federate.client_num + 1)): @@ -145,54 +147,56 @@ def load_fednlp_data(config, client_config): cur_task = cfg_client.model.downstream_tasks[client_id - 1] \ if pretrain else cfg_client.model.task train_data, val_data, test_data = [ - create_data( - data=all_data_dict[split][client_id - 1], - split=split, - tokenizer=tokenizer, - task=cur_task, - model_type=model_type, - max_seq_len=getattr(cfg_client.data, 'max_seq_len', None), - max_query_len=getattr(cfg_client.data, 'max_query_len', None), - trunc_stride=getattr(cfg_client.data, 'trunc_stride', None), - max_tgt_len=getattr(cfg_client.data, 'max_tgt_len', None), - cache_dir=cache_dir, - client_id=client_id, - pretrain=pretrain, - debug=debug) - for split in ['train', 'val', 'test'] + create_data(data=all_data_dict[split][client_id - 1], + split=split, + tokenizer=tokenizer, + task=cur_task, + model_type=model_type, + max_seq_len=getattr(cfg_client.data, 'max_seq_len', + None), + max_query_len=getattr(cfg_client.data, 'max_query_len', + None), + trunc_stride=getattr(cfg_client.data, 'trunc_stride', + None), + max_tgt_len=getattr(cfg_client.data, 'max_tgt_len', + None), + cache_dir=cache_dir, + client_id=client_id, + pretrain=pretrain, + debug=debug) for split in ['train', 'val', 'test'] ] dataloader_dict = { 'train': { - 'dataloader': DataLoader( - dataset=train_data[0], - batch_size=cfg_client.data.batch_size, - shuffle=config.data.shuffle, - num_workers=config.data.num_workers, - collate_fn=data_collator, - pin_memory=config.use_gpu), + 'dataloader': DataLoader(dataset=train_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=config.data.shuffle, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), 'encoded': train_data[1], - 'examples': train_data[2]}, + 'examples': train_data[2] + }, 'val': { - 'dataloader': DataLoader( - dataset=val_data[0], - batch_size=cfg_client.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers, - collate_fn=data_collator, - pin_memory=config.use_gpu), + 'dataloader': DataLoader(dataset=val_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), 'encoded': val_data[1], - 'examples': val_data[2]}, + 'examples': val_data[2] + }, 'test': { - 'dataloader': DataLoader( - dataset=test_data[0], - batch_size=cfg_client.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers, - collate_fn=data_collator, - pin_memory=config.use_gpu), + 'dataloader': DataLoader(dataset=test_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), 'encoded': test_data[1], - 'examples': test_data[2]}, + 'examples': test_data[2] + }, } data_dict[client_id] = dataloader_dict @@ -212,9 +216,11 @@ def load_pfednlp_data(config, client_config): logger.info(f'Preprocessing dataset {config.data.type}') data_processor = HFLDataProcessor(config) all_data = data_processor.get_data() - all_data_dict = {'train': all_data[0], - 'val': all_data[1], - 'test': all_data[2]} + all_data_dict = { + 'train': all_data[0], + 'val': all_data[1], + 'test': all_data[2] + } data_dict = dict() for client_id in tqdm(range(1, config.federate.client_num + 1)): @@ -223,54 +229,56 @@ def load_pfednlp_data(config, client_config): cur_task = cfg_client.model.downstream_tasks[client_id - 1] \ if pretrain else cfg_client.model.task train_data, val_data, test_data = [ - create_data( - data=all_data_dict[split][client_id - 1], - split=split, - tokenizer=tokenizer, - task=cur_task, - model_type=model_type, - max_seq_len=getattr(cfg_client.data, 'max_seq_len', None), - max_query_len=getattr(cfg_client.data, 'max_query_len', None), - trunc_stride=getattr(cfg_client.data, 'trunc_stride', None), - max_tgt_len=getattr(cfg_client.data, 'max_tgt_len', None), - cache_dir=cache_dir, - client_id=client_id, - pretrain=pretrain, - debug=debug) - for split in ['train', 'val', 'test'] + create_data(data=all_data_dict[split][client_id - 1], + split=split, + tokenizer=tokenizer, + task=cur_task, + model_type=model_type, + max_seq_len=getattr(cfg_client.data, 'max_seq_len', + None), + max_query_len=getattr(cfg_client.data, 'max_query_len', + None), + trunc_stride=getattr(cfg_client.data, 'trunc_stride', + None), + max_tgt_len=getattr(cfg_client.data, 'max_tgt_len', + None), + cache_dir=cache_dir, + client_id=client_id, + pretrain=pretrain, + debug=debug) for split in ['train', 'val', 'test'] ] dataloader_dict = { 'train': { - 'dataloader': DataLoader( - dataset=train_data[0], - batch_size=cfg_client.data.batch_size, - shuffle=config.data.shuffle, - num_workers=config.data.num_workers, - collate_fn=data_collator, - pin_memory=config.use_gpu), + 'dataloader': DataLoader(dataset=train_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=config.data.shuffle, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), 'encoded': train_data[1], - 'examples': train_data[2]}, + 'examples': train_data[2] + }, 'val': { - 'dataloader': DataLoader( - dataset=val_data[0], - batch_size=cfg_client.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers, - collate_fn=data_collator, - pin_memory=config.use_gpu), + 'dataloader': DataLoader(dataset=val_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), 'encoded': val_data[1], - 'examples': val_data[2]}, + 'examples': val_data[2] + }, 'test': { - 'dataloader': DataLoader( - dataset=test_data[0], - batch_size=cfg_client.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers, - collate_fn=data_collator, - pin_memory=config.use_gpu), + 'dataloader': DataLoader(dataset=test_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), 'encoded': test_data[1], - 'examples': test_data[2]}, + 'examples': test_data[2] + }, } data_dict[client_id] = dataloader_dict @@ -290,9 +298,11 @@ def load_pfednlp_contrast_data(config, client_config): logger.info(f'Preprocessing dataset {config.data.type}') data_processor = HFLDataProcessor(config) all_data = data_processor.get_data() - all_data_dict = {'train': all_data[0], - 'val': all_data[1], - 'test': all_data[2]} + all_data_dict = { + 'train': all_data[0], + 'val': all_data[1], + 'test': all_data[2] + } data_dict = dict() for client_id in tqdm(range(1, config.federate.client_num + 1)): @@ -301,64 +311,66 @@ def load_pfednlp_contrast_data(config, client_config): cur_task = cfg_client.model.downstream_tasks[client_id - 1] \ if pretrain else cfg_client.model.task train_data, val_data, test_data = [ - create_data( - data=all_data_dict[split][client_id - 1], - split=split, - tokenizer=tokenizer, - task=cur_task, - model_type=model_type, - max_seq_len=getattr(cfg_client.data, 'max_seq_len', None), - max_query_len=getattr(cfg_client.data, 'max_query_len', None), - trunc_stride=getattr(cfg_client.data, 'trunc_stride', None), - max_tgt_len=getattr(cfg_client.data, 'max_tgt_len', None), - cache_dir=cache_dir, - client_id=client_id, - pretrain=pretrain, - debug=debug) - for split in ['train', 'val', 'test'] + create_data(data=all_data_dict[split][client_id - 1], + split=split, + tokenizer=tokenizer, + task=cur_task, + model_type=model_type, + max_seq_len=getattr(cfg_client.data, 'max_seq_len', + None), + max_query_len=getattr(cfg_client.data, 'max_query_len', + None), + trunc_stride=getattr(cfg_client.data, 'trunc_stride', + None), + max_tgt_len=getattr(cfg_client.data, 'max_tgt_len', + None), + cache_dir=cache_dir, + client_id=client_id, + pretrain=pretrain, + debug=debug) for split in ['train', 'val', 'test'] ] dataloader_dict = { 'train_raw': { - 'dataloader': DataLoader( - dataset=train_data[0], - batch_size=cfg_client.data.batch_size, - shuffle=config.data.shuffle, - num_workers=config.data.num_workers, - collate_fn=data_collator, - pin_memory=config.use_gpu), + 'dataloader': DataLoader(dataset=train_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=config.data.shuffle, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), 'encoded': train_data[1], - 'examples': train_data[2]}, + 'examples': train_data[2] + }, 'train_contrast': { - 'dataloader': DataLoader( - dataset=train_data[0], - batch_size=cfg_client.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers, - collate_fn=data_collator, - pin_memory=config.use_gpu), + 'dataloader': DataLoader(dataset=train_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), 'encoded': train_data[1], - 'examples': train_data[2]}, + 'examples': train_data[2] + }, 'val': { - 'dataloader': DataLoader( - dataset=val_data[0], - batch_size=cfg_client.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers, - collate_fn=data_collator, - pin_memory=config.use_gpu), + 'dataloader': DataLoader(dataset=val_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), 'encoded': val_data[1], - 'examples': val_data[2]}, + 'examples': val_data[2] + }, 'test': { - 'dataloader': DataLoader( - dataset=test_data[0], - batch_size=cfg_client.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers, - collate_fn=data_collator, - pin_memory=config.use_gpu), + 'dataloader': DataLoader(dataset=test_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), 'encoded': test_data[1], - 'examples': test_data[2]}, + 'examples': test_data[2] + }, } data_dict[client_id] = dataloader_dict diff --git a/federatedscope/nlp/dataset/agnews.py b/federatedscope/nlp/dataset/agnews.py index c5ed9dcf4..9765d3492 100644 --- a/federatedscope/nlp/dataset/agnews.py +++ b/federatedscope/nlp/dataset/agnews.py @@ -16,12 +16,19 @@ def create_agnews_examples(data, debug=False): return examples -def create_agnews_dataset(data, split, tokenizer, max_seq_len, cache_dir='', - client_id=None, pretrain=False, debug=False, +def create_agnews_dataset(data, + split, + tokenizer, + max_seq_len, + cache_dir='', + client_id=None, + pretrain=False, + debug=False, **kwargs): if pretrain: - return create_agnews_pretrain_dataset( - data, split, tokenizer, max_seq_len, cache_dir, client_id, debug) + return create_agnews_pretrain_dataset(data, split, tokenizer, + max_seq_len, cache_dir, + client_id, debug) save_dir = osp.join(cache_dir, 'finetune', str(client_id)) cache_file = osp.join(save_dir, split + '.pt') @@ -42,22 +49,31 @@ def create_agnews_dataset(data, split, tokenizer, max_seq_len, cache_dir='', if cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_file)) os.makedirs(save_dir, exist_ok=True) - torch.save({'examples': examples, - 'encoded_inputs': encoded_inputs}, cache_file) + torch.save({ + 'examples': examples, + 'encoded_inputs': encoded_inputs + }, cache_file) labels = [ex[1] for ex in examples] example_indices = torch.arange(encoded_inputs.input_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': encoded_inputs.input_ids, - 'token_type_ids': encoded_inputs.token_type_ids, - 'attention_mask': encoded_inputs.attention_mask, - 'labels': torch.LongTensor(labels), - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': encoded_inputs.input_ids, + 'token_type_ids': encoded_inputs.token_type_ids, + 'attention_mask': encoded_inputs.attention_mask, + 'labels': torch.LongTensor(labels), + 'example_indices': example_indices + }) return dataset, encoded_inputs, examples -def create_agnews_pretrain_dataset(data, split, tokenizer, max_seq_len, - cache_dir='', client_id=None, debug=False): +def create_agnews_pretrain_dataset(data, + split, + tokenizer, + max_seq_len, + cache_dir='', + client_id=None, + debug=False): save_dir = osp.join(cache_dir, 'pretrain', str(client_id)) cache_file = osp.join(save_dir, split + '.pt') @@ -84,12 +100,16 @@ def create_agnews_pretrain_dataset(data, split, tokenizer, max_seq_len, if cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_file)) os.makedirs(save_dir, exist_ok=True) - torch.save({'examples': examples, - 'encoded_inputs': encoded_inputs}, cache_file) + torch.save({ + 'examples': examples, + 'encoded_inputs': encoded_inputs + }, cache_file) example_indices = torch.arange(encoded_inputs.input_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': encoded_inputs.input_ids, - 'attention_mask': encoded_inputs.attention_mask, - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': encoded_inputs.input_ids, + 'attention_mask': encoded_inputs.attention_mask, + 'example_indices': example_indices + }) return dataset, encoded_inputs, examples diff --git a/federatedscope/nlp/dataset/cnndm.py b/federatedscope/nlp/dataset/cnndm.py index f03a9871d..36e52f6d9 100644 --- a/federatedscope/nlp/dataset/cnndm.py +++ b/federatedscope/nlp/dataset/cnndm.py @@ -18,38 +18,43 @@ def create_cnndm_examples(data, debug=False): return src_examples, tgt_examples -def create_cnndm_dataset(data, split, tokenizer, max_src_len, max_tgt_len, - raw_cache_dir='', client_id=None, pretrain=False, - debug=False, **kwargs): +def create_cnndm_dataset(data, + split, + tokenizer, + max_src_len, + max_tgt_len, + raw_cache_dir='', + client_id=None, + pretrain=False, + debug=False, + **kwargs): if pretrain: - return create_cnndm_pretrain_dataset( - data, split, tokenizer, max_src_len, raw_cache_dir, client_id, - debug) + return create_cnndm_pretrain_dataset(data, split, tokenizer, + max_src_len, raw_cache_dir, + client_id, debug) cache_dir = osp.join(raw_cache_dir, 'finetune', str(client_id), split) src_examples, tgt_examples = create_cnndm_examples(data, debug) if osp.exists(cache_dir): logger.info('Loading cache file from \'{}\''.format(cache_dir)) - token_ids = np.memmap( - filename=osp.join(cache_dir, 'token_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='r', - dtype=np.int64) - token_type_ids = np.memmap( - filename=osp.join(cache_dir, 'token_type_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='r', - dtype=np.int64) - attention_mask = np.memmap( - filename=osp.join(cache_dir, 'attention_mask.memmap'), - shape=(len(src_examples), max_src_len), - mode='r', - dtype=np.int64) - labels = np.memmap( - filename=osp.join(cache_dir, 'labels.memmap'), - shape=(len(src_examples), max_tgt_len), - mode='r', - dtype=np.int64) + token_ids = np.memmap(filename=osp.join(cache_dir, 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + token_type_ids = np.memmap(filename=osp.join(cache_dir, + 'token_type_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + attention_mask = np.memmap(filename=osp.join(cache_dir, + 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + labels = np.memmap(filename=osp.join(cache_dir, 'labels.memmap'), + shape=(len(src_examples), max_tgt_len), + mode='r', + dtype=np.int64) token_ids = torch.from_numpy(token_ids) token_type_ids = torch.from_numpy(token_type_ids) @@ -76,26 +81,25 @@ def create_cnndm_dataset(data, split, tokenizer, max_src_len, max_tgt_len, if raw_cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_dir)) os.makedirs(cache_dir, exist_ok=True) - token_ids = np.memmap( - filename=osp.join(cache_dir, 'token_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='w+', - dtype=np.int64) - token_type_ids = np.memmap( - filename=osp.join(cache_dir, 'token_type_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='w+', - dtype=np.int64) - attention_mask = np.memmap( - filename=osp.join(cache_dir, 'attention_mask.memmap'), - shape=(len(src_examples), max_src_len), - mode='w+', - dtype=np.int64) - labels = np.memmap( - filename=osp.join(cache_dir, 'labels.memmap'), - shape=(len(src_examples), max_tgt_len), - mode='w+', - dtype=np.int64) + token_ids = np.memmap(filename=osp.join(cache_dir, + 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + token_type_ids = np.memmap(filename=osp.join( + cache_dir, 'token_type_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + attention_mask = np.memmap(filename=osp.join( + cache_dir, 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + labels = np.memmap(filename=osp.join(cache_dir, 'labels.memmap'), + shape=(len(src_examples), max_tgt_len), + mode='w+', + dtype=np.int64) for i in range(len(src_examples)): token_ids[i] = src_encoded.input_ids[i] @@ -114,31 +118,36 @@ def create_cnndm_dataset(data, split, tokenizer, max_src_len, max_tgt_len, labels = tgt_encoded.input_ids example_indices = torch.arange(token_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': token_ids, - 'token_type_ids': token_type_ids, - 'attention_mask': attention_mask, - 'labels': labels, - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': token_ids, + 'token_type_ids': token_type_ids, + 'attention_mask': attention_mask, + 'labels': labels, + 'example_indices': example_indices + }) return dataset, None, None -def create_cnndm_pretrain_dataset(data, split, tokenizer, max_src_len, - raw_cache_dir='', client_id=None, +def create_cnndm_pretrain_dataset(data, + split, + tokenizer, + max_src_len, + raw_cache_dir='', + client_id=None, debug=False): cache_dir = osp.join(raw_cache_dir, 'pretrain', str(client_id), split) src_examples, tgt_examples = create_cnndm_examples(data, debug) if osp.exists(cache_dir): logger.info('Loading cache file from \'{}\''.format(cache_dir)) - token_ids = np.memmap( - filename=osp.join(cache_dir, 'token_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='r', - dtype=np.int64) - attention_mask = np.memmap( - filename=osp.join(cache_dir, 'attention_mask.memmap'), - shape=(len(src_examples), max_src_len), - mode='r', - dtype=np.int64) + token_ids = np.memmap(filename=osp.join(cache_dir, 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + attention_mask = np.memmap(filename=osp.join(cache_dir, + 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) token_ids = torch.from_numpy(token_ids) attention_mask = torch.from_numpy(attention_mask) @@ -158,16 +167,16 @@ def create_cnndm_pretrain_dataset(data, split, tokenizer, max_src_len, if raw_cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_dir)) os.makedirs(cache_dir, exist_ok=True) - token_ids = np.memmap( - filename=osp.join(cache_dir, 'token_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='w+', - dtype=np.int64) - attention_mask = np.memmap( - filename=osp.join(cache_dir, 'attention_mask.memmap'), - shape=(len(src_examples), max_src_len), - mode='w+', - dtype=np.int64) + token_ids = np.memmap(filename=osp.join(cache_dir, + 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + attention_mask = np.memmap(filename=osp.join( + cache_dir, 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) for i in range(len(src_examples)): token_ids[i] = src_encoded.input_ids[i] @@ -180,7 +189,9 @@ def create_cnndm_pretrain_dataset(data, split, tokenizer, max_src_len, attention_mask = src_encoded.attention_mask example_indices = torch.arange(token_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': token_ids, - 'attention_mask': attention_mask, - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': token_ids, + 'attention_mask': attention_mask, + 'example_indices': example_indices + }) return dataset, None, None diff --git a/federatedscope/nlp/dataset/imdb.py b/federatedscope/nlp/dataset/imdb.py index b514037f8..63db181ff 100644 --- a/federatedscope/nlp/dataset/imdb.py +++ b/federatedscope/nlp/dataset/imdb.py @@ -16,11 +16,19 @@ def create_imdb_examples(data, debug=False): return examples -def create_imdb_dataset(data, split, tokenizer, max_seq_len, cache_dir='', - client_id=None, pretrain=False, debug=False, **kwargs): +def create_imdb_dataset(data, + split, + tokenizer, + max_seq_len, + cache_dir='', + client_id=None, + pretrain=False, + debug=False, + **kwargs): if pretrain: - return create_imdb_pretrain_dataset( - data, split, tokenizer, max_seq_len, cache_dir, client_id, debug) + return create_imdb_pretrain_dataset(data, split, tokenizer, + max_seq_len, cache_dir, client_id, + debug) save_dir = osp.join(cache_dir, 'finetune', str(client_id)) cache_file = osp.join(save_dir, split + '.pt') @@ -41,22 +49,31 @@ def create_imdb_dataset(data, split, tokenizer, max_seq_len, cache_dir='', if cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_file)) os.makedirs(save_dir, exist_ok=True) - torch.save({'examples': examples, - 'encoded_inputs': encoded_inputs}, cache_file) + torch.save({ + 'examples': examples, + 'encoded_inputs': encoded_inputs + }, cache_file) labels = [ex[1] for ex in examples] example_indices = torch.arange(encoded_inputs.input_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': encoded_inputs.input_ids, - 'token_type_ids': encoded_inputs.token_type_ids, - 'attention_mask': encoded_inputs.attention_mask, - 'labels': torch.LongTensor(labels), - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': encoded_inputs.input_ids, + 'token_type_ids': encoded_inputs.token_type_ids, + 'attention_mask': encoded_inputs.attention_mask, + 'labels': torch.LongTensor(labels), + 'example_indices': example_indices + }) return dataset, encoded_inputs, examples -def create_imdb_pretrain_dataset(data, split, tokenizer, max_seq_len, - cache_dir='', client_id=None, debug=False): +def create_imdb_pretrain_dataset(data, + split, + tokenizer, + max_seq_len, + cache_dir='', + client_id=None, + debug=False): save_dir = osp.join(cache_dir, 'pretrain', str(client_id)) cache_file = osp.join(save_dir, split + '.pt') if osp.exists(cache_file): @@ -82,12 +99,16 @@ def create_imdb_pretrain_dataset(data, split, tokenizer, max_seq_len, if cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_file)) os.makedirs(save_dir, exist_ok=True) - torch.save({'examples': examples, - 'encoded_inputs': encoded_inputs}, cache_file) + torch.save({ + 'examples': examples, + 'encoded_inputs': encoded_inputs + }, cache_file) example_indices = torch.arange(encoded_inputs.input_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': encoded_inputs.input_ids, - 'attention_mask': encoded_inputs.attention_mask, - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': encoded_inputs.input_ids, + 'attention_mask': encoded_inputs.attention_mask, + 'example_indices': example_indices + }) return dataset, encoded_inputs, examples diff --git a/federatedscope/nlp/dataset/msqg.py b/federatedscope/nlp/dataset/msqg.py index dab24eac7..6040d047c 100644 --- a/federatedscope/nlp/dataset/msqg.py +++ b/federatedscope/nlp/dataset/msqg.py @@ -18,38 +18,43 @@ def create_msqg_examples(data, debug=False): return src_examples, tgt_examples -def create_msqg_dataset(data, split, tokenizer, max_src_len, max_tgt_len, - raw_cache_dir='', client_id=None, pretrain=False, - debug=False, **kwargs): +def create_msqg_dataset(data, + split, + tokenizer, + max_src_len, + max_tgt_len, + raw_cache_dir='', + client_id=None, + pretrain=False, + debug=False, + **kwargs): if pretrain: - return create_msqg_pretrain_dataset( - data, split, tokenizer, max_src_len, raw_cache_dir, client_id, - debug) + return create_msqg_pretrain_dataset(data, split, tokenizer, + max_src_len, raw_cache_dir, + client_id, debug) cache_dir = osp.join(raw_cache_dir, 'finetune', str(client_id), split) src_examples, tgt_examples = create_msqg_examples(data, debug) if osp.exists(cache_dir): logger.info('Loading cache file from \'{}\''.format(cache_dir)) - token_ids = np.memmap( - filename=osp.join(cache_dir, 'token_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='r', - dtype=np.int64) - token_type_ids = np.memmap( - filename=osp.join(cache_dir, 'token_type_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='r', - dtype=np.int64) - attention_mask = np.memmap( - filename=osp.join(cache_dir, 'attention_mask.memmap'), - shape=(len(src_examples), max_src_len), - mode='r', - dtype=np.int64) - labels = np.memmap( - filename=osp.join(cache_dir, 'labels.memmap'), - shape=(len(src_examples), max_tgt_len), - mode='r', - dtype=np.int64) + token_ids = np.memmap(filename=osp.join(cache_dir, 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + token_type_ids = np.memmap(filename=osp.join(cache_dir, + 'token_type_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + attention_mask = np.memmap(filename=osp.join(cache_dir, + 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + labels = np.memmap(filename=osp.join(cache_dir, 'labels.memmap'), + shape=(len(src_examples), max_tgt_len), + mode='r', + dtype=np.int64) token_ids = torch.from_numpy(token_ids) token_type_ids = torch.from_numpy(token_type_ids) @@ -78,26 +83,25 @@ def create_msqg_dataset(data, split, tokenizer, max_src_len, max_tgt_len, if raw_cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_dir)) os.makedirs(cache_dir, exist_ok=True) - token_ids = np.memmap( - filename=osp.join(cache_dir, 'token_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='w+', - dtype=np.int64) - token_type_ids = np.memmap( - filename=osp.join(cache_dir, 'token_type_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='w+', - dtype=np.int64) - attention_mask = np.memmap( - filename=osp.join(cache_dir, 'attention_mask.memmap'), - shape=(len(src_examples), max_src_len), - mode='w+', - dtype=np.int64) - labels = np.memmap( - filename=osp.join(cache_dir, 'labels.memmap'), - shape=(len(src_examples), max_tgt_len), - mode='w+', - dtype=np.int64) + token_ids = np.memmap(filename=osp.join(cache_dir, + 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + token_type_ids = np.memmap(filename=osp.join( + cache_dir, 'token_type_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + attention_mask = np.memmap(filename=osp.join( + cache_dir, 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + labels = np.memmap(filename=osp.join(cache_dir, 'labels.memmap'), + shape=(len(src_examples), max_tgt_len), + mode='w+', + dtype=np.int64) for i in range(len(src_examples)): token_ids[i] = src_encoded.input_ids[i] @@ -117,30 +121,36 @@ def create_msqg_dataset(data, split, tokenizer, max_src_len, max_tgt_len, labels = tgt_encoded.input_ids example_indices = torch.arange(token_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': token_ids, - 'token_type_ids': token_type_ids, - 'attention_mask': attention_mask, - 'labels': labels, - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': token_ids, + 'token_type_ids': token_type_ids, + 'attention_mask': attention_mask, + 'labels': labels, + 'example_indices': example_indices + }) return dataset, None, None -def create_msqg_pretrain_dataset(data, split, tokenizer, max_src_len, - raw_cache_dir='', client_id=None, debug=False): +def create_msqg_pretrain_dataset(data, + split, + tokenizer, + max_src_len, + raw_cache_dir='', + client_id=None, + debug=False): cache_dir = osp.join(raw_cache_dir, 'pretrain', str(client_id), split) src_examples, tgt_examples = create_msqg_examples(data, debug) if osp.exists(cache_dir): logger.info('Loading cache file from \'{}\''.format(cache_dir)) - token_ids = np.memmap( - filename=osp.join(cache_dir, 'token_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='r', - dtype=np.int64) - attention_mask = np.memmap( - filename=osp.join(cache_dir, 'attention_mask.memmap'), - shape=(len(src_examples), max_src_len), - mode='r', - dtype=np.int64) + token_ids = np.memmap(filename=osp.join(cache_dir, 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + attention_mask = np.memmap(filename=osp.join(cache_dir, + 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) token_ids = torch.from_numpy(token_ids) attention_mask = torch.from_numpy(attention_mask) else: @@ -161,16 +171,16 @@ def create_msqg_pretrain_dataset(data, split, tokenizer, max_src_len, if raw_cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_dir)) os.makedirs(cache_dir, exist_ok=True) - token_ids = np.memmap( - filename=osp.join(cache_dir, 'token_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='w+', - dtype=np.int64) - attention_mask = np.memmap( - filename=osp.join(cache_dir, 'attention_mask.memmap'), - shape=(len(src_examples), max_src_len), - mode='w+', - dtype=np.int64) + token_ids = np.memmap(filename=osp.join(cache_dir, + 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + attention_mask = np.memmap(filename=osp.join( + cache_dir, 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) for i in range(len(src_examples)): token_ids[i] = src_encoded.input_ids[i] @@ -183,7 +193,9 @@ def create_msqg_pretrain_dataset(data, split, tokenizer, max_src_len, attention_mask = src_encoded.attention_mask example_indices = torch.arange(token_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': token_ids, - 'attention_mask': attention_mask, - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': token_ids, + 'attention_mask': attention_mask, + 'example_indices': example_indices + }) return dataset, None, None diff --git a/federatedscope/nlp/dataset/newsqa.py b/federatedscope/nlp/dataset/newsqa.py index 04eb44110..08aa36903 100644 --- a/federatedscope/nlp/dataset/newsqa.py +++ b/federatedscope/nlp/dataset/newsqa.py @@ -53,8 +53,8 @@ def get_char_to_word_positions(context, answer, start_char_pos, is_impossible): char_to_word_offset = [] is_prev_whitespace = True for c in context: - is_whitespace = (c == ' ' or c == '\t' or c == '\r' or c == '\n' or - ord(c) == 0x202F) + is_whitespace = (c == ' ' or c == '\t' or c == '\r' or c == '\n' + or ord(c) == 0x202F) if is_whitespace: is_prev_whitespace = True else: @@ -153,30 +153,40 @@ def create_newsqa_examples(data, split, debug=False): is_impossible = qa['is_impossible'] if 'is_impossible' in qa else False if not is_impossible: answers = qa['detected_answers'] - spans = sorted([span for spans in answers - for span in spans['char_spans']]) + spans = sorted( + [span for spans in answers for span in spans['char_spans']]) if split == 'train': - train_answer = context[spans[0][0]: spans[0][1] + 1] + train_answer = context[spans[0][0]:spans[0][1] + 1] start_char_pos = spans[0][0] else: - val_answer = [{'text': context[spans[i][0]: spans[i][1] + 1], - 'answer_start': spans[i][0]} - for i in range(len(spans))] + val_answer = [{ + 'text': context[spans[i][0]:spans[i][1] + 1], + 'answer_start': spans[i][0] + } for i in range(len(spans))] start_pos, end_pos, context_tokens = get_char_to_word_positions( context, train_answer, start_char_pos, is_impossible) - examples.append(NewsQAExample(qa_id, question, context, train_answer, - val_answer, start_pos, end_pos, - context_tokens, is_impossible)) + examples.append( + NewsQAExample(qa_id, question, context, train_answer, val_answer, + start_pos, end_pos, context_tokens, is_impossible)) return examples -def create_newsqa_dataset(data, split, tokenizer, max_seq_len, max_query_len, - trunc_stride, cache_dir='', client_id=None, - pretrain=False, debug=False, **kwargs): +def create_newsqa_dataset(data, + split, + tokenizer, + max_seq_len, + max_query_len, + trunc_stride, + cache_dir='', + client_id=None, + pretrain=False, + debug=False, + **kwargs): if pretrain: - return create_newsqa_pretrain_dataset( - data, split, tokenizer, max_seq_len, cache_dir, client_id, debug) + return create_newsqa_pretrain_dataset(data, split, tokenizer, + max_seq_len, cache_dir, + client_id, debug) save_dir = osp.join(cache_dir, 'finetune', str(client_id)) cache_file = osp.join(save_dir, split + '.pt') @@ -193,8 +203,8 @@ def create_newsqa_dataset(data, split, tokenizer, max_seq_len, max_query_len, if split == 'train' and not example.is_impossible: start_pos = example.start_position end_pos = example.end_position - actual_answer = ' '.join(example.context_tokens[ - start_pos:(end_pos + 1)]) + actual_answer = ' '.join( + example.context_tokens[start_pos:(end_pos + 1)]) cleaned_answer = ' '.join(example.train_answer.strip().split()) if actual_answer.find(cleaned_answer) == -1: logger.info('Could not find answer: {} vs. {}'.format( @@ -214,8 +224,8 @@ def create_newsqa_dataset(data, split, tokenizer, max_seq_len, max_query_len, if split == 'train' and not example.is_impossible: subtoken_start_pos = tok_to_subtok_idx[example.start_position] if example.end_position < len(example.context_tokens) - 1: - subtoken_end_pos = tok_to_subtok_idx[ - example.end_position + 1] - 1 + subtoken_end_pos = tok_to_subtok_idx[example.end_position + + 1] - 1 else: subtoken_end_pos = len(context_subtokens) - 1 subtoken_start_pos, subtoken_end_pos = \ @@ -236,14 +246,16 @@ def create_newsqa_dataset(data, split, tokenizer, max_seq_len, max_query_len, encoded_input = encode(tokenizer, text_a, text_b, max_seq_len, max_query_len, added_trunc_size) context_start_pos = len(spans) * trunc_stride - context_len = min(len(context_subtokens) - context_start_pos, - max_seq_len - len_question - 3) + context_len = min( + len(context_subtokens) - context_start_pos, + max_seq_len - len_question - 3) context_end_pos = context_start_pos + context_len - 1 if tokenizer.pad_token_id in encoded_input.token_ids: - non_padded_ids = encoded_input.token_ids[ - :encoded_input.token_ids.index( - tokenizer.pad_token_id)] + non_padded_ids = encoded_input.token_ids[:encoded_input. + token_ids.index( + tokenizer. + pad_token_id)] else: non_padded_ids = encoded_input.token_ids tokens = tokenizer.convert_ids_to_tokens(non_padded_ids) @@ -306,31 +318,40 @@ def create_newsqa_dataset(data, split, tokenizer, max_seq_len, max_query_len, if cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_file)) os.makedirs(save_dir, exist_ok=True) - torch.save({'examples': examples, - 'encoded_inputs': encoded_inputs}, cache_file) + torch.save({ + 'examples': examples, + 'encoded_inputs': encoded_inputs + }, cache_file) token_ids = torch.LongTensor([inp.token_ids for inp in encoded_inputs]) - token_type_ids = torch.LongTensor([inp.token_type_ids - for inp in encoded_inputs]) - attention_mask = torch.LongTensor([inp.attention_mask - for inp in encoded_inputs]) - start_positions = torch.LongTensor([inp.start_position - for inp in encoded_inputs]) - end_positions = torch.LongTensor([inp.end_position for - inp in encoded_inputs]) + token_type_ids = torch.LongTensor( + [inp.token_type_ids for inp in encoded_inputs]) + attention_mask = torch.LongTensor( + [inp.attention_mask for inp in encoded_inputs]) + start_positions = torch.LongTensor( + [inp.start_position for inp in encoded_inputs]) + end_positions = torch.LongTensor( + [inp.end_position for inp in encoded_inputs]) example_indices = torch.arange(token_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': token_ids, - 'token_type_ids': token_type_ids, - 'attention_mask': attention_mask, - 'start_positions': start_positions, - 'end_positions': end_positions, - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': token_ids, + 'token_type_ids': token_type_ids, + 'attention_mask': attention_mask, + 'start_positions': start_positions, + 'end_positions': end_positions, + 'example_indices': example_indices + }) return dataset, encoded_inputs, examples -def create_newsqa_pretrain_dataset(data, split, tokenizer, max_seq_len, - cache_dir='', client_id=None, debug=False): +def create_newsqa_pretrain_dataset(data, + split, + tokenizer, + max_seq_len, + cache_dir='', + client_id=None, + debug=False): save_dir = osp.join(cache_dir, 'pretrain', str(client_id)) cache_file = osp.join(save_dir, split + '.pt') if osp.exists(cache_file): @@ -356,12 +377,16 @@ def create_newsqa_pretrain_dataset(data, split, tokenizer, max_seq_len, if cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_file)) os.makedirs(save_dir, exist_ok=True) - torch.save({'examples': examples, - 'encoded_inputs': encoded_inputs}, cache_file) + torch.save({ + 'examples': examples, + 'encoded_inputs': encoded_inputs + }, cache_file) example_indices = torch.arange(encoded_inputs.input_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': encoded_inputs.input_ids, - 'attention_mask': encoded_inputs.attention_mask, - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': encoded_inputs.input_ids, + 'attention_mask': encoded_inputs.attention_mask, + 'example_indices': example_indices + }) return dataset, encoded_inputs, examples diff --git a/federatedscope/nlp/dataset/preprocess/get_hfl_data.py b/federatedscope/nlp/dataset/preprocess/get_hfl_data.py index 9757011ca..71e5f1b8f 100644 --- a/federatedscope/nlp/dataset/preprocess/get_hfl_data.py +++ b/federatedscope/nlp/dataset/preprocess/get_hfl_data.py @@ -27,14 +27,18 @@ def get_data(self): for i, dataset in enumerate(self.datasets): if dataset not in HFL_NAMES: raise ValueError(f'No HFL dataset named {dataset}') - train_val_data = self._load_data( - dataset, 'train', self.num_grouped_clients[i]) - train_data = [data[:int(self.train_frac * len(data))] - for data in train_val_data] - val_data = [data[int(self.train_frac * len(data)):] - for data in train_val_data] - test_data = self._load_data( - dataset, 'test', self.num_grouped_clients[i]) + train_val_data = self._load_data(dataset, 'train', + self.num_grouped_clients[i]) + train_data = [ + data[:int(self.train_frac * len(data))] + for data in train_val_data + ] + val_data = [ + data[int(self.train_frac * len(data)):] + for data in train_val_data + ] + test_data = self._load_data(dataset, 'test', + self.num_grouped_clients[i]) all_train_data.extend(train_data) all_val_data.extend(val_data) all_test_data.extend(test_data) @@ -78,7 +82,8 @@ def _load_data(self, dataset, split, num_clients): data.append({'text': text, 'label': label}) elif dataset == 'squad': - with open(os.path.join(data_dir, split + '.json'), 'r', + with open(os.path.join(data_dir, split + '.json'), + 'r', encoding='utf-8') as reader: raw_data = json.load(reader)['data'] for line in raw_data: @@ -101,11 +106,13 @@ def _load_data(self, dataset, split, num_clients): src_file = os.path.join(data_dir, split + '.src') tgt_file = os.path.join(data_dir, split + '.tgt') with open(src_file) as f: - src_data = [line.strip().replace('', '[SEP]') - for line in f] + src_data = [ + line.strip().replace('', '[SEP]') for line in f + ] with open(tgt_file) as f: - tgt_data = [line.strip().replace('', '[SEP]') - for line in f] + tgt_data = [ + line.strip().replace('', '[SEP]') for line in f + ] for src, tgt in zip(src_data, tgt_data): data.append({'src': src, 'tgt': tgt}) @@ -117,7 +124,7 @@ def _load_data(self, dataset, split, num_clients): for i in range(num_clients): num_split = n if i < num_clients - 1 else \ len(data) - n * (num_clients - 1) - cur_data = data[data_idx: data_idx + num_split] + cur_data = data[data_idx:data_idx + num_split] data_idx += num_split all_split_data.append(cur_data) logger.info(f'Client id: {i + 1}, num samples: {num_split}') diff --git a/federatedscope/nlp/dataset/squad.py b/federatedscope/nlp/dataset/squad.py index 3e30d7b71..56b22ca40 100644 --- a/federatedscope/nlp/dataset/squad.py +++ b/federatedscope/nlp/dataset/squad.py @@ -53,8 +53,8 @@ def get_char_to_word_positions(context, answer, start_char_pos, is_impossible): char_to_word_offset = [] is_prev_whitespace = True for c in context: - is_whitespace = (c == ' ' or c == '\t' or c == '\r' or c == '\n' or - ord(c) == 0x202F) + is_whitespace = (c == ' ' or c == '\t' or c == '\r' or c == '\n' + or ord(c) == 0x202F) if is_whitespace: is_prev_whitespace = True else: @@ -160,18 +160,27 @@ def create_squad_examples(data, split, debug=False): start_pos, end_pos, context_tokens = get_char_to_word_positions( context, train_answer, start_char_pos, is_impossible) - examples.append(SquadExample(qa_id, question, context, train_answer, - val_answer, start_pos, end_pos, - context_tokens, is_impossible)) + examples.append( + SquadExample(qa_id, question, context, train_answer, val_answer, + start_pos, end_pos, context_tokens, is_impossible)) return examples -def create_squad_dataset(data, split, tokenizer, max_seq_len, max_query_len, - trunc_stride, cache_dir='', client_id=None, - pretrain=False, debug=False, **kwargs): +def create_squad_dataset(data, + split, + tokenizer, + max_seq_len, + max_query_len, + trunc_stride, + cache_dir='', + client_id=None, + pretrain=False, + debug=False, + **kwargs): if pretrain: - return create_squad_pretrain_dataset( - data, split, tokenizer, max_seq_len, cache_dir, client_id, debug) + return create_squad_pretrain_dataset(data, split, tokenizer, + max_seq_len, cache_dir, client_id, + debug) save_dir = osp.join(cache_dir, 'finetune', str(client_id)) cache_file = osp.join(save_dir, split + '.pt') @@ -188,8 +197,8 @@ def create_squad_dataset(data, split, tokenizer, max_seq_len, max_query_len, if split == 'train' and not example.is_impossible: start_pos = example.start_position end_pos = example.end_position - actual_answer = ' '.join(example.context_tokens[ - start_pos:(end_pos + 1)]) + actual_answer = ' '.join( + example.context_tokens[start_pos:(end_pos + 1)]) cleaned_answer = ' '.join(example.train_answer.strip().split()) if actual_answer.find(cleaned_answer) == -1: logger.info('Could not find answer: {} vs. {}'.format( @@ -209,8 +218,8 @@ def create_squad_dataset(data, split, tokenizer, max_seq_len, max_query_len, if split == 'train' and not example.is_impossible: subtoken_start_pos = tok_to_subtok_idx[example.start_position] if example.end_position < len(example.context_tokens) - 1: - subtoken_end_pos = tok_to_subtok_idx[ - example.end_position + 1] - 1 + subtoken_end_pos = tok_to_subtok_idx[example.end_position + + 1] - 1 else: subtoken_end_pos = len(context_subtokens) - 1 subtoken_start_pos, subtoken_end_pos = \ @@ -231,14 +240,16 @@ def create_squad_dataset(data, split, tokenizer, max_seq_len, max_query_len, encoded_input = encode(tokenizer, text_a, text_b, max_seq_len, max_query_len, added_trunc_size) context_start_pos = len(spans) * trunc_stride - context_len = min(len(context_subtokens) - context_start_pos, - max_seq_len - len_question - 3) + context_len = min( + len(context_subtokens) - context_start_pos, + max_seq_len - len_question - 3) context_end_pos = context_start_pos + context_len - 1 if tokenizer.pad_token_id in encoded_input.token_ids: - non_padded_ids = encoded_input.token_ids[ - :encoded_input.token_ids.index( - tokenizer.pad_token_id)] + non_padded_ids = encoded_input.token_ids[:encoded_input. + token_ids.index( + tokenizer. + pad_token_id)] else: non_padded_ids = encoded_input.token_ids tokens = tokenizer.convert_ids_to_tokens(non_padded_ids) @@ -301,31 +312,40 @@ def create_squad_dataset(data, split, tokenizer, max_seq_len, max_query_len, if cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_file)) os.makedirs(save_dir, exist_ok=True) - torch.save({'examples': examples, - 'encoded_inputs': encoded_inputs}, cache_file) + torch.save({ + 'examples': examples, + 'encoded_inputs': encoded_inputs + }, cache_file) token_ids = torch.LongTensor([inp.token_ids for inp in encoded_inputs]) - token_type_ids = torch.LongTensor([inp.token_type_ids - for inp in encoded_inputs]) - attention_mask = torch.LongTensor([inp.attention_mask - for inp in encoded_inputs]) - start_positions = torch.LongTensor([inp.start_position - for inp in encoded_inputs]) - end_positions = torch.LongTensor([inp.end_position - for inp in encoded_inputs]) + token_type_ids = torch.LongTensor( + [inp.token_type_ids for inp in encoded_inputs]) + attention_mask = torch.LongTensor( + [inp.attention_mask for inp in encoded_inputs]) + start_positions = torch.LongTensor( + [inp.start_position for inp in encoded_inputs]) + end_positions = torch.LongTensor( + [inp.end_position for inp in encoded_inputs]) example_indices = torch.arange(token_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': token_ids, - 'token_type_ids': token_type_ids, - 'attention_mask': attention_mask, - 'start_positions': start_positions, - 'end_positions': end_positions, - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': token_ids, + 'token_type_ids': token_type_ids, + 'attention_mask': attention_mask, + 'start_positions': start_positions, + 'end_positions': end_positions, + 'example_indices': example_indices + }) return dataset, encoded_inputs, examples -def create_squad_pretrain_dataset(data, split, tokenizer, max_seq_len, - cache_dir='', client_id=None, debug=False): +def create_squad_pretrain_dataset(data, + split, + tokenizer, + max_seq_len, + cache_dir='', + client_id=None, + debug=False): save_dir = osp.join(cache_dir, 'pretrain', str(client_id)) cache_file = osp.join(save_dir, split + '.pt') if osp.exists(cache_file): @@ -351,12 +371,16 @@ def create_squad_pretrain_dataset(data, split, tokenizer, max_seq_len, if cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_file)) os.makedirs(save_dir, exist_ok=True) - torch.save({'examples': examples, - 'encoded_inputs': encoded_inputs}, cache_file) + torch.save({ + 'examples': examples, + 'encoded_inputs': encoded_inputs + }, cache_file) example_indices = torch.arange(encoded_inputs.input_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': encoded_inputs.input_ids, - 'attention_mask': encoded_inputs.attention_mask, - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': encoded_inputs.input_ids, + 'attention_mask': encoded_inputs.attention_mask, + 'example_indices': example_indices + }) return dataset, encoded_inputs, examples diff --git a/federatedscope/nlp/dataset/utils.py b/federatedscope/nlp/dataset/utils.py index 94b88cf95..71820359c 100644 --- a/federatedscope/nlp/dataset/utils.py +++ b/federatedscope/nlp/dataset/utils.py @@ -12,7 +12,6 @@ from torch.utils.data.dataset import Dataset from transformers.models.bert import BertTokenizerFast - # ------------------------ # utils for shakespeare dataset @@ -110,8 +109,9 @@ def split_sent(examples, eoq='[unused2]', tokenize=True): class DictDataset(Dataset): def __init__(self, inputs): super().__init__() - assert all(list(inputs.values())[0].size(0) == v.size(0) - for v in inputs.values()), "Size mismatch between tensors" + assert all( + list(inputs.values())[0].size(0) == v.size(0) + for v in inputs.values()), "Size mismatch between tensors" self.inputs = inputs def __getitem__(self, index): diff --git a/federatedscope/nlp/trainer/utils.py b/federatedscope/nlp/trainer/utils.py index 0b610a8e6..5a5127304 100644 --- a/federatedscope/nlp/trainer/utils.py +++ b/federatedscope/nlp/trainer/utils.py @@ -1,4 +1,3 @@ - class AverageMeter(object): def __init__(self): self.reset() @@ -17,8 +16,14 @@ def update(self, val, n=1): class ContrastiveMonitor(object): - def __init__(self, stat=1, enc_hidden=None, synth_tokens=None, - dec_hidden=None, dec_out=None, all_group_ids=None, topk_group_ids=None): + def __init__(self, + stat=1, + enc_hidden=None, + synth_tokens=None, + dec_hidden=None, + dec_out=None, + all_group_ids=None, + topk_group_ids=None): self.stat = stat self.enc_hidden = enc_hidden self.synth_tokens = synth_tokens From b7590481ce84f635a32215af8a86422994d14649 Mon Sep 17 00:00:00 2001 From: cheneydon Date: Tue, 25 Oct 2022 12:10:05 +0800 Subject: [PATCH 4/5] update dataset for hetero-fednlp --- .../core/auxiliaries/model_builder.py | 4 +- federatedscope/core/configs/cfg_aggregator.py | 12 +- federatedscope/core/configs/cfg_data.py | 12 +- federatedscope/core/configs/cfg_model.py | 26 +- .../nlp/dataloader/data_collator.py | 117 ++-- .../nlp/dataloader/hfl_dataloader.py | 284 ++++----- federatedscope/nlp/dataset/agnews.py | 56 +- federatedscope/nlp/dataset/cnndm.py | 163 ++--- federatedscope/nlp/dataset/imdb.py | 57 +- federatedscope/nlp/dataset/msqg.py | 164 ++--- federatedscope/nlp/dataset/newsqa.py | 121 ++-- .../nlp/dataset/preprocess/get_hfl_data.py | 53 +- federatedscope/nlp/dataset/squad.py | 108 ++-- federatedscope/nlp/dataset/utils.py | 6 +- federatedscope/nlp/loss/label_smooth_loss.py | 31 + federatedscope/nlp/model/__init__.py | 10 +- federatedscope/nlp/model/hfl_model.py | 561 ++++++++++++++++++ federatedscope/nlp/trainer/utils.py | 11 +- 18 files changed, 1284 insertions(+), 512 deletions(-) create mode 100644 federatedscope/nlp/loss/label_smooth_loss.py create mode 100644 federatedscope/nlp/model/hfl_model.py diff --git a/federatedscope/core/auxiliaries/model_builder.py b/federatedscope/core/auxiliaries/model_builder.py index 522159153..ae75c64fd 100644 --- a/federatedscope/core/auxiliaries/model_builder.py +++ b/federatedscope/core/auxiliaries/model_builder.py @@ -1,8 +1,6 @@ import logging - -import numpy as np - import federatedscope.register as register +from federatedscope.nlp.model import * logger = logging.getLogger(__name__) diff --git a/federatedscope/core/configs/cfg_aggregator.py b/federatedscope/core/configs/cfg_aggregator.py index e70df3e9f..3b3b9d4b7 100644 --- a/federatedscope/core/configs/cfg_aggregator.py +++ b/federatedscope/core/configs/cfg_aggregator.py @@ -4,12 +4,12 @@ def extend_aggregator_cfg(cfg): cfg.aggregator = CN() - cfg.aggregator.num_agg_groups = None - cfg.aggregator.num_agg_topk = None - cfg.aggregator.inside_weight = None - cfg.aggregator.outside_weight = None - cfg.aggregator.proto_weight = None - cfg.aggregator.synth_ratio = None + cfg.aggregator.num_agg_groups = 1 + cfg.aggregator.num_agg_topk = 100 + cfg.aggregator.inside_weight = 1.0 + cfg.aggregator.outside_weight = 0.0 + cfg.aggregator.proto_weight = 0.0 + cfg.aggregator.synth_ratio = 0.5 # --------------- register corresponding check function ---------- cfg.register_cfg_check_fun(assert_aggregator_cfg) diff --git a/federatedscope/core/configs/cfg_data.py b/federatedscope/core/configs/cfg_data.py index f09e13b95..69de4f334 100644 --- a/federatedscope/core/configs/cfg_data.py +++ b/federatedscope/core/configs/cfg_data.py @@ -56,12 +56,12 @@ def extend_data_cfg(cfg): cfg.data.quadratic.max_curv = 12.5 # fednlp - cfg.data.datasets = [] - cfg.data.num_grouped_clients = [] - cfg.data.max_seq_len = 0 - cfg.data.max_tgt_len = 0 - cfg.data.max_query_len = 0 - cfg.data.trunc_stride = 0 + cfg.data.datasets = ['imdb', 'agnews', 'squad', 'newsqa', 'cnndm', 'msqg'] + cfg.data.num_grouped_clients = [1, 3, 3, 2, 5, 4] + cfg.data.max_seq_len = 384 + cfg.data.max_tgt_len = 128 + cfg.data.max_query_len = 128 + cfg.data.trunc_stride = 128 cfg.data.cache_dir = '' cfg.data.num_contrast = 0 cfg.data.debug = False diff --git a/federatedscope/core/configs/cfg_model.py b/federatedscope/core/configs/cfg_model.py index 56fe98412..eb315f98e 100644 --- a/federatedscope/core/configs/cfg_model.py +++ b/federatedscope/core/configs/cfg_model.py @@ -25,15 +25,31 @@ def extend_model_cfg(cfg): cfg.model.input_shape = () # A tuple, e.g., (in_channel, h, w) # fednlp - cfg.model.model_type = '' - cfg.model.bos_token = '' - cfg.model.eos_token = '' - cfg.model.eoq_token = '' - cfg.model.pad_token = '' + cfg.model.model_type = 'google/bert_uncased_L-2_H-128_A-2' + cfg.model.bos_token = '[unused0]' + cfg.model.eos_token = '[unused1]' + cfg.model.eoq_token = '[unused2]' cfg.model.bos_token_id = -1 cfg.model.eos_token_id = -1 cfg.model.eoq_token_id = -1 cfg.model.pad_token_id = -1 + cfg.model.task = '' + cfg.model.pretrain_task = '' + cfg.model.pretrain_tasks = [] + cfg.model.downstream_tasks = [] + cfg.model.num_labels = 1 + cfg.model.max_length = 200 + cfg.model.min_length = 1 + cfg.model.no_repeat_ngram_size = 3 + cfg.model.length_penalty = 2.0 + cfg.model.num_beams = 5 + cfg.model.label_smoothing = 0.1 + cfg.model.n_best_size = 20 + cfg.model.max_answer_len = 30 + cfg.model.null_score_diff_threshold = 0.0 + cfg.model.train_contrast = False + cfg.model.contrast_topk = 100 + cfg.model.contrast_temp = 1.0 # ---------------------------------------------------------------------- # # Criterion related options diff --git a/federatedscope/nlp/dataloader/data_collator.py b/federatedscope/nlp/dataloader/data_collator.py index 289806763..d7498d7c1 100644 --- a/federatedscope/nlp/dataloader/data_collator.py +++ b/federatedscope/nlp/dataloader/data_collator.py @@ -12,8 +12,10 @@ def __init__(self, tokenizer, mlm_probability=0.15): def __call__(self, examples): """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ - examples = {k: torch.stack([x[k] for x in examples]) - for k in examples[0].keys()} + examples = { + k: torch.stack([x[k] for x in examples]) + for k in examples[0].keys() + } token_ids = examples['token_ids'] attention_mask = examples['attention_mask'] labels = token_ids.clone() @@ -27,8 +29,9 @@ def __call__(self, examples): val, already_has_special_tokens=True) for val in labels.tolist() ] - probability_matrix.masked_fill_( - torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) + probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, + dtype=torch.bool), + value=0.0) if self.tokenizer._pad_token is not None: padding_mask = labels.eq(self.tokenizer.pad_token_id) probability_matrix.masked_fill_(padding_mask, value=0.0) @@ -37,8 +40,8 @@ def __call__(self, examples): # 80% of the time, we replace masked input tokens with # tokenizer.mask_token ([MASK]) - indices_replaced = torch.bernoulli( - torch.full(labels.shape, 0.8)).bool() & masked_indices + indices_replaced = torch.bernoulli(torch.full( + labels.shape, 0.8)).bool() & masked_indices token_ids[indices_replaced] = self.tokenizer.convert_tokens_to_ids( self.tokenizer.mask_token) @@ -46,16 +49,19 @@ def __call__(self, examples): indices_random = \ torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & \ masked_indices & ~indices_replaced - random_words = torch.randint(len(self.tokenizer), labels.shape, + random_words = torch.randint(len(self.tokenizer), + labels.shape, dtype=torch.long) token_ids[indices_random] = random_words[indices_random] # The rest of the time (10% of the time) we keep the masked input # tokens unchanged - return {'token_ids': token_ids, - 'attention_mask': attention_mask, - 'labels': labels, - 'example_indices': examples['example_indices']} + return { + 'token_ids': token_ids, + 'attention_mask': attention_mask, + 'labels': labels, + 'example_indices': examples['example_indices'] + } class DataCollatorForDenoisingTasks(object): @@ -66,7 +72,10 @@ class DataCollatorForDenoisingTasks(object): The default paramters is based on BART paper https://arxiv.org/abs/1910.13461. """ - def __init__(self, tokenizer, mask_ratio=0.3, poisson_lambda=3.0, + def __init__(self, + tokenizer, + mask_ratio=0.3, + poisson_lambda=3.0, permutate_sentence_ratio=1.0): self.tokenizer = tokenizer self.mask_ratio = mask_ratio @@ -74,8 +83,10 @@ def __init__(self, tokenizer, mask_ratio=0.3, poisson_lambda=3.0, self.permutate_sentence_ratio = permutate_sentence_ratio def __call__(self, examples): - examples = {k: torch.stack([x[k] for x in examples]) - for k in examples[0].keys()} + examples = { + k: torch.stack([x[k] for x in examples]) + for k in examples[0].keys() + } token_ids = examples['token_ids'].numpy() attention_mask = examples['attention_mask'].numpy() labels = token_ids.copy() @@ -89,52 +100,55 @@ def __call__(self, examples): if self.mask_ratio: token_ids, _ = self.add_whole_word_mask(token_ids, do_permutate) - num_non_padding = np.sum( - token_ids != self.tokenizer.pad_token_id, axis=-1) + num_non_padding = np.sum(token_ids != self.tokenizer.pad_token_id, + axis=-1) for i in range(len(attention_mask)): attention_mask[i][num_non_padding[i]:] = 0 token_ids = torch.from_numpy(token_ids) attention_mask = torch.from_numpy(attention_mask) labels = torch.from_numpy(labels) - return {'token_ids': token_ids, - 'attention_mask': attention_mask, - 'labels': labels, - 'example_indices': examples['example_indices']} + return { + 'token_ids': token_ids, + 'attention_mask': attention_mask, + 'labels': labels, + 'example_indices': examples['example_indices'] + } def permutate_sentences(self, inputs): results = inputs.copy() for i in range(inputs.shape[0]): full_stops = (inputs[i] == self.tokenizer.eoq_token_id) | ( - inputs[i] == self.tokenizer.eos_token_id) + inputs[i] == self.tokenizer.eos_token_id) full_stops = full_stops[None, :] - sentence_ends = np.argwhere( - full_stops[:, 1:] * ~full_stops[:, :-1]) + sentence_ends = np.argwhere(full_stops[:, 1:] * + ~full_stops[:, :-1]) if len(sentence_ends) == 0: continue sentence_ends[:, 1] += 2 - num_sentences = np.unique( - sentence_ends[:, 0], return_counts=True)[1] + num_sentences = np.unique(sentence_ends[:, 0], + return_counts=True)[1] num_to_permute = np.ceil( (num_sentences * 2 * self.permutate_sentence_ratio) / 2.0).astype(int) sentence_ends = np.split( - sentence_ends[:, 1], np.unique( - sentence_ends[:, 0], return_index=True)[1][1:]) + sentence_ends[:, 1], + np.unique(sentence_ends[:, 0], return_index=True)[1][1:]) - substitutions = np.random.permutation(num_sentences[0])[ - :num_to_permute[0]] + substitutions = np.random.permutation( + num_sentences[0])[:num_to_permute[0]] ordering = np.arange(0, num_sentences[0]) ordering[substitutions] = substitutions[np.random.permutation( num_to_permute[0])] index = 0 for j in ordering: - sentence = inputs[i, (sentence_ends[0][j - 1] if j > 0 else - 0) : sentence_ends[0][j]] - results[i, index : index + sentence.shape[0]] = sentence + sentence = inputs[i, ( + sentence_ends[0][j - + 1] if j > 0 else 0):sentence_ends[0][j]] + results[i, index:index + sentence.shape[0]] = sentence index += sentence.shape[0] num_non_padding = np.sum(results != self.tokenizer.pad_token_id, @@ -152,23 +166,26 @@ def add_whole_word_mask(self, inputs, do_permutate): special_tokens_mask = [ self.tokenizer.get_special_tokens_mask( - val,already_has_special_tokens=True) for val in labels.tolist() + val, already_has_special_tokens=True) + for val in labels.tolist() ] special_tokens_mask = np.array(special_tokens_mask, dtype=bool) # determine how many tokens we need to mask in total is_token = ~(labels == self.tokenizer.pad_token_id) & \ ~special_tokens_mask - num_to_mask = int(math.ceil(is_token.astype(float).sum() * - self.mask_ratio)) + num_to_mask = int( + math.ceil(is_token.astype(float).sum() * self.mask_ratio)) if num_to_mask == 0: return inputs, labels # generate a sufficient number of span lengths - lengths = poisson(lam=self.poisson_lambda, size=(num_to_mask,)) + lengths = poisson(lam=self.poisson_lambda, size=(num_to_mask, )) while np.cumsum(lengths, 0)[-1] < num_to_mask: - lengths = np.concatenate([lengths, poisson( - lam=self.poisson_lambda, size=(num_to_mask,))]) + lengths = np.concatenate([ + lengths, + poisson(lam=self.poisson_lambda, size=(num_to_mask, )) + ]) # remove all spans of length 0 # Note that BART inserts additional mask tokens where length == 0, @@ -177,11 +194,11 @@ def add_whole_word_mask(self, inputs, do_permutate): # trim to about num_to_mask tokens idx = np.argmin(np.abs(np.cumsum(lengths, 0) - num_to_mask)) + 1 - lengths = lengths[: idx + 1] + lengths = lengths[:idx + 1] # select span start indices token_indices = np.argwhere(is_token == 1) - span_starts = permutation(token_indices.shape[0])[: lengths.shape[0]] + span_starts = permutation(token_indices.shape[0])[:lengths.shape[0]] # prepare mask masked_indices = np.array(token_indices[span_starts]) @@ -213,23 +230,29 @@ def add_whole_word_mask(self, inputs, do_permutate): # remove mask tokens that are not starts of spans to_remove = (mask == 1) & np.roll((mask == 1), 1, 1) - new_inputs = np.full_like( - labels, fill_value=self.tokenizer.pad_token_id) + new_inputs = np.full_like(labels, + fill_value=self.tokenizer.pad_token_id) # splits = list(map(lambda x: x.reshape(-1), np.split(inputs_copy, # indices_or_sections=2, axis=0)) - for i, example in enumerate(np.split( - inputs, indices_or_sections=new_inputs.shape[0], axis=0)): + for i, example in enumerate( + np.split(inputs, + indices_or_sections=new_inputs.shape[0], + axis=0)): new_example = example[0][~to_remove[i]] - new_inputs[i, 0 : new_example.shape[0]] = new_example + new_inputs[i, 0:new_example.shape[0]] = new_example # batching now fixed return new_inputs, labels class DataCollatorForPFedNLP(object): - def __init__(self, tokenizer, mlm_probability=0.15, mask_ratio=0.3, - poisson_lambda=3.0, permutate_sentence_ratio=1.0): + def __init__(self, + tokenizer, + mlm_probability=0.15, + mask_ratio=0.3, + poisson_lambda=3.0, + permutate_sentence_ratio=1.0): self.mlm_collator = DataCollatorForMLM(tokenizer, mlm_probability) self.denoise_collator = DataCollatorForDenoisingTasks( tokenizer, mask_ratio, poisson_lambda, permutate_sentence_ratio) diff --git a/federatedscope/nlp/dataloader/hfl_dataloader.py b/federatedscope/nlp/dataloader/hfl_dataloader.py index e9f90c0b2..d387e985f 100644 --- a/federatedscope/nlp/dataloader/hfl_dataloader.py +++ b/federatedscope/nlp/dataloader/hfl_dataloader.py @@ -81,8 +81,8 @@ def extend_cfg(cfg, cfg_client): def create_data(data, split, tokenizer, task, model_type, max_seq_len, - max_query_len, trunc_stride, max_tgt_len, cache_dir, - client_id, pretrain, debug): + max_query_len, trunc_stride, max_tgt_len, cache_dir, client_id, + pretrain, debug): if task == 'imdb': create_dataset_func = create_imdb_dataset elif task == 'agnews': @@ -134,9 +134,11 @@ def load_fednlp_data(config, client_config): logger.info(f'Preprocessing dataset {config.data.type}') data_processor = HFLDataProcessor(config) all_data = data_processor.get_data() - all_data_dict = {'train': all_data[0], - 'val': all_data[1], - 'test': all_data[2]} + all_data_dict = { + 'train': all_data[0], + 'val': all_data[1], + 'test': all_data[2] + } data_dict = dict() for client_id in tqdm(range(1, config.federate.client_num + 1)): @@ -145,54 +147,56 @@ def load_fednlp_data(config, client_config): cur_task = cfg_client.model.downstream_tasks[client_id - 1] \ if pretrain else cfg_client.model.task train_data, val_data, test_data = [ - create_data( - data=all_data_dict[split][client_id - 1], - split=split, - tokenizer=tokenizer, - task=cur_task, - model_type=model_type, - max_seq_len=getattr(cfg_client.data, 'max_seq_len', None), - max_query_len=getattr(cfg_client.data, 'max_query_len', None), - trunc_stride=getattr(cfg_client.data, 'trunc_stride', None), - max_tgt_len=getattr(cfg_client.data, 'max_tgt_len', None), - cache_dir=cache_dir, - client_id=client_id, - pretrain=pretrain, - debug=debug) - for split in ['train', 'val', 'test'] + create_data(data=all_data_dict[split][client_id - 1], + split=split, + tokenizer=tokenizer, + task=cur_task, + model_type=model_type, + max_seq_len=getattr(cfg_client.data, 'max_seq_len', + None), + max_query_len=getattr(cfg_client.data, 'max_query_len', + None), + trunc_stride=getattr(cfg_client.data, 'trunc_stride', + None), + max_tgt_len=getattr(cfg_client.data, 'max_tgt_len', + None), + cache_dir=cache_dir, + client_id=client_id, + pretrain=pretrain, + debug=debug) for split in ['train', 'val', 'test'] ] dataloader_dict = { 'train': { - 'dataloader': DataLoader( - dataset=train_data[0], - batch_size=cfg_client.data.batch_size, - shuffle=config.data.shuffle, - num_workers=config.data.num_workers, - collate_fn=data_collator, - pin_memory=config.use_gpu), + 'dataloader': DataLoader(dataset=train_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=config.data.shuffle, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), 'encoded': train_data[1], - 'examples': train_data[2]}, + 'examples': train_data[2] + }, 'val': { - 'dataloader': DataLoader( - dataset=val_data[0], - batch_size=cfg_client.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers, - collate_fn=data_collator, - pin_memory=config.use_gpu), + 'dataloader': DataLoader(dataset=val_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), 'encoded': val_data[1], - 'examples': val_data[2]}, + 'examples': val_data[2] + }, 'test': { - 'dataloader': DataLoader( - dataset=test_data[0], - batch_size=cfg_client.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers, - collate_fn=data_collator, - pin_memory=config.use_gpu), + 'dataloader': DataLoader(dataset=test_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), 'encoded': test_data[1], - 'examples': test_data[2]}, + 'examples': test_data[2] + }, } data_dict[client_id] = dataloader_dict @@ -212,9 +216,11 @@ def load_pfednlp_data(config, client_config): logger.info(f'Preprocessing dataset {config.data.type}') data_processor = HFLDataProcessor(config) all_data = data_processor.get_data() - all_data_dict = {'train': all_data[0], - 'val': all_data[1], - 'test': all_data[2]} + all_data_dict = { + 'train': all_data[0], + 'val': all_data[1], + 'test': all_data[2] + } data_dict = dict() for client_id in tqdm(range(1, config.federate.client_num + 1)): @@ -223,54 +229,56 @@ def load_pfednlp_data(config, client_config): cur_task = cfg_client.model.downstream_tasks[client_id - 1] \ if pretrain else cfg_client.model.task train_data, val_data, test_data = [ - create_data( - data=all_data_dict[split][client_id - 1], - split=split, - tokenizer=tokenizer, - task=cur_task, - model_type=model_type, - max_seq_len=getattr(cfg_client.data, 'max_seq_len', None), - max_query_len=getattr(cfg_client.data, 'max_query_len', None), - trunc_stride=getattr(cfg_client.data, 'trunc_stride', None), - max_tgt_len=getattr(cfg_client.data, 'max_tgt_len', None), - cache_dir=cache_dir, - client_id=client_id, - pretrain=pretrain, - debug=debug) - for split in ['train', 'val', 'test'] + create_data(data=all_data_dict[split][client_id - 1], + split=split, + tokenizer=tokenizer, + task=cur_task, + model_type=model_type, + max_seq_len=getattr(cfg_client.data, 'max_seq_len', + None), + max_query_len=getattr(cfg_client.data, 'max_query_len', + None), + trunc_stride=getattr(cfg_client.data, 'trunc_stride', + None), + max_tgt_len=getattr(cfg_client.data, 'max_tgt_len', + None), + cache_dir=cache_dir, + client_id=client_id, + pretrain=pretrain, + debug=debug) for split in ['train', 'val', 'test'] ] dataloader_dict = { 'train': { - 'dataloader': DataLoader( - dataset=train_data[0], - batch_size=cfg_client.data.batch_size, - shuffle=config.data.shuffle, - num_workers=config.data.num_workers, - collate_fn=data_collator, - pin_memory=config.use_gpu), + 'dataloader': DataLoader(dataset=train_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=config.data.shuffle, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), 'encoded': train_data[1], - 'examples': train_data[2]}, + 'examples': train_data[2] + }, 'val': { - 'dataloader': DataLoader( - dataset=val_data[0], - batch_size=cfg_client.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers, - collate_fn=data_collator, - pin_memory=config.use_gpu), + 'dataloader': DataLoader(dataset=val_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), 'encoded': val_data[1], - 'examples': val_data[2]}, + 'examples': val_data[2] + }, 'test': { - 'dataloader': DataLoader( - dataset=test_data[0], - batch_size=cfg_client.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers, - collate_fn=data_collator, - pin_memory=config.use_gpu), + 'dataloader': DataLoader(dataset=test_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), 'encoded': test_data[1], - 'examples': test_data[2]}, + 'examples': test_data[2] + }, } data_dict[client_id] = dataloader_dict @@ -290,9 +298,11 @@ def load_pfednlp_contrast_data(config, client_config): logger.info(f'Preprocessing dataset {config.data.type}') data_processor = HFLDataProcessor(config) all_data = data_processor.get_data() - all_data_dict = {'train': all_data[0], - 'val': all_data[1], - 'test': all_data[2]} + all_data_dict = { + 'train': all_data[0], + 'val': all_data[1], + 'test': all_data[2] + } data_dict = dict() for client_id in tqdm(range(1, config.federate.client_num + 1)): @@ -301,64 +311,66 @@ def load_pfednlp_contrast_data(config, client_config): cur_task = cfg_client.model.downstream_tasks[client_id - 1] \ if pretrain else cfg_client.model.task train_data, val_data, test_data = [ - create_data( - data=all_data_dict[split][client_id - 1], - split=split, - tokenizer=tokenizer, - task=cur_task, - model_type=model_type, - max_seq_len=getattr(cfg_client.data, 'max_seq_len', None), - max_query_len=getattr(cfg_client.data, 'max_query_len', None), - trunc_stride=getattr(cfg_client.data, 'trunc_stride', None), - max_tgt_len=getattr(cfg_client.data, 'max_tgt_len', None), - cache_dir=cache_dir, - client_id=client_id, - pretrain=pretrain, - debug=debug) - for split in ['train', 'val', 'test'] + create_data(data=all_data_dict[split][client_id - 1], + split=split, + tokenizer=tokenizer, + task=cur_task, + model_type=model_type, + max_seq_len=getattr(cfg_client.data, 'max_seq_len', + None), + max_query_len=getattr(cfg_client.data, 'max_query_len', + None), + trunc_stride=getattr(cfg_client.data, 'trunc_stride', + None), + max_tgt_len=getattr(cfg_client.data, 'max_tgt_len', + None), + cache_dir=cache_dir, + client_id=client_id, + pretrain=pretrain, + debug=debug) for split in ['train', 'val', 'test'] ] dataloader_dict = { 'train_raw': { - 'dataloader': DataLoader( - dataset=train_data[0], - batch_size=cfg_client.data.batch_size, - shuffle=config.data.shuffle, - num_workers=config.data.num_workers, - collate_fn=data_collator, - pin_memory=config.use_gpu), + 'dataloader': DataLoader(dataset=train_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=config.data.shuffle, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), 'encoded': train_data[1], - 'examples': train_data[2]}, + 'examples': train_data[2] + }, 'train_contrast': { - 'dataloader': DataLoader( - dataset=train_data[0], - batch_size=cfg_client.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers, - collate_fn=data_collator, - pin_memory=config.use_gpu), + 'dataloader': DataLoader(dataset=train_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), 'encoded': train_data[1], - 'examples': train_data[2]}, + 'examples': train_data[2] + }, 'val': { - 'dataloader': DataLoader( - dataset=val_data[0], - batch_size=cfg_client.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers, - collate_fn=data_collator, - pin_memory=config.use_gpu), + 'dataloader': DataLoader(dataset=val_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), 'encoded': val_data[1], - 'examples': val_data[2]}, + 'examples': val_data[2] + }, 'test': { - 'dataloader': DataLoader( - dataset=test_data[0], - batch_size=cfg_client.data.batch_size, - shuffle=False, - num_workers=config.data.num_workers, - collate_fn=data_collator, - pin_memory=config.use_gpu), + 'dataloader': DataLoader(dataset=test_data[0], + batch_size=cfg_client.data.batch_size, + shuffle=False, + num_workers=config.data.num_workers, + collate_fn=data_collator, + pin_memory=config.use_gpu), 'encoded': test_data[1], - 'examples': test_data[2]}, + 'examples': test_data[2] + }, } data_dict[client_id] = dataloader_dict diff --git a/federatedscope/nlp/dataset/agnews.py b/federatedscope/nlp/dataset/agnews.py index c5ed9dcf4..9765d3492 100644 --- a/federatedscope/nlp/dataset/agnews.py +++ b/federatedscope/nlp/dataset/agnews.py @@ -16,12 +16,19 @@ def create_agnews_examples(data, debug=False): return examples -def create_agnews_dataset(data, split, tokenizer, max_seq_len, cache_dir='', - client_id=None, pretrain=False, debug=False, +def create_agnews_dataset(data, + split, + tokenizer, + max_seq_len, + cache_dir='', + client_id=None, + pretrain=False, + debug=False, **kwargs): if pretrain: - return create_agnews_pretrain_dataset( - data, split, tokenizer, max_seq_len, cache_dir, client_id, debug) + return create_agnews_pretrain_dataset(data, split, tokenizer, + max_seq_len, cache_dir, + client_id, debug) save_dir = osp.join(cache_dir, 'finetune', str(client_id)) cache_file = osp.join(save_dir, split + '.pt') @@ -42,22 +49,31 @@ def create_agnews_dataset(data, split, tokenizer, max_seq_len, cache_dir='', if cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_file)) os.makedirs(save_dir, exist_ok=True) - torch.save({'examples': examples, - 'encoded_inputs': encoded_inputs}, cache_file) + torch.save({ + 'examples': examples, + 'encoded_inputs': encoded_inputs + }, cache_file) labels = [ex[1] for ex in examples] example_indices = torch.arange(encoded_inputs.input_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': encoded_inputs.input_ids, - 'token_type_ids': encoded_inputs.token_type_ids, - 'attention_mask': encoded_inputs.attention_mask, - 'labels': torch.LongTensor(labels), - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': encoded_inputs.input_ids, + 'token_type_ids': encoded_inputs.token_type_ids, + 'attention_mask': encoded_inputs.attention_mask, + 'labels': torch.LongTensor(labels), + 'example_indices': example_indices + }) return dataset, encoded_inputs, examples -def create_agnews_pretrain_dataset(data, split, tokenizer, max_seq_len, - cache_dir='', client_id=None, debug=False): +def create_agnews_pretrain_dataset(data, + split, + tokenizer, + max_seq_len, + cache_dir='', + client_id=None, + debug=False): save_dir = osp.join(cache_dir, 'pretrain', str(client_id)) cache_file = osp.join(save_dir, split + '.pt') @@ -84,12 +100,16 @@ def create_agnews_pretrain_dataset(data, split, tokenizer, max_seq_len, if cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_file)) os.makedirs(save_dir, exist_ok=True) - torch.save({'examples': examples, - 'encoded_inputs': encoded_inputs}, cache_file) + torch.save({ + 'examples': examples, + 'encoded_inputs': encoded_inputs + }, cache_file) example_indices = torch.arange(encoded_inputs.input_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': encoded_inputs.input_ids, - 'attention_mask': encoded_inputs.attention_mask, - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': encoded_inputs.input_ids, + 'attention_mask': encoded_inputs.attention_mask, + 'example_indices': example_indices + }) return dataset, encoded_inputs, examples diff --git a/federatedscope/nlp/dataset/cnndm.py b/federatedscope/nlp/dataset/cnndm.py index f03a9871d..36e52f6d9 100644 --- a/federatedscope/nlp/dataset/cnndm.py +++ b/federatedscope/nlp/dataset/cnndm.py @@ -18,38 +18,43 @@ def create_cnndm_examples(data, debug=False): return src_examples, tgt_examples -def create_cnndm_dataset(data, split, tokenizer, max_src_len, max_tgt_len, - raw_cache_dir='', client_id=None, pretrain=False, - debug=False, **kwargs): +def create_cnndm_dataset(data, + split, + tokenizer, + max_src_len, + max_tgt_len, + raw_cache_dir='', + client_id=None, + pretrain=False, + debug=False, + **kwargs): if pretrain: - return create_cnndm_pretrain_dataset( - data, split, tokenizer, max_src_len, raw_cache_dir, client_id, - debug) + return create_cnndm_pretrain_dataset(data, split, tokenizer, + max_src_len, raw_cache_dir, + client_id, debug) cache_dir = osp.join(raw_cache_dir, 'finetune', str(client_id), split) src_examples, tgt_examples = create_cnndm_examples(data, debug) if osp.exists(cache_dir): logger.info('Loading cache file from \'{}\''.format(cache_dir)) - token_ids = np.memmap( - filename=osp.join(cache_dir, 'token_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='r', - dtype=np.int64) - token_type_ids = np.memmap( - filename=osp.join(cache_dir, 'token_type_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='r', - dtype=np.int64) - attention_mask = np.memmap( - filename=osp.join(cache_dir, 'attention_mask.memmap'), - shape=(len(src_examples), max_src_len), - mode='r', - dtype=np.int64) - labels = np.memmap( - filename=osp.join(cache_dir, 'labels.memmap'), - shape=(len(src_examples), max_tgt_len), - mode='r', - dtype=np.int64) + token_ids = np.memmap(filename=osp.join(cache_dir, 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + token_type_ids = np.memmap(filename=osp.join(cache_dir, + 'token_type_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + attention_mask = np.memmap(filename=osp.join(cache_dir, + 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + labels = np.memmap(filename=osp.join(cache_dir, 'labels.memmap'), + shape=(len(src_examples), max_tgt_len), + mode='r', + dtype=np.int64) token_ids = torch.from_numpy(token_ids) token_type_ids = torch.from_numpy(token_type_ids) @@ -76,26 +81,25 @@ def create_cnndm_dataset(data, split, tokenizer, max_src_len, max_tgt_len, if raw_cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_dir)) os.makedirs(cache_dir, exist_ok=True) - token_ids = np.memmap( - filename=osp.join(cache_dir, 'token_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='w+', - dtype=np.int64) - token_type_ids = np.memmap( - filename=osp.join(cache_dir, 'token_type_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='w+', - dtype=np.int64) - attention_mask = np.memmap( - filename=osp.join(cache_dir, 'attention_mask.memmap'), - shape=(len(src_examples), max_src_len), - mode='w+', - dtype=np.int64) - labels = np.memmap( - filename=osp.join(cache_dir, 'labels.memmap'), - shape=(len(src_examples), max_tgt_len), - mode='w+', - dtype=np.int64) + token_ids = np.memmap(filename=osp.join(cache_dir, + 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + token_type_ids = np.memmap(filename=osp.join( + cache_dir, 'token_type_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + attention_mask = np.memmap(filename=osp.join( + cache_dir, 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + labels = np.memmap(filename=osp.join(cache_dir, 'labels.memmap'), + shape=(len(src_examples), max_tgt_len), + mode='w+', + dtype=np.int64) for i in range(len(src_examples)): token_ids[i] = src_encoded.input_ids[i] @@ -114,31 +118,36 @@ def create_cnndm_dataset(data, split, tokenizer, max_src_len, max_tgt_len, labels = tgt_encoded.input_ids example_indices = torch.arange(token_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': token_ids, - 'token_type_ids': token_type_ids, - 'attention_mask': attention_mask, - 'labels': labels, - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': token_ids, + 'token_type_ids': token_type_ids, + 'attention_mask': attention_mask, + 'labels': labels, + 'example_indices': example_indices + }) return dataset, None, None -def create_cnndm_pretrain_dataset(data, split, tokenizer, max_src_len, - raw_cache_dir='', client_id=None, +def create_cnndm_pretrain_dataset(data, + split, + tokenizer, + max_src_len, + raw_cache_dir='', + client_id=None, debug=False): cache_dir = osp.join(raw_cache_dir, 'pretrain', str(client_id), split) src_examples, tgt_examples = create_cnndm_examples(data, debug) if osp.exists(cache_dir): logger.info('Loading cache file from \'{}\''.format(cache_dir)) - token_ids = np.memmap( - filename=osp.join(cache_dir, 'token_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='r', - dtype=np.int64) - attention_mask = np.memmap( - filename=osp.join(cache_dir, 'attention_mask.memmap'), - shape=(len(src_examples), max_src_len), - mode='r', - dtype=np.int64) + token_ids = np.memmap(filename=osp.join(cache_dir, 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + attention_mask = np.memmap(filename=osp.join(cache_dir, + 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) token_ids = torch.from_numpy(token_ids) attention_mask = torch.from_numpy(attention_mask) @@ -158,16 +167,16 @@ def create_cnndm_pretrain_dataset(data, split, tokenizer, max_src_len, if raw_cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_dir)) os.makedirs(cache_dir, exist_ok=True) - token_ids = np.memmap( - filename=osp.join(cache_dir, 'token_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='w+', - dtype=np.int64) - attention_mask = np.memmap( - filename=osp.join(cache_dir, 'attention_mask.memmap'), - shape=(len(src_examples), max_src_len), - mode='w+', - dtype=np.int64) + token_ids = np.memmap(filename=osp.join(cache_dir, + 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + attention_mask = np.memmap(filename=osp.join( + cache_dir, 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) for i in range(len(src_examples)): token_ids[i] = src_encoded.input_ids[i] @@ -180,7 +189,9 @@ def create_cnndm_pretrain_dataset(data, split, tokenizer, max_src_len, attention_mask = src_encoded.attention_mask example_indices = torch.arange(token_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': token_ids, - 'attention_mask': attention_mask, - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': token_ids, + 'attention_mask': attention_mask, + 'example_indices': example_indices + }) return dataset, None, None diff --git a/federatedscope/nlp/dataset/imdb.py b/federatedscope/nlp/dataset/imdb.py index b514037f8..63db181ff 100644 --- a/federatedscope/nlp/dataset/imdb.py +++ b/federatedscope/nlp/dataset/imdb.py @@ -16,11 +16,19 @@ def create_imdb_examples(data, debug=False): return examples -def create_imdb_dataset(data, split, tokenizer, max_seq_len, cache_dir='', - client_id=None, pretrain=False, debug=False, **kwargs): +def create_imdb_dataset(data, + split, + tokenizer, + max_seq_len, + cache_dir='', + client_id=None, + pretrain=False, + debug=False, + **kwargs): if pretrain: - return create_imdb_pretrain_dataset( - data, split, tokenizer, max_seq_len, cache_dir, client_id, debug) + return create_imdb_pretrain_dataset(data, split, tokenizer, + max_seq_len, cache_dir, client_id, + debug) save_dir = osp.join(cache_dir, 'finetune', str(client_id)) cache_file = osp.join(save_dir, split + '.pt') @@ -41,22 +49,31 @@ def create_imdb_dataset(data, split, tokenizer, max_seq_len, cache_dir='', if cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_file)) os.makedirs(save_dir, exist_ok=True) - torch.save({'examples': examples, - 'encoded_inputs': encoded_inputs}, cache_file) + torch.save({ + 'examples': examples, + 'encoded_inputs': encoded_inputs + }, cache_file) labels = [ex[1] for ex in examples] example_indices = torch.arange(encoded_inputs.input_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': encoded_inputs.input_ids, - 'token_type_ids': encoded_inputs.token_type_ids, - 'attention_mask': encoded_inputs.attention_mask, - 'labels': torch.LongTensor(labels), - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': encoded_inputs.input_ids, + 'token_type_ids': encoded_inputs.token_type_ids, + 'attention_mask': encoded_inputs.attention_mask, + 'labels': torch.LongTensor(labels), + 'example_indices': example_indices + }) return dataset, encoded_inputs, examples -def create_imdb_pretrain_dataset(data, split, tokenizer, max_seq_len, - cache_dir='', client_id=None, debug=False): +def create_imdb_pretrain_dataset(data, + split, + tokenizer, + max_seq_len, + cache_dir='', + client_id=None, + debug=False): save_dir = osp.join(cache_dir, 'pretrain', str(client_id)) cache_file = osp.join(save_dir, split + '.pt') if osp.exists(cache_file): @@ -82,12 +99,16 @@ def create_imdb_pretrain_dataset(data, split, tokenizer, max_seq_len, if cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_file)) os.makedirs(save_dir, exist_ok=True) - torch.save({'examples': examples, - 'encoded_inputs': encoded_inputs}, cache_file) + torch.save({ + 'examples': examples, + 'encoded_inputs': encoded_inputs + }, cache_file) example_indices = torch.arange(encoded_inputs.input_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': encoded_inputs.input_ids, - 'attention_mask': encoded_inputs.attention_mask, - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': encoded_inputs.input_ids, + 'attention_mask': encoded_inputs.attention_mask, + 'example_indices': example_indices + }) return dataset, encoded_inputs, examples diff --git a/federatedscope/nlp/dataset/msqg.py b/federatedscope/nlp/dataset/msqg.py index dab24eac7..6040d047c 100644 --- a/federatedscope/nlp/dataset/msqg.py +++ b/federatedscope/nlp/dataset/msqg.py @@ -18,38 +18,43 @@ def create_msqg_examples(data, debug=False): return src_examples, tgt_examples -def create_msqg_dataset(data, split, tokenizer, max_src_len, max_tgt_len, - raw_cache_dir='', client_id=None, pretrain=False, - debug=False, **kwargs): +def create_msqg_dataset(data, + split, + tokenizer, + max_src_len, + max_tgt_len, + raw_cache_dir='', + client_id=None, + pretrain=False, + debug=False, + **kwargs): if pretrain: - return create_msqg_pretrain_dataset( - data, split, tokenizer, max_src_len, raw_cache_dir, client_id, - debug) + return create_msqg_pretrain_dataset(data, split, tokenizer, + max_src_len, raw_cache_dir, + client_id, debug) cache_dir = osp.join(raw_cache_dir, 'finetune', str(client_id), split) src_examples, tgt_examples = create_msqg_examples(data, debug) if osp.exists(cache_dir): logger.info('Loading cache file from \'{}\''.format(cache_dir)) - token_ids = np.memmap( - filename=osp.join(cache_dir, 'token_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='r', - dtype=np.int64) - token_type_ids = np.memmap( - filename=osp.join(cache_dir, 'token_type_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='r', - dtype=np.int64) - attention_mask = np.memmap( - filename=osp.join(cache_dir, 'attention_mask.memmap'), - shape=(len(src_examples), max_src_len), - mode='r', - dtype=np.int64) - labels = np.memmap( - filename=osp.join(cache_dir, 'labels.memmap'), - shape=(len(src_examples), max_tgt_len), - mode='r', - dtype=np.int64) + token_ids = np.memmap(filename=osp.join(cache_dir, 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + token_type_ids = np.memmap(filename=osp.join(cache_dir, + 'token_type_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + attention_mask = np.memmap(filename=osp.join(cache_dir, + 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + labels = np.memmap(filename=osp.join(cache_dir, 'labels.memmap'), + shape=(len(src_examples), max_tgt_len), + mode='r', + dtype=np.int64) token_ids = torch.from_numpy(token_ids) token_type_ids = torch.from_numpy(token_type_ids) @@ -78,26 +83,25 @@ def create_msqg_dataset(data, split, tokenizer, max_src_len, max_tgt_len, if raw_cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_dir)) os.makedirs(cache_dir, exist_ok=True) - token_ids = np.memmap( - filename=osp.join(cache_dir, 'token_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='w+', - dtype=np.int64) - token_type_ids = np.memmap( - filename=osp.join(cache_dir, 'token_type_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='w+', - dtype=np.int64) - attention_mask = np.memmap( - filename=osp.join(cache_dir, 'attention_mask.memmap'), - shape=(len(src_examples), max_src_len), - mode='w+', - dtype=np.int64) - labels = np.memmap( - filename=osp.join(cache_dir, 'labels.memmap'), - shape=(len(src_examples), max_tgt_len), - mode='w+', - dtype=np.int64) + token_ids = np.memmap(filename=osp.join(cache_dir, + 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + token_type_ids = np.memmap(filename=osp.join( + cache_dir, 'token_type_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + attention_mask = np.memmap(filename=osp.join( + cache_dir, 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + labels = np.memmap(filename=osp.join(cache_dir, 'labels.memmap'), + shape=(len(src_examples), max_tgt_len), + mode='w+', + dtype=np.int64) for i in range(len(src_examples)): token_ids[i] = src_encoded.input_ids[i] @@ -117,30 +121,36 @@ def create_msqg_dataset(data, split, tokenizer, max_src_len, max_tgt_len, labels = tgt_encoded.input_ids example_indices = torch.arange(token_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': token_ids, - 'token_type_ids': token_type_ids, - 'attention_mask': attention_mask, - 'labels': labels, - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': token_ids, + 'token_type_ids': token_type_ids, + 'attention_mask': attention_mask, + 'labels': labels, + 'example_indices': example_indices + }) return dataset, None, None -def create_msqg_pretrain_dataset(data, split, tokenizer, max_src_len, - raw_cache_dir='', client_id=None, debug=False): +def create_msqg_pretrain_dataset(data, + split, + tokenizer, + max_src_len, + raw_cache_dir='', + client_id=None, + debug=False): cache_dir = osp.join(raw_cache_dir, 'pretrain', str(client_id), split) src_examples, tgt_examples = create_msqg_examples(data, debug) if osp.exists(cache_dir): logger.info('Loading cache file from \'{}\''.format(cache_dir)) - token_ids = np.memmap( - filename=osp.join(cache_dir, 'token_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='r', - dtype=np.int64) - attention_mask = np.memmap( - filename=osp.join(cache_dir, 'attention_mask.memmap'), - shape=(len(src_examples), max_src_len), - mode='r', - dtype=np.int64) + token_ids = np.memmap(filename=osp.join(cache_dir, 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) + attention_mask = np.memmap(filename=osp.join(cache_dir, + 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='r', + dtype=np.int64) token_ids = torch.from_numpy(token_ids) attention_mask = torch.from_numpy(attention_mask) else: @@ -161,16 +171,16 @@ def create_msqg_pretrain_dataset(data, split, tokenizer, max_src_len, if raw_cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_dir)) os.makedirs(cache_dir, exist_ok=True) - token_ids = np.memmap( - filename=osp.join(cache_dir, 'token_ids.memmap'), - shape=(len(src_examples), max_src_len), - mode='w+', - dtype=np.int64) - attention_mask = np.memmap( - filename=osp.join(cache_dir, 'attention_mask.memmap'), - shape=(len(src_examples), max_src_len), - mode='w+', - dtype=np.int64) + token_ids = np.memmap(filename=osp.join(cache_dir, + 'token_ids.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) + attention_mask = np.memmap(filename=osp.join( + cache_dir, 'attention_mask.memmap'), + shape=(len(src_examples), max_src_len), + mode='w+', + dtype=np.int64) for i in range(len(src_examples)): token_ids[i] = src_encoded.input_ids[i] @@ -183,7 +193,9 @@ def create_msqg_pretrain_dataset(data, split, tokenizer, max_src_len, attention_mask = src_encoded.attention_mask example_indices = torch.arange(token_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': token_ids, - 'attention_mask': attention_mask, - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': token_ids, + 'attention_mask': attention_mask, + 'example_indices': example_indices + }) return dataset, None, None diff --git a/federatedscope/nlp/dataset/newsqa.py b/federatedscope/nlp/dataset/newsqa.py index 04eb44110..08aa36903 100644 --- a/federatedscope/nlp/dataset/newsqa.py +++ b/federatedscope/nlp/dataset/newsqa.py @@ -53,8 +53,8 @@ def get_char_to_word_positions(context, answer, start_char_pos, is_impossible): char_to_word_offset = [] is_prev_whitespace = True for c in context: - is_whitespace = (c == ' ' or c == '\t' or c == '\r' or c == '\n' or - ord(c) == 0x202F) + is_whitespace = (c == ' ' or c == '\t' or c == '\r' or c == '\n' + or ord(c) == 0x202F) if is_whitespace: is_prev_whitespace = True else: @@ -153,30 +153,40 @@ def create_newsqa_examples(data, split, debug=False): is_impossible = qa['is_impossible'] if 'is_impossible' in qa else False if not is_impossible: answers = qa['detected_answers'] - spans = sorted([span for spans in answers - for span in spans['char_spans']]) + spans = sorted( + [span for spans in answers for span in spans['char_spans']]) if split == 'train': - train_answer = context[spans[0][0]: spans[0][1] + 1] + train_answer = context[spans[0][0]:spans[0][1] + 1] start_char_pos = spans[0][0] else: - val_answer = [{'text': context[spans[i][0]: spans[i][1] + 1], - 'answer_start': spans[i][0]} - for i in range(len(spans))] + val_answer = [{ + 'text': context[spans[i][0]:spans[i][1] + 1], + 'answer_start': spans[i][0] + } for i in range(len(spans))] start_pos, end_pos, context_tokens = get_char_to_word_positions( context, train_answer, start_char_pos, is_impossible) - examples.append(NewsQAExample(qa_id, question, context, train_answer, - val_answer, start_pos, end_pos, - context_tokens, is_impossible)) + examples.append( + NewsQAExample(qa_id, question, context, train_answer, val_answer, + start_pos, end_pos, context_tokens, is_impossible)) return examples -def create_newsqa_dataset(data, split, tokenizer, max_seq_len, max_query_len, - trunc_stride, cache_dir='', client_id=None, - pretrain=False, debug=False, **kwargs): +def create_newsqa_dataset(data, + split, + tokenizer, + max_seq_len, + max_query_len, + trunc_stride, + cache_dir='', + client_id=None, + pretrain=False, + debug=False, + **kwargs): if pretrain: - return create_newsqa_pretrain_dataset( - data, split, tokenizer, max_seq_len, cache_dir, client_id, debug) + return create_newsqa_pretrain_dataset(data, split, tokenizer, + max_seq_len, cache_dir, + client_id, debug) save_dir = osp.join(cache_dir, 'finetune', str(client_id)) cache_file = osp.join(save_dir, split + '.pt') @@ -193,8 +203,8 @@ def create_newsqa_dataset(data, split, tokenizer, max_seq_len, max_query_len, if split == 'train' and not example.is_impossible: start_pos = example.start_position end_pos = example.end_position - actual_answer = ' '.join(example.context_tokens[ - start_pos:(end_pos + 1)]) + actual_answer = ' '.join( + example.context_tokens[start_pos:(end_pos + 1)]) cleaned_answer = ' '.join(example.train_answer.strip().split()) if actual_answer.find(cleaned_answer) == -1: logger.info('Could not find answer: {} vs. {}'.format( @@ -214,8 +224,8 @@ def create_newsqa_dataset(data, split, tokenizer, max_seq_len, max_query_len, if split == 'train' and not example.is_impossible: subtoken_start_pos = tok_to_subtok_idx[example.start_position] if example.end_position < len(example.context_tokens) - 1: - subtoken_end_pos = tok_to_subtok_idx[ - example.end_position + 1] - 1 + subtoken_end_pos = tok_to_subtok_idx[example.end_position + + 1] - 1 else: subtoken_end_pos = len(context_subtokens) - 1 subtoken_start_pos, subtoken_end_pos = \ @@ -236,14 +246,16 @@ def create_newsqa_dataset(data, split, tokenizer, max_seq_len, max_query_len, encoded_input = encode(tokenizer, text_a, text_b, max_seq_len, max_query_len, added_trunc_size) context_start_pos = len(spans) * trunc_stride - context_len = min(len(context_subtokens) - context_start_pos, - max_seq_len - len_question - 3) + context_len = min( + len(context_subtokens) - context_start_pos, + max_seq_len - len_question - 3) context_end_pos = context_start_pos + context_len - 1 if tokenizer.pad_token_id in encoded_input.token_ids: - non_padded_ids = encoded_input.token_ids[ - :encoded_input.token_ids.index( - tokenizer.pad_token_id)] + non_padded_ids = encoded_input.token_ids[:encoded_input. + token_ids.index( + tokenizer. + pad_token_id)] else: non_padded_ids = encoded_input.token_ids tokens = tokenizer.convert_ids_to_tokens(non_padded_ids) @@ -306,31 +318,40 @@ def create_newsqa_dataset(data, split, tokenizer, max_seq_len, max_query_len, if cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_file)) os.makedirs(save_dir, exist_ok=True) - torch.save({'examples': examples, - 'encoded_inputs': encoded_inputs}, cache_file) + torch.save({ + 'examples': examples, + 'encoded_inputs': encoded_inputs + }, cache_file) token_ids = torch.LongTensor([inp.token_ids for inp in encoded_inputs]) - token_type_ids = torch.LongTensor([inp.token_type_ids - for inp in encoded_inputs]) - attention_mask = torch.LongTensor([inp.attention_mask - for inp in encoded_inputs]) - start_positions = torch.LongTensor([inp.start_position - for inp in encoded_inputs]) - end_positions = torch.LongTensor([inp.end_position for - inp in encoded_inputs]) + token_type_ids = torch.LongTensor( + [inp.token_type_ids for inp in encoded_inputs]) + attention_mask = torch.LongTensor( + [inp.attention_mask for inp in encoded_inputs]) + start_positions = torch.LongTensor( + [inp.start_position for inp in encoded_inputs]) + end_positions = torch.LongTensor( + [inp.end_position for inp in encoded_inputs]) example_indices = torch.arange(token_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': token_ids, - 'token_type_ids': token_type_ids, - 'attention_mask': attention_mask, - 'start_positions': start_positions, - 'end_positions': end_positions, - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': token_ids, + 'token_type_ids': token_type_ids, + 'attention_mask': attention_mask, + 'start_positions': start_positions, + 'end_positions': end_positions, + 'example_indices': example_indices + }) return dataset, encoded_inputs, examples -def create_newsqa_pretrain_dataset(data, split, tokenizer, max_seq_len, - cache_dir='', client_id=None, debug=False): +def create_newsqa_pretrain_dataset(data, + split, + tokenizer, + max_seq_len, + cache_dir='', + client_id=None, + debug=False): save_dir = osp.join(cache_dir, 'pretrain', str(client_id)) cache_file = osp.join(save_dir, split + '.pt') if osp.exists(cache_file): @@ -356,12 +377,16 @@ def create_newsqa_pretrain_dataset(data, split, tokenizer, max_seq_len, if cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_file)) os.makedirs(save_dir, exist_ok=True) - torch.save({'examples': examples, - 'encoded_inputs': encoded_inputs}, cache_file) + torch.save({ + 'examples': examples, + 'encoded_inputs': encoded_inputs + }, cache_file) example_indices = torch.arange(encoded_inputs.input_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': encoded_inputs.input_ids, - 'attention_mask': encoded_inputs.attention_mask, - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': encoded_inputs.input_ids, + 'attention_mask': encoded_inputs.attention_mask, + 'example_indices': example_indices + }) return dataset, encoded_inputs, examples diff --git a/federatedscope/nlp/dataset/preprocess/get_hfl_data.py b/federatedscope/nlp/dataset/preprocess/get_hfl_data.py index 9757011ca..73bfdf402 100644 --- a/federatedscope/nlp/dataset/preprocess/get_hfl_data.py +++ b/federatedscope/nlp/dataset/preprocess/get_hfl_data.py @@ -19,26 +19,31 @@ def __init__(self, config, train_frac=0.9): self.total_client_num = config.federate.client_num self.num_grouped_clients = config.data.num_grouped_clients self.train_frac = train_frac + self.all_train_data = [] + self.all_val_data = [] + self.all_test_data = [] def get_data(self): - all_train_data = [] - all_val_data = [] - all_test_data = [] for i, dataset in enumerate(self.datasets): if dataset not in HFL_NAMES: raise ValueError(f'No HFL dataset named {dataset}') - train_val_data = self._load_data( - dataset, 'train', self.num_grouped_clients[i]) - train_data = [data[:int(self.train_frac * len(data))] - for data in train_val_data] - val_data = [data[int(self.train_frac * len(data)):] - for data in train_val_data] - test_data = self._load_data( - dataset, 'test', self.num_grouped_clients[i]) - all_train_data.extend(train_data) - all_val_data.extend(val_data) - all_test_data.extend(test_data) - return all_train_data, all_val_data, all_test_data + train_val_data = self._load_data(dataset, 'train', + self.num_grouped_clients[i]) + train_data = [ + data[:int(self.train_frac * len(data))] + for data in train_val_data + ] + val_data = [ + data[int(self.train_frac * len(data)):] + for data in train_val_data + ] + test_data = self._load_data(dataset, 'test', + self.num_grouped_clients[i]) + self.all_train_data.extend(train_data) + self.all_val_data.extend(val_data) + self.all_test_data.extend(test_data) + + return self.all_train_data, self.all_val_data, self.all_test_data def _load_data(self, dataset, split, num_clients): data_dir = os.path.join(self.data_dir, dataset) @@ -78,7 +83,8 @@ def _load_data(self, dataset, split, num_clients): data.append({'text': text, 'label': label}) elif dataset == 'squad': - with open(os.path.join(data_dir, split + '.json'), 'r', + with open(os.path.join(data_dir, split + '.json'), + 'r', encoding='utf-8') as reader: raw_data = json.load(reader)['data'] for line in raw_data: @@ -101,11 +107,13 @@ def _load_data(self, dataset, split, num_clients): src_file = os.path.join(data_dir, split + '.src') tgt_file = os.path.join(data_dir, split + '.tgt') with open(src_file) as f: - src_data = [line.strip().replace('', '[SEP]') - for line in f] + src_data = [ + line.strip().replace('', '[SEP]') for line in f + ] with open(tgt_file) as f: - tgt_data = [line.strip().replace('', '[SEP]') - for line in f] + tgt_data = [ + line.strip().replace('', '[SEP]') for line in f + ] for src, tgt in zip(src_data, tgt_data): data.append({'src': src, 'tgt': tgt}) @@ -117,10 +125,11 @@ def _load_data(self, dataset, split, num_clients): for i in range(num_clients): num_split = n if i < num_clients - 1 else \ len(data) - n * (num_clients - 1) - cur_data = data[data_idx: data_idx + num_split] + cur_data = data[data_idx:data_idx + num_split] data_idx += num_split all_split_data.append(cur_data) - logger.info(f'Client id: {i + 1}, num samples: {num_split}') + logger.info(f'Client id: {len(self.all_train_data) + i + 1}, ' + f'num samples: {num_split}') return all_split_data def _download(self, dataset): diff --git a/federatedscope/nlp/dataset/squad.py b/federatedscope/nlp/dataset/squad.py index 3e30d7b71..56b22ca40 100644 --- a/federatedscope/nlp/dataset/squad.py +++ b/federatedscope/nlp/dataset/squad.py @@ -53,8 +53,8 @@ def get_char_to_word_positions(context, answer, start_char_pos, is_impossible): char_to_word_offset = [] is_prev_whitespace = True for c in context: - is_whitespace = (c == ' ' or c == '\t' or c == '\r' or c == '\n' or - ord(c) == 0x202F) + is_whitespace = (c == ' ' or c == '\t' or c == '\r' or c == '\n' + or ord(c) == 0x202F) if is_whitespace: is_prev_whitespace = True else: @@ -160,18 +160,27 @@ def create_squad_examples(data, split, debug=False): start_pos, end_pos, context_tokens = get_char_to_word_positions( context, train_answer, start_char_pos, is_impossible) - examples.append(SquadExample(qa_id, question, context, train_answer, - val_answer, start_pos, end_pos, - context_tokens, is_impossible)) + examples.append( + SquadExample(qa_id, question, context, train_answer, val_answer, + start_pos, end_pos, context_tokens, is_impossible)) return examples -def create_squad_dataset(data, split, tokenizer, max_seq_len, max_query_len, - trunc_stride, cache_dir='', client_id=None, - pretrain=False, debug=False, **kwargs): +def create_squad_dataset(data, + split, + tokenizer, + max_seq_len, + max_query_len, + trunc_stride, + cache_dir='', + client_id=None, + pretrain=False, + debug=False, + **kwargs): if pretrain: - return create_squad_pretrain_dataset( - data, split, tokenizer, max_seq_len, cache_dir, client_id, debug) + return create_squad_pretrain_dataset(data, split, tokenizer, + max_seq_len, cache_dir, client_id, + debug) save_dir = osp.join(cache_dir, 'finetune', str(client_id)) cache_file = osp.join(save_dir, split + '.pt') @@ -188,8 +197,8 @@ def create_squad_dataset(data, split, tokenizer, max_seq_len, max_query_len, if split == 'train' and not example.is_impossible: start_pos = example.start_position end_pos = example.end_position - actual_answer = ' '.join(example.context_tokens[ - start_pos:(end_pos + 1)]) + actual_answer = ' '.join( + example.context_tokens[start_pos:(end_pos + 1)]) cleaned_answer = ' '.join(example.train_answer.strip().split()) if actual_answer.find(cleaned_answer) == -1: logger.info('Could not find answer: {} vs. {}'.format( @@ -209,8 +218,8 @@ def create_squad_dataset(data, split, tokenizer, max_seq_len, max_query_len, if split == 'train' and not example.is_impossible: subtoken_start_pos = tok_to_subtok_idx[example.start_position] if example.end_position < len(example.context_tokens) - 1: - subtoken_end_pos = tok_to_subtok_idx[ - example.end_position + 1] - 1 + subtoken_end_pos = tok_to_subtok_idx[example.end_position + + 1] - 1 else: subtoken_end_pos = len(context_subtokens) - 1 subtoken_start_pos, subtoken_end_pos = \ @@ -231,14 +240,16 @@ def create_squad_dataset(data, split, tokenizer, max_seq_len, max_query_len, encoded_input = encode(tokenizer, text_a, text_b, max_seq_len, max_query_len, added_trunc_size) context_start_pos = len(spans) * trunc_stride - context_len = min(len(context_subtokens) - context_start_pos, - max_seq_len - len_question - 3) + context_len = min( + len(context_subtokens) - context_start_pos, + max_seq_len - len_question - 3) context_end_pos = context_start_pos + context_len - 1 if tokenizer.pad_token_id in encoded_input.token_ids: - non_padded_ids = encoded_input.token_ids[ - :encoded_input.token_ids.index( - tokenizer.pad_token_id)] + non_padded_ids = encoded_input.token_ids[:encoded_input. + token_ids.index( + tokenizer. + pad_token_id)] else: non_padded_ids = encoded_input.token_ids tokens = tokenizer.convert_ids_to_tokens(non_padded_ids) @@ -301,31 +312,40 @@ def create_squad_dataset(data, split, tokenizer, max_seq_len, max_query_len, if cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_file)) os.makedirs(save_dir, exist_ok=True) - torch.save({'examples': examples, - 'encoded_inputs': encoded_inputs}, cache_file) + torch.save({ + 'examples': examples, + 'encoded_inputs': encoded_inputs + }, cache_file) token_ids = torch.LongTensor([inp.token_ids for inp in encoded_inputs]) - token_type_ids = torch.LongTensor([inp.token_type_ids - for inp in encoded_inputs]) - attention_mask = torch.LongTensor([inp.attention_mask - for inp in encoded_inputs]) - start_positions = torch.LongTensor([inp.start_position - for inp in encoded_inputs]) - end_positions = torch.LongTensor([inp.end_position - for inp in encoded_inputs]) + token_type_ids = torch.LongTensor( + [inp.token_type_ids for inp in encoded_inputs]) + attention_mask = torch.LongTensor( + [inp.attention_mask for inp in encoded_inputs]) + start_positions = torch.LongTensor( + [inp.start_position for inp in encoded_inputs]) + end_positions = torch.LongTensor( + [inp.end_position for inp in encoded_inputs]) example_indices = torch.arange(token_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': token_ids, - 'token_type_ids': token_type_ids, - 'attention_mask': attention_mask, - 'start_positions': start_positions, - 'end_positions': end_positions, - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': token_ids, + 'token_type_ids': token_type_ids, + 'attention_mask': attention_mask, + 'start_positions': start_positions, + 'end_positions': end_positions, + 'example_indices': example_indices + }) return dataset, encoded_inputs, examples -def create_squad_pretrain_dataset(data, split, tokenizer, max_seq_len, - cache_dir='', client_id=None, debug=False): +def create_squad_pretrain_dataset(data, + split, + tokenizer, + max_seq_len, + cache_dir='', + client_id=None, + debug=False): save_dir = osp.join(cache_dir, 'pretrain', str(client_id)) cache_file = osp.join(save_dir, split + '.pt') if osp.exists(cache_file): @@ -351,12 +371,16 @@ def create_squad_pretrain_dataset(data, split, tokenizer, max_seq_len, if cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_file)) os.makedirs(save_dir, exist_ok=True) - torch.save({'examples': examples, - 'encoded_inputs': encoded_inputs}, cache_file) + torch.save({ + 'examples': examples, + 'encoded_inputs': encoded_inputs + }, cache_file) example_indices = torch.arange(encoded_inputs.input_ids.size(0), dtype=torch.long) - dataset = DictDataset({'token_ids': encoded_inputs.input_ids, - 'attention_mask': encoded_inputs.attention_mask, - 'example_indices': example_indices}) + dataset = DictDataset({ + 'token_ids': encoded_inputs.input_ids, + 'attention_mask': encoded_inputs.attention_mask, + 'example_indices': example_indices + }) return dataset, encoded_inputs, examples diff --git a/federatedscope/nlp/dataset/utils.py b/federatedscope/nlp/dataset/utils.py index 94b88cf95..71820359c 100644 --- a/federatedscope/nlp/dataset/utils.py +++ b/federatedscope/nlp/dataset/utils.py @@ -12,7 +12,6 @@ from torch.utils.data.dataset import Dataset from transformers.models.bert import BertTokenizerFast - # ------------------------ # utils for shakespeare dataset @@ -110,8 +109,9 @@ def split_sent(examples, eoq='[unused2]', tokenize=True): class DictDataset(Dataset): def __init__(self, inputs): super().__init__() - assert all(list(inputs.values())[0].size(0) == v.size(0) - for v in inputs.values()), "Size mismatch between tensors" + assert all( + list(inputs.values())[0].size(0) == v.size(0) + for v in inputs.values()), "Size mismatch between tensors" self.inputs = inputs def __getitem__(self, index): diff --git a/federatedscope/nlp/loss/label_smooth_loss.py b/federatedscope/nlp/loss/label_smooth_loss.py new file mode 100644 index 000000000..b4a7b9530 --- /dev/null +++ b/federatedscope/nlp/loss/label_smooth_loss.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LabelSmoothingLoss(nn.Module): + """ + With label smoothing, + KL-divergence between q_{smoothed ground truth prob.}(w) + and p_{prob. computed by model}(w) is minimized. + """ + def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100): + assert 0.0 < label_smoothing <= 1.0 + self.padding_idx = ignore_index + super(LabelSmoothingLoss, self).__init__() + + smoothing_value = label_smoothing / (tgt_vocab_size - 2) + one_hot = torch.full((tgt_vocab_size, ), smoothing_value) + one_hot[self.padding_idx] = 0 + self.register_buffer('one_hot', one_hot.unsqueeze(0)) + self.confidence = 1.0 - label_smoothing + + def forward(self, output, target): + """ + output (FloatTensor): batch_size x n_classes + target (LongTensor): batch_size + """ + model_prob = self.one_hot.repeat(target.size(0), 1) + model_prob.scatter_(1, target.unsqueeze(1), self.confidence) + model_prob.masked_fill_((target == self.padding_idx).unsqueeze(1), 0) + return F.kl_div(output, model_prob, reduction='sum') diff --git a/federatedscope/nlp/model/__init__.py b/federatedscope/nlp/model/__init__.py index 941335213..c0b31382d 100644 --- a/federatedscope/nlp/model/__init__.py +++ b/federatedscope/nlp/model/__init__.py @@ -1,4 +1,8 @@ -from federatedscope.nlp.model.rnn import LSTM -from federatedscope.nlp.model.model_builder import get_rnn, get_transformer +from os.path import dirname, basename, isfile, join +import glob -__all__ = ['LSTM', 'get_rnn', 'get_transformer'] +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules + if isfile(f) and not f.endswith('__init__.py') +] diff --git a/federatedscope/nlp/model/hfl_model.py b/federatedscope/nlp/model/hfl_model.py new file mode 100644 index 000000000..1a168a320 --- /dev/null +++ b/federatedscope/nlp/model/hfl_model.py @@ -0,0 +1,561 @@ +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.models.bert.modeling_bert import BertLMPredictionHead +from transformers.models.encoder_decoder import EncoderDecoderModel +from federatedscope.register import register_model +from federatedscope.nlp.loss.label_smooth_loss import LabelSmoothingLoss + + +class ModelOutput(object): + def __init__(self, + loss=None, + regular_loss=None, + contrast_loss=None, + logits=None, + hidden_states=None, + example_indices=None): + self.loss = loss + self.regular_loss = regular_loss + self.contrast_loss = contrast_loss + self.logits = logits + self.hidden_states = hidden_states + self.example_indices = example_indices + + +class ContrastiveHead(nn.Module): + def __init__(self, input_dim, inner_dim, out_dim, dropout_prob): + super().__init__() + + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=dropout_prob) + self.out_prj = nn.Linear(inner_dim, out_dim) + + def forward(self, x): + x = self.dense(self.dropout(x)) + x = torch.tanh(x) + x = self.out_prj(self.dropout(x)) + return x + + +class FedNLPModel(nn.Module): + def __init__(self, config): + super().__init__() + + self.model = EncoderDecoderModel.from_encoder_decoder_pretrained( + config.model_type, config.model_type) + self.lm_head = BertLMPredictionHead(self.model.encoder.config) + + self.task = config.task + self.pretrain_task = config.pretrain_task + self.pt_cfg = self.model.encoder.config + self.vocab_size = self.pt_cfg.vocab_size + self.hidden_size = self.pt_cfg.hidden_size + self.dropout_prob = self.pt_cfg.hidden_dropout_prob + self.dropout = nn.Dropout(self.dropout_prob) + + self.label_smoothing = config.label_smoothing + self.padding_idx = config.pad_token_id + self.classifier = nn.Linear(self.hidden_size, config.num_labels) \ + if config.num_labels is not None else None + + # for eval generation + self.model.config.decoder_start_token_id = config.bos_token_id + self.model.config.eos_token_id = config.eos_token_id + self.model.config.pad_token_id = config.pad_token_id + self.model.config.vocab_size = self.pt_cfg.vocab_size + self.model.config.max_length = config.max_length + self.model.config.min_length = config.min_length + self.model.config.no_repeat_ngram_size = config.no_repeat_ngram_size + self.model.config.length_penalty = config.length_penalty + self.model.config.num_beams = config.num_beams + + def generate(self, **kwargs): + return self.model.generate(**kwargs) + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + start_positions=None, + end_positions=None, + labels=None, + ): + enc_outputs = self.model.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + ) + + if self.task == 'pretrain': + if self.pretrain_task == 'mlm': + logits = self.lm_head(enc_outputs.last_hidden_state) + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(logits.view(-1, self.vocab_size), + labels.view(-1)) + loss = masked_lm_loss + + elif self.pretrain_task == 'denoise': + dec_outputs = self.model.decoder.bert( + input_ids=labels, + encoder_hidden_states=enc_outputs.last_hidden_state, + encoder_attention_mask=attention_mask, + ) + dec_hidden_states = dec_outputs.last_hidden_state + logits = self.model.decoder.cls(dec_hidden_states)[:, :-1, :] + loss_fct = CrossEntropyLoss(ignore_index=self.padding_idx) + loss = loss_fct(logits.contiguous().view(-1, self.vocab_size), + labels[:, 1:].contiguous().view(-1)) + + elif self.task in {'imdb', 'agnews'}: + pooled_output = self.dropout(enc_outputs.pooler_output) + logits = self.classifier(pooled_output) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) + + elif self.task in {'squad', 'newsqa'}: + logits = self.classifier(enc_outputs.last_hidden_state) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + logits = (start_logits, end_logits) + + # sometimes the start/end positions are outside our model + # inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + loss = (start_loss + end_loss) / 2 + + elif self.task in {'cnndm', 'msqg'}: + dec_outputs = self.model.decoder.bert( + input_ids=labels, + encoder_hidden_states=enc_outputs.last_hidden_state, + encoder_attention_mask=attention_mask, + ) + dec_hidden_states = dec_outputs.last_hidden_state + logits = self.model.decoder.cls(dec_hidden_states)[:, :-1, :] + + num_tokens = labels[:, 1:].ne(self.padding_idx).sum().item() + label_smoothing = self.label_smoothing if self.training else 0.0 + if label_smoothing > 0: + loss_fct = LabelSmoothingLoss( + label_smoothing, + self.vocab_size, + ignore_index=self.padding_idx, + ).to(logits.device) + loss = loss_fct( + F.log_softmax( + logits.contiguous().view(-1, self.vocab_size), dim=-1), + labels[:, 1:].contiguous().view(-1)) / num_tokens + else: + loss_fct = CrossEntropyLoss(ignore_index=self.padding_idx) + loss = loss_fct(logits.contiguous().view(-1, self.vocab_size), + labels[:, 1:].contiguous().view(-1)) + + return ModelOutput(loss=loss, logits=logits) + + +class PFedNLPModel(nn.Module): + def __init__(self, config): + super().__init__() + + self.model = EncoderDecoderModel.from_encoder_decoder_pretrained( + config.model_type, config.model_type) + self.lm_head = BertLMPredictionHead(self.model.encoder.config) + + self.task = config.task + self.pt_cfg = self.model.encoder.config + self.vocab_size = self.pt_cfg.vocab_size + self.hidden_size = self.pt_cfg.hidden_size + self.dropout_prob = self.pt_cfg.hidden_dropout_prob + self.dropout = nn.Dropout(self.dropout_prob) + + self.label_smoothing = config.label_smoothing + self.padding_idx = config.pad_token_id + self.classifier = nn.Linear(self.hidden_size, config.num_labels) \ + if config.num_labels is not None else None + + # for eval generation + self.model.config.decoder_start_token_id = config.bos_token_id + self.model.config.eos_token_id = config.eos_token_id + self.model.config.pad_token_id = config.pad_token_id + self.model.config.vocab_size = self.pt_cfg.vocab_size + self.model.config.max_length = config.max_length + self.model.config.min_length = config.min_length + self.model.config.no_repeat_ngram_size = config.no_repeat_ngram_size + self.model.config.length_penalty = config.length_penalty + self.model.config.num_beams = config.num_beams + + def generate(self, **kwargs): + return self.model.generate(**kwargs) + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + start_positions=None, + end_positions=None, + labels=None, + pretrain_task=None, + ): + enc_outputs = self.model.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + ) + + if self.task == 'pretrain': + if pretrain_task == 'mlm': + logits = self.lm_head(enc_outputs.last_hidden_state) + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(logits.view(-1, self.vocab_size), + labels.view(-1)) + loss = masked_lm_loss + + elif pretrain_task == 'denoise': + dec_outputs = self.model.decoder.bert( + input_ids=labels, + encoder_hidden_states=enc_outputs.last_hidden_state, + encoder_attention_mask=attention_mask, + ) + logits = self.model.decoder.cls( + dec_outputs.last_hidden_state)[:, :-1, :] + loss_fct = CrossEntropyLoss(ignore_index=self.padding_idx) + loss = loss_fct(logits.contiguous().view(-1, self.vocab_size), + labels[:, 1:].contiguous().view(-1)) + + elif self.task in {'imdb', 'agnews'}: + pooled_output = self.dropout(enc_outputs.pooler_output) + logits = self.classifier(pooled_output) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) + + elif self.task in {'squad', 'newsqa'}: + logits = self.classifier(enc_outputs.last_hidden_state) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + logits = (start_logits, end_logits) + + # sometimes the start/end positions are outside our model + # inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + loss = (start_loss + end_loss) / 2 + + elif self.task in {'cnndm', 'msqg'}: + dec_outputs = self.model.decoder.bert( + input_ids=labels, + encoder_hidden_states=enc_outputs.last_hidden_state, + encoder_attention_mask=attention_mask, + ) + dec_hidden_states = dec_outputs.last_hidden_state + logits = self.model.decoder.cls(dec_hidden_states)[:, :-1, :] + + num_tokens = labels[:, 1:].ne(self.padding_idx).sum().item() + label_smoothing = self.label_smoothing if self.training else 0.0 + if label_smoothing > 0: + loss_fct = LabelSmoothingLoss( + label_smoothing, + self.vocab_size, + ignore_index=self.padding_idx, + ).to(logits.device) + loss = loss_fct( + F.log_softmax( + logits.contiguous().view(-1, self.vocab_size), dim=-1), + labels[:, 1:].contiguous().view(-1)) / num_tokens + else: + loss_fct = CrossEntropyLoss(ignore_index=self.padding_idx) + loss = loss_fct(logits.contiguous().view(-1, self.vocab_size), + labels[:, 1:].contiguous().view(-1)) + + return ModelOutput(loss=loss, + logits=logits, + hidden_states=enc_outputs.last_hidden_state) + + +class PFedNLPContrastModel(nn.Module): + def __init__(self, config): + super().__init__() + + self.model = EncoderDecoderModel.from_encoder_decoder_pretrained( + config.model_type, config.model_type) + self.lm_head = BertLMPredictionHead(self.model.encoder.config) + + self.client_id = None + self.task = config.task + self.pt_cfg = self.model.encoder.config + self.vocab_size = self.pt_cfg.vocab_size + self.hidden_size = self.pt_cfg.hidden_size + self.dropout_prob = self.pt_cfg.hidden_dropout_prob + self.dropout = nn.Dropout(self.dropout_prob) + + self.label_smoothing = config.label_smoothing + self.padding_idx = config.pad_token_id + self.classifier = nn.Linear(self.hidden_size, config.num_labels) \ + if config.num_labels is not None else None + + self.contrast_topk = config.contrast_topk + self.contrast_temp = config.contrast_temp + self.train_contrast = config.train_contrast + self.contrast_head = ContrastiveHead(input_dim=self.hidden_size, + inner_dim=self.hidden_size, + out_dim=self.hidden_size, + dropout_prob=self.dropout_prob) + + # for eval generation + self.model.config.decoder_start_token_id = config.bos_token_id + self.model.config.eos_token_id = config.eos_token_id + self.model.config.pad_token_id = config.pad_token_id + self.model.config.vocab_size = self.pt_cfg.vocab_size + self.model.config.max_length = config.max_length + self.model.config.min_length = config.min_length + self.model.config.no_repeat_ngram_size = config.no_repeat_ngram_size + self.model.config.length_penalty = config.length_penalty + self.model.config.num_beams = config.num_beams + + def update_client_id(self, client_id): + self.client_id = client_id + + def generate(self, **kwargs): + return self.model.generate(**kwargs) + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + start_positions=None, + end_positions=None, + labels=None, + pretrain_task=None, + contrast_monitor=None, + in_contrast_prepare=None, + example_indices=None, + ): + if in_contrast_prepare: # return dec_hidden_states & dec_out + self.eval() + with torch.no_grad(): + example_indices = torch.stack([ + k for k in example_indices + if k.item() in contrast_monitor.synth_tokens + ]) + synth_input_ids = torch.stack([ + contrast_monitor.synth_tokens[k.item()] + for k in example_indices + ]).to(self.model.device) + + enc_hidden = torch.stack([ + contrast_monitor.enc_hidden[k.item()] + for k in example_indices + ]).to(self.model.device) + outputs = self.model.decoder.bert( + input_ids=synth_input_ids, + encoder_hidden_states=enc_hidden, + ) + logits = self.model.decoder.cls(outputs.last_hidden_state) + dec_hidden = self.contrast_head( + outputs.last_hidden_state).mean(1) + + return ModelOutput(logits=logits.argmax(-1), + hidden_states=dec_hidden, + example_indices=example_indices) + + enc_outputs = self.model.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + ) + + regular_loss, contrast_loss = None, None + if self.task == 'pretrain': + if pretrain_task == 'mlm': + logits = self.lm_head(enc_outputs.last_hidden_state) + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(logits.view(-1, self.vocab_size), + labels.view(-1)) + loss = masked_lm_loss + + elif pretrain_task == 'denoise': + dec_outputs = self.model.decoder.bert( + input_ids=labels, + encoder_hidden_states=enc_outputs.last_hidden_state, + encoder_attention_mask=attention_mask, + ) + logits = self.model.decoder.cls( + dec_outputs.last_hidden_state)[:, :-1, :] + loss_fct = CrossEntropyLoss(ignore_index=self.padding_idx) + denoise_loss = loss_fct( + logits.contiguous().view(-1, self.vocab_size), + labels[:, 1:].contiguous().view(-1)) + loss = denoise_loss + + else: + # regular loss + if self.task in {'imdb', 'agnews'}: + pooled_output = self.dropout(enc_outputs.pooler_output) + logits = self.classifier(pooled_output) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, logits.size(-1)), + labels.view(-1)) + + elif self.task in {'squad', 'newsqa'}: + logits = self.classifier(enc_outputs.last_hidden_state) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + logits = (start_logits, end_logits) + + # sometimes the start/end positions are outside our model + # inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + loss = (start_loss + end_loss) / 2 + + elif self.task in {'cnndm', 'msqg'}: + dec_outputs = self.model.decoder.bert( + input_ids=labels, + encoder_hidden_states=enc_outputs.last_hidden_state, + encoder_attention_mask=attention_mask, + ) + dec_hidden_states = dec_outputs.last_hidden_state + logits = self.model.decoder.cls(dec_hidden_states)[:, :-1, :] + + num_tokens = labels[:, 1:].ne(self.padding_idx).sum().item() + label_smoothing = self.label_smoothing if self.training \ + else 0.0 + if label_smoothing > 0: + loss_fct = LabelSmoothingLoss( + label_smoothing, + self.vocab_size, + ignore_index=self.padding_idx, + ).to(logits.device) + loss = loss_fct( + F.log_softmax(logits.contiguous().view( + -1, self.vocab_size), + dim=-1), + labels[:, 1:].contiguous().view(-1)) / num_tokens + else: + loss_fct = CrossEntropyLoss(ignore_index=self.padding_idx) + loss = loss_fct( + logits.contiguous().view(-1, self.vocab_size), + labels[:, 1:].contiguous().view(-1)) + regular_loss = loss.clone() + + # contrastive loss + if self.training and self.train_contrast: + example_indices = [ + k for k in example_indices + if k.item() in contrast_monitor.synth_tokens + ] + all_group_ids = contrast_monitor.all_group_ids[self.client_id] + topk_group_ids = \ + contrast_monitor.topk_group_ids[self.client_id] + if len(example_indices) > 0 and len(topk_group_ids) > 1: + example_indices = torch.stack(example_indices) + synth_input_ids = torch.stack([ + contrast_monitor.synth_tokens[k.item()] + for k in example_indices + ]).to(self.model.device) + + contrast_enc_hidden = torch.stack([ + contrast_monitor.enc_hidden[k.item()] + for k in example_indices + ]).to(self.model.device) + contrast_outputs = self.model.decoder.bert( + input_ids=synth_input_ids, + encoder_hidden_states=contrast_enc_hidden, + ) + cur_dec_hidden = self.contrast_head( + contrast_outputs.last_hidden_state).mean(1) + + pos_client_ids = [ + x for x in topk_group_ids[1:self.contrast_topk + 1] + ] + all_dec_hiddens = contrast_monitor.dec_hidden + sim_hiddens = [[ + all_dec_hiddens[cid][k.item()] for k in example_indices + ] for cid in pos_client_ids] + sim_hiddens = torch.stack([ + torch.stack(hid) for hid in sim_hiddens + ]).mean(0).to(self.model.device) + sim_matrix = F.cosine_similarity(cur_dec_hidden, + sim_hiddens, + dim=-1) + nominator = sim_matrix / self.contrast_temp + + neg_client_ids = [ + x for x in all_group_ids[::-1][:self.contrast_topk] + if x not in topk_group_ids + ] + if len(neg_client_ids) > 0: + dissim_hiddens = [[ + all_dec_hiddens[cid][k.item()] + for k in example_indices + ] for cid in neg_client_ids] + dissim_hiddens = torch.stack([ + torch.stack(hid) for hid in dissim_hiddens + ]).to(self.model.device) + dissim_matrix = F.cosine_similarity( + cur_dec_hidden.unsqueeze(0), + dissim_hiddens, + dim=-1) + denominator = torch.exp(dissim_matrix / + self.contrast_temp).sum(0) + contrast_loss = -torch.log( + torch.exp(nominator) / denominator).mean() + else: + contrast_loss = -nominator.mean() + loss += contrast_loss + + return ModelOutput(loss=loss, + regular_loss=regular_loss, + contrast_loss=contrast_loss, + logits=logits) + + +def call_fednlp_model(model_config, local_data): + if model_config.type == 'fednlp_model': + model = FedNLPModel(model_config) + return model + + +def call_pfednlp_model(model_config, local_data): + if model_config.type == 'pfednlp_model': + model = PFedNLPModel(model_config) + return model + + +def call_pfednlp_contrast_model(model_config, local_data): + if model_config.type == 'pfednlp_contrast_model': + model = PFedNLPContrastModel(model_config) + return model + + +register_model('fednlp_model', call_fednlp_model) +register_model('pfednlp_model', call_pfednlp_model) +register_model('pfednlp_contrast_model', call_pfednlp_contrast_model) diff --git a/federatedscope/nlp/trainer/utils.py b/federatedscope/nlp/trainer/utils.py index 0b610a8e6..5a5127304 100644 --- a/federatedscope/nlp/trainer/utils.py +++ b/federatedscope/nlp/trainer/utils.py @@ -1,4 +1,3 @@ - class AverageMeter(object): def __init__(self): self.reset() @@ -17,8 +16,14 @@ def update(self, val, n=1): class ContrastiveMonitor(object): - def __init__(self, stat=1, enc_hidden=None, synth_tokens=None, - dec_hidden=None, dec_out=None, all_group_ids=None, topk_group_ids=None): + def __init__(self, + stat=1, + enc_hidden=None, + synth_tokens=None, + dec_hidden=None, + dec_out=None, + all_group_ids=None, + topk_group_ids=None): self.stat = stat self.enc_hidden = enc_hidden self.synth_tokens = synth_tokens From f8e0a317908b16546a859997a016726ab45ab444 Mon Sep 17 00:00:00 2001 From: cheneydon Date: Tue, 25 Oct 2022 12:10:05 +0800 Subject: [PATCH 5/5] update model for hetero-fednlp --- .../core/auxiliaries/model_builder.py | 4 +- federatedscope/core/configs/cfg_aggregator.py | 12 +- federatedscope/core/configs/cfg_data.py | 12 +- federatedscope/core/configs/cfg_model.py | 26 +- federatedscope/nlp/dataloader/__init__.py | 1 + .../nlp/dataset/preprocess/get_hfl_data.py | 18 +- federatedscope/nlp/loss/label_smooth_loss.py | 31 + federatedscope/nlp/model/__init__.py | 11 +- federatedscope/nlp/model/hfl_model.py | 561 ++++++++++++++++++ 9 files changed, 645 insertions(+), 31 deletions(-) create mode 100644 federatedscope/nlp/loss/label_smooth_loss.py create mode 100644 federatedscope/nlp/model/hfl_model.py diff --git a/federatedscope/core/auxiliaries/model_builder.py b/federatedscope/core/auxiliaries/model_builder.py index 522159153..ae75c64fd 100644 --- a/federatedscope/core/auxiliaries/model_builder.py +++ b/federatedscope/core/auxiliaries/model_builder.py @@ -1,8 +1,6 @@ import logging - -import numpy as np - import federatedscope.register as register +from federatedscope.nlp.model import * logger = logging.getLogger(__name__) diff --git a/federatedscope/core/configs/cfg_aggregator.py b/federatedscope/core/configs/cfg_aggregator.py index e70df3e9f..3b3b9d4b7 100644 --- a/federatedscope/core/configs/cfg_aggregator.py +++ b/federatedscope/core/configs/cfg_aggregator.py @@ -4,12 +4,12 @@ def extend_aggregator_cfg(cfg): cfg.aggregator = CN() - cfg.aggregator.num_agg_groups = None - cfg.aggregator.num_agg_topk = None - cfg.aggregator.inside_weight = None - cfg.aggregator.outside_weight = None - cfg.aggregator.proto_weight = None - cfg.aggregator.synth_ratio = None + cfg.aggregator.num_agg_groups = 1 + cfg.aggregator.num_agg_topk = 100 + cfg.aggregator.inside_weight = 1.0 + cfg.aggregator.outside_weight = 0.0 + cfg.aggregator.proto_weight = 0.0 + cfg.aggregator.synth_ratio = 0.5 # --------------- register corresponding check function ---------- cfg.register_cfg_check_fun(assert_aggregator_cfg) diff --git a/federatedscope/core/configs/cfg_data.py b/federatedscope/core/configs/cfg_data.py index f09e13b95..69de4f334 100644 --- a/federatedscope/core/configs/cfg_data.py +++ b/federatedscope/core/configs/cfg_data.py @@ -56,12 +56,12 @@ def extend_data_cfg(cfg): cfg.data.quadratic.max_curv = 12.5 # fednlp - cfg.data.datasets = [] - cfg.data.num_grouped_clients = [] - cfg.data.max_seq_len = 0 - cfg.data.max_tgt_len = 0 - cfg.data.max_query_len = 0 - cfg.data.trunc_stride = 0 + cfg.data.datasets = ['imdb', 'agnews', 'squad', 'newsqa', 'cnndm', 'msqg'] + cfg.data.num_grouped_clients = [1, 3, 3, 2, 5, 4] + cfg.data.max_seq_len = 384 + cfg.data.max_tgt_len = 128 + cfg.data.max_query_len = 128 + cfg.data.trunc_stride = 128 cfg.data.cache_dir = '' cfg.data.num_contrast = 0 cfg.data.debug = False diff --git a/federatedscope/core/configs/cfg_model.py b/federatedscope/core/configs/cfg_model.py index 56fe98412..eb315f98e 100644 --- a/federatedscope/core/configs/cfg_model.py +++ b/federatedscope/core/configs/cfg_model.py @@ -25,15 +25,31 @@ def extend_model_cfg(cfg): cfg.model.input_shape = () # A tuple, e.g., (in_channel, h, w) # fednlp - cfg.model.model_type = '' - cfg.model.bos_token = '' - cfg.model.eos_token = '' - cfg.model.eoq_token = '' - cfg.model.pad_token = '' + cfg.model.model_type = 'google/bert_uncased_L-2_H-128_A-2' + cfg.model.bos_token = '[unused0]' + cfg.model.eos_token = '[unused1]' + cfg.model.eoq_token = '[unused2]' cfg.model.bos_token_id = -1 cfg.model.eos_token_id = -1 cfg.model.eoq_token_id = -1 cfg.model.pad_token_id = -1 + cfg.model.task = '' + cfg.model.pretrain_task = '' + cfg.model.pretrain_tasks = [] + cfg.model.downstream_tasks = [] + cfg.model.num_labels = 1 + cfg.model.max_length = 200 + cfg.model.min_length = 1 + cfg.model.no_repeat_ngram_size = 3 + cfg.model.length_penalty = 2.0 + cfg.model.num_beams = 5 + cfg.model.label_smoothing = 0.1 + cfg.model.n_best_size = 20 + cfg.model.max_answer_len = 30 + cfg.model.null_score_diff_threshold = 0.0 + cfg.model.train_contrast = False + cfg.model.contrast_topk = 100 + cfg.model.contrast_temp = 1.0 # ---------------------------------------------------------------------- # # Criterion related options diff --git a/federatedscope/nlp/dataloader/__init__.py b/federatedscope/nlp/dataloader/__init__.py index c0b31382d..05cdc7f3f 100644 --- a/federatedscope/nlp/dataloader/__init__.py +++ b/federatedscope/nlp/dataloader/__init__.py @@ -6,3 +6,4 @@ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py') ] +__all__ += ['load_nlp_dataset'] diff --git a/federatedscope/nlp/dataset/preprocess/get_hfl_data.py b/federatedscope/nlp/dataset/preprocess/get_hfl_data.py index 71e5f1b8f..73bfdf402 100644 --- a/federatedscope/nlp/dataset/preprocess/get_hfl_data.py +++ b/federatedscope/nlp/dataset/preprocess/get_hfl_data.py @@ -19,11 +19,11 @@ def __init__(self, config, train_frac=0.9): self.total_client_num = config.federate.client_num self.num_grouped_clients = config.data.num_grouped_clients self.train_frac = train_frac + self.all_train_data = [] + self.all_val_data = [] + self.all_test_data = [] def get_data(self): - all_train_data = [] - all_val_data = [] - all_test_data = [] for i, dataset in enumerate(self.datasets): if dataset not in HFL_NAMES: raise ValueError(f'No HFL dataset named {dataset}') @@ -39,10 +39,11 @@ def get_data(self): ] test_data = self._load_data(dataset, 'test', self.num_grouped_clients[i]) - all_train_data.extend(train_data) - all_val_data.extend(val_data) - all_test_data.extend(test_data) - return all_train_data, all_val_data, all_test_data + self.all_train_data.extend(train_data) + self.all_val_data.extend(val_data) + self.all_test_data.extend(test_data) + + return self.all_train_data, self.all_val_data, self.all_test_data def _load_data(self, dataset, split, num_clients): data_dir = os.path.join(self.data_dir, dataset) @@ -127,7 +128,8 @@ def _load_data(self, dataset, split, num_clients): cur_data = data[data_idx:data_idx + num_split] data_idx += num_split all_split_data.append(cur_data) - logger.info(f'Client id: {i + 1}, num samples: {num_split}') + logger.info(f'Client id: {len(self.all_train_data) + i + 1}, ' + f'num samples: {num_split}') return all_split_data def _download(self, dataset): diff --git a/federatedscope/nlp/loss/label_smooth_loss.py b/federatedscope/nlp/loss/label_smooth_loss.py new file mode 100644 index 000000000..b4a7b9530 --- /dev/null +++ b/federatedscope/nlp/loss/label_smooth_loss.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LabelSmoothingLoss(nn.Module): + """ + With label smoothing, + KL-divergence between q_{smoothed ground truth prob.}(w) + and p_{prob. computed by model}(w) is minimized. + """ + def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100): + assert 0.0 < label_smoothing <= 1.0 + self.padding_idx = ignore_index + super(LabelSmoothingLoss, self).__init__() + + smoothing_value = label_smoothing / (tgt_vocab_size - 2) + one_hot = torch.full((tgt_vocab_size, ), smoothing_value) + one_hot[self.padding_idx] = 0 + self.register_buffer('one_hot', one_hot.unsqueeze(0)) + self.confidence = 1.0 - label_smoothing + + def forward(self, output, target): + """ + output (FloatTensor): batch_size x n_classes + target (LongTensor): batch_size + """ + model_prob = self.one_hot.repeat(target.size(0), 1) + model_prob.scatter_(1, target.unsqueeze(1), self.confidence) + model_prob.masked_fill_((target == self.padding_idx).unsqueeze(1), 0) + return F.kl_div(output, model_prob, reduction='sum') diff --git a/federatedscope/nlp/model/__init__.py b/federatedscope/nlp/model/__init__.py index 941335213..47dc90a85 100644 --- a/federatedscope/nlp/model/__init__.py +++ b/federatedscope/nlp/model/__init__.py @@ -1,4 +1,9 @@ -from federatedscope.nlp.model.rnn import LSTM -from federatedscope.nlp.model.model_builder import get_rnn, get_transformer +from os.path import dirname, basename, isfile, join +import glob -__all__ = ['LSTM', 'get_rnn', 'get_transformer'] +modules = glob.glob(join(dirname(__file__), "*.py")) +__all__ = [ + basename(f)[:-3] for f in modules + if isfile(f) and not f.endswith('__init__.py') +] +__all__ += ['LSTM', 'get_rnn', 'get_transformer'] diff --git a/federatedscope/nlp/model/hfl_model.py b/federatedscope/nlp/model/hfl_model.py new file mode 100644 index 000000000..1a168a320 --- /dev/null +++ b/federatedscope/nlp/model/hfl_model.py @@ -0,0 +1,561 @@ +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.models.bert.modeling_bert import BertLMPredictionHead +from transformers.models.encoder_decoder import EncoderDecoderModel +from federatedscope.register import register_model +from federatedscope.nlp.loss.label_smooth_loss import LabelSmoothingLoss + + +class ModelOutput(object): + def __init__(self, + loss=None, + regular_loss=None, + contrast_loss=None, + logits=None, + hidden_states=None, + example_indices=None): + self.loss = loss + self.regular_loss = regular_loss + self.contrast_loss = contrast_loss + self.logits = logits + self.hidden_states = hidden_states + self.example_indices = example_indices + + +class ContrastiveHead(nn.Module): + def __init__(self, input_dim, inner_dim, out_dim, dropout_prob): + super().__init__() + + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=dropout_prob) + self.out_prj = nn.Linear(inner_dim, out_dim) + + def forward(self, x): + x = self.dense(self.dropout(x)) + x = torch.tanh(x) + x = self.out_prj(self.dropout(x)) + return x + + +class FedNLPModel(nn.Module): + def __init__(self, config): + super().__init__() + + self.model = EncoderDecoderModel.from_encoder_decoder_pretrained( + config.model_type, config.model_type) + self.lm_head = BertLMPredictionHead(self.model.encoder.config) + + self.task = config.task + self.pretrain_task = config.pretrain_task + self.pt_cfg = self.model.encoder.config + self.vocab_size = self.pt_cfg.vocab_size + self.hidden_size = self.pt_cfg.hidden_size + self.dropout_prob = self.pt_cfg.hidden_dropout_prob + self.dropout = nn.Dropout(self.dropout_prob) + + self.label_smoothing = config.label_smoothing + self.padding_idx = config.pad_token_id + self.classifier = nn.Linear(self.hidden_size, config.num_labels) \ + if config.num_labels is not None else None + + # for eval generation + self.model.config.decoder_start_token_id = config.bos_token_id + self.model.config.eos_token_id = config.eos_token_id + self.model.config.pad_token_id = config.pad_token_id + self.model.config.vocab_size = self.pt_cfg.vocab_size + self.model.config.max_length = config.max_length + self.model.config.min_length = config.min_length + self.model.config.no_repeat_ngram_size = config.no_repeat_ngram_size + self.model.config.length_penalty = config.length_penalty + self.model.config.num_beams = config.num_beams + + def generate(self, **kwargs): + return self.model.generate(**kwargs) + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + start_positions=None, + end_positions=None, + labels=None, + ): + enc_outputs = self.model.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + ) + + if self.task == 'pretrain': + if self.pretrain_task == 'mlm': + logits = self.lm_head(enc_outputs.last_hidden_state) + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(logits.view(-1, self.vocab_size), + labels.view(-1)) + loss = masked_lm_loss + + elif self.pretrain_task == 'denoise': + dec_outputs = self.model.decoder.bert( + input_ids=labels, + encoder_hidden_states=enc_outputs.last_hidden_state, + encoder_attention_mask=attention_mask, + ) + dec_hidden_states = dec_outputs.last_hidden_state + logits = self.model.decoder.cls(dec_hidden_states)[:, :-1, :] + loss_fct = CrossEntropyLoss(ignore_index=self.padding_idx) + loss = loss_fct(logits.contiguous().view(-1, self.vocab_size), + labels[:, 1:].contiguous().view(-1)) + + elif self.task in {'imdb', 'agnews'}: + pooled_output = self.dropout(enc_outputs.pooler_output) + logits = self.classifier(pooled_output) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) + + elif self.task in {'squad', 'newsqa'}: + logits = self.classifier(enc_outputs.last_hidden_state) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + logits = (start_logits, end_logits) + + # sometimes the start/end positions are outside our model + # inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + loss = (start_loss + end_loss) / 2 + + elif self.task in {'cnndm', 'msqg'}: + dec_outputs = self.model.decoder.bert( + input_ids=labels, + encoder_hidden_states=enc_outputs.last_hidden_state, + encoder_attention_mask=attention_mask, + ) + dec_hidden_states = dec_outputs.last_hidden_state + logits = self.model.decoder.cls(dec_hidden_states)[:, :-1, :] + + num_tokens = labels[:, 1:].ne(self.padding_idx).sum().item() + label_smoothing = self.label_smoothing if self.training else 0.0 + if label_smoothing > 0: + loss_fct = LabelSmoothingLoss( + label_smoothing, + self.vocab_size, + ignore_index=self.padding_idx, + ).to(logits.device) + loss = loss_fct( + F.log_softmax( + logits.contiguous().view(-1, self.vocab_size), dim=-1), + labels[:, 1:].contiguous().view(-1)) / num_tokens + else: + loss_fct = CrossEntropyLoss(ignore_index=self.padding_idx) + loss = loss_fct(logits.contiguous().view(-1, self.vocab_size), + labels[:, 1:].contiguous().view(-1)) + + return ModelOutput(loss=loss, logits=logits) + + +class PFedNLPModel(nn.Module): + def __init__(self, config): + super().__init__() + + self.model = EncoderDecoderModel.from_encoder_decoder_pretrained( + config.model_type, config.model_type) + self.lm_head = BertLMPredictionHead(self.model.encoder.config) + + self.task = config.task + self.pt_cfg = self.model.encoder.config + self.vocab_size = self.pt_cfg.vocab_size + self.hidden_size = self.pt_cfg.hidden_size + self.dropout_prob = self.pt_cfg.hidden_dropout_prob + self.dropout = nn.Dropout(self.dropout_prob) + + self.label_smoothing = config.label_smoothing + self.padding_idx = config.pad_token_id + self.classifier = nn.Linear(self.hidden_size, config.num_labels) \ + if config.num_labels is not None else None + + # for eval generation + self.model.config.decoder_start_token_id = config.bos_token_id + self.model.config.eos_token_id = config.eos_token_id + self.model.config.pad_token_id = config.pad_token_id + self.model.config.vocab_size = self.pt_cfg.vocab_size + self.model.config.max_length = config.max_length + self.model.config.min_length = config.min_length + self.model.config.no_repeat_ngram_size = config.no_repeat_ngram_size + self.model.config.length_penalty = config.length_penalty + self.model.config.num_beams = config.num_beams + + def generate(self, **kwargs): + return self.model.generate(**kwargs) + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + start_positions=None, + end_positions=None, + labels=None, + pretrain_task=None, + ): + enc_outputs = self.model.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + ) + + if self.task == 'pretrain': + if pretrain_task == 'mlm': + logits = self.lm_head(enc_outputs.last_hidden_state) + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(logits.view(-1, self.vocab_size), + labels.view(-1)) + loss = masked_lm_loss + + elif pretrain_task == 'denoise': + dec_outputs = self.model.decoder.bert( + input_ids=labels, + encoder_hidden_states=enc_outputs.last_hidden_state, + encoder_attention_mask=attention_mask, + ) + logits = self.model.decoder.cls( + dec_outputs.last_hidden_state)[:, :-1, :] + loss_fct = CrossEntropyLoss(ignore_index=self.padding_idx) + loss = loss_fct(logits.contiguous().view(-1, self.vocab_size), + labels[:, 1:].contiguous().view(-1)) + + elif self.task in {'imdb', 'agnews'}: + pooled_output = self.dropout(enc_outputs.pooler_output) + logits = self.classifier(pooled_output) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) + + elif self.task in {'squad', 'newsqa'}: + logits = self.classifier(enc_outputs.last_hidden_state) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + logits = (start_logits, end_logits) + + # sometimes the start/end positions are outside our model + # inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + loss = (start_loss + end_loss) / 2 + + elif self.task in {'cnndm', 'msqg'}: + dec_outputs = self.model.decoder.bert( + input_ids=labels, + encoder_hidden_states=enc_outputs.last_hidden_state, + encoder_attention_mask=attention_mask, + ) + dec_hidden_states = dec_outputs.last_hidden_state + logits = self.model.decoder.cls(dec_hidden_states)[:, :-1, :] + + num_tokens = labels[:, 1:].ne(self.padding_idx).sum().item() + label_smoothing = self.label_smoothing if self.training else 0.0 + if label_smoothing > 0: + loss_fct = LabelSmoothingLoss( + label_smoothing, + self.vocab_size, + ignore_index=self.padding_idx, + ).to(logits.device) + loss = loss_fct( + F.log_softmax( + logits.contiguous().view(-1, self.vocab_size), dim=-1), + labels[:, 1:].contiguous().view(-1)) / num_tokens + else: + loss_fct = CrossEntropyLoss(ignore_index=self.padding_idx) + loss = loss_fct(logits.contiguous().view(-1, self.vocab_size), + labels[:, 1:].contiguous().view(-1)) + + return ModelOutput(loss=loss, + logits=logits, + hidden_states=enc_outputs.last_hidden_state) + + +class PFedNLPContrastModel(nn.Module): + def __init__(self, config): + super().__init__() + + self.model = EncoderDecoderModel.from_encoder_decoder_pretrained( + config.model_type, config.model_type) + self.lm_head = BertLMPredictionHead(self.model.encoder.config) + + self.client_id = None + self.task = config.task + self.pt_cfg = self.model.encoder.config + self.vocab_size = self.pt_cfg.vocab_size + self.hidden_size = self.pt_cfg.hidden_size + self.dropout_prob = self.pt_cfg.hidden_dropout_prob + self.dropout = nn.Dropout(self.dropout_prob) + + self.label_smoothing = config.label_smoothing + self.padding_idx = config.pad_token_id + self.classifier = nn.Linear(self.hidden_size, config.num_labels) \ + if config.num_labels is not None else None + + self.contrast_topk = config.contrast_topk + self.contrast_temp = config.contrast_temp + self.train_contrast = config.train_contrast + self.contrast_head = ContrastiveHead(input_dim=self.hidden_size, + inner_dim=self.hidden_size, + out_dim=self.hidden_size, + dropout_prob=self.dropout_prob) + + # for eval generation + self.model.config.decoder_start_token_id = config.bos_token_id + self.model.config.eos_token_id = config.eos_token_id + self.model.config.pad_token_id = config.pad_token_id + self.model.config.vocab_size = self.pt_cfg.vocab_size + self.model.config.max_length = config.max_length + self.model.config.min_length = config.min_length + self.model.config.no_repeat_ngram_size = config.no_repeat_ngram_size + self.model.config.length_penalty = config.length_penalty + self.model.config.num_beams = config.num_beams + + def update_client_id(self, client_id): + self.client_id = client_id + + def generate(self, **kwargs): + return self.model.generate(**kwargs) + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + start_positions=None, + end_positions=None, + labels=None, + pretrain_task=None, + contrast_monitor=None, + in_contrast_prepare=None, + example_indices=None, + ): + if in_contrast_prepare: # return dec_hidden_states & dec_out + self.eval() + with torch.no_grad(): + example_indices = torch.stack([ + k for k in example_indices + if k.item() in contrast_monitor.synth_tokens + ]) + synth_input_ids = torch.stack([ + contrast_monitor.synth_tokens[k.item()] + for k in example_indices + ]).to(self.model.device) + + enc_hidden = torch.stack([ + contrast_monitor.enc_hidden[k.item()] + for k in example_indices + ]).to(self.model.device) + outputs = self.model.decoder.bert( + input_ids=synth_input_ids, + encoder_hidden_states=enc_hidden, + ) + logits = self.model.decoder.cls(outputs.last_hidden_state) + dec_hidden = self.contrast_head( + outputs.last_hidden_state).mean(1) + + return ModelOutput(logits=logits.argmax(-1), + hidden_states=dec_hidden, + example_indices=example_indices) + + enc_outputs = self.model.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + ) + + regular_loss, contrast_loss = None, None + if self.task == 'pretrain': + if pretrain_task == 'mlm': + logits = self.lm_head(enc_outputs.last_hidden_state) + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(logits.view(-1, self.vocab_size), + labels.view(-1)) + loss = masked_lm_loss + + elif pretrain_task == 'denoise': + dec_outputs = self.model.decoder.bert( + input_ids=labels, + encoder_hidden_states=enc_outputs.last_hidden_state, + encoder_attention_mask=attention_mask, + ) + logits = self.model.decoder.cls( + dec_outputs.last_hidden_state)[:, :-1, :] + loss_fct = CrossEntropyLoss(ignore_index=self.padding_idx) + denoise_loss = loss_fct( + logits.contiguous().view(-1, self.vocab_size), + labels[:, 1:].contiguous().view(-1)) + loss = denoise_loss + + else: + # regular loss + if self.task in {'imdb', 'agnews'}: + pooled_output = self.dropout(enc_outputs.pooler_output) + logits = self.classifier(pooled_output) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, logits.size(-1)), + labels.view(-1)) + + elif self.task in {'squad', 'newsqa'}: + logits = self.classifier(enc_outputs.last_hidden_state) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + logits = (start_logits, end_logits) + + # sometimes the start/end positions are outside our model + # inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + loss = (start_loss + end_loss) / 2 + + elif self.task in {'cnndm', 'msqg'}: + dec_outputs = self.model.decoder.bert( + input_ids=labels, + encoder_hidden_states=enc_outputs.last_hidden_state, + encoder_attention_mask=attention_mask, + ) + dec_hidden_states = dec_outputs.last_hidden_state + logits = self.model.decoder.cls(dec_hidden_states)[:, :-1, :] + + num_tokens = labels[:, 1:].ne(self.padding_idx).sum().item() + label_smoothing = self.label_smoothing if self.training \ + else 0.0 + if label_smoothing > 0: + loss_fct = LabelSmoothingLoss( + label_smoothing, + self.vocab_size, + ignore_index=self.padding_idx, + ).to(logits.device) + loss = loss_fct( + F.log_softmax(logits.contiguous().view( + -1, self.vocab_size), + dim=-1), + labels[:, 1:].contiguous().view(-1)) / num_tokens + else: + loss_fct = CrossEntropyLoss(ignore_index=self.padding_idx) + loss = loss_fct( + logits.contiguous().view(-1, self.vocab_size), + labels[:, 1:].contiguous().view(-1)) + regular_loss = loss.clone() + + # contrastive loss + if self.training and self.train_contrast: + example_indices = [ + k for k in example_indices + if k.item() in contrast_monitor.synth_tokens + ] + all_group_ids = contrast_monitor.all_group_ids[self.client_id] + topk_group_ids = \ + contrast_monitor.topk_group_ids[self.client_id] + if len(example_indices) > 0 and len(topk_group_ids) > 1: + example_indices = torch.stack(example_indices) + synth_input_ids = torch.stack([ + contrast_monitor.synth_tokens[k.item()] + for k in example_indices + ]).to(self.model.device) + + contrast_enc_hidden = torch.stack([ + contrast_monitor.enc_hidden[k.item()] + for k in example_indices + ]).to(self.model.device) + contrast_outputs = self.model.decoder.bert( + input_ids=synth_input_ids, + encoder_hidden_states=contrast_enc_hidden, + ) + cur_dec_hidden = self.contrast_head( + contrast_outputs.last_hidden_state).mean(1) + + pos_client_ids = [ + x for x in topk_group_ids[1:self.contrast_topk + 1] + ] + all_dec_hiddens = contrast_monitor.dec_hidden + sim_hiddens = [[ + all_dec_hiddens[cid][k.item()] for k in example_indices + ] for cid in pos_client_ids] + sim_hiddens = torch.stack([ + torch.stack(hid) for hid in sim_hiddens + ]).mean(0).to(self.model.device) + sim_matrix = F.cosine_similarity(cur_dec_hidden, + sim_hiddens, + dim=-1) + nominator = sim_matrix / self.contrast_temp + + neg_client_ids = [ + x for x in all_group_ids[::-1][:self.contrast_topk] + if x not in topk_group_ids + ] + if len(neg_client_ids) > 0: + dissim_hiddens = [[ + all_dec_hiddens[cid][k.item()] + for k in example_indices + ] for cid in neg_client_ids] + dissim_hiddens = torch.stack([ + torch.stack(hid) for hid in dissim_hiddens + ]).to(self.model.device) + dissim_matrix = F.cosine_similarity( + cur_dec_hidden.unsqueeze(0), + dissim_hiddens, + dim=-1) + denominator = torch.exp(dissim_matrix / + self.contrast_temp).sum(0) + contrast_loss = -torch.log( + torch.exp(nominator) / denominator).mean() + else: + contrast_loss = -nominator.mean() + loss += contrast_loss + + return ModelOutput(loss=loss, + regular_loss=regular_loss, + contrast_loss=contrast_loss, + logits=logits) + + +def call_fednlp_model(model_config, local_data): + if model_config.type == 'fednlp_model': + model = FedNLPModel(model_config) + return model + + +def call_pfednlp_model(model_config, local_data): + if model_config.type == 'pfednlp_model': + model = PFedNLPModel(model_config) + return model + + +def call_pfednlp_contrast_model(model_config, local_data): + if model_config.type == 'pfednlp_contrast_model': + model = PFedNLPContrastModel(model_config) + return model + + +register_model('fednlp_model', call_fednlp_model) +register_model('pfednlp_model', call_pfednlp_model) +register_model('pfednlp_contrast_model', call_pfednlp_contrast_model)