Skip to content

Commit cd69092

Browse files
committed
internal/cabi: fix llvm.alloca for callInsrt
1 parent 6de3bdc commit cd69092

2 files changed

Lines changed: 75 additions & 15 deletions

File tree

_demo/cabisret/main.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package main
2+
3+
type array9 struct {
4+
x [9]float32
5+
}
6+
7+
func demo1(a array9) array9 {
8+
a.x[0] += 1
9+
return a
10+
}
11+
12+
func demo2(a array9) array9 {
13+
for i := 0; i < 1024*128; i++ {
14+
a = demo1(a)
15+
}
16+
return a
17+
}
18+
19+
func testDemo() {
20+
ar := array9{x: [9]float32{1, 2, 3, 4, 5, 6, 7, 8, 9}}
21+
for i := 0; i < 1024*128; i++ {
22+
ar = demo1(ar)
23+
}
24+
ar = demo2(ar)
25+
println(ar.x[0], ar.x[1])
26+
}
27+
28+
func testSlice() {
29+
var b []byte
30+
for i := 0; i < 1024*128; i++ {
31+
b = append(b, byte(i))
32+
}
33+
_ = b
34+
}
35+
36+
func main() {
37+
testDemo()
38+
testSlice()
39+
}

internal/cabi/cabi.go

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,15 @@ func (p *Transformer) isCFunc(name string) bool {
5454
return !strings.Contains(name, ".")
5555
}
5656

