diff --git a/data/coco.py b/data/coco.py index 765531761..a1065578a 100644 --- a/data/coco.py +++ b/data/coco.py @@ -84,7 +84,11 @@ class COCODetection(data.Dataset): """ def __init__(self, root, image_set='trainval35k', transform=None, - target_transform=COCOAnnotationTransform(), dataset_name='MS COCO'): + target_transform=None, dataset_name='MS COCO'): + + if target_transform is None: + target_transform = COCOAnnotationTransform() + sys.path.append(osp.join(root, COCO_API)) from pycocotools.coco import COCO self.root = osp.join(root, IMAGES, image_set) diff --git a/data/voc0712.py b/data/voc0712.py index a3e80d037..e8718a92b 100644 --- a/data/voc0712.py +++ b/data/voc0712.py @@ -96,8 +96,12 @@ class VOCDetection(data.Dataset): def __init__(self, root, image_sets=[('2007', 'trainval'), ('2012', 'trainval')], - transform=None, target_transform=VOCAnnotationTransform(), + transform=None, target_transform=None, dataset_name='VOC0712'): + + if target_transform is None: + target_transform = VOCAnnotationTransform() + self.root = root self.image_set = image_sets self.transform = transform