diff --git a/README.md b/README.md index 5ec1ef6..ed9064a 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,13 @@ We conduct a comprehensive analysis for the problem of existing spiking neuron m Based on this, we propose a bounded spiking neuron to build the discontinuous density field. ## Usage +### Download Data + +Download data for two example datasets: `lego` and `fern` +``` +bash download_example_data.sh +``` + #### Data Convention The Blender data is organized as follows: @@ -46,5 +53,7 @@ pip install -r requirements.txt - **Training Blender** ```shell -python nerf_vth2.py --config ./config/xx.txt +python nerf_vth2.py --config ./configs/{DATASET}.txt ``` + +replace `{DATASET}` with `trex` | `horns` | `flower` | `fortress` | `lego` | etc. diff --git a/download_example_data.sh b/download_example_data.sh new file mode 100644 index 0000000..1b552af --- /dev/null +++ b/download_example_data.sh @@ -0,0 +1,6 @@ +wget http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz +mkdir -p data +cd data +wget http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/nerf_example_data.zip +unzip nerf_example_data.zip +cd .. diff --git a/load_LINEMOD.py b/load_LINEMOD.py new file mode 100644 index 0000000..388fdbb --- /dev/null +++ b/load_LINEMOD.py @@ -0,0 +1,95 @@ +import os +import torch +import numpy as np +import imageio +import json +import torch.nn.functional as F +import cv2 + + +trans_t = lambda t : torch.Tensor([ + [1,0,0,0], + [0,1,0,0], + [0,0,1,t], + [0,0,0,1]]).float() + +rot_phi = lambda phi : torch.Tensor([ + [1,0,0,0], + [0,np.cos(phi),-np.sin(phi),0], + [0,np.sin(phi), np.cos(phi),0], + [0,0,0,1]]).float() + +rot_theta = lambda th : torch.Tensor([ + [np.cos(th),0,-np.sin(th),0], + [0,1,0,0], + [np.sin(th),0, np.cos(th),0], + [0,0,0,1]]).float() + + +def pose_spherical(theta, phi, radius): + c2w = trans_t(radius) + c2w = rot_phi(phi/180.*np.pi) @ c2w + c2w = rot_theta(theta/180.*np.pi) @ c2w + c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w + return c2w + + +def load_LINEMOD_data(basedir, half_res=False, testskip=1): + splits = ['train', 'val', 'test'] + metas = {} + for s in splits: + with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: + metas[s] = json.load(fp) + + all_imgs = [] + all_poses = [] + counts = [0] + for s in splits: + meta = metas[s] + imgs = [] + poses = [] + if s=='train' or testskip==0: + skip = 1 + else: + skip = testskip + + for idx_test, frame in enumerate(meta['frames'][::skip]): + fname = frame['file_path'] + if s == 'test': + print(f"{idx_test}th test frame: {fname}") + imgs.append(imageio.imread(fname)) + poses.append(np.array(frame['transform_matrix'])) + imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) + poses = np.array(poses).astype(np.float32) + counts.append(counts[-1] + imgs.shape[0]) + all_imgs.append(imgs) + all_poses.append(poses) + + i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] + + imgs = np.concatenate(all_imgs, 0) + poses = np.concatenate(all_poses, 0) + + H, W = imgs[0].shape[:2] + focal = float(meta['frames'][0]['intrinsic_matrix'][0][0]) + K = meta['frames'][0]['intrinsic_matrix'] + print(f"Focal: {focal}") + + render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) + + if half_res: + H = H//2 + W = W//2 + focal = focal/2. + + imgs_half_res = np.zeros((imgs.shape[0], H, W, 3)) + for i, img in enumerate(imgs): + imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) + imgs = imgs_half_res + # imgs = tf.image.resize_area(imgs, [400, 400]).numpy() + + near = np.floor(min(metas['train']['near'], metas['test']['near'])) + far = np.ceil(max(metas['train']['far'], metas['test']['far'])) + return imgs, poses, render_poses, [H, W, focal], K, i_split, near, far + + diff --git a/load_deepvoxels.py b/load_deepvoxels.py new file mode 100644 index 0000000..deb2a9c --- /dev/null +++ b/load_deepvoxels.py @@ -0,0 +1,110 @@ +import os +import numpy as np +import imageio + + +def load_dv_data(scene='cube', basedir='/data/deepvoxels', testskip=8): + + + def parse_intrinsics(filepath, trgt_sidelength, invert_y=False): + # Get camera intrinsics + with open(filepath, 'r') as file: + f, cx, cy = list(map(float, file.readline().split()))[:3] + grid_barycenter = np.array(list(map(float, file.readline().split()))) + near_plane = float(file.readline()) + scale = float(file.readline()) + height, width = map(float, file.readline().split()) + + try: + world2cam_poses = int(file.readline()) + except ValueError: + world2cam_poses = None + + if world2cam_poses is None: + world2cam_poses = False + + world2cam_poses = bool(world2cam_poses) + + print(cx,cy,f,height,width) + + cx = cx / width * trgt_sidelength + cy = cy / height * trgt_sidelength + f = trgt_sidelength / height * f + + fx = f + if invert_y: + fy = -f + else: + fy = f + + # Build the intrinsic matrices + full_intrinsic = np.array([[fx, 0., cx, 0.], + [0., fy, cy, 0], + [0., 0, 1, 0], + [0, 0, 0, 1]]) + + return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses + + + def load_pose(filename): + assert os.path.isfile(filename) + nums = open(filename).read().split() + return np.array([float(x) for x in nums]).reshape([4,4]).astype(np.float32) + + + H = 512 + W = 512 + deepvoxels_base = '{}/train/{}/'.format(basedir, scene) + + full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses = parse_intrinsics(os.path.join(deepvoxels_base, 'intrinsics.txt'), H) + print(full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses) + focal = full_intrinsic[0,0] + print(H, W, focal) + + + def dir2poses(posedir): + poses = np.stack([load_pose(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0) + transf = np.array([ + [1,0,0,0], + [0,-1,0,0], + [0,0,-1,0], + [0,0,0,1.], + ]) + poses = poses @ transf + poses = poses[:,:3,:4].astype(np.float32) + return poses + + posedir = os.path.join(deepvoxels_base, 'pose') + poses = dir2poses(posedir) + testposes = dir2poses('{}/test/{}/pose'.format(basedir, scene)) + testposes = testposes[::testskip] + valposes = dir2poses('{}/validation/{}/pose'.format(basedir, scene)) + valposes = valposes[::testskip] + + imgfiles = [f for f in sorted(os.listdir(os.path.join(deepvoxels_base, 'rgb'))) if f.endswith('png')] + imgs = np.stack([imageio.imread(os.path.join(deepvoxels_base, 'rgb', f))/255. for f in imgfiles], 0).astype(np.float32) + + + testimgd = '{}/test/{}/rgb'.format(basedir, scene) + imgfiles = [f for f in sorted(os.listdir(testimgd)) if f.endswith('png')] + testimgs = np.stack([imageio.imread(os.path.join(testimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32) + + valimgd = '{}/validation/{}/rgb'.format(basedir, scene) + imgfiles = [f for f in sorted(os.listdir(valimgd)) if f.endswith('png')] + valimgs = np.stack([imageio.imread(os.path.join(valimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32) + + all_imgs = [imgs, valimgs, testimgs] + counts = [0] + [x.shape[0] for x in all_imgs] + counts = np.cumsum(counts) + i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] + + imgs = np.concatenate(all_imgs, 0) + poses = np.concatenate([poses, valposes, testposes], 0) + + render_poses = testposes + + print(poses.shape, imgs.shape) + + return imgs, poses, render_poses, [H,W,focal], i_split + + diff --git a/load_llff.py b/load_llff.py new file mode 100644 index 0000000..98b7916 --- /dev/null +++ b/load_llff.py @@ -0,0 +1,319 @@ +import numpy as np +import os, imageio + + +########## Slightly modified version of LLFF data loading code +########## see https://github.com/Fyusion/LLFF for original + +def _minify(basedir, factors=[], resolutions=[]): + needtoload = False + for r in factors: + imgdir = os.path.join(basedir, 'images_{}'.format(r)) + if not os.path.exists(imgdir): + needtoload = True + for r in resolutions: + imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0])) + if not os.path.exists(imgdir): + needtoload = True + if not needtoload: + return + + from shutil import copy + from subprocess import check_output + + imgdir = os.path.join(basedir, 'images') + imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))] + imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])] + imgdir_orig = imgdir + + wd = os.getcwd() + + for r in factors + resolutions: + if isinstance(r, int): + name = 'images_{}'.format(r) + resizearg = '{}%'.format(100./r) + else: + name = 'images_{}x{}'.format(r[1], r[0]) + resizearg = '{}x{}'.format(r[1], r[0]) + imgdir = os.path.join(basedir, name) + if os.path.exists(imgdir): + continue + + print('Minifying', r, basedir) + + os.makedirs(imgdir) + check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True) + + ext = imgs[0].split('.')[-1] + args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)]) + print(args) + os.chdir(imgdir) + check_output(args, shell=True) + os.chdir(wd) + + if ext != 'png': + check_output('rm {}/*.{}'.format(imgdir, ext), shell=True) + print('Removed duplicates') + print('Done') + + + + +def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True): + + poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy')) + poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0]) + bds = poses_arr[:, -2:].transpose([1,0]) + + img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \ + if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0] + sh = imageio.imread(img0).shape + + sfx = '' + + if factor is not None: + sfx = '_{}'.format(factor) + _minify(basedir, factors=[factor]) + factor = factor + elif height is not None: + factor = sh[0] / float(height) + width = int(sh[1] / factor) + _minify(basedir, resolutions=[[height, width]]) + sfx = '_{}x{}'.format(width, height) + elif width is not None: + factor = sh[1] / float(width) + height = int(sh[0] / factor) + _minify(basedir, resolutions=[[height, width]]) + sfx = '_{}x{}'.format(width, height) + else: + factor = 1 + + imgdir = os.path.join(basedir, 'images' + sfx) + if not os.path.exists(imgdir): + print( imgdir, 'does not exist, returning' ) + return + + imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')] + if poses.shape[-1] != len(imgfiles): + print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) ) + return + + sh = imageio.imread(imgfiles[0]).shape + poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1]) + poses[2, 4, :] = poses[2, 4, :] * 1./factor + + if not load_imgs: + return poses, bds + + def imread(f): + if f.endswith('png'): + return imageio.imread(f, ignoregamma=True) + else: + return imageio.imread(f) + + imgs = imgs = [imread(f)[...,:3]/255. for f in imgfiles] + imgs = np.stack(imgs, -1) + + print('Loaded image data', imgs.shape, poses[:,-1,0]) + return poses, bds, imgs + + + + + + +def normalize(x): + return x / np.linalg.norm(x) + +def viewmatrix(z, up, pos): + vec2 = normalize(z) + vec1_avg = up + vec0 = normalize(np.cross(vec1_avg, vec2)) + vec1 = normalize(np.cross(vec2, vec0)) + m = np.stack([vec0, vec1, vec2, pos], 1) + return m + +def ptstocam(pts, c2w): + tt = np.matmul(c2w[:3,:3].T, (pts-c2w[:3,3])[...,np.newaxis])[...,0] + return tt + +def poses_avg(poses): + + hwf = poses[0, :3, -1:] + + center = poses[:, :3, 3].mean(0) + vec2 = normalize(poses[:, :3, 2].sum(0)) + up = poses[:, :3, 1].sum(0) + c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) + + return c2w + + + +def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N): + render_poses = [] + rads = np.array(list(rads) + [1.]) + hwf = c2w[:,4:5] + + for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]: + c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads) + z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.]))) + render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) + return render_poses + + + +def recenter_poses(poses): + + poses_ = poses+0 + bottom = np.reshape([0,0,0,1.], [1,4]) + c2w = poses_avg(poses) + c2w = np.concatenate([c2w[:3,:4], bottom], -2) + bottom = np.tile(np.reshape(bottom, [1,1,4]), [poses.shape[0],1,1]) + poses = np.concatenate([poses[:,:3,:4], bottom], -2) + + poses = np.linalg.inv(c2w) @ poses + poses_[:,:3,:4] = poses[:,:3,:4] + poses = poses_ + return poses + + +##################### + + +def spherify_poses(poses, bds): + + p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1) + + rays_d = poses[:,:3,2:3] + rays_o = poses[:,:3,3:4] + + def min_line_dist(rays_o, rays_d): + A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1]) + b_i = -A_i @ rays_o + pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0)) + return pt_mindist + + pt_mindist = min_line_dist(rays_o, rays_d) + + center = pt_mindist + up = (poses[:,:3,3] - center).mean(0) + + vec0 = normalize(up) + vec1 = normalize(np.cross([.1,.2,.3], vec0)) + vec2 = normalize(np.cross(vec0, vec1)) + pos = center + c2w = np.stack([vec1, vec2, vec0, pos], 1) + + poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4]) + + rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:,:3,3]), -1))) + + sc = 1./rad + poses_reset[:,:3,3] *= sc + bds *= sc + rad *= sc + + centroid = np.mean(poses_reset[:,:3,3], 0) + zh = centroid[2] + radcircle = np.sqrt(rad**2-zh**2) + new_poses = [] + + for th in np.linspace(0.,2.*np.pi, 120): + + camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) + up = np.array([0,0,-1.]) + + vec2 = normalize(camorigin) + vec0 = normalize(np.cross(vec2, up)) + vec1 = normalize(np.cross(vec2, vec0)) + pos = camorigin + p = np.stack([vec0, vec1, vec2, pos], 1) + + new_poses.append(p) + + new_poses = np.stack(new_poses, 0) + + new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0,:3,-1:], new_poses[:,:3,-1:].shape)], -1) + poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1) + + return poses_reset, new_poses, bds + + +def load_llff_data(basedir, factor=8, recenter=True, bd_factor=.75, spherify=False, path_zflat=False): + + + poses, bds, imgs = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x + print('Loaded', basedir, bds.min(), bds.max()) + + # Correct rotation matrix ordering and move variable dim to axis 0 + poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) + poses = np.moveaxis(poses, -1, 0).astype(np.float32) + imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) + images = imgs + bds = np.moveaxis(bds, -1, 0).astype(np.float32) + + # Rescale if bd_factor is provided + sc = 1. if bd_factor is None else 1./(bds.min() * bd_factor) + poses[:,:3,3] *= sc + bds *= sc + + if recenter: + poses = recenter_poses(poses) + + if spherify: + poses, render_poses, bds = spherify_poses(poses, bds) + + else: + + c2w = poses_avg(poses) + print('recentered', c2w.shape) + print(c2w[:3,:4]) + + ## Get spiral + # Get average pose + up = normalize(poses[:, :3, 1].sum(0)) + + # Find a reasonable "focus depth" for this dataset + close_depth, inf_depth = bds.min()*.9, bds.max()*5. + dt = .75 + mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth)) + focal = mean_dz + + # Get radii for spiral path + shrink_factor = .8 + zdelta = close_depth * .2 + tt = poses[:,:3,3] # ptstocam(poses[:3,3,:].T, c2w).T + rads = np.percentile(np.abs(tt), 90, 0) + c2w_path = c2w + N_views = 120 + N_rots = 2 + if path_zflat: +# zloc = np.percentile(tt, 10, 0)[2] + zloc = -close_depth * .1 + c2w_path[:3,3] = c2w_path[:3,3] + zloc * c2w_path[:3,2] + rads[2] = 0. + N_rots = 1 + N_views/=2 + + # Generate poses for spiral path + render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views) + + + render_poses = np.array(render_poses).astype(np.float32) + + c2w = poses_avg(poses) + print('Data:') + print(poses.shape, images.shape, bds.shape) + + dists = np.sum(np.square(c2w[:3,3] - poses[:,:3,3]), -1) + i_test = np.argmin(dists) + print('HOLDOUT view is', i_test) + + images = images.astype(np.float32) + poses = poses.astype(np.float32) + + return images, poses, bds, render_poses, i_test + + + diff --git a/nerf_vth2.py b/nerf_vth2.py index 7a5ce0a..397d215 100644 --- a/nerf_vth2.py +++ b/nerf_vth2.py @@ -1,21 +1,15 @@ -import math import os, sys import numpy as np import imageio -import json import random import time import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.data import Dataset, DataLoader -from tqdm import tqdm, trange -from spikingjelly.clock_driven import ann2snn -import matplotlib.pyplot as plt -import subprocess +from torch.utils.data import Dataset +from tqdm import tqdm from run_nerf_helpers_snn import * from datetime import datetime -from dataset import Real_Dataset from load_llff import load_llff_data from load_deepvoxels import load_dv_data from load_blender import load_blender_data @@ -278,7 +272,7 @@ def create_nerf(args): model.init() model = nn.DataParallel(model).to(device) - grad_vars = list(model.named_parameters()) + grad_vars = list(model.parameters()) model_fine = None if args.N_importance > 0: @@ -288,7 +282,7 @@ def create_nerf(args): model_fine.init() model_fine = nn.DataParallel(model_fine).to(device) - grad_vars += list(model_fine.named_parameters()) + grad_vars += list(model_fine.parameters()) network_query_fn = lambda inputs, viewdirs, network_fn: run_network(inputs, viewdirs, network_fn, embed_fn=embed_fn, diff --git a/run_nerf_helpers_snn.py b/run_nerf_helpers_snn.py index 855598a..98cc0ed 100644 --- a/run_nerf_helpers_snn.py +++ b/run_nerf_helpers_snn.py @@ -2,81 +2,11 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np -from spikingjelly.clock_driven.ann2snn import modules -from spikingjelly.clock_driven import neuron -from spikingjelly.clock_driven import functional, surrogate -from spikingjelly.clock_driven.neuron import BaseNode, IFNode, LIFNode # Misc img2mse = lambda x, y: torch.mean((x - y) ** 2) mse2psnr = lambda x: -10. * torch.log(x) / torch.log(torch.Tensor([10.])) to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) -class Embedder2: - def __init__(self, **kwargs): - self.kwargs = kwargs - self.create_embedding_fn() - - def create_embedding_fn(self): - embed_fns = [] - d = self.kwargs['input_dims'] - out_dim = 0 - if self.kwargs['include_input']: - embed_fns.append(lambda x: x) - out_dim += d - - max_freq = self.kwargs['max_freq_log2'] - N_freqs = self.kwargs['num_freqs'] - - if self.kwargs['log_sampling']: - freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) - else: - freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs) - - for freq in freq_bands: - for p_fn in self.kwargs['periodic_fns']: - embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) - out_dim += d - - self.embed_fns = embed_fns - self.out_dim = out_dim - - def embed(self, inputs): - return torch.cat([fn(inputs) for fn in self.embed_fns], -1) - -def get_embedder2(multires, input_dims=3): - embed_kwargs = { - 'include_input': True, - 'input_dims': input_dims, - 'max_freq_log2': multires-1, - 'num_freqs': multires, - 'log_sampling': True, - 'periodic_fns': [torch.sin, torch.cos], - } - - embedder_obj = Embedder2(**embed_kwargs) - def embed(x, eo=embedder_obj): return eo.embed(x) - return embed, embedder_obj.out_dim - -class EdgePreservingSmoothnessLoss(nn.Module): - def __init__(self, opt=0): - super().__init__() - # self.opt = opt - self.patch_size = 4 - self.gamma = 0.1 - self.loss = lambda x: torch.mean(torch.abs(x)) - self.bilateral_filter = lambda x: torch.exp(-torch.abs(x).sum(-1) / self.gamma) - - def forward(self, inputs, weights): - w1 = self.bilateral_filter(weights[:,:,:-1] - weights[:,:,1:]) - w2 = self.bilateral_filter(weights[:,:-1,:] - weights[:,1:,:]) - w3 = self.bilateral_filter(weights[:,:-1,:-1] - weights[:,1:,1:]) - w4 = self.bilateral_filter(weights[:,1:,:-1] - weights[:,:-1,1:]) - - L1 = self.loss(w1 * (inputs[:,:,:-1] - inputs[:,:,1:])) - L2 = self.loss(w2 * (inputs[:,:-1,:] - inputs[:,1:,:])) - L3 = self.loss(w3 * (inputs[:,:-1,:-1] - inputs[:,1:,1:])) - L4 = self.loss(w4 * (inputs[:,1:,:-1] - inputs[:,:-1,1:])) - return (L1 + L2 + L3 + L4) / 4 # Positional encoding (section 5.1) class Embedder: def __init__(self, **kwargs): @@ -131,23 +61,6 @@ def get_embedder(multires, i=0): return embed, embedder_obj.out_dim -def get_embedder_ex(multires, input_dims, i=0): - if i == -1: - return nn.Identity(), 3 - - embed_kwargs = { - 'include_input': True, - 'input_dims': input_dims, - 'max_freq_log2': multires - 1, - 'num_freqs': multires, - 'log_sampling': True, - 'periodic_fns': [torch.sin, torch.cos], - } - - embedder_obj = Embedder(**embed_kwargs) - embed = lambda x, eo=embedder_obj: eo.embed(x) - return embed, embedder_obj.out_dim - # Model class NeRF(nn.Module): def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False):