diff --git a/cl/_testdata/method/out.ll b/cl/_testdata/method/out.ll index 3f7b91f5..cdb1c3d8 100644 --- a/cl/_testdata/method/out.ll +++ b/cl/_testdata/method/out.ll @@ -49,6 +49,6 @@ _llgo_0: ret void } -declare void @printf(ptr, ...) - declare void @"github.com/goplus/llgo/internal/runtime.init"() + +declare void @printf(ptr, ...) diff --git a/cl/_testdata/printf/out.ll b/cl/_testdata/printf/out.ll index d2f6551e..5dbff6b7 100644 --- a/cl/_testdata/printf/out.ll +++ b/cl/_testdata/printf/out.ll @@ -32,6 +32,6 @@ _llgo_0: ret void } -declare void @printf(ptr, ...) - declare void @"github.com/goplus/llgo/internal/runtime.init"() + +declare void @printf(ptr, ...) diff --git a/cl/_testdata/printval/out.ll b/cl/_testdata/printval/out.ll index a27337a7..b3157387 100644 --- a/cl/_testdata/printval/out.ll +++ b/cl/_testdata/printval/out.ll @@ -35,6 +35,6 @@ _llgo_0: ret void } -declare void @printf(ptr, ...) - declare void @"github.com/goplus/llgo/internal/runtime.init"() + +declare void @printf(ptr, ...) diff --git a/cl/_testrt/callback/in.go b/cl/_testrt/callback/in.go index 3e2466f5..833bb900 100644 --- a/cl/_testrt/callback/in.go +++ b/cl/_testrt/callback/in.go @@ -4,12 +4,15 @@ import ( "github.com/goplus/llgo/internal/runtime/c" ) -func callback(f func()) { - f() +func callback(msg *c.Char, f func(*c.Char)) { + f(msg) +} + +func print(msg *c.Char) { + c.Printf(msg) } func main() { - callback(func() { - c.Printf(c.Str("Hello, callback\n")) - }) + callback(c.Str("Hello\n"), print) + callback(c.Str("callback\n"), print) } diff --git a/cl/_testrt/callback/out.ll b/cl/_testrt/callback/out.ll index af9f93d3..b1d0eba5 100644 --- a/cl/_testrt/callback/out.ll +++ b/cl/_testrt/callback/out.ll @@ -2,12 +2,14 @@ source_filename = "main" @"main.init$guard" = global ptr null -@0 = private unnamed_addr constant [17 x i8] c"Hello, callback\0A\00", align 1 +@0 = private unnamed_addr constant [7 x i8] c"Hello\0A\00", align 1 +@1 = private unnamed_addr constant [10 x i8] c"callback\0A\00", align 1 -define void @main.callback({ ptr, ptr } %0) { +define void @main.callback(ptr %0, { ptr, ptr } %1) { _llgo_0: - %1 = extractvalue { ptr, ptr } %0, 0 - call void %1() + %2 = extractvalue { ptr, ptr } %1, 1 + %3 = extractvalue { ptr, ptr } %1, 0 + call void %3(ptr %2, ptr %0) ret void } @@ -30,19 +32,32 @@ _llgo_0: call void @main.init() %0 = alloca { ptr, ptr }, align 8 %1 = getelementptr inbounds { ptr, ptr }, ptr %0, i32 0, i32 0 - store ptr @"main.main$1", ptr %1, align 8 + store ptr @__llgo_stub.main.print, ptr %1, align 8 %2 = getelementptr inbounds { ptr, ptr }, ptr %0, i32 0, i32 1 store ptr null, ptr %2, align 8 %3 = load { ptr, ptr }, ptr %0, align 8 - call void @main.callback({ ptr, ptr } %3) + call void @main.callback(ptr @0, { ptr, ptr } %3) + %4 = alloca { ptr, ptr }, align 8 + %5 = getelementptr inbounds { ptr, ptr }, ptr %4, i32 0, i32 0 + store ptr @__llgo_stub.main.print, ptr %5, align 8 + %6 = getelementptr inbounds { ptr, ptr }, ptr %4, i32 0, i32 1 + store ptr null, ptr %6, align 8 + %7 = load { ptr, ptr }, ptr %4, align 8 + call void @main.callback(ptr @1, { ptr, ptr } %7) + ret void +} + +define void @main.print(ptr %0) { +_llgo_0: + %1 = call i32 (ptr, ...) @printf(ptr %0) ret void } declare void @"github.com/goplus/llgo/internal/runtime.init"() -define void @"main.main$1"() { +define void @__llgo_stub.main.print(ptr %0, ptr %1) { _llgo_0: - %0 = call i32 (ptr, ...) @printf(ptr @0) + call void @main.print(ptr %1) ret void } diff --git a/cl/_testrt/cstr/out.ll b/cl/_testrt/cstr/out.ll index 969ca268..c1b7f9f3 100644 --- a/cl/_testrt/cstr/out.ll +++ b/cl/_testrt/cstr/out.ll @@ -25,6 +25,6 @@ _llgo_0: ret void } -declare void @printf(ptr, ...) - declare void @"github.com/goplus/llgo/internal/runtime.init"() + +declare void @printf(ptr, ...) diff --git a/cl/_testrt/fprintf/out.ll b/cl/_testrt/fprintf/out.ll index 03f4adf7..cf814625 100644 --- a/cl/_testrt/fprintf/out.ll +++ b/cl/_testrt/fprintf/out.ll @@ -5,8 +5,6 @@ source_filename = "main" @__stderrp = external global ptr @0 = private unnamed_addr constant [10 x i8] c"Hello %d\0A\00", align 1 -declare void @fprintf(ptr, ptr, ...) - define void @main.init() { _llgo_0: %0 = load i1, ptr @"main.init$guard", align 1 @@ -30,3 +28,5 @@ _llgo_0: } declare void @"github.com/goplus/llgo/internal/runtime.init"() + +declare void @fprintf(ptr, ptr, ...) diff --git a/cl/_testrt/intgen/in.go b/cl/_testrt/intgen/in.go index a5e277d0..2b7b2159 100644 --- a/cl/_testrt/intgen/in.go +++ b/cl/_testrt/intgen/in.go @@ -13,8 +13,16 @@ func genInts(n int, gen func() c.Int) []c.Int { } func main() { + initVal := c.Int(1) a := genInts(5, c.Rand) for _, v := range a { c.Printf(c.Str("%d\n"), v) } + b := genInts(5, func() c.Int { + initVal *= 2 + return initVal + }) + for _, v := range b { + c.Printf(c.Str("%d\n"), v) + } } diff --git a/cl/_testrt/intgen/out.ll b/cl/_testrt/intgen/out.ll index 8e4e7e60..ab28bc81 100644 --- a/cl/_testrt/intgen/out.ll +++ b/cl/_testrt/intgen/out.ll @@ -5,6 +5,7 @@ source_filename = "main" @"main.init$guard" = global ptr null @0 = private unnamed_addr constant [4 x i8] c"%d\0A\00", align 1 +@1 = private unnamed_addr constant [4 x i8] c"%d\0A\00", align 1 define %"github.com/goplus/llgo/internal/runtime.Slice" @main.genInts(i64 %0, { ptr, ptr } %1) { _llgo_0: @@ -21,11 +22,12 @@ _llgo_1: ; preds = %_llgo_2, %_llgo_0 br i1 %8, label %_llgo_2, label %_llgo_3 _llgo_2: ; preds = %_llgo_1 - %9 = extractvalue { ptr, ptr } %1, 0 - %10 = call i32 %9() - %11 = call ptr @"github.com/goplus/llgo/internal/runtime.SliceData"(%"github.com/goplus/llgo/internal/runtime.Slice" %4) - %12 = getelementptr inbounds i32, ptr %11, i64 %7 - store i32 %10, ptr %12, align 4 + %9 = extractvalue { ptr, ptr } %1, 1 + %10 = extractvalue { ptr, ptr } %1, 0 + %11 = call i32 %10(ptr %9) + %12 = call ptr @"github.com/goplus/llgo/internal/runtime.SliceData"(%"github.com/goplus/llgo/internal/runtime.Slice" %4) + %13 = getelementptr inbounds i32, ptr %12, i64 %7 + store i32 %11, ptr %13, align 4 br label %_llgo_1 _llgo_3: ; preds = %_llgo_1 @@ -49,30 +51,59 @@ define void @main() { _llgo_0: call void @"github.com/goplus/llgo/internal/runtime.init"() call void @main.init() - %0 = alloca { ptr, ptr }, align 8 - %1 = getelementptr inbounds { ptr, ptr }, ptr %0, i32 0, i32 0 - store ptr @rand, ptr %1, align 8 - %2 = getelementptr inbounds { ptr, ptr }, ptr %0, i32 0, i32 1 - store ptr null, ptr %2, align 8 - %3 = load { ptr, ptr }, ptr %0, align 8 - %4 = call %"github.com/goplus/llgo/internal/runtime.Slice" @main.genInts(i64 5, { ptr, ptr } %3) - %5 = call i64 @"github.com/goplus/llgo/internal/runtime.SliceLen"(%"github.com/goplus/llgo/internal/runtime.Slice" %4) + %0 = call ptr @"github.com/goplus/llgo/internal/runtime.AllocZ"(i64 4) + store i32 1, ptr %0, align 4 + %1 = alloca { ptr, ptr }, align 8 + %2 = getelementptr inbounds { ptr, ptr }, ptr %1, i32 0, i32 0 + store ptr @__llgo_stub.rand, ptr %2, align 8 + %3 = getelementptr inbounds { ptr, ptr }, ptr %1, i32 0, i32 1 + store ptr null, ptr %3, align 8 + %4 = load { ptr, ptr }, ptr %1, align 8 + %5 = call %"github.com/goplus/llgo/internal/runtime.Slice" @main.genInts(i64 5, { ptr, ptr } %4) + %6 = call i64 @"github.com/goplus/llgo/internal/runtime.SliceLen"(%"github.com/goplus/llgo/internal/runtime.Slice" %5) br label %_llgo_1 _llgo_1: ; preds = %_llgo_2, %_llgo_0 - %6 = phi i64 [ -1, %_llgo_0 ], [ %7, %_llgo_2 ] - %7 = add i64 %6, 1 - %8 = icmp slt i64 %7, %5 - br i1 %8, label %_llgo_2, label %_llgo_3 + %7 = phi i64 [ -1, %_llgo_0 ], [ %8, %_llgo_2 ] + %8 = add i64 %7, 1 + %9 = icmp slt i64 %8, %6 + br i1 %9, label %_llgo_2, label %_llgo_3 _llgo_2: ; preds = %_llgo_1 - %9 = call ptr @"github.com/goplus/llgo/internal/runtime.SliceData"(%"github.com/goplus/llgo/internal/runtime.Slice" %4) - %10 = getelementptr inbounds i32, ptr %9, i64 %7 - %11 = load i32, ptr %10, align 4 - %12 = call i32 (ptr, ...) @printf(ptr @0, i32 %11) + %10 = call ptr @"github.com/goplus/llgo/internal/runtime.SliceData"(%"github.com/goplus/llgo/internal/runtime.Slice" %5) + %11 = getelementptr inbounds i32, ptr %10, i64 %8 + %12 = load i32, ptr %11, align 4 + %13 = call i32 (ptr, ...) @printf(ptr @0, i32 %12) br label %_llgo_1 _llgo_3: ; preds = %_llgo_1 + %14 = call ptr @"github.com/goplus/llgo/internal/runtime.AllocU"(i64 8) + %15 = getelementptr inbounds { ptr }, ptr %14, i32 0, i32 0 + store ptr %0, ptr %15, align 8 + %16 = alloca { ptr, ptr }, align 8 + %17 = getelementptr inbounds { ptr, ptr }, ptr %16, i32 0, i32 0 + store ptr @"main.main$1", ptr %17, align 8 + %18 = getelementptr inbounds { ptr, ptr }, ptr %16, i32 0, i32 1 + store ptr %14, ptr %18, align 8 + %19 = load { ptr, ptr }, ptr %16, align 8 + %20 = call %"github.com/goplus/llgo/internal/runtime.Slice" @main.genInts(i64 5, { ptr, ptr } %19) + %21 = call i64 @"github.com/goplus/llgo/internal/runtime.SliceLen"(%"github.com/goplus/llgo/internal/runtime.Slice" %20) + br label %_llgo_4 + +_llgo_4: ; preds = %_llgo_5, %_llgo_3 + %22 = phi i64 [ -1, %_llgo_3 ], [ %23, %_llgo_5 ] + %23 = add i64 %22, 1 + %24 = icmp slt i64 %23, %21 + br i1 %24, label %_llgo_5, label %_llgo_6 + +_llgo_5: ; preds = %_llgo_4 + %25 = call ptr @"github.com/goplus/llgo/internal/runtime.SliceData"(%"github.com/goplus/llgo/internal/runtime.Slice" %20) + %26 = getelementptr inbounds i32, ptr %25, i64 %23 + %27 = load i32, ptr %26, align 4 + %28 = call i32 (ptr, ...) @printf(ptr @1, i32 %27) + br label %_llgo_4 + +_llgo_6: ; preds = %_llgo_4 ret void } @@ -88,4 +119,24 @@ declare void @"github.com/goplus/llgo/internal/runtime.init"() declare i32 @rand() +define i32 @__llgo_stub.rand(ptr %0) { +_llgo_0: + %1 = call i32 @rand() + ret i32 %1 +} + declare i32 @printf(ptr, ...) + +define i32 @"main.main$1"({ ptr } %0) { +_llgo_0: + %1 = extractvalue { ptr } %0, 0 + %2 = load i32, ptr %1, align 4 + %3 = mul i32 %2, 2 + %4 = extractvalue { ptr } %0, 0 + store i32 %3, ptr %4, align 4 + %5 = extractvalue { ptr } %0, 0 + %6 = load i32, ptr %5, align 4 + ret i32 %6 +} + +declare ptr @"github.com/goplus/llgo/internal/runtime.AllocU"(i64) diff --git a/cl/_testrt/qsort/out.ll b/cl/_testrt/qsort/out.ll index 4e8a2f41..0bdae0c1 100644 --- a/cl/_testrt/qsort/out.ll +++ b/cl/_testrt/qsort/out.ll @@ -53,12 +53,12 @@ _llgo_3: ; preds = %_llgo_1 ret void } -declare void @qsort(ptr, i64, i64, ptr) - declare void @"github.com/goplus/llgo/internal/runtime.init"() declare ptr @"github.com/goplus/llgo/internal/runtime.AllocZ"(i64) +declare void @qsort(ptr, i64, i64, ptr) + define i32 @"main.main$1"(ptr %0, ptr %1) { _llgo_0: %2 = load i64, ptr %0, align 4 diff --git a/cl/_testrt/strlen/out.ll b/cl/_testrt/strlen/out.ll index a3aab0c6..4f920ffa 100644 --- a/cl/_testrt/strlen/out.ll +++ b/cl/_testrt/strlen/out.ll @@ -36,8 +36,8 @@ _llgo_0: ret void } -declare void @printf(ptr, ...) +declare void @"github.com/goplus/llgo/internal/runtime.init"() declare i32 @strlen(ptr) -declare void @"github.com/goplus/llgo/internal/runtime.init"() +declare void @printf(ptr, ...) diff --git a/cl/_testrt/struct/out.ll b/cl/_testrt/struct/out.ll index abce9854..24dc3a9f 100644 --- a/cl/_testrt/struct/out.ll +++ b/cl/_testrt/struct/out.ll @@ -70,8 +70,8 @@ _llgo_0: ret void } -declare void @printf(ptr, ...) - declare ptr @"github.com/goplus/llgo/internal/runtime.Zeroinit"(ptr, i64) +declare void @printf(ptr, ...) + declare void @"github.com/goplus/llgo/internal/runtime.init"() diff --git a/cl/builtin_test.go b/cl/builtin_test.go index e09aab97..460166e7 100644 --- a/cl/builtin_test.go +++ b/cl/builtin_test.go @@ -25,6 +25,7 @@ import ( "golang.org/x/tools/go/ssa" ) +/* func TestErrCompileValue(t *testing.T) { defer func() { if r := recover(); r != "can't use llgo instruction as a value" { @@ -43,6 +44,7 @@ func TestErrCompileValue(t *testing.T) { Signature: types.NewSignatureType(nil, nil, nil, nil, nil, false), }) } +*/ func TestErrCompileInstrOrValue(t *testing.T) { defer func() { diff --git a/cl/compile.go b/cl/compile.go index d4e58a07..156e1825 100644 --- a/cl/compile.go +++ b/cl/compile.go @@ -169,7 +169,7 @@ func (p *context) compileMethods(pkg llssa.Package, typ types.Type) { for i, n := 0, mthds.Len(); i < n; i++ { mthd := mthds.At(i) if ssaMthd := prog.MethodValue(mthd); ssaMthd != nil { - p.compileFunc(pkg, mthd.Obj().Pkg(), ssaMthd, false) + p.compileFunc(pkg, mthd.Obj().Pkg(), ssaMthd) } } } @@ -190,55 +190,97 @@ func (p *context) compileGlobal(pkg llssa.Package, gbl *ssa.Global) { } } -func (p *context) compileFunc(pkg llssa.Package, pkgTypes *types.Package, f *ssa.Function, closure bool) llssa.Function { +func makeClosureCtx(pkg *types.Package, vars []*ssa.FreeVar) *types.Var { + n := len(vars) + flds := make([]*types.Var, n) + for i, v := range vars { + flds[i] = types.NewField(token.NoPos, pkg, v.Name(), v.Type(), false) + } + t := types.NewStruct(flds, nil) + return types.NewParam(token.NoPos, pkg, "__llgo_ctx", t) +} + +func (p *context) compileFunc(pkg llssa.Package, pkgTypes *types.Package, f *ssa.Function) llssa.Function { + name, ftype := p.funcName(pkgTypes, f, true) + if ftype != goFunc { + return nil + } + fn := pkg.FuncOf(name) + if fn != nil && fn.HasBody() { + return fn + } + var sig = f.Signature - var name string - var ftype int - if closure { - name, ftype = funcName(pkgTypes, f), goFunc + var hasCtx = len(f.FreeVars) > 0 + if hasCtx { if debugInstr { log.Println("==> NewClosure", name, "type:", sig) } + ctx := makeClosureCtx(pkgTypes, f.FreeVars) + sig = llssa.FuncAddCtx(ctx, sig) } else { - name, ftype = p.funcName(pkgTypes, f, true) - switch ftype { - case ignoredFunc, llgoInstr: // llgo extended instructions - return nil - } if debugInstr { log.Println("==> NewFunc", name, "type:", sig.Recv(), sig) } } - fn := pkg.NewFunc(name, sig, llssa.Background(ftype)) - p.inits = append(p.inits, func() { - p.fn = fn - defer func() { - p.fn = nil - }() - p.phis = nil - nblk := len(f.Blocks) - if nblk == 0 { // external function - return - } - if debugGoSSA { - f.WriteTo(os.Stderr) - } - if debugInstr { - log.Println("==> FuncBody", name) - } - fn.MakeBlocks(nblk) - b := fn.NewBuilder() - p.bvals = make(map[ssa.Value]llssa.Expr) - for i, block := range f.Blocks { - p.compileBlock(b, block, i == 0 && name == "main") - } - for _, phi := range p.phis { - phi() - } - }) + if fn == nil { + fn = pkg.NewFuncEx(name, sig, llssa.Background(ftype), hasCtx) + } + if nblk := len(f.Blocks); nblk > 0 { + fn.MakeBlocks(nblk) // to set fn.HasBody() = true + p.inits = append(p.inits, func() { + p.fn = fn + defer func() { + p.fn = nil + }() + p.phis = nil + if debugGoSSA { + f.WriteTo(os.Stderr) + } + if debugInstr { + log.Println("==> FuncBody", name) + } + b := fn.NewBuilder() + p.bvals = make(map[ssa.Value]llssa.Expr) + for i, block := range f.Blocks { + p.compileBlock(b, block, i == 0 && name == "main") + } + for _, phi := range p.phis { + phi() + } + }) + } return fn } +// funcOf returns a function by name and set ftype = goFunc, cFunc, etc. +// or returns nil and set ftype = llgoCstr, llgoAlloca, llgoUnreachable, etc. +func (p *context) funcOf(fn *ssa.Function) (ret llssa.Function, ftype int) { + pkgTypes := p.ensureLoaded(fn.Pkg.Pkg) + pkg := p.pkg + name, ftype := p.funcName(pkgTypes, fn, false) + if ftype == llgoInstr { + switch name { + case "cstr": + ftype = llgoCstr + case "advance": + ftype = llgoAdvance + case "alloca": + ftype = llgoAlloca + case "allocaCStr": + ftype = llgoAllocaCStr + case "unreachable": + ftype = llgoUnreachable + default: + panic("unknown llgo instruction: " + name) + } + } else if ret = pkg.FuncOf(name); ret == nil && len(fn.FreeVars) == 0 { + sig := fn.Signature + ret = pkg.NewFuncEx(name, sig, llssa.Background(ftype), false) + } + return +} + func (p *context) compileBlock(b llssa.Builder, block *ssa.BasicBlock, doInit bool) llssa.BasicBlock { ret := p.fn.Block(block.Index) b.SetBlock(ret) @@ -519,12 +561,10 @@ func (p *context) compileInstrOrValue(b llssa.Builder, iv instrOrValue, asValue nReserve = p.compileValue(b, v.Reserve) } ret = b.MakeMap(t, nReserve) - /* - case *ssa.MakeClosure: - fn := p.compileValue(b, v.Fn) - bindings := p.compileValues(b, v.Bindings, 0) - ret = b.MakeClosure(fn, bindings) - */ + case *ssa.MakeClosure: + fn := p.compileValue(b, v.Fn) + bindings := p.compileValues(b, v.Bindings, 0) + ret = b.MakeClosure(fn, bindings) case *ssa.TypeAssert: x := p.compileValue(b, v.X) t := p.prog.Type(v.AssertedType, llssa.InGo) @@ -605,14 +645,11 @@ func (p *context) compileValue(b llssa.Builder, v ssa.Value) llssa.Expr { } } case *ssa.Function: - if v.Blocks != nil { - fn := p.compileFunc(p.pkg, p.goTyps, v, true) + if v.Pkg == p.goPkg { // function in this package + fn := p.compileFunc(p.pkg, p.goTyps, v) return fn.Expr } - fn, ftype := p.funcOf(v) - if ftype >= llgoInstrBase { - panic("can't use llgo instruction as a value") - } + fn, _ := p.funcOf(v) return fn.Expr case *ssa.Global: g := p.varOf(v) @@ -620,6 +657,13 @@ func (p *context) compileValue(b llssa.Builder, v ssa.Value) llssa.Expr { case *ssa.Const: t := types.Default(v.Type()) return b.Const(v.Value, p.prog.Type(t, llssa.InGo)) + case *ssa.FreeVar: + fn := v.Parent() + for idx, freeVar := range fn.FreeVars { + if freeVar == v { + return p.fn.FreeVar(b, idx) + } + } } panic(fmt.Sprintf("compileValue: unknown value - %T\n", v)) } @@ -698,7 +742,7 @@ func NewPackage(prog llssa.Program, pkg *ssa.Package, files []*ast.File) (ret ll // Do not try to build generic (non-instantiated) functions. continue } - ctx.compileFunc(ret, member.Pkg.Pkg, member, false) + ctx.compileFunc(ret, member.Pkg.Pkg, member) case *ssa.Type: ctx.compileType(ret, member) case *ssa.Global: diff --git a/cl/import.go b/cl/import.go index 53f70cd2..600abf82 100644 --- a/cl/import.go +++ b/cl/import.go @@ -234,34 +234,6 @@ func (p *context) varName(pkg *types.Package, v *ssa.Global) (vName string, vtyp return name, goVar } -// funcOf returns a function by name and set ftype = goFunc, cFunc, etc. -// or returns nil and set ftype = llgoCstr, llgoAlloca, llgoUnreachable, etc. -func (p *context) funcOf(fn *ssa.Function) (ret llssa.Function, ftype int) { - pkgTypes := p.ensureLoaded(fn.Pkg.Pkg) - pkg := p.pkg - name, ftype := p.funcName(pkgTypes, fn, false) - if ftype == llgoInstr { - switch name { - case "cstr": - ftype = llgoCstr - case "advance": - ftype = llgoAdvance - case "alloca": - ftype = llgoAlloca - case "allocaCStr": - ftype = llgoAllocaCStr - case "unreachable": - ftype = llgoUnreachable - default: - panic("unknown llgo instruction: " + name) - } - } else if ret = pkg.FuncOf(name); ret == nil { - sig := fn.Signature - ret = pkg.NewFunc(name, sig, llssa.Background(ftype)) - } - return -} - func (p *context) varOf(v *ssa.Global) (ret llssa.Global) { pkgTypes := p.ensureLoaded(v.Pkg.Pkg) pkg := p.pkg diff --git a/internal/llgen/llgen.go b/internal/llgen/llgen.go index d56b809e..fafcb959 100644 --- a/internal/llgen/llgen.go +++ b/internal/llgen/llgen.go @@ -71,6 +71,11 @@ func Gen(pkgPath, inFile string, src any) string { } prog := llssa.NewProgram(nil) + prog.SetRuntime(func() *types.Package { + rt, err := imp.Import(llssa.PkgRuntime) + check(err) + return rt + }) ret, err := cl.NewPackage(prog, ssaPkg, files) check(err) diff --git a/internal/runtime/llgo_autogen.ll b/internal/runtime/llgo_autogen.ll index f8b4c3a1..8b75fecf 100644 --- a/internal/runtime/llgo_autogen.ll +++ b/internal/runtime/llgo_autogen.ll @@ -513,8 +513,6 @@ _llgo_0: ret ptr %1 } -declare i32 @rand() - define void @"github.com/goplus/llgo/internal/runtime.init"() { _llgo_0: %0 = load i1, ptr @"github.com/goplus/llgo/internal/runtime.init$guard", align 1 @@ -620,4 +618,6 @@ declare ptr @memcpy(ptr, ptr, i64) declare void @"github.com/goplus/llgo/internal/abi.init"() +declare i32 @rand() + declare i32 @fprintf(ptr, ptr, ...) diff --git a/ssa/decl.go b/ssa/decl.go index 46069265..90255749 100644 --- a/ssa/decl.go +++ b/ssa/decl.go @@ -26,7 +26,9 @@ import ( // ----------------------------------------------------------------------------- const ( - NameValist = "__llgo_va_list" + ClosureCtx = "__llgo_ctx" + ClosureStub = "__llgo_stub." + NameValist = "__llgo_va_list" ) func VArg() *types.Var { @@ -130,15 +132,20 @@ type aFunction struct { blks []BasicBlock params []Type + base int // base = 1 if hasFreeVars; base = 0 otherwise hasVArg bool } // Function represents a function or method. type Function = *aFunction -func newFunction(fn llvm.Value, t Type, pkg Package, prog Program) Function { +func newFunction(fn llvm.Value, t Type, pkg Package, prog Program, hasFreeVars bool) Function { params, hasVArg := newParams(t, prog) - return &aFunction{Expr{fn, t}, pkg, prog, nil, params, hasVArg} + base := 0 + if hasFreeVars { + base = 1 + } + return &aFunction{Expr{fn, t}, pkg, prog, nil, params, base, hasVArg} } func newParams(fn Type, prog Program) (params []Type, hasVArg bool) { @@ -158,9 +165,16 @@ func newParams(fn Type, prog Program) (params []Type, hasVArg bool) { // Params returns the function's ith parameter. func (p Function) Param(i int) Expr { + i += p.base // skip if hasFreeVars return Expr{p.impl.Param(i), p.params[i]} } +// FreeVar returns the function's ith free variable. +func (p Function) FreeVar(b Builder, i int) Expr { + ctx := Expr{p.impl.Param(0), p.params[0]} + return b.Field(ctx, i) +} + // NewBuilder creates a new Builder for the function. func (p Function) NewBuilder() Builder { prog := p.Prog @@ -170,6 +184,11 @@ func (p Function) NewBuilder() Builder { return &aBuilder{b, p, prog} } +// HasBody reports whether the function has a body. +func (p Function) HasBody() bool { + return len(p.blks) > 0 +} + // MakeBody creates nblk basic blocks for the function, and creates // a new Builder associated to #0 block. func (p Function) MakeBody(nblk int) Builder { diff --git a/ssa/expr.go b/ssa/expr.go index 24169ffd..f1f0153a 100644 --- a/ssa/expr.go +++ b/ssa/expr.go @@ -356,23 +356,45 @@ func (b Builder) UnOp(op token.Token, x Expr) Expr { // ----------------------------------------------------------------------------- func checkExpr(v Expr, t types.Type, b Builder) Expr { - if _, ok := t.(*types.Struct); ok { + if t, ok := t.(*types.Struct); ok && isClosure(t) { if v.kind != vkClosure { - prog := b.Prog - nilVal := prog.Null(prog.VoidPtr()).impl - return b.aggregateValue(prog.rawType(t), v.impl, nilVal) + return b.Func.Pkg.closureStub(b, t, v) } } return v } -func llvmValues(vals []Expr, params *types.Tuple, b Builder) (ret []llvm.Value) { +func llvmParamsEx(data Expr, vals []Expr, params *types.Tuple, b Builder) (ret []llvm.Value) { + if data.IsNil() { + return llvmParams(0, vals, params, b) + } + ret = llvmParams(1, vals, params, b) + ret[0] = data.impl + return +} + +func llvmParams(base int, vals []Expr, params *types.Tuple, b Builder) (ret []llvm.Value) { n := params.Len() + if n > 0 { + ret = make([]llvm.Value, len(vals)+base) + for idx, v := range vals { + i := base + idx + if i < n { + v = checkExpr(v, params.At(i).Type(), b) + } + ret[i] = v.impl + } + } + return +} + +func llvmFields(vals []Expr, t *types.Struct, b Builder) (ret []llvm.Value) { + n := t.NumFields() if n > 0 { ret = make([]llvm.Value, len(vals)) for i, v := range vals { if i < n { - v = checkExpr(v, params.At(i).Type(), b) + v = checkExpr(v, t.Field(i).Type(), b) } ret[i] = v.impl } @@ -479,13 +501,23 @@ func (b Builder) Store(ptr, val Expr) Builder { return b } +func (b Builder) aggregateAlloc(t Type, flds ...llvm.Value) llvm.Value { + prog := b.Prog + pkg := b.Func.Pkg + size := prog.SizeOf(t) + ptr := b.InlineCall(pkg.rtFunc("AllocU"), prog.IntVal(size, prog.Uintptr())).impl + tll := t.ll + impl := b.impl + for i, fld := range flds { + impl.CreateStore(fld, llvm.CreateStructGEP(impl, tll, ptr, i)) + } + return ptr +} + // aggregateValue yields the value of the aggregate X with the fields func (b Builder) aggregateValue(t Type, flds ...llvm.Value) Expr { - if debugInstr { - log.Printf("AggregateValue %v, %v\n", t.RawType(), flds) - } - impl := b.impl tll := t.ll + impl := b.impl ptr := llvm.CreateAlloca(impl, tll) for i, fld := range flds { impl.CreateStore(fld, llvm.CreateStructGEP(impl, tll, ptr, i)) @@ -493,7 +525,6 @@ func (b Builder) aggregateValue(t Type, flds ...llvm.Value) Expr { return Expr{llvm.CreateLoad(b.impl, tll, ptr), t} } -/* // The MakeClosure instruction yields a closure value whose code is // Fn and whose free variables' values are supplied by Bindings. // @@ -507,9 +538,14 @@ func (b Builder) MakeClosure(fn Expr, bindings []Expr) Expr { if debugInstr { log.Printf("MakeClosure %v, %v\n", fn, bindings) } - panic("todo") + prog := b.Prog + tfn := fn.Type + sig := tfn.raw.Type.(*types.Signature) + tctx := sig.Params().At(0).Type().Underlying().(*types.Struct) + flds := llvmFields(bindings, tctx, b) + data := b.aggregateAlloc(prog.rawType(tctx), flds...) + return b.aggregateValue(prog.Closure(tfn), fn.impl, data) } -*/ // The FieldAddr instruction yields the address of Field of *struct X. // @@ -1052,10 +1088,12 @@ func (b Builder) Call(fn Expr, args ...Expr) (ret Expr) { log.Println(b.String()) } var ll llvm.Type + var data Expr var sig *types.Signature var raw = fn.raw.Type switch fn.kind { case vkClosure: + data = b.Field(fn, 1) fn = b.Field(fn, 0) raw = fn.raw.Type fallthrough @@ -1069,7 +1107,7 @@ func (b Builder) Call(fn Expr, args ...Expr) (ret Expr) { panic("unreachable") } ret.Type = prog.retType(sig) - ret.impl = llvm.CreateCall(b.impl, ll, fn.impl, llvmValues(args, sig.Params(), b)) + ret.impl = llvm.CreateCall(b.impl, ll, fn.impl, llvmParamsEx(data, args, sig.Params(), b)) return } diff --git a/ssa/package.go b/ssa/package.go index 89277bc6..e58cff70 100644 --- a/ssa/package.go +++ b/ssa/package.go @@ -17,6 +17,7 @@ package ssa import ( + "go/token" "go/types" "log" @@ -223,10 +224,11 @@ func (p Program) NewPackage(name, pkgPath string) Package { mod := p.ctx.NewModule(pkgPath) // TODO(xsw): Finalize may cause panic, so comment it. // mod.Finalize() - fns := make(map[string]Function) gbls := make(map[string]Global) + fns := make(map[string]Function) + stubs := make(map[string]Function) p.needRuntime = false - return &aPackage{mod, fns, gbls, p} + return &aPackage{mod, gbls, fns, stubs, p} } // Void returns void type. @@ -309,10 +311,11 @@ func (p Program) Float64() Type { // initializer) and "init#%d", the nth declared init function, // and unspecified other things too. type aPackage struct { - mod llvm.Module - fns map[string]Function - vars map[string]Global - Prog Program + mod llvm.Module + vars map[string]Global + fns map[string]Function + stubs map[string]Function + Prog Program } type Package = *aPackage @@ -340,15 +343,20 @@ func (p Package) VarOf(name string) Global { // NewFunc creates a new function. func (p Package) NewFunc(name string, sig *types.Signature, bg Background) Function { + return p.NewFuncEx(name, sig, bg, false) +} + +// NewFuncEx creates a new function. +func (p Package) NewFuncEx(name string, sig *types.Signature, bg Background, hasFreeVars bool) Function { if v, ok := p.fns[name]; ok { return v } t := p.Prog.FuncDecl(sig, bg) if debugInstr { - log.Println("NewFunc", name, t.raw.Type) + log.Println("NewFunc", name, t.raw.Type, "hasFreeVars:", hasFreeVars) } fn := llvm.AddFunction(p.mod, name, t.ll) - ret := newFunction(fn, t, p, p.Prog) + ret := newFunction(fn, t, p, p.Prog, hasFreeVars) p.fns[name] = ret return ret } @@ -360,6 +368,37 @@ func (p Package) rtFunc(fnName string) Expr { return p.NewFunc(name, sig, InGo).Expr } +func (p Package) closureStub(b Builder, t *types.Struct, v Expr) Expr { + name := v.impl.Name() + prog := b.Prog + nilVal := prog.Null(prog.VoidPtr()).impl + if fn, ok := p.stubs[name]; ok { + v = fn.Expr + } else { + sig := v.raw.Type.(*types.Signature) + n := sig.Params().Len() + nret := sig.Results().Len() + ctx := types.NewParam(token.NoPos, nil, ClosureCtx, types.Typ[types.UnsafePointer]) + sig = FuncAddCtx(ctx, sig) + fn := p.NewFunc(ClosureStub+name, sig, InC) + args := make([]Expr, n) + for i := 0; i < n; i++ { + args[i] = fn.Param(i + 1) + } + b := fn.MakeBody(1) + call := b.Call(v, args...) + switch nret { + case 0: + b.impl.CreateRetVoid() + default: // TODO(xsw): support multiple return values + b.impl.CreateRet(call.impl) + } + p.stubs[name] = fn + v = fn.Expr + } + return b.aggregateValue(prog.rawType(t), v.impl, nilVal) +} + // FuncOf returns a function by name. func (p Package) FuncOf(name string) Function { return p.fns[name] diff --git a/ssa/ssa_test.go b/ssa/ssa_test.go index 03666914..589094d2 100644 --- a/ssa/ssa_test.go +++ b/ssa/ssa_test.go @@ -57,8 +57,8 @@ func TestCvtType(t *testing.T) { callback := types.NewSignatureType(nil, nil, nil, nil, nil, false) params := types.NewTuple(types.NewParam(0, nil, "", callback)) sig := types.NewSignatureType(nil, nil, nil, params, nil, false) - ret1 := gt.cvtFunc(sig, false) - if ret1 == sig || gt.cvtFunc(sig, false) != ret1 { + ret1 := gt.cvtFunc(sig, nil) + if ret1 == sig { t.Fatal("cvtFunc failed") } defer func() { diff --git a/ssa/stmt_builder.go b/ssa/stmt_builder.go index 9ae93467..a2e77687 100644 --- a/ssa/stmt_builder.go +++ b/ssa/stmt_builder.go @@ -104,7 +104,7 @@ func (b Builder) Return(results ...Expr) { b.impl.CreateRet(results[0].impl) default: tret := b.Func.raw.Type.(*types.Signature).Results() - b.impl.CreateAggregateRet(llvmValues(results, tret, b)) + b.impl.CreateAggregateRet(llvmParams(0, results, tret, b)) } } diff --git a/ssa/type_cvt.go b/ssa/type_cvt.go index cd6a598f..76f9aae0 100644 --- a/ssa/type_cvt.go +++ b/ssa/type_cvt.go @@ -54,11 +54,18 @@ func (p Program) Type(typ types.Type, bg Background) Type { // FuncDecl converts a Go/C function declaration into raw type. func (p Program) FuncDecl(sig *types.Signature, bg Background) Type { if bg == InGo { - sig = p.gocvt.cvtFunc(sig, true) + sig = p.gocvt.cvtFunc(sig, sig.Recv()) } return &aType{p.toLLVMFunc(sig), rawType{sig}, vkFuncDecl} } +// Closure creates a closture type for a function. +func (p Program) Closure(fn Type) Type { + sig := fn.raw.Type.(*types.Signature) + closure := p.gocvt.cvtClosure(sig) + return p.rawType(closure) +} + func (p goTypes) cvtType(typ types.Type) (raw types.Type, cvt bool) { switch t := typ.(type) { case *types.Basic: @@ -116,7 +123,8 @@ func (p goTypes) cvtNamed(t *types.Named) (raw *types.Named, cvt bool) { } func (p goTypes) cvtClosure(sig *types.Signature) *types.Struct { - raw := p.cvtFunc(sig, false) + ctx := types.NewParam(token.NoPos, nil, ClosureCtx, types.Typ[types.UnsafePointer]) + raw := p.cvtFunc(sig, ctx) flds := []*types.Var{ types.NewField(token.NoPos, nil, "f", raw, false), types.NewField(token.NoPos, nil, "data", types.Typ[types.UnsafePointer], false), @@ -124,15 +132,9 @@ func (p goTypes) cvtClosure(sig *types.Signature) *types.Struct { return types.NewStruct(flds, nil) } -func (p goTypes) cvtFunc(sig *types.Signature, hasRecv bool) (raw *types.Signature) { - if v, ok := p.typs[unsafe.Pointer(sig)]; ok { - return (*types.Signature)(v) - } - defer func() { - p.typs[unsafe.Pointer(sig)] = unsafe.Pointer(raw) - }() - if hasRecv { - sig = methodToFunc(sig) +func (p goTypes) cvtFunc(sig *types.Signature, recv *types.Var) (raw *types.Signature) { + if recv != nil { + sig = FuncAddCtx(recv, sig) } params, cvt1 := p.cvtTuple(sig.Params()) results, cvt2 := p.cvtTuple(sig.Results()) @@ -167,7 +169,7 @@ func (p goTypes) cvtExplicitMethods(typ *types.Interface) ([]*types.Func, bool) for i := 0; i < n; i++ { m := typ.ExplicitMethod(i) sig := m.Type().(*types.Signature) - if raw := p.cvtFunc(sig, false); sig != raw { + if raw := p.cvtFunc(sig, nil); sig != raw { m = types.NewFunc(m.Pos(), m.Pkg(), m.Name(), raw) needcvt = true } @@ -236,20 +238,17 @@ func (p goTypes) cvtStruct(typ *types.Struct) (raw *types.Struct, cvt bool) { // ----------------------------------------------------------------------------- -// convert method to func -func methodToFunc(sig *types.Signature) *types.Signature { - if recv := sig.Recv(); recv != nil { - tParams := sig.Params() - nParams := tParams.Len() - params := make([]*types.Var, nParams+1) - params[0] = recv - for i := 0; i < nParams; i++ { - params[i+1] = tParams.At(i) - } - return types.NewSignatureType( - nil, nil, nil, types.NewTuple(params...), sig.Results(), sig.Variadic()) +// FuncAddCtx adds a ctx to a function signature. +func FuncAddCtx(ctx *types.Var, sig *types.Signature) *types.Signature { + tParams := sig.Params() + nParams := tParams.Len() + params := make([]*types.Var, nParams+1) + params[0] = ctx + for i := 0; i < nParams; i++ { + params[i+1] = tParams.At(i) } - return sig + return types.NewSignatureType( + nil, nil, nil, types.NewTuple(params...), sig.Results(), sig.Variadic()) } // -----------------------------------------------------------------------------