From 9fc2ecd553b16b673be9cbfac4c21a076cf16ef5 Mon Sep 17 00:00:00 2001 From: Potter Hsu Date: Tue, 11 Sep 2018 04:37:11 +0800 Subject: [PATCH] Migrate to PyTorch 0.4.1 --- README.md | 16 +++--- bbox.py | 23 ++++---- dataset.py | 16 +++--- eval.py | 2 +- evaluator.py | 33 ++++++------ infer.py | 2 +- model.py | 97 ++++++++++++++-------------------- nms/nms.py | 5 +- nms/src/nms.c | 6 +-- rpn/region_proposal_network.py | 41 +++++++------- train.py | 8 +-- 11 files changed, 119 insertions(+), 130 deletions(-) diff --git a/README.md b/README.md index 6e06432f..0dbb9e94 100644 --- a/README.md +++ b/README.md @@ -26,11 +26,11 @@ An easy implementation of Faster R-CNN in PyTorch. * **25 minutes** every 10000 steps - * **3 hours** for 70000 steps (which leads to mAP=70.29%) + * **3 hours** for 70000 steps (which leads to mAP=xx.xx%) * Inference - * **~9 examples** per second + * **~13 examples** per second ### Trained Model @@ -39,8 +39,8 @@ An easy implementation of Faster R-CNN in PyTorch. ## Requirements * Python 3.6 -* torch 0.3.1 -* torchvision 0.2.0 +* torch 0.4.1 +* torchvision 0.2.1 * tqdm ``` @@ -50,8 +50,8 @@ An easy implementation of Faster R-CNN in PyTorch. ## Setup 1. Download VOC 2007 Dataset - - [Training / Validation](http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar) - - [Test](http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar) + - [Training / Validation](http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar) (5011 images) + - [Test](http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar) (4952 images) 1. Extract to data folder, now your folder structure should be like: ``` @@ -83,7 +83,9 @@ An easy implementation of Faster R-CNN in PyTorch. $ python test_nms.py ``` > sm_61 is for GTX-1080-Ti, to see others, visit [here](http://arnon.dk/matching-sm-architectures-arch-and-gencode-for-various-nvidia-cards/) - + + > Try to rebuild module if unit test fails + * result after running `test_nms.py` ![](https://github.com/potterhsu/easy-faster-rcnn.pytorch/blob/master/images/test_nms.png?raw=true) diff --git a/bbox.py b/bbox.py index 1ba5a882..3bae136e 100644 --- a/bbox.py +++ b/bbox.py @@ -1,10 +1,11 @@ -import torch import numpy as np +import torch +from torch import Tensor class BBox(object): - def __init__(self, left: float, top: float, right: float, bottom: float) -> None: + def __init__(self, left: float, top: float, right: float, bottom: float): super().__init__() self.left = left self.top = top @@ -19,7 +20,7 @@ def tolist(self): return [self.left, self.top, self.right, self.bottom] @staticmethod - def to_center_base(bboxes): + def to_center_base(bboxes: Tensor): return torch.stack([ (bboxes[:, 0] + bboxes[:, 2]) / 2, (bboxes[:, 1] + bboxes[:, 3]) / 2, @@ -28,7 +29,7 @@ def to_center_base(bboxes): ], dim=1) @staticmethod - def from_center_base(center_based_bboxes): + def from_center_base(center_based_bboxes: Tensor) -> Tensor: return torch.stack([ center_based_bboxes[:, 0] - center_based_bboxes[:, 2] / 2, center_based_bboxes[:, 1] - center_based_bboxes[:, 3] / 2, @@ -37,7 +38,7 @@ def from_center_base(center_based_bboxes): ], dim=1) @staticmethod - def calc_transformer(src_bboxes, dst_bboxes): + def calc_transformer(src_bboxes: Tensor, dst_bboxes: Tensor) -> Tensor: center_based_src_bboxes = BBox.to_center_base(src_bboxes) center_based_dst_bboxes = BBox.to_center_base(dst_bboxes) transformers = torch.stack([ @@ -49,7 +50,7 @@ def calc_transformer(src_bboxes, dst_bboxes): return transformers @staticmethod - def apply_transformer(src_bboxes, transformers): + def apply_transformer(src_bboxes: Tensor, transformers: Tensor) -> Tensor: center_based_src_bboxes = BBox.to_center_base(src_bboxes) center_based_dst_bboxes = torch.stack([ transformers[:, 0] * center_based_src_bboxes[:, 2] + center_based_src_bboxes[:, 0], @@ -61,7 +62,7 @@ def apply_transformer(src_bboxes, transformers): return dst_bboxes @staticmethod - def iou(source, other): + def iou(source: Tensor, other: Tensor) -> Tensor: source = source.repeat(other.shape[0], 1, 1).permute(1, 0, 2) other = other.repeat(source.shape[0], 1, 1) @@ -79,14 +80,14 @@ def iou(source, other): return intersection_area / (source_area + other_area - intersection_area) @staticmethod - def inside(source, other) -> bool: + def inside(source: Tensor, other: Tensor) -> bool: source = source.repeat(other.shape[0], 1, 1).permute(1, 0, 2) other = other.repeat(source.shape[0], 1, 1) return ((source[:, :, 0] >= other[:, :, 0]) * (source[:, :, 1] >= other[:, :, 1]) * (source[:, :, 2] <= other[:, :, 2]) * (source[:, :, 3] <= other[:, :, 3])) @staticmethod - def clip(bboxes, left: float, top: float, right: float, bottom: float): + def clip(bboxes: Tensor, left: float, top: float, right: float, bottom: float) -> Tensor: return torch.stack([ torch.clamp(bboxes[:, 0], min=left, max=right), torch.clamp(bboxes[:, 1], min=top, max=bottom), @@ -95,7 +96,7 @@ def clip(bboxes, left: float, top: float, right: float, bottom: float): ], dim=1) @staticmethod - def generate_anchors(max_x: int, max_y: int, stride: int): + def generate_anchors(max_x: int, max_y: int, stride: int) -> Tensor: center_based_anchor_bboxes = [] # NOTE: it's important to let `anchor_y` be the major index of list (i.e., move horizontally and then vertically) for consistency with 2D convolution @@ -110,7 +111,7 @@ def generate_anchors(max_x: int, max_y: int, stride: int): width = size * np.sqrt(1 / r) center_based_anchor_bboxes.append([center_x, center_y, width, height]) - center_based_anchor_bboxes = torch.FloatTensor(center_based_anchor_bboxes) + center_based_anchor_bboxes = torch.tensor(center_based_anchor_bboxes, dtype=torch.float) anchor_bboxes = BBox.from_center_base(center_based_anchor_bboxes) return anchor_bboxes diff --git a/dataset.py b/dataset.py index 3b250582..fe600c36 100644 --- a/dataset.py +++ b/dataset.py @@ -7,8 +7,8 @@ import PIL import torch.utils.data from PIL import Image, ImageOps +from torch import Tensor from torchvision import transforms -from torch import FloatTensor, LongTensor from bbox import BBox @@ -21,7 +21,7 @@ class Mode(Enum): class Annotation(object): class Object(object): - def __init__(self, name: str, difficult: bool, bbox: BBox) -> None: + def __init__(self, name: str, difficult: bool, bbox: BBox): super().__init__() self.name = name self.difficult = difficult @@ -31,7 +31,7 @@ def __repr__(self) -> str: return 'Object[name={:s}, difficult={!s}, bbox={!s}]'.format( self.name, self.difficult, self.bbox) - def __init__(self, filename: str, objects: List[Object]) -> None: + def __init__(self, filename: str, objects: List[Object]): super().__init__() self.filename = filename self.objects = objects @@ -46,7 +46,7 @@ def __init__(self, filename: str, objects: List[Object]) -> None: LABEL_TO_CATEGORY_DICT = {v: k for k, v in CATEGORY_TO_LABEL_DICT.items()} - def __init__(self, path_to_data_dir: str, mode: Mode) -> None: + def __init__(self, path_to_data_dir: str, mode: Mode): super().__init__() self._mode = mode @@ -89,15 +89,15 @@ def __init__(self, path_to_data_dir: str, mode: Mode) -> None: def __len__(self) -> int: return len(self._image_id_to_annotation_dict) - def __getitem__(self, index: int) -> Tuple[str, FloatTensor, float, FloatTensor, LongTensor]: + def __getitem__(self, index: int) -> Tuple[str, Tensor, float, Tensor, Tensor]: image_id = self._image_ids[index] annotation = self._image_id_to_annotation_dict[image_id] bboxes = [obj.bbox.tolist() for obj in annotation.objects if not obj.difficult] labels = [Dataset.CATEGORY_TO_LABEL_DICT[obj.name] for obj in annotation.objects if not obj.difficult] - bboxes = torch.FloatTensor(bboxes) - labels = torch.LongTensor(labels) + bboxes = torch.tensor(bboxes, dtype=torch.float) + labels = torch.tensor(labels, dtype=torch.long) image = Image.open(os.path.join(self._path_to_jpeg_images_dir, annotation.filename)) @@ -112,7 +112,7 @@ def __getitem__(self, index: int) -> Tuple[str, FloatTensor, float, FloatTensor, return image_id, image, scale, bboxes, labels @staticmethod - def preprocess(image: PIL.Image.Image): + def preprocess(image: PIL.Image.Image) -> Tuple[Tensor, float]: # resize according to the rules: # 1. scale shorter edge to 600 # 2. after scaling, if longer edge > 1000, scale longer edge to 1000 diff --git a/eval.py b/eval.py index e0e2e206..9027384a 100644 --- a/eval.py +++ b/eval.py @@ -7,7 +7,7 @@ from model import Model -def _eval(path_to_checkpoint, path_to_data_dir, path_to_results_dir): +def _eval(path_to_checkpoint: str, path_to_data_dir: str, path_to_results_dir: str): dataset = Dataset(path_to_data_dir, Dataset.Mode.TEST) evaluator = Evaluator(dataset, path_to_data_dir, path_to_results_dir) diff --git a/evaluator.py b/evaluator.py index e99c5161..6c4e49e4 100644 --- a/evaluator.py +++ b/evaluator.py @@ -1,5 +1,7 @@ import os +from typing import Dict, List +import torch from torch.utils.data import DataLoader from tqdm import tqdm @@ -9,29 +11,30 @@ class Evaluator(object): - def __init__(self, dataset, path_to_data_dir, path_to_results_dir): + def __init__(self, dataset: Dataset, path_to_data_dir: str, path_to_results_dir: str): super().__init__() self.dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=8, pin_memory=True) self._path_to_data_dir = path_to_data_dir self._path_to_results_dir = path_to_results_dir os.makedirs(self._path_to_results_dir, exist_ok=True) - def evaluate(self, model): + def evaluate(self, model: Model) -> Dict[int, float]: all_image_ids, all_pred_bboxes, all_pred_labels, all_pred_probs = [], [], [], [] - for batch_index, (image_id_batch, image_batch, scale_batch, _, _) in enumerate(tqdm(self.dataloader)): - image_id = image_id_batch[0] - image = image_batch[0].cuda() - scale = scale_batch[0] + with torch.no_grad(): + for batch_index, (image_id_batch, image_batch, scale_batch, _, _) in enumerate(tqdm(self.dataloader)): + image_id = image_id_batch[0] + image = image_batch[0].cuda() + scale = scale_batch[0].item() - pred_bboxes, pred_labels, pred_probs = model.detect(image) + pred_bboxes, pred_labels, pred_probs = model.detect(image) - pred_bboxes = [[it / scale for it in bbox] for bbox in pred_bboxes] + pred_bboxes = [[it / scale for it in bbox] for bbox in pred_bboxes] - all_pred_bboxes.extend(pred_bboxes) - all_pred_labels.extend(pred_labels) - all_pred_probs.extend(pred_probs) - all_image_ids.extend([image_id] * len(pred_labels)) + all_pred_bboxes.extend(pred_bboxes) + all_pred_labels.extend(pred_labels) + all_pred_probs.extend(pred_probs) + all_image_ids.extend([image_id] * len(pred_labels)) self._write_results(all_image_ids, all_pred_bboxes, all_pred_labels, all_pred_probs) @@ -57,13 +60,13 @@ def evaluate(self, model): return label_to_ap_dict - def _write_results(self, image_ids, bboxes, labels, preds): + def _write_results(self, image_ids: List[str], bboxes: List[List[float]], labels: List[int], probs: List[float]): label_to_txt_files_dict = {} for c in range(1, Model.NUM_CLASSES): label_to_txt_files_dict[c] = open(os.path.join(self._path_to_results_dir, 'comp3_det_test_{:s}.txt'.format(Dataset.LABEL_TO_CATEGORY_DICT[c])), 'w') - for image_id, bbox, label, pred in zip(image_ids, bboxes, labels, preds): - label_to_txt_files_dict[label].write('{:s} {:f} {:f} {:f} {:f} {:f}\n'.format(image_id, pred, + for image_id, bbox, label, prob in zip(image_ids, bboxes, labels, probs): + label_to_txt_files_dict[label].write('{:s} {:f} {:f} {:f} {:f} {:f}\n'.format(image_id, prob, bbox[0], bbox[1], bbox[2], bbox[3])) for _, f in label_to_txt_files_dict.items(): diff --git a/infer.py b/infer.py index 27bb87e0..c141c986 100644 --- a/infer.py +++ b/infer.py @@ -9,7 +9,7 @@ from model import Model -def _infer(path_to_input_image, path_to_output_image, path_to_checkpoint): +def _infer(path_to_input_image: str, path_to_output_image: str, path_to_checkpoint: str): image = transforms.Image.open(path_to_input_image) image_tensor, scale = Dataset.preprocess(image) diff --git a/model.py b/model.py index ebfdef19..16dcab1b 100644 --- a/model.py +++ b/model.py @@ -1,13 +1,10 @@ import os import time -from typing import Union +from typing import Union, Tuple, List -import numpy as np import torch import torchvision.models -from torch import FloatTensor -from torch import nn -from torch.autograd import Variable +from torch import nn, Tensor from torch.nn import functional as F from bbox import BBox @@ -21,30 +18,30 @@ class Model(nn.Module): class ForwardInput: class Train(object): - def __init__(self, image, gt_classes, gt_bboxes) -> None: + def __init__(self, image: Tensor, gt_classes: Tensor, gt_bboxes: Tensor): self.image = image self.gt_classes = gt_classes self.gt_bboxes = gt_bboxes class Eval(object): - def __init__(self, image) -> None: + def __init__(self, image: Tensor): self.image = image class ForwardOutput: class Train(object): - def __init__(self, anchor_objectness_loss, anchor_transformer_loss, proposal_class_loss, proposal_transformer_loss) -> None: + def __init__(self, anchor_objectness_loss: Tensor, anchor_transformer_loss: Tensor, proposal_class_loss: Tensor, proposal_transformer_loss: Tensor): self.anchor_objectness_loss = anchor_objectness_loss self.anchor_transformer_loss = anchor_transformer_loss self.proposal_class_loss = proposal_class_loss self.proposal_transformer_loss = proposal_transformer_loss class Eval(object): - def __init__(self, proposal_bboxes, proposal_classes, proposal_transformers) -> None: + def __init__(self, proposal_bboxes: Tensor, proposal_classes: Tensor, proposal_transformers: Tensor): self.proposal_bboxes = proposal_bboxes self.proposal_classes = proposal_classes self.proposal_transformers = proposal_transformers - def __init__(self) -> None: + def __init__(self): super().__init__() vgg16 = torchvision.models.vgg16(pretrained=True) @@ -59,11 +56,11 @@ def __init__(self) -> None: self.rpn = RegionProposalNetwork() self.head = Model.Head() - self._transformer_normalize_mean = FloatTensor([0., 0., 0., 0.]) - self._transformer_normalize_std = FloatTensor([.1, .1, .2, .2]) + self._transformer_normalize_mean = torch.tensor([0., 0., 0., 0.], dtype=torch.float) + self._transformer_normalize_std = torch.tensor([.1, .1, .2, .2], dtype=torch.float) def forward(self, forward_input: Union[ForwardInput.Train, ForwardInput.Eval]) -> Union[ForwardOutput.Train, ForwardOutput.Eval]: - image = Variable(forward_input.image, volatile=not self.training).unsqueeze(dim=0) + image = forward_input.image.unsqueeze(dim=0) image_height, image_width = image.shape[2], image.shape[3] features = self.features(image) @@ -86,26 +83,23 @@ def forward(self, forward_input: Union[ForwardInput.Train, ForwardInput.Eval]) - return forward_output - def sample(self, proposal_bboxes, gt_classes, gt_bboxes): + def sample(self, proposal_bboxes: Tensor, gt_classes: Tensor, gt_bboxes: Tensor): proposal_bboxes = proposal_bboxes.cpu() gt_classes = gt_classes.cpu() gt_bboxes = gt_bboxes.cpu() # find labels for each `proposal_bboxes` - labels = torch.ones(len(proposal_bboxes)).long() * -1 + labels = torch.ones(len(proposal_bboxes), dtype=torch.long) * -1 ious = BBox.iou(proposal_bboxes, gt_bboxes) proposal_max_ious, proposal_assignments = ious.max(dim=1) labels[proposal_max_ious < 0.5] = 0 - if len((proposal_max_ious >= 0.5).nonzero().squeeze()) > 0: - labels[proposal_max_ious >= 0.5] = gt_classes[proposal_assignments[proposal_max_ious >= 0.5]] + labels[proposal_max_ious >= 0.5] = gt_classes[proposal_assignments[proposal_max_ious >= 0.5]] # select 128 samples - fg_indices = (labels > 0).nonzero().squeeze() - bg_indices = (labels == 0).nonzero().squeeze() - if len(fg_indices) > 0: - fg_indices = fg_indices[torch.randperm(len(fg_indices))[:min(len(fg_indices), 32)]] - if len(bg_indices) > 0: - bg_indices = bg_indices[torch.randperm(len(bg_indices))[:128 - len(fg_indices)]] + fg_indices = (labels > 0).nonzero().view(-1) + bg_indices = (labels == 0).nonzero().view(-1) + fg_indices = fg_indices[torch.randperm(len(fg_indices))[:min(len(fg_indices), 32)]] + bg_indices = bg_indices[torch.randperm(len(bg_indices))[:128 - len(fg_indices)]] select_indices = torch.cat([fg_indices, bg_indices]) select_indices = select_indices[torch.randperm(len(select_indices))] @@ -114,48 +108,43 @@ def sample(self, proposal_bboxes, gt_classes, gt_bboxes): gt_proposal_classes = labels[select_indices] gt_proposal_transformers = (gt_proposal_transformers - self._transformer_normalize_mean) / self._transformer_normalize_std - gt_proposal_transformers = Variable(gt_proposal_transformers).cuda() - gt_proposal_classes = Variable(gt_proposal_classes).cuda() + + gt_proposal_transformers = gt_proposal_transformers.cuda() + gt_proposal_classes = gt_proposal_classes.cuda() return proposal_bboxes, gt_proposal_classes, gt_proposal_transformers - def loss(self, proposal_classes, proposal_transformers, gt_proposal_classes, gt_proposal_transformers): + def loss(self, proposal_classes: Tensor, proposal_transformers: Tensor, gt_proposal_classes: Tensor, gt_proposal_transformers: Tensor): cross_entropy = F.cross_entropy(input=proposal_classes, target=gt_proposal_classes) proposal_transformers = proposal_transformers.view(-1, Model.NUM_CLASSES, 4) - proposal_transformers = proposal_transformers[torch.arange(0, len(proposal_transformers)).long().cuda(), gt_proposal_classes] - - fg_indices = np.where(gt_proposal_classes.data.cpu().numpy() > 0)[0] - in_weight = np.zeros((len(proposal_transformers), 4)) - in_weight[fg_indices] = 1 - in_weight = Variable(FloatTensor(in_weight)).cuda() + proposal_transformers = proposal_transformers[torch.arange(end=len(proposal_transformers), dtype=torch.long).cuda(), gt_proposal_classes] - proposal_transformers = proposal_transformers * in_weight - gt_proposal_transformers = gt_proposal_transformers * in_weight + fg_indices = gt_proposal_classes.nonzero().view(-1) - # NOTE: The default of `size_average` is `True`, which is divided by N x 4 (number of all elements), here we replaced by N for better performance - smooth_l1_loss = F.smooth_l1_loss(input=proposal_transformers, target=gt_proposal_transformers, size_average=False) + # NOTE: The default of `reduction` is `elementwise_mean`, which is divided by N x 4 (number of all elements), here we replaced by N for better performance + smooth_l1_loss = F.smooth_l1_loss(input=proposal_transformers[fg_indices], target=gt_proposal_transformers[fg_indices], reduction='sum') smooth_l1_loss /= len(gt_proposal_transformers) return cross_entropy, smooth_l1_loss - def save(self, path_to_checkpoints_dir, step): + def save(self, path_to_checkpoints_dir: str, step: int) -> str: path_to_checkpoint = os.path.join(path_to_checkpoints_dir, 'model-{:s}-{:d}.pth'.format(time.strftime('%Y%m%d%H%M'), step)) torch.save(self.state_dict(), path_to_checkpoint) return path_to_checkpoint - def load(self, path_to_checkpoint): + def load(self, path_to_checkpoint: str) -> 'Model': self.load_state_dict(torch.load(path_to_checkpoint)) return self - def detect(self, image): + def detect(self, image: Tensor) -> Tuple[List[List[float]], List[int], List[float]]: forward_input = Model.ForwardInput.Eval(image) forward_output: Model.ForwardOutput.Eval = self.eval().forward(forward_input) proposal_bboxes = forward_output.proposal_bboxes - proposal_classes = forward_output.proposal_classes.data - proposal_transformers = forward_output.proposal_transformers.data + proposal_classes = forward_output.proposal_classes + proposal_transformers = forward_output.proposal_transformers proposal_transformers = proposal_transformers.view(-1, Model.NUM_CLASSES, 4) mean = self._transformer_normalize_mean.repeat(1, Model.NUM_CLASSES, 1).cuda() @@ -172,7 +161,10 @@ def detect(self, image): detection_bboxes[:, :, [0, 2]] = detection_bboxes[:, :, [0, 2]].clamp(min=0, max=image_width) detection_bboxes[:, :, [1, 3]] = detection_bboxes[:, :, [1, 3]].clamp(min=0, max=image_height) - proposal_probs = F.softmax(Variable(proposal_classes), dim=1).data + proposal_probs = F.softmax(proposal_classes, dim=1) + + detection_bboxes = detection_bboxes.cpu() + proposal_probs = proposal_probs.cpu() bboxes = [] labels = [] @@ -182,16 +174,11 @@ def detect(self, image): detection_class_bboxes = detection_bboxes[:, c, :] proposal_class_probs = proposal_probs[:, c] - selected_indices = (proposal_class_probs > 0.05).nonzero().squeeze() - if len(selected_indices) > 0: - detection_class_bboxes = detection_class_bboxes[selected_indices] - proposal_class_probs = proposal_class_probs[selected_indices] - _, sorted_indices = proposal_class_probs.sort(descending=True) detection_class_bboxes = detection_class_bboxes[sorted_indices] proposal_class_probs = proposal_class_probs[sorted_indices] - keep_indices = NMS.suppress(detection_class_bboxes, threshold=0.3) + keep_indices = NMS.suppress(detection_class_bboxes.cuda(), threshold=0.3) detection_class_bboxes = detection_class_bboxes[keep_indices] proposal_class_probs = proposal_class_probs[keep_indices] @@ -211,22 +198,20 @@ def __init__(self): nn.Linear(512 * 7 * 7, 4096), nn.ReLU(), nn.Linear(4096, 4096), - nn.ReLU(), + nn.ReLU() ) self._class = nn.Linear(4096, Model.NUM_CLASSES) self._transformer = nn.Linear(4096, Model.NUM_CLASSES * 4) - def forward(self, features, proposal_bboxes): - proposal_bboxes = Variable(proposal_bboxes) - + def forward(self, features: Tensor, proposal_bboxes: Tensor) -> Tuple[Tensor, Tensor]: _, _, feature_map_height, feature_map_width = features.size() pool = [] for proposal_bbox in proposal_bboxes: - start_x = max(min(round(proposal_bbox[0].data[0] / 16), feature_map_width - 1), 0) # [0, feature_map_width) - start_y = max(min(round(proposal_bbox[1].data[0] / 16), feature_map_height - 1), 0) # (0, feature_map_height] - end_x = max(min(round(proposal_bbox[2].data[0] / 16) + 1, feature_map_width), 1) # [0, feature_map_width) - end_y = max(min(round(proposal_bbox[3].data[0] / 16) + 1, feature_map_height), 1) # (0, feature_map_height] + start_x = max(min(round(proposal_bbox[0].item() / 16), feature_map_width - 1), 0) # [0, feature_map_width) + start_y = max(min(round(proposal_bbox[1].item() / 16), feature_map_height - 1), 0) # (0, feature_map_height] + end_x = max(min(round(proposal_bbox[2].item() / 16) + 1, feature_map_width), 1) # [0, feature_map_width) + end_y = max(min(round(proposal_bbox[3].item() / 16) + 1, feature_map_height), 1) # (0, feature_map_height] roi_feature_map = features[..., start_y:end_y, start_x:end_x] pool.append(F.adaptive_max_pool2d(roi_feature_map, 7)) pool = torch.cat(pool, dim=0) # pool has shape (128, 512, 7, 7) diff --git a/nms/nms.py b/nms/nms.py index fd655b17..03c0c315 100644 --- a/nms/nms.py +++ b/nms/nms.py @@ -1,12 +1,13 @@ import torch from nms._ext import nms +from torch import Tensor class NMS(object): @staticmethod - def suppress(sorted_bboxes, threshold: float): - keep_indices = torch.LongTensor().cuda() + def suppress(sorted_bboxes: Tensor, threshold: float) -> Tensor: + keep_indices = torch.tensor([], dtype=torch.long).cuda() nms.suppress(sorted_bboxes.contiguous(), threshold, keep_indices) return keep_indices diff --git a/nms/src/nms.c b/nms/src/nms.c index c4bf8581..6d5b51c2 100644 --- a/nms/src/nms.c +++ b/nms/src/nms.c @@ -4,13 +4,13 @@ extern THCState *state; int suppress(THCudaTensor *bboxes, float threshold, THCudaLongTensor *keepIndices) { - if (!((bboxes->nDimension == 2) && (bboxes->size[1] == 4))) + if (!((THCudaTensor_nDimension(state, bboxes) == 2) && (THCudaTensor_size(state, bboxes, 1) == 4))) return 0; - long numBoxes = bboxes->size[0]; + long numBoxes = THCudaTensor_size(state, bboxes, 0); THLongTensor *keepIndicesTmp = THLongTensor_newWithSize1d(numBoxes); - long numKeepBoxes; + long numKeepBoxes; nms(THCudaTensor_data(state, bboxes), numBoxes, threshold, THLongTensor_data(keepIndicesTmp), &numKeepBoxes); THLongTensor_resize1d(keepIndicesTmp, numKeepBoxes); diff --git a/rpn/region_proposal_network.py b/rpn/region_proposal_network.py index ed9d6715..37541738 100644 --- a/rpn/region_proposal_network.py +++ b/rpn/region_proposal_network.py @@ -1,9 +1,7 @@ from typing import Tuple import torch -from torch import FloatTensor -from torch import nn -from torch.autograd import Variable +from torch import nn, Tensor from torch.nn import functional as F from bbox import BBox @@ -12,7 +10,7 @@ class RegionProposalNetwork(nn.Module): - def __init__(self) -> None: + def __init__(self): super().__init__() self._features = nn.Sequential( @@ -23,7 +21,7 @@ def __init__(self) -> None: self._objectness = nn.Conv2d(in_channels=512, out_channels=18, kernel_size=1) self._transformer = nn.Conv2d(in_channels=512, out_channels=36, kernel_size=1) - def forward(self, features, image_width: int, image_height: int): + def forward(self, features: Tensor, image_width: int, image_height: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]: anchor_bboxes = BBox.generate_anchors(max_x=image_width, max_y=image_height, stride=16).cuda() features = self._features(features) @@ -33,13 +31,13 @@ def forward(self, features, image_width: int, image_height: int): objectnesses = objectnesses.permute(0, 2, 3, 1).contiguous().view(-1, 2) transformers = transformers.permute(0, 2, 3, 1).contiguous().view(-1, 4) - proposal_score = objectnesses.data[:, 1] + proposal_score = objectnesses[:, 1] _, sorted_indices = torch.sort(proposal_score, dim=0, descending=True) - sorted_transformers = transformers.data[sorted_indices] + sorted_transformers = transformers[sorted_indices] sorted_anchor_bboxes = anchor_bboxes[sorted_indices] - proposal_bboxes = BBox.apply_transformer(sorted_anchor_bboxes, sorted_transformers) + proposal_bboxes = BBox.apply_transformer(sorted_anchor_bboxes, sorted_transformers.detach()) proposal_bboxes = BBox.clip(proposal_bboxes, 0, 0, image_width, image_height) area_threshold = 16 @@ -54,12 +52,13 @@ def forward(self, features, image_width: int, image_height: int): return anchor_bboxes, objectnesses, transformers, proposal_bboxes - def sample(self, anchor_bboxes, anchor_objectnesses, anchor_transformers, gt_bboxes, image_width: int, image_height: int): + def sample(self, anchor_bboxes: Tensor, anchor_objectnesses: Tensor, anchor_transformers: Tensor, gt_bboxes: Tensor, + image_width: int, image_height: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]: anchor_bboxes = anchor_bboxes.cpu() gt_bboxes = gt_bboxes.cpu() # remove cross-boundary - boundary = FloatTensor(BBox(0, 0, image_width, image_height).tolist()) + boundary = torch.tensor(BBox(0, 0, image_width, image_height).tolist(), dtype=torch.float) inside_indices = BBox.inside(anchor_bboxes, boundary.unsqueeze(dim=0)).squeeze().nonzero().squeeze() anchor_bboxes = anchor_bboxes[inside_indices] @@ -67,7 +66,7 @@ def sample(self, anchor_bboxes, anchor_objectnesses, anchor_transformers, gt_bbo anchor_transformers = anchor_transformers[inside_indices.cuda()] # find labels for each `anchor_bboxes` - labels = torch.ones(len(anchor_bboxes)).long() * -1 + labels = torch.ones(len(anchor_bboxes), dtype=torch.long) * -1 ious = BBox.iou(anchor_bboxes, gt_bboxes) anchor_max_ious, anchor_assignments = ious.max(dim=1) gt_max_ious, gt_assignments = ious.max(dim=0) @@ -77,12 +76,10 @@ def sample(self, anchor_bboxes, anchor_objectnesses, anchor_transformers, gt_bbo labels[anchor_max_ious >= 0.7] = 1 # select 256 samples - fg_indices = (labels == 1).nonzero().squeeze() - bg_indices = (labels == 0).nonzero().squeeze() - if len(fg_indices) > 0: - fg_indices = fg_indices[torch.randperm(len(fg_indices))[:min(len(fg_indices), 128)]] - if len(bg_indices) > 0: - bg_indices = bg_indices[torch.randperm(len(bg_indices))[:256 - len(fg_indices)]] + fg_indices = (labels == 1).nonzero().view(-1) + bg_indices = (labels == 0).nonzero().view(-1) + fg_indices = fg_indices[torch.randperm(len(fg_indices))[:min(len(fg_indices), 128)]] + bg_indices = bg_indices[torch.randperm(len(bg_indices))[:256 - len(fg_indices)]] select_indices = torch.cat([fg_indices, bg_indices]) select_indices = select_indices[torch.randperm(len(select_indices))] @@ -91,19 +88,19 @@ def sample(self, anchor_bboxes, anchor_objectnesses, anchor_transformers, gt_bbo anchor_bboxes = anchor_bboxes[fg_indices] gt_anchor_transformers = BBox.calc_transformer(anchor_bboxes, gt_bboxes) - gt_anchor_objectnesses = Variable(gt_anchor_objectnesses).cuda() - gt_anchor_transformers = Variable(gt_anchor_transformers).cuda() + gt_anchor_objectnesses = gt_anchor_objectnesses.cuda() + gt_anchor_transformers = gt_anchor_transformers.cuda() anchor_objectnesses = anchor_objectnesses[select_indices.cuda()] anchor_transformers = anchor_transformers[fg_indices.cuda()] return anchor_objectnesses, anchor_transformers, gt_anchor_objectnesses, gt_anchor_transformers - def loss(self, anchor_objectnesses, anchor_transformers, gt_anchor_objectnesses, gt_anchor_transformers): + def loss(self, anchor_objectnesses: Tensor, anchor_transformers: Tensor, gt_anchor_objectnesses: Tensor, gt_anchor_transformers: Tensor) -> Tuple[Tensor, Tensor]: cross_entropy = F.cross_entropy(input=anchor_objectnesses, target=gt_anchor_objectnesses) - # NOTE: The default of `size_average` is `True`, which is divided by N x 4 (number of all elements), here we replaced by N for better performance - smooth_l1_loss = F.smooth_l1_loss(input=anchor_transformers, target=gt_anchor_transformers, size_average=False) + # NOTE: The default of `reduction` is `elementwise_mean`, which is divided by N x 4 (number of all elements), here we replaced by N for better performance + smooth_l1_loss = F.smooth_l1_loss(input=anchor_transformers, target=gt_anchor_transformers, reduction='sum') smooth_l1_loss /= len(gt_anchor_transformers) return cross_entropy, smooth_l1_loss diff --git a/train.py b/train.py index 3c786436..c8427516 100644 --- a/train.py +++ b/train.py @@ -10,12 +10,12 @@ from model import Model -def _train(path_to_data_dir: str, path_to_checkpoints_dir: str) -> None: +def _train(path_to_data_dir: str, path_to_checkpoints_dir: str): dataset = Dataset(path_to_data_dir, mode=Dataset.Mode.TRAIN) - dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=8) + dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=8, pin_memory=True) model = Model().cuda() - optimizer = optim.SGD([it for it in model.parameters() if it.requires_grad], lr=1e-3, momentum=0.9, weight_decay=0.0005) + optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=0.0005) scheduler = StepLR(optimizer, step_size=50000, gamma=0.1) step = 0 @@ -53,7 +53,7 @@ def _train(path_to_data_dir: str, path_to_checkpoints_dir: str) -> None: if step % num_steps_to_display == 0: steps_per_sec = num_steps_to_display / elapsed_time elapsed_time = 0.0 - print(f'[Step {step}] Loss = {loss.data[0]:.6f}, Learning Rate = {scheduler.get_lr()[0]} ({steps_per_sec:.2f} steps/sec)') + print(f'[Step {step}] Loss = {loss.item():.6f}, Learning Rate = {scheduler.get_lr()[0]} ({steps_per_sec:.2f} steps/sec)') if step % num_steps_to_snapshot == 0: path_to_checkpoint = model.save(path_to_checkpoints_dir, step)