Soft Operations¤
The diffbio.core.soft_ops module provides smooth, differentiable relaxations of discrete, piecewise-linear, and sharp operations. These relaxations enable end-to-end gradient-based optimization through operations that are normally non-differentiable -- comparisons, sorting, indexing, rounding, and logical gates all become continuous functions with well-defined gradients.
Acknowledgment
The soft operations in this module are based on the algorithms and implementations from SoftJAX (Paulus et al., 2026; arXiv:2603.08824), adapted for the DiffBio/JAX/Flax NNX ecosystem.
All soft operations are JIT-compatible and support jax.grad and jax.vmap. Most accept a softness parameter controlling the width of the transition region (higher = smoother) and a mode parameter selecting the smoothness family:
| Mode | Description |
|---|---|
"hard" |
Exact (non-differentiable) version matching JAX |
"smooth" |
C-infinity smooth via logistic sigmoid |
"c0" |
Continuous (C0) via piecewise linear/quadratic |
"c1" |
Once differentiable (C1) via cubic Hermite polynomial |
"c2" |
Twice differentiable (C2) via quintic Hermite polynomial |
Name shadowing
Several names (abs, all, any, round, max, min) shadow Python builtins.
Use qualified imports:
or alias on import:
Types¤
SoftBool¤
diffbio.core.soft_ops._types.SoftBool
module-attribute
¤
Soft boolean: probability in [0, 1].
SoftIndex¤
diffbio.core.soft_ops._types.SoftIndex
module-attribute
¤
Soft index: probabilities summing to 1 along the last axis.
Autograd-Safe Math¤
NaN-free alternatives to standard math functions. These use the double-where trick so that the forward pass computes correct values even at domain boundaries, while the backward pass produces finite (zero) gradients instead of NaN or Inf.
sqrt¤
diffbio.core.soft_ops.autograd_safe.sqrt
¤
Autograd-safe square root.
Returns sqrt(x) for x > 0 and 0 otherwise, without
producing NaN gradients at x = 0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Elementwise square root, safe for autodiff. |
arcsin¤
diffbio.core.soft_ops.autograd_safe.arcsin
¤
Autograd-safe arcsine.
Returns arcsin(x) for |x| < 1 and +/-pi/2 at the
boundary, without producing NaN gradients at x = +/-1.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array with values in [-1, 1]. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Elementwise arcsine, safe for autodiff. |
arccos¤
diffbio.core.soft_ops.autograd_safe.arccos
¤
Autograd-safe arccosine.
Returns arccos(x) for |x| < 1, 0 at x = 1, and
pi at x = -1, without producing NaN gradients at the
boundary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array with values in [-1, 1]. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Elementwise arccosine, safe for autodiff. |
div¤
diffbio.core.soft_ops.autograd_safe.div
¤
Autograd-safe division.
Returns x / y when y != 0 and 0 otherwise, without
producing NaN gradients at y = 0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Numerator array. |
required |
y
|
Array
|
Denominator array. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Elementwise safe division. |
log¤
diffbio.core.soft_ops.autograd_safe.log
¤
Autograd-safe natural logarithm.
Returns log(x) for x > 0 and 0 otherwise, without
producing NaN gradients at x = 0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Elementwise natural logarithm, safe for autodiff. |
norm¤
diffbio.core.soft_ops.autograd_safe.norm
¤
Autograd-safe L2 norm.
Computes sqrt(sum(x**2)) using :func:sqrt, avoiding NaN
gradients when the norm is zero.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array. |
required |
axis
|
int | None
|
Axis or axes along which to compute the norm. |
None
|
keepdims
|
bool
|
If True, retains reduced axes with size 1. |
False
|
Returns:
| Type | Description |
|---|---|
Array
|
L2 norm along the given axis, safe for autodiff. |
Elementwise¤
Differentiable relaxations of elementwise non-smooth functions using sigmoidal smoothing with configurable smoothness modes.
sigmoidal¤
diffbio.core.soft_ops.elementwise.sigmoidal
¤
sigmoidal(
x: Array,
softness: float | Array = 0.1,
mode: SigmoidalMode = "smooth",
) -> SoftBool
Sigmoidal S-curve function mapping R -> (0, 1).
Foundation for all other elementwise operations. Maps input values through an S-shaped curve centered at 0, approaching 0 for large negative values and 1 for large positive values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array. |
required |
softness
|
float | Array
|
Width of transition region (> 0). Higher = smoother. |
0.1
|
mode
|
SigmoidalMode
|
Smoothness family: |
'smooth'
|
Returns:
| Type | Description |
|---|---|
SoftBool
|
SoftBool array with values in [0, 1]. |
softrelu¤
diffbio.core.soft_ops.elementwise.softrelu
¤
softrelu(
x: Array,
softness: float | Array = 0.1,
mode: SigmoidalMode = "smooth",
gated: bool = False,
) -> Array
Family of soft relaxations to ReLU.
Two variants:
- Non-gated (default): Antiderivative of :func:sigmoidal.
Smooth analog of max(0, x).
- Gated: x * sigmoidal(x). SiLU-style gating.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array. |
required |
softness
|
float | Array
|
Width of transition region (> 0). |
0.1
|
mode
|
SigmoidalMode
|
Smoothness family (see :func: |
'smooth'
|
gated
|
bool
|
If True, use gated version |
False
|
Returns:
| Type | Description |
|---|---|
Array
|
Soft ReLU output, same shape as |
heaviside¤
diffbio.core.soft_ops.elementwise.heaviside
¤
heaviside(
x: Array,
softness: float | Array = 0.1,
mode: Mode = "smooth",
) -> SoftBool
Soft Heaviside step function.
Returns 0 for x < 0, 1 for x > 0, and 0.5 at x = 0 (hard mode).
Soft modes use :func:sigmoidal for smooth transition.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array. |
required |
softness
|
float | Array
|
Width of transition (> 0). |
0.1
|
mode
|
Mode
|
|
'smooth'
|
Returns:
| Type | Description |
|---|---|
SoftBool
|
SoftBool in [0, 1]. |
round¤
diffbio.core.soft_ops.elementwise.round
¤
round(
x: Array,
softness: float | Array = 0.1,
mode: Mode = "smooth",
neighbor_radius: int = 5,
) -> Array
Soft rounding.
Hard mode returns jnp.round(x). Soft modes use a weighted sum
of nearby integers, with weights from :func:sigmoidal.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array. |
required |
softness
|
float | Array
|
Width of transition (> 0). |
0.1
|
mode
|
Mode
|
|
'smooth'
|
neighbor_radius
|
int
|
Number of integer neighbors to consider. |
5
|
Returns:
| Type | Description |
|---|---|
Array
|
Soft-rounded values. |
sign¤
diffbio.core.soft_ops.elementwise.sign
¤
Soft sign function.
Maps to [-1, 1]. Hard mode returns jnp.sign(x).
Soft modes use 2 * sigmoidal(x) - 1.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array. |
required |
softness
|
float | Array
|
Width of transition (> 0). |
0.1
|
mode
|
Mode
|
|
'smooth'
|
Returns:
| Type | Description |
|---|---|
Array
|
Values in [-1, 1]. |
abs¤
diffbio.core.soft_ops.elementwise.abs
¤
Soft absolute value.
Hard mode returns jnp.abs(x). Soft modes use
x * sign(x, softness, mode).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array. |
required |
softness
|
float | Array
|
Width of transition (> 0). |
0.1
|
mode
|
Mode
|
|
'smooth'
|
Returns:
| Type | Description |
|---|---|
Array
|
Non-negative values (approximately). |
relu¤
diffbio.core.soft_ops.elementwise.relu
¤
relu(
x: Array,
softness: float | Array = 0.1,
mode: Mode = "smooth",
gated: bool = False,
) -> Array
Soft ReLU.
Hard mode returns jax.nn.relu(x). Soft modes delegate to
:func:softrelu.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array. |
required |
softness
|
float | Array
|
Width of transition (> 0). |
0.1
|
mode
|
Mode
|
|
'smooth'
|
gated
|
bool
|
If True, use gated variant. |
False
|
Returns:
| Type | Description |
|---|---|
Array
|
Soft ReLU output. |
clip¤
diffbio.core.soft_ops.elementwise.clip
¤
clip(
x: Array,
a: float | Array,
b: float | Array,
softness: float | Array = 0.1,
mode: Mode = "smooth",
gated: bool = False,
) -> Array
Soft clipping to [a, b].
Hard mode returns jnp.clip(x, a, b). Soft modes use two
:func:softrelu calls: a + softrelu(x - a) - softrelu(x - b).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array. |
required |
a
|
float | Array
|
Lower bound. |
required |
b
|
float | Array
|
Upper bound. |
required |
softness
|
float | Array
|
Width of transition (> 0). |
0.1
|
mode
|
Mode
|
|
'smooth'
|
gated
|
bool
|
If True, use gated softrelu variant. |
False
|
Returns:
| Type | Description |
|---|---|
Array
|
Clipped values approximately in [a, b]. |
Comparison¤
Differentiable relaxations of elementwise comparison operations, returning SoftBool values in [0, 1]. Each function uses sigmoidal as the underlying smooth step function.
greater¤
diffbio.core.soft_ops.comparison.greater
¤
greater(
x: Array,
y: float | Array,
softness: float | Array = 0.1,
mode: Mode = "smooth",
epsilon: float = 1e-10,
) -> SoftBool
Soft x > y.
Uses sigmoidal on x - y - epsilon so the output approaches 0
at equality as softness -> 0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
First input array. |
required |
y
|
float | Array
|
Second input array (broadcastable with x). |
required |
softness
|
float | Array
|
Width of transition (> 0). |
0.1
|
mode
|
Mode
|
|
'smooth'
|
epsilon
|
float
|
Small offset for strict inequality at the limit. |
1e-10
|
Returns:
| Type | Description |
|---|---|
SoftBool
|
SoftBool in [0, 1]. |
greater_equal¤
diffbio.core.soft_ops.comparison.greater_equal
¤
greater_equal(
x: Array,
y: float | Array,
softness: float | Array = 0.1,
mode: Mode = "smooth",
epsilon: float = 1e-10,
) -> SoftBool
Soft x >= y.
Uses sigmoidal on x - y + epsilon so the output approaches 1
at equality as softness -> 0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
First input array. |
required |
y
|
float | Array
|
Second input array. |
required |
softness
|
float | Array
|
Width of transition (> 0). |
0.1
|
mode
|
Mode
|
|
'smooth'
|
epsilon
|
float
|
Small offset for non-strict inequality at the limit. |
1e-10
|
Returns:
| Type | Description |
|---|---|
SoftBool
|
SoftBool in [0, 1]. |
less¤
diffbio.core.soft_ops.comparison.less
¤
less(
x: Array,
y: float | Array,
softness: float | Array = 0.1,
mode: Mode = "smooth",
epsilon: float = 1e-10,
) -> SoftBool
Soft x < y. Complement of :func:greater_equal.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
First input array. |
required |
y
|
float | Array
|
Second input array. |
required |
softness
|
float | Array
|
Width of transition (> 0). |
0.1
|
mode
|
Mode
|
|
'smooth'
|
epsilon
|
float
|
Small offset. |
1e-10
|
Returns:
| Type | Description |
|---|---|
SoftBool
|
SoftBool in [0, 1]. |
less_equal¤
diffbio.core.soft_ops.comparison.less_equal
¤
less_equal(
x: Array,
y: float | Array,
softness: float | Array = 0.1,
mode: Mode = "smooth",
epsilon: float = 1e-10,
) -> SoftBool
Soft x <= y. Complement of :func:greater.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
First input array. |
required |
y
|
float | Array
|
Second input array. |
required |
softness
|
float | Array
|
Width of transition (> 0). |
0.1
|
mode
|
Mode
|
|
'smooth'
|
epsilon
|
float
|
Small offset. |
1e-10
|
Returns:
| Type | Description |
|---|---|
SoftBool
|
SoftBool in [0, 1]. |
equal¤
diffbio.core.soft_ops.comparison.equal
¤
equal(
x: Array,
y: Array,
softness: float | Array = 0.1,
mode: Mode = "smooth",
epsilon: float = 1e-10,
) -> SoftBool
Soft x == y.
Implemented as soft abs(x - y) <= 0, scaled to [0, 1].
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
First input array. |
required |
y
|
Array
|
Second input array. |
required |
softness
|
float | Array
|
Width of transition (> 0). |
0.1
|
mode
|
Mode
|
|
'smooth'
|
epsilon
|
float
|
Small offset. |
1e-10
|
Returns:
| Type | Description |
|---|---|
SoftBool
|
SoftBool in [0, 1]. |
not_equal¤
diffbio.core.soft_ops.comparison.not_equal
¤
not_equal(
x: Array,
y: Array,
softness: float | Array = 0.1,
mode: Mode = "smooth",
epsilon: float = 1e-10,
) -> SoftBool
Soft x != y.
Implemented as soft abs(x - y) > 0, scaled to [0, 1].
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
First input array. |
required |
y
|
Array
|
Second input array. |
required |
softness
|
float | Array
|
Width of transition (> 0). |
0.1
|
mode
|
Mode
|
|
'smooth'
|
epsilon
|
float
|
Small offset. |
1e-10
|
Returns:
| Type | Description |
|---|---|
SoftBool
|
SoftBool in [0, 1]. |
isclose¤
diffbio.core.soft_ops.comparison.isclose
¤
isclose(
x: Array,
y: Array,
softness: float | Array = 0.1,
rtol: float = 1e-05,
atol: float = 1e-08,
mode: Mode = "smooth",
epsilon: float = 1e-10,
) -> SoftBool
Soft approximate equality.
Implements soft abs(x - y) <= atol + rtol * abs(y).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
First input array. |
required |
y
|
Array
|
Second input array. |
required |
softness
|
float | Array
|
Width of transition (> 0). |
0.1
|
rtol
|
float
|
Relative tolerance. |
1e-05
|
atol
|
float
|
Absolute tolerance. |
1e-08
|
mode
|
Mode
|
|
'smooth'
|
epsilon
|
float
|
Small offset. |
1e-10
|
Returns:
| Type | Description |
|---|---|
SoftBool
|
SoftBool in [0, 1]. |
Logical¤
Differentiable fuzzy logic operations on SoftBool values. These operate purely on probability values in [0, 1] and do not take a softness parameter.
Fuzzy logic semantics:
- NOT:
1 - x - AND (product):
prod(x)or geometric mean - OR:
1 - AND(NOT(x)) - XOR:
AND(x, NOT(y)) OR AND(NOT(x), y)
logical_not¤
diffbio.core.soft_ops.logical.logical_not
¤
all¤
diffbio.core.soft_ops.logical.all
¤
all(
x: SoftBool,
axis: int = -1,
epsilon: float = 1e-10,
use_geometric_mean: bool = False,
) -> SoftBool
Soft logical AND reduction along axis.
Uses product (default) or geometric mean to combine probabilities.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
SoftBool
|
SoftBool input in [0, 1]. |
required |
axis
|
int
|
Axis along which to reduce. |
-1
|
epsilon
|
float
|
Minimum value for numerical stability in log. |
1e-10
|
use_geometric_mean
|
bool
|
If True, use geometric mean instead of product. |
False
|
Returns:
| Type | Description |
|---|---|
SoftBool
|
Reduced SoftBool. |
any¤
diffbio.core.soft_ops.logical.any
¤
Soft logical OR reduction along axis.
Implemented as 1 - all(1 - x).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
SoftBool
|
SoftBool input in [0, 1]. |
required |
axis
|
int
|
Axis along which to reduce. |
-1
|
use_geometric_mean
|
bool
|
If True, use geometric mean in the inner AND. |
False
|
Returns:
| Type | Description |
|---|---|
SoftBool
|
Reduced SoftBool. |
logical_and¤
diffbio.core.soft_ops.logical.logical_and
¤
Soft logical AND between two SoftBools.
Stacks inputs and applies :func:all along the stack axis.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
SoftBool
|
First SoftBool. |
required |
y
|
SoftBool
|
Second SoftBool. |
required |
use_geometric_mean
|
bool
|
If True, use geometric mean. |
False
|
Returns:
| Type | Description |
|---|---|
SoftBool
|
SoftBool in [0, 1]. |
logical_or¤
diffbio.core.soft_ops.logical.logical_or
¤
Soft logical OR between two SoftBools.
Stacks inputs and applies :func:any along the stack axis.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
SoftBool
|
First SoftBool. |
required |
y
|
SoftBool
|
Second SoftBool. |
required |
use_geometric_mean
|
bool
|
If True, use geometric mean in inner AND. |
False
|
Returns:
| Type | Description |
|---|---|
SoftBool
|
SoftBool in [0, 1]. |
logical_xor¤
diffbio.core.soft_ops.logical.logical_xor
¤
Soft logical XOR between two SoftBools.
Implemented as (x AND NOT y) OR (NOT x AND y).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
SoftBool
|
First SoftBool. |
required |
y
|
SoftBool
|
Second SoftBool. |
required |
use_geometric_mean
|
bool
|
If True, use geometric mean in AND/OR. |
False
|
Returns:
| Type | Description |
|---|---|
SoftBool
|
SoftBool in [0, 1]. |
Selection¤
Differentiable relaxations of array selection and indexing operations. These use SoftBool conditions and SoftIndex probability distributions in place of discrete boolean masks and integer indices.
where¤
diffbio.core.soft_ops.selection.where
¤
where(condition: SoftBool, x: Array, y: Array) -> Array
Soft where: x * condition + y * (1 - condition).
Unlike jnp.where, this smoothly interpolates between x and
y based on the continuous condition value.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
condition
|
SoftBool
|
SoftBool in [0, 1], same shape as x and y. |
required |
x
|
Array
|
Values selected when condition is 1. |
required |
y
|
Array
|
Values selected when condition is 0. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Interpolated array. |
take_along_axis¤
diffbio.core.soft_ops.selection.take_along_axis
¤
take_along_axis(
x: Array, soft_index: SoftIndex, axis: int | None = -1
) -> Array
Soft take_along_axis via weighted dot product.
soft_index must have one more dimension than x: the extra
(last) dimension contains the probability distribution over the
elements along axis.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array of shape |
required |
soft_index
|
SoftIndex
|
SoftIndex of shape |
required |
axis
|
int | None
|
Axis in |
-1
|
Returns:
| Type | Description |
|---|---|
Array
|
Array of shape |
take¤
diffbio.core.soft_ops.selection.take
¤
take(
x: Array, soft_index: SoftIndex, axis: int | None = None
) -> Array
Soft take via weighted dot product.
Unlike :func:take_along_axis, soft_index is a 2-D matrix
of shape (k, [n]) applied uniformly across batch dimensions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array of shape |
required |
soft_index
|
SoftIndex
|
SoftIndex of shape |
required |
axis
|
int | None
|
Axis to select from. If None, |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
Array of shape |
choose¤
diffbio.core.soft_ops.selection.choose
¤
choose(soft_index: SoftIndex, choices: Array) -> Array
Soft choose among multiple arrays.
Softly selects among choices using soft_index weights.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
soft_index
|
SoftIndex
|
SoftIndex of shape |
required |
choices
|
Array
|
Array of shape |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Weighted combination of choices. |
dynamic_index_in_dim¤
diffbio.core.soft_ops.selection.dynamic_index_in_dim
¤
dynamic_index_in_dim(
x: Array,
soft_index: SoftIndex,
axis: int = 0,
keepdims: bool = True,
) -> Array
Soft dynamic indexing along a dimension.
Selects a single element (weighted combination) along axis
using the probability distribution soft_index.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array of shape |
required |
soft_index
|
SoftIndex
|
SoftIndex of shape |
required |
axis
|
int
|
Axis to index. |
0
|
keepdims
|
bool
|
If True, retains the indexed dimension as size 1. |
True
|
Returns:
| Type | Description |
|---|---|
Array
|
Indexed array. |
dynamic_slice_in_dim¤
diffbio.core.soft_ops.selection.dynamic_slice_in_dim
¤
dynamic_slice_in_dim(
x: Array,
soft_start_index: SoftIndex,
slice_size: int,
axis: int = 0,
) -> Array
Soft dynamic slicing along a dimension.
Extracts a soft slice of slice_size elements starting at the
position defined by soft_start_index.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array of shape |
required |
soft_start_index
|
SoftIndex
|
SoftIndex of shape |
required |
slice_size
|
int
|
Number of elements to extract. |
required |
axis
|
int
|
Axis to slice. |
0
|
Returns:
| Type | Description |
|---|---|
Array
|
Array of shape |
dynamic_slice¤
diffbio.core.soft_ops.selection.dynamic_slice
¤
dynamic_slice(
x: Array,
soft_start_indices: Sequence[SoftIndex],
slice_sizes: Sequence[int],
) -> Array
Soft dynamic slicing across multiple dimensions.
Applies :func:dynamic_slice_in_dim sequentially along each axis.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array of shape |
required |
soft_start_indices
|
Sequence[SoftIndex]
|
One SoftIndex per dimension. |
required |
slice_sizes
|
Sequence[int]
|
One slice length per dimension. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Array of shape |
Sorting¤
Differentiable relaxations of discrete ordering operations including argmax/argmin, argsort, sort, rank, and top-k. Multiple algorithmic backends are available:
| Method | Complexity | Default for |
|---|---|---|
"softsort" |
O(n log n) | argmax, argmin |
"neuralsort" |
O(n^2) | argsort, sort |
"sorting_network" |
O(n log^2 n) | -- |
"ot" |
varies | -- |
"fast_soft_sort" |
O(n log n) | -- |
"smooth_sort" |
O(n log n) | -- |
The ot, fast_soft_sort, and smooth_sort methods require the soft-ops-advanced optional dependency group.
argmax¤
diffbio.core.soft_ops.sorting.argmax
¤
argmax(
x: Array,
axis: int | None = None,
keepdims: bool = False,
softness: float | Array = 0.1,
mode: Mode = "smooth",
method: ArgMethod = "softsort",
standardize: bool = True,
ot_kwargs: dict | None = None,
) -> SoftIndex
Soft argmax returning a SoftIndex (probability distribution).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array. |
required |
axis
|
int | None
|
Axis along which to compute argmax. None flattens first. |
None
|
keepdims
|
bool
|
If True, keep the reduced dimension as singleton. |
False
|
softness
|
float | Array
|
Controls sharpness (> 0). |
0.1
|
mode
|
Mode
|
Smoothness mode. |
'smooth'
|
method
|
ArgMethod
|
Algorithm: |
'softsort'
|
standardize
|
bool
|
If True, standardize input for numerical stability. |
True
|
ot_kwargs
|
dict | None
|
Extra kwargs for OT method. |
None
|
Returns:
| Type | Description |
|---|---|
SoftIndex
|
SoftIndex of shape |
max¤
diffbio.core.soft_ops.sorting.max
¤
max(
x: Array,
axis: int | None = None,
keepdims: bool = False,
softness: float | Array = 0.1,
mode: Mode = "smooth",
method: SortMethod = "softsort",
standardize: bool = True,
ot_kwargs: dict | None = None,
gated_grad: bool = True,
) -> Array
Soft max via argmax + take_along_axis.
For sorting_network method, uses sort + take first element.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array. |
required |
axis
|
int | None
|
Axis along which to compute max. |
None
|
keepdims
|
bool
|
If True, keep reduced dimension. |
False
|
softness
|
float | Array
|
Controls sharpness (> 0). |
0.1
|
mode
|
Mode
|
Smoothness mode. |
'smooth'
|
method
|
SortMethod
|
Algorithm (see :func: |
'softsort'
|
standardize
|
bool
|
If True, standardize input. |
True
|
ot_kwargs
|
dict | None
|
Extra kwargs for OT method. |
None
|
gated_grad
|
bool
|
If False, stop gradient through soft index. |
True
|
Returns:
| Type | Description |
|---|---|
Array
|
Soft maximum value(s). |
argmin¤
diffbio.core.soft_ops.sorting.argmin
¤
argmin(
x: Array,
axis: int | None = None,
keepdims: bool = False,
softness: float | Array = 0.1,
mode: Mode = "smooth",
method: ArgMethod = "softsort",
standardize: bool = True,
ot_kwargs: dict | None = None,
) -> SoftIndex
Soft argmin: :func:argmax on -x.
min¤
diffbio.core.soft_ops.sorting.min
¤
min(
x: Array,
axis: int | None = None,
keepdims: bool = False,
softness: float | Array = 0.1,
mode: Mode = "smooth",
method: SortMethod = "softsort",
standardize: bool = True,
ot_kwargs: dict | None = None,
gated_grad: bool = True,
) -> Array
Soft min: -max(-x).
argsort¤
diffbio.core.soft_ops.sorting.argsort
¤
argsort(
x: Array,
axis: int | None = None,
descending: bool = False,
softness: float | Array = 0.1,
mode: Mode = "smooth",
method: ArgMethod = "neuralsort",
standardize: bool = True,
ot_kwargs: dict | None = None,
) -> SoftIndex
Soft argsort returning a soft permutation matrix.
Output shape is (..., n, ..., [n]) where the last dimension
is the probability distribution over original elements.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array. |
required |
axis
|
int | None
|
Axis along which to argsort. None flattens first. |
None
|
descending
|
bool
|
If True, sort descending. |
False
|
softness
|
float | Array
|
Controls sharpness (> 0). |
0.1
|
mode
|
Mode
|
Smoothness mode. |
'smooth'
|
method
|
ArgMethod
|
Algorithm. |
'neuralsort'
|
standardize
|
bool
|
If True, standardize input. |
True
|
ot_kwargs
|
dict | None
|
Extra kwargs for OT method. |
None
|
Returns:
| Type | Description |
|---|---|
SoftIndex
|
SoftIndex permutation matrix. |
sort¤
diffbio.core.soft_ops.sorting.sort
¤
sort(
x: Array,
axis: int | None = None,
descending: bool = False,
softness: float | Array = 0.1,
mode: Mode = "smooth",
method: SortMethod = "neuralsort",
standardize: bool = True,
ot_kwargs: dict | None = None,
gated_grad: bool = True,
) -> Array
Soft sort returning sorted values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array. |
required |
axis
|
int | None
|
Axis along which to sort. None flattens first. |
None
|
descending
|
bool
|
If True, sort descending. |
False
|
softness
|
float | Array
|
Controls sharpness (> 0). |
0.1
|
mode
|
Mode
|
Smoothness mode. |
'smooth'
|
method
|
SortMethod
|
Algorithm. |
'neuralsort'
|
standardize
|
bool
|
If True, standardize input. |
True
|
ot_kwargs
|
dict | None
|
Extra kwargs for OT method. |
None
|
gated_grad
|
bool
|
If False, stop gradient through soft index. |
True
|
Returns:
| Type | Description |
|---|---|
Array
|
Soft-sorted values. |
rank¤
diffbio.core.soft_ops.sorting.rank
¤
rank(
x: Array,
axis: int | None = None,
descending: bool = False,
softness: float | Array = 0.1,
mode: Mode = "smooth",
method: RankMethod = "softsort",
standardize: bool = True,
) -> Array
Soft fractional ranking.
Returns continuous ranks in [1, n] where 1 is the smallest.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array. |
required |
axis
|
int | None
|
Axis along which to rank. |
None
|
descending
|
bool
|
If True, rank 1 = largest. |
False
|
softness
|
float | Array
|
Controls sharpness (> 0). |
0.1
|
mode
|
Mode
|
Smoothness mode. |
'smooth'
|
method
|
RankMethod
|
|
'softsort'
|
standardize
|
bool
|
If True, standardize input. |
True
|
Returns:
| Type | Description |
|---|---|
Array
|
Continuous ranks. |
top_k¤
diffbio.core.soft_ops.sorting.top_k
¤
top_k(
x: Array,
k: int,
axis: int = -1,
softness: float | Array = 0.1,
mode: Mode = "smooth",
method: SortMethod = "neuralsort",
standardize: bool = True,
ot_kwargs: dict | None = None,
gated_grad: bool = True,
) -> tuple[Array, SoftIndex | None]
Soft top-k selection.
Returns the k largest values and their soft indices.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array. |
required |
k
|
int
|
Number of top elements. |
required |
axis
|
int
|
Axis along which to select. Default -1 (last axis). |
-1
|
softness
|
float | Array
|
Controls sharpness (> 0). |
0.1
|
mode
|
Mode
|
Smoothness mode. |
'smooth'
|
method
|
SortMethod
|
Sorting algorithm. Default |
'neuralsort'
|
standardize
|
bool
|
If True, standardize input. |
True
|
ot_kwargs
|
dict | None
|
Extra keyword arguments for OT-based methods. |
None
|
gated_grad
|
bool
|
If False, stop gradient through soft index. |
True
|
Returns:
| Type | Description |
|---|---|
Array
|
Tuple of (values, soft_indices) where values has shape |
SoftIndex | None
|
|
tuple[Array, SoftIndex | None]
|
|
tuple[Array, SoftIndex | None]
|
methods that only return values (fast_soft_sort, sorting_network). |
Quantile¤
Differentiable relaxations of quantile-based statistics. Quantiles are computed via soft argsort or soft sort with interpolation following the same methods as jax.numpy.quantile.
argquantile¤
diffbio.core.soft_ops.quantile.argquantile
¤
argquantile(
x: Array,
q: Array,
axis: int | None = None,
keepdims: bool = False,
softness: float | Array = 0.1,
mode: Mode = "smooth",
method: ArgMethod = "neuralsort",
quantile_method: Literal[
"linear", "lower", "higher", "nearest", "midpoint"
] = "linear",
standardize: bool = True,
) -> SoftIndex
Soft argquantile returning SoftIndex.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array. |
required |
q
|
Array
|
Quantile(s) in [0, 1]. Scalar or 1-D array. |
required |
axis
|
int | None
|
Axis along which to compute. None flattens. |
None
|
keepdims
|
bool
|
If True, keep reduced dimension. |
False
|
softness
|
float | Array
|
Controls sharpness (> 0). |
0.1
|
mode
|
Mode
|
Smoothness mode. |
'smooth'
|
method
|
ArgMethod
|
Algorithm. |
'neuralsort'
|
quantile_method
|
Literal['linear', 'lower', 'higher', 'nearest', 'midpoint']
|
Interpolation method. |
'linear'
|
standardize
|
bool
|
If True, standardize input. |
True
|
Returns:
| Type | Description |
|---|---|
SoftIndex
|
SoftIndex probability distribution over quantile position(s). |
quantile¤
diffbio.core.soft_ops.quantile.quantile
¤
quantile(
x: Array,
q: Array,
axis: int | None = None,
keepdims: bool = False,
softness: float | Array = 0.1,
mode: Mode = "smooth",
method: ArgMethod = "neuralsort",
quantile_method: Literal[
"linear", "lower", "higher", "nearest", "midpoint"
] = "linear",
standardize: bool = True,
gated_grad: bool = True,
) -> Array
Soft quantile returning value.
Implemented as :func:argquantile + :func:take_along_axis
for most methods.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input array. |
required |
q
|
Array
|
Quantile(s) in [0, 1]. |
required |
axis
|
int | None
|
Axis along which to compute. |
None
|
keepdims
|
bool
|
If True, keep reduced dimension. |
False
|
softness
|
float | Array
|
Controls sharpness (> 0). |
0.1
|
mode
|
Mode
|
Smoothness mode. |
'smooth'
|
method
|
ArgMethod
|
Algorithm. |
'neuralsort'
|
quantile_method
|
Literal['linear', 'lower', 'higher', 'nearest', 'midpoint']
|
Interpolation method. |
'linear'
|
standardize
|
bool
|
If True, standardize input. |
True
|
gated_grad
|
bool
|
If False, stop gradient through soft index. |
True
|
Returns:
| Type | Description |
|---|---|
Array
|
Quantile value(s). |
argmedian¤
diffbio.core.soft_ops.quantile.argmedian
¤
argmedian(
x: Array,
axis: int | None = None,
keepdims: bool = False,
softness: float | Array = 0.1,
mode: Mode = "smooth",
method: ArgMethod = "neuralsort",
standardize: bool = True,
) -> SoftIndex
Soft argmedian: :func:argquantile with q=0.5.
median¤
diffbio.core.soft_ops.quantile.median
¤
median(
x: Array,
axis: int | None = None,
keepdims: bool = False,
softness: float | Array = 0.1,
mode: Mode = "smooth",
method: ArgMethod = "neuralsort",
standardize: bool = True,
gated_grad: bool = True,
) -> Array
Soft median: :func:quantile with q=0.5.
argpercentile¤
diffbio.core.soft_ops.quantile.argpercentile
¤
argpercentile(
x: Array,
p: Array,
axis: int | None = None,
keepdims: bool = False,
softness: float | Array = 0.1,
mode: Mode = "smooth",
method: ArgMethod = "neuralsort",
standardize: bool = True,
) -> SoftIndex
Soft argpercentile: :func:argquantile with q = p / 100.
percentile¤
diffbio.core.soft_ops.quantile.percentile
¤
percentile(
x: Array,
p: Array,
axis: int | None = None,
keepdims: bool = False,
softness: float | Array = 0.1,
mode: Mode = "smooth",
method: ArgMethod = "neuralsort",
standardize: bool = True,
gated_grad: bool = True,
) -> Array
Soft percentile: :func:quantile with q = p / 100.
Straight-Through Estimators¤
Straight-through estimators use the hard (exact, non-differentiable) function for the forward pass but route gradients through the soft (differentiable) version during backpropagation. The trick stop_gradient(hard - soft) + soft ensures the forward output is exact while gradients flow through the smooth relaxation.
st¤
diffbio.core.soft_ops.straight_through.st
¤
Decorator creating a straight-through estimator from a soft_ops function.
The decorated function is called twice: once with mode="hard"
(forward pass) and once with the specified mode (backward pass).
The trick stop_gradient(hard - soft) + soft ensures the forward
output is hard but gradients flow through the soft version.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable
|
Function with a |
required |
Returns:
| Type | Description |
|---|---|
Callable
|
Wrapped straight-through estimator function. |
grad_replace¤
diffbio.core.soft_ops.straight_through.grad_replace
¤
Decorator for custom forward/backward computation split.
The decorated function is called twice: once with forward=True
(output used for forward pass) and once with forward=False
(output used for gradient computation).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable
|
Function accepting a |
required |
Returns:
| Type | Description |
|---|---|
Callable
|
Wrapped function using hard forward, soft backward. |
Pre-built _st Variants¤
The module provides 27 pre-built straight-through variants, one for each soft operation that accepts a mode parameter. Each variant uses the hard function for the forward pass and the corresponding soft function for the backward pass.
| Elementwise | Comparison | Sorting | Quantile |
|---|---|---|---|
abs_st |
equal_st |
argmax_st |
argmedian_st |
clip_st |
greater_st |
argmin_st |
argpercentile_st |
heaviside_st |
greater_equal_st |
argsort_st |
argquantile_st |
relu_st |
isclose_st |
max_st |
median_st |
round_st |
less_st |
min_st |
percentile_st |
sign_st |
less_equal_st |
rank_st |
quantile_st |
not_equal_st |
sort_st |
||
top_k_st |
Each _st variant accepts the same arguments as its base function. For example: