Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6498cd9
Add fix for evo2 generate/inference
Jun 26, 2025
c8f34b2
Add Farhad's suggested refactor
Jun 26, 2025
422b056
lint
Jun 26, 2025
0e6de40
Remove unused and truism code
Jun 30, 2025
cd4d96b
add inference_context conditional check to use new ops
Jul 1, 2025
181b115
add hyena operator tests
Jul 1, 2025
87c28d2
lint
Jul 1, 2025
ad018c3
remove unused code
Jul 1, 2025
a21f429
add doc strings for flake8
Jul 1, 2025
942c241
Merge branch 'main' into jwilber/fix-evo2-generate
jwilber Jul 2, 2025
8c0a32e
Apply isort and black reformatting
jwilber Jul 2, 2025
1d283f8
remove unnecessary assignments
Jul 3, 2025
9e34b8e
invoke original forward for non-inference calls
Jul 3, 2025
be0c14e
Apply isort and black reformatting
jwilber Jul 3, 2025
2d3692b
Fix reset issue
Jul 3, 2025
e8afda4
Apply isort and black reformatting
jwilber Jul 3, 2025
d3e78f6
Merge branch 'main' into jwilber/fix-evo2-generate
jwilber Jul 3, 2025
9c35b31
add docstring
Jul 4, 2025
1f9ecbd
Apply isort and black reformatting
jwilber Jul 4, 2025
f23f7a8
Merge branch 'main' into jwilber/fix-evo2-generate
jwilber Jul 7, 2025
918f901
Merge branch 'main' into jwilber/fix-evo2-generate
jwilber Jul 7, 2025
109592e
Merge branch 'main' into jwilber/fix-evo2-generate
jwilber Jul 8, 2025
377bf8a
Merge branch 'main' into jwilber/fix-evo2-generate
jwilber Jul 8, 2025
1af0005
Remove test
Jul 8, 2025
ccc2419
Merge branch 'main' into jwilber/fix-evo2-generate
jwilber Jul 9, 2025
c99457d
Add tests for hyena operator
Jul 10, 2025
fc1bac4
Apply isort and black reformatting
jwilber Jul 10, 2025
b466867
lint
Jul 10, 2025
8dec653
simplify context manager in tests
Jul 10, 2025
5d797c2
remove unused import in test
Jul 10, 2025
dab1c53
Add env vars for test
Jul 10, 2025
a3202d8
mark tests gpu only
Jul 10, 2025
df97bf5
Apply isort and black reformatting
jwilber Jul 10, 2025
3eed182
Merge branch 'main' into jwilber/fix-evo2-generate
chtruong814 Jul 11, 2025
d7cf0a9
Merge branch 'main' into jwilber/fix-evo2-generate
chtruong814 Jul 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,7 @@ def generate(
text_only: bool = False,
output_path: Optional[AnyPath] = None,
enable_flash_decode: bool = True,
**kwargs,
) -> list[Union["InferenceRequest", str]]:
"""
Generates text using a NeMo LLM model.
Expand Down Expand Up @@ -1116,6 +1117,7 @@ def generate(
output_path (Optional[Union[Path, str]], optional): The path to save the generated text or test dataset
predictions. Defaults to None.
enable_flash_decode (bool, optional): Whether to enable flash decode. Defaults to True.
**kwargs: Additional keyword arguments passed to setup_model_and_tokenizer.

Returns:
list[Union["InferenceRequest", str]]: A list of generated text,
Expand All @@ -1139,6 +1141,7 @@ def generate(
params_dtype=params_dtype,
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
enable_flash_decode=enable_flash_decode,
**kwargs,
)

max_seq_length = inference_params.num_tokens_to_generate + max(len(mcore_tokenizer.tokenize(p)) for p in inputs)
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/llm/gpt/model/hyena.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ class Hyena7bARCLongContextConfig(Hyena7bConfig):
due to constraintes from large TP size for training."""

ffn_hidden_size: int = 11264
seq_len_interpolation_factor: float = 128


