Skip to content

Population Genetics Operators API¤

Differentiable operators for population genetics analysis including ancestry estimation.

DifferentiableAncestryEstimator¤

diffbio.operators.population.ancestry_estimation.DifferentiableAncestryEstimator ¤

DifferentiableAncestryEstimator(
    config: AncestryEstimatorConfig, *, rngs: Rngs
)

Bases: TemperatureOperator

Neural ADMIXTURE-style differentiable ancestry estimator.

This operator uses an autoencoder architecture to estimate ancestry proportions from genotype data. The encoder maps genotypes to a latent representation, which is then transformed to ancestry proportions via temperature-controlled softmax. The decoder reconstructs genotypes from ancestry proportions, enabling unsupervised learning.

The model follows the ADMIXTURE generative model

G_ij = sum_k Q_ik * P_kj

Where: - G is the genotype matrix (individuals x SNPs) - Q is the ancestry proportion matrix (individuals x K populations) - P is the population allele frequency matrix (K x SNPs)

Attributes:

Name Type Description
config

Operator configuration.

backbone

Shared encoder MLP, or None when hidden_dims is empty.

ancestry_head

Linear layer for ancestry proportions.

population_frequencies

Learnable population allele frequencies (P matrix).

Example
from diffbio.operators.population import (
    DifferentiableAncestryEstimator,
    AncestryEstimatorConfig,
)
config = AncestryEstimatorConfig(n_snps=1000, n_populations=5)
estimator = DifferentiableAncestryEstimator(config, rngs=nnx.Rngs(42))
data = {"genotypes": genotype_matrix}  # (n_samples, n_snps)
result, _, _ = estimator.apply(data, {}, None)
ancestry = result["ancestry_proportions"]  # (n_samples, K)

Parameters:

Name Type Description Default
config AncestryEstimatorConfig

Operator configuration.

required
rngs Rngs

Flax NNX random number generators.

required

apply ¤

apply(
    data: dict[str, Any],
    state: dict[str, Any],
    metadata: dict[str, Any] | None,
    random_params: Any = None,
    stats: dict[str, Any] | None = None,
) -> tuple[
    dict[str, Any], dict[str, Any], dict[str, Any] | None
]

Apply ancestry estimation to genotype data.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary containing: - "genotypes": Genotype matrix (n_samples, n_snps) with values 0/½.

required
state dict[str, Any]

Per-element state (passed through).

required
metadata dict[str, Any] | None

Optional metadata (passed through).

required
random_params Any

Random parameters for stochastic operations.

None
stats dict[str, Any] | None

Optional statistics dictionary.

None

Returns:

Name Type Description
dict[str, Any]

Tuple of (transformed_data, state, metadata) where transformed_data

contains dict[str, Any]
  • "genotypes": Original genotype matrix.
  • "ancestry_proportions": Estimated ancestry (n_samples, K).
  • "reconstructed": Reconstructed genotypes (n_samples, n_snps).
  • "latent": Latent representation.

encode ¤

encode(genotypes: ndarray) -> ndarray

Encode genotypes to latent representation.

Parameters:

Name Type Description Default
genotypes ndarray

Genotype matrix of shape (n_samples, n_snps). Values should be 0, 1, or 2 representing allele counts.

required

Returns:

Type Description
ndarray

Latent representation of shape (n_samples, hidden_dims[-1]) when

ndarray

hidden layers are configured, otherwise the original genotype matrix.

compute_ancestry ¤

compute_ancestry(latent: ndarray) -> ndarray

Compute ancestry proportions from latent representation.

Parameters:

Name Type Description Default
latent ndarray

Latent representation of shape (n_samples, hidden_dims[-1]).

required

Returns:

Type Description
ndarray

Ancestry proportions of shape (n_samples, n_populations).

ndarray

Each row sums to 1 and all values are non-negative.

decode ¤

decode(ancestry: ndarray) -> ndarray

Decode ancestry proportions to reconstructed genotypes.

Following the ADMIXTURE model: G = Q @ P Where Q is ancestry proportions and P is population frequencies.

Parameters:

Name Type Description Default
ancestry ndarray

