diff --git a/internal/netlink/types.go b/internal/netlink/types.go index 3633beda..d556068a 100644 --- a/internal/netlink/types.go +++ b/internal/netlink/types.go @@ -32,6 +32,11 @@ type Route struct { Type int } +func (r Route) String() string { + return fmt.Sprintf("{link %d, dst %s, src %s, gw %s, priority %d, family %d, table %d, type %d}", + r.LinkIndex, r.Dst, r.Src, r.Gw, r.Priority, r.Family, r.Table, r.Type) +} + type Rule struct { Priority int Family int diff --git a/internal/vpn/interfaces.go b/internal/vpn/interfaces.go index 68103690..0e9c50a3 100644 --- a/internal/vpn/interfaces.go +++ b/internal/vpn/interfaces.go @@ -67,6 +67,7 @@ type NetLinker interface { type Router interface { RouteList(family int) (routes []netlink.Route, err error) RouteAdd(route netlink.Route) error + RouteReplace(route netlink.Route) error } type Ruler interface { diff --git a/internal/vpn/run.go b/internal/vpn/run.go index a0cc0274..8d8f1e4f 100644 --- a/internal/vpn/run.go +++ b/internal/vpn/run.go @@ -38,7 +38,7 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { l.openvpnConf, providerConf, settings, l.ipv6Supported, l.starter, subLogger) } else { // Wireguard vpnInterface = settings.Wireguard.Interface - vpnRunner, serverName, canPortForward, err = setupWireguard(ctx, l.netLinker, l.fw, + vpnRunner, serverName, canPortForward, err = setupWireguard(ctx, l.netLinker, l.routing, l.fw, providerConf, settings, l.ipv6Supported, subLogger) } if err != nil { diff --git a/internal/vpn/wireguard.go b/internal/vpn/wireguard.go index 7f5c4246..cda44901 100644 --- a/internal/vpn/wireguard.go +++ b/internal/vpn/wireguard.go @@ -13,7 +13,7 @@ import ( // setupWireguard sets Wireguard up using the configurators and settings given. // It returns a serverName for port forwarding (PIA) and an error if it fails. -func setupWireguard(ctx context.Context, netlinker NetLinker, +func setupWireguard(ctx context.Context, netlinker NetLinker, routing Routing, fw Firewall, providerConf provider.Provider, settings settings.VPN, ipv6Supported bool, logger wireguard.Logger) ( wireguarder *wireguard.Wireguard, serverName string, canPortForward bool, err error, @@ -29,7 +29,7 @@ func setupWireguard(ctx context.Context, netlinker NetLinker, logger.Debug("Wireguard client private key: " + gosettings.ObfuscateKey(wireguardSettings.PrivateKey)) logger.Debug("Wireguard pre-shared key: " + gosettings.ObfuscateKey(wireguardSettings.PreSharedKey)) - wireguarder, err = wireguard.New(wireguardSettings, netlinker, logger) + wireguarder, err = wireguard.New(wireguardSettings, netlinker, routing, logger) if err != nil { return nil, "", false, fmt.Errorf("creating Wireguard: %w", err) } diff --git a/internal/wireguard/constructor.go b/internal/wireguard/constructor.go index 2b54498b..717841d8 100644 --- a/internal/wireguard/constructor.go +++ b/internal/wireguard/constructor.go @@ -4,10 +4,11 @@ type Wireguard struct { logger Logger settings Settings netlink NetLinker + routing Routing } func New(settings Settings, netlink NetLinker, - logger Logger, + routing Routing, logger Logger, ) (w *Wireguard, err error) { settings.SetDefaults() if err := settings.Check(); err != nil { @@ -18,5 +19,6 @@ func New(settings Settings, netlink NetLinker, logger: logger, settings: settings, netlink: netlink, + routing: routing, }, nil } diff --git a/internal/wireguard/interfaces.go b/internal/wireguard/interfaces.go new file mode 100644 index 00000000..b80ee0e0 --- /dev/null +++ b/internal/wireguard/interfaces.go @@ -0,0 +1,7 @@ +package wireguard + +import "net/netip" + +type Routing interface { + VPNLocalGatewayIP(vpnInterface string) (gateway netip.Addr, err error) +} diff --git a/internal/wireguard/netlinker.go b/internal/wireguard/netlinker.go index 6b077016..7433e9ef 100644 --- a/internal/wireguard/netlinker.go +++ b/internal/wireguard/netlinker.go @@ -1,6 +1,8 @@ package wireguard -import "github.com/qdm12/gluetun/internal/netlink" +import ( + "github.com/qdm12/gluetun/internal/netlink" +) //go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker @@ -15,6 +17,7 @@ type NetLinker interface { type Router interface { RouteList(family int) (routes []netlink.Route, err error) RouteAdd(route netlink.Route) error + RouteReplace(route netlink.Route) error } type Ruler interface { diff --git a/internal/wireguard/route.go b/internal/wireguard/route.go index 84178d03..b45d400c 100644 --- a/internal/wireguard/route.go +++ b/internal/wireguard/route.go @@ -1,6 +1,7 @@ package wireguard import ( + "errors" "fmt" "net/netip" "strings" @@ -29,6 +30,10 @@ func (w *Wireguard) addRoutes(link netlink.Link, destinations []netip.Prefix, return nil } +var ( + ErrDefaultRouteNotFound = errors.New("default route not found") +) + func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix, firewallMark uint32, ) (err error) { @@ -45,5 +50,39 @@ func (w *Wireguard) addRoute(link netlink.Link, dst netip.Prefix, link.Name, dst, firewallMark, err) } + vpnGatewayIP, err := w.routing.VPNLocalGatewayIP(link.Name) + if err != nil { + return fmt.Errorf("getting VPN gateway IP: %w", err) + } + + routes, err := w.netlink.RouteList(netlink.FamilyV4) + if err != nil { + return fmt.Errorf("listing routes: %w", err) + } + + var defaultRoute netlink.Route + var defaultRouteFound bool + for _, route = range routes { + if !route.Dst.IsValid() || route.Dst.Addr().IsUnspecified() { + defaultRoute = route + defaultRouteFound = true + break + } + } + + if !defaultRouteFound { + return fmt.Errorf("%w: in %d routes", ErrDefaultRouteNotFound, len(routes)) + } + + // Equivalent replacement to: + // ip route replace default via dev tun0 + defaultRoute.Gw = vpnGatewayIP + defaultRoute.LinkIndex = link.Index + + err = w.netlink.RouteReplace(defaultRoute) + if err != nil { + return fmt.Errorf("replacing default route: %w", err) + } + return err }