feat(wireguard): WIREGUARD_ALLOWED_IPS variable (#1291)
This commit is contained in:
@@ -26,6 +26,10 @@ type Settings struct {
|
||||
// Addresses assigned to the client.
|
||||
// Note IPv6 addresses are ignored if IPv6 is not supported.
|
||||
Addresses []netip.Prefix
|
||||
// AllowedIPs is the IP networks to be routed through
|
||||
// the Wireguard interface.
|
||||
// Note IPv6 addresses are ignored if IPv6 is not supported.
|
||||
AllowedIPs []netip.Prefix
|
||||
// FirewallMark to be used in routing tables and IP rules.
|
||||
// It defaults to 51820 if left to 0.
|
||||
FirewallMark int
|
||||
@@ -68,6 +72,13 @@ func (s *Settings) SetDefaults() {
|
||||
s.IPv6 = &ipv6
|
||||
}
|
||||
|
||||
if len(s.AllowedIPs) == 0 {
|
||||
s.AllowedIPs = append(s.AllowedIPs, allIPv4())
|
||||
if *s.IPv6 {
|
||||
s.AllowedIPs = append(s.AllowedIPs, allIPv6())
|
||||
}
|
||||
}
|
||||
|
||||
if s.Implementation == "" {
|
||||
const defaultImplementation = "auto"
|
||||
s.Implementation = defaultImplementation
|
||||
@@ -75,19 +86,22 @@ func (s *Settings) SetDefaults() {
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInterfaceNameInvalid = errors.New("invalid interface name")
|
||||
ErrPrivateKeyMissing = errors.New("private key is missing")
|
||||
ErrPrivateKeyInvalid = errors.New("cannot parse private key")
|
||||
ErrPublicKeyMissing = errors.New("public key is missing")
|
||||
ErrPublicKeyInvalid = errors.New("cannot parse public key")
|
||||
ErrPreSharedKeyInvalid = errors.New("cannot parse pre-shared key")
|
||||
ErrEndpointAddrMissing = errors.New("endpoint address is missing")
|
||||
ErrEndpointPortMissing = errors.New("endpoint port is missing")
|
||||
ErrAddressMissing = errors.New("interface address is missing")
|
||||
ErrAddressNotValid = errors.New("interface address is not valid")
|
||||
ErrFirewallMarkMissing = errors.New("firewall mark is missing")
|
||||
ErrMTUMissing = errors.New("MTU is missing")
|
||||
ErrImplementationInvalid = errors.New("invalid implementation")
|
||||
ErrInterfaceNameInvalid = errors.New("invalid interface name")
|
||||
ErrPrivateKeyMissing = errors.New("private key is missing")
|
||||
ErrPrivateKeyInvalid = errors.New("cannot parse private key")
|
||||
ErrPublicKeyMissing = errors.New("public key is missing")
|
||||
ErrPublicKeyInvalid = errors.New("cannot parse public key")
|
||||
ErrPreSharedKeyInvalid = errors.New("cannot parse pre-shared key")
|
||||
ErrEndpointAddrMissing = errors.New("endpoint address is missing")
|
||||
ErrEndpointPortMissing = errors.New("endpoint port is missing")
|
||||
ErrAddressMissing = errors.New("interface address is missing")
|
||||
ErrAddressNotValid = errors.New("interface address is not valid")
|
||||
ErrAllowedIPsMissing = errors.New("allowed IPs are missing")
|
||||
ErrAllowedIPNotValid = errors.New("allowed IP is not valid")
|
||||
ErrAllowedIPv6NotSupported = errors.New("allowed IPv6 address not supported")
|
||||
ErrFirewallMarkMissing = errors.New("firewall mark is missing")
|
||||
ErrMTUMissing = errors.New("MTU is missing")
|
||||
ErrImplementationInvalid = errors.New("invalid implementation")
|
||||
)
|
||||
|
||||
var interfaceNameRegexp = regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
|
||||
@@ -132,6 +146,20 @@ func (s *Settings) Check() (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
if len(s.AllowedIPs) == 0 {
|
||||
return fmt.Errorf("%w", ErrAllowedIPsMissing)
|
||||
}
|
||||
for i, allowedIP := range s.AllowedIPs {
|
||||
switch {
|
||||
case !allowedIP.IsValid():
|
||||
return fmt.Errorf("%w: for allowed IP %d of %d",
|
||||
ErrAllowedIPNotValid, i+1, len(s.AllowedIPs))
|
||||
case allowedIP.Addr().Is6() && !*s.IPv6:
|
||||
return fmt.Errorf("%w: for allowed IP %s",
|
||||
ErrAllowedIPv6NotSupported, allowedIP)
|
||||
}
|
||||
}
|
||||
|
||||
if s.FirewallMark == 0 {
|
||||
return fmt.Errorf("%w", ErrFirewallMarkMissing)
|
||||
}
|
||||
@@ -247,5 +275,16 @@ func (s Settings) ToLines(settings ToLinesSettings) (lines []string) {
|
||||
}
|
||||
}
|
||||
|
||||
if len(s.AllowedIPs) > 0 {
|
||||
lines = append(lines, fieldPrefix+"Allowed IPs:")
|
||||
for i, allowedIP := range s.AllowedIPs {
|
||||
prefix := fieldPrefix
|
||||
if i == len(s.AllowedIPs)-1 {
|
||||
prefix = lastFieldPrefix
|
||||
}
|
||||
lines = append(lines, indent+prefix+allowedIP.String())
|
||||
}
|
||||
}
|
||||
|
||||
return lines
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user