diff --git a/internal/socks5/constants.go b/internal/socks5/constants.go new file mode 100644 index 00000000..bb185ad6 --- /dev/null +++ b/internal/socks5/constants.go @@ -0,0 +1,86 @@ +package socks5 + +import "fmt" + +// See https://datatracker.ietf.org/doc/html/rfc1928#section-3 +type authMethod byte + +const ( + authNotRequired authMethod = 0 + authGssapi authMethod = 1 + authUsernamePassword authMethod = 2 + authNotAcceptable authMethod = 255 +) + +func (a authMethod) String() string { + switch a { + case authNotRequired: + return "no authentication required" + case authGssapi: + return "GSSAPI" + case authUsernamePassword: + return "username/password" + case authNotAcceptable: + return "no acceptable methods" + default: + return fmt.Sprintf("unknown method (%d)", a) + } +} + +// Subnegotiation version +// See https://datatracker.ietf.org/doc/html/rfc1929#section-2 +const ( + authUsernamePasswordSubNegotiation1 byte = 1 +) + +// SOCKS versions. +const ( + socks5Version byte = 5 +) + +// See https://datatracker.ietf.org/doc/html/rfc1928#section-4 +type cmdType byte + +const ( + connect cmdType = 1 + bind cmdType = 2 + udpAssociate cmdType = 3 +) + +func (c cmdType) String() string { + switch c { + case connect: + return "connect" + case bind: + return "bind" + case udpAssociate: + return "UDP associate" + default: + return fmt.Sprintf("unknown command (%d)", c) + } +} + +// See https://datatracker.ietf.org/doc/html/rfc1928#section-4 and +// https://datatracker.ietf.org/doc/html/rfc1928#section-5 +type addrType byte + +const ( + ipv4 addrType = 1 + domainName addrType = 3 + ipv6 addrType = 4 +) + +// See https://datatracker.ietf.org/doc/html/rfc1928#section-6 +type replyCode byte + +const ( + succeeded replyCode = iota + generalServerFailure + connectionNotAllowedByRuleset + networkUnreachable + hostUnreachable + connectionRefused + ttlExpired + commandNotSupported + addressTypeNotSupported +) diff --git a/internal/socks5/interfaces.go b/internal/socks5/interfaces.go new file mode 100644 index 00000000..a9951848 --- /dev/null +++ b/internal/socks5/interfaces.go @@ -0,0 +1,6 @@ +package socks5 + +type Logger interface { + Infof(format string, a ...interface{}) + Warnf(format string, a ...interface{}) +} diff --git a/internal/socks5/response.go b/internal/socks5/response.go new file mode 100644 index 00000000..1a6a1e69 --- /dev/null +++ b/internal/socks5/response.go @@ -0,0 +1,103 @@ +package socks5 + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "net/netip" +) + +// See https://datatracker.ietf.org/doc/html/rfc1928#section-6 +func (c *socksConn) encodeFailedResponse(writer io.Writer, socksVersion byte, reply replyCode) { + _, err := writer.Write([]byte{ + socksVersion, + byte(reply), + 0, // RSV byte + // TODO do we need to set the bind addr type to 0?? + }) + if err != nil { + c.logger.Warnf("failed writing failed response: %s", err) + } +} + +// See https://datatracker.ietf.org/doc/html/rfc1928#section-6 +func (c *socksConn) encodeSuccessResponse(writer io.Writer, socksVersion byte, + reply replyCode, bindAddrType addrType, bindAddress string, + bindPort uint16) (err error) { + bindData, err := encodeBindData(bindAddrType, bindAddress, bindPort) + if err != nil { // TODO encode with below block if this changes + return err + } + + const initialPacketLength = 3 + capacity := initialPacketLength + len(bindData) + packet := make([]byte, initialPacketLength, capacity) + packet[0] = socksVersion + packet[1] = byte(reply) + packet[2] = 0 // RSV byte + packet = append(packet, bindData...) + + _, err = writer.Write(packet) + if err != nil { + c.logger.Warnf("failed writing success response: %s", err) + } + return nil +} + +var ( + ErrIPVersionUnexpected = errors.New("ip version is unexpected") + ErrDomainNameTooLong = errors.New("domain name is too long") +) + +func encodeBindData(addrType addrType, address string, port uint16) ( + data []byte, err error) { + capacity := bindDataLength(addrType, address) + data = make([]byte, 0, capacity) + + data = append(data, byte(addrType)) + switch addrType { + case ipv4, ipv6: + ip, err := netip.ParseAddr(address) + if err != nil { + return nil, fmt.Errorf("parsing IP address: %w", err) + } + + switch { + case addrType == ipv4 && !ip.Is4(): + return nil, fmt.Errorf("%w: expected IPv4 for %s", ErrIPVersionUnexpected, ip) + case addrType == ipv6 && !ip.Is6(): + return nil, fmt.Errorf("%w: expected IPv6 for %s", ErrIPVersionUnexpected, ip) + } + data = append(data, ip.AsSlice()...) + case domainName: + const maxDomainNameLength = 255 + if len(address) > maxDomainNameLength { + return nil, fmt.Errorf("%w: %s", ErrDomainNameTooLong, address) + } + data = append(data, byte(len(address))) + data = append(data, []byte(address)...) + default: + panic(fmt.Sprintf("unsupported address type %d", addrType)) + } + data = binary.BigEndian.AppendUint16(data, port) + return data, nil +} + +func bindDataLength(addrType addrType, address string) (maxLength int) { + maxLength++ // address type + switch addrType { + case ipv4: + maxLength += net.IPv4len + case domainName: + maxLength++ // domain name length + maxLength += len([]byte(address)) + case ipv6: + maxLength += net.IPv6len + default: + panic("unsupported address type: " + fmt.Sprint(addrType)) + } + maxLength += 2 // port + return maxLength +} diff --git a/internal/socks5/server.go b/internal/socks5/server.go new file mode 100644 index 00000000..5a9f3abe --- /dev/null +++ b/internal/socks5/server.go @@ -0,0 +1,105 @@ +package socks5 + +import ( + "context" + "fmt" + "net" + "sync" + "sync/atomic" +) + +type Server struct { + username string + password string + address string + logger Logger + + // internal fields + listener net.Listener + listening atomic.Bool + socksConnCtx context.Context //nolint:containedctx + socksConnCancel context.CancelFunc + done <-chan struct{} + stopping atomic.Bool +} + +func New(settings Settings) *Server { + return &Server{ + username: settings.Username, + password: settings.Password, + address: settings.Address, + logger: settings.Logger, + } +} + +func (s *Server) Start(_ context.Context) (runErr <-chan error, err error) { + s.listener, err = net.Listen("tcp", s.address) + if err != nil { + return nil, fmt.Errorf("listening on %s: %w", s.address, err) + } + s.listening.Store(true) + + s.socksConnCtx, s.socksConnCancel = context.WithCancel(context.Background()) + + ready := make(chan struct{}) + runErrCh := make(chan error) + runErr = runErrCh + done := make(chan struct{}) + s.done = done + go s.runServer(ready, runErrCh, done) + <-ready + return runErr, nil +} + +func (s *Server) runServer(ready chan<- struct{}, + runErrCh chan<- error, done chan<- struct{}) { + close(ready) + defer close(done) + wg := new(sync.WaitGroup) + defer wg.Wait() + + dialer := &net.Dialer{} + for { + connection, err := s.listener.Accept() + if err != nil { + if !s.stopping.Load() { + _ = s.Stop() + runErrCh <- fmt.Errorf("accepting connection: %w", err) + } + return + } + wg.Add(1) + go func(ctx context.Context, connection net.Conn, + dialer *net.Dialer, wg *sync.WaitGroup) { + defer wg.Done() + socksConn := &socksConn{ + dialer: dialer, + username: s.username, + password: s.password, + clientConn: connection, + logger: s.logger, + } + err := socksConn.run(ctx) + if err != nil { + s.logger.Infof("running socks connection: %s", err) + } + }(s.socksConnCtx, connection, dialer, wg) + } +} + +func (s *Server) Stop() (err error) { + s.stopping.Store(true) + s.listening.Store(false) + err = s.listener.Close() + s.socksConnCancel() // stop ongoing socks connections + <-s.done // wait for run goroutine to finish + s.stopping.Store(false) + return err +} + +func (s *Server) listeningAddress() net.Addr { + if s.listening.Load() { + return s.listener.Addr() + } + return nil +} diff --git a/internal/socks5/settings.go b/internal/socks5/settings.go new file mode 100644 index 00000000..77b7c9ed --- /dev/null +++ b/internal/socks5/settings.go @@ -0,0 +1,8 @@ +package socks5 + +type Settings struct { + Username string + Password string + Address string + Logger Logger +} diff --git a/internal/socks5/socks5.go b/internal/socks5/socks5.go new file mode 100644 index 00000000..9edbce0f --- /dev/null +++ b/internal/socks5/socks5.go @@ -0,0 +1,283 @@ +package socks5 + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "net/netip" + "strconv" + "strings" +) + +type socksConn struct { + // Injected fields + dialer *net.Dialer + username string + password string + clientConn net.Conn + logger Logger +} + +func (c *socksConn) closeClientConn(ctxErr error) { + err := c.clientConn.Close() + if err != nil && ctxErr == nil { + c.logger.Warnf("closing client connection: %s", err) + } +} + +func (c *socksConn) run(ctx context.Context) error { + authMethod := authNotRequired + if c.username != "" || c.password != "" { + authMethod = authUsernamePassword + } + + err := verifyFirstNegotiation(c.clientConn, authMethod) + if err != nil { + replyMethod := authMethod + if errors.Is(err, ErrNoMethodIdentifiers) || errors.Is(err, ErrNoValidMethodIdentifier) { + replyMethod = authNotAcceptable + } + _, writeErr := c.clientConn.Write([]byte{socks5Version, byte(replyMethod)}) + if writeErr != nil { + c.logger.Warnf("failed writing first negotiation reply: %s", writeErr) + } + c.closeClientConn(ctx.Err()) + return fmt.Errorf("verifying first negotiation: %w", err) + } + + _, err = c.clientConn.Write([]byte{socks5Version, byte(authMethod)}) + if err != nil { + c.closeClientConn(ctx.Err()) + return fmt.Errorf("writing first negotiation reply: %w", err) + } + + switch authMethod { + case authNotRequired, authNotAcceptable: + case authGssapi: + panic("not implemented") + // TODO + case authUsernamePassword: + // See https://datatracker.ietf.org/doc/html/rfc1929#section-2 + err = usernamePasswordSubnegotiate(c.clientConn, c.username, c.password) + if err != nil { + // If the server returns a `failure' (STATUS value other than X'00') status, + // it MUST close the connection. + c.closeClientConn(ctx.Err()) + return fmt.Errorf("subnegotiating username and password: %w", err) + } + default: + panic(fmt.Sprintf("unimplemented auth method %d", authMethod)) + } + + err = c.handleRequest(ctx) + c.closeClientConn(ctx.Err()) + if err != nil { + return fmt.Errorf("handling request: %w", err) + } + return nil +} + +var ( + ErrCommandNotSupported = errors.New("command not supported") +) + +func (c *socksConn) handleRequest(ctx context.Context) error { + const socksVersion = socks5Version + request, err := decodeRequest(c.clientConn, socksVersion) + if err != nil { + c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure) + return err + } + if request.command != connect { + c.encodeFailedResponse(c.clientConn, socksVersion, commandNotSupported) + return fmt.Errorf("%w: %s", ErrCommandNotSupported, request.command) + } + + destinationAddress := net.JoinHostPort(request.destination, fmt.Sprint(request.port)) + destinationConn, err := c.dialer.DialContext(ctx, "tcp", destinationAddress) + if err != nil { + c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure) + return err + } + defer destinationConn.Close() + + destinationServerAddress := destinationConn.LocalAddr().String() + destinationAddr, destinationPortStr, err := net.SplitHostPort(destinationServerAddress) + fmt.Println("===", destinationServerAddress) + if err != nil { + return err + } + destinationPort, err := strconv.Atoi(destinationPortStr) + if err != nil { + return err + } + + var bindAddrType addrType + if ip := net.ParseIP(destinationAddr); ip != nil { + if ip.To4() != nil { + bindAddrType = ipv4 + } else { + bindAddrType = ipv6 + } + } else { + bindAddrType = domainName + } + + err = c.encodeSuccessResponse(c.clientConn, socksVersion, succeeded, bindAddrType, + destinationAddr, uint16(destinationPort)) + if err != nil { + c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure) + return fmt.Errorf("writing successful %s response: %w", request.command, err) + } + + errc := make(chan error) + go func() { + _, err := io.Copy(c.clientConn, destinationConn) + if err != nil { + err = fmt.Errorf("from backend to client: %w", err) + } + errc <- err + }() + go func() { + _, err := io.Copy(destinationConn, c.clientConn) + if err != nil { + err = fmt.Errorf("from client to backend: %w", err) + } + errc <- err + }() + select { + case err := <-errc: + return err + case <-ctx.Done(): + _ = destinationConn.Close() + _ = c.clientConn.Close() + return nil + } +} + +var ( + ErrVersionNotSupported = errors.New("version not supported") + ErrNoMethodIdentifiers = errors.New("no method identifiers") +) + +// See https://datatracker.ietf.org/doc/html/rfc1928#section-3 +func verifyFirstNegotiation(reader io.Reader, requiredMethod authMethod) error { + const headerLength = 2 // version + nMethods bytes + header := make([]byte, headerLength) + _, err := io.ReadFull(reader, header[:]) + if err != nil { + return fmt.Errorf("reading header: %w", err) + } + + if header[0] != socks5Version { + return fmt.Errorf("%w: %d", ErrVersionNotSupported, header[0]) + } + + nMethods := header[1] + if nMethods == 0 { + return fmt.Errorf("%w", ErrNoMethodIdentifiers) + } + + methodIdentifiers := make([]byte, nMethods) + _, err = io.ReadFull(reader, methodIdentifiers) + if err != nil { + return fmt.Errorf("reading method identifiers: %w", err) + } + for _, methodIdentifier := range methodIdentifiers { + if methodIdentifier == byte(requiredMethod) { + return nil + } + } + + return makeNoAcceptableMethodError(requiredMethod, methodIdentifiers) +} + +var ( + ErrNoValidMethodIdentifier = errors.New("no valid method identifier") +) + +func makeNoAcceptableMethodError(requiredAuthMethod authMethod, methodIdentifiers []byte) error { + methodNames := make([]string, len(methodIdentifiers)) + for i, methodIdentifier := range methodIdentifiers { + methodNames[i] = fmt.Sprintf("%q", authMethod(methodIdentifier)) + } + + return fmt.Errorf("%w: none of %s matches %s", + ErrNoValidMethodIdentifier, strings.Join(methodNames, ", "), + requiredAuthMethod) +} + +// See https://datatracker.ietf.org/doc/html/rfc1928#section-4 +type request struct { + command cmdType + destination string + port uint16 + addressType addrType +} + +var ( + ErrRequestSocksVersionMismatch = errors.New("request SOCKS version mismatch") + ErrAddressTypeNotSupported = errors.New("address type not supported") +) + +func decodeRequest(reader io.Reader, expectedVersion byte) (req request, err error) { + const headerLength = 4 + header := [headerLength]byte{} + _, err = io.ReadFull(reader, header[:]) + if err != nil { + return request{}, fmt.Errorf("reading header: %w", err) + } + + version := header[0] + if header[0] != expectedVersion { + return request{}, fmt.Errorf("%w: expected %d and got %d", + ErrRequestSocksVersionMismatch, expectedVersion, version) + } + + req.command = cmdType(header[1]) + // header[2] is RSV byte + req.addressType = addrType(header[3]) + + switch req.addressType { + case ipv4: + var ip [4]byte + _, err = io.ReadFull(reader, ip[:]) + if err != nil { + return request{}, fmt.Errorf("reading IPv4 address: %w", err) + } + req.destination = netip.AddrFrom4(ip).String() + case ipv6: + var ip [16]byte + _, err = io.ReadFull(reader, ip[:]) + if err != nil { + return request{}, fmt.Errorf("reading IPv6 address: %w", err) + } + req.destination = netip.AddrFrom16(ip).String() + case domainName: + var header [1]byte + _, err = io.ReadFull(reader, header[:]) + if err != nil { + return request{}, fmt.Errorf("reading domain name header: %w", err) + } + domainName := make([]byte, header[0]) + _, err = io.ReadFull(reader, domainName) + if err != nil { + return request{}, fmt.Errorf("reading domain name bytes: %w", err) + } + req.destination = string(domainName) + default: + return request{}, fmt.Errorf("%w: %d", ErrAddressTypeNotSupported, req.addressType) + } + + var portBytes [2]byte + _, err = io.ReadFull(reader, portBytes[:]) + if err != nil { + return request{}, fmt.Errorf("reading port: %w", err) + } + req.port = binary.BigEndian.Uint16(portBytes[:]) + + return req, nil +} diff --git a/internal/socks5/socks5_test.go b/internal/socks5/socks5_test.go new file mode 100644 index 00000000..2f947d6a --- /dev/null +++ b/internal/socks5/socks5_test.go @@ -0,0 +1,175 @@ +package socks5 + +import ( + "context" + "fmt" + "io" + "net" + "testing" + "time" + + "github.com/qdm12/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/net/proxy" +) + +func Test(t *testing.T) { + server := New(Settings{ + Username: "test", + Password: "test", + Address: ":8000", + Logger: log.New(), + }) + + runErr, startErr := server.Start(context.Background()) + require.NoError(t, startErr) + + select { + case err := <-runErr: + require.NoError(t, err) + default: + } + + t.Log("SlEEPING") + time.Sleep(15 * time.Second) + t.Log("Done sleeping") + + err := server.Stop() + require.NoError(t, err) +} + +func backendServer(listener net.Listener) { + conn, err := listener.Accept() + if err != nil { + panic(err) + } + conn.Write([]byte("Test")) + conn.Close() + listener.Close() +} + +func TestRead(t *testing.T) { + // backend server which we'll use SOCKS5 to connect to + listener, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + backendServerPort := listener.Addr().(*net.TCPAddr).Port + go backendServer(listener) + + // SOCKS5 server + server := New(Settings{ + Address: ":0", + }) + _, err = server.Start(context.Background()) + require.NoError(t, err) + t.Cleanup(func() { + err = server.Stop() + assert.NoError(t, err) + }) + socks5Port := server.listeningAddress().(*net.TCPAddr).Port + + addr := fmt.Sprintf("localhost:%d", socks5Port) + socksDialer, err := proxy.SOCKS5("tcp", addr, nil, proxy.Direct) + if err != nil { + t.Fatal(err) + } + + addr = fmt.Sprintf("localhost:%d", backendServerPort) + conn, err := socksDialer.Dial("tcp", addr) + if err != nil { + t.Fatal(err) + } + + buf := make([]byte, 4) + _, err = io.ReadFull(conn, buf) + if err != nil { + t.Fatal(err) + } + if string(buf) != "Test" { + t.Fatalf("got: %q want: Test", buf) + } + + err = conn.Close() + if err != nil { + t.Fatal(err) + } +} + +func TestReadPassword(t *testing.T) { + // backend server which we'll use SOCKS5 to connect to + ln, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + backendServerPort := ln.Addr().(*net.TCPAddr).Port + go backendServer(ln) + + auth := &proxy.Auth{User: "foo", Password: "bar"} + + server := Server{ + logger: log.New(), + username: auth.User, + password: auth.Password, + address: ":0", + } + _, err = server.Start(context.Background()) + require.NoError(t, err) + + t.Cleanup(func() { + err = server.Stop() + assert.NoError(t, err) + }) + + addr := fmt.Sprintf("localhost:%d", server.listeningAddress().(*net.TCPAddr).Port) + + if d, err := proxy.SOCKS5("tcp", addr, nil, proxy.Direct); err != nil { + t.Fatal(err) + } else { + if _, err := d.Dial("tcp", addr); err == nil { + t.Fatal("expected no-auth dial error") + } + } + + badPwd := &proxy.Auth{User: "foo", Password: "not right"} + if d, err := proxy.SOCKS5("tcp", addr, badPwd, proxy.Direct); err != nil { + t.Fatal(err) + } else { + if _, err := d.Dial("tcp", addr); err == nil { + t.Fatal("expected bad password dial error") + } + } + + badUsr := &proxy.Auth{User: "not right", Password: "bar"} + if d, err := proxy.SOCKS5("tcp", addr, badUsr, proxy.Direct); err != nil { + t.Fatal(err) + } else { + if _, err := d.Dial("tcp", addr); err == nil { + t.Fatal("expected bad username dial error") + } + } + + socksDialer, err := proxy.SOCKS5("tcp", addr, auth, proxy.Direct) + if err != nil { + t.Fatal(err) + } + + addr = fmt.Sprintf("localhost:%d", backendServerPort) + conn, err := socksDialer.Dial("tcp", addr) + if err != nil { + t.Fatal(err) + } + + buf := make([]byte, 4) + if _, err := io.ReadFull(conn, buf); err != nil { + t.Fatal(err) + } + if string(buf) != "Test" { + t.Fatalf("got: %q want: Test", buf) + } + + if err := conn.Close(); err != nil { + t.Fatal(err) + } +} diff --git a/internal/socks5/usernamepassword.go b/internal/socks5/usernamepassword.go new file mode 100644 index 00000000..225fabff --- /dev/null +++ b/internal/socks5/usernamepassword.go @@ -0,0 +1,69 @@ +package socks5 + +import ( + "errors" + "fmt" + "io" +) + +var ( + ErrSubnegotiationVersionNotSupported = errors.New("subnegotiation version not supported") + ErrUsernameNotValid = errors.New("username not valid") + ErrPasswordNotValid = errors.New("password not valid") +) + +// See https://datatracker.ietf.org/doc/html/rfc1929#section-2 +func usernamePasswordSubnegotiate(conn io.ReadWriter, username, password string) (err error) { + status := byte(1) + const defaultVersion = byte(1) + + const headerLength = 2 + var header [headerLength]byte + _, err = io.ReadFull(conn, header[:]) + if err != nil { + _, _ = conn.Write([]byte{defaultVersion, status}) + return fmt.Errorf("reading header: %w", err) + } + + if header[0] != authUsernamePasswordSubNegotiation1 { + _, _ = conn.Write([]byte{defaultVersion, status}) + return fmt.Errorf("%w: %d", ErrSubnegotiationVersionNotSupported, header[0]) + } + version := header[0] + + usernameBytes := make([]byte, header[1]) + _, err = io.ReadFull(conn, usernameBytes) + if err != nil { + _, _ = conn.Write([]byte{version, status}) + return fmt.Errorf("reading username bytes: %w", err) + } else if username != string(usernameBytes) { + _, _ = conn.Write([]byte{version, status}) + return fmt.Errorf("%w: %s", ErrUsernameNotValid, string(usernameBytes)) + } + + const passwordHeaderLength = 1 + passwordHeader := make([]byte, passwordHeaderLength) + _, err = io.ReadFull(conn, passwordHeader[:]) + if err != nil { + _, _ = conn.Write([]byte{version, status}) + return fmt.Errorf("reading password length: %w", err) + } + + passwordBytes := make([]byte, passwordHeader[0]) + _, err = io.ReadFull(conn, passwordBytes) + if err != nil { + _, _ = conn.Write([]byte{version, status}) + return fmt.Errorf("reading password bytes: %w", err) + } else if password != string(passwordBytes) { + _, _ = conn.Write([]byte{version, status}) + return fmt.Errorf("%w: %s", ErrPasswordNotValid, string(passwordBytes)) + } + + status = 0 + _, err = conn.Write([]byte{version, status}) + if err != nil { + return fmt.Errorf("writing success status: %w", err) + } + + return nil +}