Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

About ONNX #21

Open
SolicTous opened this issue Oct 30, 2024 · 0 comments
Open

About ONNX #21

SolicTous opened this issue Oct 30, 2024 · 0 comments

Comments

@SolicTous
Copy link

I have tried convert to ONNX, but have got a lot of issues. After that seen your SamOnnxModel(nn.Module), but do not know how implement it.

As I see it have to be something like that

    ModelToExport = SamOnnxModel(model= model, return_single_mask = True)

    dummy_image_embeddings = torch.randn(1, 3, 1024, 1024, device='cuda', requires_grad=True)
    dummy_point_coords = torch.randn(1, 1, 2, device='cuda', requires_grad=True)
    dummy_point_labels = torch.randn(1, 1, device='cuda', requires_grad=True)
    dummy_mask_input = torch.randn(1, 1, 1200, 1200, device='cuda', requires_grad=True)  # ??
    dummy_has_mask_input = torch.randn(1, 1024, 1024, 3, device='cuda', requires_grad=True)  # ??
    dummy_orig_im_size = torch.randn(1200, 1200, device='cuda', requires_grad=True)

    inputs = ['image_embeddings', 'point_coords', 'point_labels',
                   'mask_input', 'has_mask_input', 'orig_im_size']
    outputs = ['upscaled_masks', 'scores', 'masks']
    torch.onnx.export(ModelToExport,
                      (dummy_image_embeddings, dummy_point_coords, dummy_point_labels,
                       dummy_mask_input, dummy_has_mask_input, dummy_orig_im_size),
                      opt.checkpoint_path.replace('pth','onnx'),
                      export_params=True, do_constant_folding=True,
                      input_names=inputs, output_names=outputs, opset_version=19,
                      verbose=False)

But I do not know how used mask_input and has_mask_input. This inputs do not included in base predict of for torch model.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant