@@ -92,6 +92,9 @@ class SPIRVEmitIntrinsics
9292 void insertPtrCastOrAssignTypeInstr (Instruction *I, IRBuilder<> &B);
9393 void processGlobalValue (GlobalVariable &GV, IRBuilder<> &B);
9494 void processParamTypes (Function *F, IRBuilder<> &B);
95+ Type *deduceFunParamType (Function *F, unsigned OpIdx);
96+ Type *deduceFunParamType (Function *F, unsigned OpIdx,
97+ std::unordered_set<Function *> &FVisited);
9598
9699public:
97100 static char ID;
@@ -169,6 +172,10 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
169172static Type *deduceElementTypeHelper (Value *I,
170173 std::unordered_set<Value *> &Visited,
171174 DenseMap<Value *, Type *> &DeducedElTys) {
175+ // allow to pass nullptr as an argument
176+ if (!I)
177+ return nullptr ;
178+
172179 // maybe already known
173180 auto It = DeducedElTys.find (I);
174181 if (It != DeducedElTys.end ())
@@ -182,15 +189,20 @@ static Type *deduceElementTypeHelper(Value *I,
182189 // fallback value in case when we fail to deduce a type
183190 Type *Ty = nullptr ;
184191 // look for known basic patterns of type inference
185- if (auto *Ref = dyn_cast<AllocaInst>(I))
192+ if (auto *Ref = dyn_cast<AllocaInst>(I)) {
186193 Ty = Ref->getAllocatedType ();
187- else if (auto *Ref = dyn_cast<GetElementPtrInst>(I))
194+ } else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) {
188195 Ty = Ref->getResultElementType ();
189- else if (auto *Ref = dyn_cast<GlobalValue>(I))
196+ } else if (auto *Ref = dyn_cast<GlobalValue>(I)) {
190197 Ty = Ref->getValueType ();
191- else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I))
198+ } else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) {
192199 Ty = deduceElementTypeHelper (Ref->getPointerOperand (), Visited,
193200 DeducedElTys);
201+ } else if (auto *Ref = dyn_cast<BitCastInst>(I)) {
202+ if (Type *Src = Ref->getSrcTy (), *Dest = Ref->getDestTy ();
203+ isPointerTy (Src) && isPointerTy (Dest))
204+ Ty = deduceElementTypeHelper (Ref->getOperand (0 ), Visited, DeducedElTys);
205+ }
194206
195207 // remember the found relationship
196208 if (Ty)
@@ -795,61 +807,80 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
795807 }
796808}
797809
798- void SPIRVEmitIntrinsics::processParamTypes (Function *F, IRBuilder<> &B) {
799- DenseMap<unsigned , Argument *> Args;
800- unsigned i = 0 ;
801- for (Argument &Arg : F->args ()) {
802- if (isUntypedPointerTy (Arg.getType ()) &&
803- DeducedElTys.find (&Arg) == DeducedElTys.end () &&
804- !HasPointeeTypeAttr (&Arg))
805- Args[i] = &Arg;
806- i++;
807- }
808- if (Args.size () == 0 )
809- return ;
810+ Type *SPIRVEmitIntrinsics::deduceFunParamType (Function *F, unsigned OpIdx) {
811+ std::unordered_set<Function *> FVisited;
812+ return deduceFunParamType (F, OpIdx, FVisited);
813+ }
814+
815+ Type *SPIRVEmitIntrinsics::deduceFunParamType (
816+ Function *F, unsigned OpIdx, std::unordered_set<Function *> &FVisited) {
817+ // maybe a cycle
818+ if (FVisited.find (F) != FVisited.end ())
819+ return nullptr ;
820+ FVisited.insert (F);
810821
811- // Args contains opaque pointers without element type definition
812- B.SetInsertPointPastAllocas (F);
813822 std::unordered_set<Value *> Visited;
823+ SmallVector<std::pair<Function *, unsigned >> Lookup;
824+ // search in function's call sites
814825 for (User *U : F->users ()) {
815826 CallInst *CI = dyn_cast<CallInst>(U);
816- if (!CI)
827+ if (!CI || OpIdx >= CI-> arg_size () )
817828 continue ;
818- for (unsigned OpIdx = 0 ; OpIdx < CI->arg_size () && Args.size () > 0 ;
819- OpIdx++) {
820- auto It = Args.find (OpIdx);
821- Argument *Arg = It == Args.end () ? nullptr : It->second ;
822- if (!Arg)
823- continue ;
824- Value *OpArg = CI->getArgOperand (OpIdx);
825- if (!isPointerTy (OpArg->getType ()))
829+ Value *OpArg = CI->getArgOperand (OpIdx);
830+ if (!isPointerTy (OpArg->getType ()))
831+ continue ;
832+ // maybe we already know operand's element type
833+ if (auto It = DeducedElTys.find (OpArg); It != DeducedElTys.end ())
834+ return It->second ;
835+ // search in actual parameter's users
836+ for (User *OpU : OpArg->users ()) {
837+ Instruction *Inst = dyn_cast<Instruction>(OpU);
838+ if (!Inst || Inst == CI)
826839 continue ;
827- // maybe we already know the operand's element type
828- auto DeducedIt = DeducedElTys.find (OpArg);
829- Type *ElemTy =
830- DeducedIt == DeducedElTys.end () ? nullptr : DeducedIt->second ;
831- if (!ElemTy) {
832- for (User *OpU : OpArg->users ()) {
833- if (Instruction *Inst = dyn_cast<Instruction>(OpU)) {
834- Visited.clear ();
835- ElemTy = deduceElementTypeHelper (Inst, Visited, DeducedElTys);
836- if (ElemTy)
837- break ;
838- }
839- }
840+ Visited.clear ();
841+ if (Type *Ty = deduceElementTypeHelper (Inst, Visited, DeducedElTys))
842+ return Ty;
843+ }
844+ // check if it's a formal parameter of the outer function
845+ if (!CI->getParent () || !CI->getParent ()->getParent ())
846+ continue ;
847+ Function *OuterF = CI->getParent ()->getParent ();
848+ if (FVisited.find (OuterF) != FVisited.end ())
849+ continue ;
850+ for (unsigned i = 0 ; i < OuterF->arg_size (); ++i) {
851+ if (OuterF->getArg (i) == OpArg) {
852+ Lookup.push_back (std::make_pair (OuterF, i));
853+ break ;
840854 }
841- if (ElemTy) {
842- unsigned AddressSpace = getPointerAddressSpace (Arg->getType ());
855+ }
856+ }
857+
858+ // search in function parameters
859+ for (auto &Pair : Lookup) {
860+ if (Type *Ty = deduceFunParamType (Pair.first , Pair.second , FVisited))
861+ return Ty;
862+ }
863+
864+ return nullptr ;
865+ }
866+
867+ void SPIRVEmitIntrinsics::processParamTypes (Function *F, IRBuilder<> &B) {
868+ B.SetInsertPointPastAllocas (F);
869+ DenseMap<Argument *, Type *> Args;
870+ for (unsigned OpIdx = 0 ; OpIdx < F->arg_size (); ++OpIdx) {
871+ Argument *Arg = F->getArg (OpIdx);
872+ if (isUntypedPointerTy (Arg->getType ()) &&
873+ DeducedElTys.find (Arg) == DeducedElTys.end () &&
874+ !HasPointeeTypeAttr (Arg)) {
875+ if (Type *ElemTy = deduceFunParamType (F, OpIdx)) {
843876 CallInst *AssignPtrTyCI = buildIntrWithMD (
844877 Intrinsic::spv_assign_ptr_type, {Arg->getType ()},
845- Constant::getNullValue (ElemTy), Arg, {B.getInt32 (AddressSpace)}, B);
878+ Constant::getNullValue (ElemTy), Arg,
879+ {B.getInt32 (getPointerAddressSpace (Arg->getType ()))}, B);
846880 DeducedElTys[AssignPtrTyCI] = ElemTy;
847881 DeducedElTys[Arg] = ElemTy;
848- Args.erase (It);
849882 }
850883 }
851- if (Args.size () == 0 )
852- break ;
853884 }
854885}
855886
0 commit comments