Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion orbit/estimators/stan_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@
import multiprocessing
from abc import abstractmethod
from copy import copy
from sys import platform, version_info

from ..exceptions import EstimatorException
from ..utils.general import update_dict
from ..utils.logger import get_logger
from ..utils.set_cmdstan_path import set_cmdstan_path
from ..utils.stan import get_compiled_stan_model
from ..utils.cmdstanpy_compat import patch_tqdm_progress_hook
from .base_estimator import BaseEstimator

logger = get_logger("orbit")

# Make sure models are using the right cmdstan folder
set_cmdstan_path()

# Apply cmdstanpy compatibility patches
patch_tqdm_progress_hook()
Comment on lines +19 to +20
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The patch_tqdm_progress_hook() function is called unconditionally at module load time (line 20), which means the patch is only applied if TQDM_DISABLE=1 was set before orbit/estimators/stan_estimator.py was first imported. If a user sets TQDM_DISABLE=1 after importing orbit, the patch will not be applied, and the same crash will occur. While this matches the specific scenario described in issue #887, this limitation is not documented, which could lead to user confusion. Consider adding a note to the docstring or a logged warning if TQDM_DISABLE=1 is set but the patch has already been skipped.

Copilot uses AI. Check for mistakes.


class StanEstimator(BaseEstimator):
"""Abstract StanEstimator with shared args for all StanEstimator child classes
Expand Down
116 changes: 116 additions & 0 deletions orbit/utils/cmdstanpy_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""
Compatibility utilities for cmdstanpy integration.

This module contains patches and workarounds for cmdstanpy compatibility issues.
"""

import os
from typing import Dict, List, Optional, Callable

from .logger import get_logger

logger = get_logger("orbit")


def patch_tqdm_progress_hook():
"""
Patch cmdstanpy progress hook to handle TQDM_DISABLE safely.

When TQDM_DISABLE=1 is set, tqdm creates disabled progress bar objects
that don't have the 'postfix' attribute. cmdstanpy assumes this attribute
exists and tries to access it, causing AttributeError.

This patch adds safe access checks to prevent the crash.

See: https://github.com/uber/orbit/issues/887
"""
# Only patch if TQDM_DISABLE is set
if os.environ.get("TQDM_DISABLE") != "1":
Comment on lines +27 to +28
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The guard on line 28 only triggers the patch when TQDM_DISABLE is exactly "1". However, tqdm treats several other values as truthy for disabling progress bars (e.g., "true", "yes", "True"). Users who set TQDM_DISABLE=true or TQDM_DISABLE=yes would still encounter the original crash, as the patch would not be applied. The condition should be broadened to cover all values that tqdm considers as "disabled", for example by checking os.environ.get("TQDM_DISABLE", "0").lower() not in ("0", "false", "", "no"), or by checking whether tqdm.tqdm.is_disabled() returns True.

Suggested change
# Only patch if TQDM_DISABLE is set
if os.environ.get("TQDM_DISABLE") != "1":
# Only patch when tqdm is effectively disabled via TQDM_DISABLE
env_value = os.environ.get("TQDM_DISABLE", "0")
tqdm_disabled = env_value.lower() not in ("0", "false", "", "no")
if not tqdm_disabled:

Copilot uses AI. Check for mistakes.
return

try:
import cmdstanpy.model
import re
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The re module is imported inside the try block at line 33, and is captured by the safe_wrap_sampler_progress_hook closure. While this works correctly, re is a standard library module. Importing it at the top of the file (next to the import os) would be cleaner and more consistent with Python conventions and with the rest of the codebase.

Copilot uses AI. Check for mistakes.

# Store reference to original method to avoid patching multiple times
if hasattr(cmdstanpy.model.CmdStanModel, "_orbit_tqdm_patched"):
return

original_hook = getattr(
cmdstanpy.model.CmdStanModel, "_wrap_sampler_progress_hook", None
)
if original_hook is None:
return

