Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Visualize trained plenoxels #19

Open
FrederikWarburg opened this issue Jan 25, 2023 · 1 comment
Open

Visualize trained plenoxels #19

FrederikWarburg opened this issue Jan 25, 2023 · 1 comment

Comments

@FrederikWarburg
Copy link

FrederikWarburg commented Jan 25, 2023

Hi

I would like to visualize one of your trained plenoxels. Ideally, I would want to just load a ckpt and render views from a spherical path around the center object. I would like to be able to do this without having to download co3d. However, I find this challenging to do with your current code.

I was able to load your model by using your on_load_checkpoint that dequantize the checkpoints and load the model. Then I want to render views from this.

I decide on an intrinsic matrix:

near, far = 0., 1.
ndc_coeffs = (-1., -1.)
image_sizes = (200, 200)
focal = (100., 100.)
intrinsics = np.array(
    [
        [focal[0], 0.0, image_sizes[0]/2, 0.0],
        [0.0, focal[1], image_sizes[1]/2, 0.0],
        [0.0, 0.0, 1.0, 0.0],
        [0.0, 0.0, 0.0, 1.0],
    ]
)

and use your function spherical_poses to get the extrinsics

cam_trans = np.diag(np.array([-1, -1, 1, 1], dtype=np.float32))
render_poses = spherical_poses(cam_trans)

I then try to create the rays from the first pose using various of your functions

extrinsics_idx = render_poses[:1]
N_render = len(render_poses)
intrinsics_idx = np.stack(
    [intrinsics for _ in range(N_render)]
)
image_sizes_idx = np.stack(
    [image_sizes for _ in range(N_render)]
)

rays_o, rays_d = batchified_get_rays(
    intrinsics_idx, 
    extrinsics_idx, 
    image_sizes_idx,
    True,
)

rays_d = torch.tensor(rays_d, dtype=torch.float32)
rays_o = torch.tensor(rays_o, dtype=torch.float32)
rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)

rays = torch.stack(
    convert_to_ndc(rays_o, rays_o, ndc_coeffs), dim=1
)

rays_o = rays[:,0,:].contiguous()
rays_d = rays[:,1,:].contiguous()

rays_o = rays_o.to("cuda")
rays_d = rays_d.to("cuda")

and then try to render

rays = Rays(rays_o, rays_d)
grid = grid.to(device="cuda")
depth = grid.volume_render_depth(rays, 1e-5)
target = torch.zeros_like(rays_o)

rgb, mask = grid.volume_render_fused(rays, target)

but when I visualize the rendering it looks like I did something wrong:

depth = depth.reshape(200, 200)
rgb = rgb.reshape(200, 200, 3)

plt.imshow(depth.cpu().numpy())
plt.show()

plt.imshow(rgb.cpu().numpy())
plt.show()

Could you please help me? Would be very useful to check that the model loading is correct and to see how good the reconstructions are.

@FrederikWarburg
Copy link
Author

FrederikWarburg commented Jan 25, 2023

For completeness, I'll also share my code for loading the model. I moved some of the functions to a jupyter notebook to keep the evaluation code disentangled from the training code.

def dequantize_data( data, data_min, data_scale, quant_bit=8, logarithmic_quant=False):

    if quant_bit == 8 or quant_bit == 16: 
        data_tensor = data.type(torch.FloatTensor) * data_scale + data_min
    elif quant_bit == 4:
        data_blank = torch.zeros(len(data) * 2, *data.shape[1:], device=data.device)
        data_blank[0::2] = data // 16
        data_blank[1::2] = data % 16
        if torch.all(data_blank[-1] == 0): 
            data_blank = data_blank[:-1]
        data_tensor = data_blank.type(torch.FloatTensor) * data_scale + data_min
    elif quant_bit == 2:
        data_blank = torch.zeros(len(data) * 4, *data.shape[1:], device=data.device)
        data_blank[0::4] = data // 64
        data_blank[1::4] = data % 64 // 16
        data_blank[2::4] = data % 16 // 4
        data_blank[3::4] = data % 4
        for _ in range(4):
            if torch.all(data_blank[-1]) == 0:
                data_blank = data_blank[:-1]
        data_tensor = data_blank.type(torch.FloatTensor) * data_scale + data_min

    if logarithmic_quant:
        data_tensor = torch.exp(-data_tensor)

    return data_tensor

def load_checkpoint(grid, checkpoint, quantize=True, quantize_density=False) -> None:

    state_dict = checkpoint["state_dict"]

    grid.reso_idx = checkpoint["reso_idx"]

    del grid.basis_data
    del grid.density_data
    del grid.sh_data
    del grid.links

    grid.register_parameter(
        "basis_data", nn.Parameter(state_dict["model.basis_data"])
    )

    if "model.background_data_min" in checkpoint.keys():
        del grid.background_data
        bgd_data = state_dict["model.background_data"]
        if quantize:
            bgd_min = checkpoint["model.background_data_min"]
            bgd_scale = checkpoint["model.background_data_scale"]
            bgd_data = dequantize_data(bgd_data, bgd_min, bgd_scale)

        grid.register_parameter("background_data", nn.Parameter(bgd_data))
        checkpoint["state_dict"]["model.background_data"] = bgd_data

    density_data = state_dict["model.density_data"]
    if quantize_density:
        density_min = checkpoint["model.density_data_min"]
        density_scale = checkpoint["model.density_data_scale"]
        density_data = dequantize_data(density_data, density_min, density_scale)

    grid.register_parameter("density_data", nn.Parameter(density_data))
    checkpoint["state_dict"]["model.density_data"] = density_data

    sh_data = state_dict["model.sh_data"]
    if quantize:
        sh_data_min = checkpoint["model.sh_data_min"]
        sh_data_scale = checkpoint["model.sh_data_scale"]
        sh_data = dequantize_data(sh_data, sh_data_min, sh_data_scale)

    grid.register_parameter("sh_data", nn.Parameter(sh_data))
    checkpoint["state_dict"]["model.sh_data"] = sh_data

    reso_list = [[128, 128, 128], [256, 256, 256]]
    reso = reso_list[checkpoint["reso_idx"]]

    links = torch.zeros(reso, dtype=torch.int32) - 1
    links_sparse = state_dict["model.links_idx"]
    links_idx = torch.stack(
        [
            links_sparse // (reso[1] * reso[2]),
            links_sparse % (reso[1] * reso[2]) // reso[2],
            links_sparse % reso[2],
        ]
    ).long()
    links[links_idx[0], links_idx[1], links_idx[2]] = torch.arange(
        len(links_idx[0]), dtype=torch.int32
    )
    checkpoint["state_dict"].pop("model.links_idx")
    checkpoint["state_dict"]["model.links"] = links
    grid.register_buffer("links", links)

    state_dict = checkpoint["state_dict"]
    state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
    grid.load_state_dict(state_dict)

    return grid

having defined these, I

grid = SparseGrid(background_nlayers=28, background_reso = 512)
ckpt = torch.load('../../data/co3d/PeRFception-v1-1/00/plenoxel_co3d_30_1091_3400/last.ckpt', map_location='cpu')
grid = load_checkpoint(grid, ckpt)

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant