diff --git a/internal/build/build.go b/internal/build/build.go index 1a130122..d2a77cbb 100644 --- a/internal/build/build.go +++ b/internal/build/build.go @@ -28,7 +28,6 @@ import ( "path/filepath" "runtime" "strings" - "unsafe" "golang.org/x/tools/go/ssa" @@ -107,11 +106,13 @@ func Do(args []string, conf *Config) { prog := llssa.NewProgram(nil) sizes := prog.TypeSizes + // dedup := packages.NewDeduper() + dedup := (*packages.Deduper)(nil) if patterns == nil { patterns = []string{"."} } - initial, err := packages.LoadEx(sizes, cfg, patterns...) + initial, err := packages.LoadEx(dedup, sizes, cfg, patterns...) check(err) mode := conf.Mode @@ -133,7 +134,7 @@ func Do(args []string, conf *Config) { load := func() []*packages.Package { if rt == nil { var err error - rt, err = packages.LoadEx(sizes, cfg, llssa.PkgRuntime, llssa.PkgPython) + rt, err = packages.LoadEx(dedup, sizes, cfg, llssa.PkgRuntime, llssa.PkgPython) check(err) } return rt @@ -149,7 +150,7 @@ func Do(args []string, conf *Config) { }) imp := func(pkgPath string) *packages.Package { - if ret, e := packages.LoadEx(sizes, cfg, pkgPath); e == nil { + if ret, e := packages.LoadEx(dedup, sizes, cfg, pkgPath); e == nil { return ret[0] } return nil @@ -443,17 +444,6 @@ func allPkgs(imp importer, initial []*packages.Package, mode ssa.BuilderMode) (p return } -type ssaProgram struct { - Fset *token.FileSet - imported map[string]*ssa.Package - packages map[*types.Package]*ssa.Package // TODO(xsw): ensure offset of packages -} - -func setPkgSSA(prog *ssa.Program, pkg *types.Package, pkgSSA *ssa.Package) { - s := (*ssaProgram)(unsafe.Pointer(prog)) - s.packages[pkg] = pkgSSA -} - func createAltSSAPkg(prog *ssa.Program, alt *packages.Package) *ssa.Package { altPath := alt.Types.Path() altSSA := prog.ImportedPackage(altPath) @@ -461,11 +451,8 @@ func createAltSSAPkg(prog *ssa.Program, alt *packages.Package) *ssa.Package { packages.Visit([]*packages.Package{alt}, nil, func(p *packages.Package) { pkgTypes := p.Types if pkgTypes != nil && !p.IllTyped { - pkgSSA := prog.ImportedPackage(pkgTypes.Path()) - if pkgSSA == nil { + if prog.ImportedPackage(pkgTypes.Path()) == nil { prog.CreatePackage(pkgTypes, p.Syntax, p.TypesInfo, true) - } else { - setPkgSSA(prog, pkgTypes, pkgSSA) } } }) diff --git a/internal/build/clean.go b/internal/build/clean.go index 1e5a2e26..60ec5f46 100644 --- a/internal/build/clean.go +++ b/internal/build/clean.go @@ -42,7 +42,7 @@ func Clean(args []string, conf *Config) { if patterns == nil { patterns = []string{"."} } - initial, err := packages.LoadEx(nil, cfg, patterns...) + initial, err := packages.LoadEx(nil, nil, cfg, patterns...) check(err) cleanPkgs(initial, verbose) diff --git a/internal/llgen/llgenf.go b/internal/llgen/llgenf.go index 843c245a..916831dc 100644 --- a/internal/llgen/llgenf.go +++ b/internal/llgen/llgenf.go @@ -43,7 +43,7 @@ func initRtAndPy(prog llssa.Program, cfg *packages.Config) { load := func() []*packages.Package { if pkgRtAndPy == nil { var err error - pkgRtAndPy, err = packages.LoadEx(prog.TypeSizes, cfg, llssa.PkgRuntime, llssa.PkgPython) + pkgRtAndPy, err = packages.LoadEx(nil, prog.TypeSizes, cfg, llssa.PkgRuntime, llssa.PkgPython) check(err) } return pkgRtAndPy @@ -65,7 +65,7 @@ func GenFrom(fileOrPkg string) string { cfg := &packages.Config{ Mode: loadSyntax | packages.NeedDeps, } - initial, err := packages.LoadEx(prog.TypeSizes, cfg, fileOrPkg) + initial, err := packages.LoadEx(nil, prog.TypeSizes, cfg, fileOrPkg) check(err) _, pkgs := ssautil.AllPackages(initial, ssa.SanityCheckFunctions) diff --git a/internal/packages/load.go b/internal/packages/load.go index 18d49d63..22e4fbb6 100644 --- a/internal/packages/load.go +++ b/internal/packages/load.go @@ -17,12 +17,14 @@ package packages import ( + "errors" "fmt" "go/types" "runtime" "sync" "unsafe" + "golang.org/x/sync/errgroup" "golang.org/x/tools/go/packages" ) @@ -57,6 +59,12 @@ const ( // Calls to Load do not modify this struct. type Config = packages.Config +func setGoListOverlayFile(cfg *Config, val string) { + // TODO(xsw): suppose that the field is at the end of the struct + ptr := uintptr(unsafe.Pointer(cfg)) + (unsafe.Sizeof(*cfg) - unsafe.Sizeof(val)) + *(*string)(unsafe.Pointer(ptr)) = val +} + // A Package describes a loaded Go package. type Package = packages.Package @@ -64,7 +72,7 @@ type Package = packages.Package type loader struct { pkgs map[string]unsafe.Pointer Config - sizes types.Sizes // non-nil if needed by mode + sizes types.Sizes // TODO(xsw): ensure offset of sizes parseCache map[string]unsafe.Pointer parseCacheMu sync.Mutex exportMu sync.Mutex // enforces mutual exclusion of exportdata operations @@ -78,12 +86,131 @@ type loader struct { requestedMode LoadMode } +// Deduper wraps a DriverResponse, deduplicating its contents. +type Deduper struct { + seenRoots map[string]bool + seenPackages map[string]*Package + dr *packages.DriverResponse // TODO(xsw): ensure offset of dr +} + +//go:linkname NewDeduper golang.org/x/tools/go/packages.newDeduper +func NewDeduper() *Deduper + +//go:linkname addAll golang.org/x/tools/go/packages.(*responseDeduper).addAll +func addAll(r *Deduper, dr *packages.DriverResponse) + +func mergeResponsesEx(dedup *Deduper, responses ...*packages.DriverResponse) *packages.DriverResponse { + if len(responses) == 0 { + return nil + } + if dedup == nil { + dedup = NewDeduper() + } + response := dedup + response.dr.NotHandled = false + response.dr.Compiler = responses[0].Compiler + response.dr.Arch = responses[0].Arch + response.dr.GoVersion = responses[0].GoVersion + for _, v := range responses { + addAll(response, v) + } + return response.dr +} + +// driver is the type for functions that query the build system for the +// packages named by the patterns. +type driver func(cfg *Config, patterns ...string) (*packages.DriverResponse, error) + +func callDriverOnChunksEx(dedup *Deduper, driver driver, cfg *Config, chunks [][]string) (*packages.DriverResponse, error) { + if len(chunks) == 0 { + return driver(cfg) + } + responses := make([]*packages.DriverResponse, len(chunks)) + errNotHandled := errors.New("driver returned NotHandled") + var g errgroup.Group + for i, chunk := range chunks { + i := i + chunk := chunk + g.Go(func() (err error) { + responses[i], err = driver(cfg, chunk...) + if responses[i] != nil && responses[i].NotHandled { + err = errNotHandled + } + return err + }) + } + if err := g.Wait(); err != nil { + if errors.Is(err, errNotHandled) { + return &packages.DriverResponse{NotHandled: true}, nil + } + return nil, err + } + return mergeResponsesEx(dedup, responses...), nil +} + +//go:linkname splitIntoChunks golang.org/x/tools/go/packages.splitIntoChunks +func splitIntoChunks(patterns []string, argMax int) ([][]string, error) + +//go:linkname findExternalDriver golang.org/x/tools/go/packages.findExternalDriver +func findExternalDriver(cfg *Config) driver + +//go:linkname goListDriver golang.org/x/tools/go/packages.goListDriver +func goListDriver(cfg *Config, patterns ...string) (_ *packages.DriverResponse, err error) + +//go:linkname writeOverlays golang.org/x/tools/internal/gocommand.WriteOverlays +func writeOverlays(overlay map[string][]byte) (filename string, cleanup func(), err error) + +func defaultDriverEx(dedup *Deduper, cfg *Config, patterns ...string) (*packages.DriverResponse, bool, error) { + const ( + // windowsArgMax specifies the maximum command line length for + // the Windows' CreateProcess function. + windowsArgMax = 32767 + // maxEnvSize is a very rough estimation of the maximum environment + // size of a user. + maxEnvSize = 16384 + // safeArgMax specifies the maximum safe command line length to use + // by the underlying driver excl. the environment. We choose the Windows' + // ARG_MAX as the starting point because it's one of the lowest ARG_MAX + // constants out of the different supported platforms, + // e.g., https://www.in-ulm.de/~mascheck/various/argmax/#results. + safeArgMax = windowsArgMax - maxEnvSize + ) + chunks, err := splitIntoChunks(patterns, safeArgMax) + if err != nil { + return nil, false, err + } + + if driver := findExternalDriver(cfg); driver != nil { + response, err := callDriverOnChunksEx(dedup, driver, cfg, chunks) + if err != nil { + return nil, false, err + } else if !response.NotHandled { + return response, true, nil + } + // (fall through) + } + + // go list fallback + // + // Write overlays once, as there are many calls + // to 'go list' (one per chunk plus others too). + overlay, cleanupOverlay, err := writeOverlays(cfg.Overlay) + if err != nil { + return nil, false, err + } + defer cleanupOverlay() + setGoListOverlayFile(cfg, overlay) + + response, err := callDriverOnChunksEx(dedup, goListDriver, cfg, chunks) + if err != nil { + return nil, false, err + } + return response, false, err +} + //go:linkname newLoader golang.org/x/tools/go/packages.newLoader func newLoader(cfg *Config) *loader -//go:linkname defaultDriver golang.org/x/tools/go/packages.defaultDriver -func defaultDriver(cfg *Config, patterns ...string) (*packages.DriverResponse, bool, error) - //go:linkname refine golang.org/x/tools/go/packages.(*loader).refine func refine(ld *loader, response *packages.DriverResponse) ([]*Package, error) @@ -101,9 +228,9 @@ func refine(ld *loader, response *packages.DriverResponse) ([]*Package, error) // return an error. Clients may need to handle such errors before // proceeding with further analysis. The PrintErrors function is // provided for convenient display of all errors. -func LoadEx(sizes func(types.Sizes) types.Sizes, cfg *Config, patterns ...string) ([]*Package, error) { +func LoadEx(dedup *Deduper, sizes func(types.Sizes) types.Sizes, cfg *Config, patterns ...string) ([]*Package, error) { ld := newLoader(cfg) - response, external, err := defaultDriver(&ld.Config, patterns...) + response, external, err := defaultDriverEx(dedup, &ld.Config, patterns...) if err != nil { return nil, err }