Compare commits

...

5 Commits

Author SHA1 Message Date
Maycon Santos
7bf9793f85 Support environment vars (#155)
* updage flag values from environment variables

* add log and removing unused constants

* removing unused code

* Docker build client

* fix indentation

* Documentation with docker command

* use docker volume
2021-11-15 09:11:50 +01:00
Maycon Santos
fcbf980588 Stop service before uninstall (#158) 2021-11-14 21:30:18 +01:00
Mikhail Bragin
d08e5efbce fix: too many open files caused by agent not being closed (#154)
* fix: too many open files caused by agent not being closed after unsuccessful attempts to start a peer connection (happens when no network available)

* fix: minor refactor to consider signal status
2021-11-14 19:41:17 +01:00
Maycon Santos
95ef8547f3 Signal management arm builds (#152)
* Add arm builds for Signal and Management services

* adding arm's binary version
2021-11-07 13:11:03 +01:00
Mikhail Bragin
ed1e4dfc51 refactor signal client sync func (#147)
* refactor: move goroutine that runs Signal Client Receive to the engine for better control

* chore: fix comments typo

* test: fix golint

* chore: comments update

* chore: consider connection state=READY in signal and management clients

* chore: fix typos

* test: fix signal ping-pong test

* chore: add wait condition to signal client

* refactor: add stream status to the Signal client

* refactor: defer mutex unlock
2021-11-06 15:00:13 +01:00
16 changed files with 417 additions and 147 deletions

View File

@@ -37,6 +37,7 @@ builds:
goarch: goarch:
- amd64 - amd64
- arm64 - arm64
- arm
ldflags: ldflags:
- -s -w -X main.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X main.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}' mod_timestamp: '{{ .CommitTimestamp }}'
@@ -50,6 +51,7 @@ builds:
goarch: goarch:
- amd64 - amd64
- arm64 - arm64
- arm
ldflags: ldflags:
- -s -w -X main.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X main.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}' mod_timestamp: '{{ .CommitTimestamp }}'
@@ -83,6 +85,52 @@ nfpms:
postinstall: "release_files/post_install.sh" postinstall: "release_files/post_install.sh"
preremove: "release_files/pre_remove.sh" preremove: "release_files/pre_remove.sh"
dockers: dockers:
- image_templates:
- wiretrustee/wiretrustee:{{ .Version }}-amd64
ids:
- wiretrustee
goarch: amd64
use: buildx
dockerfile: client/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=wiretrustee@wiretrustee.com"
- image_templates:
- wiretrustee/wiretrustee:{{ .Version }}-arm64v8
ids:
- wiretrustee
goarch: arm64
use: buildx
dockerfile: client/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=wiretrustee@wiretrustee.com"
- image_templates:
- wiretrustee/wiretrustee:{{ .Version }}-arm
ids:
- wiretrustee
goarch: arm
goarm: 6
use: buildx
dockerfile: client/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=wiretrustee@wiretrustee.com"
- image_templates: - image_templates:
- wiretrustee/signal:{{ .Version }}-amd64 - wiretrustee/signal:{{ .Version }}-amd64
ids: ids:
@@ -113,6 +161,22 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=wiretrustee@wiretrustee.com" - "--label=maintainer=wiretrustee@wiretrustee.com"
- image_templates:
- wiretrustee/signal:{{ .Version }}-arm
ids:
- wiretrustee-signal
goarch: arm
goarm: 6
use: buildx
dockerfile: signal/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=wiretrustee@wiretrustee.com"
- image_templates: - image_templates:
- wiretrustee/management:{{ .Version }}-amd64 - wiretrustee/management:{{ .Version }}-amd64
ids: ids:
@@ -143,6 +207,22 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=wiretrustee@wiretrustee.com" - "--label=maintainer=wiretrustee@wiretrustee.com"
- image_templates:
- wiretrustee/management:{{ .Version }}-arm
ids:
- wiretrustee-mgmt
goarch: arm
goarm: 6
use: buildx
dockerfile: management/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=wiretrustee@wiretrustee.com"
- image_templates: - image_templates:
- wiretrustee/management:{{ .Version }}-debug-amd64 - wiretrustee/management:{{ .Version }}-debug-amd64
ids: ids:
@@ -174,30 +254,63 @@ dockers:
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=wiretrustee@wiretrustee.com" - "--label=maintainer=wiretrustee@wiretrustee.com"
- image_templates:
- wiretrustee/management:{{ .Version }}-debug-arm
ids:
- wiretrustee-mgmt
goarch: arm
goarm: 6
use: buildx
dockerfile: management/Dockerfile.debug
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=wiretrustee@wiretrustee.com"
docker_manifests: docker_manifests:
- name_template: wiretrustee/wiretrustee:{{ .Version }}
image_templates:
- wiretrustee/wiretrustee:{{ .Version }}-arm64v8
- wiretrustee/wiretrustee:{{ .Version }}-arm
- wiretrustee/wiretrustee:{{ .Version }}-amd64
- name_template: wiretrustee/wiretrustee:latest
image_templates:
- wiretrustee/wiretrustee:{{ .Version }}-arm64v8
- wiretrustee/wiretrustee:{{ .Version }}-arm
- wiretrustee/wiretrustee:{{ .Version }}-amd64
- name_template: wiretrustee/signal:{{ .Version }} - name_template: wiretrustee/signal:{{ .Version }}
image_templates: image_templates:
- wiretrustee/signal:{{ .Version }}-arm64v8 - wiretrustee/signal:{{ .Version }}-arm64v8
- wiretrustee/signal:{{ .Version }}-arm
- wiretrustee/signal:{{ .Version }}-amd64 - wiretrustee/signal:{{ .Version }}-amd64
- name_template: wiretrustee/signal:latest - name_template: wiretrustee/signal:latest
image_templates: image_templates:
- wiretrustee/signal:{{ .Version }}-arm64v8 - wiretrustee/signal:{{ .Version }}-arm64v8
- wiretrustee/signal:{{ .Version }}-arm
- wiretrustee/signal:{{ .Version }}-amd64 - wiretrustee/signal:{{ .Version }}-amd64
- name_template: wiretrustee/management:{{ .Version }} - name_template: wiretrustee/management:{{ .Version }}
image_templates: image_templates:
- wiretrustee/management:{{ .Version }}-arm64v8 - wiretrustee/management:{{ .Version }}-arm64v8
- wiretrustee/management:{{ .Version }}-arm
- wiretrustee/management:{{ .Version }}-amd64 - wiretrustee/management:{{ .Version }}-amd64
- name_template: wiretrustee/management:latest - name_template: wiretrustee/management:latest
image_templates: image_templates:
- wiretrustee/management:{{ .Version }}-arm64v8 - wiretrustee/management:{{ .Version }}-arm64v8
- wiretrustee/management:{{ .Version }}-arm
- wiretrustee/management:{{ .Version }}-amd64 - wiretrustee/management:{{ .Version }}-amd64
- name_template: wiretrustee/management:debug-latest - name_template: wiretrustee/management:debug-latest
image_templates: image_templates:
- wiretrustee/management:{{ .Version }}-debug-arm64v8 - wiretrustee/management:{{ .Version }}-debug-arm64v8
- wiretrustee/management:{{ .Version }}-debug-arm
- wiretrustee/management:{{ .Version }}-debug-amd64 - wiretrustee/management:{{ .Version }}-debug-amd64
brews: brews:

View File

@@ -145,6 +145,11 @@ For **Windows** systems, start powershell as administrator and:
```shell ```shell
wiretrustee up --setup-key <SETUP KEY> wiretrustee up --setup-key <SETUP KEY>
``` ```
For **Docker**, you can run with the following command:
```shell
docker run --network host --privileged --rm -d -e WT_SETUP_KEY=<SETUP KEY> -v wiretrustee-client:/etc/wiretrustee wiretrustee/wiretrustee:<TAG>
```
> TAG > 0.3.0 version
Alternatively, if you are hosting your own Management Service provide `--management-url` property pointing to your Management Service: Alternatively, if you are hosting your own Management Service provide `--management-url` property pointing to your Management Service:
```shell ```shell

4
client/Dockerfile Normal file
View File

@@ -0,0 +1,4 @@
FROM gcr.io/distroless/base:debug
ENV WT_LOG_FILE=console
ENTRYPOINT [ "/go/bin/wiretrustee","up"]
COPY wiretrustee /go/bin/wiretrustee

View File

@@ -18,12 +18,12 @@ import (
) )
var ( var (
setupKey string
loginCmd = &cobra.Command{ loginCmd = &cobra.Command{
Use: "login", Use: "login",
Short: "login to the Wiretrustee Management Service (first run)", Short: "login to the Wiretrustee Management Service (first run)",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars()
err := util.InitLog(logLevel, logFile) err := util.InitLog(logLevel, logFile)
if err != nil { if err != nil {
log.Errorf("failed initializing log %v", err) log.Errorf("failed initializing log %v", err)
@@ -151,6 +151,3 @@ func promptPeerSetupKey() (string, error) {
return "", s.Err() return "", s.Err()
} }
//func init() {
//}

View File

@@ -4,19 +4,15 @@ import (
"fmt" "fmt"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/wiretrustee/wiretrustee/client/internal" "github.com/wiretrustee/wiretrustee/client/internal"
"os" "os"
"os/signal" "os/signal"
"runtime" "runtime"
"strings"
"syscall" "syscall"
) )
const (
// ExitSetupFailed defines exit code
ExitSetupFailed = 1
DefaultConfigPath = ""
)
var ( var (
configPath string configPath string
defaultConfigPath string defaultConfigPath string
@@ -24,6 +20,7 @@ var (
defaultLogFile string defaultLogFile string
logFile string logFile string
managementURL string managementURL string
setupKey string
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "wiretrustee", Use: "wiretrustee",
Short: "", Short: "",
@@ -75,3 +72,28 @@ func SetupCloseHandler() {
} }
}() }()
} }
// SetFlagsFromEnvVars reads and updates flag values from environment variables with prefix WT_
func SetFlagsFromEnvVars() {
flags := rootCmd.PersistentFlags()
flags.VisitAll(func(f *pflag.Flag) {
envVar := FlagNameToEnvVar(f.Name)
if value, present := os.LookupEnv(envVar); present {
err := flags.Set(f.Name, value)
if err != nil {
log.Infof("unable to configure flag %s using variable %s, err: %v", f.Name, envVar, err)
}
}
})
}
// FlagNameToEnvVar converts flag name to environment var name adding a prefix,
// replacing dashes and making all uppercase (e.g. setup-keys is converted to WT_SETUP_KEYS)
func FlagNameToEnvVar(f string) string {
prefix := "WT_"
parsed := strings.ReplaceAll(f, "-", "_")
upper := strings.ToUpper(parsed)
return prefix + upper
}

View File

@@ -34,6 +34,3 @@ var (
Short: "manages wiretrustee service", Short: "manages wiretrustee service",
} }
) )
func init() {
}

View File

@@ -8,7 +8,7 @@ import (
"time" "time"
) )
func (p *program) Start(s service.Service) error { func (p *program) Start(service.Service) error {
// Start should not block. Do the actual work async. // Start should not block. Do the actual work async.
log.Info("starting service") //nolint log.Info("starting service") //nolint
@@ -22,7 +22,7 @@ func (p *program) Start(s service.Service) error {
return nil return nil
} }
func (p *program) Stop(s service.Service) error { func (p *program) Stop(service.Service) error {
go func() { go func() {
stopCh <- 1 stopCh <- 1
}() }()
@@ -41,6 +41,7 @@ var (
Use: "run", Use: "run",
Short: "runs wiretrustee as service", Short: "runs wiretrustee as service",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
SetFlagsFromEnvVars()
err := util.InitLog(logLevel, logFile) err := util.InitLog(logLevel, logFile)
if err != nil { if err != nil {
@@ -75,6 +76,8 @@ var (
Use: "start", Use: "start",
Short: "starts wiretrustee service", Short: "starts wiretrustee service",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars()
err := util.InitLog(logLevel, logFile) err := util.InitLog(logLevel, logFile)
if err != nil { if err != nil {
log.Errorf("failed initializing log %v", err) log.Errorf("failed initializing log %v", err)
@@ -101,6 +104,8 @@ var (
Use: "stop", Use: "stop",
Short: "stops wiretrustee service", Short: "stops wiretrustee service",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
SetFlagsFromEnvVars()
err := util.InitLog(logLevel, logFile) err := util.InitLog(logLevel, logFile)
if err != nil { if err != nil {
log.Errorf("failed initializing log %v", err) log.Errorf("failed initializing log %v", err)
@@ -125,6 +130,8 @@ var (
Use: "restart", Use: "restart",
Short: "restarts wiretrustee service", Short: "restarts wiretrustee service",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
SetFlagsFromEnvVars()
err := util.InitLog(logLevel, logFile) err := util.InitLog(logLevel, logFile)
if err != nil { if err != nil {
log.Errorf("failed initializing log %v", err) log.Errorf("failed initializing log %v", err)
@@ -143,6 +150,3 @@ var (
}, },
} }
) )
func init() {
}

View File

@@ -10,6 +10,7 @@ var (
Use: "install", Use: "install",
Short: "installs wiretrustee service", Short: "installs wiretrustee service",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars()
svcConfig := newSVCConfig() svcConfig := newSVCConfig()
@@ -49,6 +50,7 @@ var (
Use: "uninstall", Use: "uninstall",
Short: "uninstalls wiretrustee service from system", Short: "uninstalls wiretrustee service from system",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
SetFlagsFromEnvVars()
s, err := newSVC(&program{}, newSVCConfig()) s, err := newSVC(&program{}, newSVCConfig())
if err != nil { if err != nil {
@@ -65,6 +67,3 @@ var (
}, },
} }
) )
func init() {
}

View File

@@ -21,7 +21,7 @@ var (
Use: "up", Use: "up",
Short: "install, login and start wiretrustee client", Short: "install, login and start wiretrustee client",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars()
err := loginCmd.RunE(cmd, args) err := loginCmd.RunE(cmd, args)
if err != nil { if err != nil {
return err return err

View File

@@ -106,6 +106,7 @@ SectionEnd
Section Uninstall Section Uninstall
${INSTALL_TYPE} ${INSTALL_TYPE}
Exec '"$INSTDIR\${MAIN_APP_EXE}" service stop'
Exec '"$INSTDIR\${MAIN_APP_EXE}" service uninstall' Exec '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
# wait the service uninstall take unblock the executable # wait the service uninstall take unblock the executable
Sleep 3000 Sleep 3000

View File

@@ -138,12 +138,18 @@ func (conn *Connection) Open(timeout time.Duration) error {
return !ok return !ok
}, },
}) })
conn.agent = a
if err != nil { if err != nil {
return err return err
} }
conn.agent = a
defer func() {
err := conn.agent.Close()
if err != nil {
return
}
}()
err = conn.listenOnLocalCandidates() err = conn.listenOnLocalCandidates()
if err != nil { if err != nil {
return err return err

View File

@@ -4,7 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
ice "github.com/pion/ice/v2" "github.com/pion/ice/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/wiretrustee/wiretrustee/iface" "github.com/wiretrustee/wiretrustee/iface"
mgm "github.com/wiretrustee/wiretrustee/management/client" mgm "github.com/wiretrustee/wiretrustee/management/client"
@@ -142,12 +142,17 @@ func (e *Engine) initializePeer(peer Peer) {
RandomizationFactor: backoff.DefaultRandomizationFactor, RandomizationFactor: backoff.DefaultRandomizationFactor,
Multiplier: backoff.DefaultMultiplier, Multiplier: backoff.DefaultMultiplier,
MaxInterval: 5 * time.Second, MaxInterval: 5 * time.Second,
MaxElapsedTime: time.Duration(0), //never stop MaxElapsedTime: 0, //never stop
Stop: backoff.Stop, Stop: backoff.Stop,
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
}, e.ctx) }, e.ctx)
operation := func() error { operation := func() error {
if e.signal.GetStatus() != signal.StreamConnected {
return fmt.Errorf("not opening connection to peer because Signal is unavailable")
}
_, err := e.openPeerConnection(e.wgPort, e.config.WgPrivateKey, peer) _, err := e.openPeerConnection(e.wgPort, e.config.WgPrivateKey, peer)
e.peerMux.Lock() e.peerMux.Lock()
defer e.peerMux.Unlock() defer e.peerMux.Unlock()
@@ -157,7 +162,6 @@ func (e *Engine) initializePeer(peer Peer) {
} }
if err != nil { if err != nil {
log.Warnln(err)
log.Debugf("retrying connection because of error: %s", err.Error()) log.Debugf("retrying connection because of error: %s", err.Error())
return err return err
} }
@@ -332,6 +336,8 @@ func (e *Engine) receiveManagementEvents() {
return nil return nil
}) })
if err != nil { if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
e.cancel() e.cancel()
return return
} }
@@ -414,8 +420,10 @@ func (e *Engine) updatePeers(remotePeers []*mgmProto.RemotePeerConfig) error {
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers // receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
func (e *Engine) receiveSignalEvents() { func (e *Engine) receiveSignalEvents() {
go func() {
// connect to a stream of messages coming from the signal server // connect to a stream of messages coming from the signal server
e.signal.Receive(func(msg *sProto.Message) error { err := e.signal.Receive(func(msg *sProto.Message) error {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
@@ -476,6 +484,13 @@ func (e *Engine) receiveSignalEvents() {
return nil return nil
}) })
if err != nil {
e.signal.WaitConnected() // happens if signal is unavailable for a long time.
// We want to cancel the operation of the whole client
e.cancel()
return
}
}()
e.signal.WaitStreamConnected()
} }

1
go.mod
View File

@@ -15,6 +15,7 @@ require (
github.com/rs/cors v1.8.0 github.com/rs/cors v1.8.0
github.com/sirupsen/logrus v1.7.0 github.com/sirupsen/logrus v1.7.0
github.com/spf13/cobra v1.1.3 github.com/spf13/cobra v1.1.3
github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.1.0 github.com/vishvananda/netlink v1.1.0
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c

View File

@@ -3,6 +3,7 @@ package client
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"fmt"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/wiretrustee/wiretrustee/client/system" "github.com/wiretrustee/wiretrustee/client/system"
@@ -10,6 +11,7 @@ import (
"github.com/wiretrustee/wiretrustee/management/proto" "github.com/wiretrustee/wiretrustee/management/proto"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"io" "io"
@@ -71,12 +73,18 @@ func defaultBackoff(ctx context.Context) backoff.BackOff {
RandomizationFactor: backoff.DefaultRandomizationFactor, RandomizationFactor: backoff.DefaultRandomizationFactor,
Multiplier: backoff.DefaultMultiplier, Multiplier: backoff.DefaultMultiplier,
MaxInterval: 10 * time.Second, MaxInterval: 10 * time.Second,
MaxElapsedTime: 30 * time.Minute, //stop after an 30 min of trying, the error will be propagated to the general retry of the client MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client
Stop: backoff.Stop, Stop: backoff.Stop,
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
}, ctx) }, ctx)
} }
// ready indicates whether the client is okay and ready to be used
// for now it just checks whether gRPC connection to the service is ready
func (c *Client) ready() bool {
return c.conn.GetState() == connectivity.Ready
}
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages // Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
// Blocking request. The result will be sent via msgHandler callback function // Blocking request. The result will be sent via msgHandler callback function
func (c *Client) Sync(msgHandler func(msg *proto.SyncResponse) error) error { func (c *Client) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
@@ -85,6 +93,12 @@ func (c *Client) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
operation := func() error { operation := func() error {
log.Debugf("management connection state %v", c.conn.GetState())
if !c.ready() {
return fmt.Errorf("no connection to management")
}
// todo we already have it since we did the Login, maybe cache it locally? // todo we already have it since we did the Login, maybe cache it locally?
serverPubKey, err := c.GetServerPublicKey() serverPubKey, err := c.GetServerPublicKey()
if err != nil { if err != nil {
@@ -98,7 +112,7 @@ func (c *Client) Sync(msgHandler func(msg *proto.SyncResponse) error) error {
return err return err
} }
log.Infof("connected to the Management Service Stream") log.Infof("connected to the Management Service stream")
// blocking until error // blocking until error
err = c.receiveEvents(stream, *serverPubKey, msgHandler) err = c.receiveEvents(stream, *serverPubKey, msgHandler)
@@ -139,7 +153,7 @@ func (c *Client) receiveEvents(stream proto.ManagementService_SyncClient, server
for { for {
update, err := stream.Recv() update, err := stream.Recv()
if err == io.EOF { if err == io.EOF {
log.Errorf("managment stream was closed: %s", err) log.Errorf("Management stream has been closed by server: %s", err)
return err return err
} }
if err != nil { if err != nil {
@@ -165,6 +179,10 @@ func (c *Client) receiveEvents(stream proto.ManagementService_SyncClient, server
// GetServerPublicKey returns server Wireguard public key (used later for encrypting messages sent to the server) // GetServerPublicKey returns server Wireguard public key (used later for encrypting messages sent to the server)
func (c *Client) GetServerPublicKey() (*wgtypes.Key, error) { func (c *Client) GetServerPublicKey() (*wgtypes.Key, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management")
}
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) //todo make a general setting mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) //todo make a general setting
defer cancel() defer cancel()
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{}) resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
@@ -181,6 +199,9 @@ func (c *Client) GetServerPublicKey() (*wgtypes.Key, error) {
} }
func (c *Client) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) { func (c *Client) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management")
}
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req) loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
if err != nil { if err != nil {
log.Errorf("failed to encrypt message: %s", err) log.Errorf("failed to encrypt message: %s", err)

View File

@@ -11,6 +11,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
@@ -23,6 +24,12 @@ import (
// A set of tools to exchange connection details (Wireguard endpoints) with the remote peer. // A set of tools to exchange connection details (Wireguard endpoints) with the remote peer.
// Status is the status of the client
type Status string
const StreamConnected Status = "Connected"
const StreamDisconnected Status = "Disconnected"
// Client Wraps the Signal Exchange Service gRpc client // Client Wraps the Signal Exchange Service gRpc client
type Client struct { type Client struct {
key wgtypes.Key key wgtypes.Key
@@ -30,8 +37,15 @@ type Client struct {
signalConn *grpc.ClientConn signalConn *grpc.ClientConn
ctx context.Context ctx context.Context
stream proto.SignalExchange_ConnectStreamClient stream proto.SignalExchange_ConnectStreamClient
//waiting group to notify once stream is connected // connectedCh used to notify goroutines waiting for the connection to the Signal stream
connWg *sync.WaitGroup //todo use a channel instead?? connectedCh chan struct{}
mux sync.Mutex
// StreamConnected indicates whether this client is StreamConnected to the Signal stream
status Status
}
func (c *Client) GetStatus() Status {
return c.status
} }
// Close Closes underlying connections to the Signal Exchange // Close Closes underlying connections to the Signal Exchange
@@ -65,13 +79,13 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
return nil, err return nil, err
} }
var wg sync.WaitGroup
return &Client{ return &Client{
realClient: proto.NewSignalExchangeClient(conn), realClient: proto.NewSignalExchangeClient(conn),
ctx: ctx, ctx: ctx,
signalConn: conn, signalConn: conn,
key: key, key: key,
connWg: &wg, mux: sync.Mutex{},
status: StreamDisconnected,
}, nil }, nil
} }
@@ -82,7 +96,7 @@ func defaultBackoff(ctx context.Context) backoff.BackOff {
RandomizationFactor: backoff.DefaultRandomizationFactor, RandomizationFactor: backoff.DefaultRandomizationFactor,
Multiplier: backoff.DefaultMultiplier, Multiplier: backoff.DefaultMultiplier,
MaxInterval: 10 * time.Second, MaxInterval: 10 * time.Second,
MaxElapsedTime: 30 * time.Minute, //stop after an 30 min of trying, the error will be propagated to the general retry of the client MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client
Stop: backoff.Stop, Stop: backoff.Stop,
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
}, ctx) }, ctx)
@@ -91,25 +105,37 @@ func defaultBackoff(ctx context.Context) backoff.BackOff {
// Receive Connects to the Signal Exchange message stream and starts receiving messages. // Receive Connects to the Signal Exchange message stream and starts receiving messages.
// The messages will be handled by msgHandler function provided. // The messages will be handled by msgHandler function provided.
// This function runs a goroutine underneath and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart) // This function is blocking and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart)
// The key is the identifier of our Peer (could be Wireguard public key) // The connection retry logic will try to reconnect for 30 min and if wasn't successful will propagate the error to the function caller.
func (c *Client) Receive(msgHandler func(msg *proto.Message) error) { func (c *Client) Receive(msgHandler func(msg *proto.Message) error) error {
c.connWg.Add(1)
go func() {
var backOff = defaultBackoff(c.ctx) var backOff = defaultBackoff(c.ctx)
operation := func() error { operation := func() error {
c.notifyStreamDisconnected()
log.Debugf("signal connection state %v", c.signalConn.GetState())
if !c.ready() {
return fmt.Errorf("no connection to signal")
}
// connect to Signal stream identifying ourselves with a public Wireguard key
// todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management)
stream, err := c.connect(c.key.PublicKey().String()) stream, err := c.connect(c.key.PublicKey().String())
if err != nil { if err != nil {
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err) log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
c.connWg.Add(1)
return err return err
} }
c.notifyStreamConnected()
log.Infof("connected to the Signal Service stream")
// start receiving messages from the Signal stream (from other peers through signal)
err = c.receive(stream, msgHandler) err = c.receive(stream, msgHandler)
if err != nil { if err != nil {
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
backOff.Reset() backOff.Reset()
return err return err
} }
@@ -120,9 +146,35 @@ func (c *Client) Receive(msgHandler func(msg *proto.Message) error) {
err := backoff.Retry(operation, backOff) err := backoff.Retry(operation, backOff)
if err != nil { if err != nil {
log.Errorf("exiting Signal Service connection retry loop due to unrecoverable error: %s", err) log.Errorf("exiting Signal Service connection retry loop due to unrecoverable error: %s", err)
return return err
} }
}()
return nil
}
func (c *Client) notifyStreamDisconnected() {
c.mux.Lock()
defer c.mux.Unlock()
c.status = StreamDisconnected
}
func (c *Client) notifyStreamConnected() {
c.mux.Lock()
defer c.mux.Unlock()
c.status = StreamConnected
if c.connectedCh != nil {
// there are goroutines waiting on this channel -> release them
close(c.connectedCh)
c.connectedCh = nil
}
}
func (c *Client) getStreamStatusChan() <-chan struct{} {
c.mux.Lock()
defer c.mux.Unlock()
if c.connectedCh == nil {
c.connectedCh = make(chan struct{})
}
return c.connectedCh
} }
func (c *Client) connect(key string) (proto.SignalExchange_ConnectStreamClient, error) { func (c *Client) connect(key string) (proto.SignalExchange_ConnectStreamClient, error) {
@@ -147,24 +199,37 @@ func (c *Client) connect(key string) (proto.SignalExchange_ConnectStreamClient,
if len(registered) == 0 { if len(registered) == 0 {
return nil, fmt.Errorf("didn't receive a registration header from the Signal server whille connecting to the streams") return nil, fmt.Errorf("didn't receive a registration header from the Signal server whille connecting to the streams")
} }
//connection established we are good to use the stream
c.connWg.Done()
log.Infof("connected to the Signal Exchange Stream")
return stream, nil return stream, nil
} }
// WaitConnected waits until the client is connected to the message stream // ready indicates whether the client is okay and ready to be used
func (c *Client) WaitConnected() { // for now it just checks whether gRPC connection to the service is in state Ready
c.connWg.Wait() func (c *Client) ready() bool {
return c.signalConn.GetState() == connectivity.Ready
}
// WaitStreamConnected waits until the client is connected to the Signal stream
func (c *Client) WaitStreamConnected() {
if c.status == StreamConnected {
return
}
ch := c.getStreamStatusChan()
select {
case <-c.ctx.Done():
case <-ch:
}
} }
// SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server // SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server
// The Client.Receive method must be called before sending messages to establish initial connection to the Signal Exchange // The Client.Receive method must be called before sending messages to establish initial connection to the Signal Exchange
// Client.connWg can be used to wait // Client.connWg can be used to wait
func (c *Client) SendToStream(msg *proto.EncryptedMessage) error { func (c *Client) SendToStream(msg *proto.EncryptedMessage) error {
if !c.ready() {
return fmt.Errorf("no connection to signal")
}
if c.stream == nil { if c.stream == nil {
return fmt.Errorf("connection to the Signal Exchnage has not been established yet. Please call Client.Receive before sending messages") return fmt.Errorf("connection to the Signal Exchnage has not been established yet. Please call Client.Receive before sending messages")
} }
@@ -221,13 +286,17 @@ func (c *Client) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage, er
// Send sends a message to the remote Peer through the Signal Exchange. // Send sends a message to the remote Peer through the Signal Exchange.
func (c *Client) Send(msg *proto.Message) error { func (c *Client) Send(msg *proto.Message) error {
if !c.ready() {
return fmt.Errorf("no connection to signal")
}
encryptedMessage, err := c.encryptMessage(msg) encryptedMessage, err := c.encryptMessage(msg)
if err != nil { if err != nil {
return err return err
} }
_, err = c.realClient.Send(context.TODO(), encryptedMessage) _, err = c.realClient.Send(context.TODO(), encryptedMessage)
if err != nil { if err != nil {
log.Errorf("error while sending message to peer [%s] [error: %v]", msg.RemoteKey, err) //log.Errorf("error while sending message to peer [%s] [error: %v]", msg.RemoteKey, err)
return err return err
} }
@@ -244,10 +313,10 @@ func (c *Client) receive(stream proto.SignalExchange_ConnectStreamClient,
log.Warnf("stream canceled (usually indicates shutdown)") log.Warnf("stream canceled (usually indicates shutdown)")
return err return err
} else if s.Code() == codes.Unavailable { } else if s.Code() == codes.Unavailable {
log.Warnf("server has been stopped") log.Warnf("Signal Service is unavailable")
return err return err
} else if err == io.EOF { } else if err == io.EOF {
log.Warnf("stream closed by server") log.Warnf("Signal Service stream closed by server")
return err return err
} else if err != nil { } else if err != nil {
return err return err

View File

@@ -48,17 +48,24 @@ var _ = Describe("Client", func() {
// connect PeerA to Signal // connect PeerA to Signal
keyA, _ := wgtypes.GenerateKey() keyA, _ := wgtypes.GenerateKey()
clientA := createSignalClient(addr, keyA) clientA := createSignalClient(addr, keyA)
clientA.Receive(func(msg *sigProto.Message) error { go func() {
err := clientA.Receive(func(msg *sigProto.Message) error {
receivedOnA = msg.GetBody().GetPayload() receivedOnA = msg.GetBody().GetPayload()
msgReceived.Done() msgReceived.Done()
return nil return nil
}) })
clientA.WaitConnected() if err != nil {
return
}
}()
clientA.WaitStreamConnected()
// connect PeerB to Signal // connect PeerB to Signal
keyB, _ := wgtypes.GenerateKey() keyB, _ := wgtypes.GenerateKey()
clientB := createSignalClient(addr, keyB) clientB := createSignalClient(addr, keyB)
clientB.Receive(func(msg *sigProto.Message) error {
go func() {
err := clientB.Receive(func(msg *sigProto.Message) error {
receivedOnB = msg.GetBody().GetPayload() receivedOnB = msg.GetBody().GetPayload()
err := clientB.Send(&sigProto.Message{ err := clientB.Send(&sigProto.Message{
Key: keyB.PublicKey().String(), Key: keyB.PublicKey().String(),
@@ -71,7 +78,12 @@ var _ = Describe("Client", func() {
msgReceived.Done() msgReceived.Done()
return nil return nil
}) })
clientB.WaitConnected() if err != nil {
return
}
}()
clientB.WaitStreamConnected()
// PeerA initiates ping-pong // PeerA initiates ping-pong
err := clientA.Send(&sigProto.Message{ err := clientA.Send(&sigProto.Message{
@@ -100,11 +112,15 @@ var _ = Describe("Client", func() {
key, _ := wgtypes.GenerateKey() key, _ := wgtypes.GenerateKey()
client := createSignalClient(addr, key) client := createSignalClient(addr, key)
client.Receive(func(msg *sigProto.Message) error { go func() {
err := client.Receive(func(msg *sigProto.Message) error {
return nil return nil
}) })
client.WaitConnected() if err != nil {
return
}
}()
client.WaitStreamConnected()
Expect(client).NotTo(BeNil()) Expect(client).NotTo(BeNil())
}) })
}) })