Files
go-legacy-win7/src/os/copy_test.go

155 lines
3.0 KiB
Go
Raw Normal View History

2024-11-09 19:25:36 +11:00
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package os_test
import (
"bytes"
"errors"
"io"
"math/rand/v2"
"net"
"os"
"runtime"
"sync"
"testing"
"golang.org/x/net/nettest"
)
// Exercise sendfile/splice fast paths with a moderately large file.
//
// https://go.dev/issue/70000
func TestLargeCopyViaNetwork(t *testing.T) {
const size = 10 * 1024 * 1024
dir := t.TempDir()
src, err := os.Create(dir + "/src")
if err != nil {
t.Fatal(err)
}
defer src.Close()
if _, err := io.CopyN(src, newRandReader(), size); err != nil {
t.Fatal(err)
}
if _, err := src.Seek(0, 0); err != nil {
t.Fatal(err)
}
dst, err := os.Create(dir + "/dst")
if err != nil {
t.Fatal(err)
}
defer dst.Close()
client, server := createSocketPair(t, "tcp")
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
if n, err := io.Copy(dst, server); n != size || err != nil {
t.Errorf("copy to destination = %v, %v; want %v, nil", n, err, size)
}
}()
go func() {
defer wg.Done()
defer client.Close()
if n, err := io.Copy(client, src); n != size || err != nil {
t.Errorf("copy from source = %v, %v; want %v, nil", n, err, size)
}
}()
wg.Wait()
if _, err := dst.Seek(0, 0); err != nil {
t.Fatal(err)
}
if err := compareReaders(dst, io.LimitReader(newRandReader(), size)); err != nil {
t.Fatal(err)
}
}
func compareReaders(a, b io.Reader) error {
bufa := make([]byte, 4096)
bufb := make([]byte, 4096)
for {
na, erra := io.ReadFull(a, bufa)
if erra != nil && erra != io.EOF {
return erra
}
nb, errb := io.ReadFull(b, bufb)
if errb != nil && errb != io.EOF {
return errb
}
if !bytes.Equal(bufa[:na], bufb[:nb]) {
return errors.New("contents mismatch")
}
if erra == io.EOF && errb == io.EOF {
break
}
}
return nil
}
type randReader struct {
rand *rand.Rand
}
func newRandReader() *randReader {
return &randReader{rand.New(rand.NewPCG(0, 0))}
}
func (r *randReader) Read(p []byte) (int, error) {
var v uint64
var n int
for i := range p {
if n == 0 {
v = r.rand.Uint64()
n = 8
}
p[i] = byte(v & 0xff)
v >>= 8
n--
}
return len(p), nil
}
func createSocketPair(t *testing.T, proto string) (client, server net.Conn) {
t.Helper()
if !nettest.TestableNetwork(proto) {
t.Skipf("%s does not support %q", runtime.GOOS, proto)
}
ln, err := nettest.NewLocalListener(proto)
if err != nil {
t.Fatalf("NewLocalListener error: %v", err)
}
t.Cleanup(func() {
if ln != nil {
ln.Close()
}
if client != nil {
client.Close()
}
if server != nil {
server.Close()
}
})
ch := make(chan struct{})
go func() {
var err error
server, err = ln.Accept()
if err != nil {
t.Errorf("Accept new connection error: %v", err)
}
ch <- struct{}{}
}()
client, err = net.Dial(proto, ln.Addr().String())
<-ch
if err != nil {
t.Fatalf("Dial new connection error: %v", err)
}
return client, server
}