Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions scratch/test_depth_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import taichi as ti
from taichi_3d_gaussian_splatting.GaussianPointCloudRasterisation import GaussianPointCloudRasterisation
import torch
import numpy as np
from taichi_3d_gaussian_splatting.Camera import CameraInfo
from taichi_3d_gaussian_splatting.utils import se3_to_quaternion_and_translation_torch

RasterConifg = GaussianPointCloudRasterisation.GaussianPointCloudRasterisationConfig
def render(pts, pts_feat, c2w, intrin, HW):
rasterisation = GaussianPointCloudRasterisation(
config=RasterConifg(near_plane=0.4, far_plane=2000.0, depth_to_sort_key_scale=10.0, rgb_only=False),
)
camera_info = CameraInfo(camera_intrinsics=intrin.to(pts.device),camera_height=HW[0],camera_width=HW[1],
camera_id=0) # TODO: caemra_id, does it matter
q_pointcloud_camera, t_pointcloud_camera = se3_to_quaternion_and_translation_torch(c2w[None])
gaussian_input = GaussianPointCloudRasterisation.GaussianPointCloudRasterisationInput(
point_cloud=pts.float(),
point_cloud_features=pts_feat.cuda(),
point_object_id=torch.zeros(pts.shape[0], dtype=torch.int32,device=pts.device),
point_invalid_mask=torch.zeros(pts.shape[0], dtype=torch.int8,device=pts.device),
camera_info=camera_info,
q_pointcloud_camera=q_pointcloud_camera.cuda().contiguous(),
t_pointcloud_camera=t_pointcloud_camera.cuda().contiguous(),
color_max_sh_band=6,#TODO: check the number here, original it was iteration // self.config.increase_color_max_sh_band_interval
)
res = rasterisation(gaussian_input)
return res
def plot3d(*data,fn='/d/del.html'):
import plotly.graph_objs as go
from plotly.offline import plot
def plot_points_3d(pts):
if torch.is_tensor(pts):
pts = pts.detach().cpu().numpy()
return go.Scatter3d( x=pts[:, 0], y=pts[:, 1], z=pts[:, 2] ,marker=dict(size=1),mode='markers')
data = [plot_points_3d(_data) for _data in data]
fig = go.Figure(data,layout={'scene': {'aspectmode': 'data'}})#, 'aspectratio': ar}})

plot(fig,filename=fn, auto_open=False)
return fig


if __name__ == '__main__':
print('############ test depth grad ##############')
ti.init(arch=ti.cuda, device_memory_GB=0.1)
pts = torch.randn(10000,3,device='cuda')
pts.requires_grad_()
pts_feat = torch.zeros(pts.shape[0], 56,device=pts.device)
pts_feat[:, 0:4] = torch.rand_like(pts_feat[:, 0:4])
pts_feat[:, 4:7] = torch.randn(pts.shape[0],3,device=pts.device)*0.3-1 #size
pts_feat[:, 7] = 0. # set high alpha
pts_feat.requires_grad_(True)

