Skip to content

Commit c431947

Browse files
danilyefNikitasuperbockmartinloretzzz
authored
Speed up forecasting and making code more readable (#28)
* check tests with attention * check on jupyterlab_docker branch * chore: skip test_jupyterlab_running if Docker is not available * chore: run all tests inside Docker container * freezing dependencies up to next major version * check tests with attention * chore: run all tests inside Docker container * clamping logsigmoid, duplication del and more readable * fix white-spaces * Added suggestions for PR and replaced ipure python by torch * added self.num_heads instead of NH for ONNX compatibility. Otherwise, ONNX conversion throws and error * Resolving conflicts with main * Resolve the conflict p2 * jupyterlab_docker is removed * correct github actions * Add comment to clamp --------- Co-authored-by: Nikita <[email protected]> Co-authored-by: Sebastian Böck <[email protected]> Co-authored-by: martinloretzzz <[email protected]>
1 parent ed79a50 commit c431947

4 files changed

Lines changed: 47 additions & 41 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,5 @@ share/python-wheels/
3030

3131
.vscode
3232
.venv
33+
.ruff_cache
3334
torch_compile_debug

src/tirex/models/slstm/cell.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def _impl_cuda(self, input: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
100100

101101
def _get_input(self, x: torch.Tensor) -> torch.Tensor:
102102
assert x.shape[-1] == self.config.embedding_dim * self.config.num_gates, (
103-
f"Input size mismatch: Expected input size {self.config.embedding_dim * self.config.num_gates}, but got {input.size(-1)}."
103+
f"Input size mismatch: Expected input size {self.config.embedding_dim * self.config.num_gates}, but got {x.size(-1)}."
104104
)
105105
return x.view(x.shape[0], x.shape[1], self.config.num_gates, self.config.num_heads, -1).permute(1, 0, 2, 3, 4)
106106

@@ -128,7 +128,7 @@ def slstm_forward(
128128
states: torch.Tensor, # [4, B, H] only the first is used for recurrence!
129129
R: torch.Tensor, # [K, R*H, H] - K num_heads
130130
b: torch.Tensor, # [T*H]
131-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
131+
) -> tuple[torch.Tensor, torch.Tensor]:
132132
num_gates = 4
133133
num_heads = R.shape[0]
134134
S, B, _ = x.shape
@@ -167,7 +167,7 @@ def slstm_forward_pointwise(
167167
iraw, fraw, zraw, oraw = torch.unbind(raw.view(raw.shape[0], 4, -1), dim=1)
168168

169169
# Equations reference the xlstm paper on page 4: https://arxiv.org/pdf/2405.04517
170-
logfplusm = m + F.logsigmoid(fraw) # eq 15
170+
logfplusm = m + F.logsigmoid(torch.clamp(fraw, max=15)) # eq 15 # Clamp to avoid subnomals
171171
mnew = torch.where(torch.all(n == 0.0), iraw, torch.max(iraw, logfplusm)) # eq 15
172172
ogate = torch.sigmoid(oraw) # eq 14
173173
igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw)) # eq 16

src/tirex/models/slstm/layer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self, config: sLSTMBlockConfig, backend: str):
2020
self.ogate = LinearHeadwiseExpand(in_features, num_heads)
2121

2222
self.slstm_cell = sLSTMCell(self.config, backend)
23-
self.group_norm = MultiHeadLayerNorm(ndim=in_features)
23+
self.group_norm = MultiHeadLayerNorm(ndim=in_features, num_heads=num_heads)
2424

2525
def forward(self, x: torch.Tensor, slstm_state: torch.Tensor | None = None) -> torch.Tensor:
2626
x_g = torch.cat((self.fgate(x), self.igate(x), self.zgate(x), self.ogate(x)), dim=-1)
@@ -50,18 +50,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5050

5151

5252
class MultiHeadLayerNorm(nn.Module):
53-
def __init__(self, ndim: int):
53+
def __init__(self, ndim: int, num_heads: int):
5454
super().__init__()
5555
self.weight = nn.Parameter(torch.zeros(ndim))
56+
self.num_heads = num_heads
5657

5758
def forward(self, input: torch.Tensor) -> torch.Tensor:
5859
assert input.dim() == 4, "Input must be 4D tensor (B, NH, S, DH)"
5960
B, NH, S, DH = input.shape
6061

62+
assert NH == self.num_heads
6163
gn_in_1 = input.transpose(1, 2) # (B, S, NH, DH)
6264
gn_in_2 = gn_in_1.reshape(B * S, NH * DH) # (B * S, NH * DH)
6365
residual_weight = 1.0 + self.weight
64-
out = F.group_norm(gn_in_2, num_groups=NH, weight=residual_weight)
66+
out = F.group_norm(gn_in_2, num_groups=self.num_heads, weight=residual_weight)
6567
# (B * S), (NH * DH) -> (B, S, NH, DH) -> (B, NH, S, DH)
6668
out = out.view(B, S, NH, DH).transpose(1, 2)
6769
return out

