-
Notifications
You must be signed in to change notification settings - Fork 880
feature_extractor for Resnet #790
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Comments
You don't have to refer to your model class explicitly. The way to do that is as follows - Taken from above file def _load_pickled_model(self, model_dir, model_file, model_pt_path):
model_def_path = os.path.join(model_dir, model_file)
if not os.path.isfile(model_def_path):
raise RuntimeError("Missing the model.py file")
module = importlib.import_module(model_file.split(".")[0])
model_class_definitions = list_classes_from_module(module)
if len(model_class_definitions) != 1:
raise ValueError("Expected only one class as model definition. {}".format(
model_class_definitions))
model_class = model_class_definitions[0]
state_dict = torch.load(model_pt_path, map_location=self.map_location)
model = model_class()
model.load_state_dict(state_dict)
return model Once the model is loaded i.e. post line number https://github.com/pytorch/serve/blob/master/ts/torch_handler/base_handler.py#L58 You have to add your specific piece of code i.e. in your custom handler. What you have written is correct except the part that you don't have to explicitly create an instance of your model class. model = torch.nn.Sequential(*list(model.children())[:-1])``` |
Cool @dhaniram-kshirsagar , thank you! write like this def handle(self, data, context):
self.context = context
input_data = self.preprocess(data)
# extract features vector
feature_extractor = torch.nn.Sequential(*list(self.model.children())[:-1])
features = feature_extractor(input_data)
# flattern tensor to list
features = [element.item() for element in features.flatten()]
return [features] |
Hi,
can you help me please create feature_extractor handler for ResNet152(or other)?
try this change,
change import only model
to
from model import ResNetFeatures
or in model.py ? from pytorch/vision#2200 (comment)
get
The text was updated successfully, but these errors were encountered: