Skip to content

Soft Operations¤

The diffbio.core.soft_ops module provides smooth, differentiable relaxations of discrete, piecewise-linear, and sharp operations. These relaxations enable end-to-end gradient-based optimization through operations that are normally non-differentiable -- comparisons, sorting, indexing, rounding, and logical gates all become continuous functions with well-defined gradients.

Acknowledgment

The soft operations in this module are based on the algorithms and implementations from SoftJAX (Paulus et al., 2026; arXiv:2603.08824), adapted for the DiffBio/JAX/Flax NNX ecosystem.

All soft operations are JIT-compatible and support jax.grad and jax.vmap. Most accept a softness parameter controlling the width of the transition region (higher = smoother) and a mode parameter selecting the smoothness family:

Mode Description
"hard" Exact (non-differentiable) version matching JAX
"smooth" C-infinity smooth via logistic sigmoid
"c0" Continuous (C0) via piecewise linear/quadratic
"c1" Once differentiable (C1) via cubic Hermite polynomial
"c2" Twice differentiable (C2) via quintic Hermite polynomial

Name shadowing

Several names (abs, all, any, round, max, min) shadow Python builtins. Use qualified imports:

from diffbio.core import soft_ops
soft_ops.max(x, softness=0.1)

or alias on import:

from diffbio.core.soft_ops import max as soft_max

Types¤

SoftBool¤

diffbio.core.soft_ops._types.SoftBool module-attribute ¤

SoftBool = Float[Array, '...']

Soft boolean: probability in [0, 1].

SoftIndex¤

diffbio.core.soft_ops._types.SoftIndex module-attribute ¤

SoftIndex = Float[Array, '...']

Soft index: probabilities summing to 1 along the last axis.


Autograd-Safe Math¤

NaN-free alternatives to standard math functions. These use the double-where trick so that the forward pass computes correct values even at domain boundaries, while the backward pass produces finite (zero) gradients instead of NaN or Inf.

sqrt¤

diffbio.core.soft_ops.autograd_safe.sqrt ¤

sqrt(x: Array) -> Array

Autograd-safe square root.

Returns sqrt(x) for x > 0 and 0 otherwise, without producing NaN gradients at x = 0.

Parameters:

Name Type Description Default
x Array

Input array.

required

Returns:

Type Description
Array

Elementwise square root, safe for autodiff.

arcsin¤

diffbio.core.soft_ops.autograd_safe.arcsin ¤

arcsin(x: Array) -> Array

Autograd-safe arcsine.

Returns arcsin(x) for |x| < 1 and +/-pi/2 at the boundary, without producing NaN gradients at x = +/-1.

Parameters:

Name Type Description Default
x Array

Input array with values in [-1, 1].

required

Returns:

Type Description
Array

Elementwise arcsine, safe for autodiff.

arccos¤

diffbio.core.soft_ops.autograd_safe.arccos ¤

arccos(x: Array) -> Array

Autograd-safe arccosine.

Returns arccos(x) for |x| < 1, 0 at x = 1, and pi at x = -1, without producing NaN gradients at the boundary.

Parameters:

Name Type Description Default
x Array

Input array with values in [-1, 1].

required

Returns:

Type Description
Array

Elementwise arccosine, safe for autodiff.

div¤

diffbio.core.soft_ops.autograd_safe.div ¤

div(x: Array, y: Array) -> Array

Autograd-safe division.

Returns x / y when y != 0 and 0 otherwise, without producing NaN gradients at y = 0.

Parameters:

Name Type Description Default
x Array

Numerator array.

required
y Array

Denominator array.

required

Returns:

Type Description
Array

Elementwise safe division.

log¤

diffbio.core.soft_ops.autograd_safe.log ¤

log(x: Array) -> Array

Autograd-safe natural logarithm.

Returns log(x) for x > 0 and 0 otherwise, without producing NaN gradients at x = 0.

Parameters:

Name Type Description Default
x Array

Input array.

required

Returns:

Type Description
Array

Elementwise natural logarithm, safe for autodiff.

norm¤

diffbio.core.soft_ops.autograd_safe.norm ¤

norm(
    x: Array,
    axis: int | None = None,
    keepdims: bool = False,
) -> Array

Autograd-safe L2 norm.

Computes sqrt(sum(x**2)) using :func:sqrt, avoiding NaN gradients when the norm is zero.

Parameters:

Name Type Description Default
x Array

Input array.

required
axis int | None

Axis or axes along which to compute the norm.

None
keepdims bool

If True, retains reduced axes with size 1.

False

Returns:

Type Description
Array

L2 norm along the given axis, safe for autodiff.


Elementwise¤

Differentiable relaxations of elementwise non-smooth functions using sigmoidal smoothing with configurable smoothness modes.

sigmoidal¤

diffbio.core.soft_ops.elementwise.sigmoidal ¤

sigmoidal(
    x: Array,
    softness: float | Array = 0.1,
    mode: SigmoidalMode = "smooth",
) -> SoftBool

Sigmoidal S-curve function mapping R -> (0, 1).

Foundation for all other elementwise operations. Maps input values through an S-shaped curve centered at 0, approaching 0 for large negative values and 1 for large positive values.

Parameters:

Name Type Description Default
x Array

Input array.

required
softness float | Array

Width of transition region (> 0). Higher = smoother.

0.1
mode SigmoidalMode

Smoothness family: "smooth" (logistic sigmoid), "c0" (piecewise linear), "c1" (cubic Hermite), "c2" (quintic Hermite).

'smooth'

Returns:

Type Description
SoftBool

SoftBool array with values in [0, 1].

softrelu¤

diffbio.core.soft_ops.elementwise.softrelu ¤

softrelu(
    x: Array,
    softness: float | Array = 0.1,
    mode: SigmoidalMode = "smooth",
    gated: bool = False,
) -> Array

Family of soft relaxations to ReLU.

Two variants: - Non-gated (default): Antiderivative of :func:sigmoidal. Smooth analog of max(0, x). - Gated: x * sigmoidal(x). SiLU-style gating.

Parameters:

Name Type Description Default
x Array

Input array.

required
softness float | Array

Width of transition region (> 0).

0.1
mode SigmoidalMode

Smoothness family (see :func:sigmoidal).

'smooth'
gated bool

If True, use gated version x * sigmoidal(x).

False

Returns:

Type Description
Array

Soft ReLU output, same shape as x.

heaviside¤

diffbio.core.soft_ops.elementwise.heaviside ¤

heaviside(
    x: Array,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
) -> SoftBool

Soft Heaviside step function.

Returns 0 for x < 0, 1 for x > 0, and 0.5 at x = 0 (hard mode). Soft modes use :func:sigmoidal for smooth transition.

Parameters:

Name Type Description Default
x Array

Input array.

required
softness float | Array

Width of transition (> 0).

0.1
mode Mode

"hard" or sigmoidal mode.

'smooth'

Returns:

Type Description
SoftBool

SoftBool in [0, 1].

round¤

diffbio.core.soft_ops.elementwise.round ¤

round(
    x: Array,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    neighbor_radius: int = 5,
) -> Array

Soft rounding.

Hard mode returns jnp.round(x). Soft modes use a weighted sum of nearby integers, with weights from :func:sigmoidal.

Parameters:

Name Type Description Default
x Array

Input array.

required
softness float | Array

Width of transition (> 0).

0.1
mode Mode

"hard" or sigmoidal mode.

'smooth'
neighbor_radius int

Number of integer neighbors to consider.

5

Returns:

Type Description
Array

Soft-rounded values.

sign¤

diffbio.core.soft_ops.elementwise.sign ¤

sign(
    x: Array,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
) -> Array

Soft sign function.

Maps to [-1, 1]. Hard mode returns jnp.sign(x). Soft modes use 2 * sigmoidal(x) - 1.

Parameters:

Name Type Description Default
x Array

Input array.

required
softness float | Array

Width of transition (> 0).

0.1
mode Mode

"hard" or sigmoidal mode.

'smooth'

Returns:

Type Description
Array

Values in [-1, 1].

abs¤

diffbio.core.soft_ops.elementwise.abs ¤

abs(
    x: Array,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
) -> Array

Soft absolute value.

Hard mode returns jnp.abs(x). Soft modes use x * sign(x, softness, mode).

Parameters:

Name Type Description Default
x Array

Input array.

required
softness float | Array

Width of transition (> 0).

0.1
mode Mode

"hard" or sigmoidal mode.

'smooth'

Returns:

Type Description
Array

Non-negative values (approximately).

relu¤

diffbio.core.soft_ops.elementwise.relu ¤

relu(
    x: Array,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    gated: bool = False,
) -> Array

Soft ReLU.

Hard mode returns jax.nn.relu(x). Soft modes delegate to :func:softrelu.

Parameters:

Name Type Description Default
x Array

Input array.

required
softness float | Array

Width of transition (> 0).

0.1
mode Mode

"hard" or sigmoidal mode.

'smooth'
gated bool

If True, use gated variant.

False

Returns:

Type Description
Array

Soft ReLU output.

clip¤

diffbio.core.soft_ops.elementwise.clip ¤

clip(
    x: Array,
    a: float | Array,
    b: float | Array,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    gated: bool = False,
) -> Array

Soft clipping to [a, b].

Hard mode returns jnp.clip(x, a, b). Soft modes use two :func:softrelu calls: a + softrelu(x - a) - softrelu(x - b).

Parameters:

Name Type Description Default
x Array

Input array.

required
a float | Array

Lower bound.

required
b float | Array

Upper bound.

required
softness float | Array

Width of transition (> 0).

0.1
mode Mode

"hard" or sigmoidal mode.

'smooth'
gated bool

If True, use gated softrelu variant.

False

Returns:

Type Description
Array

Clipped values approximately in [a, b].


Comparison¤

Differentiable relaxations of elementwise comparison operations, returning SoftBool values in [0, 1]. Each function uses sigmoidal as the underlying smooth step function.

greater¤

diffbio.core.soft_ops.comparison.greater ¤

greater(
    x: Array,
    y: float | Array,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    epsilon: float = 1e-10,
) -> SoftBool

Soft x > y.

Uses sigmoidal on x - y - epsilon so the output approaches 0 at equality as softness -> 0.

Parameters:

Name Type Description Default
x Array

First input array.

required
y float | Array

Second input array (broadcastable with x).

required
softness float | Array

Width of transition (> 0).

0.1
mode Mode

"hard" or sigmoidal mode.

'smooth'
epsilon float

Small offset for strict inequality at the limit.

1e-10

Returns:

Type Description
SoftBool

SoftBool in [0, 1].

greater_equal¤

diffbio.core.soft_ops.comparison.greater_equal ¤

greater_equal(
    x: Array,
    y: float | Array,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    epsilon: float = 1e-10,
) -> SoftBool

Soft x >= y.

Uses sigmoidal on x - y + epsilon so the output approaches 1 at equality as softness -> 0.

Parameters:

Name Type Description Default
x Array

First input array.

required
y float | Array

Second input array.

required
softness float | Array

Width of transition (> 0).

0.1
mode Mode

"hard" or sigmoidal mode.

'smooth'
epsilon float

Small offset for non-strict inequality at the limit.

1e-10

Returns:

Type Description
SoftBool

SoftBool in [0, 1].

less¤

diffbio.core.soft_ops.comparison.less ¤

less(
    x: Array,
    y: float | Array,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    epsilon: float = 1e-10,
) -> SoftBool

Soft x < y. Complement of :func:greater_equal.

Parameters:

Name Type Description Default
x Array

First input array.

required
y float | Array

Second input array.

required
softness float | Array

Width of transition (> 0).

0.1
mode Mode

"hard" or sigmoidal mode.

'smooth'
epsilon float

Small offset.

1e-10

Returns:

Type Description
SoftBool

SoftBool in [0, 1].

less_equal¤

diffbio.core.soft_ops.comparison.less_equal ¤

less_equal(
    x: Array,
    y: float | Array,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    epsilon: float = 1e-10,
) -> SoftBool

Soft x <= y. Complement of :func:greater.

Parameters:

Name Type Description Default
x Array

First input array.

required
y float | Array

Second input array.

required
softness float | Array

Width of transition (> 0).

0.1
mode Mode

"hard" or sigmoidal mode.

'smooth'
epsilon float

Small offset.

1e-10

Returns:

Type Description
SoftBool

SoftBool in [0, 1].

equal¤

diffbio.core.soft_ops.comparison.equal ¤

equal(
    x: Array,
    y: Array,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    epsilon: float = 1e-10,
) -> SoftBool

Soft x == y.

Implemented as soft abs(x - y) <= 0, scaled to [0, 1].

Parameters:

Name Type Description Default
x Array

First input array.

required
y Array

Second input array.

required
softness float | Array

Width of transition (> 0).

0.1
mode Mode

"hard" or sigmoidal mode.

'smooth'
epsilon float

Small offset.

1e-10

Returns:

Type Description
SoftBool

SoftBool in [0, 1].

not_equal¤

diffbio.core.soft_ops.comparison.not_equal ¤

not_equal(
    x: Array,
    y: Array,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    epsilon: float = 1e-10,
) -> SoftBool

Soft x != y.

Implemented as soft abs(x - y) > 0, scaled to [0, 1].

Parameters:

Name Type Description Default
x Array

First input array.

required
y Array

Second input array.

required
softness float | Array

