Merge pull request #666 from cpunion/async-functions

Async functions
This commit is contained in:
xushiwei
2024-08-06 20:38:34 +08:00
committed by GitHub
25 changed files with 2019 additions and 486 deletions

27
cl/_testdata/async/in.go Normal file
View File

@@ -0,0 +1,27 @@
package async
import (
"github.com/goplus/llgo/x/async"
)
func GenInts() (co *async.Promise[int]) {
co.Yield(1)
co.Yield(2)
co.Yield(3)
return
}
func WrapGenInts() *async.Promise[int] {
return GenInts()
}
func UseGenInts() int {
co := WrapGenInts()
r := 0
for !co.Done() {
v := co.Value()
r += v
co.Next()
}
return r
}

201
cl/_testdata/async/out.ll Normal file
View File

@@ -0,0 +1,201 @@
; ModuleID = 'async'
source_filename = "async"
%"github.com/goplus/llgo/x/async.Promise[int]" = type { ptr, i64 }
@"async.init$guard" = global i1 false, align 1
define ptr @async.GenInts() presplitcoroutine {
entry:
%id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
%frame.size = call i64 @llvm.coro.size.i64()
%alloc.size = add i64 16, %frame.size
%promise = call ptr @"github.com/goplus/llgo/internal/runtime.AllocZ"(i64 %alloc.size)
%need.dyn.alloc = call i1 @llvm.coro.alloc(token %id)
br i1 %need.dyn.alloc, label %alloc, label %_llgo_5
alloc: ; preds = %entry
%0 = getelementptr ptr, ptr %promise, i64 16
br label %_llgo_5
clean: ; preds = %_llgo_8, %_llgo_7, %_llgo_6, %_llgo_5
%1 = call ptr @llvm.coro.free(token %id, ptr %hdl)
br label %suspend
suspend: ; preds = %_llgo_8, %_llgo_7, %_llgo_6, %_llgo_5, %clean
%2 = call i1 @llvm.coro.end(ptr %hdl, i1 false, token none)
ret ptr %promise
trap: ; preds = %_llgo_8
call void @llvm.trap()
unreachable
_llgo_5: ; preds = %alloc, %entry
%frame = phi ptr [ null, %entry ], [ %0, %alloc ]
%hdl = call ptr @llvm.coro.begin(token %id, ptr %frame)
store ptr %hdl, ptr %promise, align 8
call void @"github.com/goplus/llgo/x/async.(*Promise).setValue[int]"(ptr %promise, i64 1)
%3 = call i8 @llvm.coro.suspend(token %id, i1 false)
switch i8 %3, label %suspend [
i8 0, label %_llgo_6
i8 1, label %clean
]
_llgo_6: ; preds = %_llgo_5
call void @"github.com/goplus/llgo/x/async.(*Promise).setValue[int]"(ptr %promise, i64 2)
%4 = call i8 @llvm.coro.suspend(token %id, i1 false)
switch i8 %4, label %suspend [
i8 0, label %_llgo_7
i8 1, label %clean
]
_llgo_7: ; preds = %_llgo_6
call void @"github.com/goplus/llgo/x/async.(*Promise).setValue[int]"(ptr %promise, i64 3)
%5 = call i8 @llvm.coro.suspend(token %id, i1 false)
switch i8 %5, label %suspend [
i8 0, label %_llgo_8
i8 1, label %clean
]
_llgo_8: ; preds = %_llgo_7
%6 = call i8 @llvm.coro.suspend(token %id, i1 true)
switch i8 %6, label %suspend [
i8 0, label %trap
i8 1, label %clean
]
}
define i64 @async.UseGenInts() {
_llgo_0:
%0 = call ptr @async.WrapGenInts()
br label %_llgo_3
_llgo_1: ; preds = %_llgo_3
%1 = call i64 @"github.com/goplus/llgo/x/async.(*Promise).Value[int]"(ptr %0)
%2 = add i64 %3, %1
call void @"github.com/goplus/llgo/x/async.(*Promise).Next[int]"(ptr %0)
br label %_llgo_3
_llgo_2: ; preds = %_llgo_3
ret i64 %3
_llgo_3: ; preds = %_llgo_1, %_llgo_0
%3 = phi i64 [ 0, %_llgo_0 ], [ %2, %_llgo_1 ]
%4 = call i1 @"github.com/goplus/llgo/x/async.(*Promise).Done[int]"(ptr %0)
br i1 %4, label %_llgo_2, label %_llgo_1
}
define ptr @async.WrapGenInts() presplitcoroutine {
entry:
%id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
%frame.size = call i64 @llvm.coro.size.i64()
%alloc.size = add i64 16, %frame.size
%promise = call ptr @"github.com/goplus/llgo/internal/runtime.AllocZ"(i64 %alloc.size)
%need.dyn.alloc = call i1 @llvm.coro.alloc(token %id)
br i1 %need.dyn.alloc, label %alloc, label %_llgo_5
alloc: ; preds = %entry
%0 = getelementptr ptr, ptr %promise, i64 16
br label %_llgo_5
clean: ; preds = %_llgo_5
%1 = call ptr @llvm.coro.free(token %id, ptr %hdl)
br label %suspend
suspend: ; preds = %_llgo_5, %clean
%2 = call i1 @llvm.coro.end(ptr %hdl, i1 false, token none)
ret ptr %promise
trap: ; preds = %_llgo_5
call void @llvm.trap()
unreachable
_llgo_5: ; preds = %alloc, %entry
%frame = phi ptr [ null, %entry ], [ %0, %alloc ]
%hdl = call ptr @llvm.coro.begin(token %id, ptr %frame)
store ptr %hdl, ptr %promise, align 8
%3 = call ptr @async.GenInts()
%4 = call i8 @llvm.coro.suspend(token %id, i1 true)
switch i8 %4, label %suspend [
i8 0, label %trap
i8 1, label %clean
]
}
define void @async.init() {
_llgo_0:
%0 = load i1, ptr @"async.init$guard", align 1
br i1 %0, label %_llgo_2, label %_llgo_1
_llgo_1: ; preds = %_llgo_0
store i1 true, ptr @"async.init$guard", align 1
br label %_llgo_2
_llgo_2: ; preds = %_llgo_1, %_llgo_0
ret void
}
; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: read)
declare token @llvm.coro.id(i32, ptr readnone, ptr nocapture readonly, ptr)
; Function Attrs: nounwind memory(none)
declare i64 @llvm.coro.size.i64()
declare ptr @"github.com/goplus/llgo/internal/runtime.AllocZ"(i64)
; Function Attrs: nounwind
declare i1 @llvm.coro.alloc(token)
; Function Attrs: nounwind
declare ptr @llvm.coro.begin(token, ptr writeonly)
; Function Attrs: nounwind memory(argmem: read)
declare ptr @llvm.coro.free(token, ptr nocapture readonly)
; Function Attrs: nounwind
declare i1 @llvm.coro.end(ptr, i1, token)
; Function Attrs: cold noreturn nounwind memory(inaccessiblemem: write)
declare void @llvm.trap()
define void @"github.com/goplus/llgo/x/async.(*Promise).setValue[int]"(ptr %0, i64 %1) {
_llgo_0:
%2 = getelementptr inbounds %"github.com/goplus/llgo/x/async.Promise[int]", ptr %0, i32 0, i32 1
store i64 %1, ptr %2, align 4
ret void
}
; Function Attrs: nounwind
declare i8 @llvm.coro.suspend(token, i1)
define i1 @"github.com/goplus/llgo/x/async.(*Promise).Done[int]"(ptr %0) {
_llgo_0:
%1 = getelementptr inbounds %"github.com/goplus/llgo/x/async.Promise[int]", ptr %0, i32 0, i32 0
%2 = load ptr, ptr %1, align 8
%3 = call i1 @llvm.coro.done(ptr %2)
%4 = zext i1 %3 to i64
%5 = trunc i64 %4 to i8
%6 = icmp ne i8 %5, 0
ret i1 %6
}
define i64 @"github.com/goplus/llgo/x/async.(*Promise).Value[int]"(ptr %0) {
_llgo_0:
%1 = getelementptr inbounds %"github.com/goplus/llgo/x/async.Promise[int]", ptr %0, i32 0, i32 1
%2 = load i64, ptr %1, align 4
ret i64 %2
}
define void @"github.com/goplus/llgo/x/async.(*Promise).Next[int]"(ptr %0) {
_llgo_0:
%1 = getelementptr inbounds %"github.com/goplus/llgo/x/async.Promise[int]", ptr %0, i32 0, i32 0
%2 = load ptr, ptr %1, align 8
call void @llvm.coro.resume(ptr %2)
ret void
}
; Function Attrs: nounwind memory(argmem: readwrite)
declare i1 @llvm.coro.done(ptr nocapture readonly)
declare void @llvm.coro.resume(ptr)

119
cl/async.go Normal file
View File

@@ -0,0 +1,119 @@
/*
* Copyright (c) 2024 The GoPlus Authors (goplus.org). All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package cl
import (
"go/constant"
"go/types"
"strings"
llssa "github.com/goplus/llgo/ssa"
"golang.org/x/tools/go/ssa"
)
// TODO(lijie): need more generics, shouldn't limit to async.Promise
func promiseType(ty types.Type) (types.Type, bool) {
// ty is a generic type, so we need to check the package path and type name
if ptrTy, ok := ty.(*types.Pointer); ok {
ty = ptrTy.Elem()
if ty, ok := ty.(*types.Named); ok {
if ty.Obj().Pkg() == nil {
return nil, false
}
if ty.Obj().Pkg().Path() == "github.com/goplus/llgo/x/async" && ty.Obj().Name() == "Promise" {
return ty, true
}
}
}
return nil, false
}
// check function return async.Promise[T]
// TODO(lijie): make it generic
func isAsyncFunc(sig *types.Signature) bool {
r := sig.Results()
if r.Len() != 1 {
return false
}
ty := r.At(0).Type()
_, ok := promiseType(ty)
return ok
}
func (p *context) coAwait(b llssa.Builder, args []ssa.Value) llssa.Expr {
if !isAsyncFunc(b.Func.RawType().(*types.Signature)) {
panic("coAwait(promise *T) T: invalid context")
}
if len(args) == 1 {
// promise := p.compileValue(b, args[0])
b.Unreachable()
// return b.CoroutineAwait(promise)
}
panic("coAwait(promise *T) T: invalid arguments")
}
func (p *context) coSuspend(b llssa.Builder, final llssa.Expr) {
b.CoSuspend(b.AsyncToken(), final, nil)
}
func (p *context) coDone(b llssa.Builder, args []ssa.Value) llssa.Expr {
if len(args) != 1 {
panic("coDone(promise *T): invalid arguments")
}
hdl := p.compileValue(b, args[0])
return b.CoDone(hdl)
}
func (p *context) coResume(b llssa.Builder, args []ssa.Value) {
if len(args) == 1 {
hdl := p.compileValue(b, args[0])
b.CoResume(hdl)
}
}
func (p *context) getSetValueFunc(fn *ssa.Function) llssa.Function {
typ := fn.Signature.Recv().Type()
mthds := p.goProg.MethodSets.MethodSet(typ)
for i := 0; i < mthds.Len(); i++ {
m := mthds.At(i)
if ssaMthd := p.goProg.MethodValue(m); ssaMthd != nil {
if ssaMthd.Name() == "setValue" || strings.HasPrefix(ssaMthd.Name(), "setValue[") {
setValueFn, _, _ := p.compileFunction(ssaMthd)
return setValueFn
}
}
}
panic("method setValue not found on type " + typ.String())
}
func (p *context) coReturn(b llssa.Builder, fn *ssa.Function, args []ssa.Value) {
setValueFn := p.getSetValueFunc(fn)
value := p.compileValue(b, args[1])
b.CoReturn(setValueFn, value)
}
func (p *context) coYield(b llssa.Builder, fn *ssa.Function, args []ssa.Value) {
setValueFn := p.getSetValueFunc(fn)
value := p.compileValue(b, args[1])
// TODO(lijie): find whether the co.Yield/co.Return is the last instruction
final := b.Const(constant.MakeBool(false), b.Prog.Bool())
b.CoYield(setValueFn, value, final)
}
func (p *context) coRun(b llssa.Builder, args []ssa.Value) {
panic("coRun(): not implemented")
}

View File

@@ -218,6 +218,7 @@ func (p *context) compileFuncDecl(pkg llssa.Package, f *ssa.Function) (llssa.Fun
log.Println("==> NewFunc", name, "type:", sig.Recv(), sig, "ftype:", ftype)
}
}
async := isAsyncFunc(f.Signature)
if fn == nil {
if name == "main" {
argc := types.NewParam(token.NoPos, pkgTypes, "", types.Typ[types.Int32])
@@ -227,13 +228,24 @@ func (p *context) compileFuncDecl(pkg llssa.Package, f *ssa.Function) (llssa.Fun
results := types.NewTuple(ret)
sig = types.NewSignatureType(nil, nil, nil, params, results, false)
}
fn = pkg.NewFuncEx(name, sig, llssa.Background(ftype), hasCtx)
fn = pkg.NewFuncEx(name, sig, llssa.Background(ftype), hasCtx, async)
}
nBlkOff := 0
if nblk := len(f.Blocks); nblk > 0 {
fn.MakeBlocks(nblk) // to set fn.HasBody() = true
var entryBlk, allocBlk, cleanBlk, suspdBlk, trapBlk, beginBlk llssa.BasicBlock
if async {
nBlkOff = 5
entryBlk = fn.MakeBlock("entry")
allocBlk = fn.MakeBlock("alloc")
cleanBlk = fn.MakeBlock("clean")
suspdBlk = fn.MakeBlock("suspend")
trapBlk = fn.MakeBlock("trap")
}
fn.MakeBlocks(nblk) // to set fn.HasBody() = true
beginBlk = fn.Block(nBlkOff)
if f.Recover != nil { // set recover block
fn.SetRecover(fn.Block(f.Recover.Index))
// TODO(lijie): fix this for async function because of the block offset increase
fn.SetRecover(fn.Block(f.Recover.Index + nBlkOff))
}
p.inits = append(p.inits, func() {
p.fn = fn
@@ -249,6 +261,10 @@ func (p *context) compileFuncDecl(pkg llssa.Package, f *ssa.Function) (llssa.Fun
log.Println("==> FuncBody", name)
}
b := fn.NewBuilder()
b.SetBlockOffset(nBlkOff)
if async {
b.BeginAsync(fn, entryBlk, allocBlk, cleanBlk, suspdBlk, trapBlk, beginBlk)
}
p.bvals = make(map[ssa.Value]llssa.Expr)
off := make([]int, len(f.Blocks))
for i, block := range f.Blocks {
@@ -284,7 +300,7 @@ func (p *context) compileBlock(b llssa.Builder, block *ssa.BasicBlock, n int, do
var pkg = p.pkg
var fn = p.fn
var instrs = block.Instrs[n:]
var ret = fn.Block(block.Index)
var ret = fn.Block(block.Index + b.BlockOffset())
b.SetBlock(ret)
if doModInit {
if pyModInit = p.pyMod != ""; pyModInit {
@@ -322,7 +338,7 @@ func (p *context) compileBlock(b llssa.Builder, block *ssa.BasicBlock, n int, do
modPtr := pkg.PyNewModVar(modName, true).Expr
mod := b.Load(modPtr)
cond := b.BinOp(token.NEQ, mod, prog.Nil(mod.Type))
newBlk := fn.MakeBlock()
newBlk := fn.MakeBlock("")
b.If(cond, jumpTo, newBlk)
b.SetBlockEx(newBlk, llssa.AtEnd, false)
b.Store(modPtr, b.PyImportMod(modPath))
@@ -654,7 +670,11 @@ func (p *context) compileInstr(b llssa.Builder, instr ssa.Instruction) {
results = make([]llssa.Expr, 1)
results[0] = p.prog.IntVal(0, p.prog.CInt())
}
b.Return(results...)
if b.Async() {
b.EndAsync()
} else {
b.Return(results...)
}
case *ssa.If:
fn := p.fn
cond := p.compileValue(b, v.Cond)

View File

@@ -53,7 +53,7 @@ func TestFromTestrt(t *testing.T) {
}
func TestFromTestdata(t *testing.T) {
cltest.FromDir(t, "", "./_testdata", false)
cltest.FromDir(t, "", "./_testdata", true)
}
func TestFromTestpymath(t *testing.T) {
@@ -124,3 +124,114 @@ _llgo_2: ; preds = %_llgo_1, %_llgo_0
}
`)
}
func TestAsyncFunc(t *testing.T) {
testCompile(t, `package foo
import "github.com/goplus/llgo/x/async"
func GenInts() (co *async.Promise[int]) {
co.Yield(1)
co.Yield(2)
return
}
`, `; ModuleID = 'foo'
source_filename = "foo"
@"foo.init$guard" = global i1 false, align 1
define ptr @foo.GenInts() presplitcoroutine {
entry:
%id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
%frame.size = call i64 @llvm.coro.size.i64()
%alloc.size = add i64 16, %frame.size
%promise = call ptr @"github.com/goplus/llgo/internal/runtime.AllocZ"(i64 %alloc.size)
%need.dyn.alloc = call i1 @llvm.coro.alloc(token %id)
br i1 %need.dyn.alloc, label %alloc, label %_llgo_5
alloc: ; preds = %entry
%0 = getelementptr ptr, ptr %promise, i64 16
br label %_llgo_5
clean: ; preds = %_llgo_7, %_llgo_6, %_llgo_5
%1 = call ptr @llvm.coro.free(token %id, ptr %hdl)
br label %suspend
suspend: ; preds = %_llgo_7, %_llgo_6, %_llgo_5, %clean
%2 = call i1 @llvm.coro.end(ptr %hdl, i1 false, token none)
ret ptr %promise
trap: ; preds = %_llgo_7
call void @llvm.trap()
unreachable
_llgo_5: ; preds = %alloc, %entry
%frame = phi ptr [ null, %entry ], [ %0, %alloc ]
%hdl = call ptr @llvm.coro.begin(token %id, ptr %frame)
store ptr %hdl, ptr %promise, align 8
call void @"github.com/goplus/llgo/x/async.(*Promise).setValue[int]"(ptr %promise, i64 1)
%3 = call i8 @llvm.coro.suspend(token %id, i1 false)
switch i8 %3, label %suspend [
i8 0, label %_llgo_6
i8 1, label %clean
]
_llgo_6: ; preds = %_llgo_5
call void @"github.com/goplus/llgo/x/async.(*Promise).setValue[int]"(ptr %promise, i64 2)
%4 = call i8 @llvm.coro.suspend(token %id, i1 false)
switch i8 %4, label %suspend [
i8 0, label %_llgo_7
i8 1, label %clean
]
_llgo_7: ; preds = %_llgo_6
%5 = call i8 @llvm.coro.suspend(token %id, i1 true)
switch i8 %5, label %suspend [
i8 0, label %trap
i8 1, label %clean
]
}
define void @foo.init() {
_llgo_0:
%0 = load i1, ptr @"foo.init$guard", align 1
br i1 %0, label %_llgo_2, label %_llgo_1
_llgo_1: ; preds = %_llgo_0
store i1 true, ptr @"foo.init$guard", align 1
br label %_llgo_2
_llgo_2: ; preds = %_llgo_1, %_llgo_0
ret void
}
; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: read)
declare token @llvm.coro.id(i32, ptr readnone, ptr nocapture readonly, ptr)
; Function Attrs: nounwind memory(none)
declare i64 @llvm.coro.size.i64()
declare ptr @"github.com/goplus/llgo/internal/runtime.AllocZ"(i64)
; Function Attrs: nounwind
declare i1 @llvm.coro.alloc(token)
; Function Attrs: nounwind
declare ptr @llvm.coro.begin(token, ptr writeonly)
; Function Attrs: nounwind memory(argmem: read)
declare ptr @llvm.coro.free(token, ptr nocapture readonly)
; Function Attrs: nounwind
declare i1 @llvm.coro.end(ptr, i1, token)
; Function Attrs: cold noreturn nounwind memory(inaccessiblemem: write)
declare void @llvm.trap()
declare void @"github.com/goplus/llgo/x/async.(*Promise).setValue[int]"(ptr, i64)
; Function Attrs: nounwind
declare i8 @llvm.coro.suspend(token, i1)
`)
}

View File

@@ -412,7 +412,16 @@ const (
llgoAtomicUMax = llgoAtomicOpBase + llssa.OpUMax
llgoAtomicUMin = llgoAtomicOpBase + llssa.OpUMin
llgoAtomicOpLast = llgoAtomicOpBase + int(llssa.OpUMin)
llgoCoBase = llgoInstrBase + 0x30
llgoCoAwait = llgoCoBase + 0
llgoCoSuspend = llgoCoBase + 1
llgoCoDone = llgoCoBase + 2
llgoCoResume = llgoCoBase + 3
llgoCoReturn = llgoCoBase + 4
llgoCoYield = llgoCoBase + 5
llgoCoRun = llgoCoBase + 6
llgoAtomicOpLast = llgoCoRun
)
func (p *context) funcName(fn *ssa.Function, ignore bool) (*types.Package, string, int) {

View File

@@ -237,6 +237,14 @@ var llgoInstrs = map[string]int{
"atomicMin": int(llgoAtomicMin),
"atomicUMax": int(llgoAtomicUMax),
"atomicUMin": int(llgoAtomicUMin),
"coAwait": int(llgoCoAwait),
"coResume": int(llgoCoResume),
"coSuspend": int(llgoCoSuspend),
"coDone": int(llgoCoDone),
"coReturn": int(llgoCoReturn),
"coYield": int(llgoCoYield),
"coRun": int(llgoCoRun),
}
// funcOf returns a function by name and set ftype = goFunc, cFunc, etc.
@@ -265,7 +273,8 @@ func (p *context) funcOf(fn *ssa.Function) (aFn llssa.Function, pyFn llssa.PyObj
return nil, nil, ignoredFunc
}
sig := fn.Signature
aFn = pkg.NewFuncEx(name, sig, llssa.Background(ftype), false)
async := isAsyncFunc(sig)
aFn = pkg.NewFuncEx(name, sig, llssa.Background(ftype), false, async)
}
}
return
@@ -390,6 +399,20 @@ func (p *context) call(b llssa.Builder, act llssa.DoAction, call *ssa.CallCommon
ret = p.funcAddr(b, args)
case llgoUnreachable: // func unreachable()
b.Unreachable()
case llgoCoAwait:
ret = p.coAwait(b, args)
case llgoCoSuspend:
p.coSuspend(b, p.prog.BoolVal(false))
case llgoCoDone:
return p.coDone(b, args)
case llgoCoResume:
p.coResume(b, args)
case llgoCoReturn:
p.coReturn(b, cv, args)
case llgoCoYield:
p.coYield(b, cv, args)
case llgoCoRun:
p.coRun(b, args)
default:
if ftype >= llgoAtomicOpBase && ftype <= llgoAtomicOpLast {
ret = p.atomic(b, llssa.AtomicOp(ftype-llgoAtomicOpBase), args)