Merge pull request #373 from visualfc/complex

ssa: complex op and print/panic
This commit is contained in:
xushiwei
2024-06-20 22:05:04 +08:00
committed by GitHub
7 changed files with 254 additions and 10 deletions

20
cl/_testrt/complex/in.go Normal file
View File

@@ -0,0 +1,20 @@
package main
type T complex64
func main() {
a := 1 + 2i
b := 3 + 4i
c := 0 + 0i
println(real(a), imag(a))
println(-a)
println(a + b)
println(a - b)
println(a * b)
println(a / b)
println(a / c)
println(c / c)
println(a == a, a != a)
println(a == b, a != b)
println(complex128(T(a)) == a)
}

129
cl/_testrt/complex/out.ll Normal file
View File

@@ -0,0 +1,129 @@
; ModuleID = 'main'
source_filename = "main"
@"main.init$guard" = global i1 false, align 1
@__llgo_argc = global i32 0, align 4
@__llgo_argv = global ptr null, align 8
define void @main.init() {
_llgo_0:
%0 = load i1, ptr @"main.init$guard", align 1
br i1 %0, label %_llgo_2, label %_llgo_1
_llgo_1: ; preds = %_llgo_0
store i1 true, ptr @"main.init$guard", align 1
br label %_llgo_2
_llgo_2: ; preds = %_llgo_1, %_llgo_0
ret void
}
define i32 @main(i32 %0, ptr %1) {
_llgo_0:
store i32 %0, ptr @__llgo_argc, align 4
store ptr %1, ptr @__llgo_argv, align 8
call void @"github.com/goplus/llgo/internal/runtime.init"()
call void @main.init()
call void @"github.com/goplus/llgo/internal/runtime.PrintFloat"(double 1.000000e+00)
call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 32)
call void @"github.com/goplus/llgo/internal/runtime.PrintFloat"(double 2.000000e+00)
call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 10)
%2 = alloca { double, double }, align 8
%3 = getelementptr inbounds { double, double }, ptr %2, i32 0, i32 0
store double -1.000000e+00, ptr %3, align 8
%4 = getelementptr inbounds { double, double }, ptr %2, i32 0, i32 1
store double -2.000000e+00, ptr %4, align 8
%5 = load { double, double }, ptr %2, align 8
call void @"github.com/goplus/llgo/internal/runtime.PrintComplex"({ double, double } %5)
call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 10)
%6 = alloca { double, double }, align 8
%7 = getelementptr inbounds { double, double }, ptr %6, i32 0, i32 0
store double 4.000000e+00, ptr %7, align 8
%8 = getelementptr inbounds { double, double }, ptr %6, i32 0, i32 1
store double 6.000000e+00, ptr %8, align 8
%9 = load { double, double }, ptr %6, align 8
call void @"github.com/goplus/llgo/internal/runtime.PrintComplex"({ double, double } %9)
call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 10)
%10 = alloca { double, double }, align 8
%11 = getelementptr inbounds { double, double }, ptr %10, i32 0, i32 0
store double -2.000000e+00, ptr %11, align 8
%12 = getelementptr inbounds { double, double }, ptr %10, i32 0, i32 1
store double -2.000000e+00, ptr %12, align 8
%13 = load { double, double }, ptr %10, align 8
call void @"github.com/goplus/llgo/internal/runtime.PrintComplex"({ double, double } %13)
call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 10)
%14 = alloca { double, double }, align 8
%15 = getelementptr inbounds { double, double }, ptr %14, i32 0, i32 0
store double -5.000000e+00, ptr %15, align 8
%16 = getelementptr inbounds { double, double }, ptr %14, i32 0, i32 1
store double 1.000000e+01, ptr %16, align 8
%17 = load { double, double }, ptr %14, align 8
call void @"github.com/goplus/llgo/internal/runtime.PrintComplex"({ double, double } %17)
call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 10)
%18 = alloca { double, double }, align 8
%19 = getelementptr inbounds { double, double }, ptr %18, i32 0, i32 0
store double 4.400000e-01, ptr %19, align 8
%20 = getelementptr inbounds { double, double }, ptr %18, i32 0, i32 1
store double -8.000000e-02, ptr %20, align 8
%21 = load { double, double }, ptr %18, align 8
call void @"github.com/goplus/llgo/internal/runtime.PrintComplex"({ double, double } %21)
call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 10)
%22 = alloca { double, double }, align 8
%23 = getelementptr inbounds { double, double }, ptr %22, i32 0, i32 0
store double 0x7FF0000000000000, ptr %23, align 8
%24 = getelementptr inbounds { double, double }, ptr %22, i32 0, i32 1
store double 0x7FF0000000000000, ptr %24, align 8
%25 = load { double, double }, ptr %22, align 8
call void @"github.com/goplus/llgo/internal/runtime.PrintComplex"({ double, double } %25)
call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 10)
%26 = alloca { double, double }, align 8
%27 = getelementptr inbounds { double, double }, ptr %26, i32 0, i32 0
store double 0x7FF8000000000000, ptr %27, align 8
%28 = getelementptr inbounds { double, double }, ptr %26, i32 0, i32 1
store double 0x7FF8000000000000, ptr %28, align 8
%29 = load { double, double }, ptr %26, align 8
call void @"github.com/goplus/llgo/internal/runtime.PrintComplex"({ double, double } %29)
call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 10)
call void @"github.com/goplus/llgo/internal/runtime.PrintBool"(i1 true)
call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 32)
call void @"github.com/goplus/llgo/internal/runtime.PrintBool"(i1 false)
call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 10)
call void @"github.com/goplus/llgo/internal/runtime.PrintBool"(i1 false)
call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 32)
call void @"github.com/goplus/llgo/internal/runtime.PrintBool"(i1 true)
call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 10)
%30 = alloca { float, float }, align 8
%31 = getelementptr inbounds { float, float }, ptr %30, i32 0, i32 0
store float 1.000000e+00, ptr %31, align 4
%32 = getelementptr inbounds { float, float }, ptr %30, i32 0, i32 1
store float 2.000000e+00, ptr %32, align 4
%33 = load { float, float }, ptr %30, align 4
%34 = extractvalue { float, float } %33, 0
%35 = extractvalue { float, float } %33, 1
%36 = fpext float %34 to double
%37 = fpext float %35 to double
%38 = alloca { double, double }, align 8
%39 = getelementptr inbounds { double, double }, ptr %38, i32 0, i32 0
store double %36, ptr %39, align 8
%40 = getelementptr inbounds { double, double }, ptr %38, i32 0, i32 1
store double %37, ptr %40, align 8
%41 = load { double, double }, ptr %38, align 8
%42 = extractvalue { double, double } %41, 0
%43 = extractvalue { double, double } %41, 1
%44 = fcmp oeq double %42, 1.000000e+00
%45 = fcmp oeq double %43, 2.000000e+00
%46 = and i1 %44, %45
call void @"github.com/goplus/llgo/internal/runtime.PrintBool"(i1 %46)
call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 10)
ret i32 0
}
declare void @"github.com/goplus/llgo/internal/runtime.init"()
declare void @"github.com/goplus/llgo/internal/runtime.PrintFloat"(double)
declare void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8)
declare void @"github.com/goplus/llgo/internal/runtime.PrintComplex"({ double, double })
declare void @"github.com/goplus/llgo/internal/runtime.PrintBool"(i1)