Width of transition (> 0).

0.1
mode Mode

"hard" or sigmoidal mode.

'smooth'
epsilon float

Small offset.

1e-10

Returns:

Type Description
SoftBool

SoftBool in [0, 1].

isclose¤

diffbio.core.soft_ops.comparison.isclose ¤

isclose(
    x: Array,
    y: Array,
    softness: float | Array = 0.1,
    rtol: float = 1e-05,
    atol: float = 1e-08,
    mode: Mode = "smooth",
    epsilon: float = 1e-10,
) -> SoftBool

Soft approximate equality.

Implements soft abs(x - y) <= atol + rtol * abs(y).

Parameters:

Name Type Description Default
x Array

First input array.

required
y Array

Second input array.

required
softness float | Array

Width of transition (> 0).

0.1
rtol float

Relative tolerance.

1e-05
atol float

Absolute tolerance.

1e-08
mode Mode

"hard" or sigmoidal mode.

'smooth'
epsilon float

Small offset.

1e-10

Returns:

Type Description
SoftBool

SoftBool in [0, 1].


Logical¤

Differentiable fuzzy logic operations on SoftBool values. These operate purely on probability values in [0, 1] and do not take a softness parameter.

Fuzzy logic semantics:

  • NOT: 1 - x
  • AND (product): prod(x) or geometric mean
  • OR: 1 - AND(NOT(x))
  • XOR: AND(x, NOT(y)) OR AND(NOT(x), y)

logical_not¤

diffbio.core.soft_ops.logical.logical_not ¤

logical_not(x: SoftBool) -> SoftBool

Soft logical NOT: 1 - x.

Parameters:

Name Type Description Default
x SoftBool

SoftBool input in [0, 1].

required

Returns:

Type Description
SoftBool

Complement in [0, 1].

all¤

diffbio.core.soft_ops.logical.all ¤

all(
    x: SoftBool,
    axis: int = -1,
    epsilon: float = 1e-10,
    use_geometric_mean: bool = False,
) -> SoftBool

Soft logical AND reduction along axis.

Uses product (default) or geometric mean to combine probabilities.

Parameters:

Name Type Description Default
x SoftBool

SoftBool input in [0, 1].

required
axis int

Axis along which to reduce.

-1
epsilon float

Minimum value for numerical stability in log.

1e-10
use_geometric_mean bool

If True, use geometric mean instead of product.

False

Returns:

Type Description
SoftBool

Reduced SoftBool.

any¤

diffbio.core.soft_ops.logical.any ¤

any(
    x: SoftBool,
    axis: int = -1,
    use_geometric_mean: bool = False,
) -> SoftBool

Soft logical OR reduction along axis.

Implemented as 1 - all(1 - x).

Parameters:

Name Type Description Default
x SoftBool

SoftBool input in [0, 1].

required
axis int

Axis along which to reduce.

-1
use_geometric_mean bool

If True, use geometric mean in the inner AND.

False

Returns:

Type Description
SoftBool

Reduced SoftBool.

logical_and¤

diffbio.core.soft_ops.logical.logical_and ¤

logical_and(
    x: SoftBool,
    y: SoftBool,
    use_geometric_mean: bool = False,
) -> SoftBool

Soft logical AND between two SoftBools.

Stacks inputs and applies :func:all along the stack axis.

Parameters:

Name Type Description Default
x SoftBool

First SoftBool.

required
y SoftBool

Second SoftBool.

required
use_geometric_mean bool

If True, use geometric mean.

False

Returns:

Type Description
SoftBool

SoftBool in [0, 1].

logical_or¤

diffbio.core.soft_ops.logical.logical_or ¤

logical_or(
    x: SoftBool,
    y: SoftBool,
    use_geometric_mean: bool = False,
) -> SoftBool

Soft logical OR between two SoftBools.

Stacks inputs and applies :func:any along the stack axis.

Parameters:

Name Type Description Default
x SoftBool

First SoftBool.

required
y SoftBool

Second SoftBool.

required
use_geometric_mean bool

If True, use geometric mean in inner AND.

False

Returns:

Type Description
SoftBool

SoftBool in [0, 1].

logical_xor¤

diffbio.core.soft_ops.logical.logical_xor ¤

logical_xor(
    x: SoftBool,
    y: SoftBool,
    use_geometric_mean: bool = False,
) -> SoftBool

Soft logical XOR between two SoftBools.

Implemented as (x AND NOT y) OR (NOT x AND y).

Parameters:

Name Type Description Default
x SoftBool

First SoftBool.

required
y SoftBool

Second SoftBool.

required
use_geometric_mean bool

If True, use geometric mean in AND/OR.

False

Returns:

Type Description
SoftBool

SoftBool in [0, 1].


Selection¤

Differentiable relaxations of array selection and indexing operations. These use SoftBool conditions and SoftIndex probability distributions in place of discrete boolean masks and integer indices.

where¤

diffbio.core.soft_ops.selection.where ¤

where(condition: SoftBool, x: Array, y: Array) -> Array

Soft where: x * condition + y * (1 - condition).

Unlike jnp.where, this smoothly interpolates between x and y based on the continuous condition value.

Parameters:

Name Type Description Default
condition SoftBool

SoftBool in [0, 1], same shape as x and y.

required
x Array

Values selected when condition is 1.

required
y Array

Values selected when condition is 0.

required

Returns:

Type Description
Array

Interpolated array.

take_along_axis¤

diffbio.core.soft_ops.selection.take_along_axis ¤

take_along_axis(
    x: Array, soft_index: SoftIndex, axis: int | None = -1
) -> Array

Soft take_along_axis via weighted dot product.

soft_index must have one more dimension than x: the extra (last) dimension contains the probability distribution over the elements along axis.

Parameters:

Name Type Description Default
x Array

Input array of shape (..., n, ...).

required
soft_index SoftIndex

SoftIndex of shape (..., k, ..., [n]) where [n] is the probability distribution dimension.

required
axis int | None

Axis in x to select from. If None, x is flattened.

-1

Returns:

Type Description
Array

Array of shape (..., k, ...).

take¤

diffbio.core.soft_ops.selection.take ¤

take(
    x: Array, soft_index: SoftIndex, axis: int | None = None
) -> Array

Soft take via weighted dot product.

Unlike :func:take_along_axis, soft_index is a 2-D matrix of shape (k, [n]) applied uniformly across batch dimensions.

Parameters:

Name Type Description Default
x Array

Input array of shape (..., n, ...).

required
soft_index SoftIndex

SoftIndex of shape (k, [n]).

required
axis int | None

Axis to select from. If None, x is flattened.

None

Returns:

Type Description
Array

Array of shape (..., k, ...).

choose¤

diffbio.core.soft_ops.selection.choose ¤

choose(soft_index: SoftIndex, choices: Array) -> Array

Soft choose among multiple arrays.

Softly selects among choices using soft_index weights.

Parameters:

Name Type Description Default
soft_index SoftIndex

SoftIndex of shape (..., [n]).

required
choices Array

Array of shape (n, ...).

required

Returns:

Type Description
Array

Weighted combination of choices.

dynamic_index_in_dim¤

diffbio.core.soft_ops.selection.dynamic_index_in_dim ¤

dynamic_index_in_dim(
    x: Array,
    soft_index: SoftIndex,
    axis: int = 0,
    keepdims: bool = True,
) -> Array

Soft dynamic indexing along a dimension.

Selects a single element (weighted combination) along axis using the probability distribution soft_index.

Parameters:

Name Type Description Default
x Array

Input array of shape (..., n, ...).

required
soft_index SoftIndex

SoftIndex of shape ([n],).

required
axis int

Axis to index.

0
keepdims bool

If True, retains the indexed dimension as size 1.

True

Returns:

Type Description
Array

Indexed array.

dynamic_slice_in_dim¤

diffbio.core.soft_ops.selection.dynamic_slice_in_dim ¤

dynamic_slice_in_dim(
    x: Array,
    soft_start_index: SoftIndex,
    slice_size: int,
    axis: int = 0,
) -> Array

Soft dynamic slicing along a dimension.

Extracts a soft slice of slice_size elements starting at the position defined by soft_start_index.

Parameters:

Name Type Description Default
x Array

Input array of shape (..., n, ...).

required
soft_start_index SoftIndex

SoftIndex of shape ([n],).

required
slice_size int

Number of elements to extract.

required
axis int

Axis to slice.

0

Returns:

Type Description
Array

Array of shape (..., slice_size, ...).

dynamic_slice¤

diffbio.core.soft_ops.selection.dynamic_slice ¤

dynamic_slice(
    x: Array,
    soft_start_indices: Sequence[SoftIndex],
    slice_sizes: Sequence[int],
) -> Array

Soft dynamic slicing across multiple dimensions.

Applies :func:dynamic_slice_in_dim sequentially along each axis.

Parameters:

Name Type Description Default
x Array

Input array of shape (n_1, n_2, ..., n_k).

