-
Notifications
You must be signed in to change notification settings - Fork 582
perf: avoid graph break for SiLUT when inferring #4790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR streamlines the forward method in ActivationFn by replacing the conditional branch and F.silu call with a fully vectorized computation using torch.sigmoid, torch.tanh, and torch.where, improving inference efficiency.
- Unifies SiLU/Tanh computation into a single
torch.wherecall - Eliminates graph break by removing
if torch.any(mask) - Benchmarked speedup of 6–7% on LAMBench inference tasks
Comments suppressed due to low confidence (3)
deepmd/pt/utils/utils.py:152
- [nitpick] The variable name
sigis ambiguous and shadows common abbreviations; consider usingsigmoidorsig_xfor clarity.
sig = torch.sigmoid(x)
deepmd/pt/utils/utils.py:154
- [nitpick] The variable
tanhshadows thetorch.tanhfunction; renaming totanh_partortanh_xwould improve readability.
tanh = torch.tanh(self.slope * (x - self.threshold)) + self.const
deepmd/pt/utils/utils.py:155
- Add a unit test covering cases where all inputs are below the threshold and where all are above to ensure the vectorized path matches the original branching behavior exactly.
return torch.where(x >= self.threshold, tanh, silu)
📝 WalkthroughWalkthroughThe Changes
Suggested labels
Suggested reviewers
Warning There were issues while running some tools. Please review the errors and either fix the tool's configuration or disable the tool if it's a critical failure. 🔧 Pylint (3.3.7)deepmd/pt/utils/utils.pyNo files to lint: exiting. ✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
deepmd/pt/utils/utils.py (1)
151-155: Consider eliminating code duplication with thesilut_forwardfunction.The
SiLUT.forwardmethod now has identical logic to thesilut_forwardfunction (lines 24-30). Consider refactoring to callsilut_forwarddirectly to reduce duplication:def forward(self, x: torch.Tensor) -> torch.Tensor: - sig = torch.sigmoid(x) - silu = x * sig - tanh = torch.tanh(self.slope * (x - self.threshold)) + self.const - return torch.where(x >= self.threshold, tanh, silu) + return silut_forward(x, self.threshold, self.slope, self.const)However, this might introduce minimal function call overhead, so the current approach may be preferred for performance-critical code.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/pt/utils/utils.py(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
deepmd/pt/utils/utils.py (1)
deepmd/dpmodel/utils/network.py (2)
sigmoid(355-356)silu(358-359)
⏰ Context from checks skipped due to timeout of 90000ms (29)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test C++ (false)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Analyze (python)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Test C++ (true)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C++ (cpu, cpu)
🔇 Additional comments (1)
deepmd/pt/utils/utils.py (1)
151-155: Excellent performance optimization that eliminates graph breaks.The refactored implementation successfully removes conditional branching that was causing PyTorch computation graph breaks during inference. The benchmark results showing 6.4-7.1% speedup validate this approach. The mathematical behavior remains identical while achieving better performance.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## devel #4790 +/- ##
==========================================
- Coverage 84.80% 84.79% -0.01%
==========================================
Files 698 698
Lines 67798 67796 -2
Branches 3542 3542
==========================================
- Hits 57494 57490 -4
Misses 9171 9171
- Partials 1133 1135 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
This pull request simplifies and optimizes the implementation of the
forwardmethod in theActivationFnclass withindeepmd/pt/utils/utils.py. The changes streamline the logic by removing unnecessary condition checks and directly usingtorch.wherefor computation.I've evaluated this change using inference efficiency tasks from LAMBench with DPA 3.1 3M model.
catalysts_500.trajinorganic_500.trajSummary by CodeRabbit