Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 111 additions & 49 deletions bae/autograd/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,83 @@
import numpy as np
import pypose as pp

WHITELISTED_MAPS = (torch._C.TensorBase.__add__,
torch._C.TensorBase.__sub__,
torch._C.TensorBase.__mul__,
torch._C.TensorBase.__div__,
torch._C.TensorBase.add,
torch._C.TensorBase.sub,
torch._C.TensorBase.mul,)
WHITELISTED_MAPS = tuple(
func for func in (
torch._C.TensorBase.__add__,
torch._C.TensorBase.__sub__,
torch._C.TensorBase.__mul__,
getattr(torch._C.TensorBase, "__div__", None),
torch._C.TensorBase.add,
torch._C.TensorBase.sub,
torch._C.TensorBase.mul,
) if func is not None
)

_LTYPE_PRESERVING_FUNCS = {
torch._C.TensorBase.__getitem__,
torch.cat,
*WHITELISTED_MAPS,
torch._C.TensorBase.clone,
torch._C.TensorBase.to,
}


def _iter_tracked_tensors(values):
if isinstance(values, torch.Tensor):
yield values
elif isinstance(values, (tuple, list)):
for value in values:
yield from _iter_tracked_tensors(value)


def _merge_optrace(values):
merged_optrace = {}
for value in _iter_tracked_tensors(values):
if hasattr(value, 'optrace'):
merged_optrace.update(value.optrace)
return merged_optrace


def _attach_index_trace(result, index, tensor):
if not hasattr(result, 'optrace'):
result.optrace = {}
result.optrace[id(result)] = ("index", index, tensor)
return result


def _attach_cat_trace(result, tensors, dim):
merged_optrace = _merge_optrace(tensors)
merged_optrace[id(result)] = ("cat", dim, tuple(tensors))
result.optrace = merged_optrace
return result


def _attach_map_trace(result, func, args):
merged_optrace = _merge_optrace(args)
merged_optrace[id(result)] = ("map", func, args)
result.optrace = merged_optrace
return result


def _find_tracking_source(values, cls):
for value in _iter_tracked_tensors(values):
if isinstance(value, cls):
return value
return None


def _retain_ltype(result, tracking_source, cls, func):
if tracking_source is None or not issubclass(cls, pp.LieTensor):
return result
if func not in _LTYPE_PRESERVING_FUNCS:
return result
if not isinstance(result, torch.Tensor) or isinstance(result, cls):
return result
if result.shape[-1:] != tracking_source.ltype.dimension:
return result
wrapped = torch.Tensor.as_subclass(result, cls)
wrapped.ltype = tracking_source.ltype
return wrapped

# =============================================================================
# Class: IndexTrackingTensor
Expand All @@ -18,49 +88,34 @@
class TrackingTensor(torch.Tensor):
@staticmethod
def __new__(cls, data, *args, **kwargs):
if cls is TrackingTensor and isinstance(data, pp.LieTensor):
return _TrackingLieTensor(data, *args, **kwargs)

if isinstance(data, torch.Tensor):
instance = torch.Tensor._make_subclass(cls, data, *args, **kwargs)
else:
instance = torch.Tensor._make_subclass(cls, torch.as_tensor(data), *args, **kwargs)
return instance
return torch.Tensor._make_subclass(cls, data, *args, **kwargs)
return torch.Tensor._make_subclass(cls, torch.as_tensor(data), *args, **kwargs)


@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
result = super(TrackingTensor, cls).__torch_function__(func, types, args=args, kwargs=kwargs)

if isinstance(result, torch.Tensor) and getattr(args[0], '_active', True):
# print(f"__torch_function__ called with {func}")
result = _retain_ltype(result, _find_tracking_source(args, cls), cls, func)

if isinstance(result, torch.Tensor) and (not args or getattr(args[0], '_active', True)):
if (func == torch._C.TensorBase.__getitem__) and isinstance(args[1], torch.Tensor):
if not hasattr(result, 'optrace'):
result.optrace = {}
index_edge = ("index", args[1], args[0])
result.optrace[id(result)] = index_edge
result = _attach_index_trace(result, args[1], args[0])
elif func == torch.cat:
if kwargs is None:
kwargs = {}
tensors = args[0]
dim = kwargs.get("dim", 0)
if len(args) > 1:
dim = args[1]
if dim != 0:
raise NotImplementedError("TrackingTensor only supports torch.cat(..., dim=0)")

merged_optrace = {}
for t in tensors:
if isinstance(t, torch.Tensor) and hasattr(t, 'optrace'):
merged_optrace.update(t.optrace)