required
soft_start_indices Sequence[SoftIndex]

One SoftIndex per dimension.

required
slice_sizes Sequence[int]

One slice length per dimension.

required

Returns:

Type Description
Array

Array of shape (l_1, l_2, ..., l_k).


Sorting¤

Differentiable relaxations of discrete ordering operations including argmax/argmin, argsort, sort, rank, and top-k. Multiple algorithmic backends are available:

Method Complexity Default for
"softsort" O(n log n) argmax, argmin
"neuralsort" O(n^2) argsort, sort
"sorting_network" O(n log^2 n) --
"ot" varies --
"fast_soft_sort" O(n log n) --
"smooth_sort" O(n log n) --

The ot, fast_soft_sort, and smooth_sort methods require the soft-ops-advanced optional dependency group.

argmax¤

diffbio.core.soft_ops.sorting.argmax ¤

argmax(
    x: Array,
    axis: int | None = None,
    keepdims: bool = False,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    method: ArgMethod = "softsort",
    standardize: bool = True,
    ot_kwargs: dict | None = None,
) -> SoftIndex

Soft argmax returning a SoftIndex (probability distribution).

Parameters:

Name Type Description Default
x Array

Input array.

required
axis int | None

Axis along which to compute argmax. None flattens first.

None
keepdims bool

If True, keep the reduced dimension as singleton.

False
softness float | Array

Controls sharpness (> 0).

0.1
mode Mode

Smoothness mode.

'smooth'
method ArgMethod

Algorithm: "softsort", "neuralsort", "sorting_network", or "ot".

'softsort'
standardize bool

If True, standardize input for numerical stability.

True
ot_kwargs dict | None

Extra kwargs for OT method.

None

Returns:

Type Description
SoftIndex

SoftIndex of shape (..., {1}, ..., [n]).

max¤

diffbio.core.soft_ops.sorting.max ¤

max(
    x: Array,
    axis: int | None = None,
    keepdims: bool = False,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    method: SortMethod = "softsort",
    standardize: bool = True,
    ot_kwargs: dict | None = None,
    gated_grad: bool = True,
) -> Array

Soft max via argmax + take_along_axis.

For sorting_network method, uses sort + take first element.

Parameters:

Name Type Description Default
x Array

Input array.

required
axis int | None

Axis along which to compute max.

None
keepdims bool

If True, keep reduced dimension.

False
softness float | Array

Controls sharpness (> 0).

0.1
mode Mode

Smoothness mode.

'smooth'
method SortMethod

Algorithm (see :func:argmax and :func:sort).

'softsort'
standardize bool

If True, standardize input.

True
ot_kwargs dict | None

Extra kwargs for OT method.

None
gated_grad bool

If False, stop gradient through soft index.

True

Returns:

Type Description
Array

Soft maximum value(s).

argmin¤

diffbio.core.soft_ops.sorting.argmin ¤

argmin(
    x: Array,
    axis: int | None = None,
    keepdims: bool = False,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    method: ArgMethod = "softsort",
    standardize: bool = True,
    ot_kwargs: dict | None = None,
) -> SoftIndex

Soft argmin: :func:argmax on -x.

min¤

diffbio.core.soft_ops.sorting.min ¤

min(
    x: Array,
    axis: int | None = None,
    keepdims: bool = False,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    method: SortMethod = "softsort",
    standardize: bool = True,
    ot_kwargs: dict | None = None,
    gated_grad: bool = True,
) -> Array

Soft min: -max(-x).

argsort¤

diffbio.core.soft_ops.sorting.argsort ¤

argsort(
    x: Array,
    axis: int | None = None,
    descending: bool = False,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    method: ArgMethod = "neuralsort",
    standardize: bool = True,
    ot_kwargs: dict | None = None,
) -> SoftIndex

Soft argsort returning a soft permutation matrix.

Output shape is (..., n, ..., [n]) where the last dimension is the probability distribution over original elements.

Parameters:

Name Type Description Default
x Array

Input array.

required
axis int | None

Axis along which to argsort. None flattens first.

None
descending bool

If True, sort descending.

False
softness float | Array

Controls sharpness (> 0).

0.1
mode Mode

Smoothness mode.

'smooth'
method ArgMethod

Algorithm.

'neuralsort'
standardize bool

If True, standardize input.

True
ot_kwargs dict | None

Extra kwargs for OT method.

None

Returns:

Type Description
SoftIndex

SoftIndex permutation matrix.

sort¤

diffbio.core.soft_ops.sorting.sort ¤

