@@ -13,6 +13,7 @@ def rasterize_gaussians(
1313 rotations ,
1414 cov3Ds_precomp ,
1515 raster_settings ,
16+ rasterizer_state
1617):
1718 return _RasterizeGaussians .apply (
1819 means3D ,
@@ -24,6 +25,7 @@ def rasterize_gaussians(
2425 rotations ,
2526 cov3Ds_precomp ,
2627 raster_settings ,
28+ rasterizer_state
2729 )
2830
2931class _RasterizeGaussians (torch .autograd .Function ):
@@ -39,10 +41,9 @@ def forward(
3941 rotations ,
4042 cov3Ds_precomp ,
4143 raster_settings ,
44+ rasterizer_state
4245 ):
4346
44- rasterizer_state = _C .create_rasterizer_state ()
45-
4647 # Restructure arguments the way that the C++ lib expects them
4748 args = (
4849 raster_settings .bg ,
@@ -115,10 +116,9 @@ def backward(ctx, grad_out_color, _):
115116 grad_rotations ,
116117 grad_cov3Ds_precomp ,
117118 None ,
119+ None ,
118120 )
119121
120- _C .delete_rasterizer_state (rasterizer_state )
121-
122122 return grads
123123
124124class GaussianRasterizationSettings (NamedTuple ):
@@ -134,10 +134,17 @@ class GaussianRasterizationSettings(NamedTuple):
134134 campos : torch .Tensor
135135 prefiltered : bool
136136
137+ def createRasterizerState ():
138+ return _C .create_rasterizer_state ()
139+
140+ def deleteRasterizerState (state ):
141+ return _C .delete_rasterize_state (state )
142+
137143class GaussianRasterizer (nn .Module ):
138- def __init__ (self , raster_settings ):
144+ def __init__ (self , raster_settings , rasterizer_state ):
139145 super ().__init__ ()
140146 self .raster_settings = raster_settings
147+ self .rasterizer_state = rasterizer_state
141148
142149 def markVisible (self , positions ):
143150 # Mark visible points (based on frustum culling for camera) with a boolean
@@ -151,8 +158,8 @@ def markVisible(self, positions):
151158 return visible
152159
153160 def forward (self , means3D , means2D , opacities , shs = None , colors_precomp = None , scales = None , rotations = None , cov3D_precomp = None ):
154-
155161 raster_settings = self .raster_settings
162+ rasterize_state = self .rasterizer_state
156163
157164 if (shs is None and colors_precomp is None ) or (shs is not None and colors_precomp is not None ):
158165 raise Exception ('Please provide excatly one of either SHs or precomputed colors!' )
@@ -183,5 +190,6 @@ def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None
183190 rotations ,
184191 cov3D_precomp ,
185192 raster_settings ,
193+ rasterize_state
186194 )
187195
0 commit comments