Skip to content

fix: return mu (not mu+std) in VarAutoEncoder reparameterize at eval time (Fixes #8413)#8921

Open
rtmalikian wants to merge 1 commit into
Project-MONAI:devfrom
rtmalikian:fix/issue-8413-varautoencoder-reparameterize
Open

fix: return mu (not mu+std) in VarAutoEncoder reparameterize at eval time (Fixes #8413)#8921
rtmalikian wants to merge 1 commit into
Project-MONAI:devfrom
rtmalikian:fix/issue-8413-varautoencoder-reparameterize

Conversation

@rtmalikian

@rtmalikian rtmalikian commented Jun 18, 2026

Copy link
Copy Markdown

Fixes #8413

Problem

In VarAutoEncoder.reparameterize() and VarFullyConnectedNet.reparameterize(), the method always returns std + mu:

std = torch.exp(0.5 * logvar)
if self.training:
    std = torch.randn_like(std).mul(std)
return std.add_(mu)  # BUG: at eval time, this returns mu + σ instead of mu

At training time, std is correctly replaced with ε · σ (random noise scaled by standard deviation). But at eval time, the raw standard deviation σ = exp(0.5 · logvar) is still added to the mean, producing μ + σ instead of just μ.

The reparameterization trick (Kingma & Welling, 2014) specifies:

  • Training: z = μ + ε · σ where ε ~ N(0,1)
  • Inference: z = μ (no stochastic component)

Adding the standard deviation to the mean at inference time introduces systematic bias — the latent code is shifted away from the learned mean by an amount proportional to the variance.

Solution

Restructured the reparameterize method to return mu directly when not in training mode, avoiding the unnecessary exp(0.5 * logvar) computation:

def reparameterize(self, mu, logvar):
    if self.training:
        std = torch.exp(0.5 * logvar)
        return mu + torch.randn_like(std) * std
    return mu

Applied to both:

  • monai/networks/nets/varautoencoder.pyVarAutoEncoder.reparameterize
  • monai/networks/nets/fullyconnectednet.pyVarFullyConnectedNet.reparameterize

Note: spade_network.py has a different implementation that always samples noise (intentional for GAN-style use), so it was left unchanged.

Verification

Manual verification:

Eval mode: mu=tensor([[1., 2.]]), z=tensor([[1., 2.]])
PASS: In eval mode, reparameterize returns mu

Train mode: z1=tensor([[-0.6354,  2.5485]]), z2=tensor([[0.2493, 2.8168]])
PASS: In train mode, reparameterize returns mu + noise * std

Old buggy result would have been: tensor([[2.2840, 2.7788]])

Existing tests pass (7/7 + 3/3):

tests/networks/nets/test_varautoencoder.py::TestVarAutoEncoder::test_shape_0 PASSED
tests/networks/nets/test_varautoencoder.py::TestVarAutoEncoder::test_shape_1 PASSED
tests/networks/nets/test_varautoencoder.py::TestVarAutoEncoder::test_shape_2 PASSED
tests/networks/nets/test_varautoencoder.py::TestVarAutoEncoder::test_shape_3 PASSED
tests/networks/nets/test_varautoencoder.py::TestVarAutoEncoder::test_shape_4 PASSED
tests/networks/nets/test_varautoencoder.py::TestVarAutoEncoder::test_shape_5 PASSED
tests/networks/nets/test_varautoencoder.py::TestVarAutoEncoder::test_script PASSED

tests/networks/nets/test_fullyconnectednet.py::TestFullyConnectedNet::test_fc_shape_0 PASSED
tests/networks/nets/test_fullyconnectednet.py::TestFullyConnectedNet::test_fc_shape_1 PASSED
tests/networks/nets/test_fullyconnectednet.py::TestFullyConnectedNet::test_vfc_shape_0 PASSED

About the Author: Raphael Malikian — Clinical AI Solutions Architect. I specialise in building and fixing AI/ML systems for healthcare, including vector databases, RAG pipelines, and clinical NLP. If you need help with your project or think I can add value to your organisation, feel free to reach out — I'd love to connect.

📧 rtmalikian@gmail.com
🔗 GitHub: https://github.com/rtmalikian
🔗 LinkedIn: http://www.linkedin.com/in/raphael-t-malikian-mbbs-bsc-hons-71075436a


Disclosure: This code was developed with assistance from mimo-2.5-pro (Xiaomi) via Hermes Agent (Nous Research). All changes were reviewed, tested against the actual codebase, and verified for correctness.

Changelog

Date Change Author
2026-06-18 Initial fix: return μ (not μ+std) in eval mode rtmalikian
2026-06-18 Added DCO sign-off to commit rtmalikian
2026-06-18 Updated PR documentation with changelog rtmalikian

Files Changed

  • monai/networks/nets/varautoencoder.py — Fixed reparameterize() to branch on self.training in both VarAutoEncoder and VarFullyConnectedNet

Verification

  • VarAutoEncoder.reparameterize() returns μ in eval mode, μ + ε·σ in train mode
  • VarFullyConnectedNet.reparameterize() follows same pattern
  • ✅ DCO sign-off present on all commits

@coderabbitai

coderabbitai Bot commented Jun 18, 2026

Copy link
Copy Markdown
Contributor
📝 Walkthrough

Walkthrough

Two reparameterize methods — in VarFullyConnectedNet and VarAutoEncoder — now branch on self.training. During training, std = exp(0.5 * logvar) is computed and noise is applied via mu + randn_like(std) * std. In eval mode, std computation is skipped and mu is returned directly.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~5 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main fix: correcting reparameterize to return mu instead of mu+std at eval time.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed PR description comprehensively covers the problem, solution, and verification with clear examples and test results.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
monai/networks/nets/varautoencoder.py (1)

144-149: 💤 Low value

Fix is correct. The reparameterization trick now properly returns μ in eval mode and μ + ε·σ in training.

Consider adding a docstring per coding guidelines:

📝 Optional docstring addition
     def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
+        """Apply the reparameterization trick for VAE latent sampling.
+
+        Args:
+            mu: Mean of the latent distribution.
+            logvar: Log variance of the latent distribution.
+
+        Returns:
+            Sampled latent vector (stochastic in training, deterministic in eval).
+        """
         if self.training:  # reparameterization trick only during training
             std = torch.exp(0.5 * logvar)
             return mu + torch.randn_like(std) * std

         return mu
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/networks/nets/varautoencoder.py` around lines 144 - 149, The
reparameterize method in the VAE implementation is missing a docstring that
documents its behavior. Add a docstring to the reparameterize method that
explains the purpose of the method, describes the input parameters (mu and
logvar), specifies the return type and value, and documents the different
behavior between training mode (which applies the reparameterization trick with
sampling) and evaluation mode (which returns mu directly).

Source: Coding guidelines

monai/networks/nets/fullyconnectednet.py (1)

174-179: 💤 Low value

Fix is correct and consistent with VarAutoEncoder. Same reparameterization logic applied correctly.

Consider adding a docstring per coding guidelines:

📝 Optional docstring addition
     def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
+        """Apply the reparameterization trick for VAE latent sampling.
+
+        Args:
+            mu: Mean of the latent distribution.
+            logvar: Log variance of the latent distribution.
+
+        Returns:
+            Sampled latent vector (stochastic in training, deterministic in eval).
+        """
         if self.training:  # reparameterization trick only during training
             std = torch.exp(0.5 * logvar)
             return mu + torch.randn_like(std) * std

         return mu
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/networks/nets/fullyconnectednet.py` around lines 174 - 179, The
reparameterize method is missing a docstring per coding guidelines. Add a
docstring to the reparameterize method that explains its purpose, describes the
reparameterization trick applied during training (which samples from the latent
distribution using the mean mu and standard deviation from logvar), and
clarifies that during inference it simply returns the mean mu. Include
documentation for the parameters mu and logvar, and the return value.

Source: Coding guidelines

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@monai/networks/nets/fullyconnectednet.py`:
- Around line 174-179: The reparameterize method is missing a docstring per
coding guidelines. Add a docstring to the reparameterize method that explains
its purpose, describes the reparameterization trick applied during training
(which samples from the latent distribution using the mean mu and standard
deviation from logvar), and clarifies that during inference it simply returns
the mean mu. Include documentation for the parameters mu and logvar, and the
return value.

In `@monai/networks/nets/varautoencoder.py`:
- Around line 144-149: The reparameterize method in the VAE implementation is
missing a docstring that documents its behavior. Add a docstring to the
reparameterize method that explains the purpose of the method, describes the
input parameters (mu and logvar), specifies the return type and value, and
documents the different behavior between training mode (which applies the
reparameterization trick with sampling) and evaluation mode (which returns mu
directly).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 83003db9-4000-46f5-a46f-b1d61cce0d53

📥 Commits

Reviewing files that changed from the base of the PR and between 15f5073 and f8354f3.

📒 Files selected for processing (2)
  • monai/networks/nets/fullyconnectednet.py
  • monai/networks/nets/varautoencoder.py

In VarAutoEncoder and VarFullyConnectedNet, the reparameterize method
computed  and always returned .
At training time,  was correctly replaced with random noise scaled
by the standard deviation. But at eval time, the raw standard deviation
was still added to the mean — giving  instead of just .

The reparameterization trick (Kingma & Welling, 2014) defines:
  - Training:  z = μ + ε · σ  (ε ~ N(0,1))
  - Inference: z = μ           (no stochastic component)

This fix restructures the method to return  directly when not
training, avoiding the unnecessary  computation
and eliminating the incorrect  result at inference.

Fixes Project-MONAI#8413

Signed-off-by: Raphael Malikian <rtmalikian@gmail.com>
@rtmalikian rtmalikian force-pushed the fix/issue-8413-varautoencoder-reparameterize branch from f8354f3 to f74108c Compare June 19, 2026 05:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Error in VarAutoEncoder class

1 participant