284 lines
7.6 KiB
Go
284 lines
7.6 KiB
Go
package socks5
|
|
|
|
import (
|
|
"context"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/netip"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
type socksConn struct {
|
|
// Injected fields
|
|
dialer *net.Dialer
|
|
username string
|
|
password string
|
|
clientConn net.Conn
|
|
logger Logger
|
|
}
|
|
|
|
func (c *socksConn) closeClientConn(ctxErr error) {
|
|
err := c.clientConn.Close()
|
|
if err != nil && ctxErr == nil {
|
|
c.logger.Warnf("closing client connection: %s", err)
|
|
}
|
|
}
|
|
|
|
func (c *socksConn) run(ctx context.Context) error {
|
|
authMethod := authNotRequired
|
|
if c.username != "" || c.password != "" {
|
|
authMethod = authUsernamePassword
|
|
}
|
|
|
|
err := verifyFirstNegotiation(c.clientConn, authMethod)
|
|
if err != nil {
|
|
replyMethod := authMethod
|
|
if errors.Is(err, ErrNoMethodIdentifiers) || errors.Is(err, ErrNoValidMethodIdentifier) {
|
|
replyMethod = authNotAcceptable
|
|
}
|
|
_, writeErr := c.clientConn.Write([]byte{socks5Version, byte(replyMethod)})
|
|
if writeErr != nil {
|
|
c.logger.Warnf("failed writing first negotiation reply: %s", writeErr)
|
|
}
|
|
c.closeClientConn(ctx.Err())
|
|
return fmt.Errorf("verifying first negotiation: %w", err)
|
|
}
|
|
|
|
_, err = c.clientConn.Write([]byte{socks5Version, byte(authMethod)})
|
|
if err != nil {
|
|
c.closeClientConn(ctx.Err())
|
|
return fmt.Errorf("writing first negotiation reply: %w", err)
|
|
}
|
|
|
|
switch authMethod {
|
|
case authNotRequired, authNotAcceptable:
|
|
case authGssapi:
|
|
panic("not implemented")
|
|
// TODO
|
|
case authUsernamePassword:
|
|
// See https://datatracker.ietf.org/doc/html/rfc1929#section-2
|
|
err = usernamePasswordSubnegotiate(c.clientConn, c.username, c.password)
|
|
if err != nil {
|
|
// If the server returns a `failure' (STATUS value other than X'00') status,
|
|
// it MUST close the connection.
|
|
c.closeClientConn(ctx.Err())
|
|
return fmt.Errorf("subnegotiating username and password: %w", err)
|
|
}
|
|
default:
|
|
panic(fmt.Sprintf("unimplemented auth method %d", authMethod))
|
|
}
|
|
|
|
err = c.handleRequest(ctx)
|
|
c.closeClientConn(ctx.Err())
|
|
if err != nil {
|
|
return fmt.Errorf("handling request: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
var (
|
|
ErrCommandNotSupported = errors.New("command not supported")
|
|
)
|
|
|
|
func (c *socksConn) handleRequest(ctx context.Context) error {
|
|
const socksVersion = socks5Version
|
|
request, err := decodeRequest(c.clientConn, socksVersion)
|
|
if err != nil {
|
|
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
|
|
return err
|
|
}
|
|
if request.command != connect {
|
|
c.encodeFailedResponse(c.clientConn, socksVersion, commandNotSupported)
|
|
return fmt.Errorf("%w: %s", ErrCommandNotSupported, request.command)
|
|
}
|
|
|
|
destinationAddress := net.JoinHostPort(request.destination, fmt.Sprint(request.port))
|
|
destinationConn, err := c.dialer.DialContext(ctx, "tcp", destinationAddress)
|
|
if err != nil {
|
|
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
|
|
return err
|
|
}
|
|
defer destinationConn.Close()
|
|
|
|
destinationServerAddress := destinationConn.LocalAddr().String()
|
|
destinationAddr, destinationPortStr, err := net.SplitHostPort(destinationServerAddress)
|
|
fmt.Println("===", destinationServerAddress)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
destinationPort, err := strconv.Atoi(destinationPortStr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var bindAddrType addrType
|
|
if ip := net.ParseIP(destinationAddr); ip != nil {
|
|
if ip.To4() != nil {
|
|
bindAddrType = ipv4
|
|
} else {
|
|
bindAddrType = ipv6
|
|
}
|
|
} else {
|
|
bindAddrType = domainName
|
|
}
|
|
|
|
err = c.encodeSuccessResponse(c.clientConn, socksVersion, succeeded, bindAddrType,
|
|
destinationAddr, uint16(destinationPort))
|
|
if err != nil {
|
|
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
|
|
return fmt.Errorf("writing successful %s response: %w", request.command, err)
|
|
}
|
|
|
|
errc := make(chan error)
|
|
go func() {
|
|
_, err := io.Copy(c.clientConn, destinationConn)
|
|
if err != nil {
|
|
err = fmt.Errorf("from backend to client: %w", err)
|
|
}
|
|
errc <- err
|
|
}()
|
|
go func() {
|
|
_, err := io.Copy(destinationConn, c.clientConn)
|
|
if err != nil {
|
|
err = fmt.Errorf("from client to backend: %w", err)
|
|
}
|
|
errc <- err
|
|
}()
|
|
select {
|
|
case err := <-errc:
|
|
return err
|
|
case <-ctx.Done():
|
|
_ = destinationConn.Close()
|
|
_ = c.clientConn.Close()
|
|
return nil
|
|
}
|
|
}
|
|
|
|
var (
|
|
ErrVersionNotSupported = errors.New("version not supported")
|
|
ErrNoMethodIdentifiers = errors.New("no method identifiers")
|
|
)
|
|
|
|
// See https://datatracker.ietf.org/doc/html/rfc1928#section-3
|
|
func verifyFirstNegotiation(reader io.Reader, requiredMethod authMethod) error {
|
|
const headerLength = 2 // version + nMethods bytes
|
|
header := make([]byte, headerLength)
|
|
_, err := io.ReadFull(reader, header[:])
|
|
if err != nil {
|
|
return fmt.Errorf("reading header: %w", err)
|
|
}
|
|
|
|
if header[0] != socks5Version {
|
|
return fmt.Errorf("%w: %d", ErrVersionNotSupported, header[0])
|
|
}
|
|
|
|
nMethods := header[1]
|
|
if nMethods == 0 {
|
|
return fmt.Errorf("%w", ErrNoMethodIdentifiers)
|
|
}
|
|
|
|
methodIdentifiers := make([]byte, nMethods)
|
|
_, err = io.ReadFull(reader, methodIdentifiers)
|
|
if err != nil {
|
|
return fmt.Errorf("reading method identifiers: %w", err)
|
|
}
|
|
for _, methodIdentifier := range methodIdentifiers {
|
|
if methodIdentifier == byte(requiredMethod) {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return makeNoAcceptableMethodError(requiredMethod, methodIdentifiers)
|
|
}
|
|
|
|
var (
|
|
ErrNoValidMethodIdentifier = errors.New("no valid method identifier")
|
|
)
|
|
|
|
func makeNoAcceptableMethodError(requiredAuthMethod authMethod, methodIdentifiers []byte) error {
|
|
methodNames := make([]string, len(methodIdentifiers))
|
|
for i, methodIdentifier := range methodIdentifiers {
|
|
methodNames[i] = fmt.Sprintf("%q", authMethod(methodIdentifier))
|
|
}
|
|
|
|
return fmt.Errorf("%w: none of %s matches %s",
|
|
ErrNoValidMethodIdentifier, strings.Join(methodNames, ", "),
|
|
requiredAuthMethod)
|
|
}
|
|
|
|
// See https://datatracker.ietf.org/doc/html/rfc1928#section-4
|
|
type request struct {
|
|
command cmdType
|
|
destination string
|
|
port uint16
|
|
addressType addrType
|
|
}
|
|
|
|
var (
|
|
ErrRequestSocksVersionMismatch = errors.New("request SOCKS version mismatch")
|
|
ErrAddressTypeNotSupported = errors.New("address type not supported")
|
|
)
|
|
|
|
func decodeRequest(reader io.Reader, expectedVersion byte) (req request, err error) {
|
|
const headerLength = 4
|
|
header := [headerLength]byte{}
|
|
_, err = io.ReadFull(reader, header[:])
|
|
if err != nil {
|
|
return request{}, fmt.Errorf("reading header: %w", err)
|
|
}
|
|
|
|
version := header[0]
|
|
if header[0] != expectedVersion {
|
|
return request{}, fmt.Errorf("%w: expected %d and got %d",
|
|
ErrRequestSocksVersionMismatch, expectedVersion, version)
|
|
}
|
|
|
|
req.command = cmdType(header[1])
|
|
// header[2] is RSV byte
|
|
req.addressType = addrType(header[3])
|
|
|
|
switch req.addressType {
|
|
case ipv4:
|
|
var ip [4]byte
|
|
_, err = io.ReadFull(reader, ip[:])
|
|
if err != nil {
|
|
return request{}, fmt.Errorf("reading IPv4 address: %w", err)
|
|
}
|
|
req.destination = netip.AddrFrom4(ip).String()
|
|
case ipv6:
|
|
var ip [16]byte
|
|
_, err = io.ReadFull(reader, ip[:])
|
|
if err != nil {
|
|
return request{}, fmt.Errorf("reading IPv6 address: %w", err)
|
|
}
|
|
req.destination = netip.AddrFrom16(ip).String()
|
|
case domainName:
|
|
var header [1]byte
|
|
_, err = io.ReadFull(reader, header[:])
|
|
if err != nil {
|
|
return request{}, fmt.Errorf("reading domain name header: %w", err)
|
|
}
|
|
domainName := make([]byte, header[0])
|
|
_, err = io.ReadFull(reader, domainName)
|
|
if err != nil {
|
|
return request{}, fmt.Errorf("reading domain name bytes: %w", err)
|
|
}
|
|
req.destination = string(domainName)
|
|
default:
|
|
return request{}, fmt.Errorf("%w: %d", ErrAddressTypeNotSupported, req.addressType)
|
|
}
|
|
|
|
var portBytes [2]byte
|
|
_, err = io.ReadFull(reader, portBytes[:])
|
|
if err != nil {
|
|
return request{}, fmt.Errorf("reading port: %w", err)
|
|
}
|
|
req.port = binary.BigEndian.Uint16(portBytes[:])
|
|
|
|
return req, nil
|
|
}
|