Upgrade dependencies

- Use of context for custom http client
- Remove unused nodeid for logger
- Upgrade shadowsocks dependency
This commit is contained in:
Quentin McGaw
2020-10-18 02:24:34 +00:00
parent b27e637894
commit 84c1f46ae4
16 changed files with 97 additions and 71 deletions

View File

@@ -1,6 +1,7 @@
package dns
import (
"context"
"fmt"
"net/http"
"sort"
@@ -13,9 +14,9 @@ import (
"github.com/qdm12/golibs/network"
)
func (c *configurator) MakeUnboundConf(settings settings.DNS, uid, gid int) (err error) {
func (c *configurator) MakeUnboundConf(ctx context.Context, settings settings.DNS, uid, gid int) (err error) {
c.logger.Info("generating Unbound configuration")
lines, warnings := generateUnboundConf(settings, c.client, c.logger)
lines, warnings := generateUnboundConf(ctx, settings, c.client, c.logger)
for _, warning := range warnings {
c.logger.Warn(warning)
}
@@ -27,7 +28,7 @@ func (c *configurator) MakeUnboundConf(settings settings.DNS, uid, gid int) (err
}
// MakeUnboundConf generates an Unbound configuration from the user provided settings
func generateUnboundConf(settings settings.DNS, client network.Client, logger logging.Logger) (lines []string, warnings []error) {
func generateUnboundConf(ctx context.Context, settings settings.DNS, client network.Client, logger logging.Logger) (lines []string, warnings []error) {
doIPv6 := "no"
if settings.IPv6 {
doIPv6 = "yes"
@@ -70,7 +71,7 @@ func generateUnboundConf(settings settings.DNS, client network.Client, logger lo
}
// Block lists
hostnamesLines, ipsLines, warnings := buildBlocked(client,
hostnamesLines, ipsLines, warnings := buildBlocked(ctx, client,
settings.BlockMalicious, settings.BlockAds, settings.BlockSurveillance,
settings.AllowedHostnames, settings.PrivateAddresses,
)
@@ -129,18 +130,18 @@ func generateUnboundConf(settings settings.DNS, client network.Client, logger lo
return lines, warnings
}
func buildBlocked(client network.Client, blockMalicious, blockAds, blockSurveillance bool,
func buildBlocked(ctx context.Context, client network.Client, blockMalicious, blockAds, blockSurveillance bool,
allowedHostnames, privateAddresses []string) (hostnamesLines, ipsLines []string, errs []error) {
chHostnames := make(chan []string)
chIPs := make(chan []string)
chErrors := make(chan []error)
go func() {
lines, errs := buildBlockedHostnames(client, blockMalicious, blockAds, blockSurveillance, allowedHostnames)
lines, errs := buildBlockedHostnames(ctx, client, blockMalicious, blockAds, blockSurveillance, allowedHostnames)
chHostnames <- lines
chErrors <- errs
}()
go func() {
lines, errs := buildBlockedIPs(client, blockMalicious, blockAds, blockSurveillance, privateAddresses)
lines, errs := buildBlockedIPs(ctx, client, blockMalicious, blockAds, blockSurveillance, privateAddresses)
chIPs <- lines
chErrors <- errs
}()
@@ -159,8 +160,8 @@ func buildBlocked(client network.Client, blockMalicious, blockAds, blockSurveill
return hostnamesLines, ipsLines, errs
}
func getList(client network.Client, url string) (results []string, err error) {
content, status, err := client.GetContent(url)
func getList(ctx context.Context, client network.Client, url string) (results []string, err error) {
content, status, err := client.Get(ctx, url)
if err != nil {
return nil, err
} else if status != http.StatusOK {
@@ -184,7 +185,7 @@ func getList(client network.Client, url string) (results []string, err error) {
return results, nil
}
func buildBlockedHostnames(client network.Client, blockMalicious, blockAds, blockSurveillance bool,
func buildBlockedHostnames(ctx context.Context, client network.Client, blockMalicious, blockAds, blockSurveillance bool,
allowedHostnames []string) (lines []string, errs []error) {
chResults := make(chan []string)
chError := make(chan error)
@@ -192,7 +193,7 @@ func buildBlockedHostnames(client network.Client, blockMalicious, blockAds, bloc
if blockMalicious {
listsLeftToFetch++
go func() {
results, err := getList(client, string(constants.MaliciousBlockListHostnamesURL))
results, err := getList(ctx, client, string(constants.MaliciousBlockListHostnamesURL))
chResults <- results
chError <- err
}()
@@ -200,7 +201,7 @@ func buildBlockedHostnames(client network.Client, blockMalicious, blockAds, bloc
if blockAds {
listsLeftToFetch++
go func() {
results, err := getList(client, string(constants.AdsBlockListHostnamesURL))
results, err := getList(ctx, client, string(constants.AdsBlockListHostnamesURL))
chResults <- results
chError <- err
}()
@@ -208,7 +209,7 @@ func buildBlockedHostnames(client network.Client, blockMalicious, blockAds, bloc
if blockSurveillance {
listsLeftToFetch++
go func() {
results, err := getList(client, string(constants.SurveillanceBlockListHostnamesURL))
results, err := getList(ctx, client, string(constants.SurveillanceBlockListHostnamesURL))
chResults <- results
chError <- err
}()
@@ -236,7 +237,7 @@ func buildBlockedHostnames(client network.Client, blockMalicious, blockAds, bloc
return lines, errs
}
func buildBlockedIPs(client network.Client, blockMalicious, blockAds, blockSurveillance bool,
func buildBlockedIPs(ctx context.Context, client network.Client, blockMalicious, blockAds, blockSurveillance bool,
privateAddresses []string) (lines []string, errs []error) {
chResults := make(chan []string)
chError := make(chan error)
@@ -244,7 +245,7 @@ func buildBlockedIPs(client network.Client, blockMalicious, blockAds, blockSurve
if blockMalicious {
listsLeftToFetch++
go func() {
results, err := getList(client, string(constants.MaliciousBlockListIPsURL))
results, err := getList(ctx, client, string(constants.MaliciousBlockListIPsURL))
chResults <- results
chError <- err
}()
@@ -252,7 +253,7 @@ func buildBlockedIPs(client network.Client, blockMalicious, blockAds, blockSurve
if blockAds {
listsLeftToFetch++
go func() {
results, err := getList(client, string(constants.AdsBlockListIPsURL))
results, err := getList(ctx, client, string(constants.AdsBlockListIPsURL))
chResults <- results
chError <- err
}()
@@ -260,7 +261,7 @@ func buildBlockedIPs(client network.Client, blockMalicious, blockAds, blockSurve
if blockSurveillance {
listsLeftToFetch++
go func() {
results, err := getList(client, string(constants.SurveillanceBlockListIPsURL))
results, err := getList(ctx, client, string(constants.SurveillanceBlockListIPsURL))
chResults <- results
chError <- err
}()

View File

@@ -1,6 +1,7 @@
package dns
import (
"context"
"fmt"
"strings"
"testing"
@@ -31,15 +32,16 @@ func Test_generateUnboundConf(t *testing.T) {
}
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
ctx := context.Background()
client := mock_network.NewMockClient(mockCtrl)
client.EXPECT().GetContent(string(constants.MaliciousBlockListHostnamesURL)).
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListHostnamesURL)).
Return([]byte("b\na\nc"), 200, nil).Times(1)
client.EXPECT().GetContent(string(constants.MaliciousBlockListIPsURL)).
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListIPsURL)).
Return([]byte("c\nd\n"), 200, nil).Times(1)
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("%d hostnames blocked overall", 2).Times(1)
logger.EXPECT().Info("%d IP addresses blocked overall", 3).Times(1)
lines, warnings := generateUnboundConf(settings, client, logger)
lines, warnings := generateUnboundConf(ctx, settings, client, logger)
require.Len(t, warnings, 0)
expected := `
server:
@@ -232,26 +234,28 @@ func Test_buildBlocked(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
ctx := context.Background()
client := mock_network.NewMockClient(mockCtrl)
if tc.malicious.blocked {
client.EXPECT().GetContent(string(constants.MaliciousBlockListHostnamesURL)).
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListHostnamesURL)).
Return(tc.malicious.content, 200, tc.malicious.clientErr).Times(1)
client.EXPECT().GetContent(string(constants.MaliciousBlockListIPsURL)).
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListIPsURL)).
Return(tc.malicious.content, 200, tc.malicious.clientErr).Times(1)
}
if tc.ads.blocked {
client.EXPECT().GetContent(string(constants.AdsBlockListHostnamesURL)).
client.EXPECT().Get(ctx, string(constants.AdsBlockListHostnamesURL)).
Return(tc.ads.content, 200, tc.ads.clientErr).Times(1)
client.EXPECT().GetContent(string(constants.AdsBlockListIPsURL)).
client.EXPECT().Get(ctx, string(constants.AdsBlockListIPsURL)).
Return(tc.ads.content, 200, tc.ads.clientErr).Times(1)
}
if tc.surveillance.blocked {
client.EXPECT().GetContent(string(constants.SurveillanceBlockListHostnamesURL)).
client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListHostnamesURL)).
Return(tc.surveillance.content, 200, tc.surveillance.clientErr).Times(1)
client.EXPECT().GetContent(string(constants.SurveillanceBlockListIPsURL)).
client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListIPsURL)).
Return(tc.surveillance.content, 200, tc.surveillance.clientErr).Times(1)
}
hostnamesLines, ipsLines, errs := buildBlocked(client, tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked,
hostnamesLines, ipsLines, errs := buildBlocked(ctx, client,
tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked,
tc.allowedHostnames, tc.privateAddresses)
var errsString []string
for _, err := range errs {
@@ -284,10 +288,11 @@ func Test_getList(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
ctx := context.Background()
client := mock_network.NewMockClient(mockCtrl)
client.EXPECT().GetContent("irrelevant_url").
client.EXPECT().Get(ctx, "irrelevant_url").
Return(tc.content, tc.status, tc.clientErr).Times(1)
results, err := getList(client, "irrelevant_url")
results, err := getList(ctx, client, "irrelevant_url")
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
@@ -388,21 +393,23 @@ func Test_buildBlockedHostnames(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
ctx := context.Background()
client := mock_network.NewMockClient(mockCtrl)
if tc.malicious.blocked {
client.EXPECT().GetContent(string(constants.MaliciousBlockListHostnamesURL)).
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListHostnamesURL)).
Return(tc.malicious.content, 200, tc.malicious.clientErr).Times(1)
}
if tc.ads.blocked {
client.EXPECT().GetContent(string(constants.AdsBlockListHostnamesURL)).
client.EXPECT().Get(ctx, string(constants.AdsBlockListHostnamesURL)).
Return(tc.ads.content, 200, tc.ads.clientErr).Times(1)
}
if tc.surveillance.blocked {
client.EXPECT().GetContent(string(constants.SurveillanceBlockListHostnamesURL)).
client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListHostnamesURL)).
Return(tc.surveillance.content, 200, tc.surveillance.clientErr).Times(1)
}
lines, errs := buildBlockedHostnames(client,
tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked, tc.allowedHostnames)
lines, errs := buildBlockedHostnames(ctx, client,
tc.malicious.blocked, tc.ads.blocked,
tc.surveillance.blocked, tc.allowedHostnames)
var errsString []string
for _, err := range errs {
errsString = append(errsString, err.Error())
@@ -504,21 +511,23 @@ func Test_buildBlockedIPs(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
ctx := context.Background()
client := mock_network.NewMockClient(mockCtrl)
if tc.malicious.blocked {
client.EXPECT().GetContent(string(constants.MaliciousBlockListIPsURL)).
client.EXPECT().Get(ctx, string(constants.MaliciousBlockListIPsURL)).
Return(tc.malicious.content, 200, tc.malicious.clientErr).Times(1)
}
if tc.ads.blocked {
client.EXPECT().GetContent(string(constants.AdsBlockListIPsURL)).
client.EXPECT().Get(ctx, string(constants.AdsBlockListIPsURL)).
Return(tc.ads.content, 200, tc.ads.clientErr).Times(1)
}
if tc.surveillance.blocked {
client.EXPECT().GetContent(string(constants.SurveillanceBlockListIPsURL)).
client.EXPECT().Get(ctx, string(constants.SurveillanceBlockListIPsURL)).
Return(tc.surveillance.content, 200, tc.surveillance.clientErr).Times(1)
}
lines, errs := buildBlockedIPs(client,
tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked, tc.privateAddresses)
lines, errs := buildBlockedIPs(ctx, client,
tc.malicious.blocked, tc.ads.blocked,
tc.surveillance.blocked, tc.privateAddresses)
var errsString []string
for _, err := range errs {
errsString = append(errsString, err.Error())

View File

@@ -13,9 +13,9 @@ import (
)
type Configurator interface {
DownloadRootHints(uid, gid int) error
DownloadRootKey(uid, gid int) error
MakeUnboundConf(settings settings.DNS, uid, gid int) (err error)
DownloadRootHints(ctx context.Context, uid, gid int) error
DownloadRootKey(ctx context.Context, uid, gid int) error
MakeUnboundConf(ctx context.Context, settings settings.DNS, uid, gid int) (err error)
UseDNSInternally(IP net.IP)
UseDNSSystemWide(ip net.IP, keepNameserver bool) error
Start(ctx context.Context, logLevel uint8) (stdout io.ReadCloser, waitFn func() error, err error)

View File

@@ -164,15 +164,15 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady fun
settings := l.GetSettings()
// Setup
if err := l.conf.DownloadRootHints(l.uid, l.gid); err != nil {
if err := l.conf.DownloadRootHints(ctx, l.uid, l.gid); err != nil {
l.logAndWait(ctx, err)
continue
}
if err := l.conf.DownloadRootKey(l.uid, l.gid); err != nil {
if err := l.conf.DownloadRootKey(ctx, l.uid, l.gid); err != nil {
l.logAndWait(ctx, err)
continue
}
if err := l.conf.MakeUnboundConf(settings, l.uid, l.gid); err != nil {
if err := l.conf.MakeUnboundConf(ctx, settings, l.uid, l.gid); err != nil {
l.logAndWait(ctx, err)
continue
}

View File

@@ -1,6 +1,7 @@
package dns
import (
"context"
"fmt"
"net/http"
@@ -8,9 +9,9 @@ import (
"github.com/qdm12/golibs/files"
)
func (c *configurator) DownloadRootHints(uid, gid int) error {
func (c *configurator) DownloadRootHints(ctx context.Context, uid, gid int) error {
c.logger.Info("downloading root hints from %s", constants.NamedRootURL)
content, status, err := c.client.GetContent(string(constants.NamedRootURL))
content, status, err := c.client.Get(ctx, string(constants.NamedRootURL))
if err != nil {
return err
} else if status != http.StatusOK {
@@ -23,9 +24,9 @@ func (c *configurator) DownloadRootHints(uid, gid int) error {
files.Permissions(0400))
}
func (c *configurator) DownloadRootKey(uid, gid int) error {
func (c *configurator) DownloadRootKey(ctx context.Context, uid, gid int) error {
c.logger.Info("downloading root key from %s", constants.RootKeyURL)
content, status, err := c.client.GetContent(string(constants.RootKeyURL))
content, status, err := c.client.Get(ctx, string(constants.RootKeyURL))
if err != nil {
return err
} else if status != http.StatusOK {

View File

@@ -1,6 +1,7 @@
package dns
import (
"context"
"fmt"
"net/http"
"testing"
@@ -52,10 +53,11 @@ func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
t.Parallel()
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
ctx := context.Background()
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("downloading root hints from %s", constants.NamedRootURL).Times(1)
client := mock_network.NewMockClient(mockCtrl)
client.EXPECT().GetContent(string(constants.NamedRootURL)).
client.EXPECT().Get(ctx, string(constants.NamedRootURL)).
Return(tc.content, tc.status, tc.clientErr).Times(1)
fileManager := mock_files.NewMockFileManager(mockCtrl)
if tc.clientErr == nil && tc.status == http.StatusOK {
@@ -67,7 +69,7 @@ func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
Return(tc.writeErr).Times(1)
}
c := &configurator{logger: logger, client: client, fileManager: fileManager}
err := c.DownloadRootHints(1000, 1000)
err := c.DownloadRootHints(ctx, 1000, 1000)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
@@ -114,10 +116,11 @@ func Test_DownloadRootKey(t *testing.T) { //nolint:dupl
t.Parallel()
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
ctx := context.Background()
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("downloading root key from %s", constants.RootKeyURL).Times(1)
client := mock_network.NewMockClient(mockCtrl)
client.EXPECT().GetContent(string(constants.RootKeyURL)).
client.EXPECT().Get(ctx, string(constants.RootKeyURL)).
Return(tc.content, tc.status, tc.clientErr).Times(1)
fileManager := mock_files.NewMockFileManager(mockCtrl)
if tc.clientErr == nil && tc.status == http.StatusOK {
@@ -129,7 +132,7 @@ func Test_DownloadRootKey(t *testing.T) { //nolint:dupl
).Return(tc.writeErr).Times(1)
}
c := &configurator{logger: logger, client: client, fileManager: fileManager}
err := c.DownloadRootKey(1000, 1001)
err := c.DownloadRootKey(ctx, 1000, 1001)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())