Skip to content

[ML] Harden pytorch_inference with TorchScript model graph validation#2936

Open
edsavage wants to merge 30 commits intoelastic:mainfrom
edsavage:feature/harden_pytorch_inference
Open

[ML] Harden pytorch_inference with TorchScript model graph validation#2936
edsavage wants to merge 30 commits intoelastic:mainfrom
edsavage:feature/harden_pytorch_inference

Conversation

@edsavage
Copy link
Contributor

@edsavage edsavage commented Mar 2, 2026

Summary

Implements security hardening for pytorch_inference by validating TorchScript model graphs before execution, addressing elastic/ml-team#1770.

Model Graph Validation (C++)

  • CModelGraphValidator: Validates TorchScript model graphs by inlining all method calls (torch::jit::Inline) and recursively inspecting every node (including sub-blocks inside prim::If/prim::Loop).
  • CSupportedOperations: Defines a dual-list security model:
    • FORBIDDEN_OPERATIONS (4 ops): aten::execute_with_args, aten::from_file, prim::CallFunction, prim::CallMethod — rejected immediately with a clear error.
    • ALLOWED_OPERATIONS (82 ops): Exhaustive allowlist of safe tensor/control-flow ops derived from tracing reference models against PyTorch 2.7.1.
  • Maximum node count (MAX_NODE_COUNT = 1,000,000): Guards against resource exhaustion from excessively large graphs.
  • Debug logging: Observed ops are logged at DEBUG level during validation.

Operation Allowlist Tooling (Python)

  • dev-tools/extract_model_ops/: Self-contained tooling directory with:
    • extract_model_ops.py — generates the C++ allowlist from reference HuggingFace models
    • validate_allowlist.py — integration test verifying no false positives against 24 HuggingFace models + 3 Elasticsearch integration test models
    • reference_models.json / validation_models.json — model configurations
    • es_it_models/ — extracted .pt models from Elasticsearch's PyTorchModelIT, TextExpansionQueryIT, TextEmbeddingQueryIT
    • requirements.txt — pinned to torch==2.7.1 matching the libtorch build version

CI Integration

  • cmake/run-validation.cmake: Portable CMake script that locates Python 3 (searching python3, python3.12, ..., python), manages a virtual environment, handles DYLD_LIBRARY_PATH/LD_LIBRARY_PATH for libtorch conflicts, and runs the validation. Supports OPTIONAL=TRUE for graceful skip when Python or network is unavailable.
  • Wired into test_all_parallel and precommit with OPTIONAL=TRUE — runs automatically when Python is available, skips with a warning otherwise (e.g. in Docker containers without network).
  • Standalone target validate_pytorch_inference_models available for explicit verification (hard failure mode).

C++ Tests

  • Unit tests (CModelGraphValidatorTest.cc): Tests for allowed/forbidden/unrecognised ops, graph inlining, node count enforcement, and integration tests using torch::jit::Module::define().
  • Malicious model tests: 6 generated .pt fixtures testing detection of aten::from_file, hidden ops in submodules, conditional branches, and mixed scenarios.

Test plan

  • All existing C++ unit tests pass (cmake --build ... -t test)
  • New CModelGraphValidator tests pass (50 pytorch_inference test cases)
  • Malicious model fixtures correctly rejected
  • Python validation passes for all 27 models (24 HuggingFace + 3 ES .pt)
  • OPTIONAL=TRUE gracefully skips when Python unavailable
  • CI passes on all platforms (Linux x86_64, Linux aarch64, macOS aarch64, Windows x86_64)
  • CI passes in both RelWithDebInfo and Debug configurations

Made with Cursor

edsavage added 17 commits March 2, 2026 10:54
Add a static TorchScript graph validation layer that rejects models
containing operations not observed in supported transformer architectures.
This reduces the attack surface by ensuring only known-safe operation
sets are permitted, complementing the existing Sandbox2/seccomp defenses.

New files:
- CSupportedOperations: allowlist of 71 ops from 10 reference architectures
- CModelGraphValidator: recursive graph walker and validation logic
- CModelGraphValidatorTest: 10 unit tests covering pass/fail/edge cases
- extract_model_ops.py: developer tool to regenerate the allowlist

Relates to elastic/ml-team#1770

Made-with: Cursor
…onfig

- Move script to dev-tools/extract_model_ops/ subdirectory
- Extract REFERENCE_MODELS dict to reference_models.json config file
- Add requirements.txt for virtual environment setup
- Add README.md with setup, usage, and configuration instructions
- Update CSupportedOperations path references

Made-with: Cursor
…ript