View File

@@ -47,12 +47,12 @@ func PrintFloat(v float64) {
} }
return return
} }
c.Fprintf(c.Stderr, c.Str("%e"), v) c.Fprintf(c.Stderr, c.Str("%+e"), v)
} }
// func PrintComplex(c complex128) { func PrintComplex(v complex128) {
// print("(", real(c), imag(c), "i)") print("(", real(v), imag(v), "i)")
// } }
func PrintUint(v uint64) { func PrintUint(v uint64) {
c.Fprintf(c.Stderr, c.Str("%llu"), v) c.Fprintf(c.Stderr, c.Str("%llu"), v)

View File

@@ -122,6 +122,10 @@ func TracePanic(v any) {
} else { } else {
println(*(*float64)(e.data)) println(*(*float64)(e.data))
} }
case abi.Complex64:
println(*(*complex64)(e.data))
case abi.Complex128:
println(*(*complex128)(e.data))
case abi.String: case abi.String:
println(*(*string)(e.data)) println(*(*string)(e.data))
default: default:

View File

@@ -89,7 +89,7 @@ func DataKindOf(raw types.Type, lvl int, is32Bits bool) (DataKind, types.Type, i
return Integer, raw, lvl return Integer, raw, lvl
case kind == types.Float32: case kind == types.Float32:
return BitCast, raw, lvl return BitCast, raw, lvl
case kind == types.Float64 || kind == types.Complex64: case kind == types.Float64:
if is32Bits { if is32Bits {
return Indirect, raw, lvl return Indirect, raw, lvl
} }

View File

@@ -398,6 +398,52 @@ func (b Builder) BinOp(op token.Token, x, y Expr) Expr {
return Expr{b.InlineCall(b.Pkg.rtFunc("StringCat"), x, y).impl, x.Type} return Expr{b.InlineCall(b.Pkg.rtFunc("StringCat"), x, y).impl, x.Type}
} }
case vkComplex: case vkComplex:
xr, xi := b.impl.CreateExtractValue(x.impl, 0, ""), b.impl.CreateExtractValue(x.impl, 1, "")
yr, yi := b.impl.CreateExtractValue(y.impl, 0, ""), b.impl.CreateExtractValue(y.impl, 1, "")
switch op {
case token.ADD:
r := llvm.CreateBinOp(b.impl, llvm.FAdd, xr, yr)
i := llvm.CreateBinOp(b.impl, llvm.FAdd, xi, yi)
return b.aggregateValue(x.Type, r, i)
case token.SUB:
r := llvm.CreateBinOp(b.impl, llvm.FSub, xr, yr)
i := llvm.CreateBinOp(b.impl, llvm.FSub, xi, yi)
return b.aggregateValue(x.Type, r, i)
case token.MUL:
r := llvm.CreateBinOp(b.impl, llvm.FSub,
llvm.CreateBinOp(b.impl, llvm.FMul, xr, yr),
llvm.CreateBinOp(b.impl, llvm.FMul, xi, yi),
)
i := llvm.CreateBinOp(b.impl, llvm.FAdd,
llvm.CreateBinOp(b.impl, llvm.FMul, xr, yi),
llvm.CreateBinOp(b.impl, llvm.FMul, xi, yr),
)
return b.aggregateValue(x.Type, r, i)
case token.QUO:
d := llvm.CreateBinOp(b.impl, llvm.FAdd, llvm.CreateBinOp(b.impl, llvm.FMul, yr, yr), llvm.CreateBinOp(b.impl, llvm.FMul, yi, yi))
zero := llvm.CreateFCmp(b.impl, llvm.FloatOEQ, d, llvm.ConstNull(d.Type()))
r := llvm.CreateSelect(b.impl, zero,
llvm.CreateBinOp(b.impl, llvm.FDiv, xr, d),
llvm.CreateBinOp(b.impl, llvm.FDiv,
llvm.CreateBinOp(b.impl, llvm.FAdd,
llvm.CreateBinOp(b.impl, llvm.FMul, xr, yr),
llvm.CreateBinOp(b.impl, llvm.FMul, xi, yi),
),
d,
),
)
i := llvm.CreateSelect(b.impl, zero,
llvm.CreateBinOp(b.impl, llvm.FDiv, xi, d),
llvm.CreateBinOp(b.impl, llvm.FDiv,
llvm.CreateBinOp(b.impl, llvm.FSub,
llvm.CreateBinOp(b.impl, llvm.FMul, xr, yi),
llvm.CreateBinOp(b.impl, llvm.FMul, xi, yr),
),
d,
),
)
return b.aggregateValue(x.Type, r, i)
}
default: default:
idx := mathOpIdx(op, kind) idx := mathOpIdx(op, kind)
if llop := mathOpToLLVM[idx]; llop != 0 { if llop := mathOpToLLVM[idx]; llop != 0 {
@@ -453,7 +499,25 @@ func (b Builder) BinOp(op token.Token, x, y Expr) Expr {
case vkBool: case vkBool:
pred := boolPredOpToLLVM[op-predOpBase] pred := boolPredOpToLLVM[op-predOpBase]
return Expr{llvm.CreateICmp(b.impl, pred, x.impl, y.impl), tret} return Expr{llvm.CreateICmp(b.impl, pred, x.impl, y.impl), tret}
case vkString, vkComplex: case vkComplex:
switch op {
case token.EQL:
xr, xi := b.impl.CreateExtractValue(x.impl, 0, ""), b.impl.CreateExtractValue(x.impl, 1, "")
yr, yi := b.impl.CreateExtractValue(y.impl, 0, ""), b.impl.CreateExtractValue(y.impl, 1, "")
return Expr{llvm.CreateAnd(b.impl,
llvm.CreateFCmp(b.impl, llvm.FloatOEQ, xr, yr),
llvm.CreateFCmp(b.impl, llvm.FloatOEQ, xi, yi),
), tret}
case token.NEQ:
xr, xi := b.impl.CreateExtractValue(x.impl, 0, ""), b.impl.CreateExtractValue(x.impl, 1, "")
yr, yi := b.impl.CreateExtractValue(y.impl, 0, ""), b.impl.CreateExtractValue(y.impl, 1, "")
return Expr{b.impl.CreateOr(
llvm.CreateFCmp(b.impl, llvm.FloatUNE, xr, yr),
llvm.CreateFCmp(b.impl, llvm.FloatUNE, xi, yi),
"",
), tret}
}
case vkString:
switch op { switch op {
case token.EQL: case token.EQL:
return b.InlineCall(b.Pkg.rtFunc("StringEqual"), x, y) return b.InlineCall(b.Pkg.rtFunc("StringEqual"), x, y)
@@ -567,6 +631,10 @@ func (b Builder) UnOp(op token.Token, x Expr) (ret Expr) {
ret.impl = llvm.CreateNeg(b.impl, x.impl) ret.impl = llvm.CreateNeg(b.impl, x.impl)
} else if t.Info()&types.IsFloat != 0 { } else if t.Info()&types.IsFloat != 0 {
ret.impl = llvm.CreateFNeg(b.impl, x.impl) ret.impl = llvm.CreateFNeg(b.impl, x.impl)
} else if t.Info()&types.IsComplex != 0 {
r := b.impl.CreateExtractValue(x.impl, 0, "")
i := b.impl.CreateExtractValue(x.impl, 1, "")
return b.aggregateValue(x.Type, llvm.CreateFNeg(b.impl, r), llvm.CreateFNeg(b.impl, i))
} else { } else {
panic("todo") panic("todo")
} }
@@ -685,6 +753,18 @@ func (b Builder) Convert(t Type, x Expr) (ret Expr) {
ret.impl = b.InlineCall(b.Func.Pkg.rtFunc("StringFromRune"), x).impl ret.impl = b.InlineCall(b.Func.Pkg.rtFunc("StringFromRune"), x).impl
return return
} }
case types.Complex128:
switch xtyp := x.RawType().Underlying().(type) {
case *types.Basic:
if xtyp.Kind() == types.Complex64 {
r := b.impl.CreateExtractValue(x.impl, 0, "")
i := b.impl.CreateExtractValue(x.impl, 1, "")
r = castFloat(b, r, b.Prog.Float64())
i = castFloat(b, i, b.Prog.Float64())
ret.impl = b.aggregateValue(t, r, i).impl
return
}
}
} }
switch xtyp := x.RawType().Underlying().(type) { switch xtyp := x.RawType().Underlying().(type) {
case *types.Basic: case *types.Basic:
@@ -716,6 +796,16 @@ func (b Builder) Convert(t Type, x Expr) (ret Expr) {
} }
} }
} }
if x.kind == vkComplex && t.kind == vkComplex {
ft := b.Prog.Float64()
if t.raw.Type.Underlying().(*types.Basic).Kind() == types.Complex64 {
ft = b.Prog.Float32()
}
r := b.impl.CreateExtractValue(x.impl, 0, "")
i := b.impl.CreateExtractValue(x.impl, 1, "")
ret.impl = b.Complex(Expr{castFloat(b, r, ft), ft}, Expr{castFloat(b, i, ft), ft}).impl
return
}
case *types.Pointer: case *types.Pointer:
ret.impl = castPtr(b.impl, x.impl, t.ll) ret.impl = castPtr(b.impl, x.impl, t.ll)
return return
@@ -1015,8 +1105,9 @@ func (b Builder) PrintEx(ln bool, args ...Expr) (ret Expr) {
fn = "PrintEface" fn = "PrintEface"
case vkIface: case vkIface:
fn = "PrintIface" fn = "PrintIface"
// case vkComplex: case vkComplex:
// fn = "PrintComplex" fn = "PrintComplex"
typ = prog.Complex128()
default: default:
panic(fmt.Errorf("illegal types for operand: print %v", arg.RawType())) panic(fmt.Errorf("illegal types for operand: print %v", arg.RawType()))
} }

View File

@@ -211,7 +211,7 @@ func (p Program) Index(typ Type) Type {
func (p Program) Field(typ Type, i int) Type { func (p Program) Field(typ Type, i int) Type {
var fld *types.Var var fld *types.Var
switch t := typ.raw.Type.(type) { switch t := typ.raw.Type.Underlying().(type) {
case *types.Tuple: case *types.Tuple:
fld = t.At(i) fld = t.At(i)
case *types.Basic: case *types.Basic:
@@ -223,7 +223,7 @@ func (p Program) Field(typ Type, i int) Type {
} }
panic("Field: basic type doesn't have fields") panic("Field: basic type doesn't have fields")
default: default:
fld = t.Underlying().(*types.Struct).Field(i) fld = t.(*types.Struct).Field(i)
} }
return p.rawType(fld.Type()) return p.rawType(fld.Type())
} }