Skip to content

Commit c91f47e

Browse files
author
Henry Linjamäki
committed
Eliminate redundant kernel argument copies
... within hipLaunchKernel() call. Along the way fix CHIPGraphNodeKernel instances didn't copy kernel arguments fully (they only copied pointers to arguments but not their values).
1 parent e847890 commit c91f47e

9 files changed

Lines changed: 88 additions & 47 deletions

File tree

src/CHIPBackend.cc

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ static void queueKernel(chipstar::Queue *Q, chipstar::Kernel *K,
3434
::Backend->createExecItem(GridDim, BlockDim, SharedMemSize, Q);
3535
EI->setKernel(K);
3636

37-
EI->copyArgs(Args);
37+
EI->setArgs(Args);
3838
EI->setupAllArgs();
3939

4040
auto ChipQueue = EI->getQueue();
@@ -497,11 +497,6 @@ void *chipstar::ArgSpillBuffer ::allocate(const SPVFuncInfo::Arg &Arg) {
497497

498498
// ExecItem
499499
//*************************************************************************************
500-
void chipstar::ExecItem::copyArgs(void **Args) {
501-
for (int i = 0; i < getNumArgs(); i++) {
502-
Args_.push_back(Args[i]);
503-
}
504-
}
505500

506501
chipstar::ExecItem::ExecItem(dim3 GridDim, dim3 BlockDim, size_t SharedMem,
507502
hipStream_t ChipQueue)
@@ -1890,7 +1885,7 @@ void chipstar::Queue::launchKernel(chipstar::Kernel *ChipKernel, dim3 NumBlocks,
18901885
chipstar::ExecItem *ExItem =
18911886
::Backend->createExecItem(NumBlocks, DimBlocks, SharedMemBytes, this);
18921887
ExItem->setKernel(ChipKernel);
1893-
ExItem->copyArgs(Args);
1888+
ExItem->setArgs(Args);
18941889
ExItem->setupAllArgs();
18951890
launch(ExItem);
18961891
delete ExItem;

src/CHIPBackend.hh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,12 +1166,12 @@ protected:
11661166

11671167
chipstar::Queue *ChipQueue_;
11681168

1169-
std::vector<void *> Args_;
1169+
void **Args_;
11701170

11711171
std::shared_ptr<chipstar::ArgSpillBuffer> ArgSpillBuffer_;
11721172

11731173
public:
1174-
void copyArgs(void **Args);
1174+
void setArgs(void **Args) { Args_ = Args; }
11751175
void setQueue(chipstar::Queue *Queue) { ChipQueue_ = Queue; }
11761176
std::mutex ExecItemMtx;
11771177
size_t getNumArgs() {
@@ -1183,7 +1183,7 @@ public:
11831183
/**
11841184
* @brief Return argument list.
11851185
*/
1186-
const std::vector<void *> &getArgs() const { return Args_; }
1186+
void **getArgs() const { return Args_; }
11871187

11881188
/**
11891189
* @brief Deleted default constructor

src/CHIPGraph.cc

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,18 +125,22 @@ CHIPGraphNodeKernel::CHIPGraphNodeKernel(const hipKernelNodeParams *TheParams)
125125
Params_.extra = TheParams->extra;
126126
Params_.func = TheParams->func;
127127
Params_.gridDim = TheParams->gridDim;
128-
Params_.kernelParams = TheParams->kernelParams;
129128
Params_.sharedMemBytes = TheParams->sharedMemBytes;
129+
130130
auto Dev = Backend->getActiveDevice();
131131
chipstar::Kernel *ChipKernel = Dev->findKernel(HostPtr(Params_.func));
132132
if (!ChipKernel)
133133
CHIPERR_LOG_AND_THROW("Could not find requested kernel",
134134
hipErrorInvalidDeviceFunction);
135+
136+
copyKernelArgs(ArgList_, ArgData_, TheParams->kernelParams,
137+
*ChipKernel->getFuncInfo());
138+
Params_.kernelParams = ArgList_.data();
139+
135140
ExecItem_ = Backend->createExecItem(Params_.gridDim, Params_.blockDim,
136141
Params_.sharedMemBytes, nullptr);
137142
ExecItem_->setKernel(ChipKernel);
138-
139-
ExecItem_->copyArgs(TheParams->kernelParams);
143+
ExecItem_->setArgs(TheParams->kernelParams);
140144
ExecItem_->setupAllArgs();
141145
}
142146

@@ -149,18 +153,20 @@ CHIPGraphNodeKernel::CHIPGraphNodeKernel(const void *HostFunction, dim3 GridDim,
149153
Params_.extra = nullptr;
150154
Params_.func = const_cast<void *>(HostFunction);
151155
Params_.gridDim = GridDim;
152-
Params_.kernelParams = Args;
153156
Params_.sharedMemBytes = SharedMem;
154157

155158
auto Dev = Backend->getActiveDevice();
156159
chipstar::Kernel *ChipKernel = Dev->findKernel(HostPtr(HostFunction));
157160
if (!ChipKernel)
158161
CHIPERR_LOG_AND_THROW("Could not find requested kernel",
159162
hipErrorInvalidDeviceFunction);
163+
164+
copyKernelArgs(ArgList_, ArgData_, Args, *ChipKernel->getFuncInfo());
165+
Params_.kernelParams = ArgList_.data();
166+
160167
ExecItem_ = Backend->createExecItem(GridDim, BlockDim, SharedMem, nullptr);
161168
ExecItem_->setKernel(ChipKernel);
162-
163-
ExecItem_->copyArgs(Args);
169+
ExecItem_->setArgs(Params_.kernelParams);
164170
ExecItem_->setupAllArgs();
165171
}
166172

src/CHIPGraph.hh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,12 @@ public:
261261

262262
class CHIPGraphNodeKernel : public CHIPGraphNode {
263263
private:
264+
/// A block holding the bytes of the kernel arguments.
265+
std::vector<char> ArgData_;
266+
267+
/// pointer to start of the kernel argument data for each kernel argument.
268+
std::vector<void *> ArgList_;
269+
264270
hipKernelNodeParams Params_;
265271
chipstar::ExecItem *ExecItem_;
266272

src/SPIRVFuncInfo.cc

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ std::string_view SPVFuncInfo::Arg::getKindAsString() const {
8585
}
8686

8787
/// Client side kernel argument visitor.
88-
void SPVFuncInfo::visitClientArgsImpl(const std::vector<void *> &ClientArgList,
88+
void SPVFuncInfo::visitClientArgsImpl(void **ClientArgList,
8989
ClientArgVisitor Visitor) const {
9090

9191
unsigned ArgListIndex = 0;
@@ -113,12 +113,10 @@ void SPVFuncInfo::visitClientArgsImpl(const std::vector<void *> &ClientArgList,
113113
// Image argument replaced hipTextureObject_t argument.
114114
ArgKind = SPVTypeKind::Pointer;
115115

116-
auto *ArgData =
117-
ClientArgList.empty() ? nullptr : ClientArgList[ArgListIndex];
116+
auto *ArgData = !ClientArgList ? nullptr : ClientArgList[ArgListIndex];
118117

119118
// Clang generated argument list should not have nullptrs in it.
120-
assert((ClientArgList.empty() || ArgData) &&
121-
"nullptr in the argument list");
119+
assert((!ClientArgList || ArgData) && "nullptr in the argument list");
122120

123121
ClientArg CArg{
124122
{{ArgKind, ArgTI.StorageClass, ArgSize}, ArgListIndex, ArgData}};
@@ -128,19 +126,18 @@ void SPVFuncInfo::visitClientArgsImpl(const std::vector<void *> &ClientArgList,
128126
}
129127

130128
/// Visit client-visible kernel arguments
131-
void SPVFuncInfo::visitClientArgs(const std::vector<void *> &ClientArgList,
129+
void SPVFuncInfo::visitClientArgs(void **ClientArgList,
132130
ClientArgVisitor Visitor) const {
133-
assert(ClientArgList.size() == getNumClientArgs());
134131
visitClientArgsImpl(ClientArgList, Visitor);
135132
}
136133

137134
/// Visit client-visible kernel arguments without the argument value
138135
/// (Arg::Data will be nullptr).
139136
void SPVFuncInfo::visitClientArgs(ClientArgVisitor Visitor) const {
140-
visitClientArgsImpl(std::vector<void *>(), Visitor);
137+
visitClientArgsImpl(nullptr, Visitor);
141138
}
142139

143-
void SPVFuncInfo::visitKernelArgsImpl(const std::vector<void *> &ClientArgList,
140+
void SPVFuncInfo::visitKernelArgsImpl(void **ClientArgList,
144141
KernelArgVisitor Visitor) const {
145142
unsigned ArgIndex = 0;
146143
unsigned ArgListIndex = 0;
@@ -156,7 +153,7 @@ void SPVFuncInfo::visitKernelArgsImpl(const std::vector<void *> &ClientArgList,
156153
ArgListIndex--;
157154

158155
const void *ArgData = nullptr;
159-
if (!ClientArgList.empty() && !ArgTI.isWorkgroupPtr()) {
156+
if (ClientArgList && !ArgTI.isWorkgroupPtr()) {
160157
ArgData = ClientArgList[ArgListIndex];
161158

162159
// Clang geerated argument list should not have nullptrs in it.
@@ -172,15 +169,14 @@ void SPVFuncInfo::visitKernelArgsImpl(const std::vector<void *> &ClientArgList,
172169
}
173170

174171
// Visit kernel arguments
175-
void SPVFuncInfo::visitKernelArgs(const std::vector<void *> &ClientArgList,
172+
void SPVFuncInfo::visitKernelArgs(void **ClientArgList,
176173
KernelArgVisitor Visitor) const {
177-
assert(ClientArgList.size() == getNumClientArgs());
178174
visitKernelArgsImpl(ClientArgList, Visitor);
179175
}
180176

181177
/// Visit kernel arguments without argument list (Arg::Data will be nullptr)
182178
void SPVFuncInfo::visitKernelArgs(KernelArgVisitor Visitor) const {
183-
visitKernelArgsImpl(std::vector<void *>(), Visitor);
179+
visitKernelArgsImpl(nullptr, Visitor);
184180
}
185181

186182
/// Return HIP user visible kernel argument count.

src/SPIRVFuncInfo.hh

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,9 @@ public:
101101
SPVFuncInfo(const std::vector<SPVArgTypeInfo> &Info) : ArgTypeInfo_(Info) {}
102102

103103
void visitClientArgs(ClientArgVisitor Fn) const;
104-
void visitClientArgs(const std::vector<void *> &ArgList,
105-
ClientArgVisitor Fn) const;
104+
void visitClientArgs(void **ArgList, ClientArgVisitor Fn) const;
106105
void visitKernelArgs(KernelArgVisitor Fn) const;
107-
void visitKernelArgs(const std::vector<void *> &ArgList,
108-
KernelArgVisitor Fn) const;
106+
void visitKernelArgs(void **ArgList, KernelArgVisitor Fn) const;
109107

110108
/// Return visible kernel argument count.
111109
///
@@ -120,10 +118,8 @@ public:
120118
bool hasByRefArgs() const noexcept { return HasByRefArgs_; }
121119

122120
private:
123-
void visitClientArgsImpl(const std::vector<void *> &ArgList,
124-
ClientArgVisitor Fn) const;
125-
void visitKernelArgsImpl(const std::vector<void *> &ArgList,
126-
KernelArgVisitor Fn) const;
121+
void visitClientArgsImpl(void **ArgList, ClientArgVisitor Fn) const;
122+
void visitKernelArgsImpl(void **ArgList, KernelArgVisitor Fn) const;
127123
};
128124

129125
typedef std::map<int32_t, std::shared_ptr<SPVFuncInfo>> SPVFuncInfoMap;

src/Utils.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,41 @@ bool startsWith(std::string_view Str, std::string_view WithStr) {
259259
return Str.size() >= WithStr.size() &&
260260
Str.substr(0, WithStr.size()) == WithStr;
261261
}
262+
263+
/// Deep copies kernel arguments pointed by 'CopyArg'. Bytes of the
264+
/// argument values are stored in 'ArgData'. 'ArgList[I]' points to
265+
/// the argument value in 'ArgData' for Ith kernel argument.
266+
void copyKernelArgs(std::vector<void *> &ArgList, std::vector<char> &ArgData,
267+
void **CopyFrom, const SPVFuncInfo &FuncInfo) {
268+
269+
ArgList.clear();
270+
ArgData.clear();
271+
272+
std::vector<size_t> Offsets;
273+
size_t CurrOffset = 0;
274+
275+
auto CopyArgData = [&](const SPVFuncInfo::ClientArg &Arg) {
276+
assert((Arg.Kind == SPVTypeKind::POD || Arg.Kind == SPVTypeKind::Pointer) &&
277+
"Unexpected argument kind.");
278+
279+
size_t Size = Arg.Size;
280+
size_t Alignment = roundUpToPowerOfTwo(Size);
281+
assert(Size && Alignment && "Invalid arg size or alignment!");
282+
283+
CurrOffset = roundUp(CurrOffset, Alignment);
284+
logDebug("arg {} tgt offset: {}", Arg.Index, CurrOffset);
285+
Offsets.push_back(CurrOffset);
286+
assert(CurrOffset >= ArgData.size());
287+
288+
ArgData.resize(CurrOffset + Size, 0);
289+
std::memcpy(ArgData.data() + CurrOffset, Arg.Data, Size);
290+
291+
CurrOffset += Size;
292+
};
293+
FuncInfo.visitClientArgs(CopyFrom, CopyArgData);
294+
295+
ArgList.reserve(Offsets.size());
296+
char *BasePtr = ArgData.data();
297+
for (auto Offset : Offsets)
298+
ArgList.push_back(static_cast<void *>(BasePtr + Offset));
299+
}

src/Utils.hh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,4 +144,7 @@ template <class T> struct PointerCmp {
144144
}
145145
};
146146

147+
void copyKernelArgs(std::vector<void *> &ArgList, std::vector<char> &ArgData,
148+
void **CopyFrom, const SPVFuncInfo &FuncInfo);
149+
147150
#endif

tests/runtime/TestArgVisitors.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ int main() {
4949

5050
// Simulate client-side arguments.
5151
int a, b, c, d, e, f;
52-
std::vector<void *> ArgList{&a, &b, &c, &d, &e, &f};
52+
std::vector<void *> ArgListVec{&a, &b, &c, &d, &e, &f};
53+
void **ArgList = static_cast<void **>(ArgListVec.data());
5354

5455
// Test visitors.
5556

@@ -66,7 +67,7 @@ int main() {
6667
ArgIdx = 0;
6768
FI.visitClientArgs(ArgList, [&](const SPVFuncInfo::ClientArg &Arg) {
6869
assert(Arg.Index == ArgIdx++);
69-
assert(Arg.Data == ArgList.at(Arg.Index));
70+
assert(Arg.Data == ArgListVec.at(Arg.Index));
7071
if (Arg.Index == 0)
7172
assert(Arg.Kind == SPVTypeKind::Pointer);
7273
else if (Arg.Index == 1)
@@ -101,28 +102,28 @@ int main() {
101102

102103
if (Arg.Index == 0) {
103104
assert(Arg.Kind == SPVTypeKind::Pointer);
104-
assert(Arg.Data == ArgList.at(0));
105+
assert(Arg.Data == ArgListVec.at(0));
105106
} else if (Arg.Index == 1) {
106107
assert(Arg.Kind == SPVTypeKind::Image);
107-
assert(Arg.Data == ArgList.at(1));
108+
assert(Arg.Data == ArgListVec.at(1));
108109
} else if (Arg.Index == 2) {
109110
assert(Arg.Kind == SPVTypeKind::Sampler);
110-
assert(Arg.Data == ArgList.at(1));
111+
assert(Arg.Data == ArgListVec.at(1));
111112
} else if (Arg.Index == 3) {
112113
assert(Arg.Kind == SPVTypeKind::POD);
113-
assert(Arg.Data == ArgList.at(2));
114+
assert(Arg.Data == ArgListVec.at(2));
114115
} else if (Arg.Index == 4) {
115116
assert(Arg.Kind == SPVTypeKind::Image);
116-
assert(Arg.Data == ArgList.at(3));
117+
assert(Arg.Data == ArgListVec.at(3));
117118
} else if (Arg.Index == 5) {
118119
assert(Arg.Kind == SPVTypeKind::Sampler);
119-
assert(Arg.Data == ArgList.at(3));
120+
assert(Arg.Data == ArgListVec.at(3));
120121
} else if (Arg.Index == 6) {
121122
assert(Arg.Kind == SPVTypeKind::POD);
122-
assert(Arg.Data == ArgList.at(4));
123+
assert(Arg.Data == ArgListVec.at(4));
123124
} else if (Arg.Index == 7) {
124125
assert(Arg.Kind == SPVTypeKind::PODByRef);
125-
assert(Arg.Data == ArgList.at(5));
126+
assert(Arg.Data == ArgListVec.at(5));
126127
} else if (Arg.Index == 8) {
127128
assert(Arg.Kind == SPVTypeKind::Pointer);
128129
assert(Arg.isWorkgroupPtr());

0 commit comments

Comments
 (0)