sort(
    x: Array,
    axis: int | None = None,
    descending: bool = False,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    method: SortMethod = "neuralsort",
    standardize: bool = True,
    ot_kwargs: dict | None = None,
    gated_grad: bool = True,
) -> Array

Soft sort returning sorted values.

Parameters:

Name Type Description Default
x Array

Input array.

required
axis int | None

Axis along which to sort. None flattens first.

None
descending bool

If True, sort descending.

False
softness float | Array

Controls sharpness (> 0).

0.1
mode Mode

Smoothness mode.

'smooth'
method SortMethod

Algorithm.

'neuralsort'
standardize bool

If True, standardize input.

True
ot_kwargs dict | None

Extra kwargs for OT method.

None
gated_grad bool

If False, stop gradient through soft index.

True

Returns:

Type Description
Array

Soft-sorted values.

rank¤

diffbio.core.soft_ops.sorting.rank ¤

rank(
    x: Array,
    axis: int | None = None,
    descending: bool = False,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    method: RankMethod = "softsort",
    standardize: bool = True,
) -> Array

Soft fractional ranking.

Returns continuous ranks in [1, n] where 1 is the smallest.

Parameters:

Name Type Description Default
x Array

Input array.

required
axis int | None

Axis along which to rank.

None
descending bool

If True, rank 1 = largest.

False
softness float | Array

Controls sharpness (> 0).

0.1
mode Mode

Smoothness mode.

'smooth'
method RankMethod

"softsort" or "neuralsort".

'softsort'
standardize bool

If True, standardize input.

True

Returns:

Type Description
Array

Continuous ranks.

top_k¤

diffbio.core.soft_ops.sorting.top_k ¤

top_k(
    x: Array,
    k: int,
    axis: int = -1,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    method: SortMethod = "neuralsort",
    standardize: bool = True,
    ot_kwargs: dict | None = None,
    gated_grad: bool = True,
) -> tuple[Array, SoftIndex | None]

Soft top-k selection.

Returns the k largest values and their soft indices.

Parameters:

Name Type Description Default
x Array

Input array.

required
k int

Number of top elements.

required
axis int

Axis along which to select. Default -1 (last axis).

-1
softness float | Array

Controls sharpness (> 0).

0.1
mode Mode

Smoothness mode.

'smooth'
method SortMethod

Sorting algorithm. Default "neuralsort".

'neuralsort'
standardize bool

If True, standardize input.

True
ot_kwargs dict | None

Extra keyword arguments for OT-based methods.

None
gated_grad bool

If False, stop gradient through soft index.

True

Returns:

Type Description
Array

Tuple of (values, soft_indices) where values has shape

SoftIndex | None

(..., k, ...) and soft_indices has shape

tuple[Array, SoftIndex | None]

(..., k, ..., [n]). soft_indices may be None for

tuple[Array, SoftIndex | None]

methods that only return values (fast_soft_sort, sorting_network).


Quantile¤

Differentiable relaxations of quantile-based statistics. Quantiles are computed via soft argsort or soft sort with interpolation following the same methods as jax.numpy.quantile.

argquantile¤

diffbio.core.soft_ops.quantile.argquantile ¤

argquantile(
    x: Array,
    q: Array,
    axis: int | None = None,
    keepdims: bool = False,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    method: ArgMethod = "neuralsort",
    quantile_method: Literal[
        "linear", "lower", "higher", "nearest", "midpoint"
    ] = "linear",
    standardize: bool = True,
) -> SoftIndex

Soft argquantile returning SoftIndex.

Parameters:

Name Type Description Default
x Array

Input array.

required
q Array

Quantile(s) in [0, 1]. Scalar or 1-D array.

required
axis int | None

Axis along which to compute. None flattens.

None
keepdims bool

If True, keep reduced dimension.

False
softness float | Array

Controls sharpness (> 0).

0.1
mode Mode

Smoothness mode.

'smooth'
method ArgMethod

Algorithm.

'neuralsort'
quantile_method Literal['linear', 'lower', 'higher', 'nearest', 'midpoint']

Interpolation method.

'linear'
standardize bool

If True, standardize input.

True

Returns:

Type Description
SoftIndex

SoftIndex probability distribution over quantile position(s).

quantile¤

diffbio.core.soft_ops.quantile.quantile ¤

quantile(
    x: Array,
    q: Array,
    axis: int | None = None,
    keepdims: bool = False,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    method: ArgMethod = "neuralsort",
    quantile_method: Literal[
        "linear", "lower", "higher", "nearest", "midpoint"
    ] = "linear",
    standardize: bool = True,
    gated_grad: bool = True,
) -> Array

