Skip to content

Commit e9b3f9a

Browse files
authored
Merge branch 'master' into patch-1
2 parents 7482f1b + ad79205 commit e9b3f9a

3 files changed

Lines changed: 61 additions & 4 deletions

File tree

Manifest.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,9 @@ version = "0.1.4"
8181

8282
[[GPUCompiler]]
8383
deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"]
84-
git-tree-sha1 = "550bb5127b9b6cf04bb86d72ac37a81a11a204d6"
84+
git-tree-sha1 = "11b2d77f29a85f3649c273a38f6618121c6b1c51"
8585
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
86-
version = "0.20.1"
86+
version = "0.20.2"
8787

8888
[[InteractiveUtils]]
8989
deps = ["Markdown"]

src/random.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,22 @@
11
using Random
22

3-
Random.rand!(A::oneWrappedArray) = Random.rand!(GPUArrays.default_rng(oneArray), A)
4-
Random.randn!(A::oneWrappedArray) = Random.randn!(GPUArrays.default_rng(oneArray), A)
3+
gpuarrays_rng() = GPUArrays.default_rng(oneArray)
4+
5+
# GPUArrays in-place
6+
Random.rand!(A::oneWrappedArray) = Random.rand!(gpuarrays_rng(), A)
7+
Random.randn!(A::oneWrappedArray) = Random.randn!(gpuarrays_rng(), A)
8+
9+
# GPUArrays out-of-place
10+
rand(T::Type, dims::Dims) = Random.rand!(oneArray{T}(undef, dims...))
11+
randn(T::Type, dims::Dims; kwargs...) = Random.randn!(oneArray{T}(undef, dims...); kwargs...)
12+
13+
# support all dimension specifications
14+
rand(T::Type, dim1::Integer, dims::Integer...) = Random.rand!(oneArray{T}(undef, dim1, dims...))
15+
randn(T::Type, dim1::Integer, dims::Integer...; kwargs...) = Random.randn!(oneArray{T}(undef, dim1, dims...); kwargs...)
16+
17+
# untyped out-of-place
18+
rand(dim1::Integer, dims::Integer...) = Random.rand!(oneArray{Float32}(undef, dim1, dims...))
19+
randn(dim1::Integer, dims::Integer...; kwargs...) = Random.randn!(oneArray{Float32}(undef, dim1, dims...); kwargs...)
20+
21+
# seeding
22+
seed!(seed=Base.rand(UInt64)) = Random.seed!(gpuarrays_rng(), seed)

test/random.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using Random
2+
3+
@testset "rand" begin
4+
5+
# in-place
6+
for (f,T) in ((rand!,Float16),
7+
(rand!,Float32),
8+
(randn!,Float16),
9+
(randn!,Float32)),
10+
d in (2, (2,2), (2,2,2), 3, (3,3), (3,3,3))
11+
A = oneArray{T}(undef, d)
12+
fill!(A, T(0))
13+
f(A)
14+
@test !iszero(collect(A))
15+
end
16+
17+
# out-of-place, with implicit type
18+
for (f,T) in ((oneAPI.rand,Float32), (oneAPI.randn,Float32)),
19+
args in ((2,), (2, 2), (3,), (3, 3))
20+
A = f(args...)
21+
@test eltype(A) == T
22+
end
23+
24+
# out-of-place, with type specified
25+
for (f,T) in ((oneAPI.rand,Float32), (oneAPI.randn,Float32),
26+
(rand,Float32), (randn,Float32)),
27+
args in ((T, 2), (T, 2, 2), (T, (2, 2)), (T, 3), (T, 3, 3), (T, (3, 3)))
28+
A = f(args...)
29+
@test eltype(A) == T
30+
end
31+
32+
## seeding
33+
oneAPI.seed!(1)
34+
a = oneAPI.rand(Int32, 1)
35+
oneAPI.seed!(1)
36+
b = oneAPI.rand(Int32, 1)
37+
@test iszero(collect(a) - collect(b))
38+
39+
end # testset

0 commit comments

Comments
 (0)