Upgrade dependencies
- Use of context for custom http client - Remove unused nodeid for logger - Upgrade shadowsocks dependency
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
@@ -13,9 +14,9 @@ import (
|
||||
"github.com/qdm12/golibs/network"
|
||||
)
|
||||
|
||||
func (c *configurator) MakeUnboundConf(settings settings.DNS, uid, gid int) (err error) {
|
||||
func (c *configurator) MakeUnboundConf(ctx context.Context, settings settings.DNS, uid, gid int) (err error) {
|
||||
c.logger.Info("generating Unbound configuration")
|
||||
lines, warnings := generateUnboundConf(settings, c.client, c.logger)
|
||||
lines, warnings := generateUnboundConf(ctx, settings, c.client, c.logger)
|
||||
for _, warning := range warnings {
|
||||
c.logger.Warn(warning)
|
||||
}
|
||||
@@ -27,7 +28,7 @@ func (c *configurator) MakeUnboundConf(settings settings.DNS, uid, gid int) (err
|
||||
}
|
||||
|
||||
// MakeUnboundConf generates an Unbound configuration from the user provided settings
|
||||
func generateUnboundConf(settings settings.DNS, client network.Client, logger logging.Logger) (lines []string, warnings []error) {
|
||||
func generateUnboundConf(ctx context.Context, settings settings.DNS, client network.Client, logger logging.Logger) (lines []string, warnings []error) {
|
||||
doIPv6 := "no"
|
||||
if settings.IPv6 {
|
||||
doIPv6 = "yes"
|
||||
@@ -70,7 +71,7 @@ func generateUnboundConf(settings settings.DNS, client network.Client, logger lo
|
||||
}
|
||||
|
||||
// Block lists
|
||||
hostnamesLines, ipsLines, warnings := buildBlocked(client,
|
||||
hostnamesLines, ipsLines, warnings := buildBlocked(ctx, client,
|
||||
settings.BlockMalicious, settings.BlockAds, settings.BlockSurveillance,
|
||||
settings.AllowedHostnames, settings.PrivateAddresses,
|
||||
)
|
||||
@@ -129,18 +130,18 @@ func generateUnboundConf(settings settings.DNS, client network.Client, logger lo
|
||||
return lines, warnings
|
||||
}
|
||||
|
||||
func buildBlocked(client network.Client, blockMalicious, blockAds, blockSurveillance bool,
|
||||
func buildBlocked(ctx context.Context, client network.Client, blockMalicious, blockAds, blockSurveillance bool,
|
||||
allowedHostnames, privateAddresses []string) (hostnamesLines, ipsLines []string, errs []error) {
|
||||
chHostnames := make(chan []string)
|
||||
chIPs := make(chan []string)
|
||||
chErrors := make(chan []error)
|
||||
go func() {
|
||||
lines, errs := buildBlockedHostnames(client, blockMalicious, blockAds, blockSurveillance, allowedHostnames)
|
||||
lines, errs := buildBlockedHostnames(ctx, client, blockMalicious, blockAds, blockSurveillance, allowedHostnames)
|
||||
chHostnames <- lines
|
||||
chErrors <- errs
|
||||
}()
|
||||
go func() {
|
||||
lines, errs := buildBlockedIPs(client, blockMalicious, blockAds, blockSurveillance, privateAddresses)
|
||||
lines, errs := buildBlockedIPs(ctx, client, blockMalicious, blockAds, blockSurveillance, privateAddresses)
|
||||
chIPs <- lines
|
||||
chErrors <- errs
|
||||
}()
|
||||
@@ -159,8 +160,8 @@ func buildBlocked(client network.Client, blockMalicious, blockAds, blockSurveill
|
||||
return hostnamesLines, ipsLines, errs
|
||||
}
|
||||
|
||||
func getList(client network.Client, url string) (results []string, err error) {
|
||||
content, status, err := client.GetContent(url)
|
||||
func getList(ctx context.Context, client network.Client, url string) (results []string, err error) {
|
||||
content, status, err := client.Get(ctx, url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if status != http.StatusOK {
|
||||
@@ -184,7 +185,7 @@ func getList(client network.Client, url string) (results []string, err error) {
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func buildBlockedHostnames(client network.Client, blockMalicious, blockAds, blockSurveillance bool,
|
||||
func buildBlockedHostnames(ctx context.Context, client network.Client, blockMalicious, blockAds, blockSurveillance bool,
|
||||
allowedHostnames []string) (lines []string, errs []error) {
|
||||
chResults := make(chan []string)
|
||||
chError := make(chan error)
|
||||
@@ -192,7 +193,7 @@ func buildBlockedHostnames(client network.Client, blockMalicious, blockAds, bloc
|
||||
if blockMalicious {
|
||||
listsLeftToFetch++
|
||||
go func() {
|
||||
results, err := getList(client, string(constants.MaliciousBlockListHostnamesURL))
|
||||
results, err := getList(ctx, client, string(constants.MaliciousBlockListHostnamesURL))
|
||||
chResults <- results
|
||||
chError <- err
|
||||
}()
|
||||
@@ -200,7 +201,7 @@ func buildBlockedHostnames(client network.Client, blockMalicious, blockAds, bloc
|
||||
if blockAds {
|
||||
listsLeftToFetch++
|
||||
go func() {
|
||||
results, err := getList(client, string(constants.AdsBlockListHostnamesURL))
|
||||
results, err := getList(ctx, client, string(constants.AdsBlockListHostnamesURL))
|
||||
chResults <- results
|
||||
chError <- err
|
||||
}()
|
||||
@@ -208,7 +209,7 @@ func buildBlockedHostnames(client network.Client, blockMalicious, blockAds, bloc
|
||||
if blockSurveillance {
|
||||
listsLeftToFetch++
|
||||
go func() {
|
||||
results, err := getList(client, string(constants.SurveillanceBlockListHostnamesURL))
|
||||
results, err := getList(ctx, client, string(constants.SurveillanceBlockListHostnamesURL))
|
||||
chResults <- results
|
||||
chError <- err
|
||||
}()
|
||||
@@ -236,7 +237,7 @@ func buildBlockedHostnames(client network.Client, blockMalicious, blockAds, bloc
|
||||
return lines, errs
|
||||
}
|
||||
|
||||
func buildBlockedIPs(client network.Client, blockMalicious, blockAds, blockSurveillance bool,
|
||||
func buildBlockedIPs(ctx context.Context, client network.Client, blockMalicious, blockAds, blockSurveillance bool,
|
||||
privateAddresses []string) (lines []string, errs []error) {
|
||||
chResults := make(chan []string)
|
||||
chError := make(chan error)
|
||||
@@ -244,7 +245,7 @@ func buildBlockedIPs(client network.Client, blockMalicious, blockAds, blockSurve
|
||||
if blockMalicious {
|
||||
listsLeftToFetch++
|
||||
go func() {
|
||||
results, err := getList(client, string(constants.MaliciousBlockListIPsURL))
|
||||
results, err := getList(ctx, client, string(constants.MaliciousBlockListIPsURL))
|
||||
chResults <- results
|
||||
chError <- err
|
||||
}()
|
||||
@@ -252,7 +253,7 @@ func buildBlockedIPs(client network.Client, blockMalicious, blockAds, blockSurve
|
||||
if blockAds {
|
||||
listsLeftToFetch++
|
||||
go func() {
|
||||
results, err := getList(client, string(constants.AdsBlockListIPsURL))
|
||||
results, err := getList(ctx, client, string(constants.AdsBlockListIPsURL))
|
||||
chResults <- results
|
||||
chError <- err
|
||||
}()
|
||||
@@ -260,7 +261,7 @@ func buildBlockedIPs(client network.Client, blockMalicious, blockAds, blockSurve
|
||||
if blockSurveillance {
|
||||
listsLeftToFetch++
|
||||
go func() {
|
||||
results, err := getList(client, string(constants.SurveillanceBlockListIPsURL))
|
||||
results, err := getList(ctx, client, string(constants.SurveillanceBlockListIPsURL))
|
||||
chResults <- results
|
||||
chError <- err
|
||||
}()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -31,15 +32,16 @@ func Test_generateUnboundConf(t *testing.T) {
|
||||
}
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
ctx := context.Background()
|
||||
client := mock_network.NewMockClient(mockCtrl)
|
||||
client.EXPECT().GetContent(string(constants.MaliciousBlockListHostnamesURL)).
|
||||
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListHostnamesURL)).
|
||||
Return([]byte("b\na\nc"), 200, nil).Times(1)
|
||||
client.EXPECT().GetContent(string(constants.MaliciousBlockListIPsURL)).
|
||||
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListIPsURL)).
|
||||
Return([]byte("c\nd\n"), 200, nil).Times(1)
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("%d hostnames blocked overall", 2).Times(1)
|
||||
logger.EXPECT().Info("%d IP addresses blocked overall", 3).Times(1)
|
||||
lines, warnings := generateUnboundConf(settings, client, logger)
|
||||
lines, warnings := generateUnboundConf(ctx, settings, client, logger)
|
||||
require.Len(t, warnings, 0)
|
||||
expected := `
|
||||
server:
|
||||
@@ -232,26 +234,28 @@ func Test_buildBlocked(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
ctx := context.Background()
|
||||
client := mock_network.NewMockClient(mockCtrl)
|
||||
if tc.malicious.blocked {
|
||||
client.EXPECT().GetContent(string(constants.MaliciousBlockListHostnamesURL)).
|
||||
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListHostnamesURL)).
|
||||
Return(tc.malicious.content, 200, tc.malicious.clientErr).Times(1)
|
||||
client.EXPECT().GetContent(string(constants.MaliciousBlockListIPsURL)).
|
||||
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListIPsURL)).
|
||||
Return(tc.malicious.content, 200, tc.malicious.clientErr).Times(1)
|
||||
}
|
||||
if tc.ads.blocked {
|
||||
client.EXPECT().GetContent(string(constants.AdsBlockListHostnamesURL)).
|
||||
client.EXPECT().Get(ctx, string(constants.AdsBlockListHostnamesURL)).
|
||||
Return(tc.ads.content, 200, tc.ads.clientErr).Times(1)
|
||||
client.EXPECT().GetContent(string(constants.AdsBlockListIPsURL)).
|
||||
client.EXPECT().Get(ctx, string(constants.AdsBlockListIPsURL)).
|
||||
Return(tc.ads.content, 200, tc.ads.clientErr).Times(1)
|
||||
}
|
||||
if tc.surveillance.blocked {
|
||||
client.EXPECT().GetContent(string(constants.SurveillanceBlockListHostnamesURL)).
|
||||
client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListHostnamesURL)).
|
||||
Return(tc.surveillance.content, 200, tc.surveillance.clientErr).Times(1)
|
||||
client.EXPECT().GetContent(string(constants.SurveillanceBlockListIPsURL)).
|
||||
client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListIPsURL)).
|
||||
Return(tc.surveillance.content, 200, tc.surveillance.clientErr).Times(1)
|
||||
}
|
||||
hostnamesLines, ipsLines, errs := buildBlocked(client, tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked,
|
||||
hostnamesLines, ipsLines, errs := buildBlocked(ctx, client,
|
||||
tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked,
|
||||
tc.allowedHostnames, tc.privateAddresses)
|
||||
var errsString []string
|
||||
for _, err := range errs {
|
||||
@@ -284,10 +288,11 @@ func Test_getList(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
ctx := context.Background()
|
||||
client := mock_network.NewMockClient(mockCtrl)
|
||||
client.EXPECT().GetContent("irrelevant_url").
|
||||
client.EXPECT().Get(ctx, "irrelevant_url").
|
||||
Return(tc.content, tc.status, tc.clientErr).Times(1)
|
||||
results, err := getList(client, "irrelevant_url")
|
||||
results, err := getList(ctx, client, "irrelevant_url")
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
@@ -388,21 +393,23 @@ func Test_buildBlockedHostnames(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
ctx := context.Background()
|
||||
client := mock_network.NewMockClient(mockCtrl)
|
||||
if tc.malicious.blocked {
|
||||
client.EXPECT().GetContent(string(constants.MaliciousBlockListHostnamesURL)).
|
||||
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListHostnamesURL)).
|
||||
Return(tc.malicious.content, 200, tc.malicious.clientErr).Times(1)
|
||||
}
|
||||
if tc.ads.blocked {
|
||||
client.EXPECT().GetContent(string(constants.AdsBlockListHostnamesURL)).
|
||||
client.EXPECT().Get(ctx, string(constants.AdsBlockListHostnamesURL)).
|
||||
Return(tc.ads.content, 200, tc.ads.clientErr).Times(1)
|
||||
}
|
||||
if tc.surveillance.blocked {
|
||||
client.EXPECT().GetContent(string(constants.SurveillanceBlockListHostnamesURL)).
|
||||
client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListHostnamesURL)).
|
||||
Return(tc.surveillance.content, 200, tc.surveillance.clientErr).Times(1)
|
||||
}
|
||||
lines, errs := buildBlockedHostnames(client,
|
||||
tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked, tc.allowedHostnames)
|
||||
lines, errs := buildBlockedHostnames(ctx, client,
|
||||
tc.malicious.blocked, tc.ads.blocked,
|
||||
tc.surveillance.blocked, tc.allowedHostnames)
|
||||
var errsString []string
|
||||
for _, err := range errs {
|
||||
errsString = append(errsString, err.Error())
|
||||
@@ -504,21 +511,23 @@ func Test_buildBlockedIPs(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
ctx := context.Background()
|
||||
client := mock_network.NewMockClient(mockCtrl)
|
||||
if tc.malicious.blocked {
|
||||
client.EXPECT().GetContent(string(constants.MaliciousBlockListIPsURL)).
|
||||
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListIPsURL)).
|
||||
Return(tc.malicious.content, 200, tc.malicious.clientErr).Times(1)
|
||||
}
|
||||
if tc.ads.blocked {
|
||||
client.EXPECT().GetContent(string(constants.AdsBlockListIPsURL)).
|
||||
client.EXPECT().Get(ctx, string(constants.AdsBlockListIPsURL)).
|
||||
Return(tc.ads.content, 200, tc.ads.clientErr).Times(1)
|
||||
}
|
||||
if tc.surveillance.blocked {
|
||||
client.EXPECT().GetContent(string(constants.SurveillanceBlockListIPsURL)).
|
||||
client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListIPsURL)).
|
||||
Return(tc.surveillance.content, 200, tc.surveillance.clientErr).Times(1)
|
||||
}
|
||||
lines, errs := buildBlockedIPs(client,
|
||||
tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked, tc.privateAddresses)
|
||||
lines, errs := buildBlockedIPs(ctx, client,
|
||||
tc.malicious.blocked, tc.ads.blocked,
|
||||
tc.surveillance.blocked, tc.privateAddresses)
|
||||
var errsString []string
|
||||
for _, err := range errs {
|
||||
errsString = append(errsString, err.Error())
|
||||
|
||||
@@ -13,9 +13,9 @@ import (
|
||||
)
|
||||
|
||||
type Configurator interface {
|
||||
DownloadRootHints(uid, gid int) error
|
||||
DownloadRootKey(uid, gid int) error
|
||||
MakeUnboundConf(settings settings.DNS, uid, gid int) (err error)
|
||||
DownloadRootHints(ctx context.Context, uid, gid int) error
|
||||
DownloadRootKey(ctx context.Context, uid, gid int) error
|
||||
MakeUnboundConf(ctx context.Context, settings settings.DNS, uid, gid int) (err error)
|
||||
UseDNSInternally(IP net.IP)
|
||||
UseDNSSystemWide(ip net.IP, keepNameserver bool) error
|
||||
Start(ctx context.Context, logLevel uint8) (stdout io.ReadCloser, waitFn func() error, err error)
|
||||
|
||||
@@ -164,15 +164,15 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady fun
|
||||
settings := l.GetSettings()
|
||||
|
||||
// Setup
|
||||
if err := l.conf.DownloadRootHints(l.uid, l.gid); err != nil {
|
||||
if err := l.conf.DownloadRootHints(ctx, l.uid, l.gid); err != nil {
|
||||
l.logAndWait(ctx, err)
|
||||
continue
|
||||
}
|
||||
if err := l.conf.DownloadRootKey(l.uid, l.gid); err != nil {
|
||||
if err := l.conf.DownloadRootKey(ctx, l.uid, l.gid); err != nil {
|
||||
l.logAndWait(ctx, err)
|
||||
continue
|
||||
}
|
||||
if err := l.conf.MakeUnboundConf(settings, l.uid, l.gid); err != nil {
|
||||
if err := l.conf.MakeUnboundConf(ctx, settings, l.uid, l.gid); err != nil {
|
||||
l.logAndWait(ctx, err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
@@ -8,9 +9,9 @@ import (
|
||||
"github.com/qdm12/golibs/files"
|
||||
)
|
||||
|
||||
func (c *configurator) DownloadRootHints(uid, gid int) error {
|
||||
func (c *configurator) DownloadRootHints(ctx context.Context, uid, gid int) error {
|
||||
c.logger.Info("downloading root hints from %s", constants.NamedRootURL)
|
||||
content, status, err := c.client.GetContent(string(constants.NamedRootURL))
|
||||
content, status, err := c.client.Get(ctx, string(constants.NamedRootURL))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if status != http.StatusOK {
|
||||
@@ -23,9 +24,9 @@ func (c *configurator) DownloadRootHints(uid, gid int) error {
|
||||
files.Permissions(0400))
|
||||
}
|
||||
|
||||
func (c *configurator) DownloadRootKey(uid, gid int) error {
|
||||
func (c *configurator) DownloadRootKey(ctx context.Context, uid, gid int) error {
|
||||
c.logger.Info("downloading root key from %s", constants.RootKeyURL)
|
||||
content, status, err := c.client.GetContent(string(constants.RootKeyURL))
|
||||
content, status, err := c.client.Get(ctx, string(constants.RootKeyURL))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if status != http.StatusOK {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
@@ -52,10 +53,11 @@ func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
ctx := context.Background()
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("downloading root hints from %s", constants.NamedRootURL).Times(1)
|
||||
client := mock_network.NewMockClient(mockCtrl)
|
||||
client.EXPECT().GetContent(string(constants.NamedRootURL)).
|
||||
client.EXPECT().Get(ctx, string(constants.NamedRootURL)).
|
||||
Return(tc.content, tc.status, tc.clientErr).Times(1)
|
||||
fileManager := mock_files.NewMockFileManager(mockCtrl)
|
||||
if tc.clientErr == nil && tc.status == http.StatusOK {
|
||||
@@ -67,7 +69,7 @@ func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
|
||||
Return(tc.writeErr).Times(1)
|
||||
}
|
||||
c := &configurator{logger: logger, client: client, fileManager: fileManager}
|
||||
err := c.DownloadRootHints(1000, 1000)
|
||||
err := c.DownloadRootHints(ctx, 1000, 1000)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
@@ -114,10 +116,11 @@ func Test_DownloadRootKey(t *testing.T) { //nolint:dupl
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
ctx := context.Background()
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("downloading root key from %s", constants.RootKeyURL).Times(1)
|
||||
client := mock_network.NewMockClient(mockCtrl)
|
||||
client.EXPECT().GetContent(string(constants.RootKeyURL)).
|
||||
client.EXPECT().Get(ctx, string(constants.RootKeyURL)).
|
||||
Return(tc.content, tc.status, tc.clientErr).Times(1)
|
||||
fileManager := mock_files.NewMockFileManager(mockCtrl)
|
||||
if tc.clientErr == nil && tc.status == http.StatusOK {
|
||||
@@ -129,7 +132,7 @@ func Test_DownloadRootKey(t *testing.T) { //nolint:dupl
|
||||
).Return(tc.writeErr).Times(1)
|
||||
}
|
||||
c := &configurator{logger: logger, client: client, fileManager: fileManager}
|
||||
err := c.DownloadRootKey(1000, 1001)
|
||||
err := c.DownloadRootKey(ctx, 1000, 1001)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
|
||||
Reference in New Issue
Block a user