src/tirex/models/tirex.py

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,18 @@ def _forecast_quantiles(
7979
training_quantile_levels = self.config.quantiles
8080

8181
if set(quantile_levels).issubset(set(training_quantile_levels)):
82-
quantiles = predictions[..., [training_quantile_levels.index(q) for q in quantile_levels]]
82+
quantile_indices = torch.tensor(
83+
[training_quantile_levels.index(q) for q in quantile_levels],
84+
dtype=torch.long,
85+
device=predictions.device,
86+
)
87+
quantiles = torch.index_select(predictions, dim=-1, index=quantile_indices)
8388
else:
8489
quantiles = self._interpolate_quantiles(predictions, quantile_levels)
8590

8691
# median as mean
87-
mean = predictions[:, :, training_quantile_levels.index(0.5)]
92+
median_idx = torch.tensor([training_quantile_levels.index(0.5)], dtype=torch.long, device=predictions.device)
93+
mean = torch.index_select(predictions, dim=-1, index=median_idx).squeeze(-1)
8894
return quantiles, mean
8995

9096
@torch.inference_mode()
@@ -105,24 +111,8 @@ def _forecast_tensor(
105111

106112
context = context.to(dtype=torch.float32)
107113
while remaining > 0:
108-
if context.shape[-1] > max_context:
109-
context = context[..., -max_context:]
110-
if context.shape[-1] < min_context:
111-
pad = torch.full(
112-
(context.shape[0], min_context - context.shape[-1]),
113-
fill_value=torch.nan,
114-
device=context.device,
115-
dtype=context.dtype,
116-
)
117-
context = torch.concat((pad, context), dim=1)
118-
tokenized_tensor, tokenizer_state = self.tokenizer.context_input_transform(context)
119114
fut_rollouts = min(remaining, max_accelerated_rollout_steps)
120-
with torch.no_grad():
121-
prediction, _ = self._forward_model_tokenized(input_token=tokenized_tensor, rollouts=fut_rollouts)
122-
prediction = prediction[:, :, -fut_rollouts:, :].to(tokenized_tensor) # predicted token
123-
# [bs, num_quantiles, num_predicted_token, output_patch_size]
124-
prediction = self.tokenizer.output_transform(prediction, tokenizer_state)
125-
prediction = prediction.flatten(start_dim=2)
115+
prediction, fut_rollouts = self._forecast_single_step(context, max_context, min_context, fut_rollouts)
126116

127117
predictions.append(prediction)
128118
remaining -= fut_rollouts
@@ -134,6 +124,33 @@ def _forecast_tensor(
134124

135125
return torch.cat(predictions, dim=-1)[..., :prediction_length].to(dtype=torch.float32)
136126

127+
def _forecast_single_step(
128+
self,
129+
context: torch.Tensor,
130+
max_context: int,
131+
min_context: int,
132+
new_patch_count: int = 1,
133+
) -> tuple[torch.Tensor, int]:
134+
if context.shape[-1] > max_context:
135+
context = context[..., -max_context:]
136+
if context.shape[-1] < min_context:
137+
pad = torch.full(
138+
(context.shape[0], min_context - context.shape[-1]),
139+
fill_value=torch.nan,
140+
device=context.device,
141+
dtype=context.dtype,
142+
)
143+
context = torch.concat((pad, context), dim=1)
144+
145+
tokenized_tensor, tokenizer_state = self.tokenizer.context_input_transform(context)
146+
prediction, _ = self._forward_model_tokenized(input_token=tokenized_tensor, rollouts=new_patch_count)
147+
prediction = prediction[:, :, -new_patch_count:, :].to(tokenized_tensor) # predicted token
148+
# Shape: [bs, num_quantiles, num_predicted_token, output_patch_size]
149+
prediction = self.tokenizer.output_transform(prediction, tokenizer_state)
150+
prediction = prediction.flatten(start_dim=2)
151+
152+
return prediction, new_patch_count
153+
137154
def _forward_model_tokenized(
138155
self,
139156
input_token: torch.Tensor,
@@ -165,21 +182,7 @@ def _forward_model_tokenized(
165182

166183
input_token = torch.nan_to_num(input_token, nan=self.config.nan_mask_value)
167184

168-
hidden_states = self.input_patch_embedding(torch.cat((input_token, input_mask), dim=2))
169-
170-
for block in self.blocks:
171-
hidden_states = block(hidden_states)
172-
173-
hidden_states = self.out_norm(hidden_states)
174-
175-
quantile_preds = self.output_patch_embedding(hidden_states)
176-
quantile_preds = torch.unflatten(
177-
quantile_preds, -1, (len(self.config.quantiles), self.config.output_patch_size)
178-
)
179-
quantile_preds = torch.transpose(quantile_preds, 1, 2) # switch quantile and num_token_dimension
180-
# quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size]
181-
182-
quantile_preds = self._forward_model(torch.cat((input_token, input_mask), dim=2))
185+
quantile_preds, hidden_states = self._forward_model(torch.cat((input_token, input_mask), dim=2))
183186

184187
quantile_preds = torch.unflatten(
185188
quantile_preds, -1, (len(self.config.quantiles), self.config.output_patch_size)
@@ -196,7 +199,7 @@ def _forward_model(self, input: torch.Tensor):
196199

197200
hidden_states = self.out_norm(hidden_states)
198201

199-
return self.output_patch_embedding(hidden_states)
202+
return self.output_patch_embedding(hidden_states), hidden_states
200203

201204
def _interpolate_quantiles(self, predictions: torch.Tensor, quantile_levels: list[float]):
202205
training_quantile_levels = self.config.quantiles

0 commit comments

Comments
 (0)