57+
type CallInstr struct {
58+
call llvm.Value
59+
fn llvm.Value
60+
}
61+
5762
func (p *Transformer) TransformModule(path string, m llvm.Module) {
5863
ctx := m.Context()
5964
var fns []llvm.Value
60-
var callInstrs []llvm.Value
65+
var callInstrs []CallInstr
6166
switch p.mode {
6267
case ModeNone:
6368
return
@@ -66,16 +71,22 @@ func (p *Transformer) TransformModule(path string, m llvm.Module) {
6671
for !fn.IsNil() {
6772
if p.isCFunc(fn.Name()) {
6873
p.transformFuncCall(m, fn)
69-
if p.isWrapFunctionType(m.Context(), fn.GlobalValueType()) {
74+
if p.isWrapFunctionType(ctx, fn.GlobalValueType()) {
7075
fns = append(fns, fn)
71-
use := fn.FirstUse()
72-
for !use.IsNil() {
73-
if call := use.User().IsACallInst(); !call.IsNil() && call.CalledValue() == fn {
74-
callInstrs = append(callInstrs, call)
76+
}
77+
}
78+
bb := fn.FirstBasicBlock()
79+
for !bb.IsNil() {
80+
instr := bb.FirstInstruction()
81+
for !instr.IsNil() {
82+
if call := instr.IsACallInst(); !call.IsNil() && p.isCFunc(call.CalledValue().Name()) {
83+
if p.isWrapFunctionType(ctx, call.CalledFunctionType()) {
84+
callInstrs = append(callInstrs, CallInstr{call, fn})
7585
}
76-
use = use.NextUse()
7786
}
87+
instr = llvm.NextInstruction(instr)
7888
}
89+
bb = llvm.NextBasicBlock(bb)
7990
}
8091
fn = llvm.NextFunction(fn)
8192
}
@@ -91,7 +102,7 @@ func (p *Transformer) TransformModule(path string, m llvm.Module) {
91102
for !instr.IsNil() {
92103
if call := instr.IsACallInst(); !call.IsNil() {
93104
if p.isWrapFunctionType(ctx, call.CalledFunctionType()) {
94-
callInstrs = append(callInstrs, call)
105+
callInstrs = append(callInstrs, CallInstr{call, fn})
95106
}
96107
}
97108
instr = llvm.NextInstruction(instr)
@@ -102,7 +113,7 @@ func (p *Transformer) TransformModule(path string, m llvm.Module) {
102113
}
103114
}
104115
for _, call := range callInstrs {
105-
p.transformCallInstr(ctx, call)
116+
p.transformCallInstr(ctx, call.call, call.fn)
106117
}
107118
for _, fn := range fns {
108119
p.transformFunc(m, fn)
@@ -369,6 +380,7 @@ func (p *Transformer) transformFuncBody(ctx llvm.Context, info *FuncInfo, fn llv
369380
fn.Param(i).ReplaceAllUsesWith(nv)
370381
index++
371382
}
383+
372384
if info.Return.Kind >= AttrPointer {
373385
var retInstrs []llvm.Value
374386
bb := nfn.FirstBasicBlock()
@@ -402,7 +414,7 @@ func (p *Transformer) transformFuncBody(ctx llvm.Context, info *FuncInfo, fn llv
402414
}
403415
}
404416

405-
func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value) bool {
417+
func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value, fn llvm.Value) bool {
406418
nfn := call.CalledValue()
407419
info := p.GetFuncInfo(ctx, call.CalledFunctionType())
408420
if !info.HasWrap() {
@@ -411,6 +423,15 @@ func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value) bool
411423
nft, attrs := p.transformFuncType(ctx, &info)
412424
b := ctx.NewBuilder()
413425
b.SetInsertPointBefore(call)
426+
427+
first := fn.EntryBasicBlock().FirstInstruction()
428+
createAlloca := func(t llvm.Type) (ret llvm.Value) {
429+
b.SetInsertPointBefore(first)
430+
ret = llvm.CreateAlloca(b, t)
431+
b.SetInsertPointBefore(call)
432+
return
433+
}
434+
414435
operandCount := len(info.Params)
415436
var nparams []llvm.Value
416437
for i := 0; i < operandCount; i++ {
@@ -422,16 +443,16 @@ func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value) bool
422443
case AttrVoid:
423444
// none
424445
case AttrPointer:
425-
ptr := llvm.CreateAlloca(b, ti.Type)
446+
ptr := createAlloca(ti.Type)
426447
b.CreateStore(param, ptr)
427448
nparams = append(nparams, ptr)
428449
case AttrWidthType:
429-
ptr := llvm.CreateAlloca(b, ti.Type)
450+
ptr := createAlloca(ti.Type)
430451
b.CreateStore(param, ptr)
431452
iptr := b.CreateBitCast(ptr, llvm.PointerType(ti.Type1, 0), "")
432453
nparams = append(nparams, b.CreateLoad(ti.Type1, iptr, ""))
433454
case AttrWidthType2:
434-
ptr := llvm.CreateAlloca(b, ti.Type)
455+
ptr := createAlloca(ti.Type)
435456
b.CreateStore(param, ptr)
436457
typ := llvm.StructType([]llvm.Type{ti.Type1, ti.Type2}, false) // {i8,i64}
437458
iptr := b.CreateBitCast(ptr, llvm.PointerType(typ, 0), "")
@@ -457,14 +478,14 @@ func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value) bool
457478
instr = llvm.CreateCall(b, nft, nfn, nparams)
458479
updateCallAttr(instr)
459480
case AttrPointer:
460-
ret := llvm.CreateAlloca(b, info.Return.Type)
481+
ret := createAlloca(info.Return.Type)
461482
call := llvm.CreateCall(b, nft, nfn, append([]llvm.Value{ret}, nparams...))
462483
updateCallAttr(call)
463484
instr = b.CreateLoad(info.Return.Type, ret, "")
464485
case AttrWidthType, AttrWidthType2:
465486
ret := llvm.CreateCall(b, nft, nfn, nparams)
466487
updateCallAttr(ret)
467-
ptr := llvm.CreateAlloca(b, nft.ReturnType())
488+
ptr := createAlloca(nft.ReturnType())
468489
b.CreateStore(ret, ptr)
469490
pret := b.CreateBitCast(ptr, llvm.PointerType(info.Return.Type, 0), "")
470491
instr = b.CreateLoad(info.Return.Type, pret, "")

0 commit comments

Comments
 (0)