merged_optrace[id(result)] = ("cat", 0, tuple(tensors))
result.optrace = merged_optrace
result = _attach_cat_trace(result, tensors, dim)
elif func in WHITELISTED_MAPS:
merged_optrace = {}
for arg in args:
if isinstance(arg, torch.Tensor) and hasattr(arg, 'optrace'):
merged_optrace.update(arg.optrace)

merged_optrace[id(result)] = ("map", func, args)
result.optrace = merged_optrace
result = _attach_map_trace(result, func, args)
return result

def __getitem__(self, index):
Expand Down Expand Up @@ -95,6 +150,25 @@ def _convert_to_index_tensor(self, index):

def tensor(self) -> torch.Tensor:
return torch.Tensor.as_subclass(self, torch.Tensor)


class _TrackingLieTensor(TrackingTensor, pp.LieTensor):
def __init__(self, data=None, *args, **kwargs):
if isinstance(data, pp.LieTensor):
self.ltype = data.ltype

@staticmethod
def __new__(cls, data, *args, **kwargs):
if not isinstance(data, pp.LieTensor):
raise TypeError(f"_TrackingLieTensor expects a LieTensor input, got {type(data)!r}")
instance = torch.Tensor.as_subclass(data, cls)
instance.ltype = data.ltype
return instance