@staticmethod
def safe_wrap_sampler_progress_hook(
Comment on lines +45 to +46
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The @staticmethod decorator is applied to a function defined inside another function (line 45). This is an unusual pattern that creates a staticmethod descriptor object as a local variable. Assigning this descriptor to the class attribute at line 109 works correctly through Python's descriptor protocol, but this pattern can be confusing for maintainers who may not be familiar with it. A clearer equivalent would be to define the inner function without the @staticmethod decorator, and then use staticmethod(safe_wrap_sampler_progress_hook) when assigning to the class attribute on line 109. This is functionally identical but makes the intent clearer.

Copilot uses AI. Check for mistakes.
chain_ids: List[int],
total: int,
) -> Optional[Callable[[str, int], None]]:
"""Safe version that handles disabled tqdm progress bars."""
try:
from tqdm import tqdm

pat = re.compile(r"Chain \[(\d*)\] (Iteration.*)")
pbars: Dict[int, tqdm] = {
chain_id: tqdm(
total=total,
bar_format="{desc} |{bar}| {elapsed} {postfix[0][value]}",
postfix=[{"value": "Status"}],
desc=f"chain {chain_id}",
colour="yellow",
)
for chain_id in chain_ids
}

def progress_hook(line: str, idx: int) -> None:
if line == "Done":
for pbar in pbars.values():
# safe postfix access
if hasattr(pbar, "postfix") and pbar.postfix:
try:
pbar.postfix[0]["value"] = "Sampling completed"
except (AttributeError, KeyError, IndexError):
pass
pbar.update(total - pbar.n)
pbar.close()
else:
match = pat.match(line)
if match:
idx = int(match.group(1))
mline = match.group(2).strip()
elif line.startswith("Iteration"):
mline = line
idx = chain_ids[idx]
else:
return

if idx in pbars:
if "Sampling" in mline and hasattr(pbars[idx], "colour"):
pbars[idx].colour = "blue"
pbars[idx].update(1)

# safe postfix access
if hasattr(pbars[idx], "postfix") and pbars[idx].postfix:
try:
pbars[idx].postfix[0]["value"] = mline
except (AttributeError, KeyError, IndexError):
pass

return progress_hook

except Exception as e:
logger.warning(
f"Progress bar setup failed: {e}. Disabling progress bars."
)
return None
Comment on lines +51 to +106
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When TQDM_DISABLE=1, the purpose is to disable all progress output. The safe_wrap_sampler_progress_hook function still creates tqdm progress bar instances (lines 55–64) even in this case — they are silently disabled by tqdm, but creating them adds unnecessary overhead. A simpler and more direct solution would be to return a no-op callable from safe_wrap_sampler_progress_hook when TQDM_DISABLE=1 is set, rather than creating progress bar objects with safe guards. The no-op callable would satisfy cmdstanpy's requirement for a callable while completely skipping progress bar logic:

def progress_hook(line: str, idx: int) -> None:
    pass
return progress_hook

This makes the intent clearer and avoids any future tqdm API compatibility issues.

Copilot uses AI. Check for mistakes.

# apply the patch
cmdstanpy.model.CmdStanModel._wrap_sampler_progress_hook = (
safe_wrap_sampler_progress_hook
)
cmdstanpy.model.CmdStanModel._orbit_tqdm_patched = True
logger.debug("cmdstanpy progress hook patched for TQDM_DISABLE compatibility")

except Exception as e:
logger.warning(f"Failed to patch cmdstanpy progress hook: {e}")
103 changes: 103 additions & 0 deletions tests/orbit/utils/test_cmdstanpy_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import os
import pytest
from unittest.mock import patch, MagicMock

from orbit.utils.cmdstanpy_compat import patch_tqdm_progress_hook


@pytest.mark.parametrize(
"env_value",
[None, "0", "false", "true", ""],
)
def test_patch_tqdm_progress_hook_no_patch_scenarios(env_value):
"""Test that patch is not applied when TQDM_DISABLE is not '1'."""
env_dict = {"TQDM_DISABLE": env_value} if env_value is not None else {}

