-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrandomdataset.lua
More file actions
executable file
·39 lines (34 loc) · 1.1 KB
/
randomdataset.lua
File metadata and controls
executable file
·39 lines (34 loc) · 1.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
local tntenv = require 'torchnet.env'
local tnt = require 'torchnet'
local argcheck = require 'argcheck'
local RandomDataset, ResampleDataset =
torch.class('tnt.RandomDataset', 'tnt.ResampleDataset', tntenv)
RandomDataset.__init = argcheck{
doc = [[
<a name="RandomDataset">
#### tnt.RandomDataset(@ARGP)
@ARGT
`tnt.RandomDataset` is built using ResampleDataset.
Use `:random()` to substitute `:shuffle()` and `:random(size)` to
substitute `:shuffle(size,true)`.
Purpose: uses random sampling so `resample()` does not have to be called
after every epoch. Note: keep seed different if using in Parallel.
]],
{name='self', type='tnt.RandomDataset'},
{name='dataset', type='tnt.Dataset'},
{name='size', type='number', opt=true},
call =
function(self, dataset,size)
local function sampler(dataset, idx)
return torch.random(1,dataset:size())
end
ResampleDataset.__init(self, {
dataset = dataset,
sampler = sampler,
size = size })
end
}
tnt.Dataset.random =
function(...)
return tnt.RandomDataset(...)
end