-
Notifications
You must be signed in to change notification settings - Fork 760
refactor: refactoring cuda code to cute-dsl (part 1) #2428
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 10 commits
6275947
ec57a7c
946ad76
928a2ce
71631f1
0e9dd6c
e21b28f
d4d53d5
ba1a645
fcd5c5d
510b6f6
ef19fb9
6d78f6f
62764d6
11649ab
68ea276
a7690ec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -18,10 +18,19 @@ | |||||||||||||||||||||||||||||||||||||||||||||
| This module provides high-performance GPU kernels implemented using NVIDIA CuTe-DSL. | ||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| from .utils import is_cute_dsl_available, make_ptr, get_cutlass_dtype, get_num_sm | ||||||||||||||||||||||||||||||||||||||||||||||
| import importlib.util | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| # Conditionally import CuTe-DSL kernels | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def is_cute_dsl_available() -> bool: | ||||||||||||||||||||||||||||||||||||||||||||||
| return ( | ||||||||||||||||||||||||||||||||||||||||||||||
| importlib.util.find_spec("cutlass") is not None | ||||||||||||||||||||||||||||||||||||||||||||||
| and importlib.util.find_spec("cutlass.cute") is not None | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||
| import importlib.util | |
| # Conditionally import CuTe-DSL kernels | |
| def is_cute_dsl_available() -> bool: | |
| return ( | |
| importlib.util.find_spec("cutlass") is not None | |
| and importlib.util.find_spec("cutlass.cute") is not None | |
| ) | |
| import functools | |
| import importlib.util | |
| from ..api_logging import flashinfer_api | |
| `@functools.cache` | |
| `@flashinfer_api` | |
| def is_cute_dsl_available() -> bool: | |
| return ( | |
| importlib.util.find_spec("cutlass") is not None | |
| and importlib.util.find_spec("cutlass.cute") is not None | |
| ) |
🤖 Prompt for AI Agents
In `@flashinfer/cute_dsl/__init__.py` around lines 21 - 28, The
is_cute_dsl_available() function should be decorated with `@flashinfer_api` and
`@functools.cache` to enable API logging and cache the module discovery; update
the top imports to import functools (or functools.cache) and import the
flashinfer_api decorator (or from its module) so you can apply `@flashinfer_api`
and `@functools.cache` directly above def is_cute_dsl_available to avoid repeated
find_spec calls and ensure API logging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Expose quantized norm APIs at the package level.
Line 97-101 exports
rmsnormandfused_add_rmsnorm, but the new quantized variants (rmsnorm_quant,fused_add_rmsnorm_quant) fromflashinfer.normare still missing at the top level. Consider exporting them here soflashinfer.rmsnorm_quantworks consistently.✅ Suggested export additions
As per coding guidelines: Export new operations in flashinfer/init.py to make them available at package level.
🤖 Prompt for AI Agents