-
Notifications
You must be signed in to change notification settings - Fork 144
Fix TQDM_DISABLE=1 compatibility issue with cmdstanpy #888
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,116 @@ | ||||||||||||||
| """ | ||||||||||||||
| Compatibility utilities for cmdstanpy integration. | ||||||||||||||
|
|
||||||||||||||
| This module contains patches and workarounds for cmdstanpy compatibility issues. | ||||||||||||||
| """ | ||||||||||||||
|
|
||||||||||||||
| import os | ||||||||||||||
| from typing import Dict, List, Optional, Callable | ||||||||||||||
|
|
||||||||||||||
| from .logger import get_logger | ||||||||||||||
|
|
||||||||||||||
| logger = get_logger("orbit") | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| def patch_tqdm_progress_hook(): | ||||||||||||||
| """ | ||||||||||||||
| Patch cmdstanpy progress hook to handle TQDM_DISABLE safely. | ||||||||||||||
|
|
||||||||||||||
| When TQDM_DISABLE=1 is set, tqdm creates disabled progress bar objects | ||||||||||||||
| that don't have the 'postfix' attribute. cmdstanpy assumes this attribute | ||||||||||||||
| exists and tries to access it, causing AttributeError. | ||||||||||||||
|
|
||||||||||||||
| This patch adds safe access checks to prevent the crash. | ||||||||||||||
|
|
||||||||||||||
| See: https://github.com/uber/orbit/issues/887 | ||||||||||||||
| """ | ||||||||||||||
| # Only patch if TQDM_DISABLE is set | ||||||||||||||
| if os.environ.get("TQDM_DISABLE") != "1": | ||||||||||||||
|
Comment on lines
+27
to
+28
|
||||||||||||||
| # Only patch if TQDM_DISABLE is set | |
| if os.environ.get("TQDM_DISABLE") != "1": | |
| # Only patch when tqdm is effectively disabled via TQDM_DISABLE | |
| env_value = os.environ.get("TQDM_DISABLE", "0") | |
| tqdm_disabled = env_value.lower() not in ("0", "false", "", "no") | |
| if not tqdm_disabled: |
Copilot
AI
Mar 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The re module is imported inside the try block at line 33, and is captured by the safe_wrap_sampler_progress_hook closure. While this works correctly, re is a standard library module. Importing it at the top of the file (next to the import os) would be cleaner and more consistent with Python conventions and with the rest of the codebase.
Copilot
AI
Mar 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The @staticmethod decorator is applied to a function defined inside another function (line 45). This is an unusual pattern that creates a staticmethod descriptor object as a local variable. Assigning this descriptor to the class attribute at line 109 works correctly through Python's descriptor protocol, but this pattern can be confusing for maintainers who may not be familiar with it. A clearer equivalent would be to define the inner function without the @staticmethod decorator, and then use staticmethod(safe_wrap_sampler_progress_hook) when assigning to the class attribute on line 109. This is functionally identical but makes the intent clearer.
Copilot
AI
Mar 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When TQDM_DISABLE=1, the purpose is to disable all progress output. The safe_wrap_sampler_progress_hook function still creates tqdm progress bar instances (lines 55–64) even in this case — they are silently disabled by tqdm, but creating them adds unnecessary overhead. A simpler and more direct solution would be to return a no-op callable from safe_wrap_sampler_progress_hook when TQDM_DISABLE=1 is set, rather than creating progress bar objects with safe guards. The no-op callable would satisfy cmdstanpy's requirement for a callable while completely skipping progress bar logic:
def progress_hook(line: str, idx: int) -> None:
pass
return progress_hookThis makes the intent clearer and avoids any future tqdm API compatibility issues.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| import os | ||
| import pytest | ||
| from unittest.mock import patch, MagicMock | ||
|
|
||
| from orbit.utils.cmdstanpy_compat import patch_tqdm_progress_hook | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "env_value", | ||
| [None, "0", "false", "true", ""], | ||
| ) | ||
| def test_patch_tqdm_progress_hook_no_patch_scenarios(env_value): | ||
| """Test that patch is not applied when TQDM_DISABLE is not '1'.""" | ||
| env_dict = {"TQDM_DISABLE": env_value} if env_value is not None else {} | ||
|
|
||
| with patch.dict(os.environ, env_dict, clear=False): | ||
| if env_value is None and "TQDM_DISABLE" in os.environ: | ||
| del os.environ["TQDM_DISABLE"] | ||
|
|
||
| # Should return early without doing anything | ||
| patch_tqdm_progress_hook() | ||
| # Test passes if no exception raised | ||
|
|
||
|
|
||
| def test_patch_tqdm_progress_hook_applies_patch(): | ||
| """Test that patch is applied when TQDM_DISABLE=1.""" | ||
| with patch.dict(os.environ, {"TQDM_DISABLE": "1"}): | ||
| with patch("cmdstanpy.model") as mock_cmdstanpy_model: | ||
| mock_model = MagicMock() | ||
| mock_model._wrap_sampler_progress_hook = MagicMock() | ||
| # Ensure not already patched | ||
| del mock_model._orbit_tqdm_patched | ||
| mock_cmdstanpy_model.CmdStanModel = mock_model | ||
|
|
||
| patch_tqdm_progress_hook() | ||
|
|
||
| assert mock_model._orbit_tqdm_patched is True | ||
|
|
||
|
|
||
| def test_patch_tqdm_progress_hook_no_double_patch(): | ||
| """Test that patch is not applied multiple times.""" | ||
| with patch.dict(os.environ, {"TQDM_DISABLE": "1"}): | ||
| with patch("cmdstanpy.model") as mock_cmdstanpy_model: | ||
| mock_model = MagicMock() | ||
| mock_model._orbit_tqdm_patched = True # Already patched | ||
| original_hook = MagicMock() | ||
| mock_model._wrap_sampler_progress_hook = original_hook | ||
| mock_cmdstanpy_model.CmdStanModel = mock_model | ||
|
|
||
| patch_tqdm_progress_hook() | ||
|
|
||
| # Original hook should remain unchanged | ||
| assert mock_model._wrap_sampler_progress_hook is original_hook | ||
|
|
||
|
|
||
| def test_patch_tqdm_progress_hook_handles_missing_method(): | ||
| """Test graceful handling when original method doesn't exist.""" | ||
| with patch.dict(os.environ, {"TQDM_DISABLE": "1"}): | ||
| with patch("cmdstanpy.model") as mock_cmdstanpy_model: | ||
| # Simple object without the method | ||
| mock_model = type("MockModel", (), {})() | ||
| mock_cmdstanpy_model.CmdStanModel = mock_model | ||
|
|
||
| # Should not raise exception | ||
| patch_tqdm_progress_hook() | ||
|
|
||
| # Should not set patched flag | ||
| assert not hasattr(mock_model, "_orbit_tqdm_patched") | ||
|
|
||
|
|
||
| def test_patch_tqdm_progress_hook_handles_import_error(): | ||
| """Test graceful handling of import errors.""" | ||
| with patch.dict(os.environ, {"TQDM_DISABLE": "1"}): | ||
| with patch("cmdstanpy.model", side_effect=ImportError("No module")): | ||
| # Should not raise exception | ||
| patch_tqdm_progress_hook() | ||
|
|
||
|
|
||
| def test_integration_with_actual_cmdstanpy(): | ||
| """Integration test with actual cmdstanpy if available.""" | ||
| cmdstanpy = pytest.importorskip("cmdstanpy") | ||
|
|
||
| with patch.dict(os.environ, {"TQDM_DISABLE": "1"}): | ||
| # Store original state | ||
| original_method = getattr( | ||
| cmdstanpy.model.CmdStanModel, "_wrap_sampler_progress_hook", None | ||
| ) | ||
|
|
||
| try: | ||
| patch_tqdm_progress_hook() | ||
|
|
||
| # Verify patch was applied | ||
| assert hasattr(cmdstanpy.model.CmdStanModel, "_orbit_tqdm_patched") | ||
| assert cmdstanpy.model.CmdStanModel._orbit_tqdm_patched is True | ||
|
|
||
| finally: | ||
| # Clean up | ||
| if hasattr(cmdstanpy.model.CmdStanModel, "_orbit_tqdm_patched"): | ||
| delattr(cmdstanpy.model.CmdStanModel, "_orbit_tqdm_patched") | ||
| if original_method is not None: | ||
| cmdstanpy.model.CmdStanModel._wrap_sampler_progress_hook = ( | ||
| original_method | ||
| ) | ||
|
Comment on lines
+96
to
+103
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
patch_tqdm_progress_hook()function is called unconditionally at module load time (line 20), which means the patch is only applied ifTQDM_DISABLE=1was set beforeorbit/estimators/stan_estimator.pywas first imported. If a user setsTQDM_DISABLE=1after importing orbit, the patch will not be applied, and the same crash will occur. While this matches the specific scenario described in issue #887, this limitation is not documented, which could lead to user confusion. Consider adding a note to the docstring or a logged warning ifTQDM_DISABLE=1is set but the patch has already been skipped.