c2w = torch.eye(4,device='cuda')
HW = [1080//64*16,1920//64*16]
intrin = torch.Tensor([[100,0,200],[0,100,200],[0,0,1]]).cuda()
iteration = 0
import tqdm
optimizer = torch.optim.Adam([pts],lr=0.01)
for ii in tqdm.trange(1000):
optimizer.zero_grad()
res=render(pts, pts_feat, c2w, intrin, HW)
mask = res[1]>0
loss = (res[1]-3).abs()[mask].mean()
loss.backward()
optimizer.step()
if ii%200==0:
plot3d(pts)
print(loss)

print('############ test alpha grad from depth ##############')
ti.init(arch=ti.cuda, device_memory_GB=0.1)
pts = torch.randn(6000,3,device='cuda')
pts[:,2]+=3
pts.requires_grad_()
pts_feat = torch.zeros(pts.shape[0], 56,device=pts.device)
pts_feat[:, 0:4] = torch.rand_like( pts_feat[:, 0:4])
pts_feat[:, 4:7] = torch.randn(pts.shape[0],3,device=pts.device)*0.3-1 #size
pts_feat[:, 7] = -5 # Note: set alpha before sigmoid, need to set small one, this is important
pts_feat.requires_grad_(True)

c2w = torch.eye(4,device='cuda')
HW = [1080//64//1*16,1920//64//1*16]
intrin = torch.Tensor([[50,0,200],[0,50,200],[0,0,1]]).cuda()
import tqdm
optimizer = torch.optim.Adam([pts_feat],lr=0.01)
for ii in tqdm.trange(1000):
optimizer.zero_grad()
res=render(pts, pts_feat, c2w, intrin, HW)
mask = res[-1]>0.5 # region with accumulated alpha>0.5
loss = (res[1]-3).abs()[mask].mean()
loss.backward()
# only keep gradient of alpha
pts_feat.grad.data[:,:7]=0
pts_feat.grad.data[:,8:]=0
optimizer.step()
if ii%200==0:
print(loss.item())
fig=plot3d(pts[pts_feat[:,7].sigmoid()>0.5], pts[pts_feat[:,7].sigmoid()<=0.5],
fn='/d/del.html')

43 changes: 33 additions & 10 deletions taichi_3d_gaussian_splatting/GaussianPointCloudRasterisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,9 @@ def gaussian_point_rasterisation_backward(
rasterized_image_grad: ti.types.ndarray(ti.f32, ndim=3), # (H, W, 3)
enable_depth_grad: ti.template(),
rasterized_depth_grad: ti.types.ndarray(ti.f32, ndim=2), # (H, W)
accumulated_alpha_grad: ti.types.ndarray(ti.f32, ndim=2), # (H, W)
pixel_accumulated_alpha: ti.types.ndarray(ti.f32, ndim=2), # (H, W)
rasterized_depth: ti.types.ndarray(ti.f32, ndim=2), # (H, W)
# (H, W)
pixel_offset_of_last_effective_point: ti.types.ndarray(ti.i32, ndim=2),
grad_pointcloud: ti.types.ndarray(ti.f32, ndim=2), # (N, 3)
Expand Down Expand Up @@ -522,14 +524,18 @@ def gaussian_point_rasterisation_backward(
pixel_u = tile_u * 16 + pixel_offset_u_in_tile
pixel_v = tile_v * 16 + pixel_offset_v_in_tile
last_effective_point = pixel_offset_of_last_effective_point[pixel_v, pixel_u]
org_accumulated_alpha: ti.f32 = pixel_accumulated_alpha[pixel_v, pixel_u]
accumulated_alpha: ti.f32 = pixel_accumulated_alpha[pixel_v, pixel_u]
accumulated_alpha_grad_value: ti.f32 = accumulated_alpha_grad[pixel_v, pixel_u]
d_pixel: ti.f32 = rasterized_depth[pixel_v, pixel_u]
T_i = 1.0 - accumulated_alpha # T_i = \prod_{j=1}^{i-1} (1 - a_j)
# \frac{dC}{da_i} = c_i T(i) - \frac{1}{1 - a_i} \sum_{j=i+1}^{n} c_j a_j T(j)
# let w_i = \sum_{j=i+1}^{n} c_j a_j T(j)
# we have w_n = 0, w_{i-1} = w_i + c_i a_i T(i)
# \frac{dC}{da_i} = c_i T(i) - \frac{1}{1 - a_i} w_i
w_i = ti.math.vec3(0.0, 0.0, 0.0)
depth_w_i = 0.0
acc_alpha_w_i = 0.0

pixel_rgb_grad = ti.math.vec3(
rasterized_image_grad[pixel_v, pixel_u, 0], rasterized_image_grad[pixel_v, pixel_u, 1], rasterized_image_grad[pixel_v, pixel_u, 2])
Expand Down Expand Up @@ -602,8 +608,9 @@ def gaussian_point_rasterisation_backward(
tile_point_color[1, idx_point_offset_with_sort_key_in_block],
tile_point_color[2, idx_point_offset_with_sort_key_in_block]])

T_i = T_i / (1. - alpha)
accumulated_alpha = 1. - T_i
# accumulated_alpha_i = 1. - T_i #alpha after passing current point
T_i = T_i / (1. - alpha) # Transmittance before passing current point
accumulated_alpha = 1. - T_i #accumulated alha before passing current point

# print(
# f"({pixel_v}, {pixel_u}, {point_offset}, {point_offset - start_offset}), accumulated_alpha: {accumulated_alpha}")
Expand All @@ -619,10 +626,18 @@ def gaussian_point_rasterisation_backward(
alpha_grad: ti.f32 = alpha_grad_from_rgb.sum()
if enable_depth_grad:
depth_i = tile_point_depth[idx_point_offset_with_sort_key_in_block]
alpha_grad_from_depth = (depth_i * T_i - depth_w_i / (1. - alpha)) \
* pixel_depth_grad
d_depth_d_alpha = (
T_i * (depth_i - d_pixel)
+ 1.0 / (1.0 - alpha) * (depth_w_i - acc_alpha_w_i * d_pixel)
) / (org_accumulated_alpha + 0.00001)
alpha_grad_from_depth = d_depth_d_alpha * pixel_depth_grad
alpha_grad_from_accumulated_alpha = (
T_i - 1.0 / (1.0 - alpha) * acc_alpha_w_i
) * accumulated_alpha_grad_value
depth_w_i += depth_i * alpha * T_i
acc_alpha_w_i += alpha * T_i
alpha_grad += alpha_grad_from_depth
alpha_grad+= alpha_grad_from_accumulated_alpha

point_alpha_after_activation_grad = alpha_grad * gaussian_alpha
gaussian_point_3d_alpha_grad = point_alpha_after_activation_grad * \
Expand Down Expand Up @@ -650,7 +665,7 @@ def gaussian_point_rasterisation_backward(
ti.atomic_add(in_camera_grad_uv_cov_buffer[point_offset, 2],
point_uv_cov_grad[1, 1])
if enable_depth_grad:
point_depth_grad = alpha * T_i * pixel_depth_grad
point_depth_grad = alpha * T_i * pixel_depth_grad / (org_accumulated_alpha+0.00001)
ti.atomic_add(in_camera_grad_depth_buffer[point_offset], point_depth_grad)

for i in ti.static(range(3)):
Expand Down Expand Up @@ -763,7 +778,7 @@ class GaussianPointCloudRasterisationConfig(YAMLWizard):
grad_s_factor = 0.5
grad_q_factor = 1.
grad_alpha_factor = 20.
enable_depth_grad = False
enable_depth_grad = True

@dataclass
class GaussianPointCloudRasterisationInput:
Expand Down Expand Up @@ -956,6 +971,7 @@ def forward(ctx,

# Step 5: render
if point_in_camera_sort_key.shape[0] > 0:
# import ipdb;ipdb.set_trace()
gaussian_point_rasterisation(
camera_height=camera_info.camera_height,
camera_width=camera_info.camera_width,
Expand All @@ -982,6 +998,7 @@ def forward(ctx,
tile_points_start,
tile_points_end,
pixel_accumulated_alpha,
rasterized_depth,
pixel_offset_of_last_effective_point,
num_overlap_tiles,
point_object_id,
Expand All @@ -998,10 +1015,12 @@ def forward(ctx,
ctx.camera_info = camera_info
ctx.color_max_sh_band = color_max_sh_band
# rasterized_image.requires_grad_(True)
return rasterized_image, rasterized_depth, pixel_valid_point_count
return rasterized_image, rasterized_depth, pixel_valid_point_count, pixel_accumulated_alpha


@staticmethod
def backward(ctx, grad_rasterized_image, grad_rasterized_depth, grad_pixel_valid_point_count):
def backward(ctx, grad_rasterized_image, grad_rasterized_depth,
grad_pixel_valid_point_count, grad_pixel_accumulated_alpha):
grad_pointcloud = grad_pointcloud_features = grad_q_pointcloud_camera = grad_t_pointcloud_camera = None
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
pointcloud, \
Expand All @@ -1011,6 +1030,7 @@ def backward(ctx, grad_rasterized_image, grad_rasterized_depth, grad_pixel_valid
tile_points_start, \
tile_points_end, \
pixel_accumulated_alpha, \
rasterized_depth, \
pixel_offset_of_last_effective_point, \
num_overlap_tiles, \
point_object_id, \
Expand Down Expand Up @@ -1061,7 +1081,7 @@ def backward(ctx, grad_rasterized_image, grad_rasterized_depth, grad_pixel_valid
point_object_id=point_object_id,
q_camera_pointcloud=q_camera_pointcloud,
t_camera_pointcloud=t_camera_pointcloud,
t_pointcloud_camera=t_pointcloud_camera,
t_pointcloud_camera=t_pointcloud_camera.contiguous(),
pointcloud=pointcloud,
pointcloud_features=pointcloud_features,
tile_points_start=tile_points_start,
Expand All @@ -1071,7 +1091,9 @@ def backward(ctx, grad_rasterized_image, grad_rasterized_depth, grad_pixel_valid
rasterized_image_grad=grad_rasterized_image,
enable_depth_grad=enable_depth_grad,
rasterized_depth_grad=grad_rasterized_depth,
accumulated_alpha_grad=grad_pixel_accumulated_alpha,
pixel_accumulated_alpha=pixel_accumulated_alpha,
rasterized_depth=rasterized_depth,
pixel_offset_of_last_effective_point=pixel_offset_of_last_effective_point,
grad_pointcloud=grad_pointcloud,
grad_pointcloud_features=grad_pointcloud_features,
Expand Down Expand Up @@ -1111,9 +1133,10 @@ def backward(ctx, grad_rasterized_image, grad_rasterized_depth, grad_pixel_valid


if backward_valid_point_hook is not None:
point_id_in_camera_list=point_id_in_camera_list.contiguous().long()
backward_valid_point_hook_input = GaussianPointCloudRasterisation.BackwardValidPointHookInput(
point_id_in_camera_list=point_id_in_camera_list,
grad_point_in_camera=grad_pointcloud[point_id_in_camera_list],
grad_point_in_camera=grad_pointcloud[point_id_in_camera_list.long()],
grad_pointfeatures_in_camera=grad_pointcloud_features[
point_id_in_camera_list],
grad_viewspace=grad_viewspace[point_id_in_camera_list],
Expand Down