def detach(self):
detached = torch.Tensor.as_subclass(super().detach(), type(self))
detached.ltype = self.ltype
return detached
"""
graph design
Node: (tensor_type: [nn.Parameter, tensor, pp.LieTensor])
Expand Down Expand Up @@ -127,12 +201,7 @@ def tensor(self) -> torch.Tensor:
# =============================================================================
def index_transform(tensor, index):
result = tensor[index]
if not hasattr(result, 'optrace'):
result.optrace = {}
# index edge (edge_type, indicies, orig_arg)
index_edge = ("index", index, tensor)
result.optrace[id(result)] = index_edge
return result
return _attach_index_trace(result, index, tensor)


# =============================================================================
Expand All @@ -146,14 +215,7 @@ def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
# map edge (edge_type, func, [input_args])
# ensure final result is an IndexTrackingTensor
merged_optrace = {}
for arg in args:
if isinstance(arg, torch.Tensor) and hasattr(arg, 'optrace'):
merged_optrace.update(arg.optrace)

merged_optrace[id(result)] = ("map", func, args)
result.optrace = merged_optrace
return result
return _attach_map_trace(result, func, args)
return wrapper

# map_transform(vmap(func))
Expand Down
87 changes: 54 additions & 33 deletions bae/autograd/graph.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@

from typing import Optional

import pypose as pp
import torch
from torch.func import jacrev

from ..sparse import warp_wrappers as _warp_wrappers # noqa: F401
from ..utils.parameter import trim_parameter_jacobian_values


def _crow_to_row_indices(crow_indices: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -109,6 +111,34 @@ def construct_sbt(jac_from_vmap, num, index: Optional[torch.Tensor], type=torch.
size = (n * block_shape[0], num * block_shape[1]),
device=index.device, dtype=jac_from_vmap.dtype)

def _clear_jactrace(output, params):
seen = set()
stack = [output, *params]
while stack:
tensor = stack.pop()
if not isinstance(tensor, torch.Tensor) or id(tensor) in seen:
continue
seen.add(id(tensor))

if hasattr(tensor, 'jactrace'):
delattr(tensor, 'jactrace')

if not hasattr(tensor, 'optrace') or id(tensor) not in tensor.optrace:
continue

op = tensor.optrace[id(tensor)][0]
if op == 'map':
args = tensor.optrace[id(tensor)][2]
stack.extend(arg for arg in args if isinstance(arg, torch.Tensor))
elif op == 'index':
arg = tensor.optrace[id(tensor)][2]
if isinstance(arg, torch.Tensor):
stack.append(arg)
elif op == 'cat':
args = tensor.optrace[id(tensor)][2]
stack.extend(arg for arg in args if isinstance(arg, torch.Tensor))


def amend_trace(arg, jac_trace: tuple):
if hasattr(arg, 'jactrace'): # convert to sparse_bsr needed for accumulation
if type(arg.jactrace) is tuple and type(jac_trace) is tuple:
Expand Down Expand Up @@ -150,7 +180,8 @@ def backward(output_):
if len(argnums) == 0:
warning("No upstream parameters to compute jacobian")
return
jac_blocks = torch.vmap(jacrev(func, argnums=argnums))(*args)
with pp.retain_ltype():
jac_blocks = torch.vmap(jacrev(func, argnums=argnums))(*args)
for jacidx, argidx in enumerate(argnums):
jac_block = jac_blocks[jacidx]
arg = args[argidx]
Expand Down Expand Up @@ -253,38 +284,28 @@ def backward(output_):

def jacobian(output, params):
assert output.optrace[id(output)][0] in ('map', 'index', 'cat'), "Unsupported last operation in compute graph"
backward(output)
res = []
for param in params:
if hasattr(param, 'jactrace'):
if getattr(param, 'trim_SE3_grad', False):
if isinstance(param.jactrace, tuple):
values = param.jactrace[1]
elif isinstance(param.jactrace, torch.Tensor) and param.jactrace.layout == torch.sparse_bsr:
values = param.jactrace.values()
else:
values = param.jactrace

if values.shape[-1] == 7:
values = values[..., :6]
else:
values = torch.cat([values[..., :6], values[..., 7:]], dim=-1)

_clear_jactrace(output, params)
try:
backward(output)
res = []
for param in params:
if hasattr(param, 'jactrace'):
if isinstance(param.jactrace, tuple):
values = trim_parameter_jacobian_values(param, param.jactrace[1])
param.jactrace = (param.jactrace[0], values)
elif isinstance(param.jactrace, torch.Tensor) and param.jactrace.layout == torch.sparse_bsr:
param.jactrace = torch.sparse_bsr_tensor(
col_indices=param.jactrace.col_indices(),
crow_indices=param.jactrace.crow_indices(),
values=values,
size=(param.jactrace.shape[0], param.shape[0] * values.shape[-1]),
device=param.device,
)
else:
param.jactrace = values
if type(param.jactrace) is tuple:
param.jactrace = construct_sbt(param.jactrace[1], param.shape[0], param.jactrace[0], type=torch.sparse_bsr)
res.append(param.jactrace)
delattr(param, 'jactrace')
return res
values = trim_parameter_jacobian_values(param, param.jactrace.values())
if values.shape != param.jactrace.values().shape:
param.jactrace = torch.sparse_bsr_tensor(
col_indices=param.jactrace.col_indices(),
crow_indices=param.jactrace.crow_indices(),
values=values,
size=(param.jactrace.shape[0], param.shape[0] * values.shape[-1]),
device=param.device,
)
if type(param.jactrace) is tuple:
param.jactrace = construct_sbt(param.jactrace[1], param.shape[0], param.jactrace[0], type=torch.sparse_bsr)
res.append(param.jactrace)
return res
finally:
_clear_jactrace(output, params)
14 changes: 6 additions & 8 deletions bae/optim/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from functools import partial
import math
import torch
from pypose.optim import LevenbergMarquardt as ppLM
import pypose as pp
from ..autograd.graph import jacobian
from ..autograd.function import TrackingTensor
from ..sparse.py_ops import diagonal_op_
from ..sparse.spgemm import CuSparse
from ..utils.parameter import parameter_update_shape



Expand Down Expand Up @@ -58,16 +58,14 @@ def update_parameter(self, params, step):
numels = []
for param in params:
if param.requires_grad:
if getattr(param, 'trim_SE3_grad', False):
numels.append(math.prod(param.shape[:-1]) * (param.shape[-1] - 1))
else:
numels.append(param.numel())
numels.append(torch.Size(parameter_update_shape(param)).numel())
steps = step.split(numels)
for (param, d) in zip(params, steps):
if param.requires_grad:
step_view = d.view(parameter_update_shape(param))
if getattr(param, 'trim_SE3_grad', False):
param[..., :7] = pp.SE3(param[..., :7]).add_(pp.se3(d.view(param.shape[0], -1)[..., :6]))
param[..., :7] = pp.SE3(param[..., :7]).add_(pp.se3(step_view[..., :6]))
if param.shape[-1] > 7:
param[:, 7:] += d.view(param.shape[0], -1)[:, 6:]
param[:, 7:] += step_view[..., 6:]
else:
param.add_(d.view(param.shape))
param.add_(step_view)
24 changes: 24 additions & 0 deletions bae/utils/parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pypose as pp
import torch


def parameter_update_shape(param: torch.Tensor) -> torch.Size:
if param.ndim == 0:
return param.shape
if getattr(param, 'trim_SE3_grad', False):
return torch.Size((*param.shape[:-1], param.shape[-1] - 1))
if isinstance(param, pp.LieTensor):
return torch.Size((*param.shape[:-1], param.ltype.manifold[0]))
return param.shape


def trim_parameter_jacobian_values(param: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
if param.ndim == 0 or values.shape[-1] != param.shape[-1]:
return values
if getattr(param, 'trim_SE3_grad', False):
return torch.cat([values[..., :6], values[..., 7:]], dim=-1)
if isinstance(param, pp.LieTensor):
step_dim = int(param.ltype.manifold[0])
if step_dim != param.shape[-1]:
return values[..., :step_dim]
return values
Loading
Loading