Merge pull request #1231 from visualfc/cabi_alloca

internal/cabi: fix llvm.alloca for callInsrt
This commit is contained in:
xushiwei
2025-08-24 08:38:02 +08:00
committed by GitHub
2 changed files with 77 additions and 17 deletions

39
_demo/cabisret/main.go Normal file
View File

@@ -0,0 +1,39 @@
package main
type array9 struct {
x [9]float32
}
func demo1(a array9) array9 {
a.x[0] += 1
return a
}
func demo2(a array9) array9 {
for i := 0; i < 1024*128; i++ {
a = demo1(a)
}
return a
}
func testDemo() {
ar := array9{x: [9]float32{1, 2, 3, 4, 5, 6, 7, 8, 9}}
for i := 0; i < 1024*128; i++ {
ar = demo1(ar)
}
ar = demo2(ar)
println(ar.x[0], ar.x[1])
}
func testSlice() {
var b []byte
for i := 0; i < 1024*128; i++ {
b = append(b, byte(i))
}
_ = b
}
func main() {
testDemo()
testSlice()
}

View File

@@ -54,10 +54,15 @@ func (p *Transformer) isCFunc(name string) bool {
return !strings.Contains(name, ".") return !strings.Contains(name, ".")
} }
type CallInstr struct {
call llvm.Value
fn llvm.Value
}
func (p *Transformer) TransformModule(path string, m llvm.Module) { func (p *Transformer) TransformModule(path string, m llvm.Module) {
ctx := m.Context() ctx := m.Context()
var fns []llvm.Value var fns []llvm.Value
var callInstrs []llvm.Value var callInstrs []CallInstr
switch p.mode { switch p.mode {
case ModeNone: case ModeNone:
return return
@@ -66,16 +71,22 @@ func (p *Transformer) TransformModule(path string, m llvm.Module) {
for !fn.IsNil() { for !fn.IsNil() {
if p.isCFunc(fn.Name()) { if p.isCFunc(fn.Name()) {
p.transformFuncCall(m, fn) p.transformFuncCall(m, fn)
if p.isWrapFunctionType(m.Context(), fn.GlobalValueType()) { if p.isWrapFunctionType(ctx, fn.GlobalValueType()) {
fns = append(fns, fn) fns = append(fns, fn)
use := fn.FirstUse()
for !use.IsNil() {
if call := use.User().IsACallInst(); !call.IsNil() && call.CalledValue() == fn {
callInstrs = append(callInstrs, call)
}
use = use.NextUse()
} }
} }
bb := fn.FirstBasicBlock()
for !bb.IsNil() {
instr := bb.FirstInstruction()
for !instr.IsNil() {
if call := instr.IsACallInst(); !call.IsNil() && p.isCFunc(call.CalledValue().Name()) {
if p.isWrapFunctionType(ctx, call.CalledFunctionType()) {
callInstrs = append(callInstrs, CallInstr{call, fn})
}
}
instr = llvm.NextInstruction(instr)
}
bb = llvm.NextBasicBlock(bb)
} }
fn = llvm.NextFunction(fn) fn = llvm.NextFunction(fn)
} }
@@ -91,7 +102,7 @@ func (p *Transformer) TransformModule(path string, m llvm.Module) {
for !instr.IsNil() { for !instr.IsNil() {
if call := instr.IsACallInst(); !call.IsNil() { if call := instr.IsACallInst(); !call.IsNil() {
if p.isWrapFunctionType(ctx, call.CalledFunctionType()) { if p.isWrapFunctionType(ctx, call.CalledFunctionType()) {
callInstrs = append(callInstrs, call) callInstrs = append(callInstrs, CallInstr{call, fn})
} }
} }
instr = llvm.NextInstruction(instr) instr = llvm.NextInstruction(instr)
@@ -102,7 +113,7 @@ func (p *Transformer) TransformModule(path string, m llvm.Module) {
} }
} }
for _, call := range callInstrs { for _, call := range callInstrs {
p.transformCallInstr(ctx, call) p.transformCallInstr(ctx, call.call, call.fn)
} }
for _, fn := range fns { for _, fn := range fns {
p.transformFunc(m, fn) p.transformFunc(m, fn)
@@ -369,6 +380,7 @@ func (p *Transformer) transformFuncBody(ctx llvm.Context, info *FuncInfo, fn llv
fn.Param(i).ReplaceAllUsesWith(nv) fn.Param(i).ReplaceAllUsesWith(nv)
index++ index++
} }
if info.Return.Kind >= AttrPointer { if info.Return.Kind >= AttrPointer {
var retInstrs []llvm.Value var retInstrs []llvm.Value
bb := nfn.FirstBasicBlock() bb := nfn.FirstBasicBlock()
@@ -402,7 +414,7 @@ func (p *Transformer) transformFuncBody(ctx llvm.Context, info *FuncInfo, fn llv
} }
} }
func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value) bool { func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value, fn llvm.Value) bool {
nfn := call.CalledValue() nfn := call.CalledValue()
info := p.GetFuncInfo(ctx, call.CalledFunctionType()) info := p.GetFuncInfo(ctx, call.CalledFunctionType())
if !info.HasWrap() { if !info.HasWrap() {
@@ -411,6 +423,15 @@ func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value) bool
nft, attrs := p.transformFuncType(ctx, &info) nft, attrs := p.transformFuncType(ctx, &info)
b := ctx.NewBuilder() b := ctx.NewBuilder()
b.SetInsertPointBefore(call) b.SetInsertPointBefore(call)
first := fn.EntryBasicBlock().FirstInstruction()
createAlloca := func(t llvm.Type) (ret llvm.Value) {
b.SetInsertPointBefore(first)
ret = llvm.CreateAlloca(b, t)
b.SetInsertPointBefore(call)
return
}
operandCount := len(info.Params) operandCount := len(info.Params)
var nparams []llvm.Value var nparams []llvm.Value
for i := 0; i < operandCount; i++ { for i := 0; i < operandCount; i++ {
@@ -422,16 +443,16 @@ func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value) bool
case AttrVoid: case AttrVoid:
// none // none
case AttrPointer: case AttrPointer:
ptr := llvm.CreateAlloca(b, ti.Type) ptr := createAlloca(ti.Type)
b.CreateStore(param, ptr) b.CreateStore(param, ptr)
nparams = append(nparams, ptr) nparams = append(nparams, ptr)
case AttrWidthType: case AttrWidthType:
ptr := llvm.CreateAlloca(b, ti.Type) ptr := createAlloca(ti.Type)
b.CreateStore(param, ptr) b.CreateStore(param, ptr)
iptr := b.CreateBitCast(ptr, llvm.PointerType(ti.Type1, 0), "") iptr := b.CreateBitCast(ptr, llvm.PointerType(ti.Type1, 0), "")
nparams = append(nparams, b.CreateLoad(ti.Type1, iptr, "")) nparams = append(nparams, b.CreateLoad(ti.Type1, iptr, ""))
case AttrWidthType2: case AttrWidthType2:
ptr := llvm.CreateAlloca(b, ti.Type) ptr := createAlloca(ti.Type)
b.CreateStore(param, ptr) b.CreateStore(param, ptr)
typ := llvm.StructType([]llvm.Type{ti.Type1, ti.Type2}, false) // {i8,i64} typ := llvm.StructType([]llvm.Type{ti.Type1, ti.Type2}, false) // {i8,i64}
iptr := b.CreateBitCast(ptr, llvm.PointerType(typ, 0), "") iptr := b.CreateBitCast(ptr, llvm.PointerType(typ, 0), "")
@@ -457,14 +478,14 @@ func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value) bool
instr = llvm.CreateCall(b, nft, nfn, nparams) instr = llvm.CreateCall(b, nft, nfn, nparams)
updateCallAttr(instr) updateCallAttr(instr)
case AttrPointer: case AttrPointer:
ret := llvm.CreateAlloca(b, info.Return.Type) ret := createAlloca(info.Return.Type)
call := llvm.CreateCall(b, nft, nfn, append([]llvm.Value{ret}, nparams...)) call := llvm.CreateCall(b, nft, nfn, append([]llvm.Value{ret}, nparams...))
updateCallAttr(call) updateCallAttr(call)
instr = b.CreateLoad(info.Return.Type, ret, "") instr = b.CreateLoad(info.Return.Type, ret, "")
case AttrWidthType, AttrWidthType2: case AttrWidthType, AttrWidthType2:
ret := llvm.CreateCall(b, nft, nfn, nparams) ret := llvm.CreateCall(b, nft, nfn, nparams)
updateCallAttr(ret) updateCallAttr(ret)
ptr := llvm.CreateAlloca(b, nft.ReturnType()) ptr := createAlloca(nft.ReturnType())
b.CreateStore(ret, ptr) b.CreateStore(ret, ptr)
pret := b.CreateBitCast(ptr, llvm.PointerType(info.Return.Type, 0), "") pret := b.CreateBitCast(ptr, llvm.PointerType(info.Return.Type, 0), "")
instr = b.CreateLoad(info.Return.Type, pret, "") instr = b.CreateLoad(info.Return.Type, pret, "")