Skip to main content

Autoencoder Clustering Losses

Three loss functions designed for Phase-2 DCEC-style fine-tuning of variational autoencoders. The goal is to preserve reconstruction quality (PSNR) while encouraging the latent space to form well-separated clusters.

Background

Training a VAE with MSE alone produces good reconstruction but no latent structure. Adding standard KL regularisation (β-VAE) can destroy reconstruction quality when β is too large.

The DCEC approach solves this with a two-phase strategy:

PhaseLossGoal
1MSE (existing VariationalAutoencoderLoss)maximise PSNR
2ClusteringAwareVAELossmaintain PSNR + improve clustering

References

  • Xie et al., "Unsupervised Deep Embedding for Clustering Analysis", ICML 2016 — DEC soft-assignment loss.
  • Guo et al., "Deep Convolutional Embedded Clustering", IJCAI 2017 — DCEC: joint reconstruction + clustering.
  • Wen et al., "A Discriminative Feature Learning Approach for Deep Face Recognition", ECCV 2016 — Center loss.

DECSoftAssignmentLoss

Minimises the KL divergence between soft cluster assignments Q and a sharpened target distribution P. Pushing Q toward P forces the encoder to make confident, unambiguous cluster assignments.

Q_ij = (1 + ‖z_i − c_j‖² / α)^{−(α+1)/2} (Student-t kernel, α=1 default)
─────────────────────────────────────
Σ_j (1 + ‖z_i − c_j‖² / α)^{−(α+1)/2}

P_ij = Q_ij² / Σ_i Q_ij (sharpened, normalised)
─────────────────────

L_DEC = KL(P ‖ Q)

YAML

loss:
_target_: ...DECSoftAssignmentLoss
n_clusters: 8
embedding_dim: 128
alpha: 1.0

Warm-start (mandatory)

Initialise cluster centers with K-Means before training:

loss_fn.initialize_centers(kmeans_centroids) # tensor (K, D)

CenterLoss

Minimises the mean squared distance from each embedding to its nearest cluster center. Encourages compact, geometrically tight clusters.

L_center = λ · (1/2) · Σ_i ‖z_i − c_{ŷ_i}‖²

where ŷ_i = argmin_j ‖z_i − c_j‖.

YAML

loss:
_target_: ...CenterLoss
n_clusters: 8
embedding_dim: 128
lambda_center: 0.01

ClusteringAwareVAELoss (recommended entry point)

Composite loss combining all three terms:

L = L_reconstruction + β·L_KL + γ·L_DEC + δ·L_center

Owns a single cluster_centers parameter shared by both sub-losses.

YAML

loss:
_target_: pytorch_segmentation_models_trainer.custom_losses.autoencoder_clustering_losses.ClusteringAwareVAELoss
n_clusters: 8
embedding_dim: 128
gamma: 0.1 # DEC weight
delta: 0.01 # center weight
vae_latent: mu # "mu" or "z"
latent_reduction: adaptive_avg_pool
reconstruction_loss: mse
reconstruction_weight: 1.0
beta: 0.0001 # small KL — keeps PSNR intact

Phase-2 training protocol (automatic)

Add ClusterCentersWarmStartCallback to your callback list. It runs once before the first epoch and handles the warm-start automatically.

callbacks:
- _target_: pytorch_segmentation_models_trainer.custom_callbacks.ClusterCentersWarmStartCallback
max_samples: 4096
image_key: image
vae_latent: mu

The callback iterates trainer.train_dataloader, collects up to max_samples embeddings, fits K-Means, and writes the centroids into loss_function.cluster_centers before training starts. It is a no-op if loss_function is not a ClusteringAwareVAELoss.

Phase-2 training protocol (manual)

# 1. Collect embeddings from pre-trained encoder (no gradient)
embeddings = []
with torch.no_grad():
for batch in dataloader:
out = model(batch["image"])
embeddings.append(out.mu)
embeddings = torch.cat(embeddings, dim=0)

# 2. Warm-start centers
loss_fn.initialize_centers_from_embeddings(embeddings)

# 3. Train Phase 2 with ClusteringAwareVAELoss
trainer.fit(model)

Monitoring

Watch these metrics during Phase 2:

MetricDirectionHealthy sign
latent_calinski_harabaszclusters separating
latent_davies_bouldinclusters tightening
latent_silhouetteless ambiguous assignment
PSNR→ stablereconstruction preserved

If PSNR drops > 2 dB, reduce gamma or delta.


Loss output keys

ClusteringAwareVAELoss.forward returns all keys from VariationalAutoencoderLoss plus:

KeyDescription
dec_lossraw DEC KL divergence
weighted_dec_lossgamma × dec_loss
center_lossraw center loss
weighted_center_lossdelta × center_loss