Skip to content

Commit f8a56f2

Browse files
committed
fix issue with minist dataloader
1 parent e3caa2c commit f8a56f2

File tree

5 files changed

+9
-9
lines changed

5 files changed

+9
-9
lines changed

examples/mnist/SOM_LM-SNNs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777

7878
# Determines number of workers to use
7979
if n_workers == -1:
80-
n_workers = torch.cuda.is_available() * 4 * torch.cuda.device_count()
80+
n_workers = 0 # torch.cuda.is_available() * 4 * torch.cuda.device_count()
8181

8282
n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
8383
start_intensity = intensity

examples/mnist/batch_eth_mnist.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
parser.add_argument("--test", dest="train", action="store_false")
5050
parser.add_argument("--plot", dest="plot", action="store_true")
5151
parser.add_argument("--gpu", dest="gpu", action="store_true")
52-
parser.set_defaults(plot=True, gpu=False)
52+
parser.set_defaults(plot=True, gpu=True)
5353

5454
args = parser.parse_args()
5555

@@ -90,7 +90,7 @@
9090

9191
# Determines number of workers to use
9292
if n_workers == -1:
93-
n_workers = gpu * 4 * torch.cuda.device_count()
93+
n_workers = 0 # gpu * 1 * torch.cuda.device_count()
9494

9595
n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
9696
start_intensity = intensity
@@ -116,7 +116,7 @@
116116
dataset = MNIST(
117117
PoissonEncoder(time=time, dt=dt),
118118
None,
119-
root=os.path.join(ROOT_DIR, "data", "MNIST"),
119+
"../../data/MNIST",
120120
download=True,
121121
transform=transforms.Compose(
122122
[transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]

examples/mnist/conv_mnist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
parser.add_argument("--test", dest="train", action="store_false")
4242
parser.add_argument("--plot", dest="plot", action="store_true")
4343
parser.add_argument("--gpu", dest="gpu", action="store_true")
44-
parser.set_defaults(plot=True, gpu=False, train=True)
44+
parser.set_defaults(plot=True, gpu=True, train=True)
4545

4646
args = parser.parse_args()
4747

@@ -164,7 +164,7 @@
164164
start = t()
165165

166166
train_dataloader = torch.utils.data.DataLoader(
167-
train_dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=gpu
167+
train_dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=gpu
168168
)
169169

170170
for step, batch in enumerate(tqdm(train_dataloader)):

examples/mnist/eth_mnist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
parser.add_argument("--test", dest="train", action="store_false")
4949
parser.add_argument("--plot", dest="plot", action="store_true")
5050
parser.add_argument("--gpu", dest="gpu", action="store_true")
51-
parser.set_defaults(plot=True, gpu=False)
51+
parser.set_defaults(plot=True, gpu=True)
5252

5353
args = parser.parse_args()
5454

@@ -85,7 +85,7 @@
8585

8686
# Determines number of workers to use
8787
if n_workers == -1:
88-
n_workers = gpu * 4 * torch.cuda.device_count()
88+
n_workers = 0 # gpu * 4 * torch.cuda.device_count()
8989

9090
if not train:
9191
update_interval = n_test

examples/mnist/supervised_mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
parser.add_argument("--plot", dest="plot", action="store_true")
4747
parser.add_argument("--gpu", dest="gpu", action="store_true")
4848
parser.add_argument("--device_id", type=int, default=0)
49-
parser.set_defaults(plot=True, gpu=False, train=True)
49+
parser.set_defaults(plot=True, gpu=True, train=True)
5050

5151
args = parser.parse_args()
5252

0 commit comments

Comments
 (0)