Code maintenance: OS package for file system
- OS custom internal package for file system interaction - Remove fileManager external dependency - Closer API to Go's native API on the OS - Create directories at startup - Better testability - Move Unsetenv to os interface
This commit is contained in:
@@ -10,6 +10,10 @@ issues:
|
||||
linters:
|
||||
- dupl
|
||||
- maligned
|
||||
- path: internal/os/alias\.go
|
||||
linters:
|
||||
- gochecknoglobals
|
||||
text: IsNotExist is a global variable
|
||||
|
||||
linters:
|
||||
disable-all: true
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
nativeos "os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
gluetunLogging "github.com/qdm12/gluetun/internal/logging"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/openvpn"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/gluetun/internal/params"
|
||||
"github.com/qdm12/gluetun/internal/publicip"
|
||||
"github.com/qdm12/gluetun/internal/routing"
|
||||
@@ -32,7 +33,6 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/updater"
|
||||
versionpkg "github.com/qdm12/gluetun/internal/version"
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/network"
|
||||
)
|
||||
@@ -50,21 +50,24 @@ func main() {
|
||||
buildInfo.Commit = commit
|
||||
buildInfo.BuildDate = buildDate
|
||||
ctx := context.Background()
|
||||
os.Exit(_main(ctx, os.Args))
|
||||
args := nativeos.Args
|
||||
os := os.New()
|
||||
nativeos.Exit(_main(ctx, args, os))
|
||||
}
|
||||
|
||||
func _main(background context.Context, args []string) int { //nolint:gocognit,gocyclo
|
||||
//nolint:gocognit,gocyclo
|
||||
func _main(background context.Context, args []string, os os.OS) int {
|
||||
if len(args) > 1 { // cli operation
|
||||
var err error
|
||||
switch args[1] {
|
||||
case "healthcheck":
|
||||
err = cli.HealthCheck(background)
|
||||
case "clientkey":
|
||||
err = cli.ClientKey(args[2:])
|
||||
err = cli.ClientKey(args[2:], os.OpenFile)
|
||||
case "openvpnconfig":
|
||||
err = cli.OpenvpnConfig()
|
||||
err = cli.OpenvpnConfig(os)
|
||||
case "update":
|
||||
err = cli.Update(args[2:])
|
||||
err = cli.Update(args[2:], os)
|
||||
default:
|
||||
err = fmt.Errorf("command %q is unknown", args[1])
|
||||
}
|
||||
@@ -82,15 +85,14 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
|
||||
httpClient := &http.Client{Timeout: clientTimeout}
|
||||
client := network.NewClient(clientTimeout)
|
||||
// Create configurators
|
||||
fileManager := files.NewFileManager()
|
||||
alpineConf := alpine.NewConfigurator(fileManager)
|
||||
ovpnConf := openvpn.NewConfigurator(logger, fileManager)
|
||||
dnsConf := dns.NewConfigurator(logger, client, fileManager)
|
||||
alpineConf := alpine.NewConfigurator(os.OpenFile)
|
||||
ovpnConf := openvpn.NewConfigurator(logger, os)
|
||||
dnsConf := dns.NewConfigurator(logger, client, os.OpenFile)
|
||||
routingConf := routing.NewRouting(logger)
|
||||
firewallConf := firewall.NewConfigurator(logger, routingConf, fileManager)
|
||||
firewallConf := firewall.NewConfigurator(logger, routingConf, os.OpenFile)
|
||||
streamMerger := command.NewStreamMerger()
|
||||
|
||||
paramsReader := params.NewReader(logger, fileManager)
|
||||
paramsReader := params.NewReader(logger, os)
|
||||
fmt.Println(gluetunLogging.Splash(buildInfo))
|
||||
|
||||
printVersions(ctx, logger, map[string]func(ctx context.Context) (string, error){
|
||||
@@ -106,8 +108,17 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
|
||||
}
|
||||
logger.Info(allSettings.String())
|
||||
|
||||
if err := os.MkdirAll("/tmp/gluetun", 0644); err != nil {
|
||||
logger.Error(err)
|
||||
return 1
|
||||
}
|
||||
if err := os.MkdirAll("/gluetun", 0644); err != nil {
|
||||
logger.Error(err)
|
||||
return 1
|
||||
}
|
||||
|
||||
// TODO run this in a loop or in openvpn to reload from file without restarting
|
||||
storage := storage.New(logger)
|
||||
storage := storage.New(logger, os)
|
||||
const updateServerFile = true
|
||||
allServers, err := storage.SyncServers(constants.GetAllServers(), updateServerFile)
|
||||
if err != nil {
|
||||
@@ -124,8 +135,8 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
|
||||
logger.Error(err)
|
||||
return 1
|
||||
}
|
||||
err = fileManager.SetOwnership("/etc/unbound", uid, gid)
|
||||
if err != nil {
|
||||
|
||||
if err := os.Chown("/etc/unbound", uid, gid); err != nil {
|
||||
logger.Error(err)
|
||||
return 1
|
||||
}
|
||||
@@ -219,7 +230,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
|
||||
go collectStreamLines(ctx, streamMerger, logger, signalTunnelReady)
|
||||
|
||||
openvpnLooper := openvpn.NewLooper(allSettings.OpenVPN, nonRootUsername, uid, gid, allServers,
|
||||
ovpnConf, firewallConf, routingConf, logger, httpClient, fileManager, streamMerger, cancel)
|
||||
ovpnConf, firewallConf, routingConf, logger, httpClient, os.OpenFile, streamMerger, cancel)
|
||||
wg.Add(1)
|
||||
// wait for restartOpenvpn
|
||||
go openvpnLooper.Run(ctx, wg)
|
||||
@@ -236,7 +247,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
|
||||
go unboundLooper.Run(ctx, wg, signalDNSReady)
|
||||
|
||||
publicIPLooper := publicip.NewLooper(
|
||||
client, logger, fileManager, allSettings.PublicIP, uid, gid)
|
||||
client, logger, allSettings.PublicIP, uid, gid, os)
|
||||
wg.Add(1)
|
||||
go publicIPLooper.Run(ctx, wg)
|
||||
wg.Add(1)
|
||||
@@ -279,11 +290,11 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
|
||||
// until openvpn is launched
|
||||
_, _ = openvpnLooper.SetStatus(constants.Running) // TODO option to disable with variable
|
||||
|
||||
signalsCh := make(chan os.Signal, 1)
|
||||
signalsCh := make(chan nativeos.Signal, 1)
|
||||
signal.Notify(signalsCh,
|
||||
syscall.SIGINT,
|
||||
syscall.SIGTERM,
|
||||
os.Interrupt,
|
||||
nativeos.Interrupt,
|
||||
)
|
||||
shutdownErrorsCount := 0
|
||||
select {
|
||||
@@ -295,7 +306,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go
|
||||
}
|
||||
if allSettings.OpenVPN.Provider.PortForwarding.Enabled {
|
||||
logger.Info("Clearing forwarded port status file %s", allSettings.OpenVPN.Provider.PortForwarding.Filepath)
|
||||
if err := fileManager.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil {
|
||||
if err := os.Remove(string(allSettings.OpenVPN.Provider.PortForwarding.Filepath)); err != nil {
|
||||
logger.Error(err)
|
||||
shutdownErrorsCount++
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package alpine
|
||||
import (
|
||||
"os/user"
|
||||
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
)
|
||||
|
||||
type Configurator interface {
|
||||
@@ -11,15 +11,15 @@ type Configurator interface {
|
||||
}
|
||||
|
||||
type configurator struct {
|
||||
fileManager files.FileManager
|
||||
lookupUID func(uid string) (*user.User, error)
|
||||
lookupUser func(username string) (*user.User, error)
|
||||
openFile os.OpenFileFunc
|
||||
lookupUID func(uid string) (*user.User, error)
|
||||
lookupUser func(username string) (*user.User, error)
|
||||
}
|
||||
|
||||
func NewConfigurator(fileManager files.FileManager) Configurator {
|
||||
func NewConfigurator(openFile os.OpenFileFunc) Configurator {
|
||||
return &configurator{
|
||||
fileManager: fileManager,
|
||||
lookupUID: user.LookupId,
|
||||
lookupUser: user.Lookup,
|
||||
openFile: openFile,
|
||||
lookupUID: user.LookupId,
|
||||
lookupUser: user.Lookup,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package alpine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
)
|
||||
|
||||
@@ -26,14 +27,15 @@ func (c *configurator) CreateUser(username string, uid int) (createdUsername str
|
||||
return "", fmt.Errorf("cannot create user: user with name %s already exists for ID %s instead of %d",
|
||||
username, u.Uid, uid)
|
||||
}
|
||||
passwd, err := c.fileManager.ReadFile("/etc/passwd")
|
||||
file, err := c.openFile("/etc/passwd", os.O_APPEND|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot create user: %w", err)
|
||||
}
|
||||
passwd = append(passwd, []byte(fmt.Sprintf("%s:x:%d:::/dev/null:/sbin/nologin\n", username, uid))...)
|
||||
|
||||
if err := c.fileManager.WriteToFile("/etc/passwd", passwd); err != nil {
|
||||
return "", fmt.Errorf("cannot create user: %w", err)
|
||||
s := fmt.Sprintf("%s:x:%d:::/dev/null:/sbin/nologin\n", username, uid)
|
||||
_, err = file.WriteString(s)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return "", err
|
||||
}
|
||||
return username, nil
|
||||
return username, file.Close()
|
||||
}
|
||||
|
||||
@@ -4,29 +4,40 @@ import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/healthcheck"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/gluetun/internal/params"
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
"github.com/qdm12/gluetun/internal/settings"
|
||||
"github.com/qdm12/gluetun/internal/storage"
|
||||
"github.com/qdm12/gluetun/internal/updater"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
func ClientKey(args []string) error {
|
||||
func ClientKey(args []string, openFile os.OpenFileFunc) error {
|
||||
flagSet := flag.NewFlagSet("clientkey", flag.ExitOnError)
|
||||
filepath := flagSet.String("path", string(constants.ClientKey), "file path to the client.key file")
|
||||
if err := flagSet.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
fileManager := files.NewFileManager()
|
||||
data, err := fileManager.ReadFile(*filepath)
|
||||
file, err := openFile(*filepath, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := ioutil.ReadAll(file)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -49,17 +60,17 @@ func HealthCheck(ctx context.Context) error {
|
||||
return healthchecker.Check(ctx, url)
|
||||
}
|
||||
|
||||
func OpenvpnConfig() error {
|
||||
func OpenvpnConfig(os os.OS) error {
|
||||
logger, err := logging.NewLogger(logging.ConsoleEncoding, logging.InfoLevel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
paramsReader := params.NewReader(logger, files.NewFileManager())
|
||||
paramsReader := params.NewReader(logger, os)
|
||||
allSettings, err := settings.GetAllSettings(paramsReader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
allServers, err := storage.New(logger).SyncServers(constants.GetAllServers(), false)
|
||||
allServers, err := storage.New(logger, os).SyncServers(constants.GetAllServers(), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -81,7 +92,7 @@ func OpenvpnConfig() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func Update(args []string) error {
|
||||
func Update(args []string, os os.OS) error {
|
||||
options := settings.Updater{CLI: true}
|
||||
var flushToFile bool
|
||||
flagSet := flag.NewFlagSet("update", flag.ExitOnError)
|
||||
@@ -110,7 +121,7 @@ func Update(args []string) error {
|
||||
ctx := context.Background()
|
||||
const clientTimeout = 10 * time.Second
|
||||
httpClient := &http.Client{Timeout: clientTimeout}
|
||||
storage := storage.New(logger)
|
||||
storage := storage.New(logger, os)
|
||||
const writeSync = false
|
||||
currentServers, err := storage.SyncServers(constants.GetAllServers(), writeSync)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
package constants
|
||||
|
||||
import "os"
|
||||
|
||||
const (
|
||||
UserReadPermission os.FileMode = 0400
|
||||
AllReadWritePermissions os.FileMode = 0666
|
||||
)
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/gluetun/internal/settings"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/network"
|
||||
)
|
||||
@@ -21,11 +21,29 @@ func (c *configurator) MakeUnboundConf(ctx context.Context, settings settings.DN
|
||||
for _, warning := range warnings {
|
||||
c.logger.Warn(warning)
|
||||
}
|
||||
return c.fileManager.WriteLinesToFile(
|
||||
string(constants.UnboundConf),
|
||||
lines,
|
||||
files.Ownership(uid, gid),
|
||||
files.Permissions(constants.UserReadPermission))
|
||||
|
||||
const filepath = string(constants.UnboundConf)
|
||||
file, err := c.openFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0400)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = file.WriteString(strings.Join(lines, "\n"))
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if err := file.Chown(uid, gid); err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if err := file.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MakeUnboundConf generates an Unbound configuration from the user provided settings.
|
||||
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/gluetun/internal/settings"
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/network"
|
||||
)
|
||||
@@ -24,19 +24,20 @@ type Configurator interface {
|
||||
}
|
||||
|
||||
type configurator struct {
|
||||
logger logging.Logger
|
||||
client network.Client
|
||||
fileManager files.FileManager
|
||||
commander command.Commander
|
||||
lookupIP func(host string) ([]net.IP, error)
|
||||
logger logging.Logger
|
||||
client network.Client
|
||||
openFile os.OpenFileFunc
|
||||
commander command.Commander
|
||||
lookupIP func(host string) ([]net.IP, error)
|
||||
}
|
||||
|
||||
func NewConfigurator(logger logging.Logger, client network.Client, fileManager files.FileManager) Configurator {
|
||||
func NewConfigurator(logger logging.Logger, client network.Client,
|
||||
openFile os.OpenFileFunc) Configurator {
|
||||
return &configurator{
|
||||
logger: logger.WithPrefix("dns configurator: "),
|
||||
client: client,
|
||||
fileManager: fileManager,
|
||||
commander: command.NewCommander(),
|
||||
lookupIP: net.LookupIP,
|
||||
logger: logger.WithPrefix("dns configurator: "),
|
||||
client: client,
|
||||
openFile: openFile,
|
||||
commander: command.NewCommander(),
|
||||
lookupIP: net.LookupIP,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,10 +2,12 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
)
|
||||
|
||||
// UseDNSInternally is to change the Go program DNS only.
|
||||
@@ -23,10 +25,16 @@ func (c *configurator) UseDNSInternally(ip net.IP) {
|
||||
// UseDNSSystemWide changes the nameserver to use for DNS system wide.
|
||||
func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
|
||||
c.logger.Info("using DNS address %s system wide", ip.String())
|
||||
data, err := c.fileManager.ReadFile(string(constants.ResolvConf))
|
||||
const filepath = string(constants.ResolvConf)
|
||||
file, err := c.openFile(filepath, os.O_RDWR, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := ioutil.ReadAll(file)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
s := strings.TrimSuffix(string(data), "\n")
|
||||
lines := strings.Split(s, "\n")
|
||||
if len(lines) == 1 && lines[0] == "" {
|
||||
@@ -44,6 +52,11 @@ func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
|
||||
if !found {
|
||||
lines = append(lines, "nameserver "+ip.String())
|
||||
}
|
||||
data = []byte(strings.Join(lines, "\n"))
|
||||
return c.fileManager.WriteToFile(string(constants.ResolvConf), data)
|
||||
s = strings.Join(lines, "\n")
|
||||
_, err = file.WriteString(s)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
@@ -2,12 +2,14 @@ package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/golibs/files/mock_files"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/gluetun/internal/os/mock_os"
|
||||
"github.com/qdm12/golibs/logging/mock_logging"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -17,30 +19,36 @@ func Test_UseDNSSystemWide(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
data []byte
|
||||
writtenData []byte
|
||||
writtenData string
|
||||
openErr error
|
||||
readErr error
|
||||
writeErr error
|
||||
closeErr error
|
||||
err error
|
||||
}{
|
||||
"no data": {
|
||||
writtenData: []byte("nameserver 127.0.0.1"),
|
||||
writtenData: "nameserver 127.0.0.1",
|
||||
},
|
||||
"open error": {
|
||||
openErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"read error": {
|
||||
readErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"write error": {
|
||||
writtenData: []byte("nameserver 127.0.0.1"),
|
||||
writtenData: "nameserver 127.0.0.1",
|
||||
writeErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"lines without nameserver": {
|
||||
data: []byte("abc\ndef\n"),
|
||||
writtenData: []byte("abc\ndef\nnameserver 127.0.0.1"),
|
||||
writtenData: "abc\ndef\nnameserver 127.0.0.1",
|
||||
},
|
||||
"lines with nameserver": {
|
||||
data: []byte("abc\nnameserver abc def\ndef\n"),
|
||||
writtenData: []byte("abc\nnameserver 127.0.0.1\ndef"),
|
||||
writtenData: "abc\nnameserver 127.0.0.1\ndef",
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
@@ -49,18 +57,43 @@ func Test_UseDNSSystemWide(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
fileManager := mock_files.NewMockFileManager(mockCtrl)
|
||||
fileManager.EXPECT().ReadFile(string(constants.ResolvConf)).
|
||||
Return(tc.data, tc.readErr)
|
||||
if tc.readErr == nil {
|
||||
fileManager.EXPECT().WriteToFile(string(constants.ResolvConf), tc.writtenData).
|
||||
Return(tc.writeErr)
|
||||
|
||||
file := mock_os.NewMockFile(mockCtrl)
|
||||
if tc.openErr == nil {
|
||||
firstReadCall := file.EXPECT().
|
||||
Read(gomock.AssignableToTypeOf([]byte{})).
|
||||
DoAndReturn(func(b []byte) (int, error) {
|
||||
copy(b, tc.data)
|
||||
return len(tc.data), nil
|
||||
})
|
||||
readErr := tc.readErr
|
||||
if readErr == nil {
|
||||
readErr = io.EOF
|
||||
}
|
||||
finalReadCall := file.EXPECT().
|
||||
Read(gomock.AssignableToTypeOf([]byte{})).
|
||||
Return(0, readErr).After(firstReadCall)
|
||||
if tc.readErr == nil {
|
||||
writeCall := file.EXPECT().WriteString(tc.writtenData).
|
||||
Return(0, tc.writeErr).After(finalReadCall)
|
||||
file.EXPECT().Close().Return(tc.closeErr).After(writeCall)
|
||||
} else {
|
||||
file.EXPECT().Close().Return(tc.closeErr).After(finalReadCall)
|
||||
}
|
||||
}
|
||||
|
||||
openFile := func(name string, flag int, perm os.FileMode) (os.File, error) {
|
||||
assert.Equal(t, string(constants.ResolvConf), name)
|
||||
assert.Equal(t, os.O_RDWR, flag)
|
||||
assert.Equal(t, os.FileMode(0644), perm)
|
||||
return file, tc.openErr
|
||||
}
|
||||
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("using DNS address %s system wide", "127.0.0.1")
|
||||
c := &configurator{
|
||||
fileManager: fileManager,
|
||||
logger: logger,
|
||||
openFile: openFile,
|
||||
logger: logger,
|
||||
}
|
||||
err := c.UseDNSSystemWide(net.IP{127, 0, 0, 1}, false)
|
||||
if tc.err != nil {
|
||||
|
||||
@@ -4,37 +4,46 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/golibs/files"
|
||||
)
|
||||
|
||||
func (c *configurator) DownloadRootHints(ctx context.Context, uid, gid int) error {
|
||||
c.logger.Info("downloading root hints from %s", constants.NamedRootURL)
|
||||
content, status, err := c.client.Get(ctx, string(constants.NamedRootURL))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if status != http.StatusOK {
|
||||
return fmt.Errorf("HTTP status code is %d for %s", status, constants.NamedRootURL)
|
||||
}
|
||||
return c.fileManager.WriteToFile(
|
||||
string(constants.RootHints),
|
||||
content,
|
||||
files.Ownership(uid, gid),
|
||||
files.Permissions(constants.UserReadPermission))
|
||||
return c.downloadAndSave(ctx, "root hints",
|
||||
string(constants.NamedRootURL), string(constants.RootHints), uid, gid)
|
||||
}
|
||||
|
||||
func (c *configurator) DownloadRootKey(ctx context.Context, uid, gid int) error {
|
||||
c.logger.Info("downloading root key from %s", constants.RootKeyURL)
|
||||
content, status, err := c.client.Get(ctx, string(constants.RootKeyURL))
|
||||
return c.downloadAndSave(ctx, "root key",
|
||||
string(constants.RootKeyURL), string(constants.RootKey), uid, gid)
|
||||
}
|
||||
|
||||
func (c *configurator) downloadAndSave(ctx context.Context, logName, url, filepath string, uid, gid int) error {
|
||||
c.logger.Info("downloading %s from %s", logName, url)
|
||||
content, status, err := c.client.Get(ctx, url)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if status != http.StatusOK {
|
||||
return fmt.Errorf("HTTP status code is %d for %s", status, constants.RootKeyURL)
|
||||
return fmt.Errorf("HTTP status code is %d for %s", status, url)
|
||||
}
|
||||
return c.fileManager.WriteToFile(
|
||||
string(constants.RootKey),
|
||||
content,
|
||||
files.Ownership(uid, gid),
|
||||
files.Permissions(constants.UserReadPermission))
|
||||
|
||||
file, err := c.openFile(filepath, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0400)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = file.Write(content)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
err = file.Chown(uid, gid)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
@@ -2,27 +2,31 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/files/mock_files"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/gluetun/internal/os/mock_os"
|
||||
"github.com/qdm12/golibs/logging/mock_logging"
|
||||
"github.com/qdm12/golibs/network/mock_network"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
|
||||
func Test_downloadAndSave(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
content []byte
|
||||
status int
|
||||
clientErr error
|
||||
openErr error
|
||||
writeErr error
|
||||
chownErr error
|
||||
closeErr error
|
||||
err error
|
||||
}{
|
||||
"no data": {
|
||||
@@ -36,11 +40,26 @@ func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
|
||||
clientErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"open error": {
|
||||
status: http.StatusOK,
|
||||
openErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"write error": {
|
||||
status: http.StatusOK,
|
||||
writeErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"chown error": {
|
||||
status: http.StatusOK,
|
||||
chownErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"close error": {
|
||||
status: http.StatusOK,
|
||||
closeErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"data": {
|
||||
content: []byte("content"),
|
||||
status: http.StatusOK,
|
||||
@@ -52,23 +71,49 @@ func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
|
||||
ctx := context.Background()
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("downloading root hints from %s", constants.NamedRootURL)
|
||||
logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL))
|
||||
client := mock_network.NewMockClient(mockCtrl)
|
||||
client.EXPECT().Get(ctx, string(constants.NamedRootURL)).
|
||||
Return(tc.content, tc.status, tc.clientErr)
|
||||
fileManager := mock_files.NewMockFileManager(mockCtrl)
|
||||
if tc.clientErr == nil && tc.status == http.StatusOK {
|
||||
fileManager.EXPECT().WriteToFile(
|
||||
string(constants.RootHints),
|
||||
tc.content,
|
||||
gomock.AssignableToTypeOf(files.Ownership(0, 0)),
|
||||
gomock.AssignableToTypeOf(files.Ownership(0, 0))).
|
||||
Return(tc.writeErr)
|
||||
|
||||
openFile := func(name string, flag int, perm os.FileMode) (os.File, error) {
|
||||
return nil, nil
|
||||
}
|
||||
c := &configurator{logger: logger, client: client, fileManager: fileManager}
|
||||
err := c.DownloadRootHints(ctx, 1000, 1000)
|
||||
|
||||
if tc.clientErr == nil && tc.status == http.StatusOK {
|
||||
file := mock_os.NewMockFile(mockCtrl)
|
||||
if tc.openErr == nil {
|
||||
writeCall := file.EXPECT().Write(tc.content).
|
||||
Return(0, tc.writeErr)
|
||||
if tc.writeErr != nil {
|
||||
file.EXPECT().Close().Return(tc.closeErr).After(writeCall)
|
||||
} else {
|
||||
chownCall := file.EXPECT().Chown(1000, 1000).Return(tc.chownErr).After(writeCall)
|
||||
file.EXPECT().Close().Return(tc.closeErr).After(chownCall)
|
||||
}
|
||||
}
|
||||
|
||||
openFile = func(name string, flag int, perm os.FileMode) (os.File, error) {
|
||||
assert.Equal(t, string(constants.RootHints), name)
|
||||
assert.Equal(t, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, flag)
|
||||
assert.Equal(t, os.FileMode(0400), perm)
|
||||
return file, tc.openErr
|
||||
}
|
||||
}
|
||||
|
||||
c := &configurator{
|
||||
logger: logger,
|
||||
client: client,
|
||||
openFile: openFile,
|
||||
}
|
||||
|
||||
err := c.downloadAndSave(ctx, "root hints",
|
||||
string(constants.NamedRootURL), string(constants.RootHints),
|
||||
1000, 1000)
|
||||
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
@@ -79,65 +124,44 @@ func Test_DownloadRootHints(t *testing.T) { //nolint:dupl
|
||||
}
|
||||
}
|
||||
|
||||
func Test_DownloadRootKey(t *testing.T) { //nolint:dupl
|
||||
func Test_DownloadRootHints(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := map[string]struct {
|
||||
content []byte
|
||||
status int
|
||||
clientErr error
|
||||
writeErr error
|
||||
err error
|
||||
}{
|
||||
"no data": {
|
||||
status: http.StatusOK,
|
||||
},
|
||||
"bad status": {
|
||||
status: http.StatusBadRequest,
|
||||
err: fmt.Errorf("HTTP status code is 400 for https://raw.githubusercontent.com/qdm12/files/master/root.key.updated"), //nolint:lll
|
||||
},
|
||||
"client error": {
|
||||
clientErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"write error": {
|
||||
status: http.StatusOK,
|
||||
writeErr: fmt.Errorf("error"),
|
||||
err: fmt.Errorf("error"),
|
||||
},
|
||||
"data": {
|
||||
content: []byte("content"),
|
||||
status: http.StatusOK,
|
||||
},
|
||||
}
|
||||
for name, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
defer mockCtrl.Finish()
|
||||
ctx := context.Background()
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("downloading root key from %s", constants.RootKeyURL)
|
||||
client := mock_network.NewMockClient(mockCtrl)
|
||||
client.EXPECT().Get(ctx, string(constants.RootKeyURL)).
|
||||
Return(tc.content, tc.status, tc.clientErr)
|
||||
fileManager := mock_files.NewMockFileManager(mockCtrl)
|
||||
if tc.clientErr == nil && tc.status == http.StatusOK {
|
||||
fileManager.EXPECT().WriteToFile(
|
||||
string(constants.RootKey),
|
||||
tc.content,
|
||||
gomock.AssignableToTypeOf(files.Ownership(0, 0)),
|
||||
gomock.AssignableToTypeOf(files.Ownership(0, 0)),
|
||||
).Return(tc.writeErr)
|
||||
}
|
||||
c := &configurator{logger: logger, client: client, fileManager: fileManager}
|
||||
err := c.DownloadRootKey(ctx, 1000, 1001)
|
||||
if tc.err != nil {
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tc.err.Error(), err.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
ctx := context.Background()
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("downloading %s from %s", "root hints", string(constants.NamedRootURL))
|
||||
client := mock_network.NewMockClient(mockCtrl)
|
||||
client.EXPECT().Get(ctx, string(constants.NamedRootURL)).
|
||||
Return(nil, http.StatusOK, errors.New("test"))
|
||||
|
||||
c := &configurator{
|
||||
logger: logger,
|
||||
client: client,
|
||||
}
|
||||
|
||||
err := c.DownloadRootHints(ctx, 1000, 1000)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, "test", err.Error())
|
||||
}
|
||||
|
||||
func Test_DownloadRootKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
mockCtrl := gomock.NewController(t)
|
||||
|
||||
ctx := context.Background()
|
||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
||||
logger.EXPECT().Info("downloading %s from %s", "root key", string(constants.RootKeyURL))
|
||||
client := mock_network.NewMockClient(mockCtrl)
|
||||
client.EXPECT().Get(ctx, string(constants.RootKeyURL)).
|
||||
Return(nil, http.StatusOK, errors.New("test"))
|
||||
|
||||
c := &configurator{
|
||||
logger: logger,
|
||||
client: client,
|
||||
}
|
||||
|
||||
err := c.DownloadRootKey(ctx, 1000, 1000)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, "test", err.Error())
|
||||
}
|
||||
|
||||
@@ -6,9 +6,9 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/gluetun/internal/routing"
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ type configurator struct { //nolint:maligned
|
||||
commander command.Commander
|
||||
logger logging.Logger
|
||||
routing routing.Routing
|
||||
fileManager files.FileManager // for custom iptables rules
|
||||
openFile os.OpenFileFunc // for custom iptables rules
|
||||
iptablesMutex sync.Mutex
|
||||
debug bool
|
||||
defaultInterface string
|
||||
@@ -47,12 +47,12 @@ type configurator struct { //nolint:maligned
|
||||
}
|
||||
|
||||
// NewConfigurator creates a new Configurator instance.
|
||||
func NewConfigurator(logger logging.Logger, routing routing.Routing, fileManager files.FileManager) Configurator {
|
||||
func NewConfigurator(logger logging.Logger, routing routing.Routing, openFile os.OpenFileFunc) Configurator {
|
||||
return &configurator{
|
||||
commander: command.NewCommander(),
|
||||
logger: logger.WithPrefix("firewall: "),
|
||||
routing: routing,
|
||||
fileManager: fileManager,
|
||||
openFile: openFile,
|
||||
allowedInputPorts: make(map[uint16]string),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,9 @@ package firewall
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
@@ -150,14 +152,18 @@ func (c *configurator) acceptInputToPort(ctx context.Context, intf string, port
|
||||
}
|
||||
|
||||
func (c *configurator) runUserPostRules(ctx context.Context, filepath string, remove bool) error {
|
||||
exists, err := c.fileManager.FileExists(filepath)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if !exists {
|
||||
file, err := c.openFile(filepath, os.O_RDONLY, 0)
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
b, err := c.fileManager.ReadFile(filepath)
|
||||
b, err := ioutil.ReadAll(file)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
lines := strings.Split(string(b), "\n")
|
||||
|
||||
@@ -1,31 +1,68 @@
|
||||
package openvpn
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/golibs/files"
|
||||
)
|
||||
|
||||
// WriteAuthFile writes the OpenVPN auth file to disk with the right permissions.
|
||||
func (c *configurator) WriteAuthFile(user, password string, uid, gid int) error {
|
||||
exists, err := c.fileManager.FileExists(string(constants.OpenVPNAuthConf))
|
||||
if err != nil {
|
||||
const filepath = string(constants.OpenVPNAuthConf)
|
||||
file, err := c.os.OpenFile(filepath, os.O_RDONLY, 0)
|
||||
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
} else if exists {
|
||||
data, err := c.fileManager.ReadFile(string(constants.OpenVPNAuthConf))
|
||||
}
|
||||
|
||||
if os.IsNotExist(err) {
|
||||
file, err = c.os.OpenFile(filepath, os.O_WRONLY|os.O_CREATE, 0400)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lines := strings.Split(string(data), "\n")
|
||||
if len(lines) > 1 && lines[0] == user && lines[1] == password {
|
||||
return nil
|
||||
_, err = file.WriteString(user + "\n" + password)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
c.logger.Info("username and password changed", constants.OpenVPNAuthConf)
|
||||
err = file.Chown(uid, gid)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
return file.Close()
|
||||
}
|
||||
return c.fileManager.WriteLinesToFile(
|
||||
string(constants.OpenVPNAuthConf),
|
||||
[]string{user, password},
|
||||
files.Ownership(uid, gid),
|
||||
files.Permissions(constants.UserReadPermission))
|
||||
|
||||
data, err := ioutil.ReadAll(file)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
if len(lines) > 1 && lines[0] == user && lines[1] == password {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.logger.Info("username and password changed in %s", constants.OpenVPNAuthConf)
|
||||
file, err = c.os.OpenFile(filepath, os.O_TRUNC|os.O_WRONLY, 0400)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = file.WriteString(user + "\n" + password)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
err = file.Chown(uid, gid)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
@@ -4,17 +4,18 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/gluetun/internal/provider"
|
||||
"github.com/qdm12/gluetun/internal/routing"
|
||||
"github.com/qdm12/gluetun/internal/settings"
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
@@ -43,7 +44,7 @@ type looper struct {
|
||||
// Other objects
|
||||
logger, pfLogger logging.Logger
|
||||
client *http.Client
|
||||
fileManager files.FileManager
|
||||
openFile os.OpenFileFunc
|
||||
streamMerger command.StreamMerger
|
||||
cancel context.CancelFunc
|
||||
// Internal channels and locks
|
||||
@@ -57,7 +58,7 @@ type looper struct {
|
||||
func NewLooper(settings settings.OpenVPN,
|
||||
username string, uid, gid int, allServers models.AllServers,
|
||||
conf Configurator, fw firewall.Configurator, routing routing.Routing,
|
||||
logger logging.Logger, client *http.Client, fileManager files.FileManager,
|
||||
logger logging.Logger, client *http.Client, openFile os.OpenFileFunc,
|
||||
streamMerger command.StreamMerger, cancel context.CancelFunc) Looper {
|
||||
return &looper{
|
||||
state: state{
|
||||
@@ -74,7 +75,7 @@ func NewLooper(settings settings.OpenVPN,
|
||||
logger: logger.WithPrefix("openvpn: "),
|
||||
pfLogger: logger.WithPrefix("port forwarding: "),
|
||||
client: client,
|
||||
fileManager: fileManager,
|
||||
openFile: openFile,
|
||||
streamMerger: streamMerger,
|
||||
cancel: cancel,
|
||||
start: make(chan struct{}),
|
||||
@@ -115,8 +116,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||
settings.Auth,
|
||||
settings.Provider.ExtraConfigOptions,
|
||||
)
|
||||
if err := l.fileManager.WriteLinesToFile(string(constants.OpenVPNConf), lines,
|
||||
files.Ownership(l.uid, l.gid), files.Permissions(constants.UserReadPermission)); err != nil {
|
||||
|
||||
if err := writeOpenvpnConf(lines, l.openFile); err != nil {
|
||||
l.logger.Error(err)
|
||||
l.cancel()
|
||||
return
|
||||
@@ -239,6 +240,22 @@ func (l *looper) portForward(ctx context.Context, wg *sync.WaitGroup,
|
||||
return settings.Provider.PortForwarding.Filepath
|
||||
}
|
||||
providerConf.PortForward(ctx,
|
||||
client, l.fileManager, l.pfLogger,
|
||||
client, l.openFile, l.pfLogger,
|
||||
gateway, l.fw, syncState)
|
||||
}
|
||||
|
||||
func writeOpenvpnConf(lines []string, openFile os.OpenFileFunc) error {
|
||||
const filepath = string(constants.OpenVPNConf)
|
||||
file, err := openFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = file.WriteString(strings.Join(lines, "\n"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,10 +3,9 @@ package openvpn
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
@@ -20,21 +19,19 @@ type Configurator interface {
|
||||
}
|
||||
|
||||
type configurator struct {
|
||||
fileManager files.FileManager
|
||||
logger logging.Logger
|
||||
commander command.Commander
|
||||
openFile func(name string, flag int, perm os.FileMode) (*os.File, error)
|
||||
mkDev func(major uint32, minor uint32) uint64
|
||||
mkNod func(path string, mode uint32, dev int) error
|
||||
logger logging.Logger
|
||||
commander command.Commander
|
||||
os os.OS
|
||||
mkDev func(major uint32, minor uint32) uint64
|
||||
mkNod func(path string, mode uint32, dev int) error
|
||||
}
|
||||
|
||||
func NewConfigurator(logger logging.Logger, fileManager files.FileManager) Configurator {
|
||||
func NewConfigurator(logger logging.Logger, os os.OS) Configurator {
|
||||
return &configurator{
|
||||
fileManager: fileManager,
|
||||
logger: logger.WithPrefix("openvpn configurator: "),
|
||||
commander: command.NewCommander(),
|
||||
openFile: os.OpenFile,
|
||||
mkDev: unix.Mkdev,
|
||||
mkNod: unix.Mknod,
|
||||
logger: logger.WithPrefix("openvpn configurator: "),
|
||||
commander: command.NewCommander(),
|
||||
os: os,
|
||||
mkDev: unix.Mkdev,
|
||||
mkNod: unix.Mknod,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
// CheckTUN checks the tunnel device is present and accessible.
|
||||
func (c *configurator) CheckTUN() error {
|
||||
c.logger.Info("checking for device %s", constants.TunnelDevice)
|
||||
f, err := c.openFile(string(constants.TunnelDevice), os.O_RDWR, 0)
|
||||
f, err := c.os.OpenFile(string(constants.TunnelDevice), os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("TUN device is not available: %w", err)
|
||||
}
|
||||
@@ -23,9 +23,10 @@ func (c *configurator) CheckTUN() error {
|
||||
|
||||
func (c *configurator) CreateTUN() error {
|
||||
c.logger.Info("creating %s", constants.TunnelDevice)
|
||||
if err := c.fileManager.CreateDir("/dev/net"); err != nil {
|
||||
if err := c.os.MkdirAll("/dev/net", 0751); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
const (
|
||||
major = 10
|
||||
minor = 200
|
||||
@@ -34,8 +35,17 @@ func (c *configurator) CreateTUN() error {
|
||||
if err := c.mkNod(string(constants.TunnelDevice), unix.S_IFCHR, int(dev)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.fileManager.SetUserPermissions(string(constants.TunnelDevice), 0666); err != nil {
|
||||
|
||||
const filepath = string(constants.TunnelDevice)
|
||||
file, err := c.os.OpenFile(filepath, os.O_WRONLY, 0666)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
const readWriteAllPerms os.FileMode = 0666
|
||||
if err := file.Chmod(readWriteAllPerms); err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
9
internal/os/alias.go
Normal file
9
internal/os/alias.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package os
|
||||
|
||||
import nativeos "os"
|
||||
|
||||
// Aliases used for convenience so "os" does not have to be imported
|
||||
|
||||
type FileMode nativeos.FileMode
|
||||
|
||||
var IsNotExist = nativeos.IsNotExist
|
||||
16
internal/os/constants.go
Normal file
16
internal/os/constants.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package os
|
||||
|
||||
import (
|
||||
nativeos "os"
|
||||
)
|
||||
|
||||
// Constants used for convenience so "os" does not have to be imported
|
||||
|
||||
//nolint:golint
|
||||
const (
|
||||
O_CREATE = nativeos.O_CREATE
|
||||
O_TRUNC = nativeos.O_TRUNC
|
||||
O_WRONLY = nativeos.O_WRONLY
|
||||
O_RDONLY = nativeos.O_RDONLY
|
||||
O_RDWR = nativeos.O_RDWR
|
||||
)
|
||||
15
internal/os/file.go
Normal file
15
internal/os/file.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package os
|
||||
|
||||
import (
|
||||
"io"
|
||||
nativeos "os"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination=mock_$GOPACKAGE/$GOFILE . File
|
||||
|
||||
type File interface {
|
||||
io.ReadWriteCloser
|
||||
WriteString(s string) (int, error)
|
||||
Chown(uid, gid int) error
|
||||
Chmod(mode nativeos.FileMode) error
|
||||
}
|
||||
10
internal/os/funcs.go
Normal file
10
internal/os/funcs.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package os
|
||||
|
||||
import (
|
||||
nativeos "os"
|
||||
)
|
||||
|
||||
type OpenFileFunc func(name string, flag int, perm FileMode) (File, error)
|
||||
type MkdirAllFunc func(name string, perm nativeos.FileMode) error
|
||||
type RemoveFunc func(name string) error
|
||||
type ChownFunc func(name string, uid int, gid int) error
|
||||
121
internal/os/mock_os/file.go
Normal file
121
internal/os/mock_os/file.go
Normal file
@@ -0,0 +1,121 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/os (interfaces: File)
|
||||
|
||||
// Package mock_os is a generated GoMock package.
|
||||
package mock_os
|
||||
|
||||
import (
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
os "os"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockFile is a mock of File interface
|
||||
type MockFile struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockFileMockRecorder
|
||||
}
|
||||
|
||||
// MockFileMockRecorder is the mock recorder for MockFile
|
||||
type MockFileMockRecorder struct {
|
||||
mock *MockFile
|
||||
}
|
||||
|
||||
// NewMockFile creates a new mock instance
|
||||
func NewMockFile(ctrl *gomock.Controller) *MockFile {
|
||||
mock := &MockFile{ctrl: ctrl}
|
||||
mock.recorder = &MockFileMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockFile) EXPECT() *MockFileMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Chmod mocks base method
|
||||
func (m *MockFile) Chmod(arg0 os.FileMode) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Chmod", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Chmod indicates an expected call of Chmod
|
||||
func (mr *MockFileMockRecorder) Chmod(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Chmod", reflect.TypeOf((*MockFile)(nil).Chmod), arg0)
|
||||
}
|
||||
|
||||
// Chown mocks base method
|
||||
func (m *MockFile) Chown(arg0, arg1 int) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Chown", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Chown indicates an expected call of Chown
|
||||
func (mr *MockFileMockRecorder) Chown(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Chown", reflect.TypeOf((*MockFile)(nil).Chown), arg0, arg1)
|
||||
}
|
||||
|
||||
// Close mocks base method
|
||||
func (m *MockFile) Close() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close
|
||||
func (mr *MockFileMockRecorder) Close() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockFile)(nil).Close))
|
||||
}
|
||||
|
||||
// Read mocks base method
|
||||
func (m *MockFile) Read(arg0 []byte) (int, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Read", arg0)
|
||||
ret0, _ := ret[0].(int)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Read indicates an expected call of Read
|
||||
func (mr *MockFileMockRecorder) Read(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockFile)(nil).Read), arg0)
|
||||
}
|
||||
|
||||
// Write mocks base method
|
||||
func (m *MockFile) Write(arg0 []byte) (int, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Write", arg0)
|
||||
ret0, _ := ret[0].(int)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Write indicates an expected call of Write
|
||||
func (mr *MockFileMockRecorder) Write(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockFile)(nil).Write), arg0)
|
||||
}
|
||||
|
||||
// WriteString mocks base method
|
||||
func (m *MockFile) WriteString(arg0 string) (int, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "WriteString", arg0)
|
||||
ret0, _ := ret[0].(int)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// WriteString indicates an expected call of WriteString
|
||||
func (mr *MockFileMockRecorder) WriteString(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteString", reflect.TypeOf((*MockFile)(nil).WriteString), arg0)
|
||||
}
|
||||
121
internal/os/mock_os/os.go
Normal file
121
internal/os/mock_os/os.go
Normal file
@@ -0,0 +1,121 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/gluetun/internal/os (interfaces: OS)
|
||||
|
||||
// Package mock_os is a generated GoMock package.
|
||||
package mock_os
|
||||
|
||||
import (
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
os "github.com/qdm12/gluetun/internal/os"
|
||||
os0 "os"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockOS is a mock of OS interface
|
||||
type MockOS struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockOSMockRecorder
|
||||
}
|
||||
|
||||
// MockOSMockRecorder is the mock recorder for MockOS
|
||||
type MockOSMockRecorder struct {
|
||||
mock *MockOS
|
||||
}
|
||||
|
||||
// NewMockOS creates a new mock instance
|
||||
func NewMockOS(ctrl *gomock.Controller) *MockOS {
|
||||
mock := &MockOS{ctrl: ctrl}
|
||||
mock.recorder = &MockOSMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockOS) EXPECT() *MockOSMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Chown mocks base method
|
||||
func (m *MockOS) Chown(arg0 string, arg1, arg2 int) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Chown", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Chown indicates an expected call of Chown
|
||||
func (mr *MockOSMockRecorder) Chown(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Chown", reflect.TypeOf((*MockOS)(nil).Chown), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// MkdirAll mocks base method
|
||||
func (m *MockOS) MkdirAll(arg0 string, arg1 os.FileMode) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MkdirAll", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// MkdirAll indicates an expected call of MkdirAll
|
||||
func (mr *MockOSMockRecorder) MkdirAll(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MkdirAll", reflect.TypeOf((*MockOS)(nil).MkdirAll), arg0, arg1)
|
||||
}
|
||||
|
||||
// OpenFile mocks base method
|
||||
func (m *MockOS) OpenFile(arg0 string, arg1 int, arg2 os.FileMode) (os.File, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OpenFile", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(os.File)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// OpenFile indicates an expected call of OpenFile
|
||||
func (mr *MockOSMockRecorder) OpenFile(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OpenFile", reflect.TypeOf((*MockOS)(nil).OpenFile), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// Remove mocks base method
|
||||
func (m *MockOS) Remove(arg0 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Remove", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Remove indicates an expected call of Remove
|
||||
func (mr *MockOSMockRecorder) Remove(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockOS)(nil).Remove), arg0)
|
||||
}
|
||||
|
||||
// Stat mocks base method
|
||||
func (m *MockOS) Stat(arg0 string) (os0.FileInfo, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Stat", arg0)
|
||||
ret0, _ := ret[0].(os0.FileInfo)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Stat indicates an expected call of Stat
|
||||
func (mr *MockOSMockRecorder) Stat(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stat", reflect.TypeOf((*MockOS)(nil).Stat), arg0)
|
||||
}
|
||||
|
||||
// Unsetenv mocks base method
|
||||
func (m *MockOS) Unsetenv(arg0 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Unsetenv", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Unsetenv indicates an expected call of Unsetenv
|
||||
func (mr *MockOSMockRecorder) Unsetenv(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unsetenv", reflect.TypeOf((*MockOS)(nil).Unsetenv), arg0)
|
||||
}
|
||||
39
internal/os/os.go
Normal file
39
internal/os/os.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package os
|
||||
|
||||
import nativeos "os"
|
||||
|
||||
//go:generate mockgen -destination=mock_$GOPACKAGE/$GOFILE . OS
|
||||
|
||||
type OS interface {
|
||||
OpenFile(name string, flag int, perm FileMode) (File, error)
|
||||
MkdirAll(name string, perm FileMode) error
|
||||
Remove(name string) error
|
||||
Chown(name string, uid int, gid int) error
|
||||
Unsetenv(key string) error
|
||||
Stat(name string) (nativeos.FileInfo, error)
|
||||
}
|
||||
|
||||
func New() OS {
|
||||
return &os{}
|
||||
}
|
||||
|
||||
type os struct{}
|
||||
|
||||
func (o *os) OpenFile(name string, flag int, perm FileMode) (File, error) {
|
||||
return nativeos.OpenFile(name, flag, nativeos.FileMode(perm))
|
||||
}
|
||||
func (o *os) MkdirAll(name string, perm FileMode) error {
|
||||
return nativeos.MkdirAll(name, nativeos.FileMode(perm))
|
||||
}
|
||||
func (o *os) Remove(name string) error {
|
||||
return nativeos.Remove(name)
|
||||
}
|
||||
func (o *os) Chown(name string, uid, gid int) error {
|
||||
return nativeos.Chown(name, uid, gid)
|
||||
}
|
||||
func (o *os) Unsetenv(key string) error {
|
||||
return nativeos.Unsetenv(key)
|
||||
}
|
||||
func (o *os) Stat(name string) (nativeos.FileInfo, error) {
|
||||
return nativeos.Stat(name)
|
||||
}
|
||||
@@ -3,6 +3,8 @@ package params
|
||||
import (
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
@@ -32,10 +34,19 @@ func (p *reader) GetCyberghostClientKey() (clientKey string, err error) {
|
||||
} else if len(clientKey) > 0 {
|
||||
return clientKey, nil
|
||||
}
|
||||
content, err := p.fileManager.ReadFile(string(constants.ClientKey))
|
||||
const filepath = string(constants.ClientKey)
|
||||
file, err := p.os.OpenFile(filepath, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
content, err := ioutil.ReadAll(file)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return "", err
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return extractClientKey(content)
|
||||
}
|
||||
|
||||
@@ -55,10 +66,19 @@ func extractClientKey(b []byte) (key string, err error) {
|
||||
// GetCyberghostClientCertificate obtains the client certificate to use for openvpn from the
|
||||
// file at /gluetun/client.crt.
|
||||
func (p *reader) GetCyberghostClientCertificate() (clientCertificate string, err error) {
|
||||
content, err := p.fileManager.ReadFile(string(constants.ClientCertificate))
|
||||
const filepath = string(constants.ClientCertificate)
|
||||
file, err := p.os.OpenFile(filepath, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
content, err := ioutil.ReadAll(file)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return "", err
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return extractClientCertificate(content)
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
// GetUser obtains the user to use to connect to the VPN servers.
|
||||
func (r *reader) GetUser() (s string, err error) {
|
||||
defer func() {
|
||||
unsetenvErr := r.unsetEnv("USER")
|
||||
unsetenvErr := r.os.Unsetenv("USER")
|
||||
if err == nil {
|
||||
err = unsetenvErr
|
||||
}
|
||||
@@ -22,7 +22,7 @@ func (r *reader) GetUser() (s string, err error) {
|
||||
// GetPassword obtains the password to use to connect to the VPN servers.
|
||||
func (r *reader) GetPassword(required bool) (s string, err error) {
|
||||
defer func() {
|
||||
unsetenvErr := r.unsetEnv("PASSWORD")
|
||||
unsetenvErr := r.os.Unsetenv("PASSWORD")
|
||||
if err == nil {
|
||||
err = unsetenvErr
|
||||
}
|
||||
|
||||
@@ -2,11 +2,10 @@ package params
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
libparams "github.com/qdm12/golibs/params"
|
||||
"github.com/qdm12/golibs/verification"
|
||||
@@ -128,22 +127,20 @@ type Reader interface {
|
||||
}
|
||||
|
||||
type reader struct {
|
||||
envParams libparams.EnvParams
|
||||
logger logging.Logger
|
||||
verifier verification.Verifier
|
||||
unsetEnv func(key string) error
|
||||
fileManager files.FileManager
|
||||
envParams libparams.EnvParams
|
||||
logger logging.Logger
|
||||
verifier verification.Verifier
|
||||
os os.OS
|
||||
}
|
||||
|
||||
// Newreader returns a paramsReadeer object to read parameters from
|
||||
// environment variables.
|
||||
func NewReader(logger logging.Logger, fileManager files.FileManager) Reader {
|
||||
func NewReader(logger logging.Logger, os os.OS) Reader {
|
||||
return &reader{
|
||||
envParams: libparams.NewEnvParams(),
|
||||
logger: logger,
|
||||
verifier: verification.NewVerifier(),
|
||||
unsetEnv: os.Unsetenv,
|
||||
fileManager: fileManager,
|
||||
envParams: libparams.NewEnvParams(),
|
||||
logger: logger,
|
||||
verifier: verification.NewVerifier(),
|
||||
os: os,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ func (r *reader) GetShadowSocksPort() (port uint16, err error) {
|
||||
// SHADOWSOCKS_PASSWORD.
|
||||
func (r *reader) GetShadowSocksPassword() (password string, err error) {
|
||||
defer func() {
|
||||
unsetErr := r.unsetEnv("SHADOWSOCKS_PASSWORD")
|
||||
unsetErr := r.os.Unsetenv("SHADOWSOCKS_PASSWORD")
|
||||
if err == nil {
|
||||
err = unsetErr
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
@@ -133,7 +133,7 @@ func (c *cyberghost) BuildConf(connection models.OpenVPNConnection, verbosity in
|
||||
}
|
||||
|
||||
func (c *cyberghost) PortForward(ctx context.Context, client *http.Client,
|
||||
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath models.Filepath)) {
|
||||
panic("port forwarding is not supported for cyberghost")
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
@@ -128,7 +128,7 @@ func (m *mullvad) BuildConf(connection models.OpenVPNConnection,
|
||||
}
|
||||
|
||||
func (m *mullvad) PortForward(ctx context.Context, client *http.Client,
|
||||
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath models.Filepath)) {
|
||||
panic("port forwarding is not supported for mullvad")
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
@@ -142,7 +142,7 @@ func (n *nordvpn) BuildConf(connection models.OpenVPNConnection, verbosity int,
|
||||
}
|
||||
|
||||
func (n *nordvpn) PortForward(ctx context.Context, client *http.Client,
|
||||
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath models.Filepath)) {
|
||||
panic("port forwarding is not supported for nordvpn")
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
gluetunLog "github.com/qdm12/gluetun/internal/logging"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
@@ -183,7 +183,7 @@ func (p *pia) BuildConf(connection models.OpenVPNConnection, verbosity int, user
|
||||
|
||||
//nolint:gocognit
|
||||
func (p *pia) PortForward(ctx context.Context, client *http.Client,
|
||||
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath models.Filepath)) {
|
||||
if !p.activeServer.PortForward {
|
||||
pfLogger.Error("The server %s does not support port forwarding", p.activeServer.Region)
|
||||
@@ -203,7 +203,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
|
||||
return
|
||||
}
|
||||
defer pfLogger.Warn("loop exited")
|
||||
data, err := readPIAPortForwardData(fileManager)
|
||||
data, err := readPIAPortForwardData(openFile)
|
||||
if err != nil {
|
||||
pfLogger.Error(err)
|
||||
}
|
||||
@@ -222,7 +222,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
|
||||
|
||||
if !dataFound || expired {
|
||||
tryUntilSuccessful(ctx, pfLogger, func() error {
|
||||
data, err = refreshPIAPortForwardData(ctx, client, gateway, fileManager)
|
||||
data, err = refreshPIAPortForwardData(ctx, client, gateway, openFile)
|
||||
return err
|
||||
})
|
||||
if ctx.Err() != nil {
|
||||
@@ -240,12 +240,9 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
|
||||
return
|
||||
}
|
||||
|
||||
filepath := syncState(data.Port)
|
||||
filepath := string(syncState(data.Port))
|
||||
pfLogger.Info("Writing port to %s", filepath)
|
||||
if err := fileManager.WriteToFile(
|
||||
string(filepath), []byte(fmt.Sprintf("%d", data.Port)),
|
||||
files.Permissions(constants.AllReadWritePermissions),
|
||||
); err != nil {
|
||||
if err := writePortForwardedToFile(openFile, filepath, data.Port); err != nil {
|
||||
pfLogger.Error(err)
|
||||
}
|
||||
|
||||
@@ -281,7 +278,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
|
||||
pfLogger.Warn("Forward port has expired on %s, getting another one", data.Expiration.Format(time.RFC1123))
|
||||
oldPort := data.Port
|
||||
for {
|
||||
data, err = refreshPIAPortForwardData(ctx, client, gateway, fileManager)
|
||||
data, err = refreshPIAPortForwardData(ctx, client, gateway, openFile)
|
||||
if err != nil {
|
||||
pfLogger.Error(err)
|
||||
continue
|
||||
@@ -298,10 +295,7 @@ func (p *pia) PortForward(ctx context.Context, client *http.Client,
|
||||
}
|
||||
filepath := syncState(data.Port)
|
||||
pfLogger.Info("Writing port to %s", filepath)
|
||||
if err := fileManager.WriteToFile(
|
||||
string(filepath), []byte(fmt.Sprintf("%d", data.Port)),
|
||||
files.Permissions(constants.AllReadWritePermissions),
|
||||
); err != nil {
|
||||
if err := writePortForwardedToFile(openFile, string(filepath), data.Port); err != nil {
|
||||
pfLogger.Error(err)
|
||||
}
|
||||
if err := bindPIAPort(ctx, client, gateway, data); err != nil {
|
||||
@@ -365,8 +359,8 @@ func newPIAHTTPClient(serverName string) (client *http.Client, err error) {
|
||||
}
|
||||
|
||||
func refreshPIAPortForwardData(ctx context.Context, client *http.Client,
|
||||
gateway net.IP, fileManager files.FileManager) (data piaPortForwardData, err error) {
|
||||
data.Token, err = fetchPIAToken(ctx, fileManager, client)
|
||||
gateway net.IP, openFile os.OpenFileFunc) (data piaPortForwardData, err error) {
|
||||
data.Token, err = fetchPIAToken(ctx, openFile, client)
|
||||
if err != nil {
|
||||
return data, fmt.Errorf("cannot obtain token: %w", err)
|
||||
}
|
||||
@@ -374,7 +368,7 @@ func refreshPIAPortForwardData(ctx context.Context, client *http.Client,
|
||||
if err != nil {
|
||||
return data, fmt.Errorf("cannot obtain port forwarding data: %w", err)
|
||||
}
|
||||
if err := writePIAPortForwardData(fileManager, data); err != nil {
|
||||
if err := writePIAPortForwardData(openFile, data); err != nil {
|
||||
return data, fmt.Errorf("cannot persist port forwarding information to file: %w", err)
|
||||
}
|
||||
return data, nil
|
||||
@@ -393,34 +387,39 @@ type piaPortForwardData struct {
|
||||
Expiration time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
func readPIAPortForwardData(fileManager files.FileManager) (data piaPortForwardData, err error) {
|
||||
func readPIAPortForwardData(openFile os.OpenFileFunc) (data piaPortForwardData, err error) {
|
||||
const filepath = string(constants.PIAPortForward)
|
||||
exists, err := fileManager.FileExists(filepath)
|
||||
if err != nil {
|
||||
return data, err
|
||||
} else if !exists {
|
||||
file, err := openFile(filepath, os.O_RDONLY, 0)
|
||||
if os.IsNotExist(err) {
|
||||
return data, nil
|
||||
} else if err != nil {
|
||||
return data, err
|
||||
}
|
||||
b, err := fileManager.ReadFile(filepath)
|
||||
|
||||
decoder := json.NewDecoder(file)
|
||||
err = decoder.Decode(&data)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return data, err
|
||||
}
|
||||
if err := json.Unmarshal(b, &data); err != nil {
|
||||
return data, err
|
||||
}
|
||||
return data, nil
|
||||
return data, file.Close()
|
||||
}
|
||||
|
||||
func writePIAPortForwardData(fileManager files.FileManager, data piaPortForwardData) (err error) {
|
||||
b, err := json.Marshal(&data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot encode data: %w", err)
|
||||
}
|
||||
err = fileManager.WriteToFile(string(constants.PIAPortForward), b)
|
||||
func writePIAPortForwardData(openFile os.OpenFileFunc, data piaPortForwardData) (err error) {
|
||||
const filepath = string(constants.PIAPortForward)
|
||||
file, err := openFile(filepath,
|
||||
os.O_CREATE|os.O_TRUNC|os.O_WRONLY,
|
||||
0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
encoder := json.NewEncoder(file)
|
||||
err = encoder.Encode(data)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
func unpackPIAPayload(payload string) (port uint16, token string, expiration time.Time, err error) {
|
||||
@@ -449,8 +448,9 @@ func packPIAPayload(port uint16, token string, expiration time.Time) (payload st
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func fetchPIAToken(ctx context.Context, fileManager files.FileManager, client *http.Client) (token string, err error) {
|
||||
username, password, err := getOpenvpnCredentials(fileManager)
|
||||
func fetchPIAToken(ctx context.Context, openFile os.OpenFileFunc,
|
||||
client *http.Client) (token string, err error) {
|
||||
username, password, err := getOpenvpnCredentials(openFile)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("cannot get Openvpn credentials: %w", err)
|
||||
}
|
||||
@@ -489,10 +489,19 @@ func fetchPIAToken(ctx context.Context, fileManager files.FileManager, client *h
|
||||
return result.Token, nil
|
||||
}
|
||||
|
||||
func getOpenvpnCredentials(fileManager files.FileManager) (username, password string, err error) {
|
||||
authData, err := fileManager.ReadFile(string(constants.OpenVPNAuthConf))
|
||||
func getOpenvpnCredentials(openFile os.OpenFileFunc) (username, password string, err error) {
|
||||
const filepath = string(constants.OpenVPNAuthConf)
|
||||
file, err := openFile(filepath, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("cannot read openvpn auth file: %w", err)
|
||||
return "", "", fmt.Errorf("cannot read openvpn auth file: %s", err)
|
||||
}
|
||||
authData, err := ioutil.ReadAll(file)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return "", "", fmt.Errorf("cannot read openvpn auth file: %s", err)
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
lines := strings.Split(string(authData), "\n")
|
||||
const minLines = 2
|
||||
@@ -586,3 +595,17 @@ func bindPIAPort(ctx context.Context, client *http.Client, gateway net.IP, data
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writePortForwardedToFile(openFile os.OpenFileFunc,
|
||||
filepath string, port uint16) (err error) {
|
||||
file, err := openFile(filepath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = file.Write([]byte(fmt.Sprintf("%d", port)))
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
@@ -117,7 +117,7 @@ func (s *privado) BuildConf(connection models.OpenVPNConnection, verbosity int,
|
||||
}
|
||||
|
||||
func (s *privado) PortForward(ctx context.Context, client *http.Client,
|
||||
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath models.Filepath)) {
|
||||
panic("port forwarding is not supported for privado")
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ type Provider interface {
|
||||
BuildConf(connection models.OpenVPNConnection, verbosity int, username string,
|
||||
root bool, cipher, auth string, extras models.ExtraConfigOptions) (lines []string)
|
||||
PortForward(ctx context.Context, client *http.Client,
|
||||
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath models.Filepath))
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
@@ -150,7 +150,7 @@ func (p *purevpn) BuildConf(connection models.OpenVPNConnection, verbosity int,
|
||||
}
|
||||
|
||||
func (p *purevpn) PortForward(ctx context.Context, client *http.Client,
|
||||
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath models.Filepath)) {
|
||||
panic("port forwarding is not supported for purevpn")
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
@@ -138,7 +138,7 @@ func (s *surfshark) BuildConf(connection models.OpenVPNConnection, verbosity int
|
||||
}
|
||||
|
||||
func (s *surfshark) PortForward(ctx context.Context, client *http.Client,
|
||||
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath models.Filepath)) {
|
||||
panic("port forwarding is not supported for surfshark")
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
@@ -119,7 +119,7 @@ func (v *vyprvpn) BuildConf(connection models.OpenVPNConnection, verbosity int,
|
||||
}
|
||||
|
||||
func (v *vyprvpn) PortForward(ctx context.Context, client *http.Client,
|
||||
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath models.Filepath)) {
|
||||
panic("port forwarding is not supported for vyprvpn")
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/firewall"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
@@ -132,7 +132,7 @@ func (w *windscribe) BuildConf(connection models.OpenVPNConnection, verbosity in
|
||||
}
|
||||
|
||||
func (w *windscribe) PortForward(ctx context.Context, client *http.Client,
|
||||
fileManager files.FileManager, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
openFile os.OpenFileFunc, pfLogger logging.Logger, gateway net.IP, fw firewall.Configurator,
|
||||
syncState func(port uint16) (pfFilepath models.Filepath)) {
|
||||
panic("port forwarding is not supported for windscribe")
|
||||
}
|
||||
|
||||
27
internal/publicip/fs.go
Normal file
27
internal/publicip/fs.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package publicip
|
||||
|
||||
import "github.com/qdm12/gluetun/internal/os"
|
||||
|
||||
func persistPublicIP(openFile os.OpenFileFunc,
|
||||
filepath string, content string, uid, gid int) error {
|
||||
file, err := openFile(
|
||||
filepath,
|
||||
os.O_TRUNC|os.O_WRONLY|os.O_CREATE,
|
||||
0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = file.WriteString(content)
|
||||
if err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if err := file.Chown(uid, gid); err != nil {
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return file.Close()
|
||||
}
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
|
||||
"github.com/qdm12/gluetun/internal/constants"
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/gluetun/internal/settings"
|
||||
"github.com/qdm12/golibs/files"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
"github.com/qdm12/golibs/network"
|
||||
)
|
||||
@@ -27,9 +27,9 @@ type Looper interface {
|
||||
type looper struct {
|
||||
state state
|
||||
// Objects
|
||||
getter IPGetter
|
||||
logger logging.Logger
|
||||
fileManager files.FileManager
|
||||
getter IPGetter
|
||||
logger logging.Logger
|
||||
os os.OS
|
||||
// Fixed settings
|
||||
uid int
|
||||
gid int
|
||||
@@ -45,8 +45,9 @@ type looper struct {
|
||||
timeSince func(time.Time) time.Duration
|
||||
}
|
||||
|
||||
func NewLooper(client network.Client, logger logging.Logger, fileManager files.FileManager,
|
||||
settings settings.PublicIP, uid, gid int) Looper {
|
||||
func NewLooper(client network.Client, logger logging.Logger,
|
||||
settings settings.PublicIP, uid, gid int,
|
||||
os os.OS) Looper {
|
||||
return &looper{
|
||||
state: state{
|
||||
status: constants.Stopped,
|
||||
@@ -55,7 +56,7 @@ func NewLooper(client network.Client, logger logging.Logger, fileManager files.F
|
||||
// Objects
|
||||
getter: NewIPGetter(client),
|
||||
logger: logger.WithPrefix("ip getter: "),
|
||||
fileManager: fileManager,
|
||||
os: os,
|
||||
uid: uid,
|
||||
gid: gid,
|
||||
start: make(chan struct{}),
|
||||
@@ -125,7 +126,7 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||
close(errorCh)
|
||||
filepath := l.GetSettings().IPFilepath
|
||||
l.logger.Info("Removing ip file %s", filepath)
|
||||
if err := l.fileManager.Remove(string(filepath)); err != nil {
|
||||
if err := l.os.Remove(string(filepath)); err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
return
|
||||
@@ -142,12 +143,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup) {
|
||||
getCancel()
|
||||
l.state.setPublicIP(ip)
|
||||
l.logger.Info("Public IP address is %s", ip)
|
||||
const userReadWritePermissions = 0600
|
||||
err := l.fileManager.WriteLinesToFile(
|
||||
string(l.state.settings.IPFilepath),
|
||||
[]string{ip.String()},
|
||||
files.Ownership(l.uid, l.gid),
|
||||
files.Permissions(userReadWritePermissions))
|
||||
filepath := string(l.state.settings.IPFilepath)
|
||||
err := persistPublicIP(l.os.OpenFile, filepath, ip.String(), l.uid, l.gid)
|
||||
if err != nil {
|
||||
l.logger.Error(err)
|
||||
}
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
"github.com/qdm12/golibs/logging"
|
||||
)
|
||||
|
||||
@@ -14,17 +12,13 @@ type Storage interface {
|
||||
}
|
||||
|
||||
type storage struct {
|
||||
osStat func(name string) (os.FileInfo, error)
|
||||
readFile func(filename string) (data []byte, err error)
|
||||
writeFile func(filename string, data []byte, perm os.FileMode) error
|
||||
logger logging.Logger
|
||||
os os.OS
|
||||
logger logging.Logger
|
||||
}
|
||||
|
||||
func New(logger logging.Logger) Storage {
|
||||
func New(logger logging.Logger, os os.OS) Storage {
|
||||
return &storage{
|
||||
osStat: os.Stat,
|
||||
readFile: ioutil.ReadFile,
|
||||
writeFile: ioutil.WriteFile,
|
||||
logger: logger.WithPrefix("storage: "),
|
||||
os: os,
|
||||
logger: logger.WithPrefix("storage: "),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,10 +3,10 @@ package storage
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
|
||||
"github.com/qdm12/gluetun/internal/models"
|
||||
"github.com/qdm12/gluetun/internal/os"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -29,14 +29,18 @@ func (s *storage) SyncServers(hardcodedServers models.AllServers, write bool) (
|
||||
allServers models.AllServers, err error) {
|
||||
// Eventually read file
|
||||
var serversOnFile models.AllServers
|
||||
_, err = s.osStat(jsonFilepath)
|
||||
file, err := s.os.OpenFile(jsonFilepath, os.O_RDONLY, 0)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return allServers, err
|
||||
}
|
||||
if err == nil {
|
||||
serversOnFile, err = s.readFromFile()
|
||||
if err != nil {
|
||||
var serversOnFile models.AllServers
|
||||
decoder := json.NewDecoder(file)
|
||||
if err := decoder.Decode(&serversOnFile); err != nil {
|
||||
_ = file.Close()
|
||||
return allServers, err
|
||||
}
|
||||
} else if !os.IsNotExist(err) {
|
||||
return allServers, err
|
||||
return allServers, file.Close()
|
||||
}
|
||||
|
||||
// Merge data from file and hardcoded
|
||||
@@ -51,24 +55,16 @@ func (s *storage) SyncServers(hardcodedServers models.AllServers, write bool) (
|
||||
return allServers, s.FlushToFile(allServers)
|
||||
}
|
||||
|
||||
func (s *storage) readFromFile() (servers models.AllServers, err error) {
|
||||
bytes, err := s.readFile(jsonFilepath)
|
||||
if err != nil {
|
||||
return servers, err
|
||||
}
|
||||
if err := json.Unmarshal(bytes, &servers); err != nil {
|
||||
return servers, err
|
||||
}
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
func (s *storage) FlushToFile(servers models.AllServers) error {
|
||||
bytes, err := json.MarshalIndent(servers, "", " ")
|
||||
file, err := s.os.OpenFile(jsonFilepath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot write to file: %w", err)
|
||||
}
|
||||
if err := s.writeFile(jsonFilepath, bytes, 0644); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
encoder := json.NewEncoder(file)
|
||||
encoder.SetIndent("", " ")
|
||||
if err := encoder.Encode(servers); err != nil {
|
||||
_ = file.Close()
|
||||
return fmt.Errorf("cannot write to file: %w", err)
|
||||
}
|
||||
return file.Close()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user