@dataclass
Expand All @@ -475,6 +476,7 @@ class Hyena40bARCLongContextConfig(Hyena40bConfig):
due to constraintes from large TP size for training."""

ffn_hidden_size: int = 22528
seq_len_interpolation_factor: float = 128


@io.model_importer(HyenaModel, "pytorch")
Expand Down
191 changes: 191 additions & 0 deletions nemo/collections/llm/gpt/model/megatron/hyena/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024 Arc Institute. All rights reserved.
# Copyright (c) 2024 Michael Poli. All rights reserved.
# Copyright (c) 2024 Stanford University. All rights reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn.functional as F
from einops import rearrange


def adjust_filter_shape_for_broadcast(u, h):
h = h.squeeze() # Standardize to [D, L] from [1, D, L] and [D, 1, L]

# Case: u: [B, D, L], k_f: [D, L]
if len(u.shape) > len(h.shape):
h = h.unsqueeze(0)

# Case: u: [B, D1, D2, L], k_f: [B, D, L]
if len(u.shape) > 3:
h = h.unsqueeze(1)
return h


def fftconv_func(*, u, k, D):
seqlen = u.shape[-1]
fft_size = 2 * seqlen

k_f = torch.fft.rfft(k, n=fft_size) / fft_size
k_f = adjust_filter_shape_for_broadcast(u, k_f)
k = k.squeeze()

u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size)

y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen]

return y + u * D.unsqueeze(-1)


def parallel_fir(
*,
u, # B L D
weight,
bias,
L,
gated_bias,
fir_length,
compute_state,
):
L = u.shape[1]
u = rearrange(u, "b l d -> b d l")

if fir_length >= 128:
with torch.autocast("cuda"):
z = fftconv_func(
u=u.to(torch.float32),
k=weight[:, :, :L].to(torch.float32),
D=bias,
).to(dtype=u.dtype)
else:
z = F.conv1d(
u.to(torch.float32),
weight.to(torch.float32),
bias=None,
stride=1,
padding=fir_length - 1,
groups=u.shape[1], # always set to D, regardless of filter grouping
)[..., :L]

z = z.to(u.dtype)

if bias is not None:
if gated_bias:
z = z + bias[None, :, None] * u
else:
z = z + bias[None, :, None]

fir_state = None
if compute_state:
fir_state = u[..., -fir_length + 1 :]
return z, fir_state


def parallel_iir(*, z_pre, h, D, L, poles, t, hidden_size, compute_state):
"""Compute the output state of the short convolutional filter."""
fft_size = 2 * L
x1, x2, v = z_pre.split([hidden_size, hidden_size, hidden_size], dim=1)

x1v = x1 * v

H = torch.fft.rfft(h.to(dtype=torch.float32), n=fft_size) / fft_size
X_s = torch.fft.fft(x1v.to(dtype=torch.float32), n=fft_size)
X = X_s[..., : H.shape[-1]]
if len(z_pre.shape) > 3:
H = H.unsqueeze(1)
y = torch.fft.irfft(X * H, n=fft_size, norm="forward")[..., :L]
y = y.to(dtype=x1v.dtype)
y = (y + x1v * D.unsqueeze(-1)) * x2

iir_state = None
if compute_state:
iir_state = prefill_via_modal_fft(
x1v=x1v,
X_s=X_s,
L=L,
t=t,
poles=poles,
)

return y.permute(0, 2, 1), iir_state


def step_fir(*, u, fir_state, weight, bias=None, gated_bias=False, flip_filter=False):
"""Steps forward FIR filters in the architecture.
FIR filters generally include truncated convolutions in Hyena with an explicit or hybrid time-domain parametrization:
* Short FIR filters in Hyena featurizers
* Short and medium FIR filters in Hyena operators
Note:
`fir_state` contains the last FIR filter length - 1 elements of `u`: `u_(L-2), u_{L-1), ...`
We assume dimensions of `short_filter_weight` to be `[d, 1, short_filter_len]`.
"""
weight = weight.squeeze()

cache_size = fir_state.shape[-1]
filter_length = weight.shape[-1]
if flip_filter:
weight = weight.flip(-1)
weight = weight[..., -cache_size - 1 :].unsqueeze(0)
else:
weight = weight[..., : cache_size + 1].unsqueeze(0)

input_dtype = u.dtype
weight = weight.to(torch.float32)
u = u.to(torch.float32)
fir_state = fir_state.to(torch.float32)
bias = bias.to(torch.float32) if bias is not None else None

h0, h = weight[..., -1], weight[..., :-1]
y = h0 * u + torch.sum(fir_state * h, dim=-1)

if bias is not None:
if gated_bias:
y = y + bias * u
else:
y = y + bias

# Update the state
if cache_size < filter_length - 1:
fir_state = torch.cat([fir_state, u[..., None]], dim=-1)
else:
fir_state = torch.roll(fir_state, -1, dims=2)
fir_state[..., -1] = u

return y.to(input_dtype), fir_state


def step_iir(*, x2, x1, v, D, residues, poles, iir_state):
x1v = x1 * v
poles = torch.exp(poles) # poles arg contains log_poles
poles = poles[..., 0][None] # squeeze dummy seqlen dim and add dummy batch dim
residues = residues[None] # add dummy batch dim
iir_state = poles * iir_state + x1v[..., None]

res_state = torch.sum(residues * iir_state, dim=-1)
y = x2 * (res_state + D * x1v)
return y, iir_state


def prefill_via_modal_fft(*, x1v, L, poles, t, X_s):
"""
Compute the IIR state via a single FFT
"""
# When the model has a long convolution derived from a recurrence in modal form and prefill_style is "fft",
# we split the filter into poles and residues and reuse FFT computation on the input.
bs = x1v.shape[0]
fft_size = 2 * L
state_s = (poles.to(torch.float32) * t).exp()
state_S = torch.fft.fft(state_s, n=fft_size).repeat(bs, 1, 1, 1) # B, D, state_dim, 2 * L
state = torch.fft.ifft(X_s[..., None, :] * state_S, n=fft_size)
return state[..., L - 1].to(dtype=torch.float32)
50 changes: 43 additions & 7 deletions nemo/collections/llm/gpt/model/megatron/hyena/hyena_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@
from nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils import (
B2BCausalConv1dModule,
ParallelCausalDepthwiseConv1d,
ParallelCausalDepthwiseConv1dWithState,
ParallelHyenaOperator,
ParallelShortHyenaOperator,
divide,
)


logger = logging.getLogger(__name__)

try:
Expand Down Expand Up @@ -160,7 +160,13 @@ def __init__(

hyena_proj_groups = self.proj_groups if not self.grouped_attention else 1
grouped_proj_size = self.hidden_size_per_partition // hyena_proj_groups
self.hyena_proj_conv = ParallelCausalDepthwiseConv1d(

hyena_proj_conv_class = ParallelCausalDepthwiseConv1dWithState


short_conv_class = ParallelCausalDepthwiseConv1dWithState

self.hyena_proj_conv = hyena_proj_conv_class(
self.hidden_size_per_partition + 2 * grouped_proj_size,
self.transformer_config,
self.hyena_config,
Expand All @@ -179,7 +185,7 @@ def __init__(
self.transformer_config,
self.hyena_config,
self.transformer_config.init_method,
short_conv_class=ParallelCausalDepthwiseConv1d,
short_conv_class=short_conv_class,
use_fast_causal_conv=self.fast_conv_mixer,
use_conv_bias=self.transformer_config.use_short_conv_bias,
)
Expand Down Expand Up @@ -280,18 +286,48 @@ def forward(self, x, layer_past=None, inference_context=None, _hyena_use_cp=True
_proj_use_cp = True
else:
_proj_use_cp = False
features, _ = self._maybe_use_fp8(self.dense_projection, x)
# Handle padding for FP8 if enabled
if self.transformer_config.vortex_style_fp8:

def pad_to_multiple(x, multiple=16):
"""Pad tensor to make sequence length divisible by multiple."""
seq_len = x.size(0)
if seq_len % multiple == 0:
return x

pad_len = multiple - (seq_len % multiple)
pad_tensor = torch.zeros(pad_len, *x.shape[1:], device=x.device, dtype=x.dtype)
return torch.cat([x, pad_tensor], dim=0)

# Direct padding without rearrange
L = x.shape[0]
x = pad_to_multiple(x)
features, _ = self._maybe_use_fp8(self.dense_projection, x)

# Slice back to original sequence length if padding was added

if features.shape[0] > L:
features = features[:L, :, :]
else:
features, _ = self.dense_projection(x)
features = rearrange(features, "l b d -> b d l").contiguous()

if self.use_b2b_causal_conv1d and self.operator_type in ["hyena_short_conv", "hyena_medium_conv"]:
if (
self.use_b2b_causal_conv1d
and self.operator_type in ["hyena_short_conv", "hyena_medium_conv"]
and inference_context is not None
):
# todo: support inference_context for b2b_kernel
# Use the B2BCausalConv1dModule wrapper with the existing weights from the original model
z = self.b2b_kernel(features, _use_cp=_proj_use_cp)
else:
features = self.hyena_proj_conv(features, _use_cp=_proj_use_cp) # [B, D, L]
features = self.hyena_proj_conv(
features, _use_cp=_proj_use_cp, inference_context=inference_context
) # [B, D, L]
x1, x2, v = rearrange(features, "b (g dg p) l -> b (g dg) p l", p=3, g=self.num_groups_per_tp_rank).unbind(
dim=2
)
z = self.mixer(x1, x2, v, _hyena_use_cp=_proj_use_cp)
z = self.mixer(x1, x2, v, _hyena_use_cp=_proj_use_cp, inference_context=inference_context)

z = rearrange(z, "b d l -> l b d").contiguous()
y, bias = self.dense(z)
Expand Down
Loading
Loading