Skip to content
This repository was archived by the owner on Nov 1, 2021. It is now read-only.

Commit 22b229a

Browse files
author
Bart van Merriënboer
committed
Add overwriting test, fix indexing test
1 parent db06c97 commit 22b229a

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

src/gradcheck.lua

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
-- Autograd
22
local autograd = require 'autograd'
3+
local util = require 'autograd.util'
34

45
-- Perturbation (finite diffs):
56
local perturbation = 1e-6
@@ -12,20 +13,30 @@ local function jacobianFromAutograd(func, inputs, key)
1213
-- Autograd:
1314
local df = autograd(func)
1415
local grads = df(table.unpack(inputs))
15-
local gradsVerify = df(table.unpack(inputs))
1616

1717
-- Find grad:
1818
local g = autograd.util.nestedGet(grads, key)
19+
local g_clone
20+
if torch.isTensor(g) then
21+
g_clone = g:clone()
22+
end
23+
24+
-- Get the grad again
25+
local gradsVerify = df(table.unpack(inputs))
1926
local gVerify = autograd.util.nestedGet(gradsVerify, key)
2027
local err
28+
local overwrite_err = 0
2129
if torch.isTensor(g) then
2230
err = (g - gVerify):abs():max()
31+
overwrite_err = (g - g_clone):abs():max()
2332
else
2433
err = torch.abs(g - gVerify)
2534
end
2635

2736
if err ~= 0 then
2837
error("autograd gradient not deterministic")
38+
elseif overwrite_err ~= 0 then
39+
error("autograd gradient overwritten when called twice")
2940
end
3041

3142
-- Return grads:

test/test.lua

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,8 +1688,9 @@ local tests = {
16881688
end
16891689
tester:assert(gradcheck(f4,{x=torch.randn(10,10),y=torch.randn(3)}), "Incorrect gradient")
16901690
local f5 = function(params)
1691-
params.x[2] = params.y*2.0
1692-
return torch.sum(params.x)
1691+
local xc = torch.clone(params.x)
1692+
xc[2] = params.y * 2.0
1693+
return torch.sum(xc)
16931694
end
16941695
tester:assert(gradcheck(f5,{x=torch.randn(10,10),y=torch.randn(10)}), "Incorrect gradient")
16951696
end,

0 commit comments

Comments
 (0)