Merge pull request #614 from aofei/min-max

ssa: add support for `min` and `max` built-in functions
This commit is contained in:
xushiwei
2024-07-30 19:59:33 +08:00
committed by GitHub
2 changed files with 58 additions and 3 deletions

View File

@@ -383,7 +383,7 @@ var boolPredOpToLLVM = []llvm.IntPredicate{
token.NEQ - predOpBase: llvm.IntNE, token.NEQ - predOpBase: llvm.IntNE,
} }
// EQL NEQ LSS LEQ GTR GEQ == != < <= < >= // EQL NEQ LSS LEQ GTR GEQ == != < <= > >=
func isPredOp(op token.Token) bool { func isPredOp(op token.Token) bool {
return op >= predOpBase && op <= predOpLast return op >= predOpBase && op <= predOpLast
} }
@@ -392,7 +392,7 @@ func isPredOp(op token.Token) bool {
// op can be: // op can be:
// ADD SUB MUL QUO REM + - * / % // ADD SUB MUL QUO REM + - * / %
// AND OR XOR SHL SHR AND_NOT & | ^ << >> &^ // AND OR XOR SHL SHR AND_NOT & | ^ << >> &^
// EQL NEQ LSS LEQ GTR GEQ == != < <= < >= // EQL NEQ LSS LEQ GTR GEQ == != < <= > >=
func (b Builder) BinOp(op token.Token, x, y Expr) Expr { func (b Builder) BinOp(op token.Token, x, y Expr) Expr {
if debugInstr { if debugInstr {
log.Printf("BinOp %d, %v, %v\n", op, x.impl, y.impl) log.Printf("BinOp %d, %v, %v\n", op, x.impl, y.impl)
@@ -490,7 +490,7 @@ func (b Builder) BinOp(op token.Token, x, y Expr) Expr {
llop := logicOpToLLVM[op-logicOpBase] llop := logicOpToLLVM[op-logicOpBase]
return Expr{llvm.CreateBinOp(b.impl, llop, x.impl, y.impl), x.Type} return Expr{llvm.CreateBinOp(b.impl, llop, x.impl, y.impl), x.Type}
} }
case isPredOp(op): // op: == != < <= < >= case isPredOp(op): // op: == != < <= > >=
prog := b.Prog prog := b.Prog
tret := prog.Bool() tret := prog.Bool()
kind := x.kind kind := x.kind
@@ -1014,6 +1014,22 @@ func (b Builder) Do(da DoAction, fn Expr, args ...Expr) (ret Expr) {
return return
} }
// compareSelect performs a series of comparisons and selections based on the
// given comparison op. It's used to implement operations like min and max.
//
// The function iterates through the provided expressions, comparing each with
// the current result using the specified comparison op. It selects the
// appropriate value based on the comparison.
func (b Builder) compareSelect(op token.Token, x Expr, y ...Expr) Expr {
ret := x
for _, v := range y {
cond := b.BinOp(op, ret, v)
sel := llvm.CreateSelect(b.impl, cond.impl, ret.impl, v.impl)
ret = Expr{sel, ret.Type}
}
return ret
}
// A Builtin represents a specific use of a built-in function, e.g. len. // A Builtin represents a specific use of a built-in function, e.g. len.
// //
// Builtins are immutable values. Builtins do not have addresses. // Builtins are immutable values. Builtins do not have addresses.
@@ -1124,6 +1140,14 @@ func (b Builder) BuiltinCall(fn string, args ...Expr) (ret Expr) {
b.Call(b.Pkg.rtFunc("MapClear"), t, m) b.Call(b.Pkg.rtFunc("MapClear"), t, m)
return return
} }
case "min":
if len(args) > 0 {
return b.compareSelect(token.LSS, args[0], args[1:]...)
}
case "max":
if len(args) > 0 {
return b.compareSelect(token.GTR, args[0], args[1:]...)
}
} }
panic("todo: " + fn) panic("todo: " + fn)
} }

View File

@@ -521,3 +521,34 @@ func TestBasicType(t *testing.T) {
} }
} }
} }
func TestCompareSelect(t *testing.T) {
prog := NewProgram(nil)
pkg := prog.NewPackage("bar", "foo/bar")
params := types.NewTuple(
types.NewVar(0, nil, "a", types.Typ[types.Int]),
types.NewVar(0, nil, "b", types.Typ[types.Int]),
types.NewVar(0, nil, "c", types.Typ[types.Int]),
)
rets := types.NewTuple(types.NewVar(0, nil, "", types.Typ[types.Int]))
sig := types.NewSignatureType(nil, nil, nil, params, rets, false)
fn := pkg.NewFunc("fn", sig, InGo)
b := fn.MakeBody(1)
result := b.compareSelect(token.GTR, fn.Param(0), fn.Param(1), fn.Param(2))
b.Return(result)
assertPkg(t, pkg, `; ModuleID = 'foo/bar'
source_filename = "foo/bar"
define i64 @fn(i64 %0, i64 %1, i64 %2) {
_llgo_0:
%3 = icmp sgt i64 %0, %1
%4 = select i1 %3, i64 %0, i64 %1
%5 = icmp sgt i64 %4, %2
%6 = select i1 %5, i64 %4, i64 %2
ret i64 %6
}
`)
}