Ancestry proportions of shape (n_samples, n_populations).

required

Returns:

Type Description
ndarray

Reconstructed genotypes of shape (n_samples, n_snps).

ndarray

Values represent expected allele counts (continuous 0-2).

AncestryEstimatorConfig¤

diffbio.operators.population.ancestry_estimation.AncestryEstimatorConfig dataclass ¤

AncestryEstimatorConfig(
    temperature: float = DEFAULT_TEMPERATURE,
    learnable_temperature: bool = False,
    n_snps: int = 10000,
    n_populations: int = 5,
    hidden_dims: tuple[int, ...] = (128, 64),
    dropout_rate: float = 0.1,
)

Bases: TemperatureConfig

Configuration for DifferentiableAncestryEstimator.

Attributes:

Name Type Description
n_snps int

Number of SNP markers in genotype input.

n_populations int

Number of ancestral populations (K).

hidden_dims tuple[int, ...]

Hidden layer dimensions for encoder.

dropout_rate float

Dropout rate for regularization.

temperature class-attribute instance-attribute ¤

temperature: float = DEFAULT_TEMPERATURE

learnable_temperature class-attribute instance-attribute ¤

learnable_temperature: bool = False

create_ancestry_estimator¤

diffbio.operators.population.ancestry_estimation.create_ancestry_estimator ¤

create_ancestry_estimator(
    n_snps: int,
    n_populations: int,
    hidden_dims: tuple[int, ...] = (128, 64),
    temperature: float = 1.0,
    dropout_rate: float = 0.1,
    seed: int = 42,
) -> DifferentiableAncestryEstimator

Factory function to create an ancestry estimator.

Parameters:

Name Type Description Default
n_snps int

Number of SNP markers.

required
n_populations int

Number of ancestral populations (K).

required
hidden_dims tuple[int, ...]

Hidden layer dimensions for encoder.

(128, 64)
temperature float

Softmax temperature for ancestry proportions.

1.0
dropout_rate float

Dropout rate for regularization.

0.1
seed int

Random seed for initialization.

42

Returns:

Type Description
DifferentiableAncestryEstimator

Configured DifferentiableAncestryEstimator instance.

Example
estimator = create_ancestry_estimator(
    n_snps=10000,
    n_populations=5,
)
result, _, _ = estimator.apply({"genotypes": data}, {}, None)

Usage Examples¤

Basic Ancestry Estimation¤

from flax import nnx
from diffbio.operators.population import (
    DifferentiableAncestryEstimator,
    AncestryEstimatorConfig,
    create_ancestry_estimator,
)

# Using config
config = AncestryEstimatorConfig(
    n_snps=10000,
    n_populations=5,
    hidden_dims=(128, 64),
    temperature=1.0,
)
estimator = DifferentiableAncestryEstimator(config, rngs=nnx.Rngs(42))

# Or using factory function
estimator = create_ancestry_estimator(
    n_snps=10000,
    n_populations=5,
)

# Apply ancestry estimation
data = {"genotypes": genotype_matrix}  # (n_samples, n_snps)
result, _, _ = estimator.apply(data, {}, None)

# Get ancestry proportions
ancestry = result["ancestry_proportions"]  # (n_samples, K)

Training Mode¤

# Enable dropout during training
estimator.train()

for batch in train_dataloader:
    loss = train_step(estimator, batch)

# Disable dropout for inference
estimator.eval()

Accessing Components¤

# Population allele frequencies
pop_freqs = estimator.population_frequencies[...]  # (K, n_snps)

# Temperature parameter (read from config; the live Param is `_temperature`)
temperature = estimator.config.temperature

# Encoder layers
encoder = estimator.backbone

Input Specifications¤

Key Shape Description
genotypes (n_samples, n_snps) Genotype matrix with values 0, 1, or 2

Output Specifications¤

Key Shape Description
genotypes (n_samples, n_snps) Original genotype matrix
ancestry_proportions (n_samples, n_populations) Estimated ancestry proportions (sum to 1)
reconstructed (n_samples, n_snps) Reconstructed genotypes
latent (n_samples, hidden_dims[-1]) Latent representation