diff --git a/src/inversion_ideas/recipes.py b/src/inversion_ideas/recipes.py index 73fc6cc..b4db2b4 100644 --- a/src/inversion_ideas/recipes.py +++ b/src/inversion_ideas/recipes.py @@ -307,9 +307,14 @@ def create_tikhonov_regularization( ----- TODO """ - # TODO: raise errors: - # if dims == 2 and alpha_z is passed - # if dims == 1 and alpha_y or alpha_z are passed + ndims = mesh.dim + if ndims == 2 and alpha_z is not None: + msg = f"Cannot pass 'alpha_z' when mesh has {ndims} dimensions." + raise TypeError(msg) + if ndims == 1 and (alpha_y is not None or alpha_z is not None): + msg = "Cannot pass 'alpha_y' nor 'alpha_z' when mesh has 1 dimension." + raise TypeError(msg) + smallness = Smallness( mesh, active_cells=active_cells, @@ -326,16 +331,24 @@ def create_tikhonov_regularization( if reference_model_in_flatness: kwargs["reference_model"] = reference_model - flatness_x = Flatness(mesh, **kwargs, direction="x") - if alpha_x is not None: - flatness_x = alpha_x * flatness_x - - flatness_y = Flatness(mesh, **kwargs, direction="y") - if alpha_y is not None: - flatness_y = alpha_y * flatness_y - - flatness_z = Flatness(mesh, **kwargs, direction="z") - if alpha_z is not None: - flatness_z = alpha_z * flatness_z - - return (smallness + flatness_x + flatness_y + flatness_z).flatten() + match ndims: + case 3: + directions = ("x", "y", "z") + alphas = (alpha_x, alpha_y, alpha_z) + case 2: + directions = ("x", "y") + alphas = (alpha_x, alpha_y) + case 1: + directions = ("x",) + alphas = (alpha_x,) + case _: + raise ValueError() + + regularization = smallness + for direction, alpha in zip(directions, alphas, strict=True): + phi = Flatness(mesh, **kwargs, direction=direction) + if alpha is not None: + phi = alpha * phi + regularization = regularization + phi + + return regularization.flatten()