From 1cd08b0eaec5fc014d1b37935db9ad61117183b1 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 23 Apr 2025 23:48:12 +0530 Subject: [PATCH] Specialize `one` for the `SizedArray` test helper --- test/abstractarray.jl | 16 ++++++++++++++++ test/testhelpers/SizedArrays.jl | 5 +++++ 2 files changed, 21 insertions(+) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index b882778e4b152..01e13f17460b5 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -2191,6 +2191,22 @@ end @test one(Mat([1 2; 3 4])) == Mat([1 0; 0 1]) @test one(Mat([1 2; 3 4])) isa Mat + + @testset "SizedArray" begin + S = [1 2; 3 4] + A = SizedArrays.SizedArray{(2,2)}(S) + @test one(A) == one(typeof(A)) + @test oneunit(A) == oneunit(typeof(A)) + M = fill(A, 2, 2) + O = one(M) + for I in CartesianIndices(M) + if I[1] == I[2] + @test O[I] == one(S) + else + @test O[I] == zero(S) + end + end + end end @testset "copyto! with non-AbstractArray src" begin diff --git a/test/testhelpers/SizedArrays.jl b/test/testhelpers/SizedArrays.jl index 961784b89ab68..bd0272d78987d 100644 --- a/test/testhelpers/SizedArrays.jl +++ b/test/testhelpers/SizedArrays.jl @@ -54,6 +54,11 @@ Base.axes(a::SizedArray) = map(SOneTo, size(a)) Base.getindex(A::SizedArray, i...) = getindex(A.data, i...) Base.setindex!(A::SizedArray, v, i...) = setindex!(A.data, v, i...) Base.zero(::Type{T}) where T <: SizedArray = SizedArray{size(T)}(zeros(eltype(T), size(T))) +function Base.one(::Type{SizedMatrix{SZ,T,A}}) where {SZ,T,A} + allequal(SZ) || throw(DimensionMismatch("multiplicative identity defined only for square matrices")) + D = diagm(fill(one(T), SZ[1])) + SizedArray{SZ}(convert(A, D)) +end Base.parent(S::SizedArray) = S.data +(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = SizedArray{SZ}(S1.data + S2.data) ==(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = S1.data == S2.data