Code maintenance: OS package for file system

- OS custom internal package for file system interaction
- Remove fileManager external dependency
- Closer API to Go's native API on the OS
- Create directories at startup
- Better testability
- Move Unsetenv to os interface
This commit is contained in:
Quentin McGaw
2020-12-29 00:55:31 +00:00
parent f5366c33bc
commit 73479bab26
43 changed files with 923 additions and 353 deletions

View File

@@ -10,6 +10,10 @@ issues:
linters:
- dupl
- maligned
- path: internal/os/alias\.go
linters:
- gochecknoglobals
text: IsNotExist is a global variable
linters:
disable-all: true

View File

@@ -5,7 +5,7 @@ import (
"fmt"
"net"
"net/http"
"os"
nativeos "os"
"os/signal"
"strings"
"sync"
@@ -22,6 +22,7 @@ import (
gluetunLogging "github.com/qdm12/gluetun/internal/logging"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/openvpn"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/params"
"github.com/qdm12/gluetun/internal/publicip"
"github.com/qdm12/gluetun/internal/routing"
@@ -32,7 +33,6 @@ import (
"github.com/qdm12/gluetun/internal/updater"
versionpkg "github.com/qdm12/gluetun/internal/version"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network"
)
@@ -50,21 +50,24 @@ func main() {
buildInfo.Commit = commit
buildInfo.BuildDate = buildDate
ctx := context.Background()
os.Exit(_main(ctx, os.Args))
args := nativeos.Args
os := os.New()
nativeos.Exit(_main(ctx, args, os))
}
func _main(background context.Context, args []string) int { //nolint:gocognit,gocyclo
//nolint:gocognit,gocyclo
func _main(background context.Context, args []string, os os.OS) int {
if len(args) > 1 { // cli operation
var err error
switch args[1] {
case "healthcheck":
err = cli.HealthCheck(background)
case "clientkey":
err = cli.ClientKey(args[2:])
err = cli.ClientKey(args[2:], os.OpenFile)
case "openvpnconfig":
err = cli.OpenvpnConfig()
err = cli.OpenvpnConfig(os)
case "update":
err = cli.Update(args[2:])
err = cli.Update(args[2:], os)
default:
err = fmt.Errorf("command %q is unknown", args[1])
}
@@ -82,15 +85,14 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
httpClient := &http.Client{Timeout: clientTimeout}
client := network.NewClient(clientTimeout)
// Create configurators
fileManager := files.NewFileManager()
alpineConf := alpine.NewConfigurator(fileManager)
ovpnConf := openvpn.NewConfigurator(logger, fileManager)
dnsConf := dns.NewConfigurator(logger, client, fileManager)
alpineConf := alpine.NewConfigurator(os.OpenFile)
ovpnConf := openvpn.NewConfigurator(logger, os)
dnsConf := dns.NewConfigurator(logger, client, os.OpenFile)
routingConf := routing.NewRouting(logger)
firewallConf := firewall.NewConfigurator(logger, routingConf, fileManager)
firewallConf := firewall.NewConfigurator(logger, routingConf, os.OpenFile)
streamMerger := command.NewStreamMerger()
paramsReader := params.NewReader(logger, fileManager)
paramsReader := params.NewReader(logger, os)
fmt.Println(gluetunLogging.Splash(buildInfo))
printVersions(ctx, logger, map[string]func(ctx context.Context) (string, error){
@@ -106,8 +108,17 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
}
logger.Info(allSettings.String())
if err := os.MkdirAll("/tmp/gluetun", 0644); err != nil {
logger.Error(err)
return 1
}
if err := os.MkdirAll("/gluetun", 0644); err != nil {
logger.Error(err)
return 1
}
// TODO run this in a loop or in openvpn to reload from file without restarting
storage := storage.New(logger)
storage := storage.New(logger, os)
const updateServerFile = true
allServers, err := storage.SyncServers(constants.GetAllServers(), updateServerFile)
if err != nil {
@@ -124,8 +135,8 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
logger.Error(err)
return 1
}
err = fileManager.SetOwnership("/etc/unbound", uid, gid)
if err != nil {
if err := os.Chown("/etc/unbound", uid, gid); err != nil {
logger.Error(err)
return 1
}
@@ -219,7 +230,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady)
openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, uid, gid, allServers,
ovpnConf, firewallConf, routingConf, logger, httpClient, fileManager, streamMerger, cancel)
ovpnConf, firewallConf, routingConf, logger, httpClient, os.OpenFile, streamMerger, cancel)
wg.Add(1)
// wait for restartOpenvpn
go openvpnLooper.Run(ctx, wg)
@@ -236,7 +247,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
go unboundLooper.Run(ctx, wg, signalDNSReady)
publicIPLooper := publicip.NewLooper(
client, logger, fileManager, allSettings.PublicIP, uid, gid)
client, logger, allSettings.PublicIP, uid, gid, os)
wg.Add(1)
go publicIPLooper.Run(ctx, wg)
wg.Add(1)
@@ -279,11 +290,11 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
// until openvpn is launched
_, _ = openvpnLooper.SetStatus(constants.Running) // TODO option to disable with variable
signalsCh := make(chan os.Signal, 1)
signalsCh := make(chan nativeos.Signal, 1)
signal.Notify(signalsCh,
syscall.SIGINT,
syscall.SIGTERM,
os.Interrupt,
nativeos.Interrupt,
)
shutdownErrorsCount := 0
select {
@@ -295,7 +306,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
}
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
logger.Info("Clearing forwarded port status file %s", allSettings.OpenVPN.Provider.PortForwarding.Filepath)
if err := fileManager.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil {
if err := os.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil {
logger.Error(err)
shutdownErrorsCount++
}

View File

@@ -3,7 +3,7 @@ package alpine
import (
"os/user"
"github.com/qdm12/golibs/files"
"github.com/qdm12/gluetun/internal/os"
)
type Configurator interface {
@@ -11,15 +11,15 @@ type Configurator interface {
}
type configurator struct {
fileManager files.FileManager
lookupUID func(uid string) (*user.User, error)
lookupUser func(username string) (*user.User, error)
openFile os.OpenFileFunc
lookupUID func(uid string) (*user.User, error)
lookupUser func(username string) (*user.User, error)
}
func NewConfigurator(fileManager files.FileManager) Configurator {
func NewConfigurator(openFile os.OpenFileFunc) Configurator {
return &configurator{
fileManager: fileManager,
lookupUID: user.LookupId,
lookupUser: user.Lookup,
openFile: openFile,
lookupUID: user.LookupId,
lookupUser: user.Lookup,
}
}

View File

@@ -2,6 +2,7 @@ package alpine
import (
"fmt"
"os"
"os/user"
)
@@ -26,14 +27,15 @@ func (c *configurator) CreateUser(username string, uid int) (createdUsername str
return "", fmt.Errorf("cannot create user: user with name %s already exists for ID %s instead of %d",
username, u.Uid, uid)
}
passwd, err := c.fileManager.ReadFile("/etc/passwd")
file, err := c.openFile("/etc/passwd", os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
return "", fmt.Errorf("cannot create user: %w", err)
}
passwd = append(passwd, []byte(fmt.Sprintf("%s:x:%d:::/dev/null:/sbin/nologin\n", username, uid))...)
if err := c.fileManager.WriteToFile("/etc/passwd", passwd); err != nil {
return "", fmt.Errorf("cannot create user: %w", err)
s := fmt.Sprintf("%s:x:%d:::/dev/null:/sbin/nologin\n", username, uid)
_, err = file.WriteString(s)
if err != nil {
_ = file.Close()
return "", err
}
return username, nil
return username, file.Close()
}

View File

@@ -4,29 +4,40 @@ import (
"context"
"flag"
"fmt"
"io/ioutil"
"net/http"
"strings"
"time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/healthcheck"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/params"
"github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/gluetun/internal/storage"
"github.com/qdm12/gluetun/internal/updater"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
)
func ClientKey(args []string) error {
func ClientKey(args []string, openFile os.OpenFileFunc) error {
flagSet := flag.NewFlagSet("clientkey", flag.ExitOnError)
filepath := flagSet.String("path", string(constants.ClientKey), "file path to the client.key file")
if err := flagSet.Parse(args); err != nil {
return err
}
fileManager := files.NewFileManager()
data, err := fileManager.ReadFile(*filepath)
file, err := openFile(*filepath, os.O_RDONLY, 0)
if err != nil {
return err
}
data, err := ioutil.ReadAll(file)
if err != nil {
_ = file.Close()
return err
}
if err := file.Close(); err != nil {
return err
}
if err != nil {
return err
}
@@ -49,17 +60,17 @@ func HealthCheck(ctx context.Context) error {
return healthchecker.Check(ctx, url)
}
func OpenvpnConfig() error {
func OpenvpnConfig(os os.OS) error {
logger, err := logging.NewLogger(logging.ConsoleEncoding, logging.InfoLevel)
if err != nil {
return err
}
paramsReader := params.NewReader(logger, files.NewFileManager())
paramsReader := params.NewReader(logger, os)
allSettings, err := settings.GetAllSettings(paramsReader)
if err != nil {
return err
}
allServers, err := storage.New(logger).SyncServers(constants.GetAllServers(), false)
allServers, err := storage.New(logger, os).SyncServers(constants.GetAllServers(), false)
if err != nil {
return err
}
@@ -81,7 +92,7 @@ func OpenvpnConfig() error {
return nil
}
func Update(args []string) error {
func Update(args []string, os os.OS) error {
options := settings.Updater{CLI: true}
var flushToFile bool
flagSet := flag.NewFlagSet("update", flag.ExitOnError)
@@ -110,7 +121,7 @@ func Update(args []string) error {
ctx := context.Background()
const clientTimeout = 10 * time.Second
httpClient := &http.Client{Timeout: clientTimeout}
storage := storage.New(logger)
storage := storage.New(logger, os)
const writeSync = false
currentServers, err := storage.SyncServers(constants.GetAllServers(), writeSync)
if err != nil {

View File

@@ -1,8 +0,0 @@
package constants
import "os"
const (
UserReadPermission os.FileMode = 0400
AllReadWritePermissions os.FileMode = 0666
)

View File

@@ -8,8 +8,8 @@ import (
"strings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network"
)
@@ -21,11 +21,29 @@ func (c *configurator) MakeUnboundConf(ctx context.Context, settings settings.DN
for _, warning := range warnings {
c.logger.Warn(warning)
}
return c.fileManager.WriteLinesToFile(
string(constants.UnboundConf),
lines,
files.Ownership(uid, gid),
files.Permissions(constants.UserReadPermission))
const filepath = string(constants.UnboundConf)
file, err := c.openFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0400)
if err != nil {
return err
}
_, err = file.WriteString(strings.Join(lines, "\n"))
if err != nil {
_ = file.Close()
return err
}
if err := file.Chown(uid, gid); err != nil {
_ = file.Close()
return err
}
if err := file.Close(); err != nil {
return err
}
return nil
}
// MakeUnboundConf generates an Unbound configuration from the user provided settings.

View File

@@ -5,9 +5,9 @@ import (
"io"
"net"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network"
)
@@ -24,19 +24,20 @@ type Configurator interface {
}
type configurator struct {
logger logging.Logger
client network.Client
fileManager files.FileManager
commander command.Commander
lookupIP func(host string) ([]net.IP, error)
logger logging.Logger
client network.Client
openFile os.OpenFileFunc
commander command.Commander
lookupIP func(host string) ([]net.IP, error)
}
func NewConfigurator(logger logging.Logger, client network.Client, fileManager files.FileManager) Configurator {
func NewConfigurator(logger logging.Logger, client network.Client,
openFile os.OpenFileFunc) Configurator {
return &configurator{
logger: logger.WithPrefix("dns configurator: "),
client: client,
fileManager: fileManager,
commander: command.NewCommander(),
lookupIP: net.LookupIP,
logger: logger.WithPrefix("dns configurator: "),
client: client,
openFile: openFile,
commander: command.NewCommander(),
lookupIP: net.LookupIP,
}
}

View File

@@ -2,10 +2,12 @@ package dns
import (
"context"
"io/ioutil"
"net"
"strings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/os"
)
// UseDNSInternally is to change the Go program DNS only.
@@ -23,10 +25,16 @@ func (c *configurator) UseDNSInternally(ip net.IP) {
// UseDNSSystemWide changes the nameserver to use for DNS system wide.
func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
c.logger.Info("using DNS address %s system wide", ip.String())
data, err := c.fileManager.ReadFile(string(constants.ResolvConf))
const filepath = string(constants.ResolvConf)
file, err := c.openFile(filepath, os.O_RDWR, 0644)
if err != nil {
return err
}
data, err := ioutil.ReadAll(file)
if err != nil {
_ = file.Close()
return err
}
s := strings.TrimSuffix(string(data), "\n")
lines := strings.Split(s, "\n")
if len(lines) == 1 && lines[0] == "" {
@@ -44,6 +52,11 @@ func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
if !found {
lines = append(lines, "nameserver "+ip.String())
}
data = []byte(strings.Join(lines, "\n"))
return c.fileManager.WriteToFile(string(constants.ResolvConf), data)
s = strings.Join(lines, "\n")
_, err = file.WriteString(s)
if err != nil {
_ = file.Close()
return err
}
return file.Close()
}

View File

@@ -2,12 +2,14 @@ package dns
import (
"fmt"
"io"
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/files/mock_files"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/os/mock_os"
"github.com/qdm12/golibs/logging/mock_logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -17,30 +19,36 @@ func Test_UseDNSSystemWide(t *testing.T) {
t.Parallel()
tests := map[string]struct {
data []byte
writtenData []byte
writtenData string
openErr error
readErr error
writeErr error
closeErr error
err error
}{
"no data": {
writtenData: []byte("nameserver 127.0.0.1"),
writtenData: "nameserver 127.0.0.1",
},
"open error": {
openErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"read error": {
readErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"write error": {
writtenData: []byte("nameserver 127.0.0.1"),
writtenData: "nameserver 127.0.0.1",
writeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"lines without nameserver": {
data: []byte("abc\ndef\n"),
writtenData: []byte("abc\ndef\nnameserver 127.0.0.1"),
writtenData: "abc\ndef\nnameserver 127.0.0.1",
},
"lines with nameserver": {
data: []byte("abc\nnameserver abc def\ndef\n"),
writtenData: []byte("abc\nnameserver 127.0.0.1\ndef"),
writtenData: "abc\nnameserver 127.0.0.1\ndef",
},
}
for name, tc := range tests {
@@ -49,18 +57,43 @@ func Test_UseDNSSystemWide(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
fileManager := mock_files.NewMockFileManager(mockCtrl)
fileManager.EXPECT().ReadFile(string(constants.ResolvConf)).
Return(tc.data, tc.readErr)
if tc.readErr == nil {
fileManager.EXPECT().WriteToFile(string(constants.ResolvConf), tc.writtenData).
Return(tc.writeErr)
file := mock_os.NewMockFile(mockCtrl)
if tc.openErr == nil {
firstReadCall := file.EXPECT().
Read(gomock.AssignableToTypeOf([]byte{})).
DoAndReturn(func(b []byte) (int, error) {
copy(b, tc.data)
return len(tc.data), nil
})
readErr := tc.readErr
if readErr == nil {
readErr = io.EOF
}
finalReadCall := file.EXPECT().
Read(gomock.AssignableToTypeOf([]byte{})).
Return(0, readErr).After(firstReadCall)
if tc.readErr == nil {
writeCall := file.EXPECT().WriteString(tc.writtenData).
Return(0, tc.writeErr).After(finalReadCall)
file.EXPECT().Close().Return(tc.closeErr).After(writeCall)
} else {
file.EXPECT().Close().Return(tc.closeErr).After(finalReadCall)
}
}
openFile := func(name string, flag int, perm os.FileMode) (os.File, error) {
assert.Equal(t, string(constants.ResolvConf), name)
assert.Equal(t, os.O_RDWR, flag)
assert.Equal(t, os.FileMode(0644), perm)
return file, tc.openErr
}
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("using DNS address %s system wide", "127.0.0.1")
c := &configurator{
fileManager: fileManager,
logger: logger,
openFile: openFile,
logger: logger,
}
err := c.UseDNSSystemWide(net.IP{127, 0, 0, 1}, false)
if tc.err != nil {

View File

@@ -4,37 +4,46 @@ import (
"context"
"fmt"
"net/http"
"os"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/files"
)
func (c *configurator) DownloadRootHints(ctx context.Context, uid, gid int) error {
c.logger.Info("downloading root hints from %s", constants.NamedRootURL)
content, status, err := c.client.Get(ctx, string(constants.NamedRootURL))
if err != nil {
return err
} else if status != http.StatusOK {
return fmt.Errorf("HTTP status code is %d for %s", status, constants.NamedRootURL)
}
return c.fileManager.WriteToFile(
string(constants.RootHints),
content,
files.Ownership(uid, gid),
files.Permissions(constants.UserReadPermission))
return c.downloadAndSave(ctx, "root hints",
string(constants.NamedRootURL), string(constants.RootHints), uid, gid)
}
func (c *configurator) DownloadRootKey(ctx context.Context, uid, gid int) error {
c.logger.Info("downloading root key from %s", constants.RootKeyURL)
content, status, err := c.client.Get(ctx, string(constants.RootKeyURL))
return c.downloadAndSave(ctx, "root key",
string(constants.RootKeyURL), string(constants.RootKey), uid, gid)
}
func (c *configurator) downloadAndSave(ctx context.Context, logName, url, filepath string, uid, gid int) error {
c.logger.Info("downloading %s from %s", logName, url)
content, status, err := c.client.Get(ctx, url)
if err != nil {
return err
} else if status != http.StatusOK {
return fmt.Errorf("HTTP status code is %d for %s", status, constants.RootKeyURL)
return fmt.Errorf("HTTP status code is %d for %s", status, url)
}
return c.fileManager.WriteToFile(
string(constants.RootKey),
content,
files.Ownership(uid, gid),
files.Permissions(constants.UserReadPermission))
file, err := c.openFile(filepath, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0400)
if err != nil {
return err
}
_, err = file.Write(content)
if err != nil {
_ = file.Close()
return err
}
err = file.Chown(uid, gid)
if err != nil {
_ = file.Close()
return err
}
return file.Close()
}

View File

@@ -2,27 +2,31 @@ package dns
import (
"context"
"errors"
"fmt"
"net/http"
"testing"
"github.com/golang/mock/gomock"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/files/mock_files"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/os/mock_os"
"github.com/qdm12/golibs/logging/mock_logging"
"github.com/qdm12/golibs/network/mock_network"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
func Test_downloadAndSave(t *testing.T) {
t.Parallel()
tests := map[string]struct {
content []byte
status int
clientErr error
openErr error
writeErr error
chownErr error
closeErr error
err error
}{
"no data": {
@@ -36,11 +40,26 @@ func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
clientErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"open error": {
status: http.StatusOK,
openErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"write error": {
status: http.StatusOK,
writeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"chown error": {
status: http.StatusOK,
chownErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"close error": {
status: http.StatusOK,
closeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"data": {
content: []byte("content"),
status: http.StatusOK,
@@ -52,23 +71,49 @@ func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
t.Parallel()
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
ctx := context.Background()
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("downloading root hints from %s", constants.NamedRootURL)
logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL))
client := mock_network.NewMockClient(mockCtrl)
client.EXPECT().Get(ctx, string(constants.NamedRootURL)).
Return(tc.content, tc.status, tc.clientErr)
fileManager := mock_files.NewMockFileManager(mockCtrl)
if tc.clientErr == nil && tc.status == http.StatusOK {
fileManager.EXPECT().WriteToFile(
string(constants.RootHints),
tc.content,
gomock.AssignableToTypeOf(files.Ownership(0, 0)),
gomock.AssignableToTypeOf(files.Ownership(0, 0))).
Return(tc.writeErr)
openFile := func(name string, flag int, perm os.FileMode) (os.File, error) {
return nil, nil
}
c := &configurator{logger: logger, client: client, fileManager: fileManager}
err := c.DownloadRootHints(ctx, 1000, 1000)
if tc.clientErr == nil && tc.status == http.StatusOK {
file := mock_os.NewMockFile(mockCtrl)
if tc.openErr == nil {
writeCall := file.EXPECT().Write(tc.content).
Return(0, tc.writeErr)
if tc.writeErr != nil {
file.EXPECT().Close().Return(tc.closeErr).After(writeCall)
} else {
chownCall := file.EXPECT().Chown(1000, 1000).Return(tc.chownErr).After(writeCall)
file.EXPECT().Close().Return(tc.closeErr).After(chownCall)
}
}
openFile = func(name string, flag int, perm os.FileMode) (os.File, error) {
assert.Equal(t, string(constants.RootHints), name)
assert.Equal(t, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, flag)
assert.Equal(t, os.FileMode(0400), perm)
return file, tc.openErr
}
}
c := &configurator{
logger: logger,
client: client,
openFile: openFile,
}
err := c.downloadAndSave(ctx, "root hints",
string(constants.NamedRootURL), string(constants.RootHints),
1000, 1000)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
@@ -79,65 +124,44 @@ func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
}
}
func Test_DownloadRootKey(t *testing.T) { //nolint:dupl
func Test_DownloadRootHints(t *testing.T) {
t.Parallel()
tests := map[string]struct {
content []byte
status int
clientErr error
writeErr error
err error
}{
"no data": {
status: http.StatusOK,
},
"bad status": {
status: http.StatusBadRequest,
err: fmt.Errorf("HTTP status code is 400 for https://raw.githubusercontent.com/qdm12/files/master/root.key.updated"), //nolint:lll
},
"client error": {
clientErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"write error": {
status: http.StatusOK,
writeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"data": {
content: []byte("content"),
status: http.StatusOK,
},
}
for name, tc := range tests {
tc := tc
t.Run(name, func(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
ctx := context.Background()
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("downloading root key from %s", constants.RootKeyURL)
client := mock_network.NewMockClient(mockCtrl)
client.EXPECT().Get(ctx, string(constants.RootKeyURL)).
Return(tc.content, tc.status, tc.clientErr)
fileManager := mock_files.NewMockFileManager(mockCtrl)
if tc.clientErr == nil && tc.status == http.StatusOK {
fileManager.EXPECT().WriteToFile(
string(constants.RootKey),
tc.content,
gomock.AssignableToTypeOf(files.Ownership(0, 0)),
gomock.AssignableToTypeOf(files.Ownership(0, 0)),
).Return(tc.writeErr)
}
c := &configurator{logger: logger, client: client, fileManager: fileManager}
err := c.DownloadRootKey(ctx, 1000, 1001)
if tc.err != nil {
require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error())
} else {
assert.NoError(t, err)
}
})
mockCtrl := gomock.NewController(t)
ctx := context.Background()
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL))
client := mock_network.NewMockClient(mockCtrl)
client.EXPECT().Get(ctx, string(constants.NamedRootURL)).
Return(nil, http.StatusOK, errors.New("test"))
c := &configurator{
logger: logger,
client: client,
}
err := c.DownloadRootHints(ctx, 1000, 1000)
require.Error(t, err)
assert.Equal(t, "test", err.Error())
}
func Test_DownloadRootKey(t *testing.T) {
t.Parallel()
mockCtrl := gomock.NewController(t)
ctx := context.Background()
logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("downloading %s from %s", "root key", string(constants.RootKeyURL))
client := mock_network.NewMockClient(mockCtrl)
client.EXPECT().Get(ctx, string(constants.RootKeyURL)).
Return(nil, http.StatusOK, errors.New("test"))
c := &configurator{
logger: logger,
client: client,
}
err := c.DownloadRootKey(ctx, 1000, 1000)
require.Error(t, err)
assert.Equal(t, "test", err.Error())
}

View File

@@ -6,9 +6,9 @@ import (
"sync"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
)
@@ -29,7 +29,7 @@ type configurator struct { //nolint:maligned
commander command.Commander
logger logging.Logger
routing routing.Routing
fileManager files.FileManager // for custom iptables rules
openFile os.OpenFileFunc // for custom iptables rules
iptablesMutex sync.Mutex
debug bool
defaultInterface string
@@ -47,12 +47,12 @@ type configurator struct { //nolint:maligned
}
// NewConfigurator creates a new Configurator instance.
func NewConfigurator(logger logging.Logger, routing routing.Routing, fileManager files.FileManager) Configurator {
func NewConfigurator(logger logging.Logger, routing routing.Routing, openFile os.OpenFileFunc) Configurator {
return &configurator{
commander: command.NewCommander(),
logger: logger.WithPrefix("firewall: "),
routing: routing,
fileManager: fileManager,
openFile: openFile,
allowedInputPorts: make(map[uint16]string),
}
}

View File

@@ -3,7 +3,9 @@ package firewall
import (
"context"
"fmt"
"io/ioutil"
"net"
"os"
"strings"
"github.com/qdm12/gluetun/internal/models"
@@ -150,14 +152,18 @@ func (c *configurator) acceptInputToPort(ctx context.Context, intf string, port
}
func (c *configurator) runUserPostRules(ctx context.Context, filepath string, remove bool) error {
exists, err := c.fileManager.FileExists(filepath)
if err != nil {
return err
} else if !exists {
file, err := c.openFile(filepath, os.O_RDONLY, 0)
if os.IsNotExist(err) {
return nil
} else if err != nil {
return err
}
b, err := c.fileManager.ReadFile(filepath)
b, err := ioutil.ReadAll(file)
if err != nil {
_ = file.Close()
return err
}
if err := file.Close(); err != nil {
return err
}
lines := strings.Split(string(b), "\n")

View File

@@ -1,31 +1,68 @@
package openvpn
import (
"io/ioutil"
"os"
"strings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/golibs/files"
)
// WriteAuthFile writes the OpenVPN auth file to disk with the right permissions.
func (c *configurator) WriteAuthFile(user, password string, uid, gid int) error {
exists, err := c.fileManager.FileExists(string(constants.OpenVPNAuthConf))
if err != nil {
const filepath = string(constants.OpenVPNAuthConf)
file, err := c.os.OpenFile(filepath, os.O_RDONLY, 0)
if err != nil && !os.IsNotExist(err) {
return err
} else if exists {
data, err := c.fileManager.ReadFile(string(constants.OpenVPNAuthConf))
}
if os.IsNotExist(err) {
file, err = c.os.OpenFile(filepath, os.O_WRONLY|os.O_CREATE, 0400)
if err != nil {
return err
}
lines := strings.Split(string(data), "\n")
if len(lines) > 1 && lines[0] == user && lines[1] == password {
return nil
_, err = file.WriteString(user + "\n" + password)
if err != nil {
_ = file.Close()
return err
}
c.logger.Info("username and password changed", constants.OpenVPNAuthConf)
err = file.Chown(uid, gid)
if err != nil {
_ = file.Close()
return err
}
return file.Close()
}
return c.fileManager.WriteLinesToFile(
string(constants.OpenVPNAuthConf),
[]string{user, password},
files.Ownership(uid, gid),
files.Permissions(constants.UserReadPermission))
data, err := ioutil.ReadAll(file)
if err != nil {
_ = file.Close()
return err
}
if err := file.Close(); err != nil {
return err
}
lines := strings.Split(string(data), "\n")
if len(lines) > 1 && lines[0] == user && lines[1] == password {
return nil
}
c.logger.Info("username and password changed in %s", constants.OpenVPNAuthConf)
file, err = c.os.OpenFile(filepath, os.O_TRUNC|os.O_WRONLY, 0400)
if err != nil {
return err
}
_, err = file.WriteString(user + "\n" + password)
if err != nil {
_ = file.Close()
return err
}
err = file.Chown(uid, gid)
if err != nil {
_ = file.Close()
return err
}
return file.Close()
}

View File

@@ -4,17 +4,18 @@ import (
"context"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/provider"
"github.com/qdm12/gluetun/internal/routing"
"github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
)
@@ -43,7 +44,7 @@ type looper struct {
// Other objects
logger, pfLogger logging.Logger
client *http.Client
fileManager files.FileManager
openFile os.OpenFileFunc
streamMerger command.StreamMerger
cancel context.CancelFunc
// Internal channels and locks
@@ -57,7 +58,7 @@ type looper struct {
func NewLooper(settings settings.OpenVPN,
username string, uid, gid int, allServers models.AllServers,
conf Configurator, fw firewall.Configurator, routing routing.Routing,
logger logging.Logger, client *http.Client, fileManager files.FileManager,
logger logging.Logger, client *http.Client, openFile os.OpenFileFunc,
streamMerger command.StreamMerger, cancel context.CancelFunc) Looper {
return &looper{
state: state{
@@ -74,7 +75,7 @@ func NewLooper(settings settings.OpenVPN,
logger: logger.WithPrefix("openvpn: "),
pfLogger: logger.WithPrefix("port forwarding: "),
client: client,
fileManager: fileManager,
openFile: openFile,
streamMerger: streamMerger,
cancel: cancel,
start: make(chan struct{}),
@@ -115,8 +116,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
settings.Auth,
settings.Provider.ExtraConfigOptions,
)
if err := l.fileManager.WriteLinesToFile(string(constants.OpenVPNConf), lines,
files.Ownership(l.uid, l.gid), files.Permissions(constants.UserReadPermission)); err != nil {
if err := writeOpenvpnConf(lines, l.openFile); err != nil {
l.logger.Error(err)
l.cancel()
return
@@ -239,6 +240,22 @@ func (l *looper) portForward(ctx context.Context, wg *sync.WaitGroup,
return settings.Provider.PortForwarding.Filepath
}
providerConf.PortForward(ctx,
client, l.fileManager, l.pfLogger,
client, l.openFile, l.pfLogger,
gateway, l.fw, syncState)
}
func writeOpenvpnConf(lines []string, openFile os.OpenFileFunc) error {
const filepath = string(constants.OpenVPNConf)
file, err := openFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
return err
}
_, err = file.WriteString(strings.Join(lines, "\n"))
if err != nil {
return err
}
if err := file.Close(); err != nil {
return err
}
return nil
}

View File

@@ -3,10 +3,9 @@ package openvpn
import (
"context"
"io"
"os"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/command"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"golang.org/x/sys/unix"
)
@@ -20,21 +19,19 @@ type Configurator interface {
}
type configurator struct {
fileManager files.FileManager
logger logging.Logger
commander command.Commander
openFile func(name string, flag int, perm os.FileMode) (*os.File, error)
mkDev func(major uint32, minor uint32) uint64
mkNod func(path string, mode uint32, dev int) error
logger logging.Logger
commander command.Commander
os os.OS
mkDev func(major uint32, minor uint32) uint64
mkNod func(path string, mode uint32, dev int) error
}
func NewConfigurator(logger logging.Logger, fileManager files.FileManager) Configurator {
func NewConfigurator(logger logging.Logger, os os.OS) Configurator {
return &configurator{
fileManager: fileManager,
logger: logger.WithPrefix("openvpn configurator: "),
commander: command.NewCommander(),
openFile: os.OpenFile,
mkDev: unix.Mkdev,
mkNod: unix.Mknod,
logger: logger.WithPrefix("openvpn configurator: "),
commander: command.NewCommander(),
os: os,
mkDev: unix.Mkdev,
mkNod: unix.Mknod,
}
}

View File

@@ -11,7 +11,7 @@ import (
// CheckTUN checks the tunnel device is present and accessible.
func (c *configurator) CheckTUN() error {
c.logger.Info("checking for device %s", constants.TunnelDevice)
f, err := c.openFile(string(constants.TunnelDevice), os.O_RDWR, 0)
f, err := c.os.OpenFile(string(constants.TunnelDevice), os.O_RDWR, 0)
if err != nil {
return fmt.Errorf("TUN device is not available: %w", err)
}
@@ -23,9 +23,10 @@ func (c *configurator) CheckTUN() error {
func (c *configurator) CreateTUN() error {
c.logger.Info("creating %s", constants.TunnelDevice)
if err := c.fileManager.CreateDir("/dev/net"); err != nil {
if err := c.os.MkdirAll("/dev/net", 0751); err != nil {
return err
}
const (
major = 10
minor = 200
@@ -34,8 +35,17 @@ func (c *configurator) CreateTUN() error {
if err := c.mkNod(string(constants.TunnelDevice), unix.S_IFCHR, int(dev)); err != nil {
return err
}
if err := c.fileManager.SetUserPermissions(string(constants.TunnelDevice), 0666); err != nil {
const filepath = string(constants.TunnelDevice)
file, err := c.os.OpenFile(filepath, os.O_WRONLY, 0666)
if err != nil {
return err
}
return nil
const readWriteAllPerms os.FileMode = 0666
if err := file.Chmod(readWriteAllPerms); err != nil {
_ = file.Close()
return err
}
return file.Close()
}

9
internal/os/alias.go Normal file
View File

@@ -0,0 +1,9 @@
package os
import nativeos "os"
// Aliases used for convenience so "os" does not have to be imported
type FileMode nativeos.FileMode
var IsNotExist = nativeos.IsNotExist

16
internal/os/constants.go Normal file
View File

@@ -0,0 +1,16 @@
package os
import (
nativeos "os"
)
// Constants used for convenience so "os" does not have to be imported
//nolint:golint
const (
O_CREATE = nativeos.O_CREATE
O_TRUNC = nativeos.O_TRUNC
O_WRONLY = nativeos.O_WRONLY
O_RDONLY = nativeos.O_RDONLY
O_RDWR = nativeos.O_RDWR
)

15
internal/os/file.go Normal file
View File

@@ -0,0 +1,15 @@
package os
import (
"io"
nativeos "os"
)
//go:generate mockgen -destination=mock_$GOPACKAGE/$GOFILE . File
type File interface {
io.ReadWriteCloser
WriteString(s string) (int, error)
Chown(uid, gid int) error
Chmod(mode nativeos.FileMode) error
}

10
internal/os/funcs.go Normal file
View File

@@ -0,0 +1,10 @@
package os
import (
nativeos "os"
)
type OpenFileFunc func(name string, flag int, perm FileMode) (File, error)
type MkdirAllFunc func(name string, perm nativeos.FileMode) error
type RemoveFunc func(name string) error
type ChownFunc func(name string, uid int, gid int) error

121
internal/os/mock_os/file.go Normal file
View File

@@ -0,0 +1,121 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/os (interfaces: File)
// Package mock_os is a generated GoMock package.
package mock_os
import (
gomock "github.com/golang/mock/gomock"
os "os"
reflect "reflect"
)
// MockFile is a mock of File interface
type MockFile struct {
ctrl *gomock.Controller
recorder *MockFileMockRecorder
}
// MockFileMockRecorder is the mock recorder for MockFile
type MockFileMockRecorder struct {
mock *MockFile
}
// NewMockFile creates a new mock instance
func NewMockFile(ctrl *gomock.Controller) *MockFile {
mock := &MockFile{ctrl: ctrl}
mock.recorder = &MockFileMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockFile) EXPECT() *MockFileMockRecorder {
return m.recorder
}
// Chmod mocks base method
func (m *MockFile) Chmod(arg0 os.FileMode) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Chmod", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Chmod indicates an expected call of Chmod
func (mr *MockFileMockRecorder) Chmod(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Chmod", reflect.TypeOf((*MockFile)(nil).Chmod), arg0)
}
// Chown mocks base method
func (m *MockFile) Chown(arg0, arg1 int) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Chown", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// Chown indicates an expected call of Chown
func (mr *MockFileMockRecorder) Chown(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Chown", reflect.TypeOf((*MockFile)(nil).Chown), arg0, arg1)
}
// Close mocks base method
func (m *MockFile) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close
func (mr *MockFileMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockFile)(nil).Close))
}
// Read mocks base method
func (m *MockFile) Read(arg0 []byte) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Read", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Read indicates an expected call of Read
func (mr *MockFileMockRecorder) Read(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockFile)(nil).Read), arg0)
}
// Write mocks base method
func (m *MockFile) Write(arg0 []byte) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Write", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Write indicates an expected call of Write
func (mr *MockFileMockRecorder) Write(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockFile)(nil).Write), arg0)
}
// WriteString mocks base method
func (m *MockFile) WriteString(arg0 string) (int, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "WriteString", arg0)
ret0, _ := ret[0].(int)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// WriteString indicates an expected call of WriteString
func (mr *MockFileMockRecorder) WriteString(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteString", reflect.TypeOf((*MockFile)(nil).WriteString), arg0)
}

121
internal/os/mock_os/os.go Normal file
View File

@@ -0,0 +1,121 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/qdm12/gluetun/internal/os (interfaces: OS)
// Package mock_os is a generated GoMock package.
package mock_os
import (
gomock "github.com/golang/mock/gomock"
os "github.com/qdm12/gluetun/internal/os"
os0 "os"
reflect "reflect"
)
// MockOS is a mock of OS interface
type MockOS struct {
ctrl *gomock.Controller
recorder *MockOSMockRecorder
}
// MockOSMockRecorder is the mock recorder for MockOS
type MockOSMockRecorder struct {
mock *MockOS
}
// NewMockOS creates a new mock instance
func NewMockOS(ctrl *gomock.Controller) *MockOS {
mock := &MockOS{ctrl: ctrl}
mock.recorder = &MockOSMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockOS) EXPECT() *MockOSMockRecorder {
return m.recorder
}
// Chown mocks base method
func (m *MockOS) Chown(arg0 string, arg1, arg2 int) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Chown", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// Chown indicates an expected call of Chown
func (mr *MockOSMockRecorder) Chown(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Chown", reflect.TypeOf((*MockOS)(nil).Chown), arg0, arg1, arg2)
}
// MkdirAll mocks base method
func (m *MockOS) MkdirAll(arg0 string, arg1 os.FileMode) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MkdirAll", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// MkdirAll indicates an expected call of MkdirAll
func (mr *MockOSMockRecorder) MkdirAll(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MkdirAll", reflect.TypeOf((*MockOS)(nil).MkdirAll), arg0, arg1)
}
// OpenFile mocks base method
func (m *MockOS) OpenFile(arg0 string, arg1 int, arg2 os.FileMode) (os.File, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OpenFile", arg0, arg1, arg2)
ret0, _ := ret[0].(os.File)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OpenFile indicates an expected call of OpenFile
func (mr *MockOSMockRecorder) OpenFile(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenFile", reflect.TypeOf((*MockOS)(nil).OpenFile), arg0, arg1, arg2)
}
// Remove mocks base method
func (m *MockOS) Remove(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Remove", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Remove indicates an expected call of Remove
func (mr *MockOSMockRecorder) Remove(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockOS)(nil).Remove), arg0)
}
// Stat mocks base method
func (m *MockOS) Stat(arg0 string) (os0.FileInfo, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Stat", arg0)
ret0, _ := ret[0].(os0.FileInfo)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Stat indicates an expected call of Stat
func (mr *MockOSMockRecorder) Stat(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stat", reflect.TypeOf((*MockOS)(nil).Stat), arg0)
}
// Unsetenv mocks base method
func (m *MockOS) Unsetenv(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Unsetenv", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Unsetenv indicates an expected call of Unsetenv
func (mr *MockOSMockRecorder) Unsetenv(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unsetenv", reflect.TypeOf((*MockOS)(nil).Unsetenv), arg0)
}

39
internal/os/os.go Normal file
View File

@@ -0,0 +1,39 @@
package os
import nativeos "os"
//go:generate mockgen -destination=mock_$GOPACKAGE/$GOFILE . OS
type OS interface {
OpenFile(name string, flag int, perm FileMode) (File, error)
MkdirAll(name string, perm FileMode) error
Remove(name string) error
Chown(name string, uid int, gid int) error
Unsetenv(key string) error
Stat(name string) (nativeos.FileInfo, error)
}
func New() OS {
return &os{}
}
type os struct{}
func (o *os) OpenFile(name string, flag int, perm FileMode) (File, error) {
return nativeos.OpenFile(name, flag, nativeos.FileMode(perm))
}
func (o *os) MkdirAll(name string, perm FileMode) error {
return nativeos.MkdirAll(name, nativeos.FileMode(perm))
}
func (o *os) Remove(name string) error {
return nativeos.Remove(name)
}
func (o *os) Chown(name string, uid, gid int) error {
return nativeos.Chown(name, uid, gid)
}
func (o *os) Unsetenv(key string) error {
return nativeos.Unsetenv(key)
}
func (o *os) Stat(name string) (nativeos.FileInfo, error) {
return nativeos.Stat(name)
}

View File

@@ -3,6 +3,8 @@ package params
import (
"encoding/pem"
"fmt"
"io/ioutil"
"os"
"strings"
"github.com/qdm12/gluetun/internal/constants"
@@ -32,10 +34,19 @@ func (p *reader) GetCyberghostClientKey() (clientKey string, err error) {
} else if len(clientKey) > 0 {
return clientKey, nil
}
content, err := p.fileManager.ReadFile(string(constants.ClientKey))
const filepath = string(constants.ClientKey)
file, err := p.os.OpenFile(filepath, os.O_RDONLY, 0)
if err != nil {
return "", err
}
content, err := ioutil.ReadAll(file)
if err != nil {
_ = file.Close()
return "", err
}
if err := file.Close(); err != nil {
return "", err
}
return extractClientKey(content)
}
@@ -55,10 +66,19 @@ func extractClientKey(b []byte) (key string, err error) {
// GetCyberghostClientCertificate obtains the client certificate to use for openvpn from the
// file at /gluetun/client.crt.
func (p *reader) GetCyberghostClientCertificate() (clientCertificate string, err error) {
content, err := p.fileManager.ReadFile(string(constants.ClientCertificate))
const filepath = string(constants.ClientCertificate)
file, err := p.os.OpenFile(filepath, os.O_RDONLY, 0)
if err != nil {
return "", err
}
content, err := ioutil.ReadAll(file)
if err != nil {
_ = file.Close()
return "", err
}
if err := file.Close(); err != nil {
return "", err
}
return extractClientCertificate(content)
}

View File

@@ -11,7 +11,7 @@ import (
// GetUser obtains the user to use to connect to the VPN servers.
func (r *reader) GetUser() (s string, err error) {
defer func() {
unsetenvErr := r.unsetEnv("USER")
unsetenvErr := r.os.Unsetenv("USER")
if err == nil {
err = unsetenvErr
}
@@ -22,7 +22,7 @@ func (r *reader) GetUser() (s string, err error) {
// GetPassword obtains the password to use to connect to the VPN servers.
func (r *reader) GetPassword(required bool) (s string, err error) {
defer func() {
unsetenvErr := r.unsetEnv("PASSWORD")
unsetenvErr := r.os.Unsetenv("PASSWORD")
if err == nil {
err = unsetenvErr
}

View File

@@ -2,11 +2,10 @@ package params
import (
"net"
"os"
"time"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging"
libparams "github.com/qdm12/golibs/params"
"github.com/qdm12/golibs/verification"
@@ -128,22 +127,20 @@ type Reader interface {
}
type reader struct {
envParams libparams.EnvParams
logger logging.Logger
verifier verification.Verifier
unsetEnv func(key string) error
fileManager files.FileManager
envParams libparams.EnvParams
logger logging.Logger
verifier verification.Verifier
os os.OS
}
// Newreader returns a paramsReadeer object to read parameters from
// environment variables.
func NewReader(logger logging.Logger, fileManager files.FileManager) Reader {
func NewReader(logger logging.Logger, os os.OS) Reader {
return &reader{
envParams: libparams.NewEnvParams(),
logger: logger,
verifier: verification.NewVerifier(),
unsetEnv: os.Unsetenv,
fileManager: fileManager,
envParams: libparams.NewEnvParams(),
logger: logger,
verifier: verification.NewVerifier(),
os: os,
}
}

View File

@@ -36,7 +36,7 @@ func (r *reader) GetShadowSocksPort() (port uint16, err error) {
// SHADOWSOCKS_PASSWORD.
func (r *reader) GetShadowSocksPassword() (password string, err error) {
defer func() {
unsetErr := r.unsetEnv("SHADOWSOCKS_PASSWORD")
unsetErr := r.os.Unsetenv("SHADOWSOCKS_PASSWORD")
if err == nil {
err = unsetErr
}

View File

@@ -11,7 +11,7 @@ import (
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging"
)
@@ -133,7 +133,7 @@ func (c *cyberghost) BuildConf(connection models.OpenVPNConnection, verbosity in
}
func (c *cyberghost) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for cyberghost")
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging"
)
@@ -128,7 +128,7 @@ func (m *mullvad) BuildConf(connection models.OpenVPNConnection,
}
func (m *mullvad) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for mullvad")
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging"
)
@@ -142,7 +142,7 @@ func (n *nordvpn) BuildConf(connection models.OpenVPNConnection, verbosity int,
}
func (n *nordvpn) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for nordvpn")
}

View File

@@ -19,7 +19,7 @@ import (
"github.com/qdm12/gluetun/internal/firewall"
gluetunLog "github.com/qdm12/gluetun/internal/logging"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging"
)
@@ -183,7 +183,7 @@ func (p *pia) BuildConf(connection models.OpenVPNConnection, verbosity int, user
//nolint:gocognit
func (p *pia) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
if !p.activeServer.PortForward {
pfLogger.Error("The server %s does not support port forwarding", p.activeServer.Region)
@@ -203,7 +203,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
return
}
defer pfLogger.Warn("loop exited")
data, err := readPIAPortForwardData(fileManager)
data, err := readPIAPortForwardData(openFile)
if err != nil {
pfLogger.Error(err)
}
@@ -222,7 +222,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
if !dataFound || expired {
tryUntilSuccessful(ctx, pfLogger, func() error {
data, err = refreshPIAPortForwardData(ctx, client, gateway, fileManager)
data, err = refreshPIAPortForwardData(ctx, client, gateway, openFile)
return err
})
if ctx.Err() != nil {
@@ -240,12 +240,9 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
return
}
filepath := syncState(data.Port)
filepath := string(syncState(data.Port))
pfLogger.Info("Writing port to %s", filepath)
if err := fileManager.WriteToFile(
string(filepath), []byte(fmt.Sprintf("%d", data.Port)),
files.Permissions(constants.AllReadWritePermissions),
); err != nil {
if err := writePortForwardedToFile(openFile, filepath, data.Port); err != nil {
pfLogger.Error(err)
}
@@ -281,7 +278,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
pfLogger.Warn("Forward port has expired on %s, getting another one", data.Expiration.Format(time.RFC1123))
oldPort := data.Port
for {
data, err = refreshPIAPortForwardData(ctx, client, gateway, fileManager)
data, err = refreshPIAPortForwardData(ctx, client, gateway, openFile)
if err != nil {
pfLogger.Error(err)
continue
@@ -298,10 +295,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
}
filepath := syncState(data.Port)
pfLogger.Info("Writing port to %s", filepath)
if err := fileManager.WriteToFile(
string(filepath), []byte(fmt.Sprintf("%d", data.Port)),
files.Permissions(constants.AllReadWritePermissions),
); err != nil {
if err := writePortForwardedToFile(openFile, string(filepath), data.Port); err != nil {
pfLogger.Error(err)
}
if err := bindPIAPort(ctx, client, gateway, data); err != nil {
@@ -365,8 +359,8 @@ func newPIAHTTPClient(serverName string) (client *http.Client, err error) {
}
func refreshPIAPortForwardData(ctx context.Context, client *http.Client,
gateway net.IP, fileManager files.FileManager) (data piaPortForwardData, err error) {
data.Token, err = fetchPIAToken(ctx, fileManager, client)
gateway net.IP, openFile os.OpenFileFunc) (data piaPortForwardData, err error) {
data.Token, err = fetchPIAToken(ctx, openFile, client)
if err != nil {
return data, fmt.Errorf("cannot obtain token: %w", err)
}
@@ -374,7 +368,7 @@ func refreshPIAPortForwardData(ctx context.Context, client *http.Client,
if err != nil {
return data, fmt.Errorf("cannot obtain port forwarding data: %w", err)
}
if err := writePIAPortForwardData(fileManager, data); err != nil {
if err := writePIAPortForwardData(openFile, data); err != nil {
return data, fmt.Errorf("cannot persist port forwarding information to file: %w", err)
}
return data, nil
@@ -393,34 +387,39 @@ type piaPortForwardData struct {
Expiration time.Time `json:"expires_at"`
}
func readPIAPortForwardData(fileManager files.FileManager) (data piaPortForwardData, err error) {
func readPIAPortForwardData(openFile os.OpenFileFunc) (data piaPortForwardData, err error) {
const filepath = string(constants.PIAPortForward)
exists, err := fileManager.FileExists(filepath)
if err != nil {
return data, err
} else if !exists {
file, err := openFile(filepath, os.O_RDONLY, 0)
if os.IsNotExist(err) {
return data, nil
} else if err != nil {
return data, err
}
b, err := fileManager.ReadFile(filepath)
decoder := json.NewDecoder(file)
err = decoder.Decode(&data)
if err != nil {
_ = file.Close()
return data, err
}
if err := json.Unmarshal(b, &data); err != nil {
return data, err
}
return data, nil
return data, file.Close()
}
func writePIAPortForwardData(fileManager files.FileManager, data piaPortForwardData) (err error) {
b, err := json.Marshal(&data)
if err != nil {
return fmt.Errorf("cannot encode data: %w", err)
}
err = fileManager.WriteToFile(string(constants.PIAPortForward), b)
func writePIAPortForwardData(openFile os.OpenFileFunc, data piaPortForwardData) (err error) {
const filepath = string(constants.PIAPortForward)
file, err := openFile(filepath,
os.O_CREATE|os.O_TRUNC|os.O_WRONLY,
0644)
if err != nil {
return err
}
return nil
encoder := json.NewEncoder(file)
err = encoder.Encode(data)
if err != nil {
_ = file.Close()
return err
}
return file.Close()
}
func unpackPIAPayload(payload string) (port uint16, token string, expiration time.Time, err error) {
@@ -449,8 +448,9 @@ func packPIAPayload(port uint16, token string, expiration time.Time) (payload st
return payload, nil
}
func fetchPIAToken(ctx context.Context, fileManager files.FileManager, client *http.Client) (token string, err error) {
username, password, err := getOpenvpnCredentials(fileManager)
func fetchPIAToken(ctx context.Context, openFile os.OpenFileFunc,
client *http.Client) (token string, err error) {
username, password, err := getOpenvpnCredentials(openFile)
if err != nil {
return "", fmt.Errorf("cannot get Openvpn credentials: %w", err)
}
@@ -489,10 +489,19 @@ func fetchPIAToken(ctx context.Context, fileManager files.FileManager, client *h
return result.Token, nil
}
func getOpenvpnCredentials(fileManager files.FileManager) (username, password string, err error) {
authData, err := fileManager.ReadFile(string(constants.OpenVPNAuthConf))
func getOpenvpnCredentials(openFile os.OpenFileFunc) (username, password string, err error) {
const filepath = string(constants.OpenVPNAuthConf)
file, err := openFile(filepath, os.O_RDONLY, 0)
if err != nil {
return "", "", fmt.Errorf("cannot read openvpn auth file: %w", err)
return "", "", fmt.Errorf("cannot read openvpn auth file: %s", err)
}
authData, err := ioutil.ReadAll(file)
if err != nil {
_ = file.Close()
return "", "", fmt.Errorf("cannot read openvpn auth file: %s", err)
}
if err := file.Close(); err != nil {
return "", "", err
}
lines := strings.Split(string(authData), "\n")
const minLines = 2
@@ -586,3 +595,17 @@ func bindPIAPort(ctx context.Context, client *http.Client, gateway net.IP, data
}
return nil
}
func writePortForwardedToFile(openFile os.OpenFileFunc,
filepath string, port uint16) (err error) {
file, err := openFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
if err != nil {
return err
}
_, err = file.Write([]byte(fmt.Sprintf("%d", port)))
if err != nil {
_ = file.Close()
return err
}
return file.Close()
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging"
)
@@ -117,7 +117,7 @@ func (s *privado) BuildConf(connection models.OpenVPNConnection, verbosity int,
}
func (s *privado) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for privado")
}

View File

@@ -8,7 +8,7 @@ import (
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging"
)
@@ -18,7 +18,7 @@ type Provider interface {
BuildConf(connection models.OpenVPNConnection, verbosity int, username string,
root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string)
PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath))
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging"
)
@@ -150,7 +150,7 @@ func (p *purevpn) BuildConf(connection models.OpenVPNConnection, verbosity int,
}
func (p *purevpn) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for purevpn")
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging"
)
@@ -138,7 +138,7 @@ func (s *surfshark) BuildConf(connection models.OpenVPNConnection, verbosity int
}
func (s *surfshark) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for surfshark")
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging"
)
@@ -119,7 +119,7 @@ func (v *vyprvpn) BuildConf(connection models.OpenVPNConnection, verbosity int,
}
func (v *vyprvpn) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for vyprvpn")
}

View File

@@ -11,7 +11,7 @@ import (
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/firewall"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/golibs/files"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging"
)
@@ -132,7 +132,7 @@ func (w *windscribe) BuildConf(connection models.OpenVPNConnection, verbosity in
}
func (w *windscribe) PortForward(ctx context.Context, client *http.Client,
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
syncState func(port uint16) (pfFilepath models.Filepath)) {
panic("port forwarding is not supported for windscribe")
}

27
internal/publicip/fs.go Normal file
View File

@@ -0,0 +1,27 @@
package publicip
import "github.com/qdm12/gluetun/internal/os"
func persistPublicIP(openFile os.OpenFileFunc,
filepath string, content string, uid, gid int) error {
file, err := openFile(
filepath,
os.O_TRUNC|os.O_WRONLY|os.O_CREATE,
0644)
if err != nil {
return err
}
_, err = file.WriteString(content)
if err != nil {
_ = file.Close()
return err
}
if err := file.Chown(uid, gid); err != nil {
_ = file.Close()
return err
}
return file.Close()
}

View File

@@ -8,8 +8,8 @@ import (
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/gluetun/internal/settings"
"github.com/qdm12/golibs/files"
"github.com/qdm12/golibs/logging"
"github.com/qdm12/golibs/network"
)
@@ -27,9 +27,9 @@ type Looper interface {
type looper struct {
state state
// Objects
getter IPGetter
logger logging.Logger
fileManager files.FileManager
getter IPGetter
logger logging.Logger
os os.OS
// Fixed settings
uid int
gid int
@@ -45,8 +45,9 @@ type looper struct {
timeSince func(time.Time) time.Duration
}
func NewLooper(client network.Client, logger logging.Logger, fileManager files.FileManager,
settings settings.PublicIP, uid, gid int) Looper {
func NewLooper(client network.Client, logger logging.Logger,
settings settings.PublicIP, uid, gid int,
os os.OS) Looper {
return &looper{
state: state{
status: constants.Stopped,
@@ -55,7 +56,7 @@ func NewLooper(client network.Client, logger logging.Logger, fileManager files.F
// Objects
getter: NewIPGetter(client),
logger: logger.WithPrefix("ip getter: "),
fileManager: fileManager,
os: os,
uid: uid,
gid: gid,
start: make(chan struct{}),
@@ -125,7 +126,7 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
close(errorCh)
filepath := l.GetSettings().IPFilepath
l.logger.Info("Removing ip file %s", filepath)
if err := l.fileManager.Remove(string(filepath)); err != nil {
if err := l.os.Remove(string(filepath)); err != nil {
l.logger.Error(err)
}
return
@@ -142,12 +143,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
getCancel()
l.state.setPublicIP(ip)
l.logger.Info("Public IP address is %s", ip)
const userReadWritePermissions = 0600
err := l.fileManager.WriteLinesToFile(
string(l.state.settings.IPFilepath),
[]string{ip.String()},
files.Ownership(l.uid, l.gid),
files.Permissions(userReadWritePermissions))
filepath := string(l.state.settings.IPFilepath)
err := persistPublicIP(l.os.OpenFile, filepath, ip.String(), l.uid, l.gid)
if err != nil {
l.logger.Error(err)
}

View File

@@ -1,10 +1,8 @@
package storage
import (
"io/ioutil"
"os"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/os"
"github.com/qdm12/golibs/logging"
)
@@ -14,17 +12,13 @@ type Storage interface {
}
type storage struct {
osStat func(name string) (os.FileInfo, error)
readFile func(filename string) (data []byte, err error)
writeFile func(filename string, data []byte, perm os.FileMode) error
logger logging.Logger
os os.OS
logger logging.Logger
}
func New(logger logging.Logger) Storage {
func New(logger logging.Logger, os os.OS) Storage {
return &storage{
osStat: os.Stat,
readFile: ioutil.ReadFile,
writeFile: ioutil.WriteFile,
logger: logger.WithPrefix("storage: "),
os: os,
logger: logger.WithPrefix("storage: "),
}
}

View File

@@ -3,10 +3,10 @@ package storage
import (
"encoding/json"
"fmt"
"os"
"reflect"
"github.com/qdm12/gluetun/internal/models"
"github.com/qdm12/gluetun/internal/os"
)
const (
@@ -29,14 +29,18 @@ func (s *storage) SyncServers(hardcodedServers models.AllServers, write bool) (
allServers models.AllServers, err error) {
// Eventually read file
var serversOnFile models.AllServers
_, err = s.osStat(jsonFilepath)
file, err := s.os.OpenFile(jsonFilepath, os.O_RDONLY, 0)
if err != nil && !os.IsNotExist(err) {
return allServers, err
}
if err == nil {
serversOnFile, err = s.readFromFile()
if err != nil {
var serversOnFile models.AllServers
decoder := json.NewDecoder(file)
if err := decoder.Decode(&serversOnFile); err != nil {
_ = file.Close()
return allServers, err
}
} else if !os.IsNotExist(err) {
return allServers, err
return allServers, file.Close()
}
// Merge data from file and hardcoded
@@ -51,24 +55,16 @@ func (s *storage) SyncServers(hardcodedServers models.AllServers, write bool) (
return allServers, s.FlushToFile(allServers)
}
func (s *storage) readFromFile() (servers models.AllServers, err error) {
bytes, err := s.readFile(jsonFilepath)
if err != nil {
return servers, err
}
if err := json.Unmarshal(bytes, &servers); err != nil {
return servers, err
}
return servers, nil
}
func (s *storage) FlushToFile(servers models.AllServers) error {
bytes, err := json.MarshalIndent(servers, "", " ")
file, err := s.os.OpenFile(jsonFilepath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
return fmt.Errorf("cannot write to file: %w", err)
}
if err := s.writeFile(jsonFilepath, bytes, 0644); err != nil {
return err
}
return nil
encoder := json.NewEncoder(file)
encoder.SetIndent("", " ")
if err := encoder.Encode(servers); err != nil {
_ = file.Close()
return fmt.Errorf("cannot write to file: %w", err)
}
return file.Close()
}