Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 17 additions & 21 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
name: CI
on:
pull_request:
branches:
- master
push:
branches:
- master
tags: '*'
tags: ['*']
pull_request:
concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
Expand All @@ -15,33 +18,26 @@ jobs:
fail-fast: false
matrix:
version:
- '1.0' # Replace this with the minimum Julia version that your package supports. E.g. if your package requires Julia 1.5 or higher, change this to '1.5'.
- '1' # Leave this line unchanged. '1' will automatically expand to the latest stable 1.x release of Julia.
os: [ubuntu-latest, windows-latest, macOS-latest]
- '1'
os:
- ubuntu-latest
- windows-latest # Add this line to include the latest Windows system
arch:
- x64
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: actions/cache@v1
env:
cache-name: cache-artifacts
with:
path: ~/.julia/artifacts
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
restore-keys: |
${{ runner.os }}-test-${{ env.cache-name }}-
${{ runner.os }}-test-
${{ runner.os }}-
- uses: julia-actions/cache@v1
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v1
with:
file: lcov.info
- uses: codecov/codecov-action@v5
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}

docs:
name: Documentation
runs-on: ubuntu-latest
Expand Down
13 changes: 11 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
name = "ShiftedArrays"
uuid = "1277b4bf-5013-50f5-be3d-901d8477a67a"
repo = "https://github.com/JuliaArrays/ShiftedArrays.jl.git"
version = "2.0.0"

[compat]
julia = "1"
CUDA = "5.2, 5.3, 5.4, 5.5, 5.6, 5.7"
Adapt = "3.7, 4.0, 4.1"

[extensions]
CUDASupportExt = ["CUDA", "Adapt"]

[extras]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[targets]
test = ["Test", "AbstractFFTs"]
test = ["Test", "AbstractFFTs", "Random", "CUDA"]
99 changes: 99 additions & 0 deletions ext/CUDASupportExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
module CUDASupportExt
using CUDA
using Adapt
using ShiftedArrays
using Base

get_base_arr(arr::CuArray) = arr
get_base_arr(arr::Array) = arr
function get_base_arr(arr::AbstractArray)
p = parent(arr)
return (p === arr) ? arr : get_base_arr(parent(arr))

Check warning on line 11 in ext/CUDASupportExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/CUDASupportExt.jl#L7-L11

Added lines #L7 - L11 were not covered by tests
end

# define a number of Union types to not repeat all definitions for each type
AllShiftedTypeCu{N, CD} = Union{CircShiftedArray{<:Any,<:Any,<:CuArray{<:Any,N,CD}},
ShiftedArray{<:Any,<:Any,<:Any,<:CuArray{<:Any,N,CD}}}
AllShiftedTypeCuG{N, CD} = Union{AllShiftedTypeCu{N, CD}, CircShiftedArray{<:Any,<:Any,<:AllShiftedTypeCu{N,CD}},
ShiftedArray{<:Any,<:Any,<:Any,<:AllShiftedTypeCu{N,CD}}}
AllSubArrayTypeCu{N, CD} = Union{SubArray{<:Any, <:Any, <:AllShiftedTypeCuG{N,CD}, <:Any, <:Any},
Base.ReshapedArray{<:Any, <:Any, <:AllShiftedTypeCuG{N,CD}, <:Any},
SubArray{<:Any, <:Any, <:Base.ReshapedArray{<:Any, <:Any, <:AllShiftedTypeCuG{N,CD}, <:Any}, <:Any, <:Any}}
AllShiftedAndViewsCu{N, CD} = Union{AllShiftedTypeCuG{N, CD}, AllSubArrayTypeCu{N, CD}}

Adapt.adapt_structure(to, x::CircShiftedArray{T, N, S}) where {T, N, S} = CircShiftedArray(adapt(to, parent(x)), shifts(x));
Adapt.adapt_structure(to, x::ShiftedArray{T, V, N, S}) where {T, V, N, S} = ShiftedArray(adapt(to, parent(x)), shifts(x), default=ShiftedArrays.default(x));

Check warning on line 25 in ext/CUDASupportExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/CUDASupportExt.jl#L24-L25

Added lines #L24 - L25 were not covered by tests

function Base.Broadcast.BroadcastStyle(::Type{T}) where {N, CD, T<:AllShiftedTypeCu{N, CD}}
CUDA.CuArrayStyle{N,CD}()

Check warning on line 28 in ext/CUDASupportExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/CUDASupportExt.jl#L27-L28

Added lines #L27 - L28 were not covered by tests
end

# Define the BroadcastStyle for SubArray of MutableShiftedArray with CuArray

function Base.Broadcast.BroadcastStyle(::Type{T}) where {N, CD, T<:AllSubArrayTypeCu{N, CD}}
CUDA.CuArrayStyle{N,CD}()

Check warning on line 34 in ext/CUDASupportExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/CUDASupportExt.jl#L33-L34

Added lines #L33 - L34 were not covered by tests
end

function Base.copy(s::AllShiftedAndViewsCu)
res = similar(get_base_arr(s), eltype(s), size(s));
res .= s
return res

Check warning on line 40 in ext/CUDASupportExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/CUDASupportExt.jl#L37-L40

Added lines #L37 - L40 were not covered by tests
end

function Base.collect(x::AllShiftedAndViewsCu)
return copy(x) # stay on the GPU

Check warning on line 44 in ext/CUDASupportExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/CUDASupportExt.jl#L43-L44

Added lines #L43 - L44 were not covered by tests
end

function Base.Array(x::AllShiftedAndViewsCu)
return Array(copy(x)) # remove from GPU

Check warning on line 48 in ext/CUDASupportExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/CUDASupportExt.jl#L47-L48

Added lines #L47 - L48 were not covered by tests
end

function Base.:(==)(x::AllShiftedAndViewsCu, y::AbstractArray)
return all(x .== y)

Check warning on line 52 in ext/CUDASupportExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/CUDASupportExt.jl#L51-L52

Added lines #L51 - L52 were not covered by tests
end

function Base.:(==)(y::AbstractArray, x::AllShiftedAndViewsCu)
return all(x .== y)

Check warning on line 56 in ext/CUDASupportExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/CUDASupportExt.jl#L55-L56

Added lines #L55 - L56 were not covered by tests
end

function Base.:(==)(x::AllShiftedAndViewsCu, y::AllShiftedAndViewsCu)
return all(x .== y)

Check warning on line 60 in ext/CUDASupportExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/CUDASupportExt.jl#L59-L60

Added lines #L59 - L60 were not covered by tests
end

function Base.isapprox(x::AllShiftedAndViewsCu, y::AbstractArray; atol=0, rtol=atol>0 ? 0 : sqrt(eps(real(eltype(x)))), va...)
atol = (atol != 0) ? atol : rtol * maximum(abs.(x))
return all(abs.(x .- y) .<= atol)

Check warning on line 65 in ext/CUDASupportExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/CUDASupportExt.jl#L63-L65

Added lines #L63 - L65 were not covered by tests
end

function Base.isapprox(y::AbstractArray, x::AllShiftedAndViewsCu; atol=0, rtol=atol>0 ? 0 : sqrt(eps(real(eltype(x)))), va...)
atol = (atol != 0) ? atol : rtol * maximum(abs.(x))
return all(abs.(x .- y) .<= atol)

Check warning on line 70 in ext/CUDASupportExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/CUDASupportExt.jl#L68-L70

Added lines #L68 - L70 were not covered by tests
end

function Base.isapprox(x::AllShiftedAndViewsCu, y::AllShiftedAndViewsCu; atol=0, rtol=atol>0 ? 0 : sqrt(eps(real(eltype(x)))), va...)
atol = (atol != 0) ? atol : rtol * maximum(abs.(x))
return all(abs.(x .- y) .<= atol)

Check warning on line 75 in ext/CUDASupportExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/CUDASupportExt.jl#L73-L75

Added lines #L73 - L75 were not covered by tests
end

function Base.show(io::IO, mm::MIME"text/plain", cs::AllShiftedAndViewsCu)
CUDA.@allowscalar invoke(Base.show, Tuple{IO, typeof(mm), AbstractArray}, io, mm, cs)

Check warning on line 79 in ext/CUDASupportExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/CUDASupportExt.jl#L78-L79

Added lines #L78 - L79 were not covered by tests
end

# This version is needed to deal with range access of wrapped CuArrays.
# ShiftedVector(cu([1,2,3,4,5]))[2:3]
@inline function Base.getindex(s::AllShiftedTypeCu{N, CD}, x::Vararg{Union{AbstractRange, Int}, N}) where {N, CD}
v = @view s[x...]
res = similar(s.parent, eltype(s), size(v))
res .= v

Check warning on line 87 in ext/CUDASupportExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/CUDASupportExt.jl#L84-L87

Added lines #L84 - L87 were not covered by tests
end

# This specializations are to ensure that true single element accesses generate an error, if allowscalar has not be specified.
@inline function Base.getindex(s::ShiftedArray{A,B,C, <:CuArray{<:Any,N,CD}}, x::Vararg{Int, N}) where {A,B,C, N,CD}
invoke(ShiftedArrays.getindex, Tuple{ShiftedArray{A,B,C,<:AbstractArray}, ntuple((_)->Int, N)...}, s, x...)

Check warning on line 92 in ext/CUDASupportExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/CUDASupportExt.jl#L91-L92

Added lines #L91 - L92 were not covered by tests
end

@inline function Base.getindex(s::CircShiftedArray{A,B,<:CuArray{<:Any,N,CD}}, x::Vararg{Int, N}) where {A,B,N,CD}
invoke(ShiftedArrays.getindex, Tuple{CircShiftedArray{A,B,<:AbstractArray}, ntuple((_)->Int, N)...}, s, x...)

Check warning on line 96 in ext/CUDASupportExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/CUDASupportExt.jl#L95-L96

Added lines #L95 - L96 were not covered by tests
end

end
Loading