with patch.dict(os.environ, env_dict, clear=False):
if env_value is None and "TQDM_DISABLE" in os.environ:
del os.environ["TQDM_DISABLE"]

# Should return early without doing anything
patch_tqdm_progress_hook()
# Test passes if no exception raised


def test_patch_tqdm_progress_hook_applies_patch():
"""Test that patch is applied when TQDM_DISABLE=1."""
with patch.dict(os.environ, {"TQDM_DISABLE": "1"}):
with patch("cmdstanpy.model") as mock_cmdstanpy_model:
mock_model = MagicMock()
mock_model._wrap_sampler_progress_hook = MagicMock()
# Ensure not already patched
del mock_model._orbit_tqdm_patched
mock_cmdstanpy_model.CmdStanModel = mock_model

patch_tqdm_progress_hook()

assert mock_model._orbit_tqdm_patched is True


def test_patch_tqdm_progress_hook_no_double_patch():
"""Test that patch is not applied multiple times."""
with patch.dict(os.environ, {"TQDM_DISABLE": "1"}):
with patch("cmdstanpy.model") as mock_cmdstanpy_model:
mock_model = MagicMock()
mock_model._orbit_tqdm_patched = True # Already patched
original_hook = MagicMock()
mock_model._wrap_sampler_progress_hook = original_hook
mock_cmdstanpy_model.CmdStanModel = mock_model

patch_tqdm_progress_hook()

# Original hook should remain unchanged
assert mock_model._wrap_sampler_progress_hook is original_hook


def test_patch_tqdm_progress_hook_handles_missing_method():
"""Test graceful handling when original method doesn't exist."""
with patch.dict(os.environ, {"TQDM_DISABLE": "1"}):
with patch("cmdstanpy.model") as mock_cmdstanpy_model:
# Simple object without the method
mock_model = type("MockModel", (), {})()
mock_cmdstanpy_model.CmdStanModel = mock_model

# Should not raise exception
patch_tqdm_progress_hook()

# Should not set patched flag
assert not hasattr(mock_model, "_orbit_tqdm_patched")


def test_patch_tqdm_progress_hook_handles_import_error():
"""Test graceful handling of import errors."""
with patch.dict(os.environ, {"TQDM_DISABLE": "1"}):
with patch("cmdstanpy.model", side_effect=ImportError("No module")):
# Should not raise exception
patch_tqdm_progress_hook()


def test_integration_with_actual_cmdstanpy():
"""Integration test with actual cmdstanpy if available."""
cmdstanpy = pytest.importorskip("cmdstanpy")

with patch.dict(os.environ, {"TQDM_DISABLE": "1"}):
# Store original state
original_method = getattr(
cmdstanpy.model.CmdStanModel, "_wrap_sampler_progress_hook", None
)

try:
patch_tqdm_progress_hook()

# Verify patch was applied
assert hasattr(cmdstanpy.model.CmdStanModel, "_orbit_tqdm_patched")
assert cmdstanpy.model.CmdStanModel._orbit_tqdm_patched is True

finally:
# Clean up
if hasattr(cmdstanpy.model.CmdStanModel, "_orbit_tqdm_patched"):
delattr(cmdstanpy.model.CmdStanModel, "_orbit_tqdm_patched")
if original_method is not None:
cmdstanpy.model.CmdStanModel._wrap_sampler_progress_hook = (
original_method
)
Comment on lines +96 to +103
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test_integration_with_actual_cmdstanpy test has a cleanup issue: the finally block (lines 97–103) unconditionally removes _orbit_tqdm_patched and restores the original method. If patch_tqdm_progress_hook() was already called during the import of stan_estimator.py (because TQDM_DISABLE=1 was set at import time), this cleanup will remove the module-level patch that was applied at import time. The method will be restored while _orbit_tqdm_patched is deleted, leaving the class in a partially inconsistent state for subsequent test runs. The cleanup should check whether the patch was applied by this test's invocation specifically, or only clean up if the test itself applied the patch (e.g., by checking whether the flag was absent before the call).

Copilot uses AI. Check for mistakes.