155 lines
3.0 KiB
Go
155 lines
3.0 KiB
Go
|
|
// Copyright 2024 The Go Authors. All rights reserved.
|
||
|
|
// Use of this source code is governed by a BSD-style
|
||
|
|
// license that can be found in the LICENSE file.
|
||
|
|
|
||
|
|
package os_test
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bytes"
|
||
|
|
"errors"
|
||
|
|
"io"
|
||
|
|
"math/rand/v2"
|
||
|
|
"net"
|
||
|
|
"os"
|
||
|
|
"runtime"
|
||
|
|
"sync"
|
||
|
|
"testing"
|
||
|
|
|
||
|
|
"golang.org/x/net/nettest"
|
||
|
|
)
|
||
|
|
|
||
|
|
// Exercise sendfile/splice fast paths with a moderately large file.
|
||
|
|
//
|
||
|
|
// https://go.dev/issue/70000
|
||
|
|
|
||
|
|
func TestLargeCopyViaNetwork(t *testing.T) {
|
||
|
|
const size = 10 * 1024 * 1024
|
||
|
|
dir := t.TempDir()
|
||
|
|
|
||
|
|
src, err := os.Create(dir + "/src")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatal(err)
|
||
|
|
}
|
||
|
|
defer src.Close()
|
||
|
|
if _, err := io.CopyN(src, newRandReader(), size); err != nil {
|
||
|
|
t.Fatal(err)
|
||
|
|
}
|
||
|
|
if _, err := src.Seek(0, 0); err != nil {
|
||
|
|
t.Fatal(err)
|
||
|
|
}
|
||
|
|
|
||
|
|
dst, err := os.Create(dir + "/dst")
|
||
|
|
if err != nil {
|
||
|
|
t.Fatal(err)
|
||
|
|
}
|
||
|
|
defer dst.Close()
|
||
|
|
|
||
|
|
client, server := createSocketPair(t, "tcp")
|
||
|
|
var wg sync.WaitGroup
|
||
|
|
wg.Add(2)
|
||
|
|
go func() {
|
||
|
|
defer wg.Done()
|
||
|
|
if n, err := io.Copy(dst, server); n != size || err != nil {
|
||
|
|
t.Errorf("copy to destination = %v, %v; want %v, nil", n, err, size)
|
||
|
|
}
|
||
|
|
}()
|
||
|
|
go func() {
|
||
|
|
defer wg.Done()
|
||
|
|
defer client.Close()
|
||
|
|
if n, err := io.Copy(client, src); n != size || err != nil {
|
||
|
|
t.Errorf("copy from source = %v, %v; want %v, nil", n, err, size)
|
||
|
|
}
|
||
|
|
}()
|
||
|
|
wg.Wait()
|
||
|
|
|
||
|
|
if _, err := dst.Seek(0, 0); err != nil {
|
||
|
|
t.Fatal(err)
|
||
|
|
}
|
||
|
|
if err := compareReaders(dst, io.LimitReader(newRandReader(), size)); err != nil {
|
||
|
|
t.Fatal(err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func compareReaders(a, b io.Reader) error {
|
||
|
|
bufa := make([]byte, 4096)
|
||
|
|
bufb := make([]byte, 4096)
|
||
|
|
for {
|
||
|
|
na, erra := io.ReadFull(a, bufa)
|
||
|
|
if erra != nil && erra != io.EOF {
|
||
|
|
return erra
|
||
|
|
}
|
||
|
|
nb, errb := io.ReadFull(b, bufb)
|
||
|
|
if errb != nil && errb != io.EOF {
|
||
|
|
return errb
|
||
|
|
}
|
||
|
|
if !bytes.Equal(bufa[:na], bufb[:nb]) {
|
||
|
|
return errors.New("contents mismatch")
|
||
|
|
}
|
||
|
|
if erra == io.EOF && errb == io.EOF {
|
||
|
|
break
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
type randReader struct {
|
||
|
|
rand *rand.Rand
|
||
|
|
}
|
||
|
|
|
||
|
|
func newRandReader() *randReader {
|
||
|
|
return &randReader{rand.New(rand.NewPCG(0, 0))}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (r *randReader) Read(p []byte) (int, error) {
|
||
|
|
var v uint64
|
||
|
|
var n int
|
||
|
|
for i := range p {
|
||
|
|
if n == 0 {
|
||
|
|
v = r.rand.Uint64()
|
||
|
|
n = 8
|
||
|
|
}
|
||
|
|
p[i] = byte(v & 0xff)
|
||
|
|
v >>= 8
|
||
|
|
n--
|
||
|
|
}
|
||
|
|
return len(p), nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func createSocketPair(t *testing.T, proto string) (client, server net.Conn) {
|
||
|
|
t.Helper()
|
||
|
|
if !nettest.TestableNetwork(proto) {
|
||
|
|
t.Skipf("%s does not support %q", runtime.GOOS, proto)
|
||
|
|
}
|
||
|
|
|
||
|
|
ln, err := nettest.NewLocalListener(proto)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("NewLocalListener error: %v", err)
|
||
|
|
}
|
||
|
|
t.Cleanup(func() {
|
||
|
|
if ln != nil {
|
||
|
|
ln.Close()
|
||
|
|
}
|
||
|
|
if client != nil {
|
||
|
|
client.Close()
|
||
|
|
}
|
||
|
|
if server != nil {
|
||
|
|
server.Close()
|
||
|
|
}
|
||
|
|
})
|
||
|
|
ch := make(chan struct{})
|
||
|
|
go func() {
|
||
|
|
var err error
|
||
|
|
server, err = ln.Accept()
|
||
|
|
if err != nil {
|
||
|
|
t.Errorf("Accept new connection error: %v", err)
|
||
|
|
}
|
||
|
|
ch <- struct{}{}
|
||
|
|
}()
|
||
|
|
client, err = net.Dial(proto, ln.Addr().String())
|
||
|
|
<-ch
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("Dial new connection error: %v", err)
|
||
|
|
}
|
||
|
|
return client, server
|
||
|
|
}
|