Skip to content
22 changes: 14 additions & 8 deletions monai/inferers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,14 +243,19 @@ def sliding_window_inference(
for idx in slice_range
]
if sw_batch_size > 1:
win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
win_data = torch.cat([inputs[ensure_tuple(win_slice)] for win_slice in unravel_slice]).to(sw_device)
if condition is not None:
win_condition = torch.cat([condition[win_slice] for win_slice in unravel_slice]).to(sw_device)
win_condition = torch.cat([condition[ensure_tuple(win_slice)] for win_slice in unravel_slice]).to(
sw_device
)
kwargs["condition"] = win_condition
else:
win_data = inputs[unravel_slice[0]].to(sw_device)
s0 = unravel_slice[0]
s0_idx = ensure_tuple(s0)

win_data = inputs[s0_idx].to(sw_device)
if condition is not None:
win_condition = condition[unravel_slice[0]].to(sw_device)
win_condition = condition[s0_idx].to(sw_device)
kwargs["condition"] = win_condition

if with_coord:
Expand All @@ -277,7 +282,7 @@ def sliding_window_inference(
offset = s[buffer_dim + 2].start - c_start
s[buffer_dim + 2] = slice(offset, offset + roi_size[buffer_dim])
s[0] = slice(0, 1)
sw_device_buffer[0][s] += p * w_t
sw_device_buffer[0][ensure_tuple(s)] += p * w_t
b_i += len(unravel_slice)
if b_i < b_slices[b_s][0]:
continue
Expand Down Expand Up @@ -308,10 +313,11 @@ def sliding_window_inference(
o_slice[buffer_dim + 2] = slice(c_start, c_end)
img_b = b_s // n_per_batch # image batch index
o_slice[0] = slice(img_b, img_b + 1)
o_slice_idx = ensure_tuple(o_slice)
if non_blocking:
output_image_list[0][o_slice].copy_(sw_device_buffer[0], non_blocking=non_blocking)
output_image_list[0][o_slice_idx].copy_(sw_device_buffer[0], non_blocking=non_blocking)
else:
output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device)
output_image_list[0][o_slice_idx] += sw_device_buffer[0].to(device=device)
else:
sw_device_buffer[ss] *= w_t
sw_device_buffer[ss] = sw_device_buffer[ss].to(device)
Expand Down Expand Up @@ -387,7 +393,7 @@ def _compute_coords(coords, z_scale, out, patch):
idx_zm[axis] = slice(
int(original_idx[axis].start * z_scale[axis - 2]), int(original_idx[axis].stop * z_scale[axis - 2])
)
out[idx_zm] += p
out[ensure_tuple(idx_zm)] += p


def _get_scan_interval(
Expand Down
94 changes: 94 additions & 0 deletions tests/inferers/test_sliding_window_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from monai.data.utils import list_data_collate
from monai.inferers import SlidingWindowInferer, SlidingWindowInfererAdapt, sliding_window_inference
from monai.inferers.utils import _compute_coords
from monai.utils import optional_import
from tests.test_utils import TEST_TORCH_AND_META_TENSORS, skip_if_no_cuda, test_is_quick

Expand Down Expand Up @@ -704,6 +705,99 @@ def compute_dict(data, condition):
for rr, _ in zip(result_dict, expected_dict):
np.testing.assert_allclose(result_dict[rr].cpu().numpy(), expected_dict[rr], rtol=1e-4)

@parameterized.expand([(1,), (4,)])
def test_conditioned_branches_and_buffered_parity(self, sw_batch_size):
"""Validate conditioned parity between buffered and non-buffered flows.

Args:
sw_batch_size (int): Sliding-window batch size.

Returns:
None.

Raises:
AssertionError: If device, conditioning alignment, or output parity checks fail.
"""
inputs = torch.arange(1 * 1 * 10 * 8, dtype=torch.float).reshape(1, 1, 10, 8)
condition = inputs + 100.0
roi_shape = (4, 4)

def compute(data, condition):
"""Compute output for a conditioned patch.

Args:
data (torch.Tensor): Input patch tensor.
condition (torch.Tensor): Conditioning patch tensor aligned to ``data``.

Returns:
torch.Tensor: Element-wise ``data + condition``.

Raises:
AssertionError: If device placement or conditioning alignment checks fail.
"""
self.assertEqual(data.device.type, "cpu")
self.assertEqual(condition.device.type, "cpu")
torch.testing.assert_close(condition - data, torch.full_like(data, 100.0))
return data + condition

# Non-buffered flow.
result_non_buffered = sliding_window_inference(
inputs, roi_shape, sw_batch_size, compute, overlap=0.5, mode="constant", condition=condition
)
# Buffered flow; should match the non-buffered output.
result_buffered = sliding_window_inference(
inputs,
roi_shape,
sw_batch_size,
compute,
overlap=0.5,
mode="constant",
condition=condition,
buffer_steps=2,
buffer_dim=0,
)

expected = inputs + condition
torch.testing.assert_close(result_non_buffered, expected)
torch.testing.assert_close(result_buffered, expected)
torch.testing.assert_close(result_buffered, result_non_buffered)


class TestSlidingWindowUtils(unittest.TestCase):
"""Tests for low-level sliding-window utility helpers.

Args:
None.

Returns:
None.

Raises:
None.
"""

def test_compute_coords_accepts_list_indices(self):
"""Ensure ``_compute_coords`` handles list-based index containers.

Args:
None.

Returns:
None.

Raises:
AssertionError: If computed output placement differs from expected placement.
"""
out = torch.zeros((1, 1, 12, 12), dtype=torch.float)
patch = torch.arange(16, dtype=torch.float).reshape(1, 1, 4, 4)
coords = [[slice(0, 1), slice(None), slice(1, 3), slice(2, 4)]]

_compute_coords(coords=coords, z_scale=[2.0, 2.0], out=out, patch=patch)

expected = torch.zeros_like(out)
expected[0, 0, 2:6, 4:8] = patch[0, 0]
torch.testing.assert_close(out, expected)


if __name__ == "__main__":
unittest.main()
Loading