Population Genetics Operators API¤
Differentiable operators for population genetics analysis including ancestry estimation.
DifferentiableAncestryEstimator¤
diffbio.operators.population.ancestry_estimation.DifferentiableAncestryEstimator
¤
DifferentiableAncestryEstimator(
config: AncestryEstimatorConfig, *, rngs: Rngs
)
Bases: TemperatureOperator
Neural ADMIXTURE-style differentiable ancestry estimator.
This operator uses an autoencoder architecture to estimate ancestry proportions from genotype data. The encoder maps genotypes to a latent representation, which is then transformed to ancestry proportions via temperature-controlled softmax. The decoder reconstructs genotypes from ancestry proportions, enabling unsupervised learning.
The model follows the ADMIXTURE generative model
G_ij = sum_k Q_ik * P_kj
Where: - G is the genotype matrix (individuals x SNPs) - Q is the ancestry proportion matrix (individuals x K populations) - P is the population allele frequency matrix (K x SNPs)
Attributes:
| Name | Type | Description |
|---|---|---|
config |
Operator configuration. |
|
backbone |
Shared encoder MLP, or |
|
ancestry_head |
Linear layer for ancestry proportions. |
|
population_frequencies |
Learnable population allele frequencies (P matrix). |
Example
from diffbio.operators.population import (
DifferentiableAncestryEstimator,
AncestryEstimatorConfig,
)
config = AncestryEstimatorConfig(n_snps=1000, n_populations=5)
estimator = DifferentiableAncestryEstimator(config, rngs=nnx.Rngs(42))
data = {"genotypes": genotype_matrix} # (n_samples, n_snps)
result, _, _ = estimator.apply(data, {}, None)
ancestry = result["ancestry_proportions"] # (n_samples, K)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
AncestryEstimatorConfig
|
Operator configuration. |
required |
rngs
|
Rngs
|
Flax NNX random number generators. |
required |
apply
¤
apply(
data: dict[str, Any],
state: dict[str, Any],
metadata: dict[str, Any] | None,
random_params: Any = None,
stats: dict[str, Any] | None = None,
) -> tuple[
dict[str, Any], dict[str, Any], dict[str, Any] | None
]
Apply ancestry estimation to genotype data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary containing: - "genotypes": Genotype matrix (n_samples, n_snps) with values 0/½. |
required |
state
|
dict[str, Any]
|
Per-element state (passed through). |
required |
metadata
|
dict[str, Any] | None
|
Optional metadata (passed through). |
required |
random_params
|
Any
|
Random parameters for stochastic operations. |
None
|
stats
|
dict[str, Any] | None
|
Optional statistics dictionary. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
dict[str, Any]
|
Tuple of (transformed_data, state, metadata) where transformed_data |
|
contains |
dict[str, Any]
|
|
encode
¤
Encode genotypes to latent representation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
genotypes
|
ndarray
|
Genotype matrix of shape (n_samples, n_snps). Values should be 0, 1, or 2 representing allele counts. |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Latent representation of shape |
ndarray
|
hidden layers are configured, otherwise the original genotype matrix. |
compute_ancestry
¤
Compute ancestry proportions from latent representation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
latent
|
ndarray
|
Latent representation of shape (n_samples, hidden_dims[-1]). |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Ancestry proportions of shape (n_samples, n_populations). |
ndarray
|
Each row sums to 1 and all values are non-negative. |
decode
¤
Decode ancestry proportions to reconstructed genotypes.
Following the ADMIXTURE model: G = Q @ P Where Q is ancestry proportions and P is population frequencies.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ancestry
|
ndarray
|
Ancestry proportions of shape (n_samples, n_populations). |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Reconstructed genotypes of shape (n_samples, n_snps). |
ndarray
|
Values represent expected allele counts (continuous 0-2). |
AncestryEstimatorConfig¤
diffbio.operators.population.ancestry_estimation.AncestryEstimatorConfig
dataclass
¤
AncestryEstimatorConfig(
temperature: float = DEFAULT_TEMPERATURE,
learnable_temperature: bool = False,
n_snps: int = 10000,
n_populations: int = 5,
hidden_dims: tuple[int, ...] = (128, 64),
dropout_rate: float = 0.1,
)
Bases: TemperatureConfig
Configuration for DifferentiableAncestryEstimator.
Attributes:
| Name | Type | Description |
|---|---|---|
n_snps |
int
|
Number of SNP markers in genotype input. |
n_populations |
int
|
Number of ancestral populations (K). |
hidden_dims |
tuple[int, ...]
|
Hidden layer dimensions for encoder. |
dropout_rate |
float
|
Dropout rate for regularization. |
create_ancestry_estimator¤
diffbio.operators.population.ancestry_estimation.create_ancestry_estimator
¤
create_ancestry_estimator(
n_snps: int,
n_populations: int,
hidden_dims: tuple[int, ...] = (128, 64),
temperature: float = 1.0,
dropout_rate: float = 0.1,
seed: int = 42,
) -> DifferentiableAncestryEstimator
Factory function to create an ancestry estimator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_snps
|
int
|
Number of SNP markers. |
required |
n_populations
|
int
|
Number of ancestral populations (K). |
required |
hidden_dims
|
tuple[int, ...]
|
Hidden layer dimensions for encoder. |
(128, 64)
|
temperature
|
float
|
Softmax temperature for ancestry proportions. |
1.0
|
dropout_rate
|
float
|
Dropout rate for regularization. |
0.1
|
seed
|
int
|
Random seed for initialization. |
42
|
Returns:
| Type | Description |
|---|---|
DifferentiableAncestryEstimator
|
Configured DifferentiableAncestryEstimator instance. |
Usage Examples¤
Basic Ancestry Estimation¤
from flax import nnx
from diffbio.operators.population import (
DifferentiableAncestryEstimator,
AncestryEstimatorConfig,
create_ancestry_estimator,
)
# Using config
config = AncestryEstimatorConfig(
n_snps=10000,
n_populations=5,
hidden_dims=(128, 64),
temperature=1.0,
)
estimator = DifferentiableAncestryEstimator(config, rngs=nnx.Rngs(42))
# Or using factory function
estimator = create_ancestry_estimator(
n_snps=10000,
n_populations=5,
)
# Apply ancestry estimation
data = {"genotypes": genotype_matrix} # (n_samples, n_snps)
result, _, _ = estimator.apply(data, {}, None)
# Get ancestry proportions
ancestry = result["ancestry_proportions"] # (n_samples, K)
Training Mode¤
# Enable dropout during training
estimator.train()
for batch in train_dataloader:
loss = train_step(estimator, batch)
# Disable dropout for inference
estimator.eval()
Accessing Components¤
# Population allele frequencies
pop_freqs = estimator.population_frequencies[...] # (K, n_snps)
# Temperature parameter (read from config; the live Param is `_temperature`)
temperature = estimator.config.temperature
# Encoder layers
encoder = estimator.backbone
Input Specifications¤
| Key | Shape | Description |
|---|---|---|
genotypes |
(n_samples, n_snps) | Genotype matrix with values 0, 1, or 2 |
Output Specifications¤
| Key | Shape | Description |
|---|---|---|
genotypes |
(n_samples, n_snps) | Original genotype matrix |
ancestry_proportions |
(n_samples, n_populations) | Estimated ancestry proportions (sum to 1) |
reconstructed |
(n_samples, n_snps) | Reconstructed genotypes |
latent |
(n_samples, hidden_dims[-1]) | Latent representation |