@@ -79,12 +79,18 @@ def _forecast_quantiles(
7979 training_quantile_levels = self .config .quantiles
8080
8181 if set (quantile_levels ).issubset (set (training_quantile_levels )):
82- quantiles = predictions [..., [training_quantile_levels .index (q ) for q in quantile_levels ]]
82+ quantile_indices = torch .tensor (
83+ [training_quantile_levels .index (q ) for q in quantile_levels ],
84+ dtype = torch .long ,
85+ device = predictions .device ,
86+ )
87+ quantiles = torch .index_select (predictions , dim = - 1 , index = quantile_indices )
8388 else :
8489 quantiles = self ._interpolate_quantiles (predictions , quantile_levels )
8590
8691 # median as mean
87- mean = predictions [:, :, training_quantile_levels .index (0.5 )]
92+ median_idx = torch .tensor ([training_quantile_levels .index (0.5 )], dtype = torch .long , device = predictions .device )
93+ mean = torch .index_select (predictions , dim = - 1 , index = median_idx ).squeeze (- 1 )
8894 return quantiles , mean
8995
9096 @torch .inference_mode ()
@@ -105,24 +111,8 @@ def _forecast_tensor(
105111
106112 context = context .to (dtype = torch .float32 )
107113 while remaining > 0 :
108- if context .shape [- 1 ] > max_context :
109- context = context [..., - max_context :]
110- if context .shape [- 1 ] < min_context :
111- pad = torch .full (
112- (context .shape [0 ], min_context - context .shape [- 1 ]),
113- fill_value = torch .nan ,
114- device = context .device ,
115- dtype = context .dtype ,
116- )
117- context = torch .concat ((pad , context ), dim = 1 )
118- tokenized_tensor , tokenizer_state = self .tokenizer .context_input_transform (context )
119114 fut_rollouts = min (remaining , max_accelerated_rollout_steps )
120- with torch .no_grad ():
121- prediction , _ = self ._forward_model_tokenized (input_token = tokenized_tensor , rollouts = fut_rollouts )
122- prediction = prediction [:, :, - fut_rollouts :, :].to (tokenized_tensor ) # predicted token
123- # [bs, num_quantiles, num_predicted_token, output_patch_size]
124- prediction = self .tokenizer .output_transform (prediction , tokenizer_state )
125- prediction = prediction .flatten (start_dim = 2 )
115+ prediction , fut_rollouts = self ._forecast_single_step (context , max_context , min_context , fut_rollouts )
126116
127117 predictions .append (prediction )
128118 remaining -= fut_rollouts
@@ -134,6 +124,33 @@ def _forecast_tensor(
134124
135125 return torch .cat (predictions , dim = - 1 )[..., :prediction_length ].to (dtype = torch .float32 )
136126
127+ def _forecast_single_step (
128+ self ,
129+ context : torch .Tensor ,
130+ max_context : int ,
131+ min_context : int ,
132+ new_patch_count : int = 1 ,
133+ ) -> tuple [torch .Tensor , int ]:
134+ if context .shape [- 1 ] > max_context :
135+ context = context [..., - max_context :]
136+ if context .shape [- 1 ] < min_context :
137+ pad = torch .full (
138+ (context .shape [0 ], min_context - context .shape [- 1 ]),
139+ fill_value = torch .nan ,
140+ device = context .device ,
141+ dtype = context .dtype ,
142+ )
143+ context = torch .concat ((pad , context ), dim = 1 )
144+
145+ tokenized_tensor , tokenizer_state = self .tokenizer .context_input_transform (context )
146+ prediction , _ = self ._forward_model_tokenized (input_token = tokenized_tensor , rollouts = new_patch_count )
147+ prediction = prediction [:, :, - new_patch_count :, :].to (tokenized_tensor ) # predicted token
148+ # Shape: [bs, num_quantiles, num_predicted_token, output_patch_size]
149+ prediction = self .tokenizer .output_transform (prediction , tokenizer_state )
150+ prediction = prediction .flatten (start_dim = 2 )
151+
152+ return prediction , new_patch_count
153+
137154 def _forward_model_tokenized (
138155 self ,
139156 input_token : torch .Tensor ,
@@ -165,21 +182,7 @@ def _forward_model_tokenized(
165182
166183 input_token = torch .nan_to_num (input_token , nan = self .config .nan_mask_value )
167184
168- hidden_states = self .input_patch_embedding (torch .cat ((input_token , input_mask ), dim = 2 ))
169-
170- for block in self .blocks :
171- hidden_states = block (hidden_states )
172-
173- hidden_states = self .out_norm (hidden_states )
174-
175- quantile_preds = self .output_patch_embedding (hidden_states )
176- quantile_preds = torch .unflatten (
177- quantile_preds , - 1 , (len (self .config .quantiles ), self .config .output_patch_size )
178- )
179- quantile_preds = torch .transpose (quantile_preds , 1 , 2 ) # switch quantile and num_token_dimension
180- # quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size]
181-
182- quantile_preds = self ._forward_model (torch .cat ((input_token , input_mask ), dim = 2 ))
185+ quantile_preds , hidden_states = self ._forward_model (torch .cat ((input_token , input_mask ), dim = 2 ))
183186
184187 quantile_preds = torch .unflatten (
185188 quantile_preds , - 1 , (len (self .config .quantiles ), self .config .output_patch_size )
@@ -196,7 +199,7 @@ def _forward_model(self, input: torch.Tensor):
196199
197200 hidden_states = self .out_norm (hidden_states )
198201
199- return self .output_patch_embedding (hidden_states )
202+ return self .output_patch_embedding (hidden_states ), hidden_states
200203
201204 def _interpolate_quantiles (self , predictions : torch .Tensor , quantile_levels : list [float ]):
202205 training_quantile_levels = self .config .quantiles
0 commit comments