From 3605eeeef79f80db5e40623b920d0ea0318a2bcb Mon Sep 17 00:00:00 2001 From: Li Jie Date: Thu, 11 Sep 2025 14:07:58 +0800 Subject: [PATCH] export c header file for build library --- cl/import.go | 17 +- internal/build/build.go | 49 +- internal/header/header.go | 642 ++++++++++++++++++++++++++ internal/header/header_test.go | 791 +++++++++++++++++++++++++++++++++ ssa/package.go | 11 + 5 files changed, 1465 insertions(+), 45 deletions(-) create mode 100644 internal/header/header.go create mode 100644 internal/header/header_test.go diff --git a/cl/import.go b/cl/import.go index 16aee17a..ea472819 100644 --- a/cl/import.go +++ b/cl/import.go @@ -183,7 +183,9 @@ func (p *context) initFiles(pkgPath string, files []*ast.File, cPkg bool) { if !p.initLinknameByDoc(decl.Doc, fullName, inPkgName, false) && cPkg { // package C (https://github.com/goplus/llgo/issues/1165) if decl.Recv == nil && token.IsExported(inPkgName) { - p.prog.SetLinkname(fullName, strings.TrimPrefix(inPkgName, "X")) + exportName := strings.TrimPrefix(inPkgName, "X") + p.prog.SetLinkname(fullName, exportName) + p.pkg.SetExport(fullName, exportName) } } case *ast.GenDecl: @@ -301,19 +303,19 @@ func (p *context) initLinkname(line string, f func(inPkgName string) (fullName s directive = "//go:" ) if strings.HasPrefix(line, linkname) { - p.initLink(line, len(linkname), f) + p.initLink(line, len(linkname), false, f) return hasLinkname } else if strings.HasPrefix(line, llgolink2) { - p.initLink(line, len(llgolink2), f) + p.initLink(line, len(llgolink2), false, f) return hasLinkname } else if strings.HasPrefix(line, llgolink) { - p.initLink(line, len(llgolink), f) + p.initLink(line, len(llgolink), false, f) return hasLinkname } else if strings.HasPrefix(line, export) { // rewrite //export FuncName to //export FuncName FuncName funcName := strings.TrimSpace(line[len(export):]) line = line + " " + funcName - p.initLink(line, len(export), f) + p.initLink(line, len(export), true, f) return hasLinkname } else if strings.HasPrefix(line, directive) { // skip unknown annotation but continue to parse the next annotation @@ -322,13 +324,16 @@ func (p *context) initLinkname(line string, f func(inPkgName string) (fullName s return noDirective } -func (p *context) initLink(line string, prefix int, f func(inPkgName string) (fullName string, isVar, ok bool)) { +func (p *context) initLink(line string, prefix int, export bool, f func(inPkgName string) (fullName string, isVar, ok bool)) { text := strings.TrimSpace(line[prefix:]) if idx := strings.IndexByte(text, ' '); idx > 0 { inPkgName := text[:idx] if fullName, _, ok := f(inPkgName); ok { link := strings.TrimLeft(text[idx+1:], " ") p.prog.SetLinkname(fullName, link) + if export { + p.pkg.SetExport(fullName, link) + } } else { fmt.Fprintln(os.Stderr, "==>", line) fmt.Fprintf(os.Stderr, "llgo: linkname %s not found and ignored\n", inPkgName) diff --git a/internal/build/build.go b/internal/build/build.go index b20af7f6..8676838c 100644 --- a/internal/build/build.go +++ b/internal/build/build.go @@ -43,6 +43,7 @@ import ( "github.com/goplus/llgo/internal/env" "github.com/goplus/llgo/internal/firmware" "github.com/goplus/llgo/internal/flash" + "github.com/goplus/llgo/internal/header" "github.com/goplus/llgo/internal/mockable" "github.com/goplus/llgo/internal/monitor" "github.com/goplus/llgo/internal/packages" @@ -375,7 +376,15 @@ func Do(args []string, conf *Config) ([]Package, error) { // Generate C headers for c-archive and c-shared modes before linking if ctx.buildConf.BuildMode == BuildModeCArchive || ctx.buildConf.BuildMode == BuildModeCShared { - headerErr := generateCHeader(ctx, pkg, outFmts.Out, verbose) + libname := strings.TrimSuffix(filepath.Base(outFmts.Out), conf.AppExt) + headerPath := filepath.Join(filepath.Dir(outFmts.Out), libname) + ".h" + pkgs := make([]llssa.Package, 0, len(allPkgs)) + for _, p := range allPkgs { + if p.LPkg != nil { + pkgs = append(pkgs, p.LPkg) + } + } + headerErr := header.GenHeaderFile(prog, pkgs, libname, headerPath, verbose) if headerErr != nil { return nil, headerErr } @@ -818,44 +827,6 @@ func linkMainPkg(ctx *context, pkg *packages.Package, pkgs []*aPackage, global l return nil } -// TODO(lijie): export C header from function list of the pkg -func generateCHeader(ctx *context, pkg *packages.Package, outputPath string, verbose bool) error { - // Determine header file path - headerPath := strings.TrimSuffix(outputPath, filepath.Ext(outputPath)) + ".h" - - // Generate header content - headerContent := fmt.Sprintf(`/* Code generated by llgo; DO NOT EDIT. */ - -#ifndef __%s_H_ -#define __%s_H_ - -#ifdef __cplusplus -extern "C" { -#endif - -#ifdef __cplusplus -} -#endif - -#endif /* __%s_H_ */ -`, - strings.ToUpper(strings.ReplaceAll(pkg.Name, "-", "_")), - strings.ToUpper(strings.ReplaceAll(pkg.Name, "-", "_")), - strings.ToUpper(strings.ReplaceAll(pkg.Name, "-", "_"))) - - // Write header file - err := os.WriteFile(headerPath, []byte(headerContent), 0644) - if err != nil { - return fmt.Errorf("failed to write header file %s: %w", headerPath, err) - } - - if verbose { - fmt.Fprintf(os.Stderr, "Generated C header: %s\n", headerPath) - } - - return nil -} - func linkObjFiles(ctx *context, app string, objFiles, linkArgs []string, verbose bool) error { // Handle c-archive mode differently - use ar tool instead of linker if ctx.buildConf.BuildMode == BuildModeCArchive { diff --git a/internal/header/header.go b/internal/header/header.go new file mode 100644 index 00000000..a6c17370 --- /dev/null +++ b/internal/header/header.go @@ -0,0 +1,642 @@ +/* + * Copyright (c) 2024 The GoPlus Authors (goplus.org). All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package header + +import ( + "bytes" + "fmt" + "go/types" + "io" + "os" + "sort" + "strings" + + "github.com/goplus/llgo/ssa" +) + +// cheaderWriter handles C header generation with type definition management +type cheaderWriter struct { + p ssa.Program + typeBuf *bytes.Buffer // buffer for type definitions + funcBuf *bytes.Buffer // buffer for function declarations + declaredTypes map[string]bool // track declared types to avoid duplicates +} + +// newCHeaderWriter creates a new C header writer +func newCHeaderWriter(p ssa.Program) *cheaderWriter { + return &cheaderWriter{ + p: p, + typeBuf: &bytes.Buffer{}, + funcBuf: &bytes.Buffer{}, + declaredTypes: make(map[string]bool), + } +} + +// writeTypedef writes a C typedef for the given Go type if not already declared +func (hw *cheaderWriter) writeTypedef(t types.Type) error { + return hw.writeTypedefRecursive(t, make(map[string]bool)) +} + +// writeTypedefRecursive writes typedefs recursively, handling dependencies +func (hw *cheaderWriter) writeTypedefRecursive(t types.Type, visiting map[string]bool) error { + // Handle container types that only need element processing + switch typ := t.(type) { + case *types.Array: + return hw.writeTypedefRecursive(typ.Elem(), visiting) + case *types.Slice: + return hw.writeTypedefRecursive(typ.Elem(), visiting) + case *types.Map: + if err := hw.writeTypedefRecursive(typ.Key(), visiting); err != nil { + return err + } + return hw.writeTypedefRecursive(typ.Elem(), visiting) + case *types.Chan: + return hw.writeTypedefRecursive(typ.Elem(), visiting) + } + + cType := hw.goCTypeName(t) + if cType == "" || hw.declaredTypes[cType] { + return nil + } + + // Prevent infinite recursion for self-referential types + if visiting[cType] { + return nil + } + visiting[cType] = true + defer delete(visiting, cType) + + // Process dependent types for complex types + if err := hw.processDependentTypes(t, visiting); err != nil { + return err + } + + // Then write the typedef for this type + typedef := hw.generateTypedef(t) + if typedef != "" { + fmt.Fprintln(hw.typeBuf, typedef) + // Add empty line after each type definition + fmt.Fprintln(hw.typeBuf) + hw.declaredTypes[cType] = true + } + return nil +} + +// processDependentTypes processes dependent types for composite types +func (hw *cheaderWriter) processDependentTypes(t types.Type, visiting map[string]bool) error { + switch typ := t.(type) { + case *types.Pointer: + return hw.writeTypedefRecursive(typ.Elem(), visiting) + case *types.Struct: + // For anonymous structs, handle field dependencies + for i := 0; i < typ.NumFields(); i++ { + field := typ.Field(i) + if err := hw.writeTypedefRecursive(field.Type(), visiting); err != nil { + return err + } + } + case *types.Named: + // For named types, handle the underlying type dependencies + underlying := typ.Underlying() + if structType, ok := underlying.(*types.Struct); ok { + // For named struct types, handle field dependencies directly + for i := 0; i < structType.NumFields(); i++ { + field := structType.Field(i) + if err := hw.writeTypedefRecursive(field.Type(), visiting); err != nil { + return err + } + } + } else { + // For other named types, handle the underlying type + return hw.writeTypedefRecursive(underlying, visiting) + } + case *types.Signature: + return hw.processSignatureTypes(typ, visiting) + } + return nil +} + +// processSignatureTypes processes function signature parameter and result types +func (hw *cheaderWriter) processSignatureTypes(sig *types.Signature, visiting map[string]bool) error { + // Handle function parameters + if sig.Params() != nil { + for i := 0; i < sig.Params().Len(); i++ { + param := sig.Params().At(i) + if err := hw.writeTypedefRecursive(param.Type(), visiting); err != nil { + return err + } + } + } + // Handle function results + if sig.Results() != nil { + for i := 0; i < sig.Results().Len(); i++ { + result := sig.Results().At(i) + if err := hw.writeTypedefRecursive(result.Type(), visiting); err != nil { + return err + } + } + } + return nil + return nil +} + +// goCTypeName returns the C type name for a Go type +func (hw *cheaderWriter) goCTypeName(t types.Type) string { + switch typ := t.(type) { + case *types.Basic: + switch typ.Kind() { + case types.Invalid: + return "" + case types.Bool: + return "_Bool" + case types.Int8: + return "int8_t" + case types.Uint8: + return "uint8_t" + case types.Int16: + return "int16_t" + case types.Uint16: + return "uint16_t" + case types.Int32: + return "int32_t" + case types.Uint32: + return "uint32_t" + case types.Int64: + return "int64_t" + case types.Uint64: + return "uint64_t" + case types.Int: + return "intptr_t" + case types.Uint: + return "uintptr_t" + case types.Uintptr: + return "uintptr_t" + case types.Float32: + return "float" + case types.Float64: + return "double" + case types.Complex64: + return "GoComplex64" + case types.Complex128: + return "GoComplex128" + case types.String: + return "GoString" + case types.UnsafePointer: + return "void*" + } + case *types.Pointer: + elemType := hw.goCTypeName(typ.Elem()) + if elemType == "" { + return "void*" + } + return elemType + "*" + case *types.Slice: + return "GoSlice" + case *types.Array: + // For arrays, we return just the element type + // The array size will be handled in field generation + return hw.goCTypeName(typ.Elem()) + case *types.Map: + return "GoMap" + case *types.Chan: + return "GoChan" + case *types.Interface: + return "GoInterface" + case *types.Struct: + // For anonymous structs, generate a descriptive name + var fields []string + for i := 0; i < typ.NumFields(); i++ { + field := typ.Field(i) + fieldType := hw.goCTypeName(field.Type()) + fields = append(fields, fmt.Sprintf("%s_%s", fieldType, field.Name())) + } + return fmt.Sprintf("struct_%s", strings.Join(fields, "_")) + case *types.Named: + // For named types, always use the named type + pkg := typ.Obj().Pkg() + return fmt.Sprintf("%s_%s", pkg.Name(), typ.Obj().Name()) + case *types.Signature: + // Function types are represented as function pointers in C + // For simplicity, we use void* to represent function pointers + return "void*" + } + panic(fmt.Errorf("unsupported type: %v", t)) +} + +// generateTypedef generates C typedef declaration for complex types +func (hw *cheaderWriter) generateTypedef(t types.Type) string { + switch typ := t.(type) { + case *types.Struct: + // Only generate typedef for anonymous structs + return hw.generateStructTypedef(typ) + case *types.Named: + underlying := typ.Underlying() + if structType, ok := underlying.(*types.Struct); ok { + // For named struct types, generate the typedef directly + return hw.generateNamedStructTypedef(typ, structType) + } + // For other named types, create a typedef to the underlying type + underlyingCType := hw.goCTypeName(underlying) + if underlyingCType != "" { + cTypeName := hw.goCTypeName(typ) + return fmt.Sprintf("typedef %s %s;", underlyingCType, cTypeName) + } + } + return "" +} + +// generateReturnType generates C return type, converting arrays to struct wrappers +func (hw *cheaderWriter) generateReturnType(retType types.Type) string { + switch typ := retType.(type) { + case *types.Array: + // For array return values, generate a struct wrapper + return hw.ensureArrayStruct(typ) + default: + // For non-array types, use regular type conversion + return hw.goCTypeName(retType) + } +} + +// ensureArrayStruct generates array struct name and ensures its typedef is declared +func (hw *cheaderWriter) ensureArrayStruct(arr *types.Array) string { + // Generate struct name + var dimensions []int64 + baseType := types.Type(arr) + + // Traverse all array dimensions + for { + if a, ok := baseType.(*types.Array); ok { + dimensions = append(dimensions, a.Len()) + baseType = a.Elem() + } else { + break + } + } + + // Get base element type + elemType := hw.goCTypeName(baseType) + + // Generate struct name: Array_int32_t_4 for [4]int32, Array_int32_t_3_4 for [3][4]int32 + var name strings.Builder + name.WriteString("Array_") + name.WriteString(strings.ReplaceAll(elemType, "*", "_ptr")) + for _, dim := range dimensions { + name.WriteString(fmt.Sprintf("_%d", dim)) + } + + structName := name.String() + + // Ensure typedef is declared + if !hw.declaredTypes[structName] { + hw.declaredTypes[structName] = true + // Generate field declaration for the array + fieldDecl := hw.generateFieldDeclaration(arr, "data") + // Write the typedef + typedef := fmt.Sprintf("typedef struct {\n%s\n} %s;", fieldDecl, structName) + fmt.Fprintf(hw.typeBuf, "%s\n\n", typedef) + } + + return structName +} + +// generateFieldDeclaration generates C field declaration with correct array syntax +func (hw *cheaderWriter) generateFieldDeclaration(fieldType types.Type, fieldName string) string { + switch fieldType.(type) { + case *types.Array: + // Handle multidimensional arrays by collecting all dimensions + var dimensions []int64 + baseType := fieldType + + // Traverse all array dimensions + for { + if arr, ok := baseType.(*types.Array); ok { + dimensions = append(dimensions, arr.Len()) + baseType = arr.Elem() + } else { + break + } + } + + // Get base element type + elemType := hw.goCTypeName(baseType) + + // Build array dimensions string [d1][d2][d3]... + var dimStr strings.Builder + for _, dim := range dimensions { + dimStr.WriteString(fmt.Sprintf("[%d]", dim)) + } + + return fmt.Sprintf(" %s %s%s;", elemType, fieldName, dimStr.String()) + default: + cType := hw.goCTypeName(fieldType) + return fmt.Sprintf(" %s %s;", cType, fieldName) + } +} + +// generateStructTypedef generates typedef for anonymous struct +func (hw *cheaderWriter) generateStructTypedef(s *types.Struct) string { + // Generate descriptive type name inline + var nameFields []string + var declFields []string + + for i := 0; i < s.NumFields(); i++ { + field := s.Field(i) + fieldType := hw.goCTypeName(field.Type()) + nameFields = append(nameFields, fmt.Sprintf("%s_%s", fieldType, field.Name())) + declFields = append(declFields, hw.generateFieldDeclaration(field.Type(), field.Name())) + } + + typeName := fmt.Sprintf("struct_%s", strings.Join(nameFields, "_")) + return fmt.Sprintf("typedef struct {\n%s\n} %s;", strings.Join(declFields, "\n"), typeName) +} + +// generateNamedStructTypedef generates typedef for named struct +func (hw *cheaderWriter) generateNamedStructTypedef(named *types.Named, s *types.Struct) string { + typeName := hw.goCTypeName(named) + + // Check if this is a self-referential struct + needsForwardDecl := hw.needsForwardDeclaration(s, typeName) + var result string + + if needsForwardDecl { + // Add forward declaration + result = fmt.Sprintf("typedef struct %s %s;\n", typeName, typeName) + } + + var fields []string + for i := 0; i < s.NumFields(); i++ { + field := s.Field(i) + fields = append(fields, hw.generateFieldDeclaration(field.Type(), field.Name())) + } + + if needsForwardDecl { + // Use struct tag in definition + result += fmt.Sprintf("struct %s {\n%s\n};", typeName, strings.Join(fields, "\n")) + } else { + result = fmt.Sprintf("typedef struct {\n%s\n} %s;", strings.Join(fields, "\n"), typeName) + } + + return result +} + +// needsForwardDeclaration checks if a struct needs forward declaration due to self-reference +func (hw *cheaderWriter) needsForwardDeclaration(s *types.Struct, typeName string) bool { + for i := 0; i < s.NumFields(); i++ { + field := s.Field(i) + if hw.typeReferencesSelf(field.Type(), typeName) { + return true + } + } + return false +} + +// typeReferencesSelf checks if a type references the given type name +func (hw *cheaderWriter) typeReferencesSelf(t types.Type, selfTypeName string) bool { + switch typ := t.(type) { + case *types.Pointer: + elemTypeName := hw.goCTypeName(typ.Elem()) + return elemTypeName == selfTypeName + case *types.Slice: + elemTypeName := hw.goCTypeName(typ.Elem()) + return elemTypeName == selfTypeName + case *types.Array: + elemTypeName := hw.goCTypeName(typ.Elem()) + return elemTypeName == selfTypeName + case *types.Named: + return hw.goCTypeName(typ) == selfTypeName + } + return false +} + +// writeFunctionDecl writes C function declaration for exported Go function +// fullName: the C function name to display in header +// linkName: the actual Go function name for linking +func (hw *cheaderWriter) writeFunctionDecl(fullName, linkName string, fn ssa.Function) error { + if fn.IsNil() { + return nil + } + + // Get Go signature from LLVM function type + goType := fn.Type.RawType() + sig, ok := goType.(*types.Signature) + if !ok { + return fmt.Errorf("function %s does not have signature type", fullName) + } + + // Generate return type + var returnType string + if sig.Results().Len() == 0 { + returnType = "void" + } else if sig.Results().Len() == 1 { + retType := sig.Results().At(0).Type() + if err := hw.writeTypedef(retType); err != nil { + return err + } + returnType = hw.generateReturnType(retType) + } else { + return fmt.Errorf("function %s has more than one result", fullName) + } + + // Generate parameters + var params []string + for i := 0; i < sig.Params().Len(); i++ { + param := sig.Params().At(i) + paramType := param.Type() + + if err := hw.writeTypedef(paramType); err != nil { + return err + } + + paramName := param.Name() + if paramName == "" { + paramName = fmt.Sprintf("param%d", i) + } + + // Use generateFieldDeclaration logic for consistent parameter syntax + paramDecl := hw.generateFieldDeclaration(paramType, paramName) + // Remove the leading spaces and semicolon to get just the declaration + paramDecl = strings.TrimSpace(paramDecl) + paramDecl = strings.TrimSuffix(paramDecl, ";") + params = append(params, paramDecl) + } + + paramStr := strings.Join(params, ", ") + if paramStr == "" { + paramStr = "void" + } + // Write function declaration with return type on separate line for normal functions + fmt.Fprintln(hw.funcBuf, returnType) + // Generate function declaration using cross-platform macro when names differ + var funcDecl string + if fullName != linkName { + funcDecl = fmt.Sprintf("%s(%s) GO_SYMBOL_RENAME(\"%s\")", fullName, paramStr, linkName) + } else { + funcDecl = fmt.Sprintf("%s(%s);", fullName, paramStr) + } + + fmt.Fprintln(hw.funcBuf, funcDecl) + // Add empty line after each function declaration + fmt.Fprintln(hw.funcBuf) + + return nil +} + +// writeCommonIncludes writes common C header includes and Go runtime type definitions +func (hw *cheaderWriter) writeCommonIncludes() error { + includes := ` +// Platform-specific symbol renaming macro +#ifdef __APPLE__ + #define GO_SYMBOL_RENAME(go_name) __asm("_" go_name); +#else + #define GO_SYMBOL_RENAME(go_name) __asm(go_name); +#endif + +// Go runtime types +typedef struct { const char *p; intptr_t n; } GoString; +typedef struct { void *data; intptr_t len; intptr_t cap; } GoSlice; +typedef struct { void *data; } GoMap; +typedef struct { void *data; } GoChan; +typedef struct { void *data; void *type; } GoInterface; +typedef struct { float real; float imag; } GoComplex64; +typedef struct { double real; double imag; } GoComplex128; + +` + + if _, err := hw.typeBuf.WriteString(includes); err != nil { + return err + } + return nil +} + +// writeTo writes all generated content to the output writer +func (hw *cheaderWriter) writeTo(w io.Writer) error { + // Write type definitions first + if hw.typeBuf.Len() > 0 { + if _, err := hw.typeBuf.WriteTo(w); err != nil { + return err + } + } + + // Then write function declarations + if hw.funcBuf.Len() > 0 { + if _, err := hw.funcBuf.WriteTo(w); err != nil { + return err + } + } + + return nil +} + +func genHeader(p ssa.Program, pkgs []ssa.Package, w io.Writer) error { + hw := newCHeaderWriter(p) + + // Write common header includes and type definitions + if err := hw.writeCommonIncludes(); err != nil { + return err + } + + // Mark predefined Go types as declared + hw.declaredTypes["GoString"] = true + hw.declaredTypes["GoSlice"] = true + hw.declaredTypes["GoMap"] = true + hw.declaredTypes["GoChan"] = true + hw.declaredTypes["GoInterface"] = true + hw.declaredTypes["GoComplex64"] = true + hw.declaredTypes["GoComplex128"] = true + + // Process all exported functions + for _, pkg := range pkgs { + exports := pkg.ExportFuncs() + // Sort functions for testing + exportNames := make([]string, 0, len(exports)) + for name := range exports { + exportNames = append(exportNames, name) + } + sort.Strings(exportNames) + + for _, name := range exportNames { // name is goName + link := exports[name] // link is cName + fn := pkg.FuncOf(link) + if fn == nil { + continue + } + + // Write function declaration with proper C types + if err := hw.writeFunctionDecl(link, link, fn); err != nil { + return fmt.Errorf("failed to write declaration for function %s: %w", name, err) + } + } + + initFnName := pkg.Path() + ".init" + initFn := pkg.FuncOf(initFnName) + if initFn != nil { + // Generate C-compatible function name (replace . and / with _) + cInitFnName := strings.ReplaceAll(strings.ReplaceAll(initFnName, ".", "_"), "/", "_") + if err := hw.writeFunctionDecl(cInitFnName, initFnName, initFn); err != nil { + return fmt.Errorf("failed to write declaration for function %s: %w", initFnName, err) + } + } + } + + // Write all content to output in the correct order + return hw.writeTo(w) +} + +func GenHeaderFile(p ssa.Program, pkgs []ssa.Package, libName, headerPath string, verbose bool) error { + // Write header file + w, err := os.Create(headerPath) + if err != nil { + return fmt.Errorf("failed to write header file %s: %w", headerPath, err) + } + defer w.Close() + + if verbose { + fmt.Fprintf(os.Stderr, "Generated C header: %s\n", headerPath) + } + + headerIdent := strings.ToUpper(strings.ReplaceAll(libName, "-", "_")) + headerContent := fmt.Sprintf(`/* Code generated by llgo; DO NOT EDIT. */ + +#ifndef __%s_H_ +#define __%s_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif +`, headerIdent, headerIdent) + + w.Write([]byte(headerContent)) + + if err = genHeader(p, pkgs, w); err != nil { + return fmt.Errorf("failed to generate header content for %s: %w", libName, err) + } + + footerContent := fmt.Sprintf(` + +#ifdef __cplusplus +} +#endif + +#endif /* __%s_H_ */ +`, headerIdent) + + _, err = w.Write([]byte(footerContent)) + return err +} diff --git a/internal/header/header_test.go b/internal/header/header_test.go new file mode 100644 index 00000000..b22a74d9 --- /dev/null +++ b/internal/header/header_test.go @@ -0,0 +1,791 @@ +//go:build !llgo +// +build !llgo + +package header + +import ( + "bytes" + "go/token" + "go/types" + "os" + "strings" + "testing" + + "github.com/goplus/gogen/packages" + "github.com/goplus/llgo/ssa" + "github.com/goplus/llvm" +) + +func init() { + llvm.InitializeAllTargets() + llvm.InitializeAllTargetMCs() + llvm.InitializeAllTargetInfos() + llvm.InitializeAllAsmParsers() + llvm.InitializeAllAsmPrinters() +} + +func TestGenCHeaderExport(t *testing.T) { + prog := ssa.NewProgram(nil) + prog.SetRuntime(func() *types.Package { + fset := token.NewFileSet() + imp := packages.NewImporter(fset) + pkg, _ := imp.Import(ssa.PkgRuntime) + return pkg + }) + + // Define main package and the 'Foo' type within it + mainPkgPath := "github.com/goplus/llgo/test_buildmode/main" + mainTypesPkg := types.NewPackage(mainPkgPath, "main") + fooFields := []*types.Var{ + types.NewField(token.NoPos, mainTypesPkg, "a", types.Typ[types.Int], false), + types.NewField(token.NoPos, mainTypesPkg, "b", types.Typ[types.Float64], false), + } + fooStruct := types.NewStruct(fooFields, nil) + fooTypeName := types.NewTypeName(token.NoPos, mainTypesPkg, "Foo", nil) + fooNamed := types.NewNamed(fooTypeName, fooStruct, nil) + mainTypesPkg.Scope().Insert(fooTypeName) + + // Create SSA package for main + mainPkg := prog.NewPackage("main", mainPkgPath) + + // Define exported functions in mainPkg + mainPkg.NewFunc("HelloWorld", types.NewSignatureType(nil, nil, nil, nil, nil, false), ssa.InGo) + useFooPtrParams := types.NewTuple(types.NewVar(token.NoPos, nil, "f", types.NewPointer(fooNamed))) + useFooPtrResults := types.NewTuple(types.NewVar(token.NoPos, nil, "", fooNamed)) + useFooPtrSig := types.NewSignatureType(nil, nil, nil, useFooPtrParams, useFooPtrResults, false) + mainPkg.NewFunc("UseFooPtr", useFooPtrSig, ssa.InGo) + useFooParams := types.NewTuple(types.NewVar(token.NoPos, nil, "f", fooNamed)) + useFooSig := types.NewSignatureType(nil, nil, nil, useFooParams, useFooPtrResults, false) + mainPkg.NewFunc("UseFoo", useFooSig, ssa.InGo) + + // Set exports for main + mainPkg.SetExport("HelloWorld", "HelloWorld") + mainPkg.SetExport("UseFooPtr", "UseFooPtr") + mainPkg.SetExport("UseFoo", "UseFoo") + + // Create package C + cPkgPath := "github.com/goplus/llgo/test_buildmode/bar" + cPkg := prog.NewPackage("C", cPkgPath) + addParams := types.NewTuple( + types.NewVar(token.NoPos, nil, "a", types.Typ[types.Int]), + types.NewVar(token.NoPos, nil, "b", types.Typ[types.Int])) + addResults := types.NewTuple(types.NewVar(token.NoPos, nil, "", types.Typ[types.Int])) + addSig := types.NewSignatureType(nil, nil, nil, addParams, addResults, false) + cPkg.NewFunc("Add", addSig, ssa.InGo) + cPkg.NewFunc("Sub", addSig, ssa.InGo) + cPkg.SetExport("XAdd", "Add") + cPkg.SetExport("XSub", "Sub") + + // Generate header + libname := "testbuild" + headerPath := os.TempDir() + "/testbuild.h" + err := GenHeaderFile(prog, []ssa.Package{mainPkg, cPkg}, libname, headerPath, true) + if err != nil { + t.Fatal(err) + } + data, err := os.ReadFile(headerPath) + if err != nil { + t.Fatal(err) + } + + required := []string{ + "/* Code generated by llgo; DO NOT EDIT. */", + "#ifndef __TESTBUILD_H_", + "#include ", + "typedef struct { const char *p; intptr_t n; } GoString;", + "typedef struct {\n intptr_t a;\n double b;\n} main_Foo;", + "void\nHelloWorld(void);", + "main_Foo\nUseFooPtr(main_Foo* f);", + "main_Foo\nUseFoo(main_Foo f);", + "intptr_t\nAdd(intptr_t a, intptr_t b);", + "intptr_t\nSub(intptr_t a, intptr_t b);", + "#endif /* __TESTBUILD_H_ */", + } + + got := string(data) + + for _, sub := range required { + if !strings.Contains(got, sub) { + t.Fatalf("Generated content: %s\n", got) + t.Fatalf("Generated header is missing expected content:\n%s", sub) + } + } +} + +func TestCheaderWriterTypes(t *testing.T) { + prog := ssa.NewProgram(nil) + hw := newCHeaderWriter(prog) + + // Test complex integration scenarios only - basic types are covered by TestGoCTypeName + tests := []struct { + name string + goType types.Type + expected string + }{ + { + name: "named struct", + goType: func() types.Type { + pkg := types.NewPackage("main", "main") + s := types.NewStruct([]*types.Var{types.NewField(0, nil, "f1", types.Typ[types.Int], false)}, nil) + return types.NewNamed(types.NewTypeName(0, pkg, "MyStruct", nil), s, nil) + }(), + expected: "typedef struct {\n intptr_t f1;\n} main_MyStruct;\n\n", + }, + { + name: "struct with array field", + goType: func() types.Type { + arrayType := types.NewArray(types.Typ[types.Float64], 10) + return types.NewStruct([]*types.Var{ + types.NewField(0, nil, "Values", arrayType, false), + }, nil) + }(), + expected: "typedef struct {\n double Values[10];\n} struct_double_Values;\n\n", + }, + { + name: "struct with multidimensional array", + goType: func() types.Type { + // Create a 2D array: [4][3]int32 + innerArrayType := types.NewArray(types.Typ[types.Int32], 3) + outerArrayType := types.NewArray(innerArrayType, 4) + return types.NewStruct([]*types.Var{ + types.NewField(0, nil, "Matrix", outerArrayType, false), + }, nil) + }(), + expected: "typedef struct {\n int32_t Matrix[4][3];\n} struct_int32_t_Matrix;\n\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hw.typeBuf.Reset() + hw.declaredTypes = make(map[string]bool) // Reset declared types for each run + + // Mark predefined Go types as declared (same as in genHeader) + hw.declaredTypes["GoString"] = true + hw.declaredTypes["GoSlice"] = true + hw.declaredTypes["GoMap"] = true + hw.declaredTypes["GoChan"] = true + hw.declaredTypes["GoInterface"] = true + hw.declaredTypes["GoComplex64"] = true + hw.declaredTypes["GoComplex128"] = true + + if err := hw.writeTypedef(tt.goType); err != nil { + t.Fatalf("writeTypedef() error = %v", err) + } + got := hw.typeBuf.String() + if got != tt.expected { + t.Errorf("writeTypedef() got = %q, want %q", got, tt.expected) + } + }) + } +} + +// Test for goCTypeName function to cover all basic types +func TestGoCTypeName(t *testing.T) { + prog := ssa.NewProgram(nil) + hw := newCHeaderWriter(prog) + + tests := []struct { + name string + goType types.Type + expected string + }{ + {name: "bool", goType: types.Typ[types.Bool], expected: "_Bool"}, + {name: "int8", goType: types.Typ[types.Int8], expected: "int8_t"}, + {name: "uint8", goType: types.Typ[types.Uint8], expected: "uint8_t"}, + {name: "int16", goType: types.Typ[types.Int16], expected: "int16_t"}, + {name: "uint16", goType: types.Typ[types.Uint16], expected: "uint16_t"}, + {name: "int32", goType: types.Typ[types.Int32], expected: "int32_t"}, + {name: "uint32", goType: types.Typ[types.Uint32], expected: "uint32_t"}, + {name: "int64", goType: types.Typ[types.Int64], expected: "int64_t"}, + {name: "uint64", goType: types.Typ[types.Uint64], expected: "uint64_t"}, + {name: "int", goType: types.Typ[types.Int], expected: "intptr_t"}, + {name: "uint", goType: types.Typ[types.Uint], expected: "uintptr_t"}, + {name: "uintptr", goType: types.Typ[types.Uintptr], expected: "uintptr_t"}, + {name: "float32", goType: types.Typ[types.Float32], expected: "float"}, + {name: "float64", goType: types.Typ[types.Float64], expected: "double"}, + {name: "complex64", goType: types.Typ[types.Complex64], expected: "GoComplex64"}, + {name: "complex128", goType: types.Typ[types.Complex128], expected: "GoComplex128"}, + {name: "string", goType: types.Typ[types.String], expected: "GoString"}, + {name: "unsafe pointer", goType: types.Typ[types.UnsafePointer], expected: "void*"}, + {name: "slice", goType: types.NewSlice(types.Typ[types.Int]), expected: "GoSlice"}, + {name: "map", goType: types.NewMap(types.Typ[types.String], types.Typ[types.Int]), expected: "GoMap"}, + {name: "chan", goType: types.NewChan(types.SendRecv, types.Typ[types.Int]), expected: "GoChan"}, + {name: "interface", goType: types.NewInterfaceType(nil, nil), expected: "GoInterface"}, + { + name: "array", + goType: types.NewArray(types.Typ[types.Int], 5), + expected: "intptr_t", + }, + { + name: "pointer to int", + goType: types.NewPointer(types.Typ[types.Int]), + expected: "intptr_t*", + }, + { + name: "pointer to unknown type", + goType: types.NewPointer(types.Typ[types.Invalid]), + expected: "void*", + }, + { + name: "array of unknown type", + goType: types.NewArray(types.Typ[types.Invalid], 3), + expected: "", + }, + { + name: "signature type", + goType: types.NewSignature(nil, nil, nil, false), + expected: "void*", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := hw.goCTypeName(tt.goType) + if got != tt.expected { + t.Errorf("goCTypeName() = %q, want %q", got, tt.expected) + } + }) + } +} + +// Test typeReferencesSelf function +func TestTypeReferencesSelf(t *testing.T) { + prog := ssa.NewProgram(nil) + hw := newCHeaderWriter(prog) + + pkg := types.NewPackage("test", "test") + + // Create a named type for testing + nodeTypeName := types.NewTypeName(0, pkg, "Node", nil) + nodeStruct := types.NewStruct([]*types.Var{ + types.NewField(0, nil, "data", types.Typ[types.Int], false), + }, nil) + namedNode := types.NewNamed(nodeTypeName, nodeStruct, nil) + + tests := []struct { + name string + typ types.Type + selfTypeName string + expected bool + }{ + { + name: "pointer to self", + typ: types.NewPointer(namedNode), + selfTypeName: "test_Node", + expected: true, + }, + { + name: "slice of self", + typ: types.NewSlice(namedNode), + selfTypeName: "test_Node", + expected: true, + }, + { + name: "array of self", + typ: types.NewArray(namedNode, 5), + selfTypeName: "test_Node", + expected: true, + }, + { + name: "named type self", + typ: namedNode, + selfTypeName: "test_Node", + expected: true, + }, + { + name: "basic type not self", + typ: types.Typ[types.Int], + selfTypeName: "test_Node", + expected: false, + }, + { + name: "different named type", + typ: namedNode, + selfTypeName: "other_Type", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := hw.typeReferencesSelf(tt.typ, tt.selfTypeName) + if got != tt.expected { + t.Errorf("typeReferencesSelf() = %v, want %v", got, tt.expected) + } + }) + } +} + +// Test array struct generation functions +func TestArrayStructGeneration(t *testing.T) { + prog := ssa.NewProgram(nil) + hw := newCHeaderWriter(prog) + + // Test ensureArrayStruct + arrayType := types.NewArray(types.Typ[types.Int32], 5) + name := hw.ensureArrayStruct(arrayType) + expectedName := "Array_int32_t_5" + if name != expectedName { + t.Errorf("ensureArrayStruct() = %q, want %q", name, expectedName) + } + + // Test that typedef was generated + output := hw.typeBuf.String() + if !strings.Contains(output, "typedef struct") { + t.Error("ensureArrayStruct should generate typedef") + } + if !strings.Contains(output, "int32_t data[5]") { + t.Error("ensureArrayStruct should generate correct array field") + } + + // Test duplicate prevention + hw.typeBuf.Reset() + name2 := hw.ensureArrayStruct(arrayType) // Call again + if name2 != name { + t.Errorf("ensureArrayStruct should return same name for same type") + } + duplicateOutput := hw.typeBuf.String() + if duplicateOutput != "" { + t.Error("ensureArrayStruct should not generate duplicate typedef") + } +} + +// Test generateReturnType function +func TestGenerateReturnType(t *testing.T) { + prog := ssa.NewProgram(nil) + hw := newCHeaderWriter(prog) + + // Test basic type + basicRet := hw.generateReturnType(types.Typ[types.Int32]) + if basicRet != "int32_t" { + t.Errorf("generateReturnType(int32) = %q, want %q", basicRet, "int32_t") + } + + // Test array type (should generate struct wrapper) + arrayType := types.NewArray(types.Typ[types.Float64], 3) + arrayRet := hw.generateReturnType(arrayType) + expectedArrayRet := "Array_double_3" + if arrayRet != expectedArrayRet { + t.Errorf("generateReturnType(array) = %q, want %q", arrayRet, expectedArrayRet) + } +} + +// Test generateTypedef function +func TestGenerateTypedef(t *testing.T) { + prog := ssa.NewProgram(nil) + hw := newCHeaderWriter(prog) + + pkg := types.NewPackage("test", "test") + + // Test named struct + structType := types.NewStruct([]*types.Var{ + types.NewField(0, nil, "value", types.Typ[types.Int], false), + }, nil) + namedType := types.NewNamed(types.NewTypeName(0, pkg, "TestStruct", nil), structType, nil) + + typedef := hw.generateTypedef(namedType) + if !strings.Contains(typedef, "typedef struct") { + t.Error("generateTypedef should generate typedef for named struct") + } + + // Test named basic type + namedInt := types.NewNamed(types.NewTypeName(0, pkg, "MyInt", nil), types.Typ[types.Int], nil) + typedef2 := hw.generateTypedef(namedInt) + if !strings.Contains(typedef2, "typedef intptr_t test_MyInt") { + t.Error("generateTypedef should generate typedef for named basic type") + } +} + +// Test complex nested structures and dependencies +func TestComplexNestedStructures(t *testing.T) { + prog := ssa.NewProgram(nil) + hw := newCHeaderWriter(prog) + + // Create a complex nested structure + pkg := types.NewPackage("test", "test") + + // Inner struct + innerStruct := types.NewStruct([]*types.Var{ + types.NewField(0, nil, "value", types.Typ[types.Int], false), + }, nil) + + // Named inner struct + innerTypeName := types.NewTypeName(0, pkg, "InnerStruct", nil) + namedInner := types.NewNamed(innerTypeName, innerStruct, nil) + + // Outer struct with inner struct field + outerStruct := types.NewStruct([]*types.Var{ + types.NewField(0, nil, "inner", namedInner, false), + types.NewField(0, nil, "ptr", types.NewPointer(namedInner), false), + types.NewField(0, nil, "slice", types.NewSlice(namedInner), false), + }, nil) + + outerTypeName := types.NewTypeName(0, pkg, "OuterStruct", nil) + namedOuter := types.NewNamed(outerTypeName, outerStruct, nil) + + // Test writeTypedef for complex structure + err := hw.writeTypedef(namedOuter) + if err != nil { + t.Fatalf("writeTypedef() error = %v", err) + } + + output := hw.typeBuf.String() + + // Should contain both inner and outer struct definitions + if !strings.Contains(output, "test_InnerStruct") { + t.Error("Expected inner struct typedef") + } + if !strings.Contains(output, "test_OuterStruct") { + t.Error("Expected outer struct typedef") + } +} + +// Test goCTypeName with more type cases + +// Test processDependentTypes for error paths and edge cases +func TestProcessDependentTypesEdgeCases(t *testing.T) { + prog := ssa.NewProgram(nil) + hw := newCHeaderWriter(prog) + + // Test signature type dependency (function parameters and results) + params := types.NewTuple(types.NewVar(0, nil, "x", types.Typ[types.Int])) + results := types.NewTuple(types.NewVar(0, nil, "", types.Typ[types.String])) + sigType := types.NewSignatureType(nil, nil, nil, params, results, false) + + err := hw.processDependentTypes(sigType, make(map[string]bool)) + if err != nil { + t.Errorf("processDependentTypes(signature) error = %v", err) + } + + // Test processSignatureTypes directly + err = hw.processSignatureTypes(sigType, make(map[string]bool)) + if err != nil { + t.Errorf("processSignatureTypes error = %v", err) + } + + // Test Map type - this should trigger the Map case in writeTypedefRecursive + mapType := types.NewMap(types.Typ[types.String], types.Typ[types.Int]) + err = hw.writeTypedefRecursive(mapType, make(map[string]bool)) + if err != nil { + t.Errorf("writeTypedefRecursive(map) error = %v", err) + } + + // Test Chan type - this should trigger the Chan case in writeTypedefRecursive + chanType := types.NewChan(types.SendRecv, types.Typ[types.Bool]) + err = hw.writeTypedefRecursive(chanType, make(map[string]bool)) + if err != nil { + t.Errorf("writeTypedefRecursive(chan) error = %v", err) + } + + // Test Map with complex types to trigger both key and value processing + struct1 := types.NewStruct([]*types.Var{ + types.NewField(0, nil, "key", types.Typ[types.String], false), + }, nil) + struct2 := types.NewStruct([]*types.Var{ + types.NewField(0, nil, "value", types.Typ[types.Int], false), + }, nil) + complexMapType := types.NewMap(struct1, struct2) + err = hw.writeTypedefRecursive(complexMapType, make(map[string]bool)) + if err != nil { + t.Errorf("writeTypedefRecursive(complex map) error = %v", err) + } + + // Test function signature with no parameters (edge case) + noParamsSig := types.NewSignatureType(nil, nil, nil, nil, results, false) + err = hw.processSignatureTypes(noParamsSig, make(map[string]bool)) + if err != nil { + t.Errorf("processSignatureTypes(no params) error = %v", err) + } + + // Test function signature with no results (edge case) + noResultsSig := types.NewSignatureType(nil, nil, nil, params, nil, false) + err = hw.processSignatureTypes(noResultsSig, make(map[string]bool)) + if err != nil { + t.Errorf("processSignatureTypes(no results) error = %v", err) + } +} + +// Test generateNamedStructTypedef with forward declaration +func TestGenerateNamedStructTypedefWithForwardDecl(t *testing.T) { + prog := ssa.NewProgram(nil) + hw := newCHeaderWriter(prog) + + pkg := types.NewPackage("test", "test") + + // Create a self-referential struct that needs forward declaration + nodeName := types.NewTypeName(0, pkg, "Node", nil) + nodeNamed := types.NewNamed(nodeName, nil, nil) + + // Create fields including a pointer to itself + fields := []*types.Var{ + types.NewField(0, nil, "value", types.Typ[types.Int], false), + types.NewField(0, nil, "next", types.NewPointer(nodeNamed), false), + } + nodeStruct := types.NewStruct(fields, nil) + nodeNamed.SetUnderlying(nodeStruct) + + // Test generateNamedStructTypedef + result := hw.generateNamedStructTypedef(nodeNamed, nodeStruct) + + // Should contain forward declaration (with package prefix) + if !strings.Contains(result, "typedef struct test_Node test_Node;") { + t.Errorf("Expected forward declaration in result: %s", result) + } + + // Should contain the actual struct definition + if !strings.Contains(result, "struct test_Node {") { + t.Errorf("Expected struct definition in result: %s", result) + } +} + +// Test self-referential structures to ensure no infinite recursion +func TestSelfReferentialStructure(t *testing.T) { + prog := ssa.NewProgram(nil) + hw := newCHeaderWriter(prog) + + pkg := types.NewPackage("test", "test") + + // Create a self-referential struct: Node with a pointer to itself + nodeTypeName := types.NewTypeName(0, pkg, "Node", nil) + nodeStruct := types.NewStruct([]*types.Var{ + types.NewField(0, nil, "data", types.Typ[types.Int], false), + }, nil) + namedNode := types.NewNamed(nodeTypeName, nodeStruct, nil) + + // Add a self-referential field after creating the named type + nodeStructWithPtr := types.NewStruct([]*types.Var{ + types.NewField(0, nil, "data", types.Typ[types.Int], false), + types.NewField(0, nil, "next", types.NewPointer(namedNode), false), + }, nil) + + // Create a new named type with the updated struct + nodeTypeNameFinal := types.NewTypeName(0, pkg, "SelfRefNode", nil) + namedNodeFinal := types.NewNamed(nodeTypeNameFinal, nodeStructWithPtr, nil) + + // This should not cause infinite recursion + err := hw.writeTypedef(namedNodeFinal) + if err != nil { + t.Fatalf("writeTypedef() error = %v", err) + } + + output := hw.typeBuf.String() + if !strings.Contains(output, "test_SelfRefNode") { + t.Error("Expected self-referential struct typedef") + } +} + +// Test function signature dependencies +func TestFunctionSignatureDependencies(t *testing.T) { + prog := ssa.NewProgram(nil) + hw := newCHeaderWriter(prog) + + pkg := types.NewPackage("test", "test") + + // Create struct type for function parameters + paramStruct := types.NewStruct([]*types.Var{ + types.NewField(0, nil, "id", types.Typ[types.Int], false), + }, nil) + paramTypeName := types.NewTypeName(0, pkg, "ParamStruct", nil) + namedParam := types.NewNamed(paramTypeName, paramStruct, nil) + + // Create function signature with struct parameters and return values + params := types.NewTuple( + types.NewVar(0, nil, "input", namedParam), + types.NewVar(0, nil, "count", types.Typ[types.Int]), + ) + results := types.NewTuple( + types.NewVar(0, nil, "output", namedParam), + ) + + funcSig := types.NewSignatureType(nil, nil, nil, params, results, false) + + // Test that function signature dependencies are processed + err := hw.processDependentTypes(funcSig, make(map[string]bool)) + if err != nil { + t.Fatalf("processDependentTypes() error = %v", err) + } + + // Test named basic type alias (should trigger the "else" branch in processDependentTypes) + namedInt := types.NewNamed(types.NewTypeName(0, pkg, "MyInt", nil), types.Typ[types.Int], nil) + err = hw.processDependentTypes(namedInt, make(map[string]bool)) + if err != nil { + t.Errorf("processDependentTypes(named int) error = %v", err) + } + + // Test duplicate type prevention - mark type as already declared + hw2 := newCHeaderWriter(prog) + hw2.declaredTypes["test_DuplicateType"] = true + duplicateType := types.NewNamed( + types.NewTypeName(0, pkg, "DuplicateType", nil), + types.Typ[types.Int], + nil, + ) + err = hw2.writeTypedefRecursive(duplicateType, make(map[string]bool)) + if err != nil { + t.Errorf("writeTypedefRecursive(duplicate) error = %v", err) + } + if strings.Contains(hw2.typeBuf.String(), "DuplicateType") { + t.Error("Should not generate typedef for already declared type") + } + + // Test visiting map to prevent infinite recursion + visiting := make(map[string]bool) + visiting["test_MyInt"] = true // Mark as visiting + err = hw.writeTypedefRecursive(namedInt, visiting) + if err != nil { + t.Errorf("writeTypedefRecursive(already visiting) error = %v", err) + } + + // Test invalid type (should trigger cType == "" path) + invalidType := types.Typ[types.Invalid] + err = hw.writeTypedefRecursive(invalidType, make(map[string]bool)) + if err != nil { + t.Errorf("writeTypedefRecursive(invalid) error = %v", err) + } +} + +// Test error conditions and edge cases +func TestEdgeCasesAndErrorConditions(t *testing.T) { + prog := ssa.NewProgram(nil) + hw := newCHeaderWriter(prog) + + // Test with properly initialized but empty function - skip nil test since it causes panic + // This would require creating a complete Function object which is complex + + // Test writeCommonIncludes + hw.typeBuf.Reset() + err := hw.writeCommonIncludes() + if err != nil { + t.Fatalf("writeCommonIncludes() error = %v", err) + } + + output := hw.typeBuf.String() + expectedIncludes := []string{ + "GoString", + "GoSlice", + "GoMap", + "GoChan", + "GoInterface", + } + + for _, expected := range expectedIncludes { + if !strings.Contains(output, expected) { + t.Errorf("Expected %s in common includes", expected) + } + } +} + +// Test writeTo function +func TestWriteTo(t *testing.T) { + prog := ssa.NewProgram(nil) + hw := newCHeaderWriter(prog) + + // Add some content to both buffers + hw.typeBuf.WriteString("typedef struct { int x; } TestStruct;\n") + hw.funcBuf.WriteString("void TestFunction(void);\n") + + var output bytes.Buffer + err := hw.writeTo(&output) + if err != nil { + t.Fatalf("writeTo() error = %v", err) + } + + got := output.String() + if !strings.Contains(got, "TestStruct") { + t.Error("writeTo() should write type definitions") + } + if !strings.Contains(got, "TestFunction") { + t.Error("writeTo() should write function declarations") + } +} + +// Test genHeader function +func TestGenHeader(t *testing.T) { + prog := ssa.NewProgram(nil) + + // Create a mock package + pkg := prog.NewPackage("", "testpkg") + + var output bytes.Buffer + err := genHeader(prog, []ssa.Package{pkg}, &output) + if err != nil { + t.Fatalf("genHeader() error = %v", err) + } + + got := output.String() + if !strings.Contains(got, "GoString") { + t.Error("genHeader() should include Go runtime types") + } +} + +// Test GenCHeader function +func TestGenCHeader(t *testing.T) { + prog := ssa.NewProgram(nil) + + // Create a mock package + pkg := prog.NewPackage("", "testpkg") + + // Create a temp file for output + tmpfile, err := os.CreateTemp("", "test_header_*.h") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + defer os.Remove(tmpfile.Name()) + defer tmpfile.Close() + + err = GenHeaderFile(prog, []ssa.Package{pkg}, "testlib", tmpfile.Name(), false) + if err != nil { + t.Fatalf("GenCHeader() error = %v", err) + } + + // Read the file and verify content + content, err := os.ReadFile(tmpfile.Name()) + if err != nil { + t.Fatalf("Failed to read generated file: %v", err) + } + + got := string(content) + if !strings.Contains(got, "#ifndef") { + t.Error("GenCHeader() should generate header guards") + } + if !strings.Contains(got, "GoString") { + t.Error("GenCHeader() should include Go runtime types") + } +} + +// Test genHeader with init function coverage +func TestGenHeaderWithInitFunction(t *testing.T) { + prog := ssa.NewProgram(nil) + + // Create a package + pkgPath := "github.com/test/mypackage" + pkg := prog.NewPackage("", pkgPath) + + // Create an init function signature: func() + initSig := types.NewSignature(nil, types.NewTuple(), types.NewTuple(), false) + + // Create the init function with the expected name format + initFnName := pkgPath + ".init" + _ = pkg.NewFunc(initFnName, initSig, ssa.InGo) + + // Test genHeader which should now detect the init function + var output bytes.Buffer + err := genHeader(prog, []ssa.Package{pkg}, &output) + if err != nil { + t.Fatalf("genHeader() error = %v", err) + } + + got := output.String() + + // Should contain Go runtime types + if !strings.Contains(got, "GoString") { + t.Error("genHeader() should include Go runtime types") + } + + // Should contain the init function declaration with C-compatible name + expectedInitName := "github_com_test_mypackage_init" + if !strings.Contains(got, expectedInitName) { + t.Errorf("genHeader() should include init function declaration with name %s, got: %s", expectedInitName, got) + } +} diff --git a/ssa/package.go b/ssa/package.go index f60c3545..e90426d5 100644 --- a/ssa/package.go +++ b/ssa/package.go @@ -435,6 +435,7 @@ func (p Program) NewPackage(name, pkgPath string) Package { pyobjs: pyobjs, pymods: pymods, strs: strs, chkabi: chkabi, Prog: p, di: nil, cu: nil, glbDbgVars: glbDbgVars, + export: make(map[string]string), } ret.abi.Init(pkgPath) return ret @@ -693,6 +694,8 @@ type aPackage struct { NeedRuntime bool NeedPyInit bool + + export map[string]string // pkgPath.nameInPkg => exportname } type Package = *aPackage @@ -701,6 +704,14 @@ func (p Package) Module() llvm.Module { return p.mod } +func (p Package) SetExport(name, export string) { + p.export[name] = export +} + +func (p Package) ExportFuncs() map[string]string { + return p.export +} + func (p Package) rtFunc(fnName string) Expr { p.NeedRuntime = true fn := p.Prog.runtime().Scope().Lookup(fnName).(*types.Func)