Skip to content

Commit 19a0eb7

Browse files
Fix TCN model input dimension mismatch (#1520)
* transpose dimension 1 and 2 to match nn.Conv1d input * 1.update TCN benchmarks; 2.Emphasize updating the benchmark table; * replace specific version with main --------- Co-authored-by: lijinhui <[email protected]>
1 parent 3704772 commit 19a0eb7

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

examples/benchmarks/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
2626

2727
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
2828
|------------------------------------------|-------------------------------------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
29-
| TCN(Shaojie Bai, et al.) | Alpha158 | 0.0275±0.00 | 0.2157±0.01 | 0.0411±0.00 | 0.3379±0.01 | 0.0190±0.02 | 0.2887±0.27 | -0.1202±0.03 |
29+
| TCN(Shaojie Bai, et al.) | Alpha158 | 0.0279±0.00 | 0.2181±0.01 | 0.0421±0.00 | 0.3429±0.01 | 0.0262±0.02 | 0.4133±0.25 | -0.1090±0.03 |
3030
| TabNet(Sercan O. Arik, et al.) | Alpha158 | 0.0204±0.01 | 0.1554±0.07 | 0.0333±0.00 | 0.2552±0.05 | 0.0227±0.04 | 0.3676±0.54 | -0.1089±0.08 |
3131
| Transformer(Ashish Vaswani, et al.) | Alpha158 | 0.0264±0.00 | 0.2053±0.02 | 0.0407±0.00 | 0.3273±0.02 | 0.0273±0.02 | 0.3970±0.26 | -0.1101±0.02 |
3232
| GRU(Kyunghyun Cho, et al.) | Alpha158(with selected 20 features) | 0.0315±0.00 | 0.2450±0.04 | 0.0428±0.00 | 0.3440±0.03 | 0.0344±0.02 | 0.5160±0.25 | -0.1017±0.02 |
@@ -134,7 +134,7 @@ If you want to contribute your new models, you can follow the steps below.
134134
- `README.md`: a brief introduction to your models
135135
- `workflow_config_<model name>_<dataset>.yaml`: a configuration which can read by `qrun`. You are encouraged to run your model in all datasets.
136136
3. You can integrate your model as a module [in this folder](https://github.com/microsoft/qlib/tree/main/qlib/contrib/model).
137-
4. Please updated your results in the benchmark tables, e.g. [Alpha360](#alpha158-dataset), [Alpha158](#alpha158-dataset)(the values of each metric are the mean and std calculated based on 20 runs with different random seeds, if you don't have enough computational resource, you can ask for help in the PR).
137+
4. Please update your results in the above **Benchmark Tables**, e.g. [Alpha360](#alpha158-dataset), [Alpha158](#alpha158-dataset)(the values of each metric are the mean and std calculated based on **20 Runs** with different random seeds. You can accomplish the above operations through the automated [script](https://github.com/microsoft/qlib/blob/main/examples/run_all_model.py#LL286C22-L286C22) provided by Qlib, and get the final result in the .md file. if you don't have enough computational resource, you can ask for help in the PR).
138138
5. Update the info in the index page in the [news list](https://github.com/microsoft/qlib#newspaper-whats-new----sparkling_heart) and [model list](https://github.com/microsoft/qlib#quant-model-paper-zoo).
139139

140140
Finally, you can send PR for review. ([here is an example](https://github.com/microsoft/qlib/pull/1040))

qlib/contrib/model/pytorch_tcn_ts.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ def train_epoch(self, data_loader):
168168
self.TCN_model.train()
169169

170170
for data in data_loader:
171-
feature = data[:, :, 0:-1].to(self.device)
171+
data = torch.transpose(data, 1, 2)
172+
feature = data[:, 0:-1, :].to(self.device)
172173
label = data[:, -1, -1].to(self.device)
173174

174175
pred = self.TCN_model(feature.float())
@@ -187,8 +188,8 @@ def test_epoch(self, data_loader):
187188
losses = []
188189

189190
for data in data_loader:
190-
191-
feature = data[:, :, 0:-1].to(self.device)
191+
data = torch.transpose(data, 1, 2)
192+
feature = data[:, 0:-1, :].to(self.device)
192193
# feature[torch.isnan(feature)] = 0
193194
label = data[:, -1, -1].to(self.device)
194195

0 commit comments

Comments
 (0)