-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrun_anonymization.py
36 lines (29 loc) · 1.17 KB
/
run_anonymization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import logging
from pathlib import Path
from argparse import ArgumentParser
import torch
from anonymization.pipelines.sttts_pipeline import STTTSPipeline
from utils import parse_yaml, get_datasets
PIPELINES = {
'sttts': STTTSPipeline
}
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--config', default='anon_config.yaml')
parser.add_argument('--gpu_ids', default='0')
parser.add_argument('--force_compute', default=False, type=bool)
args = parser.parse_args()
config = parse_yaml(Path('configs', args.config))
datasets = get_datasets(config)
gpus = args.gpu_ids.split(',')
devices = []
if torch.cuda.is_available():
for gpu in gpus:
devices.append(torch.device(f'cuda:{gpu}'))
else:
devices.append(torch.device('cpu'))
with torch.no_grad():
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s- %(levelname)s - %(message)s')
logging.info(f'Running pipeline: {config["pipeline"]}')
pipeline = PIPELINES[config['pipeline']](config=config, force_compute=args.force_compute, devices=devices)
pipeline.run_anonymization_pipeline(datasets)