[ML] Harden pytorch_inference with TorchScript model graph validation#2936
[ML] Harden pytorch_inference with TorchScript model graph validation#2936edsavage wants to merge 30 commits intoelastic:mainfrom
Conversation
Add a static TorchScript graph validation layer that rejects models containing operations not observed in supported transformer architectures. This reduces the attack surface by ensuring only known-safe operation sets are permitted, complementing the existing Sandbox2/seccomp defenses. New files: - CSupportedOperations: allowlist of 71 ops from 10 reference architectures - CModelGraphValidator: recursive graph walker and validation logic - CModelGraphValidatorTest: 10 unit tests covering pass/fail/edge cases - extract_model_ops.py: developer tool to regenerate the allowlist Relates to elastic/ml-team#1770 Made-with: Cursor
…onfig - Move script to dev-tools/extract_model_ops/ subdirectory - Extract REFERENCE_MODELS dict to reference_models.json config file - Add requirements.txt for virtual environment setup - Add README.md with setup, usage, and configuration instructions - Update CSupportedOperations path references Made-with: Cursor
…ript - Add all 10 elastic/* models from HuggingFace to reference_models.json - Make extract_model_ops.py resilient to individual model load/trace failures (continues to next model instead of crashing) - Add sentencepiece and protobuf to requirements.txt - Add .gitignore for .venv directory - Update CSupportedOperations.cc comment with expanded model list - Op union remains 71 ops (Elastic models use same base architectures) Made-with: Cursor
Remove bart and elastic/multilingual-e5-small which cannot be traced or scripted with the current transformers/torch versions. Made-with: Cursor
Explain why both a short forbidden list and a broad allowed list are maintained: targeted error messages, safety net against accidental allowlist expansion, and defence-in-depth. Made-with: Cursor
Re-ran extraction with torch 2.7.1 (matching the libtorch version linked by ml-cpp) -- op set is identical to the 2.10.0 run. Pin torch version in requirements.txt and fix the comment. Made-with: Cursor
Aids debugging when a legitimate model is unexpectedly rejected after a PyTorch upgrade, and provides an audit trail of what was loaded. Made-with: Cursor
…Method Use torch::jit::Inline() to flatten method calls before collecting operations. This ensures ops hidden behind prim::CallMethod are surfaced for validation. After inlining, prim::CallMethod and prim::CallFunction should not appear; add them to the forbidden list so any unresolvable call is explicitly rejected. Made-with: Cursor
Reject models whose inlined computation graph exceeds 1M nodes. Typical transformer models have O(10k) nodes; the generous limit prevents pathologically crafted models from causing excessive memory or CPU usage during graph traversal. Made-with: Cursor
Construct scriptable modules with define() and validate them through the full CModelGraphValidator pipeline. Covers: a valid module with allowed ops, a module with unrecognised ops, node count tracking, and a parent/child module pair that exercises graph inlining. Made-with: Cursor
Made-with: Cursor
Adds validate_allowlist.py alongside extract_model_ops.py in dev-tools/extract_model_ops/. The script parses ALLOWED_OPERATIONS and FORBIDDEN_OPERATIONS directly from CSupportedOperations.cc, then traces every model in validation_models.json and checks for false positives. validation_models.json is a superset of reference_models.json that also includes task-specific models (NER, sentiment analysis) matching the bin/pytorch_inference/examples/ test data. A wrapper script (run_validation.sh) automatically creates the Python venv and installs dependencies on first run. A CMake target is registered for convenient invocation: cmake --build <build-dir> -t validate_pytorch_inference_models Made-with: Cursor
Extend the allowlist validation to cover models directly referenced in the Elasticsearch repo and its eland import tool: the packaged multilingual-e5-small, the cross-encoder reranker from the docs, the sentence-transformers embedding model from eland tests, and the DPR question encoder. All 24 models pass validation with no false positives. Made-with: Cursor
Extract the base64-encoded TorchScript models from PyTorchModelIT, TextExpansionQueryIT, and TextEmbeddingQueryIT in the Elasticsearch repo and validate them against our operation allowlist. These toy models use basic ops (aten::ones, aten::rand, aten::hash, prim::Loop, etc.) that weren't in the transformer-derived allowlist, so add them. All are safe tensor/control-flow operations with no I/O capability. The validation script now accepts --pt-dir to validate pre-saved .pt files alongside HuggingFace models. The CMake target passes the new es_it_models directory automatically. Made-with: Cursor
Create six malicious .pt model fixtures that exercise specific attack vectors the CModelGraphValidator must detect: - malicious_file_reader: uses aten::from_file to read arbitrary files - malicious_mixed_file_reader: hides aten::from_file among allowed ops - malicious_hidden_in_submodule: buries unrecognised ops 3 levels deep - malicious_conditional: hides unrecognised ops inside if-branches - malicious_many_unrecognised: uses sin/cos/tan/exp (unknown arch) - malicious_file_reader_in_submodule: forbidden op hidden in child module Each test loads the real .pt file via torch::jit::load and verifies the validator correctly identifies and rejects it. Includes the Python generator script for reproducibility. Made-with: Cursor
Replace the bash wrapper script with cmake/run-validation.cmake that works across all CI platforms (Linux, macOS, Windows). The CMake script searches for python3, python3.12, python3.11, python3.10, python3.9, and python — handling Linux build machines where Python is only available as python3.12 (via make altinstall) and Windows where the canonical name is python. It also prepends the venv's torch/lib directory to the dynamic library search path to avoid conflicts with any system-installed libtorch. Made-with: Cursor
Add the Python allowlist validation as a step in test_all_parallel (used by CI) and precommit (used by developers). Both use OPTIONAL=TRUE so the validation is gracefully skipped with a warning when Python 3 is not available or pip cannot install dependencies (e.g. in Docker containers without network access). The standalone validate_pytorch_inference_models target remains a hard failure for explicit use. Made-with: Cursor
✅ Snyk checks have passed. No issues have been found so far.
💻 Catch issues earlier using the plugins for VS Code, JetBrains IDEs, Visual Studio, and Eclipse. |
Made-with: Cursor
…gram Made-with: Cursor
Replace relative "../Foo.h" includes with <Foo.h> by adding the parent source directory to the test target's include path. Also remove unnecessary backslash escapes in extract_model_ops README. Made-with: Cursor
…sion Made-with: Cursor
Deduplicate collect_graph_ops, graph inlining, and HuggingFace model loading/tracing logic shared between extract_model_ops.py and validate_allowlist.py into a common module. Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
|
I have some Python tests that reproduce a simplified attack, stored in #2873. Can you take them and run them against your validation code to ensure that those resource-leakage and similar attacks are not possible with your change? |
Add HeapLeakModel and RopExploitModel (from PR elastic#2873) to the malicious model fixture generator and create corresponding .pt test fixtures. These reproduce real-world attacks that exploit torch.as_strided to leak heap addresses and build ROP chains. Add two new Boost.Test cases in CModelGraphValidatorTest that load these fixtures and verify the graph validator rejects them due to unrecognised operations (aten::as_strided, aten::item). Made-with: Cursor
Add two standalone Python test scripts alongside the existing fixture generator in bin/pytorch_inference/unittest/testfiles/: - test_graph_validation_evil_models.py: Pure-Python test that mirrors CModelGraphValidator logic (allowlist, forbidden list, recursive block traversal, graph inlining) and validates that the sandbox2 attack models are rejected. Useful for fast iteration during allowlist development without requiring a C++ build. - test_pytorch_inference_evil_models.py: End-to-end integration test that generates evil models, wraps them in the CBufferedIStreamAdapter size-prefixed framing format, and invokes the actual pytorch_inference binary to confirm graph validation rejection. Made-with: Cursor
Move reusable helpers out of test_pytorch_inference_evil_models.py into pytorch_inference_test_utils.py: - script_and_save_model(): TorchScript-compile and save any nn.Module - prepare_restore_file(): wrap a .pt archive with the 4-byte big-endian size header that CBufferedIStreamAdapter expects - find_pytorch_inference(): auto-discover the binary across CMake and Gradle build layouts - run_pytorch_inference(): invoke the binary with correct framing and arguments This makes it straightforward to add new model variants in future test scripts without duplicating the framing and discovery logic. Made-with: Cursor
valeriy42
left a comment
There was a problem hiding this comment.
Good work @edsavage . I left a few comments.
One thing that I cannot completely see through: we have now to traced models in es_it_models/ directory and some in bin/pytorch_inference/testfiles/malicious_models. What is the difference?
Also, the python test files you extracted from the sandbox2 PR verify the C++ logic by rebuilding it in Python, right? It may be confusing after 6 months, once we forgot all about this PR. Is it possible to extract the evil and benign models as pt files and then use them in malicious_models like all other models we test our C++ code against?
| //! security boundary self-contained. The list should be regenerated | ||
| //! whenever the set of supported architectures changes or when | ||
| //! upgrading the PyTorch version. |
There was a problem hiding this comment.
Do you have a test that would fail if some operations in this list will change due to the PyTorch upgrade?
There was a problem hiding this comment.
I've added testAllowlistCoversReferenceModels. A C++ test that loads a golden JSON file containing per-architecture op sets extracted from 18 reference HuggingFace models. It asserts every op is in ALLOWED_OPERATIONS and none are in FORBIDDEN_OPERATIONS. The golden file is generated by extract_model_ops.py --golden (new flag) and should be regenerated whenever PyTorch is upgraded. If a new PyTorch version introduces new ops for any supported model, the C++ test fails until the allowlist is updated.
| // Generated by dev-tools/extract_model_ops/extract_model_ops.py against PyTorch 2.7.1. | ||
| // Reference models: bert-base-uncased, roberta-base, distilbert-base-uncased, | ||
| // google/electra-small-discriminator, microsoft/mpnet-base, | ||
| // microsoft/deberta-base, facebook/dpr-ctx_encoder-single-nq-base, | ||
| // google/mobilebert-uncased, xlm-roberta-base, elastic/bge-m3, | ||
| // elastic/distilbert-base-{cased,uncased}-finetuned-conll03-english, | ||
| // elastic/eis-elser-v2, elastic/elser-v2, elastic/hugging-face-elser, | ||
| // elastic/multilingual-e5-small-optimized, elastic/splade-v3, | ||
| // elastic/test-elser-v2. |
There was a problem hiding this comment.
Wow 😲 This is a long list!
There was a problem hiding this comment.
Yeah, it's the union across 10+ architectures plus Elastic-specific and ES integration test models. dev-tools/extract_model_ops/extract_model_ops.py can regenerate it when architectures or PyTorch versions change.
| TStringSet& ops, | ||
| std::size_t& nodeCount) { | ||
| for (const auto* node : block.nodes()) { | ||
| ++nodeCount; |
There was a problem hiding this comment.
While you are inlinig and collection ops, you need to check agains MAX_NODE_COUNT to prevent resource exhaustion.
There was a problem hiding this comment.
If it's to large, than it's invalid. I would prefer failing fast and not going through other validation steps.
There was a problem hiding this comment.
Done. collectBlockOps now checks nodeCount > MAX_NODE_COUNT after each increment and returns immediately. collectModuleOps also bails out between methods. The validate() entry point rejects oversized graphs before proceeding to op matching.
| for (const auto& op : observedOps) { | ||
| if (forbiddenOps.contains(op)) { | ||
| result.s_IsValid = false; | ||
| result.s_ForbiddenOps.push_back(op); |
There was a problem hiding this comment.
Can you early exist when result.s_IsValid is false? We are going to drop this model anyway so there is no reason to go through all 1M operations.
There was a problem hiding this comment.
Done. I restructured this as a two-pass check: forbidden ops first, then unrecognised. If any forbidden ops are found we skip the unrecognised-op scan entirely.
bin/pytorch_inference/Main.cc
Outdated
| << " nodes, all operations match supported architectures."); | ||
| } catch (const c10::Error& e) { | ||
| LOG_FATAL(<< "Failed to get forward method: " << e.what()); | ||
| LOG_FATAL(<< "Model graph validation failed: " << e.what()); |
There was a problem hiding this comment.
IMO, This should be HANDLE_FATAL to prevent unvalidated models from being evaluated.
docs/CHANGELOG.asciidoc
Outdated
|
|
||
| === Enhancements | ||
|
|
||
| * Harden pytorch_inference with TorchScript model graph validation. (See {ml-pull}[#2936].) |
There was a problem hiding this comment.
| * Harden pytorch_inference with TorchScript model graph validation. (See {ml-pull}[#2936].) | |
| * Harden pytorch_inference with TorchScript model graph validation. (See {ml-pull}2936[#2936].) |
| // --- Sandbox2 attack models (PR #2873) --- | ||
| // | ||
| // These reproduce real-world attack vectors that exploit torch.as_strided | ||
| // to read out-of-bounds heap memory, leak libtorch addresses, and build | ||
| // ROP chains that call mprotect + shellcode to write arbitrary files. | ||
| // The graph validator must reject them because aten::as_strided (and | ||
| // several helper ops like aten::item) are not in the allowlist. |
There was a problem hiding this comment.
No need of referencing another PR. Does it make sense to add aten::as_strided to the forbidden list?
There was a problem hiding this comment.
Wit this tests in place, I wonder if we can remove test_*_evil_models.py as redundant.
There was a problem hiding this comment.
Agreed. I've removed all three Python test scripts. The C++ integration tests with .pt fixtures cover the same attack models. generate_malicious_models.py is kept for fixture regeneration.
| # compliance with the Elastic License 2.0 and the foregoing additional | ||
| # limitation. | ||
| # | ||
| """Generate malicious TorchScript model fixtures for validator integration tests. |
There was a problem hiding this comment.
Can this file be moved to dev-tools/ along other files you created?
- Check MAX_NODE_COUNT during graph traversal to prevent resource exhaustion on pathologically large models (bail out immediately in collectBlockOps and collectModuleOps). - Two-pass validation: check forbidden ops first, skip unrecognised op scan when forbidden ops are found. - Add aten::as_strided to FORBIDDEN_OPERATIONS (key enabler of heap-leak and ROP chain attacks). - Change LOG_FATAL to HANDLE_FATAL in the c10::Error catch block so an exception during validation terminates the process. - Fix CHANGELOG asciidoc link syntax. - Move generate_malicious_models.py to dev-tools/. - Remove redundant Python test scripts now that C++ integration tests cover the same attack models. - Remove PR cross-references from comments per reviewer request. Made-with: Cursor
Add a C++ test (testAllowlistCoversReferenceModels) that loads a
golden JSON file containing per-architecture TorchScript op sets
extracted from 18 reference HuggingFace models and verifies every
op is in ALLOWED_OPERATIONS and none are in FORBIDDEN_OPERATIONS.
This catches allowlist regressions in CI without requiring Python
or network access. When PyTorch is upgraded, regenerate the golden
file with:
python3 extract_model_ops.py --golden \
bin/pytorch_inference/unittest/testfiles/reference_model_ops.json
The --golden flag is a new addition to extract_model_ops.py that
outputs per-model op sets as structured JSON.
Made-with: Cursor
Made-with: Cursor
Summary
Implements security hardening for
pytorch_inferenceby validating TorchScript model graphs before execution, addressing elastic/ml-team#1770.Model Graph Validation (C++)
CModelGraphValidator: Validates TorchScript model graphs by inlining all method calls (torch::jit::Inline) and recursively inspecting every node (including sub-blocks insideprim::If/prim::Loop).CSupportedOperations: Defines a dual-list security model:FORBIDDEN_OPERATIONS(4 ops):aten::execute_with_args,aten::from_file,prim::CallFunction,prim::CallMethod— rejected immediately with a clear error.ALLOWED_OPERATIONS(82 ops): Exhaustive allowlist of safe tensor/control-flow ops derived from tracing reference models against PyTorch 2.7.1.MAX_NODE_COUNT = 1,000,000): Guards against resource exhaustion from excessively large graphs.DEBUGlevel during validation.Operation Allowlist Tooling (Python)
dev-tools/extract_model_ops/: Self-contained tooling directory with:extract_model_ops.py— generates the C++ allowlist from reference HuggingFace modelsvalidate_allowlist.py— integration test verifying no false positives against 24 HuggingFace models + 3 Elasticsearch integration test modelsreference_models.json/validation_models.json— model configurationses_it_models/— extracted.ptmodels from Elasticsearch'sPyTorchModelIT,TextExpansionQueryIT,TextEmbeddingQueryITrequirements.txt— pinned totorch==2.7.1matching the libtorch build versionCI Integration
cmake/run-validation.cmake: Portable CMake script that locates Python 3 (searchingpython3,python3.12, ...,python), manages a virtual environment, handlesDYLD_LIBRARY_PATH/LD_LIBRARY_PATHfor libtorch conflicts, and runs the validation. SupportsOPTIONAL=TRUEfor graceful skip when Python or network is unavailable.test_all_parallelandprecommitwithOPTIONAL=TRUE— runs automatically when Python is available, skips with a warning otherwise (e.g. in Docker containers without network).validate_pytorch_inference_modelsavailable for explicit verification (hard failure mode).C++ Tests
CModelGraphValidatorTest.cc): Tests for allowed/forbidden/unrecognised ops, graph inlining, node count enforcement, and integration tests usingtorch::jit::Module::define()..ptfixtures testing detection ofaten::from_file, hidden ops in submodules, conditional branches, and mixed scenarios.Test plan
cmake --build ... -t test)CModelGraphValidatortests pass (50 pytorch_inference test cases).pt)OPTIONAL=TRUEgracefully skips when Python unavailableMade with Cursor