diff --git a/ssa/expr.go b/ssa/expr.go index d1b158dd..391f66e3 100644 --- a/ssa/expr.go +++ b/ssa/expr.go @@ -383,7 +383,7 @@ var boolPredOpToLLVM = []llvm.IntPredicate{ token.NEQ - predOpBase: llvm.IntNE, } -// EQL NEQ LSS LEQ GTR GEQ == != < <= < >= +// EQL NEQ LSS LEQ GTR GEQ == != < <= > >= func isPredOp(op token.Token) bool { return op >= predOpBase && op <= predOpLast } @@ -392,7 +392,7 @@ func isPredOp(op token.Token) bool { // op can be: // ADD SUB MUL QUO REM + - * / % // AND OR XOR SHL SHR AND_NOT & | ^ << >> &^ -// EQL NEQ LSS LEQ GTR GEQ == != < <= < >= +// EQL NEQ LSS LEQ GTR GEQ == != < <= > >= func (b Builder) BinOp(op token.Token, x, y Expr) Expr { if debugInstr { log.Printf("BinOp %d, %v, %v\n", op, x.impl, y.impl) @@ -490,7 +490,7 @@ func (b Builder) BinOp(op token.Token, x, y Expr) Expr { llop := logicOpToLLVM[op-logicOpBase] return Expr{llvm.CreateBinOp(b.impl, llop, x.impl, y.impl), x.Type} } - case isPredOp(op): // op: == != < <= < >= + case isPredOp(op): // op: == != < <= > >= prog := b.Prog tret := prog.Bool() kind := x.kind @@ -1014,6 +1014,22 @@ func (b Builder) Do(da DoAction, fn Expr, args ...Expr) (ret Expr) { return } +// compareSelect performs a series of comparisons and selections based on the +// given comparison op. It's used to implement operations like min and max. +// +// The function iterates through the provided expressions, comparing each with +// the current result using the specified comparison op. It selects the +// appropriate value based on the comparison. +func (b Builder) compareSelect(op token.Token, x Expr, y ...Expr) Expr { + ret := x + for _, v := range y { + cond := b.BinOp(op, ret, v) + sel := llvm.CreateSelect(b.impl, cond.impl, ret.impl, v.impl) + ret = Expr{sel, ret.Type} + } + return ret +} + // A Builtin represents a specific use of a built-in function, e.g. len. // // Builtins are immutable values. Builtins do not have addresses. @@ -1124,6 +1140,14 @@ func (b Builder) BuiltinCall(fn string, args ...Expr) (ret Expr) { b.Call(b.Pkg.rtFunc("MapClear"), t, m) return } + case "min": + if len(args) > 0 { + return b.compareSelect(token.LSS, args[0], args[1:]...) + } + case "max": + if len(args) > 0 { + return b.compareSelect(token.GTR, args[0], args[1:]...) + } } panic("todo: " + fn) } diff --git a/ssa/ssa_test.go b/ssa/ssa_test.go index 81cd739f..4946869d 100644 --- a/ssa/ssa_test.go +++ b/ssa/ssa_test.go @@ -521,3 +521,34 @@ func TestBasicType(t *testing.T) { } } } + +func TestCompareSelect(t *testing.T) { + prog := NewProgram(nil) + pkg := prog.NewPackage("bar", "foo/bar") + + params := types.NewTuple( + types.NewVar(0, nil, "a", types.Typ[types.Int]), + types.NewVar(0, nil, "b", types.Typ[types.Int]), + types.NewVar(0, nil, "c", types.Typ[types.Int]), + ) + rets := types.NewTuple(types.NewVar(0, nil, "", types.Typ[types.Int])) + sig := types.NewSignatureType(nil, nil, nil, params, rets, false) + fn := pkg.NewFunc("fn", sig, InGo) + + b := fn.MakeBody(1) + result := b.compareSelect(token.GTR, fn.Param(0), fn.Param(1), fn.Param(2)) + b.Return(result) + + assertPkg(t, pkg, `; ModuleID = 'foo/bar' +source_filename = "foo/bar" + +define i64 @fn(i64 %0, i64 %1, i64 %2) { +_llgo_0: + %3 = icmp sgt i64 %0, %1 + %4 = select i1 %3, i64 %0, i64 %1 + %5 = icmp sgt i64 %4, %2 + %6 = select i1 %5, i64 %4, i64 %2 + ret i64 %6 +} +`) +}