Soft quantile returning value.

Implemented as :func:argquantile + :func:take_along_axis for most methods.

Parameters:

Name Type Description Default
x Array

Input array.

required
q Array

Quantile(s) in [0, 1].

required
axis int | None

Axis along which to compute.

None
keepdims bool

If True, keep reduced dimension.

False
softness float | Array

Controls sharpness (> 0).

0.1
mode Mode

Smoothness mode.

'smooth'
method ArgMethod

Algorithm.

'neuralsort'
quantile_method Literal['linear', 'lower', 'higher', 'nearest', 'midpoint']

Interpolation method.

'linear'
standardize bool

If True, standardize input.

True
gated_grad bool

If False, stop gradient through soft index.

True

Returns:

Type Description
Array

Quantile value(s).

argmedian¤

diffbio.core.soft_ops.quantile.argmedian ¤

argmedian(
    x: Array,
    axis: int | None = None,
    keepdims: bool = False,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    method: ArgMethod = "neuralsort",
    standardize: bool = True,
) -> SoftIndex

Soft argmedian: :func:argquantile with q=0.5.

median¤

diffbio.core.soft_ops.quantile.median ¤

median(
    x: Array,
    axis: int | None = None,
    keepdims: bool = False,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    method: ArgMethod = "neuralsort",
    standardize: bool = True,
    gated_grad: bool = True,
) -> Array

Soft median: :func:quantile with q=0.5.

argpercentile¤

diffbio.core.soft_ops.quantile.argpercentile ¤

argpercentile(
    x: Array,
    p: Array,
    axis: int | None = None,
    keepdims: bool = False,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    method: ArgMethod = "neuralsort",
    standardize: bool = True,
) -> SoftIndex

Soft argpercentile: :func:argquantile with q = p / 100.

percentile¤

diffbio.core.soft_ops.quantile.percentile ¤

percentile(
    x: Array,
    p: Array,
    axis: int | None = None,
    keepdims: bool = False,
    softness: float | Array = 0.1,
    mode: Mode = "smooth",
    method: ArgMethod = "neuralsort",
    standardize: bool = True,
    gated_grad: bool = True,
) -> Array

Soft percentile: :func:quantile with q = p / 100.


Straight-Through Estimators¤

Straight-through estimators use the hard (exact, non-differentiable) function for the forward pass but route gradients through the soft (differentiable) version during backpropagation. The trick stop_gradient(hard - soft) + soft ensures the forward output is exact while gradients flow through the smooth relaxation.

st¤

diffbio.core.soft_ops.straight_through.st ¤

st(fn: Callable) -> Callable

Decorator creating a straight-through estimator from a soft_ops function.

The decorated function is called twice: once with mode="hard" (forward pass) and once with the specified mode (backward pass). The trick stop_gradient(hard - soft) + soft ensures the forward output is hard but gradients flow through the soft version.

Parameters:

Name Type Description Default
fn Callable

Function with a mode parameter (e.g., any elementwise, comparison, or sorting function).

required

Returns:

Type Description
Callable

Wrapped straight-through estimator function.

grad_replace¤

diffbio.core.soft_ops.straight_through.grad_replace ¤

grad_replace(fn: Callable) -> Callable

Decorator for custom forward/backward computation split.

The decorated function is called twice: once with forward=True (output used for forward pass) and once with forward=False (output used for gradient computation).

Parameters:

Name Type Description Default
fn Callable

Function accepting a forward: bool keyword argument.

required

Returns:

Type Description
Callable

Wrapped function using hard forward, soft backward.

Pre-built _st Variants¤

The module provides 27 pre-built straight-through variants, one for each soft operation that accepts a mode parameter. Each variant uses the hard function for the forward pass and the corresponding soft function for the backward pass.

Elementwise Comparison Sorting Quantile
abs_st equal_st argmax_st argmedian_st
clip_st greater_st argmin_st argpercentile_st
heaviside_st greater_equal_st argsort_st argquantile_st
relu_st isclose_st max_st median_st
round_st less_st min_st percentile_st
sign_st less_equal_st rank_st quantile_st
not_equal_st sort_st
top_k_st

Each _st variant accepts the same arguments as its base function. For example:

from diffbio.core.soft_ops import relu_st, sort_st

# Forward uses jax.nn.relu; backward uses soft relu
y = relu_st(x, softness=0.1, mode="smooth")

# Forward uses jnp.sort; backward uses soft sort
y = sort_st(x, axis=-1, softness=0.1, mode="smooth")