fix: return mu (not mu+std) in VarAutoEncoder reparameterize at eval time (Fixes #8413)#8921
Conversation
📝 WalkthroughWalkthroughTwo Estimated code review effort🎯 2 (Simple) | ⏱️ ~5 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
monai/networks/nets/varautoencoder.py (1)
144-149: 💤 Low valueFix is correct. The reparameterization trick now properly returns
μin eval mode andμ + ε·σin training.Consider adding a docstring per coding guidelines:
📝 Optional docstring addition
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: + """Apply the reparameterization trick for VAE latent sampling. + + Args: + mu: Mean of the latent distribution. + logvar: Log variance of the latent distribution. + + Returns: + Sampled latent vector (stochastic in training, deterministic in eval). + """ if self.training: # reparameterization trick only during training std = torch.exp(0.5 * logvar) return mu + torch.randn_like(std) * std return mu🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/networks/nets/varautoencoder.py` around lines 144 - 149, The reparameterize method in the VAE implementation is missing a docstring that documents its behavior. Add a docstring to the reparameterize method that explains the purpose of the method, describes the input parameters (mu and logvar), specifies the return type and value, and documents the different behavior between training mode (which applies the reparameterization trick with sampling) and evaluation mode (which returns mu directly).Source: Coding guidelines
monai/networks/nets/fullyconnectednet.py (1)
174-179: 💤 Low valueFix is correct and consistent with
VarAutoEncoder. Same reparameterization logic applied correctly.Consider adding a docstring per coding guidelines:
📝 Optional docstring addition
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: + """Apply the reparameterization trick for VAE latent sampling. + + Args: + mu: Mean of the latent distribution. + logvar: Log variance of the latent distribution. + + Returns: + Sampled latent vector (stochastic in training, deterministic in eval). + """ if self.training: # reparameterization trick only during training std = torch.exp(0.5 * logvar) return mu + torch.randn_like(std) * std return mu🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/networks/nets/fullyconnectednet.py` around lines 174 - 179, The reparameterize method is missing a docstring per coding guidelines. Add a docstring to the reparameterize method that explains its purpose, describes the reparameterization trick applied during training (which samples from the latent distribution using the mean mu and standard deviation from logvar), and clarifies that during inference it simply returns the mean mu. Include documentation for the parameters mu and logvar, and the return value.Source: Coding guidelines
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@monai/networks/nets/fullyconnectednet.py`:
- Around line 174-179: The reparameterize method is missing a docstring per
coding guidelines. Add a docstring to the reparameterize method that explains
its purpose, describes the reparameterization trick applied during training
(which samples from the latent distribution using the mean mu and standard
deviation from logvar), and clarifies that during inference it simply returns
the mean mu. Include documentation for the parameters mu and logvar, and the
return value.
In `@monai/networks/nets/varautoencoder.py`:
- Around line 144-149: The reparameterize method in the VAE implementation is
missing a docstring that documents its behavior. Add a docstring to the
reparameterize method that explains the purpose of the method, describes the
input parameters (mu and logvar), specifies the return type and value, and
documents the different behavior between training mode (which applies the
reparameterization trick with sampling) and evaluation mode (which returns mu
directly).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 83003db9-4000-46f5-a46f-b1d61cce0d53
📒 Files selected for processing (2)
monai/networks/nets/fullyconnectednet.pymonai/networks/nets/varautoencoder.py
In VarAutoEncoder and VarFullyConnectedNet, the reparameterize method computed and always returned . At training time, was correctly replaced with random noise scaled by the standard deviation. But at eval time, the raw standard deviation was still added to the mean — giving instead of just . The reparameterization trick (Kingma & Welling, 2014) defines: - Training: z = μ + ε · σ (ε ~ N(0,1)) - Inference: z = μ (no stochastic component) This fix restructures the method to return directly when not training, avoiding the unnecessary computation and eliminating the incorrect result at inference. Fixes Project-MONAI#8413 Signed-off-by: Raphael Malikian <rtmalikian@gmail.com>
f8354f3 to
f74108c
Compare
Fixes #8413
Problem
In
VarAutoEncoder.reparameterize()andVarFullyConnectedNet.reparameterize(), the method always returnsstd + mu:At training time,
stdis correctly replaced withε · σ(random noise scaled by standard deviation). But at eval time, the raw standard deviationσ = exp(0.5 · logvar)is still added to the mean, producingμ + σinstead of justμ.The reparameterization trick (Kingma & Welling, 2014) specifies:
z = μ + ε · σwhereε ~ N(0,1)z = μ(no stochastic component)Adding the standard deviation to the mean at inference time introduces systematic bias — the latent code is shifted away from the learned mean by an amount proportional to the variance.
Solution
Restructured the
reparameterizemethod to returnmudirectly when not in training mode, avoiding the unnecessaryexp(0.5 * logvar)computation:Applied to both:
monai/networks/nets/varautoencoder.py—VarAutoEncoder.reparameterizemonai/networks/nets/fullyconnectednet.py—VarFullyConnectedNet.reparameterizeNote:
spade_network.pyhas a different implementation that always samples noise (intentional for GAN-style use), so it was left unchanged.Verification
Manual verification:
Existing tests pass (7/7 + 3/3):
About the Author: Raphael Malikian — Clinical AI Solutions Architect. I specialise in building and fixing AI/ML systems for healthcare, including vector databases, RAG pipelines, and clinical NLP. If you need help with your project or think I can add value to your organisation, feel free to reach out — I'd love to connect.
📧 rtmalikian@gmail.com
🔗 GitHub: https://github.com/rtmalikian
🔗 LinkedIn: http://www.linkedin.com/in/raphael-t-malikian-mbbs-bsc-hons-71075436a
Disclosure: This code was developed with assistance from mimo-2.5-pro (Xiaomi) via Hermes Agent (Nous Research). All changes were reviewed, tested against the actual codebase, and verified for correctness.
Changelog
Files Changed
monai/networks/nets/varautoencoder.py— Fixedreparameterize()to branch onself.trainingin bothVarAutoEncoderandVarFullyConnectedNetVerification
VarAutoEncoder.reparameterize()returns μ in eval mode, μ + ε·σ in train modeVarFullyConnectedNet.reparameterize()follows same pattern