# `Tinkex.Regularizer.GradientTracker`
[🔗](https://github.com/North-Shore-AI/tinkex/blob/v0.4.0/lib/tinkex/regularizer/gradient_tracker.ex#L1)

Computes gradient norms for regularizers using Nx automatic differentiation.

This module provides L2 gradient norm computation for monitoring which
regularizers dominate the training signal.

## Implementation Notes

Nx provides automatic differentiation through `Nx.Defn.grad/2`. We wrap
regularizer functions to extract just the loss tensor for differentiation.

Unlike PyTorch's `torch.autograd.grad(..., retain_graph=True)`, Nx
computes gradients symbolically and doesn't require graph retention.

## Usage

Gradient tracking is an optional feature enabled via the
`:track_grad_norms` option in the regularizer pipeline:

    Pipeline.compute(data, logprobs, base_loss_fn,
      regularizers: regularizers,
      track_grad_norms: true
    )

When enabled, each `RegularizerOutput` includes:
- `grad_norm`: L2 norm of the regularizer's gradient
- `grad_norm_weighted`: `weight * grad_norm`

# `compute_grad_norm`

```elixir
@spec compute_grad_norm(
  loss_fn :: (Nx.Tensor.t() -&gt; Nx.Tensor.t()),
  logprobs :: Nx.Tensor.t()
) :: float()
```

Compute L2 norm of gradients from a loss function with respect to inputs.

## Parameters
- `loss_fn` - Function that takes logprobs and returns scalar loss tensor
- `logprobs` - Nx tensor to differentiate with respect to

## Returns
Float representing the L2 norm: `sqrt(sum(grad^2))`

## Examples

    loss_fn = fn x -> Nx.sum(x) end
    norm = GradientTracker.compute_grad_norm(loss_fn, Nx.tensor([1.0, 2.0, 3.0]))
    # => 1.732... (sqrt(3))

# `grad_norm_for_regularizer`

```elixir
@spec grad_norm_for_regularizer(
  Tinkex.Types.RegularizerSpec.t(),
  [Tinkex.Types.Datum.t()],
  Nx.Tensor.t()
) :: float()
```

Compute gradient norm for a regularizer spec.

Wraps the regularizer function to extract just the loss for differentiation.

## Parameters
- `spec` - RegularizerSpec with the regularizer function
- `data` - Training data (passed to regularizer but not differentiated)
- `logprobs` - Nx tensor to differentiate with respect to

## Returns
Float representing the L2 gradient norm. Returns 0.0 if gradient
computation fails (e.g., for non-differentiable operations).

## Examples

    spec = %RegularizerSpec{fn: &my_regularizer/2, weight: 0.1, name: "l1"}
    norm = GradientTracker.grad_norm_for_regularizer(spec, data, logprobs)

# `total_grad_norm`

```elixir
@spec total_grad_norm(
  base_loss_fn :: function(),
  regularizers :: [Tinkex.Types.RegularizerSpec.t()],
  data :: [Tinkex.Types.Datum.t()],
  logprobs :: Nx.Tensor.t()
) :: float()
```

Compute gradient norm for the total composed loss.

Composes the base loss and all regularizers (with weights) and computes
the L2 norm of the combined gradient.

## Parameters
- `base_loss_fn` - Base loss function `(data, logprobs) -> {loss, metrics}`
- `regularizers` - List of RegularizerSpec structs
- `data` - Training data
- `logprobs` - Nx tensor to differentiate with respect to

## Returns
Float representing the L2 norm of the total gradient.

## Formula

The total loss is:
    total = base_loss + Σ(weight_i × regularizer_i)

The gradient is computed with respect to `logprobs`.

## Examples

    norm = GradientTracker.total_grad_norm(base_loss_fn, regularizers, data, logprobs)

---

*Consult [api-reference.md](api-reference.md) for complete listing*
