Code maintenance: Using qdm/dns and qdm12/updated
This commit is contained in:
@@ -1,43 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
)
|
||||
|
||||
func (c *configurator) Start(ctx context.Context, verbosityDetailsLevel uint8) (
|
||||
stdout io.ReadCloser, waitFn func() error, err error) {
|
||||
c.logger.Info("starting unbound")
|
||||
args := []string{"-d", "-c", string(constants.UnboundConf)}
|
||||
if verbosityDetailsLevel > 0 {
|
||||
args = append(args, "-"+strings.Repeat("v", int(verbosityDetailsLevel)))
|
||||
}
|
||||
// Only logs to stderr
|
||||
_, stdout, waitFn, err = c.commander.Start(ctx, "unbound", args...)
|
||||
return stdout, waitFn, err
|
||||
}
|
||||
|
||||
func (c *configurator) Version(ctx context.Context) (version string, err error) {
|
||||
output, err := c.commander.Run(ctx, "unbound", "-V")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unbound version: %w", err)
|
||||
}
|
||||
for _, line := range strings.Split(output, "\n") {
|
||||
if strings.Contains(line, "Version ") {
|
||||
words := strings.Fields(line)
|
||||
const minWords = 2
|
||||
if len(words) < minWords {
|
||||
continue
|
||||
}
|
||||
version = words[1]
|
||||
}
|
||||
}
|
||||
if version == "" {
|
||||
return "", fmt.Errorf("unbound version was not found in %q", output)
|
||||
}
|
||||
return version, nil
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/golibs/command/mock_command"
|
||||
"github.com/qdm12/golibs/logging/mock_logging"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_Start(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("starting unbound")
|
||||
commander := mock_command.NewMockCommander(mockCtrl)
|
||||
commander.EXPECT().Start(context.Background(), "unbound", "-d", "-c", string(constants.UnboundConf), "-vv").
|
||||
Return(nil, nil, nil, nil)
|
||||
c := &configurator{commander: commander, logger: logger}
|
||||
stdout, waitFn, err := c.Start(context.Background(), 2)
|
||||
assert.Nil(t, stdout)
|
||||
assert.Nil(t, waitFn)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Version(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
runOutput string
|
||||
runErr error
|
||||
version string
|
||||
err error
|
||||
}{
|
||||
"no data": {
|
||||
err: fmt.Errorf(`unbound version was not found in ""`),
|
||||
},
|
||||
"2 lines with version": {
|
||||
runOutput: "Version \nVersion 1.0-a hello\n",
|
||||
version: "1.0-a",
|
||||
},
|
||||
"run error": {
|
||||
runErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("unbound version: error"),
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
commander := mock_command.NewMockCommander(mockCtrl)
|
||||
commander.EXPECT().Run(context.Background(), "unbound", "-V").
|
||||
Return(tc.runOutput, tc.runErr)
|
||||
c := &configurator{commander: commander}
|
||||
version, err := c.Version(context.Background())
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.version, version)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,325 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/settings"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
func (c *configurator) MakeUnboundConf(ctx context.Context, settings settings.DNS,
|
||||
username string, puid, pgid int) (err error) {
|
||||
c.logger.Info("generating Unbound configuration")
|
||||
lines, warnings := generateUnboundConf(ctx, settings, username, c.client, c.logger)
|
||||
for _, warning := range warnings {
|
||||
c.logger.Warn(warning)
|
||||
}
|
||||
|
||||
const filepath = string(constants.UnboundConf)
|
||||
file, err := c.openFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0400)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = file.WriteString(strings.Join(lines, "\n"))
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if err := file.Chown(puid, pgid); err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if err := file.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MakeUnboundConf generates an Unbound configuration from the user provided settings.
|
||||
func generateUnboundConf(ctx context.Context, settings settings.DNS, username string,
|
||||
client *http.Client, logger logging.Logger) (
|
||||
lines []string, warnings []error) {
|
||||
doIPv6 := "no"
|
||||
if settings.IPv6 {
|
||||
doIPv6 = "yes"
|
||||
}
|
||||
serverSection := map[string]string{
|
||||
// Logging
|
||||
"verbosity": fmt.Sprintf("%d", settings.VerbosityLevel),
|
||||
"val-log-level": fmt.Sprintf("%d", settings.ValidationLogLevel),
|
||||
"use-syslog": "no",
|
||||
// Performance
|
||||
"num-threads": "1",
|
||||
"prefetch": "yes",
|
||||
"prefetch-key": "yes",
|
||||
"key-cache-size": "16m",
|
||||
"key-cache-slabs": "4",
|
||||
"msg-cache-size": "4m",
|
||||
"msg-cache-slabs": "4",
|
||||
"rrset-cache-size": "4m",
|
||||
"rrset-cache-slabs": "4",
|
||||
"cache-min-ttl": "3600",
|
||||
"cache-max-ttl": "9000",
|
||||
// Privacy
|
||||
"rrset-roundrobin": "yes",
|
||||
"hide-identity": "yes",
|
||||
"hide-version": "yes",
|
||||
// Security
|
||||
"tls-cert-bundle": fmt.Sprintf("%q", constants.CACertificates),
|
||||
"root-hints": fmt.Sprintf("%q", constants.RootHints),
|
||||
"trust-anchor-file": fmt.Sprintf("%q", constants.RootKey),
|
||||
"harden-below-nxdomain": "yes",
|
||||
"harden-referral-path": "yes",
|
||||
"harden-algo-downgrade": "yes",
|
||||
// Network
|
||||
"do-ip4": "yes",
|
||||
"do-ip6": doIPv6,
|
||||
"interface": "0.0.0.0",
|
||||
"port": "53",
|
||||
// Other
|
||||
"username": fmt.Sprintf("%q", username),
|
||||
}
|
||||
|
||||
// Block lists
|
||||
hostnamesLines, ipsLines, warnings := buildBlocked(ctx, client,
|
||||
settings.BlockMalicious, settings.BlockAds, settings.BlockSurveillance,
|
||||
settings.AllowedHostnames, settings.PrivateAddresses,
|
||||
)
|
||||
logger.Info("%d hostnames blocked overall", len(hostnamesLines))
|
||||
logger.Info("%d IP addresses blocked overall", len(ipsLines))
|
||||
sort.Slice(hostnamesLines, func(i, j int) bool { // for unit tests really
|
||||
return hostnamesLines[i] < hostnamesLines[j]
|
||||
})
|
||||
sort.Slice(ipsLines, func(i, j int) bool { // for unit tests really
|
||||
return ipsLines[i] < ipsLines[j]
|
||||
})
|
||||
|
||||
// Server
|
||||
lines = append(lines, "server:")
|
||||
serverLines := make([]string, len(serverSection))
|
||||
i := 0
|
||||
for k, v := range serverSection {
|
||||
serverLines[i] = " " + k + ": " + v
|
||||
i++
|
||||
}
|
||||
sort.Slice(serverLines, func(i, j int) bool {
|
||||
return serverLines[i] < serverLines[j]
|
||||
})
|
||||
lines = append(lines, serverLines...)
|
||||
lines = append(lines, hostnamesLines...)
|
||||
lines = append(lines, ipsLines...)
|
||||
|
||||
// Forward zone
|
||||
lines = append(lines, "forward-zone:")
|
||||
forwardZoneSection := map[string]string{
|
||||
"name": "\".\"",
|
||||
"forward-tls-upstream": "yes",
|
||||
}
|
||||
if settings.Caching {
|
||||
forwardZoneSection["forward-no-cache"] = "no"
|
||||
} else {
|
||||
forwardZoneSection["forward-no-cache"] = "yes"
|
||||
}
|
||||
forwardZoneLines := make([]string, len(forwardZoneSection))
|
||||
i = 0
|
||||
for k, v := range forwardZoneSection {
|
||||
forwardZoneLines[i] = " " + k + ": " + v
|
||||
i++
|
||||
}
|
||||
sort.Slice(forwardZoneLines, func(i, j int) bool {
|
||||
return forwardZoneLines[i] < forwardZoneLines[j]
|
||||
})
|
||||
for _, provider := range settings.Providers {
|
||||
providerData := constants.DNSProviderMapping()[provider]
|
||||
for _, IP := range providerData.IPs {
|
||||
forwardZoneLines = append(forwardZoneLines,
|
||||
fmt.Sprintf(" forward-addr: %s@853#%s", IP, providerData.Host))
|
||||
}
|
||||
}
|
||||
lines = append(lines, forwardZoneLines...)
|
||||
return lines, warnings
|
||||
}
|
||||
|
||||
func buildBlocked(ctx context.Context, client *http.Client, blockMalicious, blockAds, blockSurveillance bool,
|
||||
allowedHostnames, privateAddresses []string) (hostnamesLines, ipsLines []string, errs []error) {
|
||||
chHostnames := make(chan []string)
|
||||
chIPs := make(chan []string)
|
||||
chErrors := make(chan []error)
|
||||
go func() {
|
||||
lines, errs := buildBlockedHostnames(ctx, client, blockMalicious, blockAds, blockSurveillance, allowedHostnames)
|
||||
chHostnames <- lines
|
||||
chErrors <- errs
|
||||
}()
|
||||
go func() {
|
||||
lines, errs := buildBlockedIPs(ctx, client, blockMalicious, blockAds, blockSurveillance, privateAddresses)
|
||||
chIPs <- lines
|
||||
chErrors <- errs
|
||||
}()
|
||||
n := 2
|
||||
for n > 0 {
|
||||
select {
|
||||
case lines := <-chHostnames:
|
||||
hostnamesLines = append(hostnamesLines, lines...)
|
||||
case lines := <-chIPs:
|
||||
ipsLines = append(ipsLines, lines...)
|
||||
case routineErrs := <-chErrors:
|
||||
errs = append(errs, routineErrs...)
|
||||
n--
|
||||
}
|
||||
}
|
||||
return hostnamesLines, ipsLines, errs
|
||||
}
|
||||
|
||||
func getList(ctx context.Context, client *http.Client, url string) (results []string, err error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("%w from %s: %s", ErrBadStatusCode, url, response.Status)
|
||||
}
|
||||
|
||||
content, err := ioutil.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrCannotReadBody, err)
|
||||
}
|
||||
|
||||
results = strings.Split(string(content), "\n")
|
||||
|
||||
// remove empty lines
|
||||
last := len(results) - 1
|
||||
for i := range results {
|
||||
if len(results[i]) == 0 {
|
||||
results[i] = results[last]
|
||||
last--
|
||||
}
|
||||
}
|
||||
results = results[:last+1]
|
||||
|
||||
if len(results) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func buildBlockedHostnames(ctx context.Context, client *http.Client, blockMalicious, blockAds, blockSurveillance bool,
|
||||
allowedHostnames []string) (lines []string, errs []error) {
|
||||
chResults := make(chan []string)
|
||||
chError := make(chan error)
|
||||
listsLeftToFetch := 0
|
||||
if blockMalicious {
|
||||
listsLeftToFetch++
|
||||
go func() {
|
||||
results, err := getList(ctx, client, string(constants.MaliciousBlockListHostnamesURL))
|
||||
chResults <- results
|
||||
chError <- err
|
||||
}()
|
||||
}
|
||||
if blockAds {
|
||||
listsLeftToFetch++
|
||||
go func() {
|
||||
results, err := getList(ctx, client, string(constants.AdsBlockListHostnamesURL))
|
||||
chResults <- results
|
||||
chError <- err
|
||||
}()
|
||||
}
|
||||
if blockSurveillance {
|
||||
listsLeftToFetch++
|
||||
go func() {
|
||||
results, err := getList(ctx, client, string(constants.SurveillanceBlockListHostnamesURL))
|
||||
chResults <- results
|
||||
chError <- err
|
||||
}()
|
||||
}
|
||||
uniqueResults := make(map[string]struct{})
|
||||
for listsLeftToFetch > 0 {
|
||||
select {
|
||||
case results := <-chResults:
|
||||
for _, result := range results {
|
||||
uniqueResults[result] = struct{}{}
|
||||
}
|
||||
case err := <-chError:
|
||||
listsLeftToFetch--
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, allowedHostname := range allowedHostnames {
|
||||
delete(uniqueResults, allowedHostname)
|
||||
}
|
||||
for result := range uniqueResults {
|
||||
lines = append(lines, " local-zone: \""+result+"\" static")
|
||||
}
|
||||
return lines, errs
|
||||
}
|
||||
|
||||
func buildBlockedIPs(ctx context.Context, client *http.Client, blockMalicious, blockAds, blockSurveillance bool,
|
||||
privateAddresses []string) (lines []string, errs []error) {
|
||||
chResults := make(chan []string)
|
||||
chError := make(chan error)
|
||||
listsLeftToFetch := 0
|
||||
if blockMalicious {
|
||||
listsLeftToFetch++
|
||||
go func() {
|
||||
results, err := getList(ctx, client, string(constants.MaliciousBlockListIPsURL))
|
||||
chResults <- results
|
||||
chError <- err
|
||||
}()
|
||||
}
|
||||
if blockAds {
|
||||
listsLeftToFetch++
|
||||
go func() {
|
||||
results, err := getList(ctx, client, string(constants.AdsBlockListIPsURL))
|
||||
chResults <- results
|
||||
chError <- err
|
||||
}()
|
||||
}
|
||||
if blockSurveillance {
|
||||
listsLeftToFetch++
|
||||
go func() {
|
||||
results, err := getList(ctx, client, string(constants.SurveillanceBlockListIPsURL))
|
||||
chResults <- results
|
||||
chError <- err
|
||||
}()
|
||||
}
|
||||
uniqueResults := make(map[string]struct{})
|
||||
for listsLeftToFetch > 0 {
|
||||
select {
|
||||
case results := <-chResults:
|
||||
for _, result := range results {
|
||||
uniqueResults[result] = struct{}{}
|
||||
}
|
||||
case err := <-chError:
|
||||
listsLeftToFetch--
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, privateAddress := range privateAddresses {
|
||||
uniqueResults[privateAddress] = struct{}{}
|
||||
}
|
||||
for result := range uniqueResults {
|
||||
lines = append(lines, " private-address: "+result)
|
||||
}
|
||||
return lines, errs
|
||||
}
|
||||
@@ -1,702 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/settings"
|
||||
"github.com/qdm12/golibs/logging/mock_logging"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_generateUnboundConf(t *testing.T) {
|
||||
t.Parallel()
|
||||
settings := settings.DNS{
|
||||
Providers: []models.DNSProvider{constants.Cloudflare, constants.Quad9},
|
||||
AllowedHostnames: []string{"a"},
|
||||
PrivateAddresses: []string{"9.9.9.9"},
|
||||
BlockMalicious: true,
|
||||
BlockSurveillance: false,
|
||||
BlockAds: false,
|
||||
VerbosityLevel: 2,
|
||||
ValidationLogLevel: 3,
|
||||
Caching: true,
|
||||
IPv6: true,
|
||||
}
|
||||
mockCtrl := gomock.NewController(t)
|
||||
ctx := context.Background()
|
||||
|
||||
clientCalls := map[models.URL]int{
|
||||
constants.MaliciousBlockListIPsURL: 0,
|
||||
constants.MaliciousBlockListHostnamesURL: 0,
|
||||
}
|
||||
client := &http.Client{
|
||||
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
url := models.URL(r.URL.String())
|
||||
if _, ok := clientCalls[url]; !ok {
|
||||
t.Errorf("unknown URL %q", url)
|
||||
return nil, nil
|
||||
}
|
||||
clientCalls[url]++
|
||||
var body string
|
||||
switch url {
|
||||
case constants.MaliciousBlockListIPsURL:
|
||||
body = "c\nd"
|
||||
case constants.MaliciousBlockListHostnamesURL:
|
||||
body = "b\na\nc"
|
||||
default:
|
||||
t.Errorf("unknown URL %q", url)
|
||||
return nil, nil
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: ioutil.NopCloser(strings.NewReader(body)),
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("%d hostnames blocked overall", 2)
|
||||
logger.EXPECT().Info("%d IP addresses blocked overall", 3)
|
||||
lines, warnings := generateUnboundConf(ctx, settings, "nonrootuser", client, logger)
|
||||
require.Len(t, warnings, 0)
|
||||
for url, count := range clientCalls {
|
||||
assert.Equalf(t, 1, count, "for url %q", url)
|
||||
}
|
||||
const expected = `
|
||||
server:
|
||||
cache-max-ttl: 9000
|
||||
cache-min-ttl: 3600
|
||||
do-ip4: yes
|
||||
do-ip6: yes
|
||||
harden-algo-downgrade: yes
|
||||
harden-below-nxdomain: yes
|
||||
harden-referral-path: yes
|
||||
hide-identity: yes
|
||||
hide-version: yes
|
||||
interface: 0.0.0.0
|
||||
key-cache-size: 16m
|
||||
key-cache-slabs: 4
|
||||
msg-cache-size: 4m
|
||||
msg-cache-slabs: 4
|
||||
num-threads: 1
|
||||
port: 53
|
||||
prefetch-key: yes
|
||||
prefetch: yes
|
||||
root-hints: "/etc/unbound/root.hints"
|
||||
rrset-cache-size: 4m
|
||||
rrset-cache-slabs: 4
|
||||
rrset-roundrobin: yes
|
||||
tls-cert-bundle: "/etc/ssl/certs/ca-certificates.crt"
|
||||
trust-anchor-file: "/etc/unbound/root.key"
|
||||
use-syslog: no
|
||||
username: "nonrootuser"
|
||||
val-log-level: 3
|
||||
verbosity: 2
|
||||
local-zone: "b" static
|
||||
local-zone: "c" static
|
||||
private-address: 9.9.9.9
|
||||
private-address: c
|
||||
private-address: d
|
||||
forward-zone:
|
||||
forward-no-cache: no
|
||||
forward-tls-upstream: yes
|
||||
name: "."
|
||||
forward-addr: 1.1.1.1@853#cloudflare-dns.com
|
||||
forward-addr: 1.0.0.1@853#cloudflare-dns.com
|
||||
forward-addr: 2606:4700:4700::1111@853#cloudflare-dns.com
|
||||
forward-addr: 2606:4700:4700::1001@853#cloudflare-dns.com
|
||||
forward-addr: 9.9.9.9@853#dns.quad9.net
|
||||
forward-addr: 149.112.112.112@853#dns.quad9.net
|
||||
forward-addr: 2620:fe::fe@853#dns.quad9.net
|
||||
forward-addr: 2620:fe::9@853#dns.quad9.net`
|
||||
assert.Equal(t, expected, "\n"+strings.Join(lines, "\n"))
|
||||
}
|
||||
|
||||
func Test_buildBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
type blockParams struct {
|
||||
blocked bool
|
||||
content []byte
|
||||
clientErr error
|
||||
}
|
||||
tests := map[string]struct {
|
||||
malicious blockParams
|
||||
ads blockParams
|
||||
surveillance blockParams
|
||||
allowedHostnames []string
|
||||
privateAddresses []string
|
||||
hostnamesLines []string
|
||||
ipsLines []string
|
||||
errsString []string
|
||||
}{
|
||||
"none blocked": {},
|
||||
"all blocked without lists": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
},
|
||||
surveillance: blockParams{
|
||||
blocked: true,
|
||||
},
|
||||
},
|
||||
"all blocked with lists": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("malicious"),
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("ads"),
|
||||
},
|
||||
surveillance: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("surveillance"),
|
||||
},
|
||||
hostnamesLines: []string{
|
||||
" local-zone: \"ads\" static",
|
||||
" local-zone: \"malicious\" static",
|
||||
" local-zone: \"surveillance\" static"},
|
||||
ipsLines: []string{
|
||||
" private-address: ads",
|
||||
" private-address: malicious",
|
||||
" private-address: surveillance"},
|
||||
},
|
||||
"all blocked with allowed hostnames": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("malicious"),
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("ads"),
|
||||
},
|
||||
surveillance: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("surveillance"),
|
||||
},
|
||||
allowedHostnames: []string{"ads"},
|
||||
hostnamesLines: []string{
|
||||
" local-zone: \"malicious\" static",
|
||||
" local-zone: \"surveillance\" static"},
|
||||
ipsLines: []string{
|
||||
" private-address: ads",
|
||||
" private-address: malicious",
|
||||
" private-address: surveillance"},
|
||||
},
|
||||
"all blocked with private addresses": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("malicious"),
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("ads"),
|
||||
},
|
||||
surveillance: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("surveillance"),
|
||||
},
|
||||
privateAddresses: []string{"ads", "192.100.1.5"},
|
||||
hostnamesLines: []string{
|
||||
" local-zone: \"ads\" static",
|
||||
" local-zone: \"malicious\" static",
|
||||
" local-zone: \"surveillance\" static"},
|
||||
ipsLines: []string{
|
||||
" private-address: 192.100.1.5",
|
||||
" private-address: ads",
|
||||
" private-address: malicious",
|
||||
" private-address: surveillance"},
|
||||
},
|
||||
"all blocked with lists and one error": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("malicious"),
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("ads"),
|
||||
clientErr: fmt.Errorf("ads error"),
|
||||
},
|
||||
surveillance: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("surveillance"),
|
||||
},
|
||||
hostnamesLines: []string{
|
||||
" local-zone: \"malicious\" static",
|
||||
" local-zone: \"surveillance\" static"},
|
||||
ipsLines: []string{
|
||||
" private-address: malicious",
|
||||
" private-address: surveillance"},
|
||||
errsString: []string{
|
||||
`Get "https://raw.githubusercontent.com/qdm12/files/master/ads-ips.updated": ads error`,
|
||||
`Get "https://raw.githubusercontent.com/qdm12/files/master/ads-hostnames.updated": ads error`,
|
||||
},
|
||||
},
|
||||
"all blocked with errors": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
clientErr: fmt.Errorf("malicious"),
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
clientErr: fmt.Errorf("ads"),
|
||||
},
|
||||
surveillance: blockParams{
|
||||
blocked: true,
|
||||
clientErr: fmt.Errorf("surveillance"),
|
||||
},
|
||||
errsString: []string{
|
||||
`Get "https://raw.githubusercontent.com/qdm12/files/master/malicious-ips.updated": malicious`,
|
||||
`Get "https://raw.githubusercontent.com/qdm12/files/master/malicious-hostnames.updated": malicious`,
|
||||
`Get "https://raw.githubusercontent.com/qdm12/files/master/ads-ips.updated": ads`,
|
||||
`Get "https://raw.githubusercontent.com/qdm12/files/master/ads-hostnames.updated": ads`,
|
||||
`Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-ips.updated": surveillance`,
|
||||
`Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-hostnames.updated": surveillance`,
|
||||
},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
clientCalls := map[models.URL]int{}
|
||||
if tc.malicious.blocked {
|
||||
clientCalls[constants.MaliciousBlockListIPsURL] = 0
|
||||
clientCalls[constants.MaliciousBlockListHostnamesURL] = 0
|
||||
}
|
||||
if tc.ads.blocked {
|
||||
clientCalls[constants.AdsBlockListIPsURL] = 0
|
||||
clientCalls[constants.AdsBlockListHostnamesURL] = 0
|
||||
}
|
||||
if tc.surveillance.blocked {
|
||||
clientCalls[constants.SurveillanceBlockListIPsURL] = 0
|
||||
clientCalls[constants.SurveillanceBlockListHostnamesURL] = 0
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
url := models.URL(r.URL.String())
|
||||
if _, ok := clientCalls[url]; !ok {
|
||||
t.Errorf("unknown URL %q", url)
|
||||
return nil, nil
|
||||
}
|
||||
clientCalls[url]++
|
||||
var body []byte
|
||||
var err error
|
||||
switch url {
|
||||
case constants.MaliciousBlockListIPsURL, constants.MaliciousBlockListHostnamesURL:
|
||||
body = tc.malicious.content
|
||||
err = tc.malicious.clientErr
|
||||
case constants.AdsBlockListIPsURL, constants.AdsBlockListHostnamesURL:
|
||||
body = tc.ads.content
|
||||
err = tc.ads.clientErr
|
||||
case constants.SurveillanceBlockListIPsURL, constants.SurveillanceBlockListHostnamesURL:
|
||||
body = tc.surveillance.content
|
||||
err = tc.surveillance.clientErr
|
||||
default: // just in case if the test is badly written
|
||||
t.Errorf("unknown URL %q", url)
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: ioutil.NopCloser(bytes.NewReader(body)),
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
|
||||
hostnamesLines, ipsLines, errs := buildBlocked(ctx, client,
|
||||
tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked,
|
||||
tc.allowedHostnames, tc.privateAddresses)
|
||||
|
||||
var errsString []string
|
||||
for _, err := range errs {
|
||||
errsString = append(errsString, err.Error())
|
||||
}
|
||||
assert.ElementsMatch(t, tc.errsString, errsString)
|
||||
assert.ElementsMatch(t, tc.hostnamesLines, hostnamesLines)
|
||||
assert.ElementsMatch(t, tc.ipsLines, ipsLines)
|
||||
|
||||
for url, count := range clientCalls {
|
||||
assert.Equalf(t, 1, count, "for url %q", url)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getList(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
content []byte
|
||||
status int
|
||||
clientErr error
|
||||
results []string
|
||||
err error
|
||||
}{
|
||||
"no result": {
|
||||
status: http.StatusOK,
|
||||
},
|
||||
"bad status": {
|
||||
status: http.StatusInternalServerError,
|
||||
err: fmt.Errorf("bad HTTP status from irrelevant_url: Internal Server Error"),
|
||||
},
|
||||
"network error": {
|
||||
status: http.StatusOK,
|
||||
clientErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf(`Get "irrelevant_url": error`),
|
||||
},
|
||||
"results": {
|
||||
content: []byte("a\nb\nc\n"),
|
||||
status: http.StatusOK,
|
||||
results: []string{"a", "b", "c"},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
client := &http.Client{
|
||||
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
assert.Equal(t, "irrelevant_url", r.URL.String())
|
||||
if tc.clientErr != nil {
|
||||
return nil, tc.clientErr
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: tc.status,
|
||||
Status: http.StatusText(tc.status),
|
||||
Body: ioutil.NopCloser(bytes.NewReader(tc.content)),
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
|
||||
results, err := getList(ctx, client, "irrelevant_url")
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.results, results)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_buildBlockedHostnames(t *testing.T) {
|
||||
t.Parallel()
|
||||
type blockParams struct {
|
||||
blocked bool
|
||||
content []byte
|
||||
clientErr error
|
||||
}
|
||||
tests := map[string]struct {
|
||||
malicious blockParams
|
||||
ads blockParams
|
||||
surveillance blockParams
|
||||
allowedHostnames []string
|
||||
lines []string
|
||||
errsString []string
|
||||
}{
|
||||
"nothing blocked": {},
|
||||
"only malicious blocked": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_b"),
|
||||
clientErr: nil,
|
||||
},
|
||||
lines: []string{
|
||||
" local-zone: \"site_a\" static",
|
||||
" local-zone: \"site_b\" static"},
|
||||
},
|
||||
"all blocked with some duplicates": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_b"),
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_c"),
|
||||
},
|
||||
surveillance: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_c\nsite_a"),
|
||||
},
|
||||
lines: []string{
|
||||
" local-zone: \"site_a\" static",
|
||||
" local-zone: \"site_b\" static",
|
||||
" local-zone: \"site_c\" static"},
|
||||
},
|
||||
"all blocked with one errored": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_b"),
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_c"),
|
||||
},
|
||||
surveillance: blockParams{
|
||||
blocked: true,
|
||||
clientErr: fmt.Errorf("surveillance error"),
|
||||
},
|
||||
lines: []string{
|
||||
" local-zone: \"site_a\" static",
|
||||
" local-zone: \"site_b\" static",
|
||||
" local-zone: \"site_c\" static"},
|
||||
errsString: []string{
|
||||
`Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-hostnames.updated": surveillance error`,
|
||||
},
|
||||
},
|
||||
"blocked with allowed hostnames": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_b"),
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_c\nsite_d"),
|
||||
},
|
||||
allowedHostnames: []string{"site_b", "site_c"},
|
||||
lines: []string{
|
||||
" local-zone: \"site_a\" static",
|
||||
" local-zone: \"site_d\" static"},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
clientCalls := map[models.URL]int{}
|
||||
if tc.malicious.blocked {
|
||||
clientCalls[constants.MaliciousBlockListHostnamesURL] = 0
|
||||
}
|
||||
if tc.ads.blocked {
|
||||
clientCalls[constants.AdsBlockListHostnamesURL] = 0
|
||||
}
|
||||
if tc.surveillance.blocked {
|
||||
clientCalls[constants.SurveillanceBlockListHostnamesURL] = 0
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
url := models.URL(r.URL.String())
|
||||
if _, ok := clientCalls[url]; !ok {
|
||||
t.Errorf("unknown URL %q", url)
|
||||
return nil, nil
|
||||
}
|
||||
clientCalls[url]++
|
||||
var body []byte
|
||||
var err error
|
||||
switch url {
|
||||
case constants.MaliciousBlockListHostnamesURL:
|
||||
body = tc.malicious.content
|
||||
err = tc.malicious.clientErr
|
||||
case constants.AdsBlockListHostnamesURL:
|
||||
body = tc.ads.content
|
||||
err = tc.ads.clientErr
|
||||
case constants.SurveillanceBlockListHostnamesURL:
|
||||
body = tc.surveillance.content
|
||||
err = tc.surveillance.clientErr
|
||||
default: // just in case if the test is badly written
|
||||
t.Errorf("unknown URL %q", url)
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: ioutil.NopCloser(bytes.NewReader(body)),
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
|
||||
lines, errs := buildBlockedHostnames(ctx, client,
|
||||
tc.malicious.blocked, tc.ads.blocked,
|
||||
tc.surveillance.blocked, tc.allowedHostnames)
|
||||
|
||||
var errsString []string
|
||||
for _, err := range errs {
|
||||
errsString = append(errsString, err.Error())
|
||||
}
|
||||
assert.ElementsMatch(t, tc.errsString, errsString)
|
||||
assert.ElementsMatch(t, tc.lines, lines)
|
||||
|
||||
for url, count := range clientCalls {
|
||||
assert.Equalf(t, 1, count, "for url %q", url)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_buildBlockedIPs(t *testing.T) {
|
||||
t.Parallel()
|
||||
type blockParams struct {
|
||||
blocked bool
|
||||
content []byte
|
||||
clientErr error
|
||||
}
|
||||
tests := map[string]struct {
|
||||
malicious blockParams
|
||||
ads blockParams
|
||||
surveillance blockParams
|
||||
privateAddresses []string
|
||||
lines []string
|
||||
errsString []string
|
||||
}{
|
||||
"nothing blocked": {},
|
||||
"only malicious blocked": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_b"),
|
||||
clientErr: nil,
|
||||
},
|
||||
lines: []string{
|
||||
" private-address: site_a",
|
||||
" private-address: site_b"},
|
||||
},
|
||||
"all blocked with some duplicates": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_b"),
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_c"),
|
||||
},
|
||||
surveillance: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_c\nsite_a"),
|
||||
},
|
||||
lines: []string{
|
||||
" private-address: site_a",
|
||||
" private-address: site_b",
|
||||
" private-address: site_c"},
|
||||
},
|
||||
"all blocked with one errored": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_b"),
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_c"),
|
||||
},
|
||||
surveillance: blockParams{
|
||||
blocked: true,
|
||||
clientErr: fmt.Errorf("surveillance error"),
|
||||
},
|
||||
lines: []string{
|
||||
" private-address: site_a",
|
||||
" private-address: site_b",
|
||||
" private-address: site_c"},
|
||||
errsString: []string{
|
||||
`Get "https://raw.githubusercontent.com/qdm12/files/master/surveillance-ips.updated": surveillance error`,
|
||||
},
|
||||
},
|
||||
"blocked with private addresses": {
|
||||
malicious: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_a\nsite_b"),
|
||||
},
|
||||
ads: blockParams{
|
||||
blocked: true,
|
||||
content: []byte("site_c"),
|
||||
},
|
||||
privateAddresses: []string{"site_c", "site_d"},
|
||||
lines: []string{
|
||||
" private-address: site_a",
|
||||
" private-address: site_b",
|
||||
" private-address: site_c",
|
||||
" private-address: site_d"},
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
clientCalls := map[models.URL]int{}
|
||||
if tc.malicious.blocked {
|
||||
clientCalls[constants.MaliciousBlockListIPsURL] = 0
|
||||
}
|
||||
if tc.ads.blocked {
|
||||
clientCalls[constants.AdsBlockListIPsURL] = 0
|
||||
}
|
||||
if tc.surveillance.blocked {
|
||||
clientCalls[constants.SurveillanceBlockListIPsURL] = 0
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
url := models.URL(r.URL.String())
|
||||
if _, ok := clientCalls[url]; !ok {
|
||||
t.Errorf("unknown URL %q", url)
|
||||
return nil, nil
|
||||
}
|
||||
clientCalls[url]++
|
||||
var body []byte
|
||||
var err error
|
||||
switch url {
|
||||
case constants.MaliciousBlockListIPsURL:
|
||||
body = tc.malicious.content
|
||||
err = tc.malicious.clientErr
|
||||
case constants.AdsBlockListIPsURL:
|
||||
body = tc.ads.content
|
||||
err = tc.ads.clientErr
|
||||
case constants.SurveillanceBlockListIPsURL:
|
||||
body = tc.surveillance.content
|
||||
err = tc.surveillance.clientErr
|
||||
default: // just in case if the test is badly written
|
||||
t.Errorf("unknown URL %q", url)
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: ioutil.NopCloser(bytes.NewReader(body)),
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
|
||||
lines, errs := buildBlockedIPs(ctx, client,
|
||||
tc.malicious.blocked, tc.ads.blocked,
|
||||
tc.surveillance.blocked, tc.privateAddresses)
|
||||
|
||||
var errsString []string
|
||||
for _, err := range errs {
|
||||
errsString = append(errsString, err.Error())
|
||||
}
|
||||
assert.ElementsMatch(t, tc.errsString, errsString)
|
||||
assert.ElementsMatch(t, tc.lines, lines)
|
||||
|
||||
for url, count := range clientCalls {
|
||||
assert.Equalf(t, 1, count, "for url %q", url)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/settings"
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
type Configurator interface {
|
||||
DownloadRootHints(ctx context.Context, puid, pgid int) error
|
||||
DownloadRootKey(ctx context.Context, puid, pgid int) error
|
||||
MakeUnboundConf(ctx context.Context, settings settings.DNS, username string, puid, pgid int) (err error)
|
||||
UseDNSInternally(IP net.IP)
|
||||
UseDNSSystemWide(ip net.IP, keepNameserver bool) error
|
||||
Start(ctx context.Context, logLevel uint8) (stdout io.ReadCloser, waitFn func() error, err error)
|
||||
WaitForUnbound() (err error)
|
||||
Version(ctx context.Context) (version string, err error)
|
||||
}
|
||||
|
||||
type configurator struct {
|
||||
logger logging.Logger
|
||||
client *http.Client
|
||||
openFile os.OpenFileFunc
|
||||
commander command.Commander
|
||||
lookupIP func(host string) ([]net.IP, error)
|
||||
}
|
||||
|
||||
func NewConfigurator(logger logging.Logger, httpClient *http.Client,
|
||||
openFile os.OpenFileFunc) Configurator {
|
||||
return &configurator{
|
||||
logger: logger.WithPrefix("dns configurator: "),
|
||||
client: httpClient,
|
||||
openFile: openFile,
|
||||
commander: command.NewCommander(),
|
||||
lookupIP: net.LookupIP,
|
||||
}
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
package dns
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrBadStatusCode = errors.New("bad HTTP status")
|
||||
ErrCannotReadBody = errors.New("cannot read response body")
|
||||
)
|
||||
@@ -4,9 +4,11 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/dns/pkg/unbound"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/settings"
|
||||
@@ -25,7 +27,8 @@ type Looper interface {
|
||||
|
||||
type looper struct {
|
||||
state state
|
||||
conf Configurator
|
||||
conf unbound.Configurator
|
||||
client *http.Client
|
||||
logger logging.Logger
|
||||
streamMerger command.StreamMerger
|
||||
username string
|
||||
@@ -44,14 +47,16 @@ type looper struct {
|
||||
|
||||
const defaultBackoffTime = 10 * time.Second
|
||||
|
||||
func NewLooper(conf Configurator, settings settings.DNS, logger logging.Logger,
|
||||
streamMerger command.StreamMerger, username string, puid, pgid int) Looper {
|
||||
func NewLooper(conf unbound.Configurator, settings settings.DNS, client *http.Client,
|
||||
logger logging.Logger, streamMerger command.StreamMerger,
|
||||
username string, puid, pgid int) Looper {
|
||||
return &looper{
|
||||
state: state{
|
||||
status: constants.Stopped,
|
||||
settings: settings,
|
||||
},
|
||||
conf: conf,
|
||||
client: client,
|
||||
logger: logger.WithPrefix("dns over tls: "),
|
||||
username: username,
|
||||
puid: puid,
|
||||
@@ -170,7 +175,7 @@ func (l *looper) setupUnbound(ctx context.Context,
|
||||
settings := l.GetSettings()
|
||||
|
||||
unboundCtx, cancel := context.WithCancel(context.Background())
|
||||
stream, waitFn, err := l.conf.Start(unboundCtx, settings.VerbosityDetailsLevel)
|
||||
stream, waitFn, err := l.conf.Start(unboundCtx, settings.Unbound.VerbosityDetailsLevel)
|
||||
if err != nil {
|
||||
cancel()
|
||||
if !previousCrashed {
|
||||
@@ -187,7 +192,7 @@ func (l *looper) setupUnbound(ctx context.Context,
|
||||
l.logger.Error(err)
|
||||
}
|
||||
|
||||
if err := l.conf.WaitForUnbound(); err != nil {
|
||||
if err := l.conf.WaitForUnbound(ctx); err != nil {
|
||||
if !previousCrashed {
|
||||
l.running <- constants.Crashed
|
||||
}
|
||||
@@ -229,8 +234,8 @@ func (l *looper) useUnencryptedDNS(fallback bool) {
|
||||
}
|
||||
|
||||
// Try with any IPv4 address from the providers chosen
|
||||
for _, provider := range settings.Providers {
|
||||
data := constants.DNSProviderMapping()[provider]
|
||||
for _, provider := range settings.Unbound.Providers {
|
||||
data, _ := unbound.GetProviderData(provider)
|
||||
for _, targetIP = range data.IPs {
|
||||
if targetIP.To4() != nil {
|
||||
if fallback {
|
||||
@@ -248,7 +253,7 @@ func (l *looper) useUnencryptedDNS(fallback bool) {
|
||||
}
|
||||
|
||||
// No IPv4 address found
|
||||
l.logger.Error("no ipv4 DNS address found for providers %s", settings.Providers)
|
||||
l.logger.Error("no ipv4 DNS address found for providers %s", settings.Unbound.Providers)
|
||||
}
|
||||
|
||||
func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
|
||||
@@ -310,14 +315,22 @@ func (l *looper) RunRestartTicker(ctx context.Context, wg *sync.WaitGroup) {
|
||||
}
|
||||
|
||||
func (l *looper) updateFiles(ctx context.Context) (err error) {
|
||||
if err := l.conf.DownloadRootHints(ctx, l.puid, l.pgid); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := l.conf.DownloadRootKey(ctx, l.puid, l.pgid); err != nil {
|
||||
if err := l.conf.SetupFiles(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
settings := l.GetSettings()
|
||||
if err := l.conf.MakeUnboundConf(ctx, settings, l.username, l.puid, l.pgid); err != nil {
|
||||
|
||||
hostnameLines, ipLines, errs := l.conf.BuildBlocked(ctx, l.client,
|
||||
settings.BlockMalicious, settings.BlockAds, settings.BlockSurveillance,
|
||||
settings.Unbound.BlockedHostnames, settings.Unbound.BlockedIPs,
|
||||
settings.Unbound.AllowedHostnames)
|
||||
for _, err := range errs {
|
||||
l.logger.Warn(err)
|
||||
}
|
||||
|
||||
if err := l.conf.MakeUnboundConf(
|
||||
settings.Unbound, hostnameLines, ipLines,
|
||||
l.username, l.puid, l.pgid); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/golibs/os"
|
||||
)
|
||||
|
||||
// UseDNSInternally is to change the Go program DNS only.
|
||||
func (c *configurator) UseDNSInternally(ip net.IP) {
|
||||
c.logger.Info("using DNS address %s internally", ip.String())
|
||||
net.DefaultResolver = &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d := net.Dialer{}
|
||||
return d.DialContext(ctx, "udp", net.JoinHostPort(ip.String(), "53"))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// UseDNSSystemWide changes the nameserver to use for DNS system wide.
|
||||
func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
|
||||
c.logger.Info("using DNS address %s system wide", ip.String())
|
||||
const filepath = string(constants.ResolvConf)
|
||||
file, err := c.openFile(filepath, os.O_RDWR|os.O_TRUNC, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := ioutil.ReadAll(file)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
s := strings.TrimSuffix(string(data), "\n")
|
||||
lines := strings.Split(s, "\n")
|
||||
if len(lines) == 1 && lines[0] == "" {
|
||||
lines = nil
|
||||
}
|
||||
found := false
|
||||
if !keepNameserver { // default
|
||||
for i := range lines {
|
||||
if strings.HasPrefix(lines[i], "nameserver ") {
|
||||
lines[i] = "nameserver " + ip.String()
|
||||
found = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
lines = append(lines, "nameserver "+ip.String())
|
||||
}
|
||||
s = strings.Join(lines, "\n") + "\n"
|
||||
_, err = file.WriteString(s)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
return file.Close()
|
||||
}
|
||||
@@ -1,106 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/golibs/logging/mock_logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
"github.com/qdm12/golibs/os/mock_os"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_UseDNSSystemWide(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
data []byte
|
||||
writtenData string
|
||||
openErr error
|
||||
readErr error
|
||||
writeErr error
|
||||
closeErr error
|
||||
err error
|
||||
}{
|
||||
"no data": {
|
||||
writtenData: "nameserver 127.0.0.1\n",
|
||||
},
|
||||
"open error": {
|
||||
openErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"read error": {
|
||||
readErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"write error": {
|
||||
writtenData: "nameserver 127.0.0.1\n",
|
||||
writeErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"lines without nameserver": {
|
||||
data: []byte("abc\ndef\n"),
|
||||
writtenData: "abc\ndef\nnameserver 127.0.0.1\n",
|
||||
},
|
||||
"lines with nameserver": {
|
||||
data: []byte("abc\nnameserver abc def\ndef\n"),
|
||||
writtenData: "abc\nnameserver 127.0.0.1\ndef\n",
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
file := mock_os.NewMockFile(mockCtrl)
|
||||
if tc.openErr == nil {
|
||||
firstReadCall := file.EXPECT().
|
||||
Read(gomock.AssignableToTypeOf([]byte{})).
|
||||
DoAndReturn(func(b []byte) (int, error) {
|
||||
copy(b, tc.data)
|
||||
return len(tc.data), nil
|
||||
})
|
||||
readErr := tc.readErr
|
||||
if readErr == nil {
|
||||
readErr = io.EOF
|
||||
}
|
||||
finalReadCall := file.EXPECT().
|
||||
Read(gomock.AssignableToTypeOf([]byte{})).
|
||||
Return(0, readErr).After(firstReadCall)
|
||||
if tc.readErr == nil {
|
||||
writeCall := file.EXPECT().WriteString(tc.writtenData).
|
||||
Return(0, tc.writeErr).After(finalReadCall)
|
||||
file.EXPECT().Close().Return(tc.closeErr).After(writeCall)
|
||||
} else {
|
||||
file.EXPECT().Close().Return(tc.closeErr).After(finalReadCall)
|
||||
}
|
||||
}
|
||||
|
||||
openFile := func(name string, flag int, perm os.FileMode) (os.File, error) {
|
||||
assert.Equal(t, string(constants.ResolvConf), name)
|
||||
assert.Equal(t, os.O_RDWR|os.O_TRUNC, flag)
|
||||
assert.Equal(t, os.FileMode(0644), perm)
|
||||
return file, tc.openErr
|
||||
}
|
||||
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("using DNS address %s system wide", "127.0.0.1")
|
||||
c := &configurator{
|
||||
openFile: openFile,
|
||||
logger: logger,
|
||||
}
|
||||
err := c.UseDNSSystemWide(net.IP{127, 0, 0, 1}, false)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
)
|
||||
|
||||
func (c *configurator) DownloadRootHints(ctx context.Context, puid, pgid int) error {
|
||||
return c.downloadAndSave(ctx, "root hints",
|
||||
string(constants.NamedRootURL), string(constants.RootHints), puid, pgid)
|
||||
}
|
||||
|
||||
func (c *configurator) DownloadRootKey(ctx context.Context, puid, pgid int) error {
|
||||
return c.downloadAndSave(ctx, "root key",
|
||||
string(constants.RootKeyURL), string(constants.RootKey), puid, pgid)
|
||||
}
|
||||
|
||||
func (c *configurator) downloadAndSave(ctx context.Context, logName, url, filepath string, puid, pgid int) error {
|
||||
c.logger.Info("downloading %s from %s", logName, url)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("%w from %s: %s", ErrBadStatusCode, url, response.Status)
|
||||
}
|
||||
|
||||
file, err := c.openFile(filepath, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0400)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(file, response.Body)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
err = file.Chown(puid, pgid)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return file.Close()
|
||||
}
|
||||
@@ -1,195 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/golibs/logging/mock_logging"
|
||||
"github.com/qdm12/golibs/os"
|
||||
"github.com/qdm12/golibs/os/mock_os"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_downloadAndSave(t *testing.T) {
|
||||
t.Parallel()
|
||||
const defaultURL = "https://test.com"
|
||||
tests := map[string]struct {
|
||||
url string // to trigger a new request error
|
||||
content []byte
|
||||
status int
|
||||
clientErr error
|
||||
openErr error
|
||||
writeErr error
|
||||
chownErr error
|
||||
closeErr error
|
||||
err error
|
||||
}{
|
||||
"no data": {
|
||||
url: defaultURL,
|
||||
status: http.StatusOK,
|
||||
},
|
||||
"bad status": {
|
||||
url: defaultURL,
|
||||
status: http.StatusBadRequest,
|
||||
err: fmt.Errorf("bad HTTP status from %s: Bad Request", defaultURL),
|
||||
},
|
||||
"client error": {
|
||||
url: defaultURL,
|
||||
clientErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("Get %q: error", defaultURL),
|
||||
},
|
||||
"open error": {
|
||||
url: defaultURL,
|
||||
status: http.StatusOK,
|
||||
openErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"chown error": {
|
||||
url: defaultURL,
|
||||
status: http.StatusOK,
|
||||
chownErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"close error": {
|
||||
url: defaultURL,
|
||||
status: http.StatusOK,
|
||||
closeErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"data": {
|
||||
url: defaultURL,
|
||||
content: []byte("content"),
|
||||
status: http.StatusOK,
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
ctx := context.Background()
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("downloading %s from %s", "root hints", tc.url)
|
||||
|
||||
client := &http.Client{
|
||||
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
assert.Equal(t, tc.url, r.URL.String())
|
||||
if tc.clientErr != nil {
|
||||
return nil, tc.clientErr
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: tc.status,
|
||||
Status: http.StatusText(tc.status),
|
||||
Body: ioutil.NopCloser(bytes.NewReader(tc.content)),
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
|
||||
openFile := func(name string, flag int, perm os.FileMode) (os.File, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
const filepath = "/test"
|
||||
|
||||
if tc.clientErr == nil && tc.status == http.StatusOK {
|
||||
file := mock_os.NewMockFile(mockCtrl)
|
||||
if tc.openErr == nil {
|
||||
if len(tc.content) > 0 {
|
||||
file.EXPECT().
|
||||
Write(tc.content).
|
||||
Return(len(tc.content), tc.writeErr)
|
||||
}
|
||||
file.EXPECT().
|
||||
Close().
|
||||
Return(tc.closeErr)
|
||||
file.EXPECT().
|
||||
Chown(1000, 1000).
|
||||
Return(tc.chownErr)
|
||||
}
|
||||
|
||||
openFile = func(name string, flag int, perm os.FileMode) (os.File, error) {
|
||||
assert.Equal(t, filepath, name)
|
||||
assert.Equal(t, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, flag)
|
||||
assert.Equal(t, os.FileMode(0400), perm)
|
||||
return file, tc.openErr
|
||||
}
|
||||
}
|
||||
|
||||
c := &configurator{
|
||||
logger: logger,
|
||||
client: client,
|
||||
openFile: openFile,
|
||||
}
|
||||
|
||||
err := c.downloadAndSave(ctx, "root hints",
|
||||
tc.url, filepath,
|
||||
1000, 1000)
|
||||
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_DownloadRootHints(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
ctx := context.Background()
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL))
|
||||
|
||||
client := &http.Client{
|
||||
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
assert.Equal(t, string(constants.NamedRootURL), r.URL.String())
|
||||
return nil, errors.New("test")
|
||||
}),
|
||||
}
|
||||
|
||||
c := &configurator{
|
||||
logger: logger,
|
||||
client: client,
|
||||
}
|
||||
|
||||
err := c.DownloadRootHints(ctx, 1000, 1000)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, `Get "https://raw.githubusercontent.com/qdm12/files/master/named.root.updated": test`, err.Error())
|
||||
}
|
||||
|
||||
func Test_DownloadRootKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
ctx := context.Background()
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("downloading %s from %s", "root key", string(constants.RootKeyURL))
|
||||
|
||||
client := &http.Client{
|
||||
Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
assert.Equal(t, string(constants.RootKeyURL), r.URL.String())
|
||||
return nil, errors.New("test")
|
||||
}),
|
||||
}
|
||||
|
||||
c := &configurator{
|
||||
logger: logger,
|
||||
client: client,
|
||||
}
|
||||
|
||||
err := c.DownloadRootKey(ctx, 1000, 1000)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, `Get "https://raw.githubusercontent.com/qdm12/files/master/root.key.updated": test`, err.Error())
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
package dns
|
||||
|
||||
import "net/http"
|
||||
|
||||
type roundTripFunc func(r *http.Request) (*http.Response, error)
|
||||
|
||||
func (s roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
return s(r)
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (c *configurator) WaitForUnbound() (err error) {
|
||||
const hostToResolve = "github.com"
|
||||
waitDurations := [...]time.Duration{
|
||||
300 * time.Millisecond,
|
||||
100 * time.Millisecond,
|
||||
300 * time.Millisecond,
|
||||
500 * time.Millisecond,
|
||||
time.Second,
|
||||
2 * time.Second,
|
||||
}
|
||||
maxTries := len(waitDurations)
|
||||
for i, waitDuration := range waitDurations {
|
||||
time.Sleep(waitDuration)
|
||||
_, err := c.lookupIP(hostToResolve)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
c.logger.Warn("could not resolve %s (try %d of %d): %s", hostToResolve, i+1, maxTries, err)
|
||||
}
|
||||
return fmt.Errorf("Unbound does not seem to be working after %d tries", maxTries)
|
||||
}
|
||||
Reference in New Issue
Block a user