@@ -22,13 +22,9 @@ using namespace sparse_tensor;
2222// Helper methods.
2323// ===----------------------------------------------------------------------===//
2424
25- // TODO: reuse StorageLayout::foreachField?
26-
27- // TODO: we need COO AoS and SoA
28-
2925// Convert type range to new types range, with sparse tensors externalized.
30- void convTypes (TypeRange types, SmallVectorImpl<Type> &convTypes,
31- SmallVectorImpl<Type> *extraTypes = nullptr ) {
26+ static void convTypes (TypeRange types, SmallVectorImpl<Type> &convTypes,
27+ SmallVectorImpl<Type> *extraTypes = nullptr ) {
3228 for (auto type : types) {
3329 // All "dense" data passes through unmodified.
3430 if (!getSparseTensorEncoding (type)) {
@@ -42,29 +38,30 @@ void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
4238 convTypes.push_back (vtp);
4339 if (extraTypes)
4440 extraTypes->push_back (vtp);
45- // Convert the external representations of the pos/crd arrays.
46- for (Level lvl = 0 , lvlRank = stt.getLvlRank (); lvl < lvlRank; lvl++) {
47- const auto lt = stt.getLvlType (lvl);
48- if (isCompressedLT (lt) || isLooseCompressedLT (lt)) {
49- auto ptp = RankedTensorType::get (shape, stt.getPosType ());
50- auto ctp = RankedTensorType::get (shape, stt.getCrdType ());
51- convTypes.push_back (ptp);
52- convTypes.push_back (ctp);
53- if (extraTypes) {
54- extraTypes->push_back (ptp);
55- extraTypes->push_back (ctp);
56- }
57- } else {
58- assert (isDenseLT (lt)); // TODO: handle other cases
41+
42+ // Convert the external representation of the position/coordinate array.
43+ foreachFieldAndTypeInSparseTensor (stt, [&convTypes, extraTypes](
44+ Type t, FieldIndex,
45+ SparseTensorFieldKind kind,
46+ Level, LevelType) {
47+ if (kind == SparseTensorFieldKind::CrdMemRef ||
48+ kind == SparseTensorFieldKind::PosMemRef) {
49+ ShapedType st = t.cast <ShapedType>();
50+ auto rtp = RankedTensorType::get (st.getShape (), st.getElementType ());
51+ convTypes.push_back (rtp);
52+ if (extraTypes)
53+ extraTypes->push_back (rtp);
5954 }
60- }
55+ return true ;
56+ });
6157 }
6258}
6359
6460// Convert input and output values to [dis]assemble ops for sparse tensors.
65- void convVals (OpBuilder &builder, Location loc, TypeRange types,
66- ValueRange fromVals, ValueRange extraVals,
67- SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn) {
61+ static void convVals (OpBuilder &builder, Location loc, TypeRange types,
62+ ValueRange fromVals, ValueRange extraVals,
63+ SmallVectorImpl<Value> &toVals, unsigned extra,
64+ bool isIn) {
6865 unsigned idx = 0 ;
6966 for (auto type : types) {
7067 // All "dense" data passes through unmodified.
@@ -85,29 +82,28 @@ void convVals(OpBuilder &builder, Location loc, TypeRange types,
8582 if (!isIn) {
8683 inputs.push_back (extraVals[extra++]);
8784 retTypes.push_back (RankedTensorType::get (shape, stt.getElementType ()));
88- cntTypes.push_back (builder.getIndexType ());
85+ cntTypes.push_back (builder.getIndexType ()); // nnz
8986 }
87+
9088 // Collect the external representations of the pos/crd arrays.
91- for (Level lvl = 0 , lvlRank = stt.getLvlRank (); lvl < lvlRank; lvl++) {
92- const auto lt = stt.getLvlType (lvl);
93- if (isCompressedLT (lt) || isLooseCompressedLT (lt)) {
89+ foreachFieldAndTypeInSparseTensor (stt, [&, isIn](Type t, FieldIndex,
90+ SparseTensorFieldKind kind,
91+ Level, LevelType) {
92+ if (kind == SparseTensorFieldKind::CrdMemRef ||
93+ kind == SparseTensorFieldKind::PosMemRef) {
9494 if (isIn) {
9595 inputs.push_back (fromVals[idx++]);
96- inputs.push_back (fromVals[idx++]);
9796 } else {
98- Type pTp = stt.getPosType ();
99- Type cTp = stt.getCrdType ();
100- inputs.push_back (extraVals[extra++]);
97+ ShapedType st = t.cast <ShapedType>();
98+ auto rtp = RankedTensorType::get (st.getShape (), st.getElementType ());
10199 inputs.push_back (extraVals[extra++]);
102- retTypes.push_back (RankedTensorType::get (shape, pTp));
103- retTypes.push_back (RankedTensorType::get (shape, cTp));
104- cntTypes.push_back (pTp);
105- cntTypes.push_back (cTp);
100+ retTypes.push_back (rtp);
101+ cntTypes.push_back (rtp.getElementType ());
106102 }
107- } else {
108- assert (isDenseLT (lt)); // TODO: handle other cases
109103 }
110- }
104+ return true ;
105+ });
106+
111107 if (isIn) {
112108 // Assemble multiple inputs into a single sparse tensor.
113109 auto a = builder.create <sparse_tensor::AssembleOp>(loc, rtp, inputs);
0 commit comments