Skip to content

Population Genetics Operators¤

DiffBio provides differentiable operators for population genetics analysis including ancestry estimation, enabling gradient-based optimization of population structure models.

Population Fully Differentiable

Overview¤

Population genetics operators enable end-to-end optimization of:

  • DifferentiableAncestryEstimator: Neural ADMIXTURE-style ancestry estimation

DifferentiableAncestryEstimator¤

Autoencoder-based ancestry estimation that learns to decompose individual genotypes into proportions from K ancestral populations.

Quick Start¤

from flax import nnx
from diffbio.operators.population import (
    DifferentiableAncestryEstimator,
    AncestryEstimatorConfig,
    create_ancestry_estimator,
)

# Configure estimator
config = AncestryEstimatorConfig(
    n_snps=10000,           # Number of SNP markers
    n_populations=5,        # Number of ancestral populations (K)
    hidden_dims=(128, 64),  # Encoder hidden layers
    temperature=1.0,        # Softmax temperature
    dropout_rate=0.1,       # Regularization
)

# Create operator
rngs = nnx.Rngs(42)
estimator = DifferentiableAncestryEstimator(config, rngs=rngs)

# Apply ancestry estimation
data = {
    "genotypes": genotype_matrix,  # (n_samples, n_snps), values 0/1/2
}
result, state, metadata = estimator.apply(data, {}, None)

# Get ancestry proportions
ancestry = result["ancestry_proportions"]  # (n_samples, K)
reconstructed = result["reconstructed"]     # Reconstructed genotypes

Configuration¤

Parameter Type Default Description
n_snps int 10000 Number of SNP markers in input
n_populations int 5 Number of ancestral populations (K)
hidden_dims tuple[int, ...] (128, 64) Encoder hidden layer dimensions
temperature float 1.0 Temperature for softmax proportions
dropout_rate float 0.1 Dropout rate for regularization

ADMIXTURE Model¤

The operator implements the classic ADMIXTURE generative model:

\[G_{ij} = \sum_k Q_{ik} \cdot P_{kj}\]

Where:

  • \(G\) = genotype matrix (individuals × SNPs)
  • \(Q\) = ancestry proportion matrix (individuals × K)
  • \(P\) = population allele frequency matrix (K × SNPs)

The neural network learns both Q (via encoder) and P (learnable parameters).

Architecture¤

graph LR
    A["Genotypes<br/>(n, s)"] --> B["Encoder<br/>(MLP)"]
    B --> C["Latent<br/>(n, h)"]
    C --> D["Softmax (τ)"]
    D --> Q["Ancestry Q<br/>(n, K)"]
    Q --> MUL["Q @ P"]
    P["Population Frequencies P<br/>(K, s)"] --> MUL
    MUL --> R["Reconstructed<br/>(n, s)"]

    style A fill:#d1fae5,stroke:#059669,color:#064e3b
    style B fill:#e0e7ff,stroke:#4338ca,color:#312e81
    style C fill:#ede9fe,stroke:#7c3aed,color:#4c1d95
    style D fill:#e0e7ff,stroke:#4338ca,color:#312e81
    style Q fill:#dbeafe,stroke:#2563eb,color:#1e3a5f
    style P fill:#dbeafe,stroke:#2563eb,color:#1e3a5f
    style MUL fill:#e0e7ff,stroke:#4338ca,color:#312e81
    style R fill:#d1fae5,stroke:#059669,color:#064e3b

The encoder maps genotypes to a latent space, then computes ancestry proportions via temperature-controlled softmax. The decoder reconstructs genotypes using the ADMIXTURE model.

Temperature Control¤

The temperature parameter controls the sharpness of ancestry assignments:

Temperature Effect
High (5.0+) Softer, more uniform proportions
Medium (1.0) Balanced
Low (0.1) Sharper, more confident assignments
# High temperature (exploration)
config_soft = AncestryEstimatorConfig(temperature=5.0)

# Low temperature (exploitation)
config_sharp = AncestryEstimatorConfig(temperature=0.1)

Training¤

import optax
from flax import nnx

estimator = create_ancestry_estimator(n_snps=10000, n_populations=5)
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(nnx.state(estimator, nnx.Param))

def loss_fn(model, genotypes):
    """Reconstruction loss for unsupervised training."""
    result, _, _ = model.apply({"genotypes": genotypes}, {}, None)

    # Reconstruction loss
    recon_loss = jnp.mean((result["reconstructed"] - genotypes) ** 2)

    # Optional: Entropy regularization for sparse ancestry
    proportions = result["ancestry_proportions"]
    entropy = -jnp.mean(jnp.sum(proportions * jnp.log(proportions + 1e-10), axis=-1))

    return recon_loss - 0.01 * entropy

@jax.jit
def train_step(model, opt_state, genotypes):
    loss, grads = jax.value_and_grad(loss_fn)(model, genotypes)
    params = nnx.state(model, nnx.Param)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    nnx.update(model, optax.apply_updates(params, updates))
    return loss, opt_state

# Training loop
estimator.train()
for epoch in range(100):
    loss, opt_state = train_step(estimator, opt_state, train_genotypes)

estimator.eval()

Supervised Training¤

With known population labels:

def supervised_loss(model, genotypes, true_labels):
    """Cross-entropy loss for supervised training."""
    result, _, _ = model.apply({"genotypes": genotypes}, {}, None)

    # One-hot encode true labels
    true_onehot = jax.nn.one_hot(true_labels, num_classes=K)

    # Cross-entropy
    log_probs = jnp.log(result["ancestry_proportions"] + 1e-10)
    ce_loss = -jnp.mean(jnp.sum(true_onehot * log_probs, axis=-1))

    return ce_loss

Inference¤

estimator.eval()

result, _, _ = estimator.apply({"genotypes": test_genotypes}, {}, None)

# Get ancestry proportions
ancestry = result["ancestry_proportions"]  # (n_samples, K)

# Find primary ancestry
primary_ancestry = jnp.argmax(ancestry, axis=-1)

# Visualize admixture
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 3))
plt.bar(range(len(ancestry)), ancestry[:, 0], label="Pop 1")
bottom = ancestry[:, 0]
for k in range(1, K):
    plt.bar(range(len(ancestry)), ancestry[:, k], bottom=bottom, label=f"Pop {k+1}")
    bottom += ancestry[:, k]
plt.ylabel("Ancestry Proportion")
plt.xlabel("Individual")
plt.legend()
plt.show()

Accessing Learned Parameters¤

# Population allele frequencies (P matrix)
pop_freqs = estimator.population_frequencies[...]  # (K, n_snps)

# Encoder weights
encoder_params = nnx.state(estimator, nnx.Param)

# Temperature
temperature = estimator.temperature[...]

Use Cases¤

Application Operator Description
Ancestry inference DifferentiableAncestryEstimator Estimate ancestry proportions
Population structure DifferentiableAncestryEstimator Discover population structure
Admixture analysis DifferentiableAncestryEstimator Model genetic admixture

References¤

  1. Alexander, D.H. et al. (2009). "Fast model-based estimation of ancestry in unrelated individuals." Genome Research.

  2. Dias, A. et al. (2022). "Neural ADMIXTURE: Rapid population clustering with autoencoders." Nature Computational Science.

  3. Pritchard, J.K. et al. (2000). "Inference of Population Structure Using Multilocus Genotype Data." Genetics.

Next Steps¤