diff --git a/cl/_testrt/complex/in.go b/cl/_testrt/complex/in.go new file mode 100644 index 00000000..1457c2c5 --- /dev/null +++ b/cl/_testrt/complex/in.go @@ -0,0 +1,20 @@ +package main + +type T complex64 + +func main() { + a := 1 + 2i + b := 3 + 4i + c := 0 + 0i + println(real(a), imag(a)) + println(-a) + println(a + b) + println(a - b) + println(a * b) + println(a / b) + println(a / c) + println(c / c) + println(a == a, a != a) + println(a == b, a != b) + println(complex128(T(a)) == a) +} diff --git a/cl/_testrt/complex/out.ll b/cl/_testrt/complex/out.ll new file mode 100644 index 00000000..431168f7 --- /dev/null +++ b/cl/_testrt/complex/out.ll @@ -0,0 +1,129 @@ +; ModuleID = 'main' +source_filename = "main" + +@"main.init$guard" = global i1 false, align 1 +@__llgo_argc = global i32 0, align 4 +@__llgo_argv = global ptr null, align 8 + +define void @main.init() { +_llgo_0: + %0 = load i1, ptr @"main.init$guard", align 1 + br i1 %0, label %_llgo_2, label %_llgo_1 + +_llgo_1: ; preds = %_llgo_0 + store i1 true, ptr @"main.init$guard", align 1 + br label %_llgo_2 + +_llgo_2: ; preds = %_llgo_1, %_llgo_0 + ret void +} + +define i32 @main(i32 %0, ptr %1) { +_llgo_0: + store i32 %0, ptr @__llgo_argc, align 4 + store ptr %1, ptr @__llgo_argv, align 8 + call void @"github.com/goplus/llgo/internal/runtime.init"() + call void @main.init() + call void @"github.com/goplus/llgo/internal/runtime.PrintFloat"(double 1.000000e+00) + call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 32) + call void @"github.com/goplus/llgo/internal/runtime.PrintFloat"(double 2.000000e+00) + call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 10) + %2 = alloca { double, double }, align 8 + %3 = getelementptr inbounds { double, double }, ptr %2, i32 0, i32 0 + store double -1.000000e+00, ptr %3, align 8 + %4 = getelementptr inbounds { double, double }, ptr %2, i32 0, i32 1 + store double -2.000000e+00, ptr %4, align 8 + %5 = load { double, double }, ptr %2, align 8 + call void @"github.com/goplus/llgo/internal/runtime.PrintComplex"({ double, double } %5) + call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 10) + %6 = alloca { double, double }, align 8 + %7 = getelementptr inbounds { double, double }, ptr %6, i32 0, i32 0 + store double 4.000000e+00, ptr %7, align 8 + %8 = getelementptr inbounds { double, double }, ptr %6, i32 0, i32 1 + store double 6.000000e+00, ptr %8, align 8 + %9 = load { double, double }, ptr %6, align 8 + call void @"github.com/goplus/llgo/internal/runtime.PrintComplex"({ double, double } %9) + call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 10) + %10 = alloca { double, double }, align 8 + %11 = getelementptr inbounds { double, double }, ptr %10, i32 0, i32 0 + store double -2.000000e+00, ptr %11, align 8 + %12 = getelementptr inbounds { double, double }, ptr %10, i32 0, i32 1 + store double -2.000000e+00, ptr %12, align 8 + %13 = load { double, double }, ptr %10, align 8 + call void @"github.com/goplus/llgo/internal/runtime.PrintComplex"({ double, double } %13) + call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 10) + %14 = alloca { double, double }, align 8 + %15 = getelementptr inbounds { double, double }, ptr %14, i32 0, i32 0 + store double -5.000000e+00, ptr %15, align 8 + %16 = getelementptr inbounds { double, double }, ptr %14, i32 0, i32 1 + store double 1.000000e+01, ptr %16, align 8 + %17 = load { double, double }, ptr %14, align 8 + call void @"github.com/goplus/llgo/internal/runtime.PrintComplex"({ double, double } %17) + call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 10) + %18 = alloca { double, double }, align 8 + %19 = getelementptr inbounds { double, double }, ptr %18, i32 0, i32 0 + store double 4.400000e-01, ptr %19, align 8 + %20 = getelementptr inbounds { double, double }, ptr %18, i32 0, i32 1 + store double -8.000000e-02, ptr %20, align 8 + %21 = load { double, double }, ptr %18, align 8 + call void @"github.com/goplus/llgo/internal/runtime.PrintComplex"({ double, double } %21) + call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 10) + %22 = alloca { double, double }, align 8 + %23 = getelementptr inbounds { double, double }, ptr %22, i32 0, i32 0 + store double 0x7FF0000000000000, ptr %23, align 8 + %24 = getelementptr inbounds { double, double }, ptr %22, i32 0, i32 1 + store double 0x7FF0000000000000, ptr %24, align 8 + %25 = load { double, double }, ptr %22, align 8 + call void @"github.com/goplus/llgo/internal/runtime.PrintComplex"({ double, double } %25) + call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 10) + %26 = alloca { double, double }, align 8 + %27 = getelementptr inbounds { double, double }, ptr %26, i32 0, i32 0 + store double 0x7FF8000000000000, ptr %27, align 8 + %28 = getelementptr inbounds { double, double }, ptr %26, i32 0, i32 1 + store double 0x7FF8000000000000, ptr %28, align 8 + %29 = load { double, double }, ptr %26, align 8 + call void @"github.com/goplus/llgo/internal/runtime.PrintComplex"({ double, double } %29) + call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 10) + call void @"github.com/goplus/llgo/internal/runtime.PrintBool"(i1 true) + call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 32) + call void @"github.com/goplus/llgo/internal/runtime.PrintBool"(i1 false) + call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 10) + call void @"github.com/goplus/llgo/internal/runtime.PrintBool"(i1 false) + call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 32) + call void @"github.com/goplus/llgo/internal/runtime.PrintBool"(i1 true) + call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 10) + %30 = alloca { float, float }, align 8 + %31 = getelementptr inbounds { float, float }, ptr %30, i32 0, i32 0 + store float 1.000000e+00, ptr %31, align 4 + %32 = getelementptr inbounds { float, float }, ptr %30, i32 0, i32 1 + store float 2.000000e+00, ptr %32, align 4 + %33 = load { float, float }, ptr %30, align 4 + %34 = extractvalue { float, float } %33, 0 + %35 = extractvalue { float, float } %33, 1 + %36 = fpext float %34 to double + %37 = fpext float %35 to double + %38 = alloca { double, double }, align 8 + %39 = getelementptr inbounds { double, double }, ptr %38, i32 0, i32 0 + store double %36, ptr %39, align 8 + %40 = getelementptr inbounds { double, double }, ptr %38, i32 0, i32 1 + store double %37, ptr %40, align 8 + %41 = load { double, double }, ptr %38, align 8 + %42 = extractvalue { double, double } %41, 0 + %43 = extractvalue { double, double } %41, 1 + %44 = fcmp oeq double %42, 1.000000e+00 + %45 = fcmp oeq double %43, 2.000000e+00 + %46 = and i1 %44, %45 + call void @"github.com/goplus/llgo/internal/runtime.PrintBool"(i1 %46) + call void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8 10) + ret i32 0 +} + +declare void @"github.com/goplus/llgo/internal/runtime.init"() + +declare void @"github.com/goplus/llgo/internal/runtime.PrintFloat"(double) + +declare void @"github.com/goplus/llgo/internal/runtime.PrintByte"(i8) + +declare void @"github.com/goplus/llgo/internal/runtime.PrintComplex"({ double, double }) + +declare void @"github.com/goplus/llgo/internal/runtime.PrintBool"(i1) diff --git a/internal/runtime/z_print.go b/internal/runtime/z_print.go index cbf9d83d..6c426edc 100644 --- a/internal/runtime/z_print.go +++ b/internal/runtime/z_print.go @@ -47,12 +47,12 @@ func PrintFloat(v float64) { } return } - c.Fprintf(c.Stderr, c.Str("%e"), v) + c.Fprintf(c.Stderr, c.Str("%+e"), v) } -// func PrintComplex(c complex128) { -// print("(", real(c), imag(c), "i)") -// } +func PrintComplex(v complex128) { + print("(", real(v), imag(v), "i)") +} func PrintUint(v uint64) { c.Fprintf(c.Stderr, c.Str("%llu"), v) diff --git a/internal/runtime/z_rt.go b/internal/runtime/z_rt.go index 6653ff3d..eb9bbac9 100644 --- a/internal/runtime/z_rt.go +++ b/internal/runtime/z_rt.go @@ -122,6 +122,10 @@ func TracePanic(v any) { } else { println(*(*float64)(e.data)) } + case abi.Complex64: + println(*(*complex64)(e.data)) + case abi.Complex128: + println(*(*complex128)(e.data)) case abi.String: println(*(*string)(e.data)) default: diff --git a/ssa/abi/abi.go b/ssa/abi/abi.go index 127e36cb..90186b57 100644 --- a/ssa/abi/abi.go +++ b/ssa/abi/abi.go @@ -89,7 +89,7 @@ func DataKindOf(raw types.Type, lvl int, is32Bits bool) (DataKind, types.Type, i return Integer, raw, lvl case kind == types.Float32: return BitCast, raw, lvl - case kind == types.Float64 || kind == types.Complex64: + case kind == types.Float64: if is32Bits { return Indirect, raw, lvl } diff --git a/ssa/expr.go b/ssa/expr.go index ab6ecbfd..62e80e17 100644 --- a/ssa/expr.go +++ b/ssa/expr.go @@ -398,6 +398,52 @@ func (b Builder) BinOp(op token.Token, x, y Expr) Expr { return Expr{b.InlineCall(b.Pkg.rtFunc("StringCat"), x, y).impl, x.Type} } case vkComplex: + xr, xi := b.impl.CreateExtractValue(x.impl, 0, ""), b.impl.CreateExtractValue(x.impl, 1, "") + yr, yi := b.impl.CreateExtractValue(y.impl, 0, ""), b.impl.CreateExtractValue(y.impl, 1, "") + switch op { + case token.ADD: + r := llvm.CreateBinOp(b.impl, llvm.FAdd, xr, yr) + i := llvm.CreateBinOp(b.impl, llvm.FAdd, xi, yi) + return b.aggregateValue(x.Type, r, i) + case token.SUB: + r := llvm.CreateBinOp(b.impl, llvm.FSub, xr, yr) + i := llvm.CreateBinOp(b.impl, llvm.FSub, xi, yi) + return b.aggregateValue(x.Type, r, i) + case token.MUL: + r := llvm.CreateBinOp(b.impl, llvm.FSub, + llvm.CreateBinOp(b.impl, llvm.FMul, xr, yr), + llvm.CreateBinOp(b.impl, llvm.FMul, xi, yi), + ) + i := llvm.CreateBinOp(b.impl, llvm.FAdd, + llvm.CreateBinOp(b.impl, llvm.FMul, xr, yi), + llvm.CreateBinOp(b.impl, llvm.FMul, xi, yr), + ) + return b.aggregateValue(x.Type, r, i) + case token.QUO: + d := llvm.CreateBinOp(b.impl, llvm.FAdd, llvm.CreateBinOp(b.impl, llvm.FMul, yr, yr), llvm.CreateBinOp(b.impl, llvm.FMul, yi, yi)) + zero := llvm.CreateFCmp(b.impl, llvm.FloatOEQ, d, llvm.ConstNull(d.Type())) + r := llvm.CreateSelect(b.impl, zero, + llvm.CreateBinOp(b.impl, llvm.FDiv, xr, d), + llvm.CreateBinOp(b.impl, llvm.FDiv, + llvm.CreateBinOp(b.impl, llvm.FAdd, + llvm.CreateBinOp(b.impl, llvm.FMul, xr, yr), + llvm.CreateBinOp(b.impl, llvm.FMul, xi, yi), + ), + d, + ), + ) + i := llvm.CreateSelect(b.impl, zero, + llvm.CreateBinOp(b.impl, llvm.FDiv, xi, d), + llvm.CreateBinOp(b.impl, llvm.FDiv, + llvm.CreateBinOp(b.impl, llvm.FSub, + llvm.CreateBinOp(b.impl, llvm.FMul, xr, yi), + llvm.CreateBinOp(b.impl, llvm.FMul, xi, yr), + ), + d, + ), + ) + return b.aggregateValue(x.Type, r, i) + } default: idx := mathOpIdx(op, kind) if llop := mathOpToLLVM[idx]; llop != 0 { @@ -453,7 +499,25 @@ func (b Builder) BinOp(op token.Token, x, y Expr) Expr { case vkBool: pred := boolPredOpToLLVM[op-predOpBase] return Expr{llvm.CreateICmp(b.impl, pred, x.impl, y.impl), tret} - case vkString, vkComplex: + case vkComplex: + switch op { + case token.EQL: + xr, xi := b.impl.CreateExtractValue(x.impl, 0, ""), b.impl.CreateExtractValue(x.impl, 1, "") + yr, yi := b.impl.CreateExtractValue(y.impl, 0, ""), b.impl.CreateExtractValue(y.impl, 1, "") + return Expr{llvm.CreateAnd(b.impl, + llvm.CreateFCmp(b.impl, llvm.FloatOEQ, xr, yr), + llvm.CreateFCmp(b.impl, llvm.FloatOEQ, xi, yi), + ), tret} + case token.NEQ: + xr, xi := b.impl.CreateExtractValue(x.impl, 0, ""), b.impl.CreateExtractValue(x.impl, 1, "") + yr, yi := b.impl.CreateExtractValue(y.impl, 0, ""), b.impl.CreateExtractValue(y.impl, 1, "") + return Expr{b.impl.CreateOr( + llvm.CreateFCmp(b.impl, llvm.FloatUNE, xr, yr), + llvm.CreateFCmp(b.impl, llvm.FloatUNE, xi, yi), + "", + ), tret} + } + case vkString: switch op { case token.EQL: return b.InlineCall(b.Pkg.rtFunc("StringEqual"), x, y) @@ -567,6 +631,10 @@ func (b Builder) UnOp(op token.Token, x Expr) (ret Expr) { ret.impl = llvm.CreateNeg(b.impl, x.impl) } else if t.Info()&types.IsFloat != 0 { ret.impl = llvm.CreateFNeg(b.impl, x.impl) + } else if t.Info()&types.IsComplex != 0 { + r := b.impl.CreateExtractValue(x.impl, 0, "") + i := b.impl.CreateExtractValue(x.impl, 1, "") + return b.aggregateValue(x.Type, llvm.CreateFNeg(b.impl, r), llvm.CreateFNeg(b.impl, i)) } else { panic("todo") } @@ -685,6 +753,18 @@ func (b Builder) Convert(t Type, x Expr) (ret Expr) { ret.impl = b.InlineCall(b.Func.Pkg.rtFunc("StringFromRune"), x).impl return } + case types.Complex128: + switch xtyp := x.RawType().Underlying().(type) { + case *types.Basic: + if xtyp.Kind() == types.Complex64 { + r := b.impl.CreateExtractValue(x.impl, 0, "") + i := b.impl.CreateExtractValue(x.impl, 1, "") + r = castFloat(b, r, b.Prog.Float64()) + i = castFloat(b, i, b.Prog.Float64()) + ret.impl = b.aggregateValue(t, r, i).impl + return + } + } } switch xtyp := x.RawType().Underlying().(type) { case *types.Basic: @@ -716,6 +796,16 @@ func (b Builder) Convert(t Type, x Expr) (ret Expr) { } } } + if x.kind == vkComplex && t.kind == vkComplex { + ft := b.Prog.Float64() + if t.raw.Type.Underlying().(*types.Basic).Kind() == types.Complex64 { + ft = b.Prog.Float32() + } + r := b.impl.CreateExtractValue(x.impl, 0, "") + i := b.impl.CreateExtractValue(x.impl, 1, "") + ret.impl = b.Complex(Expr{castFloat(b, r, ft), ft}, Expr{castFloat(b, i, ft), ft}).impl + return + } case *types.Pointer: ret.impl = castPtr(b.impl, x.impl, t.ll) return @@ -1015,8 +1105,9 @@ func (b Builder) PrintEx(ln bool, args ...Expr) (ret Expr) { fn = "PrintEface" case vkIface: fn = "PrintIface" - // case vkComplex: - // fn = "PrintComplex" + case vkComplex: + fn = "PrintComplex" + typ = prog.Complex128() default: panic(fmt.Errorf("illegal types for operand: print %v", arg.RawType())) } diff --git a/ssa/type.go b/ssa/type.go index cc192ad5..16a98dd6 100644 --- a/ssa/type.go +++ b/ssa/type.go @@ -211,7 +211,7 @@ func (p Program) Index(typ Type) Type { func (p Program) Field(typ Type, i int) Type { var fld *types.Var - switch t := typ.raw.Type.(type) { + switch t := typ.raw.Type.Underlying().(type) { case *types.Tuple: fld = t.At(i) case *types.Basic: @@ -223,7 +223,7 @@ func (p Program) Field(typ Type, i int) Type { } panic("Field: basic type doesn't have fields") default: - fld = t.Underlying().(*types.Struct).Field(i) + fld = t.(*types.Struct).Field(i) } return p.rawType(fld.Type()) }