- Add all 10 elastic/* models from HuggingFace to reference_models.json
- Make extract_model_ops.py resilient to individual model load/trace
  failures (continues to next model instead of crashing)
- Add sentencepiece and protobuf to requirements.txt
- Add .gitignore for .venv directory
- Update CSupportedOperations.cc comment with expanded model list
- Op union remains 71 ops (Elastic models use same base architectures)

Made-with: Cursor
Remove bart and elastic/multilingual-e5-small which cannot be
traced or scripted with the current transformers/torch versions.

Made-with: Cursor
Explain why both a short forbidden list and a broad allowed list are
maintained: targeted error messages, safety net against accidental
allowlist expansion, and defence-in-depth.

Made-with: Cursor
Re-ran extraction with torch 2.7.1 (matching the libtorch version
linked by ml-cpp) -- op set is identical to the 2.10.0 run.  Pin
torch version in requirements.txt and fix the comment.

Made-with: Cursor
Aids debugging when a legitimate model is unexpectedly rejected after
a PyTorch upgrade, and provides an audit trail of what was loaded.

Made-with: Cursor
…Method

Use torch::jit::Inline() to flatten method calls before collecting
operations.  This ensures ops hidden behind prim::CallMethod are
surfaced for validation.  After inlining, prim::CallMethod and
prim::CallFunction should not appear; add them to the forbidden
list so any unresolvable call is explicitly rejected.

Made-with: Cursor
Reject models whose inlined computation graph exceeds 1M nodes.
Typical transformer models have O(10k) nodes; the generous limit
prevents pathologically crafted models from causing excessive memory
or CPU usage during graph traversal.

Made-with: Cursor
Construct scriptable modules with define() and validate them through
the full CModelGraphValidator pipeline.  Covers: a valid module with
allowed ops, a module with unrecognised ops, node count tracking, and
a parent/child module pair that exercises graph inlining.

Made-with: Cursor
Adds validate_allowlist.py alongside extract_model_ops.py in
dev-tools/extract_model_ops/.  The script parses ALLOWED_OPERATIONS
and FORBIDDEN_OPERATIONS directly from CSupportedOperations.cc, then
traces every model in validation_models.json and checks for false
positives.

validation_models.json is a superset of reference_models.json that
also includes task-specific models (NER, sentiment analysis) matching
the bin/pytorch_inference/examples/ test data.

A wrapper script (run_validation.sh) automatically creates the Python
venv and installs dependencies on first run.  A CMake target is
registered for convenient invocation:
  cmake --build <build-dir> -t validate_pytorch_inference_models

Made-with: Cursor
Extend the allowlist validation to cover models directly referenced in
the Elasticsearch repo and its eland import tool: the packaged
multilingual-e5-small, the cross-encoder reranker from the docs, the
sentence-transformers embedding model from eland tests, and the DPR
question encoder. All 24 models pass validation with no false positives.

Made-with: Cursor
Extract the base64-encoded TorchScript models from PyTorchModelIT,
TextExpansionQueryIT, and TextEmbeddingQueryIT in the Elasticsearch
repo and validate them against our operation allowlist. These toy
models use basic ops (aten::ones, aten::rand, aten::hash, prim::Loop,
etc.) that weren't in the transformer-derived allowlist, so add them.
All are safe tensor/control-flow operations with no I/O capability.

The validation script now accepts --pt-dir to validate pre-saved .pt
files alongside HuggingFace models. The CMake target passes the new
es_it_models directory automatically.

Made-with: Cursor
Create six malicious .pt model fixtures that exercise specific attack
vectors the CModelGraphValidator must detect:

- malicious_file_reader: uses aten::from_file to read arbitrary files
- malicious_mixed_file_reader: hides aten::from_file among allowed ops
- malicious_hidden_in_submodule: buries unrecognised ops 3 levels deep
- malicious_conditional: hides unrecognised ops inside if-branches
- malicious_many_unrecognised: uses sin/cos/tan/exp (unknown arch)
- malicious_file_reader_in_submodule: forbidden op hidden in child module

Each test loads the real .pt file via torch::jit::load and verifies the
validator correctly identifies and rejects it. Includes the Python
generator script for reproducibility.

Made-with: Cursor
Replace the bash wrapper script with cmake/run-validation.cmake that
works across all CI platforms (Linux, macOS, Windows). The CMake script
searches for python3, python3.12, python3.11, python3.10, python3.9,
and python — handling Linux build machines where Python is only
available as python3.12 (via make altinstall) and Windows where the
canonical name is python. It also prepends the venv's torch/lib
directory to the dynamic library search path to avoid conflicts with
any system-installed libtorch.

Made-with: Cursor
Add the Python allowlist validation as a step in test_all_parallel
(used by CI) and precommit (used by developers). Both use OPTIONAL=TRUE
so the validation is gracefully skipped with a warning when Python 3 is
not available or pip cannot install dependencies (e.g. in Docker
containers without network access). The standalone
validate_pytorch_inference_models target remains a hard failure for
explicit use.

Made-with: Cursor
@prodsecmachine
Copy link

prodsecmachine commented Mar 2, 2026

Snyk checks have passed. No issues have been found so far.

Status Scanner Critical High Medium Low Total (0)
Open Source Security 0 0 0 0 0 issues
Licenses 0 0 0 0 0 issues

💻 Catch issues earlier using the plugins for VS Code, JetBrains IDEs, Visual Studio, and Eclipse.

edsavage added 6 commits March 3, 2026 10:37
Replace relative "../Foo.h" includes with <Foo.h> by adding the parent
source directory to the test target's include path. Also remove
unnecessary backslash escapes in extract_model_ops README.

Made-with: Cursor
Deduplicate collect_graph_ops, graph inlining, and HuggingFace model
loading/tracing logic shared between extract_model_ops.py and
validate_allowlist.py into a common module.

Made-with: Cursor
@edsavage edsavage marked this pull request as ready for review March 9, 2026 20:40
@edsavage edsavage requested a review from valeriy42 March 9, 2026 20:40
@valeriy42
Copy link
Contributor

I have some Python tests that reproduce a simplified attack, stored in #2873. Can you take them and run them against your validation code to ensure that those resource-leakage and similar attacks are not possible with your change?

Add HeapLeakModel and RopExploitModel (from PR elastic#2873) to the malicious
model fixture generator and create corresponding .pt test fixtures.
These reproduce real-world attacks that exploit torch.as_strided to
leak heap addresses and build ROP chains.

Add two new Boost.Test cases in CModelGraphValidatorTest that load
these fixtures and verify the graph validator rejects them due to
unrecognised operations (aten::as_strided, aten::item).

Made-with: Cursor
Add two standalone Python test scripts alongside the existing fixture
generator in bin/pytorch_inference/unittest/testfiles/:

- test_graph_validation_evil_models.py: Pure-Python test that mirrors
  CModelGraphValidator logic (allowlist, forbidden list, recursive block
  traversal, graph inlining) and validates that the sandbox2 attack
  models are rejected.  Useful for fast iteration during allowlist
  development without requiring a C++ build.

- test_pytorch_inference_evil_models.py: End-to-end integration test
  that generates evil models, wraps them in the CBufferedIStreamAdapter
  size-prefixed framing format, and invokes the actual pytorch_inference
  binary to confirm graph validation rejection.

Made-with: Cursor
Move reusable helpers out of test_pytorch_inference_evil_models.py into
pytorch_inference_test_utils.py:

- script_and_save_model(): TorchScript-compile and save any nn.Module
- prepare_restore_file(): wrap a .pt archive with the 4-byte big-endian
  size header that CBufferedIStreamAdapter expects
- find_pytorch_inference(): auto-discover the binary across CMake and
  Gradle build layouts
- run_pytorch_inference(): invoke the binary with correct framing and
  arguments

This makes it straightforward to add new model variants in future test
scripts without duplicating the framing and discovery logic.

Made-with: Cursor
Copy link
Contributor

@valeriy42 valeriy42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good work @edsavage . I left a few comments.

One thing that I cannot completely see through: we have now to traced models in es_it_models/ directory and some in bin/pytorch_inference/testfiles/malicious_models. What is the difference?

Also, the python test files you extracted from the sandbox2 PR verify the C++ logic by rebuilding it in Python, right? It may be confusing after 6 months, once we forgot all about this PR. Is it possible to extract the evil and benign models as pt files and then use them in malicious_models like all other models we test our C++ code against?

Comment on lines +34 to +36
//! security boundary self-contained. The list should be regenerated
//! whenever the set of supported architectures changes or when
//! upgrading the PyTorch version.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a test that would fail if some operations in this list will change due to the PyTorch upgrade?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added testAllowlistCoversReferenceModels. A C++ test that loads a golden JSON file containing per-architecture op sets extracted from 18 reference HuggingFace models. It asserts every op is in ALLOWED_OPERATIONS and none are in FORBIDDEN_OPERATIONS. The golden file is generated by extract_model_ops.py --golden (new flag) and should be regenerated whenever PyTorch is upgraded. If a new PyTorch version introduces new ops for any supported model, the C++ test fails until the allowlist is updated.

Comment on lines +28 to +36
// Generated by dev-tools/extract_model_ops/extract_model_ops.py against PyTorch 2.7.1.
// Reference models: bert-base-uncased, roberta-base, distilbert-base-uncased,
// google/electra-small-discriminator, microsoft/mpnet-base,
// microsoft/deberta-base, facebook/dpr-ctx_encoder-single-nq-base,
// google/mobilebert-uncased, xlm-roberta-base, elastic/bge-m3,
// elastic/distilbert-base-{cased,uncased}-finetuned-conll03-english,
// elastic/eis-elser-v2, elastic/elser-v2, elastic/hugging-face-elser,
// elastic/multilingual-e5-small-optimized, elastic/splade-v3,
// elastic/test-elser-v2.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow 😲 This is a long list!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's the union across 10+ architectures plus Elastic-specific and ES integration test models. dev-tools/extract_model_ops/extract_model_ops.py can regenerate it when architectures or PyTorch versions change.

TStringSet& ops,
std::size_t& nodeCount) {
for (const auto* node : block.nodes()) {
++nodeCount;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While you are inlinig and collection ops, you need to check agains MAX_NODE_COUNT to prevent resource exhaustion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's to large, than it's invalid. I would prefer failing fast and not going through other validation steps.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. collectBlockOps now checks nodeCount > MAX_NODE_COUNT after each increment and returns immediately. collectModuleOps also bails out between methods. The validate() entry point rejects oversized graphs before proceeding to op matching.

for (const auto& op : observedOps) {
if (forbiddenOps.contains(op)) {
result.s_IsValid = false;
result.s_ForbiddenOps.push_back(op);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you early exist when result.s_IsValid is false? We are going to drop this model anyway so there is no reason to go through all 1M operations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I restructured this as a two-pass check: forbidden ops first, then unrecognised. If any forbidden ops are found we skip the unrecognised-op scan entirely.

<< " nodes, all operations match supported architectures.");
} catch (const c10::Error& e) {
LOG_FATAL(<< "Failed to get forward method: " << e.what());
LOG_FATAL(<< "Model graph validation failed: " << e.what());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, This should be HANDLE_FATAL to prevent unvalidated models from being evaluated.


=== Enhancements

* Harden pytorch_inference with TorchScript model graph validation. (See {ml-pull}[#2936].)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Harden pytorch_inference with TorchScript model graph validation. (See {ml-pull}[#2936].)
* Harden pytorch_inference with TorchScript model graph validation. (See {ml-pull}2936[#2936].)

Comment on lines +404 to +410
// --- Sandbox2 attack models (PR #2873) ---
//
// These reproduce real-world attack vectors that exploit torch.as_strided
// to read out-of-bounds heap memory, leak libtorch addresses, and build
// ROP chains that call mprotect + shellcode to write arbitrary files.
// The graph validator must reject them because aten::as_strided (and
// several helper ops like aten::item) are not in the allowlist.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need of referencing another PR. Does it make sense to add aten::as_strided to the forbidden list?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wit this tests in place, I wonder if we can remove test_*_evil_models.py as redundant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I've removed all three Python test scripts. The C++ integration tests with .pt fixtures cover the same attack models. generate_malicious_models.py is kept for fixture regeneration.

# compliance with the Elastic License 2.0 and the foregoing additional
# limitation.
#
"""Generate malicious TorchScript model fixtures for validator integration tests.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this file be moved to dev-tools/ along other files you created?

- Check MAX_NODE_COUNT during graph traversal to prevent resource
  exhaustion on pathologically large models (bail out immediately
  in collectBlockOps and collectModuleOps).
- Two-pass validation: check forbidden ops first, skip unrecognised
  op scan when forbidden ops are found.
- Add aten::as_strided to FORBIDDEN_OPERATIONS (key enabler of
  heap-leak and ROP chain attacks).
- Change LOG_FATAL to HANDLE_FATAL in the c10::Error catch block
  so an exception during validation terminates the process.
- Fix CHANGELOG asciidoc link syntax.
- Move generate_malicious_models.py to dev-tools/.
- Remove redundant Python test scripts now that C++ integration
  tests cover the same attack models.
- Remove PR cross-references from comments per reviewer request.

Made-with: Cursor
Add a C++ test (testAllowlistCoversReferenceModels) that loads a
golden JSON file containing per-architecture TorchScript op sets
extracted from 18 reference HuggingFace models and verifies every
op is in ALLOWED_OPERATIONS and none are in FORBIDDEN_OPERATIONS.

This catches allowlist regressions in CI without requiring Python
or network access.  When PyTorch is upgraded, regenerate the golden
file with:

  python3 extract_model_ops.py --golden \
    bin/pytorch_inference/unittest/testfiles/reference_model_ops.json

The --golden flag is a new addition to extract_model_ops.py that
outputs per-model op sets as structured JSON.

Made-with: Cursor
@edsavage edsavage requested a review from valeriy42 March 12, 2026 03:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants