Skip to content

feature_extractor for Resnet #790

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Closed
fortunto2 opened this issue Nov 19, 2020 · 2 comments
Closed

feature_extractor for Resnet #790

fortunto2 opened this issue Nov 19, 2020 · 2 comments
Assignees

Comments

@fortunto2
Copy link

fortunto2 commented Nov 19, 2020

Hi,
can you help me please create feature_extractor handler for ResNet152(or other)?

model = models.resnet152(pretrained=True)
### strip the last layer
feature_extractor = torch.nn.Sequential(*list(model.children())[:-1])

try this change,
change import only model
to from model import ResNetFeatures

import io
import logging
import numpy as np
import os
import torch
from PIL import Image
from torch.autograd import Variable
from torchvision import transforms


logger = logging.getLogger(__name__)


class BatchImageClassifier(object):
    """
    BatchImageClassifier handler class. This handler takes list of images
    and returns a corresponding list of classes
    """

    def __init__(self):
        self.model = None
        self.mapping = None
        self.device = None
        self.initialized = False

    def initialize(self, context):
        """First try to load torchscript else load eager mode state_dict based model"""

        self.manifest = context.manifest
        properties = context.system_properties
        model_dir = properties.get("model_dir")
        self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")

        # Read model serialize/pt file
        serialized_file = self.manifest['model']['serializedFile']
        model_pt_path = os.path.join(model_dir, serialized_file)
        if not os.path.isfile(model_pt_path):
            raise RuntimeError("Missing the model.pt file")

        try:
            logger.info('Loading torchscript model to device {}'.format(self.device))
            self.model = torch.jit.load(model_pt_path)
        except Exception as e:
            # Read model definition file
            model_file = self.manifest['model']['modelFile']
            model_def_path = os.path.join(model_dir, model_file)
            if not os.path.isfile(model_def_path):
                raise RuntimeError("Missing the model.py file")

            state_dict = torch.load(model_pt_path)
            from model import ResNetFeatures
            self.model = ResNetFeatures()
            self.model.load_state_dict(state_dict)
        self.model.to(self.device)
        self.model.eval()
        logger.debug('Model file {0} loaded successfully'.format(model_pt_path))

        # Read the mapping file, index to object name
        mapping_file_path = os.path.join(model_dir, "index_to_name.json")
        import json
        if os.path.isfile(mapping_file_path):
            with open(mapping_file_path) as f:
                self.mapping = json.load(f)
        else:
            logger.warning('Missing the index_to_name.json file. Inference output will not include class name.')

        self.initialized = True

    def preprocess(self, request):
        """
         Scales, crops, and normalizes a PIL image for a PyTorch model,
         returns an Numpy array
        """

        image_tensor = None

        for idx, data in enumerate(request):
            image = data.get("data")
            if image is None:
                image = data.get("body")

            my_preprocess = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])
            input_image = Image.open(io.BytesIO(image))
            input_image = my_preprocess(input_image).unsqueeze(0)
            input_image = Variable(input_image).to(self.device)
            if input_image.shape is not None:
                if image_tensor is None:
                    image_tensor = input_image
                else:
                    image_tensor = torch.cat((image_tensor, input_image), 0)

        return image_tensor

    def inference(self, img):
        return self.model.forward(img)

    def postprocess(self, inference_output):
        logger.info(inference_output.shape)
        logger.info(inference_output)
        # num_rows, num_cols = inference_output.shape
        # output_classes = []
        # for i in range(num_rows):
        #     out = inference_output[i].unsqueeze(0)
        #     _, y_hat = out.max(1)
        #     predicted_idx = str(y_hat.item())
        #     output_classes.append(self.mapping[predicted_idx])
        return inference_output


_service = BatchImageClassifier()


def handle(data, context):
    if not _service.initialized:
        _service.initialize(context)

    if data is None:
        return None

    data = _service.preprocess(data)
    data = _service.inference(data)
    data = _service.postprocess(data)

    return data

or in model.py ? from pytorch/vision#2200 (comment)

from torchvision.models.resnet import ResNet, Bottleneck


class ResNet152ImageClassifier(ResNet):
    def __init__(self):
        super(ResNet152ImageClassifier, self).__init__(Bottleneck, [3, 8, 36, 3])


    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x

get

2020-11-19 22:20:24,271 [DEBUG] W-9000-resnet-152-features_1.0 org.pytorch.serve.wlm.WorkerThread - Backend worker monitoring thread interrupted or backend worker process died.
java.lang.InterruptedException
	at java.base/java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.reportInterruptAfterWait(AbstractQueuedSynchronizer.java:2056)
	at java.base/java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2133)
	at java.base/java.util.concurrent.ArrayBlockingQueue.poll(ArrayBlockingQueue.java:432)
	at org.pytorch.serve.wlm.WorkerThread.run(WorkerThread.java:129)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at java.base/java.lang.Thread.run(Thread.java:834)
2020-11-19 22:20:24,273 [WARN ] W-9000-resnet-152-features_1.0 org.pytorch.serve.wlm.BatchAggregator - Load model failed: resnet-152-features, error: Worker died.
2020-11-19 22:20:24,273 [DEBUG] W-9000-resnet-152-features_1.0 org.pytorch.serve.wlm.WorkerThread - W-9000-resnet-152-features_1.0 State change WORKER_STARTED -> WORKER_STOPPED
2020-11-19 22:20:24,273 [WARN ] W-9000-resnet-152-features_1.0 org.pytorch.serve.wlm.WorkerLifeCycle - terminateIOStreams() threadName=W-9000-resnet-152-features_1.0-stderr

@dhaniram-kshirsagar
Copy link
Contributor

You don't have to refer to your model class explicitly. The way to do that is as follows -
Look at this https://github.com/pytorch/serve/blob/master/ts/torch_handler/base_handler.py#L48

Taken from above file

def _load_pickled_model(self, model_dir, model_file, model_pt_path):
        model_def_path = os.path.join(model_dir, model_file)
        if not os.path.isfile(model_def_path):
            raise RuntimeError("Missing the model.py file")

        module = importlib.import_module(model_file.split(".")[0])
        model_class_definitions = list_classes_from_module(module)
        if len(model_class_definitions) != 1:
            raise ValueError("Expected only one class as model definition. {}".format(
                model_class_definitions))

        model_class = model_class_definitions[0]
        state_dict = torch.load(model_pt_path, map_location=self.map_location)
        model = model_class()
        model.load_state_dict(state_dict)
        return model

Once the model is loaded i.e. post line number https://github.com/pytorch/serve/blob/master/ts/torch_handler/base_handler.py#L58

You have to add your specific piece of code i.e. in your custom handler. What you have written is correct except the part that you don't have to explicitly create an instance of your model class.

model = torch.nn.Sequential(*list(model.children())[:-1])```

@dhaniram-kshirsagar dhaniram-kshirsagar self-assigned this Nov 19, 2020
@fortunto2
Copy link
Author

fortunto2 commented Nov 19, 2020

Cool @dhaniram-kshirsagar , thank you!

write like this

    def handle(self, data, context):

        self.context = context

        input_data = self.preprocess(data)

        # extract features vector
        feature_extractor = torch.nn.Sequential(*list(self.model.children())[:-1])
        features = feature_extractor(input_data)

        # flattern tensor to list
        features = [element.item() for element in features.flatten()]

        return [features]

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants