Skip to content

Commit 849665a

Browse files
Alex-PLACETCopilot
andauthored
Fix vstack for mixed fixed-shape inputs (#2897)
# Checklist - [x] The title and commit message(s) are descriptive. - [x] Small commits made to fix your PR have been squashed to avoid history pollution. - [x] Tests have been added for new features or bug fixes. - [x] API of new functions and classes are documented. # Description Fix xt::vstack so it works with xtensor_fixed inputs when stacking a mix of 1-D and 2-D fixed-shape expressions. Before this change, vstack could try to build a runtime shape for fixed-shape inputs, which fails for cases like stacking xshape<1, 2> with xshape<2, 2>. The fixed-shape path now computes the stacked result shape at compile time instead. --------- Co-authored-by: Alexis Placet <2400067+Alex-PLACET@users.noreply.github.com> Co-authored-by: Copilot <copilot@github.com>
1 parent 0f06096 commit 849665a

3 files changed

Lines changed: 68 additions & 5 deletions

File tree

include/xtensor/generators/xbuilder.hpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,32 @@ namespace xt
910910

911911
namespace detail
912912
{
913+
template <class S>
914+
struct vstack_fixed_shape_impl;
915+
916+
template <std::size_t N>
917+
struct vstack_fixed_shape_impl<fixed_shape<N>>
918+
{
919+
using type = fixed_shape<1, N>;
920+
};
921+
922+
template <std::size_t I, std::size_t... J>
923+
struct vstack_fixed_shape_impl<fixed_shape<I, J...>>
924+
{
925+
using type = fixed_shape<I, J...>;
926+
};
927+
928+
template <class... CT>
929+
struct vstack_fixed_shape
930+
{
931+
using type = concat_fixed_shape_t<
932+
0,
933+
typename vstack_fixed_shape_impl<typename std::decay_t<CT>::shape_type>::type...>;
934+
};
935+
936+
template <class... CT>
937+
using vstack_fixed_shape_t = typename vstack_fixed_shape<CT...>::type;
938+
913939
template <class S, class... CT>
914940
inline auto vstack_shape(std::tuple<CT...>& t, const S& shape)
915941
{
@@ -948,6 +974,21 @@ namespace xt
948974
return detail::make_xgenerator(detail::vstack_impl<CT...>(std::move(t), size_t(0)), new_shape);
949975
}
950976

977+
/**
978+
* @brief Stack fixed-shape xexpressions in sequence vertically (row wise).
979+
* This overload preserves the result shape at compile time by treating
980+
* 1-D fixed shapes as ``(1, N)`` row vectors before concatenation.
981+
*
982+
* @param t \ref xtuple of fixed-shape xexpressions to stack
983+
* @return xgenerator evaluating to stacked elements with a fixed compile-time shape
984+
*/
985+
template <fixed_shape_container_concept... CT>
986+
inline auto vstack(std::tuple<CT...>&& t)
987+
{
988+
using shape_type = detail::vstack_fixed_shape_t<CT...>;
989+
return detail::make_xgenerator(detail::vstack_impl<CT...>(std::move(t), size_t(0)), shape_type{});
990+
}
991+
951992
namespace detail
952993
{
953994

include/xtensor/views/index_mapper.hpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ namespace xt
193193
* @throws Assertion failure if `i != 0` for integral slices.
194194
* @throws Assertion failure if `i >= slice.size()` for non-integral slices.
195195
*/
196-
template <size_t I, std::integral Index>
196+
template <access_t ACCESS, size_t I, std::integral Index>
197197
size_t map_ith_index(const view_type& view, const Index i) const;
198198

199199
/**
@@ -490,16 +490,16 @@ namespace xt
490490
{
491491
if constexpr (ACCESS == access_t::SAFE)
492492
{
493-
return container.at(map_ith_index<Is>(view, indices[Is])...);
493+
return container.at(map_ith_index<ACCESS, Is>(view, indices[Is])...);
494494
}
495495
else
496496
{
497-
return container(map_ith_index<Is>(view, indices[Is])...);
497+
return container(map_ith_index<ACCESS, Is>(view, indices[Is])...);
498498
}
499499
}
500500

501501
template <class UnderlyingContainer, class... Slices>
502-
template <size_t I, std::integral Index>
502+
template <access_t ACCESS, size_t I, std::integral Index>
503503
auto
504504
index_mapper<xt::xview<UnderlyingContainer, Slices...>>::map_ith_index(const view_type& view, const Index i) const
505505
-> size_t
@@ -518,10 +518,17 @@ namespace xt
518518
assert(i == 0);
519519
return size_t(slice);
520520
}
521+
else if constexpr (xt::detail::is_xall_slice<std::decay_t<current_slice>>::value)
522+
{
523+
return size_t(i);
524+
}
521525
else
522526
{
523527
using slice_size_type = typename current_slice::size_type;
524-
assert(i < slice.size());
528+
if constexpr (ACCESS == access_t::UNSAFE)
529+
{
530+
assert(static_cast<slice_size_type>(i) < slice.size());
531+
}
525532
return size_t(slice(static_cast<slice_size_type>(i)));
526533
}
527534
}

test/test_xbuilder.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,21 @@ namespace xt
424424
ASSERT_TRUE(arange(8) == w1);
425425
ASSERT_TRUE(w1 == w2);
426426
}
427+
428+
TEST(xbuilder, vstack_fixed)
429+
{
430+
xtensor_fixed<float, fixed_shape<1, 2>> a = {{1.f, 2.f}};
431+
xtensor_fixed<float, fixed_shape<2, 2>> b = {{3.f, 4.f}, {5.f, 6.f}};
432+
433+
auto c = vstack(xtuple(a, b));
434+
435+
using expected_shape_t = fixed_shape<3, 2>;
436+
ASSERT_EQ(expected_shape_t{}, c.shape());
437+
EXPECT_EQ(1.f, c(0, 0));
438+
EXPECT_EQ(2.f, c(0, 1));
439+
EXPECT_EQ(3.f, c(1, 0));
440+
EXPECT_EQ(6.f, c(2, 1));
441+
}
427442
#endif
428443

429444
TEST(xbuilder, access)

0 commit comments

Comments
 (0)