mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 09:16:40 +00:00
Compare commits
50 Commits
cli-ws-pro
...
feature/de
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b2d7121695 | ||
|
|
3288c4414f | ||
|
|
f3b0439211 | ||
|
|
ae801d77fb | ||
|
|
4545ab9a52 | ||
|
|
7f08983207 | ||
|
|
eddea14521 | ||
|
|
b9ef214ea5 | ||
|
|
709e24eb6f | ||
|
|
bf83549db2 | ||
|
|
804a3871fe | ||
|
|
6654e2dbf7 | ||
|
|
64d1edce27 | ||
|
|
bf0698e5aa | ||
|
|
fc15625963 | ||
|
|
a75dde33b9 | ||
|
|
d80d47a469 | ||
|
|
96f71ff1e1 | ||
|
|
2fe2af38d2 | ||
|
|
bb46e438aa | ||
|
|
cd9a867ad0 | ||
|
|
0f9bfeff7c | ||
|
|
f5301230bf | ||
|
|
429d7d6585 | ||
|
|
3cdb10cde7 | ||
|
|
11ba253ffb | ||
|
|
af95aabb03 | ||
|
|
3abae0bd17 | ||
|
|
8252ff41db | ||
|
|
277aa2b7cc | ||
|
|
bb37dc89ce | ||
|
|
14fe7c29cb | ||
|
|
158f3aceff | ||
|
|
bfa776c155 | ||
|
|
885b5c68ad | ||
|
|
b1ebac795d | ||
|
|
000e99e7f3 | ||
|
|
0d2e67983a | ||
|
|
5151f19d29 | ||
|
|
bedd3cabc9 | ||
|
|
d35a845dbd | ||
|
|
4e03f708a4 | ||
|
|
654aa9581d | ||
|
|
9021bb512b | ||
|
|
768332820e | ||
|
|
229c65ffa1 | ||
|
|
4d33567888 | ||
|
|
88467883fc | ||
|
|
954f40991f | ||
|
|
34341d95a9 |
@@ -4,7 +4,7 @@
|
|||||||
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||||
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||||
|
|
||||||
FROM alpine:3.22.0
|
FROM alpine:3.22.2
|
||||||
# iproute2: busybox doesn't display ip rules properly
|
# iproute2: busybox doesn't display ip rules properly
|
||||||
RUN apk add --no-cache \
|
RUN apk add --no-cache \
|
||||||
bash \
|
bash \
|
||||||
|
|||||||
@@ -307,8 +307,14 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||||
} else {
|
} else {
|
||||||
|
pm := profilemanager.NewProfileManager()
|
||||||
|
var profName string
|
||||||
|
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||||
|
profName = activeProf.Name
|
||||||
|
}
|
||||||
|
|
||||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""),
|
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
return statusOutputString
|
return statusOutputString
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"os/exec"
|
||||||
"os/user"
|
"os/user"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -356,13 +357,21 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
|||||||
cmd.Println("")
|
cmd.Println("")
|
||||||
|
|
||||||
if !noBrowser {
|
if !noBrowser {
|
||||||
if err := open.Run(verificationURIComplete); err != nil {
|
if err := openBrowser(verificationURIComplete); err != nil {
|
||||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
|
||||||
|
func openBrowser(url string) error {
|
||||||
|
if browser := os.Getenv("BROWSER"); browser != "" {
|
||||||
|
return exec.Command(browser, url).Start()
|
||||||
|
}
|
||||||
|
return open.Run(url)
|
||||||
|
}
|
||||||
|
|
||||||
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
||||||
func isUnixRunningDesktop() bool {
|
func isUnixRunningDesktop() bool {
|
||||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||||
|
|||||||
@@ -400,7 +400,6 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Include action in the ipset name to prevent squashing rules with different actions
|
|
||||||
actionSuffix := ""
|
actionSuffix := ""
|
||||||
if action == firewall.ActionDrop {
|
if action == firewall.ActionDrop {
|
||||||
actionSuffix = "-drop"
|
actionSuffix = "-drop"
|
||||||
|
|||||||
@@ -4,12 +4,15 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/connectivity"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
@@ -17,6 +20,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready
|
||||||
|
var ErrConnectionShutdown = errors.New("connection shutdown before ready")
|
||||||
|
|
||||||
// Backoff returns a backoff configuration for gRPC calls
|
// Backoff returns a backoff configuration for gRPC calls
|
||||||
func Backoff(ctx context.Context) backoff.BackOff {
|
func Backoff(ctx context.Context) backoff.BackOff {
|
||||||
b := backoff.NewExponentialBackOff()
|
b := backoff.NewExponentialBackOff()
|
||||||
@@ -25,11 +31,32 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
|||||||
return backoff.WithContext(b, ctx)
|
return backoff.WithContext(b, ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// waitForConnectionReady blocks until the connection becomes ready or fails.
|
||||||
|
// Returns an error if the connection times out, is cancelled, or enters shutdown state.
|
||||||
|
func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error {
|
||||||
|
conn.Connect()
|
||||||
|
|
||||||
|
state := conn.GetState()
|
||||||
|
for state != connectivity.Ready && state != connectivity.Shutdown {
|
||||||
|
if !conn.WaitForStateChange(ctx, state) {
|
||||||
|
return fmt.Errorf("wait state change from %s: %w", state, ctx.Err())
|
||||||
|
}
|
||||||
|
state = conn.GetState()
|
||||||
|
}
|
||||||
|
|
||||||
|
if state == connectivity.Shutdown {
|
||||||
|
return ErrConnectionShutdown
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
||||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
||||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||||
if tlsEnabled {
|
// for js, the outer websocket layer takes care of tls
|
||||||
|
if tlsEnabled && runtime.GOOS != "js" {
|
||||||
certPool, err := x509.SystemCertPool()
|
certPool, err := x509.SystemCertPool()
|
||||||
if err != nil || certPool == nil {
|
if err != nil || certPool == nil {
|
||||||
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
||||||
@@ -37,28 +64,28 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
|||||||
}
|
}
|
||||||
|
|
||||||
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
||||||
// for js, outer websocket layer takes care of tls verification via WithCustomDialer
|
RootCAs: certPool,
|
||||||
InsecureSkipVerify: runtime.GOOS == "js",
|
|
||||||
RootCAs: certPool,
|
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
conn, err := grpc.NewClient(
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
conn, err := grpc.DialContext(
|
|
||||||
connCtx,
|
|
||||||
addr,
|
addr,
|
||||||
transportOption,
|
transportOption,
|
||||||
WithCustomDialer(tlsEnabled, component),
|
WithCustomDialer(tlsEnabled, component),
|
||||||
grpc.WithBlock(),
|
|
||||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||||
Time: 30 * time.Second,
|
Time: 30 * time.Second,
|
||||||
Timeout: 10 * time.Second,
|
Timeout: 10 * time.Second,
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("DialContext error: %v", err)
|
return nil, fmt.Errorf("new client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := waitForConnectionReady(ctx, conn); err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import (
|
|||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
|
func WithCustomDialer(_ bool, _ string) grpc.DialOption {
|
||||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||||
if runtime.GOOS == "linux" {
|
if runtime.GOOS == "linux" {
|
||||||
currentUser, err := user.Current()
|
currentUser, err := user.Current()
|
||||||
@@ -36,7 +36,6 @@ func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
|
|||||||
|
|
||||||
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to dial: %s", err)
|
|
||||||
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
||||||
}
|
}
|
||||||
return conn, nil
|
return conn, nil
|
||||||
|
|||||||
@@ -73,6 +73,44 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *KernelConfigurer) RemoveEndpointAddress(peerKey string) error {
|
||||||
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the existing peer to preserve its allowed IPs
|
||||||
|
existingPeer, err := c.getPeer(c.deviceName, peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get peer: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
removePeerCfg := wgtypes.PeerConfig{
|
||||||
|
PublicKey: peerKeyParsed,
|
||||||
|
Remove: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{removePeerCfg}}); err != nil {
|
||||||
|
return fmt.Errorf(`error removing peer %s from interface %s: %w`, peerKey, c.deviceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//Re-add the peer without the endpoint but same AllowedIPs
|
||||||
|
reAddPeerCfg := wgtypes.PeerConfig{
|
||||||
|
PublicKey: peerKeyParsed,
|
||||||
|
AllowedIPs: existingPeer.AllowedIPs,
|
||||||
|
ReplaceAllowedIPs: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{reAddPeerCfg}}); err != nil {
|
||||||
|
return fmt.Errorf(
|
||||||
|
`error re-adding peer %s to interface %s with allowed IPs %v: %w`,
|
||||||
|
peerKey, c.deviceName, existingPeer.AllowedIPs, err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *KernelConfigurer) RemovePeer(peerKey string) error {
|
func (c *KernelConfigurer) RemovePeer(peerKey string) error {
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -106,6 +106,67 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *WGUSPConfigurer) RemoveEndpointAddress(peerKey string) error {
|
||||||
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse peer key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ipcStr, err := c.device.IpcGet()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get IPC config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse current status to get allowed IPs for the peer
|
||||||
|
stats, err := parseStatus(c.deviceName, ipcStr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse IPC config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var allowedIPs []net.IPNet
|
||||||
|
found := false
|
||||||
|
for _, peer := range stats.Peers {
|
||||||
|
if peer.PublicKey == peerKey {
|
||||||
|
allowedIPs = peer.AllowedIPs
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return fmt.Errorf("peer %s not found", peerKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove the peer from the WireGuard configuration
|
||||||
|
peer := wgtypes.PeerConfig{
|
||||||
|
PublicKey: peerKeyParsed,
|
||||||
|
Remove: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
config := wgtypes.Config{
|
||||||
|
Peers: []wgtypes.PeerConfig{peer},
|
||||||
|
}
|
||||||
|
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
|
||||||
|
return fmt.Errorf("failed to remove peer: %s", ipcErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the peer config
|
||||||
|
peer = wgtypes.PeerConfig{
|
||||||
|
PublicKey: peerKeyParsed,
|
||||||
|
ReplaceAllowedIPs: true,
|
||||||
|
AllowedIPs: allowedIPs,
|
||||||
|
}
|
||||||
|
|
||||||
|
config = wgtypes.Config{
|
||||||
|
Peers: []wgtypes.PeerConfig{peer},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.device.IpcSet(toWgUserspaceString(config)); err != nil {
|
||||||
|
return fmt.Errorf("remove endpoint address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -23,4 +23,5 @@ type WGTunDevice interface {
|
|||||||
FilteredDevice() *device.FilteredDevice
|
FilteredDevice() *device.FilteredDevice
|
||||||
Device() *wgdevice.Device
|
Device() *wgdevice.Device
|
||||||
GetNet() *netstack.Net
|
GetNet() *netstack.Net
|
||||||
|
GetICEBind() device.EndpointManager
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -150,6 +150,11 @@ func (t *WGTunDevice) GetNet() *netstack.Net {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetICEBind returns the ICEBind instance
|
||||||
|
func (t *WGTunDevice) GetICEBind() EndpointManager {
|
||||||
|
return t.iceBind
|
||||||
|
}
|
||||||
|
|
||||||
func routesToString(routes []string) string {
|
func routesToString(routes []string) string {
|
||||||
return strings.Join(routes, ";")
|
return strings.Join(routes, ";")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -154,3 +154,8 @@ func (t *TunDevice) assignAddr() error {
|
|||||||
func (t *TunDevice) GetNet() *netstack.Net {
|
func (t *TunDevice) GetNet() *netstack.Net {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetICEBind returns the ICEBind instance
|
||||||
|
func (t *TunDevice) GetICEBind() EndpointManager {
|
||||||
|
return t.iceBind
|
||||||
|
}
|
||||||
|
|||||||
@@ -144,3 +144,8 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
|||||||
func (t *TunDevice) GetNet() *netstack.Net {
|
func (t *TunDevice) GetNet() *netstack.Net {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetICEBind returns the ICEBind instance
|
||||||
|
func (t *TunDevice) GetICEBind() EndpointManager {
|
||||||
|
return t.iceBind
|
||||||
|
}
|
||||||
|
|||||||
@@ -179,3 +179,8 @@ func (t *TunKernelDevice) assignAddr() error {
|
|||||||
func (t *TunKernelDevice) GetNet() *netstack.Net {
|
func (t *TunKernelDevice) GetNet() *netstack.Net {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetICEBind returns nil for kernel mode devices
|
||||||
|
func (t *TunKernelDevice) GetICEBind() EndpointManager {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ type Bind interface {
|
|||||||
conn.Bind
|
conn.Bind
|
||||||
GetICEMux() (*udpmux.UniversalUDPMuxDefault, error)
|
GetICEMux() (*udpmux.UniversalUDPMuxDefault, error)
|
||||||
ActivityRecorder() *bind.ActivityRecorder
|
ActivityRecorder() *bind.ActivityRecorder
|
||||||
|
EndpointManager
|
||||||
}
|
}
|
||||||
|
|
||||||
type TunNetstackDevice struct {
|
type TunNetstackDevice struct {
|
||||||
@@ -155,3 +156,8 @@ func (t *TunNetstackDevice) Device() *device.Device {
|
|||||||
func (t *TunNetstackDevice) GetNet() *netstack.Net {
|
func (t *TunNetstackDevice) GetNet() *netstack.Net {
|
||||||
return t.net
|
return t.net
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetICEBind returns the bind instance
|
||||||
|
func (t *TunNetstackDevice) GetICEBind() EndpointManager {
|
||||||
|
return t.bind
|
||||||
|
}
|
||||||
|
|||||||
@@ -146,3 +146,8 @@ func (t *USPDevice) assignAddr() error {
|
|||||||
func (t *USPDevice) GetNet() *netstack.Net {
|
func (t *USPDevice) GetNet() *netstack.Net {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetICEBind returns the ICEBind instance
|
||||||
|
func (t *USPDevice) GetICEBind() EndpointManager {
|
||||||
|
return t.iceBind
|
||||||
|
}
|
||||||
|
|||||||
@@ -185,3 +185,8 @@ func (t *TunDevice) assignAddr() error {
|
|||||||
func (t *TunDevice) GetNet() *netstack.Net {
|
func (t *TunDevice) GetNet() *netstack.Net {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetICEBind returns the ICEBind instance
|
||||||
|
func (t *TunDevice) GetICEBind() EndpointManager {
|
||||||
|
return t.iceBind
|
||||||
|
}
|
||||||
|
|||||||
13
client/iface/device/endpoint_manager.go
Normal file
13
client/iface/device/endpoint_manager.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EndpointManager manages fake IP to connection mappings for userspace bind implementations.
|
||||||
|
// Implemented by bind.ICEBind and bind.RelayBindJS.
|
||||||
|
type EndpointManager interface {
|
||||||
|
SetEndpoint(fakeIP netip.Addr, conn net.Conn)
|
||||||
|
RemoveEndpoint(fakeIP netip.Addr)
|
||||||
|
}
|
||||||
@@ -21,4 +21,5 @@ type WGConfigurer interface {
|
|||||||
GetStats() (map[string]configurer.WGStats, error)
|
GetStats() (map[string]configurer.WGStats, error)
|
||||||
FullStats() (*configurer.Stats, error)
|
FullStats() (*configurer.Stats, error)
|
||||||
LastActivities() map[string]monotime.Time
|
LastActivities() map[string]monotime.Time
|
||||||
|
RemoveEndpointAddress(peerKey string) error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,4 +21,5 @@ type WGTunDevice interface {
|
|||||||
FilteredDevice() *device.FilteredDevice
|
FilteredDevice() *device.FilteredDevice
|
||||||
Device() *wgdevice.Device
|
Device() *wgdevice.Device
|
||||||
GetNet() *netstack.Net
|
GetNet() *netstack.Net
|
||||||
|
GetICEBind() device.EndpointManager
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,6 +80,17 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
|
|||||||
return w.wgProxyFactory.GetProxy()
|
return w.wgProxyFactory.GetProxy()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetBind returns the EndpointManager userspace bind mode.
|
||||||
|
func (w *WGIface) GetBind() device.EndpointManager {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
|
if w.tun == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return w.tun.GetICEBind()
|
||||||
|
}
|
||||||
|
|
||||||
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
|
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
|
||||||
func (w *WGIface) IsUserspaceBind() bool {
|
func (w *WGIface) IsUserspaceBind() bool {
|
||||||
return w.userspaceBind
|
return w.userspaceBind
|
||||||
@@ -148,6 +159,17 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
|
|||||||
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *WGIface) RemoveEndpointAddress(peerKey string) error {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
if w.configurer == nil {
|
||||||
|
return ErrIfaceNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Removing endpoint address: %s", peerKey)
|
||||||
|
return w.configurer.RemoveEndpointAddress(peerKey)
|
||||||
|
}
|
||||||
|
|
||||||
// RemovePeer removes a Wireguard Peer from the interface iface
|
// RemovePeer removes a Wireguard Peer from the interface iface
|
||||||
func (w *WGIface) RemovePeer(peerKey string) error {
|
func (w *WGIface) RemovePeer(peerKey string) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
|
|||||||
@@ -29,11 +29,6 @@ type Manager interface {
|
|||||||
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
|
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
type protoMatch struct {
|
|
||||||
ips map[string]int
|
|
||||||
policyID []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultManager uses firewall manager to handle
|
// DefaultManager uses firewall manager to handle
|
||||||
type DefaultManager struct {
|
type DefaultManager struct {
|
||||||
firewall firewall.Manager
|
firewall firewall.Manager
|
||||||
@@ -86,21 +81,14 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||||
rules, squashedProtocols := d.squashAcceptRules(networkMap)
|
rules := networkMap.FirewallRules
|
||||||
|
|
||||||
enableSSH := networkMap.PeerConfig != nil &&
|
enableSSH := networkMap.PeerConfig != nil &&
|
||||||
networkMap.PeerConfig.SshConfig != nil &&
|
networkMap.PeerConfig.SshConfig != nil &&
|
||||||
networkMap.PeerConfig.SshConfig.SshEnabled
|
networkMap.PeerConfig.SshConfig.SshEnabled
|
||||||
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
|
|
||||||
enableSSH = enableSSH && !ok
|
|
||||||
}
|
|
||||||
if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok {
|
|
||||||
enableSSH = enableSSH && !ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// if TCP protocol rules not squashed and SSH enabled
|
// If SSH enabled, add default firewall rule which accepts connection to any peer
|
||||||
// we add default firewall rule which accepts connection to any peer
|
// in the network by SSH (TCP port defined by ssh.DefaultSSHPort).
|
||||||
// in the network by SSH (TCP 22 port).
|
|
||||||
if enableSSH {
|
if enableSSH {
|
||||||
rules = append(rules, &mgmProto.FirewallRule{
|
rules = append(rules, &mgmProto.FirewallRule{
|
||||||
PeerIP: "0.0.0.0",
|
PeerIP: "0.0.0.0",
|
||||||
@@ -368,145 +356,6 @@ func (d *DefaultManager) getPeerRuleID(
|
|||||||
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
|
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
|
||||||
}
|
}
|
||||||
|
|
||||||
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
|
|
||||||
// to all peers in the network map to one rule which just accepts that type of the traffic.
|
|
||||||
//
|
|
||||||
// NOTE: It will not squash two rules for same protocol if one covers all peers in the network,
|
|
||||||
// but other has port definitions or has drop policy.
|
|
||||||
func (d *DefaultManager) squashAcceptRules(
|
|
||||||
networkMap *mgmProto.NetworkMap,
|
|
||||||
) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) {
|
|
||||||
totalIPs := 0
|
|
||||||
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
|
|
||||||
for range p.AllowedIps {
|
|
||||||
totalIPs++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
in := map[mgmProto.RuleProtocol]*protoMatch{}
|
|
||||||
out := map[mgmProto.RuleProtocol]*protoMatch{}
|
|
||||||
|
|
||||||
// trace which type of protocols was squashed
|
|
||||||
squashedRules := []*mgmProto.FirewallRule{}
|
|
||||||
squashedProtocols := map[mgmProto.RuleProtocol]struct{}{}
|
|
||||||
|
|
||||||
// this function we use to do calculation, can we squash the rules by protocol or not.
|
|
||||||
// We summ amount of Peers IP for given protocol we found in original rules list.
|
|
||||||
// But we zeroed the IP's for protocol if:
|
|
||||||
// 1. Any of the rule has DROP action type.
|
|
||||||
// 2. Any of rule contains Port.
|
|
||||||
//
|
|
||||||
// We zeroed this to notify squash function that this protocol can't be squashed.
|
|
||||||
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
|
|
||||||
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
|
|
||||||
r.Port != "" || !portInfoEmpty(r.PortInfo)
|
|
||||||
|
|
||||||
if hasPortRestrictions {
|
|
||||||
// Don't squash rules with port restrictions
|
|
||||||
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := protocols[r.Protocol]; !ok {
|
|
||||||
protocols[r.Protocol] = &protoMatch{
|
|
||||||
ips: map[string]int{},
|
|
||||||
// store the first encountered PolicyID for this protocol
|
|
||||||
policyID: r.PolicyID,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// special case, when we receive this all network IP address
|
|
||||||
// it means that rules for that protocol was already optimized on the
|
|
||||||
// management side
|
|
||||||
if r.PeerIP == "0.0.0.0" {
|
|
||||||
squashedRules = append(squashedRules, r)
|
|
||||||
squashedProtocols[r.Protocol] = struct{}{}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ipset := protocols[r.Protocol].ips
|
|
||||||
|
|
||||||
if _, ok := ipset[r.PeerIP]; ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ipset[r.PeerIP] = i
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, r := range networkMap.FirewallRules {
|
|
||||||
// calculate squash for different directions
|
|
||||||
if r.Direction == mgmProto.RuleDirection_IN {
|
|
||||||
addRuleToCalculationMap(i, r, in)
|
|
||||||
} else {
|
|
||||||
addRuleToCalculationMap(i, r, out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// order of squashing by protocol is important
|
|
||||||
// only for their first element ALL, it must be done first
|
|
||||||
protocolOrders := []mgmProto.RuleProtocol{
|
|
||||||
mgmProto.RuleProtocol_ALL,
|
|
||||||
mgmProto.RuleProtocol_ICMP,
|
|
||||||
mgmProto.RuleProtocol_TCP,
|
|
||||||
mgmProto.RuleProtocol_UDP,
|
|
||||||
}
|
|
||||||
|
|
||||||
squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) {
|
|
||||||
for _, protocol := range protocolOrders {
|
|
||||||
match, ok := matches[protocol]
|
|
||||||
if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 {
|
|
||||||
// don't squash if :
|
|
||||||
// 1. Rules not cover all peers in the network
|
|
||||||
// 2. Rules cover only one peer in the network.
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// add special rule 0.0.0.0 which allows all IP's in our firewall implementations
|
|
||||||
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
|
|
||||||
PeerIP: "0.0.0.0",
|
|
||||||
Direction: direction,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: protocol,
|
|
||||||
PolicyID: match.policyID,
|
|
||||||
})
|
|
||||||
squashedProtocols[protocol] = struct{}{}
|
|
||||||
|
|
||||||
if protocol == mgmProto.RuleProtocol_ALL {
|
|
||||||
// if we have ALL traffic type squashed rule
|
|
||||||
// it allows all other type of traffic, so we can stop processing
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
squash(in, mgmProto.RuleDirection_IN)
|
|
||||||
squash(out, mgmProto.RuleDirection_OUT)
|
|
||||||
|
|
||||||
// if all protocol was squashed everything is allow and we can ignore all other rules
|
|
||||||
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
|
|
||||||
return squashedRules, squashedProtocols
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(squashedRules) == 0 {
|
|
||||||
return networkMap.FirewallRules, squashedProtocols
|
|
||||||
}
|
|
||||||
|
|
||||||
var rules []*mgmProto.FirewallRule
|
|
||||||
// filter out rules which was squashed from final list
|
|
||||||
// if we also have other not squashed rules.
|
|
||||||
for i, r := range networkMap.FirewallRules {
|
|
||||||
if _, ok := squashedProtocols[r.Protocol]; ok {
|
|
||||||
if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i {
|
|
||||||
continue
|
|
||||||
} else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rules = append(rules, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
return append(rules, squashedRules...), squashedProtocols
|
|
||||||
}
|
|
||||||
|
|
||||||
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
||||||
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
||||||
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
||||||
|
|||||||
@@ -188,492 +188,6 @@ func TestDefaultManagerStateless(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultManagerSquashRules(t *testing.T) {
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
|
||||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
|
||||||
{AllowedIps: []string{"10.93.0.1"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.2"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.3"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.4"}},
|
|
||||||
},
|
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_OUT,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_OUT,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_OUT,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_OUT,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
manager := &DefaultManager{}
|
|
||||||
rules, _ := manager.squashAcceptRules(networkMap)
|
|
||||||
assert.Equal(t, 2, len(rules))
|
|
||||||
|
|
||||||
r := rules[0]
|
|
||||||
assert.Equal(t, "0.0.0.0", r.PeerIP)
|
|
||||||
assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction)
|
|
||||||
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
|
|
||||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
|
|
||||||
|
|
||||||
r = rules[1]
|
|
||||||
assert.Equal(t, "0.0.0.0", r.PeerIP)
|
|
||||||
assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction)
|
|
||||||
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
|
|
||||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
|
||||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
|
||||||
{AllowedIps: []string{"10.93.0.1"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.2"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.3"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.4"}},
|
|
||||||
},
|
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_OUT,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_OUT,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_OUT,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_ALL,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_OUT,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_UDP,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
manager := &DefaultManager{}
|
|
||||||
rules, _ := manager.squashAcceptRules(networkMap)
|
|
||||||
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
rules []*mgmProto.FirewallRule
|
|
||||||
expectedCount int
|
|
||||||
description string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "should not squash rules with port ranges",
|
|
||||||
rules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Range_{
|
|
||||||
Range: &mgmProto.PortInfo_Range{
|
|
||||||
Start: 8080,
|
|
||||||
End: 8090,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Range_{
|
|
||||||
Range: &mgmProto.PortInfo_Range{
|
|
||||||
Start: 8080,
|
|
||||||
End: 8090,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Range_{
|
|
||||||
Range: &mgmProto.PortInfo_Range{
|
|
||||||
Start: 8080,
|
|
||||||
End: 8090,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Range_{
|
|
||||||
Range: &mgmProto.PortInfo_Range{
|
|
||||||
Start: 8080,
|
|
||||||
End: 8090,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedCount: 4,
|
|
||||||
description: "Rules with port ranges should not be squashed even if they cover all peers",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should not squash rules with specific ports",
|
|
||||||
rules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Port{
|
|
||||||
Port: 80,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Port{
|
|
||||||
Port: 80,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Port{
|
|
||||||
Port: 80,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Port{
|
|
||||||
Port: 80,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedCount: 4,
|
|
||||||
description: "Rules with specific ports should not be squashed even if they cover all peers",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should not squash rules with legacy port field",
|
|
||||||
rules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedCount: 4,
|
|
||||||
description: "Rules with legacy port field should not be squashed",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should not squash rules with DROP action",
|
|
||||||
rules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_DROP,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_DROP,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_DROP,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_DROP,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedCount: 4,
|
|
||||||
description: "Rules with DROP action should not be squashed",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should squash rules without port restrictions",
|
|
||||||
rules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedCount: 1,
|
|
||||||
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "mixed rules should not squash protocol with port restrictions",
|
|
||||||
rules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{
|
|
||||||
PortSelection: &mgmProto.PortInfo_Port{
|
|
||||||
Port: 80,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedCount: 4,
|
|
||||||
description: "TCP should not be squashed because one rule has port restrictions",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "should squash UDP but not TCP when TCP has port restrictions",
|
|
||||||
rules: []*mgmProto.FirewallRule{
|
|
||||||
// TCP rules with port restrictions - should NOT be squashed
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
// UDP rules without port restrictions - SHOULD be squashed
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_UDP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.2",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_UDP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.3",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_UDP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "10.93.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_UDP,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
|
|
||||||
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
networkMap := &mgmProto.NetworkMap{
|
|
||||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
|
||||||
{AllowedIps: []string{"10.93.0.1"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.2"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.3"}},
|
|
||||||
{AllowedIps: []string{"10.93.0.4"}},
|
|
||||||
},
|
|
||||||
FirewallRules: tt.rules,
|
|
||||||
}
|
|
||||||
|
|
||||||
manager := &DefaultManager{}
|
|
||||||
rules, _ := manager.squashAcceptRules(networkMap)
|
|
||||||
|
|
||||||
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
|
|
||||||
|
|
||||||
// For squashed rules, verify we get the expected 0.0.0.0 rule
|
|
||||||
if tt.expectedCount == 1 {
|
|
||||||
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
|
|
||||||
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
|
|
||||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPortInfoEmpty(t *testing.T) {
|
func TestPortInfoEmpty(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
cProto "github.com/netbirdio/netbird/client/proto"
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
@@ -34,7 +35,6 @@ import (
|
|||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -289,15 +289,18 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
}
|
}
|
||||||
|
|
||||||
<-engineCtx.Done()
|
<-engineCtx.Done()
|
||||||
|
|
||||||
c.engineMutex.Lock()
|
c.engineMutex.Lock()
|
||||||
if c.engine != nil && c.engine.wgInterface != nil {
|
engine := c.engine
|
||||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name())
|
c.engine = nil
|
||||||
if err := c.engine.Stop(); err != nil {
|
c.engineMutex.Unlock()
|
||||||
|
|
||||||
|
if engine != nil && engine.wgInterface != nil {
|
||||||
|
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
||||||
|
if err := engine.Stop(); err != nil {
|
||||||
log.Errorf("Failed to stop engine: %v", err)
|
log.Errorf("Failed to stop engine: %v", err)
|
||||||
}
|
}
|
||||||
c.engine = nil
|
|
||||||
}
|
}
|
||||||
c.engineMutex.Unlock()
|
|
||||||
c.statusRecorder.ClientTeardown()
|
c.statusRecorder.ClientTeardown()
|
||||||
|
|
||||||
backOff.Reset()
|
backOff.Reset()
|
||||||
@@ -382,19 +385,12 @@ func (c *ConnectClient) Status() StatusType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnectClient) Stop() error {
|
func (c *ConnectClient) Stop() error {
|
||||||
if c == nil {
|
engine := c.Engine()
|
||||||
return nil
|
if engine != nil {
|
||||||
|
if err := engine.Stop(); err != nil {
|
||||||
|
return fmt.Errorf("stop engine: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
c.engineMutex.Lock()
|
|
||||||
defer c.engineMutex.Unlock()
|
|
||||||
|
|
||||||
if c.engine == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err := c.engine.Stop(); err != nil {
|
|
||||||
return fmt.Errorf("stop engine: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ nftables.txt: Anonymized nftables rules with packet counters, if --system-info f
|
|||||||
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
|
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
|
||||||
config.txt: Anonymized configuration information of the NetBird client.
|
config.txt: Anonymized configuration information of the NetBird client.
|
||||||
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
||||||
state.json: Anonymized client state dump containing netbird states.
|
state.json: Anonymized client state dump containing netbird states for the active profile.
|
||||||
mutex.prof: Mutex profiling information.
|
mutex.prof: Mutex profiling information.
|
||||||
goroutine.prof: Goroutine profiling information.
|
goroutine.prof: Goroutine profiling information.
|
||||||
block.prof: Block profiling information.
|
block.prof: Block profiling information.
|
||||||
@@ -564,6 +564,8 @@ func (g *BundleGenerator) addStateFile() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debugf("Adding state file from: %s", path)
|
||||||
|
|
||||||
data, err := os.ReadFile(path)
|
data, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, fs.ErrNotExist) {
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
|
|||||||
@@ -14,6 +14,9 @@ type WGIface interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addWgShow() error {
|
func (g *BundleGenerator) addWgShow() error {
|
||||||
|
if g.statusRecorder == nil {
|
||||||
|
return fmt.Errorf("no status recorder available for wg show")
|
||||||
|
}
|
||||||
result, err := g.statusRecorder.PeersStatus()
|
result, err := g.statusRecorder.PeersStatus()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
@@ -50,28 +51,21 @@ func (s *systemConfigurator) supportCustomPort() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||||
var err error
|
|
||||||
|
|
||||||
if err := stateManager.UpdateState(&ShutdownState{}); err != nil {
|
|
||||||
log.Errorf("failed to update shutdown state: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
searchDomains []string
|
searchDomains []string
|
||||||
matchDomains []string
|
matchDomains []string
|
||||||
)
|
)
|
||||||
|
|
||||||
err = s.recordSystemDNSSettings(true)
|
if err := s.recordSystemDNSSettings(true); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to update record of System's DNS config: %s", err.Error())
|
log.Errorf("unable to update record of System's DNS config: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.RouteAll {
|
if config.RouteAll {
|
||||||
searchDomains = append(searchDomains, "\"\"")
|
searchDomains = append(searchDomains, "\"\"")
|
||||||
err = s.addLocalDNS()
|
if err := s.addLocalDNS(); err != nil {
|
||||||
if err != nil {
|
log.Warnf("failed to add local DNS: %v", err)
|
||||||
log.Infof("failed to enable split DNS")
|
|
||||||
}
|
}
|
||||||
|
s.updateState(stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, dConf := range config.Domains {
|
for _, dConf := range config.Domains {
|
||||||
@@ -86,6 +80,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
|||||||
}
|
}
|
||||||
|
|
||||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||||
|
var err error
|
||||||
if len(matchDomains) != 0 {
|
if len(matchDomains) != 0 {
|
||||||
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
|
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
|
||||||
} else {
|
} else {
|
||||||
@@ -95,6 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("add match domains: %w", err)
|
return fmt.Errorf("add match domains: %w", err)
|
||||||
}
|
}
|
||||||
|
s.updateState(stateManager)
|
||||||
|
|
||||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||||
if len(searchDomains) != 0 {
|
if len(searchDomains) != 0 {
|
||||||
@@ -106,6 +102,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("add search domains: %w", err)
|
return fmt.Errorf("add search domains: %w", err)
|
||||||
}
|
}
|
||||||
|
s.updateState(stateManager)
|
||||||
|
|
||||||
if err := s.flushDNSCache(); err != nil {
|
if err := s.flushDNSCache(); err != nil {
|
||||||
log.Errorf("failed to flush DNS cache: %v", err)
|
log.Errorf("failed to flush DNS cache: %v", err)
|
||||||
@@ -114,6 +111,12 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *systemConfigurator) updateState(stateManager *statemanager.Manager) {
|
||||||
|
if err := stateManager.UpdateState(&ShutdownState{CreatedKeys: maps.Keys(s.createdKeys)}); err != nil {
|
||||||
|
log.Errorf("failed to update shutdown state: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) string() string {
|
func (s *systemConfigurator) string() string {
|
||||||
return "scutil"
|
return "scutil"
|
||||||
}
|
}
|
||||||
@@ -167,18 +170,20 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
|||||||
func (s *systemConfigurator) addLocalDNS() error {
|
func (s *systemConfigurator) addLocalDNS() error {
|
||||||
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
|
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
|
||||||
if err := s.recordSystemDNSSettings(true); err != nil {
|
if err := s.recordSystemDNSSettings(true); err != nil {
|
||||||
log.Errorf("Unable to get system DNS configuration")
|
|
||||||
return fmt.Errorf("recordSystemDNSSettings(): %w", err)
|
return fmt.Errorf("recordSystemDNSSettings(): %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||||
if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 {
|
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
|
||||||
err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("couldn't add local network DNS conf: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Info("Not enabling local DNS server")
|
log.Info("Not enabling local DNS server")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.addSearchDomains(
|
||||||
|
localKey,
|
||||||
|
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort,
|
||||||
|
); err != nil {
|
||||||
|
return fmt.Errorf("add search domains: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
111
client/internal/dns/host_darwin_test.go
Normal file
111
client/internal/dns/host_darwin_test.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping scutil integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
stateFile := filepath.Join(tmpDir, "state.json")
|
||||||
|
|
||||||
|
sm := statemanager.New(stateFile)
|
||||||
|
sm.RegisterState(&ShutdownState{})
|
||||||
|
sm.Start()
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, sm.Stop(context.Background()))
|
||||||
|
}()
|
||||||
|
|
||||||
|
configurator := &systemConfigurator{
|
||||||
|
createdKeys: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
config := HostDNSConfig{
|
||||||
|
ServerIP: netip.MustParseAddr("100.64.0.1"),
|
||||||
|
ServerPort: 53,
|
||||||
|
RouteAll: true,
|
||||||
|
Domains: []DomainConfig{
|
||||||
|
{Domain: "example.com", MatchOnly: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := configurator.applyDNSConfig(config, sm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, sm.PersistState(context.Background()))
|
||||||
|
|
||||||
|
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||||
|
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||||
|
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||||
|
_ = removeTestDNSKey(key)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||||
|
exists, err := checkDNSKeyExists(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
if exists {
|
||||||
|
t.Logf("Key %s exists before cleanup", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sm2 := statemanager.New(stateFile)
|
||||||
|
sm2.RegisterState(&ShutdownState{})
|
||||||
|
err = sm2.LoadState(&ShutdownState{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
state := sm2.GetState(&ShutdownState{})
|
||||||
|
if state == nil {
|
||||||
|
t.Skip("State not saved, skipping cleanup test")
|
||||||
|
}
|
||||||
|
|
||||||
|
shutdownState, ok := state.(*ShutdownState)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
err = shutdownState.Cleanup()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||||
|
exists, err := checkDNSKeyExists(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkDNSKeyExists(key string) (bool, error) {
|
||||||
|
cmd := exec.Command(scutilPath)
|
||||||
|
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
if strings.Contains(string(output), "No such key") {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return !strings.Contains(string(output), "No such key"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeTestDNSKey(key string) error {
|
||||||
|
cmd := exec.Command(scutilPath)
|
||||||
|
cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n")
|
||||||
|
_, err := cmd.CombinedOutput()
|
||||||
|
return err
|
||||||
|
}
|
||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/winregistry"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -178,13 +179,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := stateManager.UpdateState(&ShutdownState{
|
r.updateState(stateManager)
|
||||||
Guid: r.guid,
|
|
||||||
GPO: r.gpo,
|
|
||||||
NRPTEntryCount: r.nrptEntryCount,
|
|
||||||
}); err != nil {
|
|
||||||
log.Errorf("failed to update shutdown state: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var searchDomains, matchDomains []string
|
var searchDomains, matchDomains []string
|
||||||
for _, dConf := range config.Domains {
|
for _, dConf := range config.Domains {
|
||||||
@@ -197,6 +192,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
|
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.removeDNSMatchPolicies(); err != nil {
|
||||||
|
log.Errorf("cleanup old dns match policies: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
if len(matchDomains) != 0 {
|
if len(matchDomains) != 0 {
|
||||||
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -204,19 +203,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
}
|
}
|
||||||
r.nrptEntryCount = count
|
r.nrptEntryCount = count
|
||||||
} else {
|
} else {
|
||||||
if err := r.removeDNSMatchPolicies(); err != nil {
|
|
||||||
return fmt.Errorf("remove dns match policies: %w", err)
|
|
||||||
}
|
|
||||||
r.nrptEntryCount = 0
|
r.nrptEntryCount = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := stateManager.UpdateState(&ShutdownState{
|
r.updateState(stateManager)
|
||||||
Guid: r.guid,
|
|
||||||
GPO: r.gpo,
|
|
||||||
NRPTEntryCount: r.nrptEntryCount,
|
|
||||||
}); err != nil {
|
|
||||||
log.Errorf("failed to update shutdown state: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.updateSearchDomains(searchDomains); err != nil {
|
if err := r.updateSearchDomains(searchDomains); err != nil {
|
||||||
return fmt.Errorf("update search domains: %w", err)
|
return fmt.Errorf("update search domains: %w", err)
|
||||||
@@ -227,6 +217,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) {
|
||||||
|
if err := stateManager.UpdateState(&ShutdownState{
|
||||||
|
Guid: r.guid,
|
||||||
|
GPO: r.gpo,
|
||||||
|
NRPTEntryCount: r.nrptEntryCount,
|
||||||
|
}); err != nil {
|
||||||
|
log.Errorf("failed to update shutdown state: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
|
func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
|
||||||
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil {
|
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil {
|
||||||
return fmt.Errorf("adding dns setup for all failed: %w", err)
|
return fmt.Errorf("adding dns setup for all failed: %w", err)
|
||||||
@@ -273,9 +273,9 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s
|
|||||||
return fmt.Errorf("remove existing dns policy: %w", err)
|
return fmt.Errorf("remove existing dns policy: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
|
regKey, _, err := winregistry.CreateVolatileKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
|
return fmt.Errorf("create volatile registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
|
||||||
}
|
}
|
||||||
defer closer(regKey)
|
defer closer(regKey)
|
||||||
|
|
||||||
|
|||||||
102
client/internal/dns/host_windows_test.go
Normal file
102
client/internal/dns/host_windows_test.go
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/sys/windows/registry"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
|
||||||
|
// when the number of match domains decreases between configuration changes.
|
||||||
|
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping registry integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
defer cleanupRegistryKeys(t)
|
||||||
|
cleanupRegistryKeys(t)
|
||||||
|
|
||||||
|
testIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
|
||||||
|
// Create a test interface registry key so updateSearchDomains doesn't fail
|
||||||
|
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
|
||||||
|
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
|
||||||
|
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
|
||||||
|
require.NoError(t, err, "Should create test interface registry key")
|
||||||
|
testKey.Close()
|
||||||
|
defer func() {
|
||||||
|
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
|
||||||
|
}()
|
||||||
|
|
||||||
|
cfg := ®istryConfigurator{
|
||||||
|
guid: testGUID,
|
||||||
|
gpo: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
config5 := HostDNSConfig{
|
||||||
|
ServerIP: testIP,
|
||||||
|
Domains: []DomainConfig{
|
||||||
|
{Domain: "domain1.com", MatchOnly: true},
|
||||||
|
{Domain: "domain2.com", MatchOnly: true},
|
||||||
|
{Domain: "domain3.com", MatchOnly: true},
|
||||||
|
{Domain: "domain4.com", MatchOnly: true},
|
||||||
|
{Domain: "domain5.com", MatchOnly: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cfg.applyDNSConfig(config5, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify all 5 entries exist
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, exists, "Entry %d should exist after first config", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
config2 := HostDNSConfig{
|
||||||
|
ServerIP: testIP,
|
||||||
|
Domains: []DomainConfig{
|
||||||
|
{Domain: "domain1.com", MatchOnly: true},
|
||||||
|
{Domain: "domain2.com", MatchOnly: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cfg.applyDNSConfig(config2, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify first 2 entries exist
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, exists, "Entry %d should exist after second config", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify entries 2-4 are cleaned up
|
||||||
|
for i := 2; i < 5; i++ {
|
||||||
|
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func registryKeyExists(path string) (bool, error) {
|
||||||
|
k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE)
|
||||||
|
if err != nil {
|
||||||
|
if err == registry.ErrNotExist {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
k.Close()
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupRegistryKeys(*testing.T) {
|
||||||
|
cfg := ®istryConfigurator{nrptEntryCount: 10}
|
||||||
|
_ = cfg.removeDNSMatchPolicies()
|
||||||
|
}
|
||||||
@@ -65,8 +65,9 @@ type hostManagerWithOriginalNS interface {
|
|||||||
|
|
||||||
// DefaultServer dns server object
|
// DefaultServer dns server object
|
||||||
type DefaultServer struct {
|
type DefaultServer struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
|
shutdownWg sync.WaitGroup
|
||||||
// disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running.
|
// disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running.
|
||||||
// This is different from ServiceEnable=false from management which completely disables the DNS service.
|
// This is different from ServiceEnable=false from management which completely disables the DNS service.
|
||||||
disableSys bool
|
disableSys bool
|
||||||
@@ -318,6 +319,7 @@ func (s *DefaultServer) DnsIP() netip.Addr {
|
|||||||
// Stop stops the server
|
// Stop stops the server
|
||||||
func (s *DefaultServer) Stop() {
|
func (s *DefaultServer) Stop() {
|
||||||
s.ctxCancel()
|
s.ctxCancel()
|
||||||
|
s.shutdownWg.Wait()
|
||||||
|
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
@@ -507,8 +509,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
|
|
||||||
s.applyHostConfig()
|
s.applyHostConfig()
|
||||||
|
|
||||||
|
s.shutdownWg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
// persist dns state right away
|
defer s.shutdownWg.Done()
|
||||||
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||||
log.Errorf("Failed to persist dns state: %v", err)
|
log.Errorf("Failed to persist dns state: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ const (
|
|||||||
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
|
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
|
||||||
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
|
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
|
||||||
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
|
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
|
||||||
|
systemdDbusSetDNSOverTLSMethodSuffix = systemdDbusLinkInterface + ".SetDNSOverTLS"
|
||||||
systemdDbusResolvConfModeForeign = "foreign"
|
systemdDbusResolvConfModeForeign = "foreign"
|
||||||
|
|
||||||
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
|
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
|
||||||
@@ -102,6 +103,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
|||||||
log.Warnf("failed to set DNSSEC to 'no': %v", err)
|
log.Warnf("failed to set DNSSEC to 'no': %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We don't support DNSOverTLS. On some machines this is default on so we explicitly set it to off
|
||||||
|
if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethodSuffix, dnsSecDisabled); err != nil {
|
||||||
|
log.Warnf("failed to set DNSOverTLS to 'no': %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
searchDomains []string
|
searchDomains []string
|
||||||
matchDomains []string
|
matchDomains []string
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type ShutdownState struct {
|
type ShutdownState struct {
|
||||||
|
CreatedKeys []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) Name() string {
|
func (s *ShutdownState) Name() string {
|
||||||
@@ -19,6 +20,10 @@ func (s *ShutdownState) Cleanup() error {
|
|||||||
return fmt.Errorf("create host manager: %w", err)
|
return fmt.Errorf("create host manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, key := range s.CreatedKeys {
|
||||||
|
manager.createdKeys[key] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
if err := manager.restoreUncleanShutdownDNS(); err != nil {
|
if err := manager.restoreUncleanShutdownDNS(); err != nil {
|
||||||
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
78
client/internal/dnsfwd/cache.go
Normal file
78
client/internal/dnsfwd/cache.go
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
package dnsfwd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
type cache struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
records map[string]*cacheEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
type cacheEntry struct {
|
||||||
|
ip4Addrs []netip.Addr
|
||||||
|
ip6Addrs []netip.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCache() *cache {
|
||||||
|
return &cache{
|
||||||
|
records: make(map[string]*cacheEntry),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
entry, exists := c.records[normalizeDomain(domain)]
|
||||||
|
if !exists {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch reqType {
|
||||||
|
case dns.TypeA:
|
||||||
|
return slices.Clone(entry.ip4Addrs), true
|
||||||
|
case dns.TypeAAAA:
|
||||||
|
return slices.Clone(entry.ip6Addrs), true
|
||||||
|
default:
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
norm := normalizeDomain(domain)
|
||||||
|
entry, exists := c.records[norm]
|
||||||
|
if !exists {
|
||||||
|
entry = &cacheEntry{}
|
||||||
|
c.records[norm] = entry
|
||||||
|
}
|
||||||
|
|
||||||
|
switch reqType {
|
||||||
|
case dns.TypeA:
|
||||||
|
entry.ip4Addrs = slices.Clone(addrs)
|
||||||
|
case dns.TypeAAAA:
|
||||||
|
entry.ip6Addrs = slices.Clone(addrs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// unset removes cached entries for the given domain and request type.
|
||||||
|
func (c *cache) unset(domain string) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
delete(c.records, normalizeDomain(domain))
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeDomain converts an input domain into a canonical form used as cache key:
|
||||||
|
// lowercase and fully-qualified (with trailing dot).
|
||||||
|
func normalizeDomain(domain string) string {
|
||||||
|
// dns.Fqdn ensures trailing dot; ToLower for consistent casing
|
||||||
|
return dns.Fqdn(strings.ToLower(domain))
|
||||||
|
}
|
||||||
86
client/internal/dnsfwd/cache_test.go
Normal file
86
client/internal/dnsfwd/cache_test.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package dnsfwd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustAddr(t *testing.T, s string) netip.Addr {
|
||||||
|
t.Helper()
|
||||||
|
a, err := netip.ParseAddr(s)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parse addr %s: %v", s, err)
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCacheNormalization(t *testing.T) {
|
||||||
|
c := newCache()
|
||||||
|
|
||||||
|
// Mixed case, without trailing dot
|
||||||
|
domainInput := "ExAmPlE.CoM"
|
||||||
|
ipv4 := []netip.Addr{mustAddr(t, "1.2.3.4")}
|
||||||
|
c.set(domainInput, 1 /* dns.TypeA */, ipv4)
|
||||||
|
|
||||||
|
// Lookup with lower, with trailing dot
|
||||||
|
if got, ok := c.get("example.com.", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
|
||||||
|
t.Fatalf("expected cached IPv4 result via normalized key, got=%v ok=%v", got, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lookup with different casing again
|
||||||
|
if got, ok := c.get("EXAMPLE.COM", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
|
||||||
|
t.Fatalf("expected cached IPv4 result via different casing, got=%v ok=%v", got, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCacheSeparateTypes(t *testing.T) {
|
||||||
|
c := newCache()
|
||||||
|
|
||||||
|
domain := "test.local"
|
||||||
|
ipv4 := []netip.Addr{mustAddr(t, "10.0.0.1")}
|
||||||
|
ipv6 := []netip.Addr{mustAddr(t, "2001:db8::1")}
|
||||||
|
|
||||||
|
c.set(domain, 1 /* A */, ipv4)
|
||||||
|
c.set(domain, 28 /* AAAA */, ipv6)
|
||||||
|
|
||||||
|
got4, ok4 := c.get(domain, 1)
|
||||||
|
if !ok4 || len(got4) != 1 || got4[0] != ipv4[0] {
|
||||||
|
t.Fatalf("expected A record from cache, got=%v ok=%v", got4, ok4)
|
||||||
|
}
|
||||||
|
|
||||||
|
got6, ok6 := c.get(domain, 28)
|
||||||
|
if !ok6 || len(got6) != 1 || got6[0] != ipv6[0] {
|
||||||
|
t.Fatalf("expected AAAA record from cache, got=%v ok=%v", got6, ok6)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCacheCloneOnGetAndSet(t *testing.T) {
|
||||||
|
c := newCache()
|
||||||
|
domain := "clone.test"
|
||||||
|
|
||||||
|
src := []netip.Addr{mustAddr(t, "8.8.8.8")}
|
||||||
|
c.set(domain, 1, src)
|
||||||
|
|
||||||
|
// Mutate source slice; cache should be unaffected
|
||||||
|
src[0] = mustAddr(t, "9.9.9.9")
|
||||||
|
|
||||||
|
got, ok := c.get(domain, 1)
|
||||||
|
if !ok || len(got) != 1 || got[0].String() != "8.8.8.8" {
|
||||||
|
t.Fatalf("expected cached value to be independent of source slice, got=%v ok=%v", got, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mutate returned slice; internal cache should remain unchanged
|
||||||
|
got[0] = mustAddr(t, "4.4.4.4")
|
||||||
|
got2, ok2 := c.get(domain, 1)
|
||||||
|
if !ok2 || len(got2) != 1 || got2[0].String() != "8.8.8.8" {
|
||||||
|
t.Fatalf("expected returned slice to be a clone, got=%v ok=%v", got2, ok2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCacheMiss(t *testing.T) {
|
||||||
|
c := newCache()
|
||||||
|
if got, ok := c.get("missing.example", 1); ok || got != nil {
|
||||||
|
t.Fatalf("expected cache miss, got=%v ok=%v", got, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@@ -46,6 +46,7 @@ type DNSForwarder struct {
|
|||||||
fwdEntries []*ForwarderEntry
|
fwdEntries []*ForwarderEntry
|
||||||
firewall firewaller
|
firewall firewaller
|
||||||
resolver resolver
|
resolver resolver
|
||||||
|
cache *cache
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
|
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
|
||||||
@@ -56,6 +57,7 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
|
|||||||
firewall: firewall,
|
firewall: firewall,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
resolver: net.DefaultResolver,
|
resolver: net.DefaultResolver,
|
||||||
|
cache: newCache(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,10 +105,39 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
|||||||
f.mutex.Lock()
|
f.mutex.Lock()
|
||||||
defer f.mutex.Unlock()
|
defer f.mutex.Unlock()
|
||||||
|
|
||||||
|
// remove cache entries for domains that no longer appear
|
||||||
|
f.removeStaleCacheEntries(f.fwdEntries, entries)
|
||||||
|
|
||||||
f.fwdEntries = entries
|
f.fwdEntries = entries
|
||||||
log.Debugf("Updated DNS forwarder with %d domains", len(entries))
|
log.Debugf("Updated DNS forwarder with %d domains", len(entries))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// removeStaleCacheEntries unsets cache items for domains that were present
|
||||||
|
// in the old list but not present in the new list.
|
||||||
|
func (f *DNSForwarder) removeStaleCacheEntries(oldEntries, newEntries []*ForwarderEntry) {
|
||||||
|
if f.cache == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
newSet := make(map[string]struct{}, len(newEntries))
|
||||||
|
for _, e := range newEntries {
|
||||||
|
if e == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newSet[e.Domain.PunycodeString()] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, e := range oldEntries {
|
||||||
|
if e == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pattern := e.Domain.PunycodeString()
|
||||||
|
if _, ok := newSet[pattern]; !ok {
|
||||||
|
f.cache.unset(pattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) Close(ctx context.Context) error {
|
func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||||
var result *multierror.Error
|
var result *multierror.Error
|
||||||
|
|
||||||
@@ -171,6 +202,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
|||||||
|
|
||||||
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
||||||
f.addIPsToResponse(resp, domain, ips)
|
f.addIPsToResponse(resp, domain, ips)
|
||||||
|
f.cache.set(domain, question.Qtype, ips)
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
@@ -282,29 +314,69 @@ func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns
|
|||||||
resp.Rcode = dns.RcodeSuccess
|
resp.Rcode = dns.RcodeSuccess
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response
|
// handleDNSError processes DNS lookup errors and sends an appropriate error response.
|
||||||
func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) {
|
func (f *DNSForwarder) handleDNSError(
|
||||||
|
ctx context.Context,
|
||||||
|
w dns.ResponseWriter,
|
||||||
|
question dns.Question,
|
||||||
|
resp *dns.Msg,
|
||||||
|
domain string,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
// Default to SERVFAIL; override below when appropriate.
|
||||||
|
resp.Rcode = dns.RcodeServerFailure
|
||||||
|
|
||||||
|
qType := question.Qtype
|
||||||
|
qTypeName := dns.TypeToString[qType]
|
||||||
|
|
||||||
|
// Prefer typed DNS errors; fall back to generic logging otherwise.
|
||||||
var dnsErr *net.DNSError
|
var dnsErr *net.DNSError
|
||||||
|
if !errors.As(err, &dnsErr) {
|
||||||
switch {
|
log.Warnf(errResolveFailed, domain, err)
|
||||||
case errors.As(err, &dnsErr):
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||||
resp.Rcode = dns.RcodeServerFailure
|
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||||
if dnsErr.IsNotFound {
|
|
||||||
f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype)
|
|
||||||
}
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if dnsErr.Server != "" {
|
// NotFound: set NXDOMAIN / appropriate code via helper.
|
||||||
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err)
|
if dnsErr.IsNotFound {
|
||||||
} else {
|
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
||||||
log.Warnf(errResolveFailed, domain, err)
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||||
|
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||||
}
|
}
|
||||||
default:
|
f.cache.set(domain, question.Qtype, nil)
|
||||||
resp.Rcode = dns.RcodeServerFailure
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upstream failed but we might have a cached answer—serve it if present.
|
||||||
|
if ips, ok := f.cache.get(domain, qType); ok {
|
||||||
|
if len(ips) > 0 {
|
||||||
|
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||||
|
f.addIPsToResponse(resp, domain, ips)
|
||||||
|
resp.Rcode = dns.RcodeSuccess
|
||||||
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||||
|
log.Errorf("failed to write cached DNS response: %v", writeErr)
|
||||||
|
}
|
||||||
|
} else { // send NXDOMAIN / appropriate code if cache is empty
|
||||||
|
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
||||||
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||||
|
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// No cache. Log with or without the server field for more context.
|
||||||
|
if dnsErr.Server != "" {
|
||||||
|
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
|
||||||
|
} else {
|
||||||
log.Warnf(errResolveFailed, domain, err)
|
log.Warnf(errResolveFailed, domain, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
// Write final failure response.
|
||||||
log.Errorf("failed to write failure DNS response: %v", err)
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||||
|
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -648,6 +648,95 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
|||||||
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
|
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensures that when the first query succeeds and populates the cache,
|
||||||
|
// a subsequent upstream failure still returns a successful response from cache.
|
||||||
|
func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
d, err := domain.FromString("example.com")
|
||||||
|
require.NoError(t, err)
|
||||||
|
entries := []*ForwarderEntry{{Domain: d, ResID: "res-cache"}}
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
ip := netip.MustParseAddr("1.2.3.4")
|
||||||
|
|
||||||
|
// First call resolves successfully and populates cache
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
|
||||||
|
Return([]netip.Addr{ip}, nil).Once()
|
||||||
|
|
||||||
|
// Second call fails upstream; forwarder should serve from cache
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
|
||||||
|
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
|
||||||
|
|
||||||
|
// First query: populate cache
|
||||||
|
q1 := &dns.Msg{}
|
||||||
|
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||||
|
w1 := &test.MockResponseWriter{}
|
||||||
|
resp1 := forwarder.handleDNSQuery(w1, q1)
|
||||||
|
require.NotNil(t, resp1)
|
||||||
|
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||||
|
require.Len(t, resp1.Answer, 1)
|
||||||
|
|
||||||
|
// Second query: serve from cache after upstream failure
|
||||||
|
q2 := &dns.Msg{}
|
||||||
|
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||||
|
var writtenResp *dns.Msg
|
||||||
|
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||||
|
_ = forwarder.handleDNSQuery(w2, q2)
|
||||||
|
|
||||||
|
require.NotNil(t, writtenResp, "expected response to be written")
|
||||||
|
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||||
|
require.Len(t, writtenResp.Answer, 1)
|
||||||
|
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verifies that cache normalization works across casing and trailing dot variations.
|
||||||
|
func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
||||||
|
mockResolver := &MockResolver{}
|
||||||
|
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||||
|
forwarder.resolver = mockResolver
|
||||||
|
|
||||||
|
d, err := domain.FromString("ExAmPlE.CoM")
|
||||||
|
require.NoError(t, err)
|
||||||
|
entries := []*ForwarderEntry{{Domain: d, ResID: "res-norm"}}
|
||||||
|
forwarder.UpdateDomains(entries)
|
||||||
|
|
||||||
|
ip := netip.MustParseAddr("9.8.7.6")
|
||||||
|
|
||||||
|
// Initial resolution with mixed case to populate cache
|
||||||
|
mixedQuery := "ExAmPlE.CoM"
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(strings.ToLower(mixedQuery))).
|
||||||
|
Return([]netip.Addr{ip}, nil).Once()
|
||||||
|
|
||||||
|
q1 := &dns.Msg{}
|
||||||
|
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
||||||
|
w1 := &test.MockResponseWriter{}
|
||||||
|
resp1 := forwarder.handleDNSQuery(w1, q1)
|
||||||
|
require.NotNil(t, resp1)
|
||||||
|
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||||
|
require.Len(t, resp1.Answer, 1)
|
||||||
|
|
||||||
|
// Subsequent query without dot and upper case should hit cache even if upstream fails
|
||||||
|
// Forwarder lowercases and uses the question name as-is (no trailing dot here)
|
||||||
|
mockResolver.On("LookupNetIP", mock.Anything, "ip4", strings.ToLower("EXAMPLE.COM")).
|
||||||
|
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
|
||||||
|
|
||||||
|
q2 := &dns.Msg{}
|
||||||
|
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
||||||
|
var writtenResp *dns.Msg
|
||||||
|
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||||
|
_ = forwarder.handleDNSQuery(w2, q2)
|
||||||
|
|
||||||
|
require.NotNil(t, writtenResp)
|
||||||
|
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||||
|
require.Len(t, writtenResp.Answer, 1)
|
||||||
|
|
||||||
|
mockResolver.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
||||||
// Test complex overlapping pattern scenarios
|
// Test complex overlapping pattern scenarios
|
||||||
mockFirewall := &MockFirewall{}
|
mockFirewall := &MockFirewall{}
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ type Manager struct {
|
|||||||
fwRules []firewall.Rule
|
fwRules []firewall.Rule
|
||||||
tcpRules []firewall.Rule
|
tcpRules []firewall.Rule
|
||||||
dnsForwarder *DNSForwarder
|
dnsForwarder *DNSForwarder
|
||||||
port uint16
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListenPort() uint16 {
|
func ListenPort() uint16 {
|
||||||
@@ -49,11 +48,16 @@ func ListenPort() uint16 {
|
|||||||
return listenPort
|
return listenPort
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager {
|
func SetListenPort(port uint16) {
|
||||||
|
listenPortMu.Lock()
|
||||||
|
listenPort = port
|
||||||
|
listenPortMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
|
||||||
return &Manager{
|
return &Manager{
|
||||||
firewall: fw,
|
firewall: fw,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
port: port,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -67,12 +71,6 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.port > 0 {
|
|
||||||
listenPortMu.Lock()
|
|
||||||
listenPort = m.port
|
|
||||||
listenPortMu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder)
|
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder)
|
||||||
go func() {
|
go func() {
|
||||||
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
|
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
|
||||||
|
|||||||
@@ -200,8 +200,10 @@ type Engine struct {
|
|||||||
flowManager nftypes.FlowManager
|
flowManager nftypes.FlowManager
|
||||||
|
|
||||||
// WireGuard interface monitor
|
// WireGuard interface monitor
|
||||||
wgIfaceMonitor *WGIfaceMonitor
|
wgIfaceMonitor *WGIfaceMonitor
|
||||||
wgIfaceMonitorWg sync.WaitGroup
|
|
||||||
|
// shutdownWg tracks all long-running goroutines to ensure clean shutdown
|
||||||
|
shutdownWg sync.WaitGroup
|
||||||
|
|
||||||
// dns forwarder port
|
// dns forwarder port
|
||||||
dnsFwdPort uint16
|
dnsFwdPort uint16
|
||||||
@@ -326,10 +328,6 @@ func (e *Engine) Stop() error {
|
|||||||
e.cancel()
|
e.cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
|
|
||||||
// Removing peers happens in the conn.Close() asynchronously
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
|
|
||||||
e.close()
|
e.close()
|
||||||
|
|
||||||
// stop flow manager after wg interface is gone
|
// stop flow manager after wg interface is gone
|
||||||
@@ -337,8 +335,6 @@ func (e *Engine) Stop() error {
|
|||||||
e.flowManager.Close()
|
e.flowManager.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("stopped Netbird Engine")
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -349,12 +345,52 @@ func (e *Engine) Stop() error {
|
|||||||
log.Errorf("failed to persist state: %v", err)
|
log.Errorf("failed to persist state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop WireGuard interface monitor and wait for it to exit
|
timeout := e.calculateShutdownTimeout()
|
||||||
e.wgIfaceMonitorWg.Wait()
|
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
|
||||||
|
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("stopped Netbird Engine")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
|
||||||
|
func (e *Engine) calculateShutdownTimeout() time.Duration {
|
||||||
|
peerCount := len(e.peerStore.PeersPubKey())
|
||||||
|
|
||||||
|
baseTimeout := 10 * time.Second
|
||||||
|
perPeerTimeout := time.Duration(peerCount) * 100 * time.Millisecond
|
||||||
|
timeout := baseTimeout + perPeerTimeout
|
||||||
|
|
||||||
|
maxTimeout := 30 * time.Second
|
||||||
|
if timeout > maxTimeout {
|
||||||
|
timeout = maxTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
return timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitWithContext waits for WaitGroup with timeout, returns ctx.Err() on timeout.
|
||||||
|
func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
|
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
|
||||||
// Connections to remote peers are not established here.
|
// Connections to remote peers are not established here.
|
||||||
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
||||||
@@ -484,14 +520,14 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
|
|
||||||
// monitor WireGuard interface lifecycle and restart engine on changes
|
// monitor WireGuard interface lifecycle and restart engine on changes
|
||||||
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
||||||
e.wgIfaceMonitorWg.Add(1)
|
e.shutdownWg.Add(1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer e.wgIfaceMonitorWg.Done()
|
defer e.shutdownWg.Done()
|
||||||
|
|
||||||
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
||||||
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
||||||
e.restartEngine()
|
e.triggerClientRestart()
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
log.Warnf("WireGuard interface monitor: %s", err)
|
log.Warnf("WireGuard interface monitor: %s", err)
|
||||||
}
|
}
|
||||||
@@ -892,7 +928,9 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create ssh server: %w", err)
|
return fmt.Errorf("create ssh server: %w", err)
|
||||||
}
|
}
|
||||||
|
e.shutdownWg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer e.shutdownWg.Done()
|
||||||
// blocking
|
// blocking
|
||||||
err = e.sshServer.Start()
|
err = e.sshServer.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -950,7 +988,9 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
|
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
|
||||||
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
||||||
func (e *Engine) receiveManagementEvents() {
|
func (e *Engine) receiveManagementEvents() {
|
||||||
|
e.shutdownWg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer e.shutdownWg.Done()
|
||||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
|
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to get system info with checks: %v", err)
|
log.Warnf("failed to get system info with checks: %v", err)
|
||||||
@@ -1368,7 +1408,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
|||||||
|
|
||||||
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
|
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
|
||||||
func (e *Engine) receiveSignalEvents() {
|
func (e *Engine) receiveSignalEvents() {
|
||||||
|
e.shutdownWg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer e.shutdownWg.Done()
|
||||||
// connect to a stream of messages coming from the signal server
|
// connect to a stream of messages coming from the signal server
|
||||||
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
|
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
@@ -1724,8 +1766,10 @@ func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// restartEngine restarts the engine by cancelling the client context
|
// triggerClientRestart triggers a full client restart by cancelling the client context.
|
||||||
func (e *Engine) restartEngine() {
|
// Note: This does NOT just restart the engine - it cancels the entire client context,
|
||||||
|
// which causes the connect client's retry loop to create a completely new engine.
|
||||||
|
func (e *Engine) triggerClientRestart() {
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
@@ -1747,7 +1791,9 @@ func (e *Engine) startNetworkMonitor() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
e.networkMonitor = networkmonitor.New()
|
e.networkMonitor = networkmonitor.New()
|
||||||
|
e.shutdownWg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer e.shutdownWg.Done()
|
||||||
if err := e.networkMonitor.Listen(e.ctx); err != nil {
|
if err := e.networkMonitor.Listen(e.ctx); err != nil {
|
||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
log.Infof("network monitor stopped")
|
log.Infof("network monitor stopped")
|
||||||
@@ -1757,8 +1803,8 @@ func (e *Engine) startNetworkMonitor() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("Network monitor: detected network change, restarting engine")
|
log.Infof("Network monitor: detected network change, triggering client restart")
|
||||||
e.restartEngine()
|
e.triggerClientRestart()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1849,6 +1895,10 @@ func (e *Engine) updateDNSForwarder(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if forwarderPort > 0 {
|
||||||
|
dnsfwd.SetListenPort(forwarderPort)
|
||||||
|
}
|
||||||
|
|
||||||
if !enabled {
|
if !enabled {
|
||||||
if e.dnsForwardMgr == nil {
|
if e.dnsForwardMgr == nil {
|
||||||
return
|
return
|
||||||
@@ -1862,7 +1912,7 @@ func (e *Engine) updateDNSForwarder(
|
|||||||
if len(fwdEntries) > 0 {
|
if len(fwdEntries) > 0 {
|
||||||
switch {
|
switch {
|
||||||
case e.dnsForwardMgr == nil:
|
case e.dnsForwardMgr == nil:
|
||||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
|
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
|
||||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||||
log.Errorf("failed to start DNS forward: %v", err)
|
log.Errorf("failed to start DNS forward: %v", err)
|
||||||
e.dnsForwardMgr = nil
|
e.dnsForwardMgr = nil
|
||||||
@@ -1892,7 +1942,7 @@ func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPor
|
|||||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||||
log.Errorf("failed to stop DNS forward: %v", err)
|
log.Errorf("failed to stop DNS forward: %v", err)
|
||||||
}
|
}
|
||||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
|
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
|
||||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||||
log.Errorf("failed to start DNS forward: %v", err)
|
log.Errorf("failed to start DNS forward: %v", err)
|
||||||
e.dnsForwardMgr = nil
|
e.dnsForwardMgr = nil
|
||||||
|
|||||||
@@ -105,6 +105,10 @@ type MockWGIface struct {
|
|||||||
LastActivitiesFunc func() map[string]monotime.Time
|
LastActivitiesFunc func() map[string]monotime.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
|
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
|
||||||
return nil, fmt.Errorf("not implemented")
|
return nil, fmt.Errorf("not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ type wgIfaceBase interface {
|
|||||||
UpdateAddr(newAddr string) error
|
UpdateAddr(newAddr string) error
|
||||||
GetProxy() wgproxy.Proxy
|
GetProxy() wgproxy.Proxy
|
||||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||||
|
RemoveEndpointAddress(key string) error
|
||||||
RemovePeer(peerKey string) error
|
RemovePeer(peerKey string) error
|
||||||
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||||
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||||
|
|||||||
82
client/internal/lazyconn/activity/lazy_conn.go
Normal file
82
client/internal/lazyconn/activity/lazy_conn.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package activity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// lazyConn detects activity when WireGuard attempts to send packets.
|
||||||
|
// It does not deliver packets, only signals that activity occurred.
|
||||||
|
type lazyConn struct {
|
||||||
|
activityCh chan struct{}
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// newLazyConn creates a new lazyConn for activity detection.
|
||||||
|
func newLazyConn() *lazyConn {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
return &lazyConn{
|
||||||
|
activityCh: make(chan struct{}, 1),
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read blocks until the connection is closed.
|
||||||
|
func (c *lazyConn) Read(_ []byte) (n int, err error) {
|
||||||
|
<-c.ctx.Done()
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write signals activity detection when ICEBind routes packets to this endpoint.
|
||||||
|
func (c *lazyConn) Write(b []byte) (n int, err error) {
|
||||||
|
if c.ctx.Err() != nil {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case c.activityCh <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ActivityChan returns the channel that signals when activity is detected.
|
||||||
|
func (c *lazyConn) ActivityChan() <-chan struct{} {
|
||||||
|
return c.activityCh
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the connection.
|
||||||
|
func (c *lazyConn) Close() error {
|
||||||
|
c.cancel()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr returns the local address.
|
||||||
|
func (c *lazyConn) LocalAddr() net.Addr {
|
||||||
|
return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoteAddr returns the remote address.
|
||||||
|
func (c *lazyConn) RemoteAddr() net.Addr {
|
||||||
|
return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeadline sets the read and write deadlines.
|
||||||
|
func (c *lazyConn) SetDeadline(_ time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadDeadline sets the deadline for future Read calls.
|
||||||
|
func (c *lazyConn) SetReadDeadline(_ time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetWriteDeadline sets the deadline for future Write calls.
|
||||||
|
func (c *lazyConn) SetWriteDeadline(_ time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
127
client/internal/lazyconn/activity/listener_bind.go
Normal file
127
client/internal/lazyconn/activity/listener_bind.go
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
package activity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
type bindProvider interface {
|
||||||
|
GetBind() device.EndpointManager
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// lazyBindPort is an obscure port used for lazy peer endpoints to avoid confusion with real peers.
|
||||||
|
// The actual routing is done via fakeIP in ICEBind, not by this port.
|
||||||
|
lazyBindPort = 17473
|
||||||
|
)
|
||||||
|
|
||||||
|
// BindListener uses lazyConn with bind implementations for direct data passing in userspace bind mode.
|
||||||
|
type BindListener struct {
|
||||||
|
wgIface WgInterface
|
||||||
|
peerCfg lazyconn.PeerConfig
|
||||||
|
done sync.WaitGroup
|
||||||
|
|
||||||
|
lazyConn *lazyConn
|
||||||
|
bind device.EndpointManager
|
||||||
|
fakeIP netip.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBindListener creates a listener that passes data directly through bind using LazyConn.
|
||||||
|
// It automatically derives a unique fake IP from the peer's NetBird IP in the 127.2.x.x range.
|
||||||
|
func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyconn.PeerConfig) (*BindListener, error) {
|
||||||
|
fakeIP, err := deriveFakeIP(wgIface, cfg.AllowedIPs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("derive fake IP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
d := &BindListener{
|
||||||
|
wgIface: wgIface,
|
||||||
|
peerCfg: cfg,
|
||||||
|
bind: bind,
|
||||||
|
fakeIP: fakeIP,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := d.setupLazyConn(); err != nil {
|
||||||
|
return nil, fmt.Errorf("setup lazy connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
d.done.Add(1)
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP.
|
||||||
|
// Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y).
|
||||||
|
// It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface.
|
||||||
|
func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) {
|
||||||
|
if len(allowedIPs) == 0 {
|
||||||
|
return netip.Addr{}, fmt.Errorf("no allowed IPs for peer")
|
||||||
|
}
|
||||||
|
|
||||||
|
ourNetwork := wgIface.Address().Network
|
||||||
|
|
||||||
|
var peerIP netip.Addr
|
||||||
|
for _, allowedIP := range allowedIPs {
|
||||||
|
ip := allowedIP.Addr()
|
||||||
|
if !ip.Is4() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ourNetwork.Contains(ip) {
|
||||||
|
peerIP = ip
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !peerIP.IsValid() {
|
||||||
|
return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs")
|
||||||
|
}
|
||||||
|
|
||||||
|
octets := peerIP.As4()
|
||||||
|
fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]})
|
||||||
|
return fakeIP, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *BindListener) setupLazyConn() error {
|
||||||
|
d.lazyConn = newLazyConn()
|
||||||
|
d.bind.SetEndpoint(d.fakeIP, d.lazyConn)
|
||||||
|
|
||||||
|
endpoint := &net.UDPAddr{
|
||||||
|
IP: d.fakeIP.AsSlice(),
|
||||||
|
Port: lazyBindPort,
|
||||||
|
}
|
||||||
|
return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, endpoint, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadPackets blocks until activity is detected on the LazyConn or the listener is closed.
|
||||||
|
func (d *BindListener) ReadPackets() {
|
||||||
|
select {
|
||||||
|
case <-d.lazyConn.ActivityChan():
|
||||||
|
d.peerCfg.Log.Infof("activity detected via LazyConn")
|
||||||
|
case <-d.lazyConn.ctx.Done():
|
||||||
|
d.peerCfg.Log.Infof("exit from activity listener")
|
||||||
|
}
|
||||||
|
|
||||||
|
d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey)
|
||||||
|
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
|
||||||
|
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = d.lazyConn.Close()
|
||||||
|
d.bind.RemoveEndpoint(d.fakeIP)
|
||||||
|
d.done.Done()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the listener and cleans up resources.
|
||||||
|
func (d *BindListener) Close() {
|
||||||
|
d.peerCfg.Log.Infof("closing activity listener (LazyConn)")
|
||||||
|
|
||||||
|
if err := d.lazyConn.Close(); err != nil {
|
||||||
|
d.peerCfg.Log.Errorf("failed to close LazyConn: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
d.done.Wait()
|
||||||
|
}
|
||||||
291
client/internal/lazyconn/activity/listener_bind_test.go
Normal file
291
client/internal/lazyconn/activity/listener_bind_test.go
Normal file
@@ -0,0 +1,291 @@
|
|||||||
|
package activity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
|
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
|
)
|
||||||
|
|
||||||
|
func isBindListenerPlatform() bool {
|
||||||
|
return runtime.GOOS == "windows" || runtime.GOOS == "js"
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockEndpointManager implements device.EndpointManager for testing
|
||||||
|
type mockEndpointManager struct {
|
||||||
|
endpoints map[netip.Addr]net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockEndpointManager() *mockEndpointManager {
|
||||||
|
return &mockEndpointManager{
|
||||||
|
endpoints: make(map[netip.Addr]net.Conn),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockEndpointManager) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
|
||||||
|
m.endpoints[fakeIP] = conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockEndpointManager) RemoveEndpoint(fakeIP netip.Addr) {
|
||||||
|
delete(m.endpoints, fakeIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockEndpointManager) GetEndpoint(fakeIP netip.Addr) net.Conn {
|
||||||
|
return m.endpoints[fakeIP]
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockWGIfaceBind mocks WgInterface with bind support
|
||||||
|
type MockWGIfaceBind struct {
|
||||||
|
endpointMgr *mockEndpointManager
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockWGIfaceBind) RemovePeer(string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockWGIfaceBind) IsUserspaceBind() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockWGIfaceBind) Address() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: netip.MustParseAddr("100.64.0.1"),
|
||||||
|
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockWGIfaceBind) GetBind() device.EndpointManager {
|
||||||
|
return m.endpointMgr
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBindListener_Creation(t *testing.T) {
|
||||||
|
mockEndpointMgr := newMockEndpointManager()
|
||||||
|
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||||
|
|
||||||
|
peer := &MocPeer{PeerID: "testPeer1"}
|
||||||
|
cfg := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peer.PeerID,
|
||||||
|
PeerConnID: peer.ConnID(),
|
||||||
|
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||||
|
Log: log.WithField("peer", "testPeer1"),
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
expectedFakeIP := netip.MustParseAddr("127.2.0.2")
|
||||||
|
conn := mockEndpointMgr.GetEndpoint(expectedFakeIP)
|
||||||
|
require.NotNil(t, conn, "Endpoint should be registered in mock endpoint manager")
|
||||||
|
|
||||||
|
_, ok := conn.(*lazyConn)
|
||||||
|
assert.True(t, ok, "Registered endpoint should be a lazyConn")
|
||||||
|
|
||||||
|
readPacketsDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
listener.ReadPackets()
|
||||||
|
close(readPacketsDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
listener.Close()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-readPacketsDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for ReadPackets to exit after Close")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBindListener_ActivityDetection(t *testing.T) {
|
||||||
|
mockEndpointMgr := newMockEndpointManager()
|
||||||
|
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||||
|
|
||||||
|
peer := &MocPeer{PeerID: "testPeer1"}
|
||||||
|
cfg := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peer.PeerID,
|
||||||
|
PeerConnID: peer.ConnID(),
|
||||||
|
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||||
|
Log: log.WithField("peer", "testPeer1"),
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
activityDetected := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
listener.ReadPackets()
|
||||||
|
close(activityDetected)
|
||||||
|
}()
|
||||||
|
|
||||||
|
fakeIP := listener.fakeIP
|
||||||
|
conn := mockEndpointMgr.GetEndpoint(fakeIP)
|
||||||
|
require.NotNil(t, conn, "Endpoint should be registered")
|
||||||
|
|
||||||
|
_, err = conn.Write([]byte{0x01, 0x02, 0x03})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-activityDetected:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for activity detection")
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity detection")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBindListener_Close(t *testing.T) {
|
||||||
|
mockEndpointMgr := newMockEndpointManager()
|
||||||
|
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||||
|
|
||||||
|
peer := &MocPeer{PeerID: "testPeer1"}
|
||||||
|
cfg := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peer.PeerID,
|
||||||
|
PeerConnID: peer.ConnID(),
|
||||||
|
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||||
|
Log: log.WithField("peer", "testPeer1"),
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
readPacketsDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
listener.ReadPackets()
|
||||||
|
close(readPacketsDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
fakeIP := listener.fakeIP
|
||||||
|
listener.Close()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-readPacketsDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for ReadPackets to exit after Close")
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after Close")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_BindMode(t *testing.T) {
|
||||||
|
if !isBindListenerPlatform() {
|
||||||
|
t.Skip("BindListener only used on Windows/JS platforms")
|
||||||
|
}
|
||||||
|
|
||||||
|
mockEndpointMgr := newMockEndpointManager()
|
||||||
|
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||||
|
|
||||||
|
peer := &MocPeer{PeerID: "testPeer1"}
|
||||||
|
mgr := NewManager(mockIface)
|
||||||
|
defer mgr.Close()
|
||||||
|
|
||||||
|
cfg := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peer.PeerID,
|
||||||
|
PeerConnID: peer.ConnID(),
|
||||||
|
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||||
|
Log: log.WithField("peer", "testPeer1"),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := mgr.MonitorPeerActivity(cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
listener, exists := mgr.GetPeerListener(cfg.PeerConnID)
|
||||||
|
require.True(t, exists, "Peer listener should be found")
|
||||||
|
|
||||||
|
bindListener, ok := listener.(*BindListener)
|
||||||
|
require.True(t, ok, "Listener should be BindListener, got %T", listener)
|
||||||
|
|
||||||
|
fakeIP := bindListener.fakeIP
|
||||||
|
conn := mockEndpointMgr.GetEndpoint(fakeIP)
|
||||||
|
require.NotNil(t, conn, "Endpoint should be registered")
|
||||||
|
|
||||||
|
_, err = conn.Write([]byte{0x01, 0x02, 0x03})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case peerConnID := <-mgr.OnActivityChan:
|
||||||
|
assert.Equal(t, cfg.PeerConnID, peerConnID, "Received peer connection ID should match")
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for activity notification")
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_BindMode_MultiplePeers(t *testing.T) {
|
||||||
|
if !isBindListenerPlatform() {
|
||||||
|
t.Skip("BindListener only used on Windows/JS platforms")
|
||||||
|
}
|
||||||
|
|
||||||
|
mockEndpointMgr := newMockEndpointManager()
|
||||||
|
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||||
|
|
||||||
|
peer1 := &MocPeer{PeerID: "testPeer1"}
|
||||||
|
peer2 := &MocPeer{PeerID: "testPeer2"}
|
||||||
|
mgr := NewManager(mockIface)
|
||||||
|
defer mgr.Close()
|
||||||
|
|
||||||
|
cfg1 := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peer1.PeerID,
|
||||||
|
PeerConnID: peer1.ConnID(),
|
||||||
|
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||||
|
Log: log.WithField("peer", "testPeer1"),
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg2 := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peer2.PeerID,
|
||||||
|
PeerConnID: peer2.ConnID(),
|
||||||
|
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.3/32")},
|
||||||
|
Log: log.WithField("peer", "testPeer2"),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := mgr.MonitorPeerActivity(cfg1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = mgr.MonitorPeerActivity(cfg2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
listener1, exists := mgr.GetPeerListener(cfg1.PeerConnID)
|
||||||
|
require.True(t, exists, "Peer1 listener should be found")
|
||||||
|
bindListener1 := listener1.(*BindListener)
|
||||||
|
|
||||||
|
listener2, exists := mgr.GetPeerListener(cfg2.PeerConnID)
|
||||||
|
require.True(t, exists, "Peer2 listener should be found")
|
||||||
|
bindListener2 := listener2.(*BindListener)
|
||||||
|
|
||||||
|
conn1 := mockEndpointMgr.GetEndpoint(bindListener1.fakeIP)
|
||||||
|
require.NotNil(t, conn1, "Peer1 endpoint should be registered")
|
||||||
|
_, err = conn1.Write([]byte{0x01})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
conn2 := mockEndpointMgr.GetEndpoint(bindListener2.fakeIP)
|
||||||
|
require.NotNil(t, conn2, "Peer2 endpoint should be registered")
|
||||||
|
_, err = conn2.Write([]byte{0x02})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
receivedPeers := make(map[peerid.ConnID]bool)
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
select {
|
||||||
|
case peerConnID := <-mgr.OnActivityChan:
|
||||||
|
receivedPeers[peerConnID] = true
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for activity notifications")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, receivedPeers[cfg1.PeerConnID], "Peer1 activity should be received")
|
||||||
|
assert.True(t, receivedPeers[cfg2.PeerConnID], "Peer2 activity should be received")
|
||||||
|
}
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
package activity
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewListener(t *testing.T) {
|
|
||||||
peer := &MocPeer{
|
|
||||||
PeerID: "examplePublicKey1",
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg := lazyconn.PeerConfig{
|
|
||||||
PublicKey: peer.PeerID,
|
|
||||||
PeerConnID: peer.ConnID(),
|
|
||||||
Log: log.WithField("peer", "examplePublicKey1"),
|
|
||||||
}
|
|
||||||
|
|
||||||
l, err := NewListener(MocWGIface{}, cfg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to create listener: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
chanClosed := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
defer close(chanClosed)
|
|
||||||
l.ReadPackets()
|
|
||||||
}()
|
|
||||||
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
l.Close()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-chanClosed:
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -11,26 +11,27 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking
|
// UDPListener uses UDP sockets for activity detection in kernel mode.
|
||||||
type Listener struct {
|
type UDPListener struct {
|
||||||
wgIface WgInterface
|
wgIface WgInterface
|
||||||
peerCfg lazyconn.PeerConfig
|
peerCfg lazyconn.PeerConfig
|
||||||
conn *net.UDPConn
|
conn *net.UDPConn
|
||||||
endpoint *net.UDPAddr
|
endpoint *net.UDPAddr
|
||||||
done sync.Mutex
|
done sync.Mutex
|
||||||
|
|
||||||
isClosed atomic.Bool // use to avoid error log when closing the listener
|
isClosed atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error) {
|
// NewUDPListener creates a listener that detects activity via UDP socket reads.
|
||||||
d := &Listener{
|
func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener, error) {
|
||||||
|
d := &UDPListener{
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
peerCfg: cfg,
|
peerCfg: cfg,
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := d.newConn()
|
conn, err := d.newConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to creating activity listener: %v", err)
|
return nil, fmt.Errorf("create UDP connection: %v", err)
|
||||||
}
|
}
|
||||||
d.conn = conn
|
d.conn = conn
|
||||||
d.endpoint = conn.LocalAddr().(*net.UDPAddr)
|
d.endpoint = conn.LocalAddr().(*net.UDPAddr)
|
||||||
@@ -38,12 +39,14 @@ func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error
|
|||||||
if err := d.createEndpoint(); err != nil {
|
if err := d.createEndpoint(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
d.done.Lock()
|
d.done.Lock()
|
||||||
cfg.Log.Infof("created activity listener: %s", conn.LocalAddr().(*net.UDPAddr).String())
|
cfg.Log.Infof("created activity listener: %s", d.conn.LocalAddr().(*net.UDPAddr).String())
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Listener) ReadPackets() {
|
// ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed.
|
||||||
|
func (d *UDPListener) ReadPackets() {
|
||||||
for {
|
for {
|
||||||
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
|
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -64,15 +67,17 @@ func (d *Listener) ReadPackets() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
|
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
|
||||||
if err := d.removeEndpoint(); err != nil {
|
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
|
||||||
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
|
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = d.conn.Close() // do not care err because some cases it will return "use of closed network connection"
|
// Ignore close error as it may return "use of closed network connection" if already closed.
|
||||||
|
_ = d.conn.Close()
|
||||||
d.done.Unlock()
|
d.done.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Listener) Close() {
|
// Close stops the listener and cleans up resources.
|
||||||
|
func (d *UDPListener) Close() {
|
||||||
d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())
|
d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())
|
||||||
d.isClosed.Store(true)
|
d.isClosed.Store(true)
|
||||||
|
|
||||||
@@ -82,16 +87,12 @@ func (d *Listener) Close() {
|
|||||||
d.done.Lock()
|
d.done.Lock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Listener) removeEndpoint() error {
|
func (d *UDPListener) createEndpoint() error {
|
||||||
return d.wgIface.RemovePeer(d.peerCfg.PublicKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Listener) createEndpoint() error {
|
|
||||||
d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String())
|
d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String())
|
||||||
return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil)
|
return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Listener) newConn() (*net.UDPConn, error) {
|
func (d *UDPListener) newConn() (*net.UDPConn, error) {
|
||||||
addr := &net.UDPAddr{
|
addr := &net.UDPAddr{
|
||||||
Port: 0,
|
Port: 0,
|
||||||
IP: listenIP,
|
IP: listenIP,
|
||||||
110
client/internal/lazyconn/activity/listener_udp_test.go
Normal file
110
client/internal/lazyconn/activity/listener_udp_test.go
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
package activity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUDPListener_Creation(t *testing.T) {
|
||||||
|
mockIface := &MocWGIface{}
|
||||||
|
|
||||||
|
peer := &MocPeer{PeerID: "testPeer1"}
|
||||||
|
cfg := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peer.PeerID,
|
||||||
|
PeerConnID: peer.ConnID(),
|
||||||
|
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||||
|
Log: log.WithField("peer", "testPeer1"),
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err := NewUDPListener(mockIface, cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, listener.conn)
|
||||||
|
require.NotNil(t, listener.endpoint)
|
||||||
|
|
||||||
|
readPacketsDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
listener.ReadPackets()
|
||||||
|
close(readPacketsDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
listener.Close()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-readPacketsDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for ReadPackets to exit after Close")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDPListener_ActivityDetection(t *testing.T) {
|
||||||
|
mockIface := &MocWGIface{}
|
||||||
|
|
||||||
|
peer := &MocPeer{PeerID: "testPeer1"}
|
||||||
|
cfg := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peer.PeerID,
|
||||||
|
PeerConnID: peer.ConnID(),
|
||||||
|
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||||
|
Log: log.WithField("peer", "testPeer1"),
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err := NewUDPListener(mockIface, cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
activityDetected := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
listener.ReadPackets()
|
||||||
|
close(activityDetected)
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn, err := net.Dial("udp", listener.conn.LocalAddr().String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
_, err = conn.Write([]byte{0x01, 0x02, 0x03})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-activityDetected:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for activity detection")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDPListener_Close(t *testing.T) {
|
||||||
|
mockIface := &MocWGIface{}
|
||||||
|
|
||||||
|
peer := &MocPeer{PeerID: "testPeer1"}
|
||||||
|
cfg := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peer.PeerID,
|
||||||
|
PeerConnID: peer.ConnID(),
|
||||||
|
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||||
|
Log: log.WithField("peer", "testPeer1"),
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err := NewUDPListener(mockIface, cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
readPacketsDone := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
listener.ReadPackets()
|
||||||
|
close(readPacketsDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
listener.Close()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-readPacketsDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for ReadPackets to exit after Close")
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, listener.isClosed.Load(), "Listener should be marked as closed")
|
||||||
|
}
|
||||||
@@ -1,21 +1,32 @@
|
|||||||
package activity
|
package activity
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// listener defines the contract for activity detection listeners.
|
||||||
|
type listener interface {
|
||||||
|
ReadPackets()
|
||||||
|
Close()
|
||||||
|
}
|
||||||
|
|
||||||
type WgInterface interface {
|
type WgInterface interface {
|
||||||
RemovePeer(peerKey string) error
|
RemovePeer(peerKey string) error
|
||||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||||
|
IsUserspaceBind() bool
|
||||||
|
Address() wgaddr.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
@@ -23,7 +34,7 @@ type Manager struct {
|
|||||||
|
|
||||||
wgIface WgInterface
|
wgIface WgInterface
|
||||||
|
|
||||||
peers map[peerid.ConnID]*Listener
|
peers map[peerid.ConnID]listener
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@@ -33,7 +44,7 @@ func NewManager(wgIface WgInterface) *Manager {
|
|||||||
m := &Manager{
|
m := &Manager{
|
||||||
OnActivityChan: make(chan peerid.ConnID, 1),
|
OnActivityChan: make(chan peerid.ConnID, 1),
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
peers: make(map[peerid.ConnID]*Listener),
|
peers: make(map[peerid.ConnID]listener),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
}
|
}
|
||||||
return m
|
return m
|
||||||
@@ -48,16 +59,38 @@ func (m *Manager) MonitorPeerActivity(peerCfg lazyconn.PeerConfig) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
listener, err := NewListener(m.wgIface, peerCfg)
|
listener, err := m.createListener(peerCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
m.peers[peerCfg.PeerConnID] = listener
|
|
||||||
|
|
||||||
|
m.peers[peerCfg.PeerConnID] = listener
|
||||||
go m.waitForTraffic(listener, peerCfg.PeerConnID)
|
go m.waitForTraffic(listener, peerCfg.PeerConnID)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error) {
|
||||||
|
if !m.wgIface.IsUserspaceBind() {
|
||||||
|
return NewUDPListener(m.wgIface, peerCfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BindListener is only used on Windows and JS platforms:
|
||||||
|
// - JS: Cannot listen to UDP sockets
|
||||||
|
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
|
||||||
|
// gateway points to, preventing them from reaching the loopback interface.
|
||||||
|
// BindListener bypasses this by passing data directly through the bind.
|
||||||
|
if runtime.GOOS != "windows" && runtime.GOOS != "js" {
|
||||||
|
return NewUDPListener(m.wgIface, peerCfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, ok := m.wgIface.(bindProvider)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider")
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewBindListener(m.wgIface, provider.GetBind(), peerCfg)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) {
|
func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
@@ -82,8 +115,8 @@ func (m *Manager) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) waitForTraffic(listener *Listener, peerConnID peerid.ConnID) {
|
func (m *Manager) waitForTraffic(l listener, peerConnID peerid.ConnID) {
|
||||||
listener.ReadPackets()
|
l.ReadPackets()
|
||||||
|
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
if _, ok := m.peers[peerConnID]; !ok {
|
if _, ok := m.peers[peerConnID]; !ok {
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
)
|
)
|
||||||
@@ -30,16 +31,26 @@ func (m MocWGIface) RemovePeer(string) error {
|
|||||||
|
|
||||||
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
|
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add this method to the Manager struct
|
func (m MocWGIface) IsUserspaceBind() bool {
|
||||||
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (*Listener, bool) {
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MocWGIface) Address() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: netip.MustParseAddr("100.64.0.1"),
|
||||||
|
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerListener is a test helper to access listeners
|
||||||
|
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
listener, exists := m.peers[peerConnID]
|
l, exists := m.peers[peerConnID]
|
||||||
return listener, exists
|
return l, exists
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_MonitorPeerActivity(t *testing.T) {
|
func TestManager_MonitorPeerActivity(t *testing.T) {
|
||||||
@@ -65,7 +76,12 @@ func TestManager_MonitorPeerActivity(t *testing.T) {
|
|||||||
t.Fatalf("peer listener not found")
|
t.Fatalf("peer listener not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
|
// Get the UDP listener's address for triggering
|
||||||
|
udpListener, ok := listener.(*UDPListener)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected UDPListener")
|
||||||
|
}
|
||||||
|
if err := trigger(udpListener.conn.LocalAddr().String()); err != nil {
|
||||||
t.Fatalf("failed to trigger activity: %v", err)
|
t.Fatalf("failed to trigger activity: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -97,7 +113,9 @@ func TestManager_RemovePeerActivity(t *testing.T) {
|
|||||||
t.Fatalf("failed to monitor peer activity: %v", err)
|
t.Fatalf("failed to monitor peer activity: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()
|
listener, _ := mgr.GetPeerListener(peerCfg1.PeerConnID)
|
||||||
|
udpListener, _ := listener.(*UDPListener)
|
||||||
|
addr := udpListener.conn.LocalAddr().String()
|
||||||
|
|
||||||
mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID)
|
mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID)
|
||||||
|
|
||||||
@@ -147,7 +165,8 @@ func TestManager_MultiPeerActivity(t *testing.T) {
|
|||||||
t.Fatalf("peer listener for peer1 not found")
|
t.Fatalf("peer listener for peer1 not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
|
udpListener1, _ := listener.(*UDPListener)
|
||||||
|
if err := trigger(udpListener1.conn.LocalAddr().String()); err != nil {
|
||||||
t.Fatalf("failed to trigger activity: %v", err)
|
t.Fatalf("failed to trigger activity: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,7 +175,8 @@ func TestManager_MultiPeerActivity(t *testing.T) {
|
|||||||
t.Fatalf("peer listener for peer2 not found")
|
t.Fatalf("peer listener for peer2 not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
|
udpListener2, _ := listener.(*UDPListener)
|
||||||
|
if err := trigger(udpListener2.conn.LocalAddr().String()); err != nil {
|
||||||
t.Fatalf("failed to trigger activity: %v", err)
|
t.Fatalf("failed to trigger activity: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/monotime"
|
"github.com/netbirdio/netbird/monotime"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -14,5 +15,6 @@ type WGIface interface {
|
|||||||
RemovePeer(peerKey string) error
|
RemovePeer(peerKey string) error
|
||||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
|
Address() wgaddr.Address
|
||||||
LastActivities() map[string]monotime.Time
|
LastActivities() map[string]monotime.Time
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import (
|
|||||||
// Manager handles netflow tracking and logging
|
// Manager handles netflow tracking and logging
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
|
shutdownWg sync.WaitGroup
|
||||||
logger nftypes.FlowLogger
|
logger nftypes.FlowLogger
|
||||||
flowConfig *nftypes.FlowConfig
|
flowConfig *nftypes.FlowConfig
|
||||||
conntrack nftypes.ConnTracker
|
conntrack nftypes.ConnTracker
|
||||||
@@ -105,8 +106,15 @@ func (m *Manager) resetClient() error {
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
m.cancel = cancel
|
m.cancel = cancel
|
||||||
|
|
||||||
go m.receiveACKs(ctx, flowClient)
|
m.shutdownWg.Add(2)
|
||||||
go m.startSender(ctx)
|
go func() {
|
||||||
|
defer m.shutdownWg.Done()
|
||||||
|
m.receiveACKs(ctx, flowClient)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer m.shutdownWg.Done()
|
||||||
|
m.startSender(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -176,11 +184,12 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error {
|
|||||||
// Close cleans up all resources
|
// Close cleans up all resources
|
||||||
func (m *Manager) Close() {
|
func (m *Manager) Close() {
|
||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
defer m.mux.Unlock()
|
|
||||||
|
|
||||||
if err := m.disableFlow(); err != nil {
|
if err := m.disableFlow(); err != nil {
|
||||||
log.Warnf("failed to disable flow manager: %v", err)
|
log.Warnf("failed to disable flow manager: %v", err)
|
||||||
}
|
}
|
||||||
|
m.mux.Unlock()
|
||||||
|
|
||||||
|
m.shutdownWg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLogger returns the flow logger
|
// GetLogger returns the flow logger
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd
|
//go:build dragonfly || freebsd || netbsd || openbsd
|
||||||
|
|
||||||
package networkmonitor
|
package networkmonitor
|
||||||
|
|
||||||
|
|||||||
344
client/internal/networkmonitor/check_change_darwin.go
Normal file
344
client/internal/networkmonitor/check_change_darwin.go
Normal file
@@ -0,0 +1,344 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package networkmonitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"hash/fnv"
|
||||||
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/net/route"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
|
)
|
||||||
|
|
||||||
|
// todo: refactor to not use static functions
|
||||||
|
|
||||||
|
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||||
|
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("open routing socket: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := unix.Close(fd)
|
||||||
|
if err != nil && !errors.Is(err, unix.EBADF) {
|
||||||
|
log.Warnf("Network monitor: failed to close routing socket: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
routeChanged := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
routeCheck(ctx, fd, nexthopv4, nexthopv6)
|
||||||
|
close(routeChanged)
|
||||||
|
}()
|
||||||
|
|
||||||
|
wakeUp := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
wakeUpListen(ctx)
|
||||||
|
close(wakeUp)
|
||||||
|
}()
|
||||||
|
|
||||||
|
gatewayChanged := make(chan string)
|
||||||
|
go func() {
|
||||||
|
gatewayPoll(ctx, nexthopv4, nexthopv6, gatewayChanged)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-routeChanged:
|
||||||
|
log.Infof("route change detected via routing socket")
|
||||||
|
return nil
|
||||||
|
case <-wakeUp:
|
||||||
|
log.Infof("wakeup detected via sleep hash change")
|
||||||
|
return nil
|
||||||
|
case reason := <-gatewayChanged:
|
||||||
|
log.Infof("gateway change detected via polling: %s", reason)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func routeCheck(ctx context.Context, fd int, nexthopv4 systemops.Nexthop, nexthopv6 systemops.Nexthop) {
|
||||||
|
for {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 2048)
|
||||||
|
n, err := unix.Read(fd, buf)
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
|
||||||
|
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if n < unix.SizeofRtMsghdr {
|
||||||
|
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||||
|
|
||||||
|
switch msg.Type {
|
||||||
|
// handle route changes
|
||||||
|
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||||
|
route, err := parseRouteMessage(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if route.Dst.Bits() != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
intf := "<nil>"
|
||||||
|
if route.Interface != nil {
|
||||||
|
intf = route.Interface.Name
|
||||||
|
}
|
||||||
|
switch msg.Type {
|
||||||
|
case unix.RTM_ADD:
|
||||||
|
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||||
|
return
|
||||||
|
case unix.RTM_DELETE:
|
||||||
|
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
||||||
|
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
||||||
|
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msgs) != 1 {
|
||||||
|
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, ok := msgs[0].(*route.RouteMessage)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
return systemops.MsgToRoute(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func wakeUpListen(ctx context.Context) {
|
||||||
|
log.Infof("start to watch for system wakeups")
|
||||||
|
var (
|
||||||
|
initialHash uint32
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
// Keep retrying until initial sysctl succeeds or context is canceled
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Info("exit from wakeUpListen initial hash detection due to context cancellation")
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
initialHash, err = readSleepTimeHash()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to detect initial sleep time: %v", err)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Info("exit from wakeUpListen initial hash detection due to context cancellation")
|
||||||
|
return
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Infof("initial wakeup hash: %d", initialHash)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(5 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
lastCheck := time.Now()
|
||||||
|
const maxTickerDrift = 1 * time.Minute
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Info("context canceled, stopping wakeUpListen")
|
||||||
|
return
|
||||||
|
|
||||||
|
case <-ticker.C:
|
||||||
|
now := time.Now()
|
||||||
|
elapsed := now.Sub(lastCheck)
|
||||||
|
|
||||||
|
// If more time passed than expected, system likely slept (informational only)
|
||||||
|
if elapsed > maxTickerDrift {
|
||||||
|
upOut, err := exec.Command("uptime").Output()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to run uptime command: %v", err)
|
||||||
|
upOut = []byte("unknown")
|
||||||
|
}
|
||||||
|
log.Infof("Time drift detected (potential wakeup): expected ~5s, actual %s, uptime: %s", elapsed, upOut)
|
||||||
|
|
||||||
|
currentV4, errV4 := systemops.GetNextHop(netip.IPv4Unspecified())
|
||||||
|
currentV6, errV6 := systemops.GetNextHop(netip.IPv6Unspecified())
|
||||||
|
if errV4 == nil {
|
||||||
|
log.Infof("Current IPv4 default gateway: %s via %s", currentV4.IP, currentV4.Intf.Name)
|
||||||
|
} else {
|
||||||
|
log.Debugf("No IPv4 default gateway: %v", errV4)
|
||||||
|
}
|
||||||
|
if errV6 == nil {
|
||||||
|
log.Infof("Current IPv6 default gateway: %s via %s", currentV6.IP, currentV6.Intf.Name)
|
||||||
|
} else {
|
||||||
|
log.Debugf("No IPv6 default gateway: %v", errV6)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newHash, err := readSleepTimeHash()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to read sleep time hash: %v", err)
|
||||||
|
lastCheck = now
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if newHash == initialHash {
|
||||||
|
log.Debugf("no wakeup detected (hash unchanged: %d, time drift: %s)", initialHash, elapsed)
|
||||||
|
lastCheck = now
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
upOut, err := exec.Command("uptime").Output()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to run uptime command: %v", err)
|
||||||
|
upOut = []byte("unknown")
|
||||||
|
}
|
||||||
|
log.Infof("Wakeup detected via hash change: %d -> %d, uptime: %s", initialHash, newHash, upOut)
|
||||||
|
|
||||||
|
currentV4, errV4 := systemops.GetNextHop(netip.IPv4Unspecified())
|
||||||
|
currentV6, errV6 := systemops.GetNextHop(netip.IPv6Unspecified())
|
||||||
|
if errV4 == nil {
|
||||||
|
log.Infof("Current IPv4 default gateway after wakeup: %s via %s", currentV4.IP, currentV4.Intf.Name)
|
||||||
|
} else {
|
||||||
|
log.Debugf("No IPv4 default gateway after wakeup: %v", errV4)
|
||||||
|
}
|
||||||
|
if errV6 == nil {
|
||||||
|
log.Infof("Current IPv6 default gateway after wakeup: %s via %s", currentV6.IP, currentV6.Intf.Name)
|
||||||
|
} else {
|
||||||
|
log.Debugf("No IPv6 default gateway after wakeup: %v", errV6)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func readSleepTimeHash() (uint32, error) {
|
||||||
|
cmd := exec.Command("sysctl", "kern.sleeptime")
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to run sysctl: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
h, err := hash(out)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to compute hash: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func hash(data []byte) (uint32, error) {
|
||||||
|
hasher := fnv.New32a()
|
||||||
|
if _, err := hasher.Write(data); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return hasher.Sum32(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// gatewayPoll polls the default gateway every 5 seconds to detect changes that might be missed by routing socket or wake-up detection.
|
||||||
|
func gatewayPoll(ctx context.Context, initialV4, initialV6 systemops.Nexthop, changed chan<- string) {
|
||||||
|
ticker := time.NewTicker(5 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
log.Infof("Gateway polling started - initial v4: %s via %v, v6: %s via %v",
|
||||||
|
initialV4.IP, initialV4.Intf, initialV6.IP, initialV6.Intf)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Debug("context canceled, stopping gateway polling")
|
||||||
|
return
|
||||||
|
|
||||||
|
case <-ticker.C:
|
||||||
|
currentV4, errV4 := systemops.GetNextHop(netip.IPv4Unspecified())
|
||||||
|
currentV6, errV6 := systemops.GetNextHop(netip.IPv6Unspecified())
|
||||||
|
|
||||||
|
var reason string
|
||||||
|
|
||||||
|
if errV4 == nil && initialV4.IP.IsValid() {
|
||||||
|
if currentV4.IP.Compare(initialV4.IP) != 0 {
|
||||||
|
reason = fmt.Sprintf("IPv4 gateway changed from %s to %s", initialV4.IP, currentV4.IP)
|
||||||
|
log.Infof("Gateway poll detected change: %s", reason)
|
||||||
|
changed <- reason
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if initialV4.Intf != nil && currentV4.Intf != nil && currentV4.Intf.Name != initialV4.Intf.Name {
|
||||||
|
reason = fmt.Sprintf("IPv4 interface changed from %s to %s", initialV4.Intf.Name, currentV4.Intf.Name)
|
||||||
|
log.Infof("Gateway poll detected change: %s", reason)
|
||||||
|
changed <- reason
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else if errV4 == nil && !initialV4.IP.IsValid() {
|
||||||
|
reason = "IPv4 gateway appeared"
|
||||||
|
log.Infof("Gateway poll detected change: %s (new: %s)", reason, currentV4.IP)
|
||||||
|
changed <- reason
|
||||||
|
return
|
||||||
|
} else if errV4 != nil && initialV4.IP.IsValid() {
|
||||||
|
reason = "IPv4 gateway disappeared"
|
||||||
|
log.Infof("Gateway poll detected change: %s", reason)
|
||||||
|
changed <- reason
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if errV6 == nil && initialV6.IP.IsValid() {
|
||||||
|
if currentV6.IP.Compare(initialV6.IP) != 0 {
|
||||||
|
reason = fmt.Sprintf("IPv6 gateway changed from %s to %s", initialV6.IP, currentV6.IP)
|
||||||
|
log.Infof("Gateway poll detected change: %s", reason)
|
||||||
|
changed <- reason
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if initialV6.Intf != nil && currentV6.Intf != nil && currentV6.Intf.Name != initialV6.Intf.Name {
|
||||||
|
reason = fmt.Sprintf("IPv6 interface changed from %s to %s", initialV6.Intf.Name, currentV6.Intf.Name)
|
||||||
|
log.Infof("Gateway poll detected change: %s", reason)
|
||||||
|
changed <- reason
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else if errV6 == nil && !initialV6.IP.IsValid() {
|
||||||
|
reason = "IPv6 gateway appeared"
|
||||||
|
log.Infof("Gateway poll detected change: %s (new: %s)", reason, currentV6.IP)
|
||||||
|
changed <- reason
|
||||||
|
return
|
||||||
|
} else if errV6 != nil && initialV6.IP.IsValid() {
|
||||||
|
reason = "IPv6 gateway disappeared"
|
||||||
|
log.Infof("Gateway poll detected change: %s", reason)
|
||||||
|
changed <- reason
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Gateway poll: no change detected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -88,6 +88,7 @@ func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
|
|||||||
event := make(chan struct{}, 1)
|
event := make(chan struct{}, 1)
|
||||||
go nw.checkChanges(ctx, event, nexthop4, nexthop6)
|
go nw.checkChanges(ctx, event, nexthop4, nexthop6)
|
||||||
|
|
||||||
|
log.Infof("start watching for network changes")
|
||||||
// debounce changes
|
// debounce changes
|
||||||
timer := time.NewTimer(0)
|
timer := time.NewTimer(0)
|
||||||
timer.Stop()
|
timer.Stop()
|
||||||
|
|||||||
@@ -171,9 +171,9 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
|||||||
|
|
||||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
|
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
|
||||||
|
|
||||||
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
|
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
|
||||||
if !isForceRelayed() {
|
if !isForceRelayed() {
|
||||||
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
|
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)
|
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)
|
||||||
@@ -430,6 +430,9 @@ func (conn *Conn) onICEStateDisconnected() {
|
|||||||
} else {
|
} else {
|
||||||
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
||||||
conn.currentConnPriority = conntype.None
|
conn.currentConnPriority = conntype.None
|
||||||
|
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
|
||||||
|
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
changed := conn.statusICE.Get() != worker.StatusDisconnected
|
changed := conn.statusICE.Get() != worker.StatusDisconnected
|
||||||
@@ -523,6 +526,9 @@ func (conn *Conn) onRelayDisconnected() {
|
|||||||
if conn.currentConnPriority == conntype.Relay {
|
if conn.currentConnPriority == conntype.Relay {
|
||||||
conn.Log.Debugf("clean up WireGuard config")
|
conn.Log.Debugf("clean up WireGuard config")
|
||||||
conn.currentConnPriority = conntype.None
|
conn.currentConnPriority = conntype.None
|
||||||
|
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
|
||||||
|
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.wgProxyRelay != nil {
|
if conn.wgProxyRelay != nil {
|
||||||
|
|||||||
@@ -79,10 +79,10 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
onNewOffeChan := make(chan struct{})
|
onNewOfferChan := make(chan struct{})
|
||||||
|
|
||||||
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
|
conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
|
||||||
onNewOffeChan <- struct{}{}
|
onNewOfferChan <- struct{}{}
|
||||||
})
|
})
|
||||||
|
|
||||||
conn.OnRemoteOffer(OfferAnswer{
|
conn.OnRemoteOffer(OfferAnswer{
|
||||||
@@ -98,7 +98,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-onNewOffeChan:
|
case <-onNewOfferChan:
|
||||||
// success
|
// success
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Error("expected to receive a new offer notification, but timed out")
|
t.Error("expected to receive a new offer notification, but timed out")
|
||||||
@@ -118,10 +118,10 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
onNewOffeChan := make(chan struct{})
|
onNewOfferChan := make(chan struct{})
|
||||||
|
|
||||||
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
|
conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
|
||||||
onNewOffeChan <- struct{}{}
|
onNewOfferChan <- struct{}{}
|
||||||
})
|
})
|
||||||
|
|
||||||
conn.OnRemoteAnswer(OfferAnswer{
|
conn.OnRemoteAnswer(OfferAnswer{
|
||||||
@@ -136,7 +136,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-onNewOffeChan:
|
case <-onNewOfferChan:
|
||||||
// success
|
// success
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Error("expected to receive a new offer notification, but timed out")
|
t.Error("expected to receive a new offer notification, but timed out")
|
||||||
|
|||||||
20
client/internal/peer/guard/env.go
Normal file
20
client/internal/peer/guard/env.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package guard
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
envICEMonitorPeriod = "NB_ICE_MONITOR_PERIOD"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetICEMonitorPeriod() time.Duration {
|
||||||
|
if envVal := os.Getenv(envICEMonitorPeriod); envVal != "" {
|
||||||
|
if seconds, err := strconv.Atoi(envVal); err == nil && seconds > 0 {
|
||||||
|
return time.Duration(seconds) * time.Second
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultCandidatesMonitorPeriod
|
||||||
|
}
|
||||||
@@ -16,8 +16,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
candidatesMonitorPeriod = 5 * time.Minute
|
defaultCandidatesMonitorPeriod = 5 * time.Minute
|
||||||
candidateGatheringTimeout = 5 * time.Second
|
candidateGatheringTimeout = 5 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
type ICEMonitor struct {
|
type ICEMonitor struct {
|
||||||
@@ -25,16 +25,19 @@ type ICEMonitor struct {
|
|||||||
|
|
||||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
iceConfig icemaker.Config
|
iceConfig icemaker.Config
|
||||||
|
tickerPeriod time.Duration
|
||||||
|
|
||||||
currentCandidatesAddress []string
|
currentCandidatesAddress []string
|
||||||
candidatesMu sync.Mutex
|
candidatesMu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor {
|
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config, period time.Duration) *ICEMonitor {
|
||||||
|
log.Debugf("prepare ICE monitor with period: %s", period)
|
||||||
cm := &ICEMonitor{
|
cm := &ICEMonitor{
|
||||||
ReconnectCh: make(chan struct{}, 1),
|
ReconnectCh: make(chan struct{}, 1),
|
||||||
iFaceDiscover: iFaceDiscover,
|
iFaceDiscover: iFaceDiscover,
|
||||||
iceConfig: config,
|
iceConfig: config,
|
||||||
|
tickerPeriod: period,
|
||||||
}
|
}
|
||||||
return cm
|
return cm
|
||||||
}
|
}
|
||||||
@@ -46,7 +49,12 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ticker := time.NewTicker(candidatesMonitorPeriod)
|
// Initial check to populate the candidates for later comparison
|
||||||
|
if _, err := cm.handleCandidateTick(ctx, ufrag, pwd); err != nil {
|
||||||
|
log.Warnf("Failed to check initial ICE candidates: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(cm.tickerPeriod)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
|||||||
@@ -19,11 +19,11 @@ type SRWatcher struct {
|
|||||||
signalClient chNotifier
|
signalClient chNotifier
|
||||||
relayManager chNotifier
|
relayManager chNotifier
|
||||||
|
|
||||||
listeners map[chan struct{}]struct{}
|
listeners map[chan struct{}]struct{}
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
shutdownWg sync.WaitGroup
|
||||||
iceConfig ice.Config
|
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
|
iceConfig ice.Config
|
||||||
cancelIceMonitor context.CancelFunc
|
cancelIceMonitor context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,8 +51,12 @@ func (w *SRWatcher) Start() {
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
w.cancelIceMonitor = cancel
|
w.cancelIceMonitor = cancel
|
||||||
|
|
||||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig)
|
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
w.shutdownWg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer w.shutdownWg.Done()
|
||||||
|
iceMonitor.Start(ctx, w.onICEChanged)
|
||||||
|
}()
|
||||||
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
||||||
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
||||||
|
|
||||||
@@ -60,14 +64,16 @@ func (w *SRWatcher) Start() {
|
|||||||
|
|
||||||
func (w *SRWatcher) Close() {
|
func (w *SRWatcher) Close() {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
|
||||||
|
|
||||||
if w.cancelIceMonitor == nil {
|
if w.cancelIceMonitor == nil {
|
||||||
|
w.mu.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.cancelIceMonitor()
|
w.cancelIceMonitor()
|
||||||
w.signalClient.SetOnReconnectedListener(nil)
|
w.signalClient.SetOnReconnectedListener(nil)
|
||||||
w.relayManager.SetOnReconnectedListener(nil)
|
w.relayManager.SetOnReconnectedListener(nil)
|
||||||
|
w.mu.Unlock()
|
||||||
|
|
||||||
|
w.shutdownWg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *SRWatcher) NewListener() chan struct{} {
|
func (w *SRWatcher) NewListener() chan struct{} {
|
||||||
|
|||||||
@@ -44,13 +44,19 @@ type OfferAnswer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Handshaker struct {
|
type Handshaker struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
config ConnConfig
|
config ConnConfig
|
||||||
signaler *Signaler
|
signaler *Signaler
|
||||||
ice *WorkerICE
|
ice *WorkerICE
|
||||||
relay *WorkerRelay
|
relay *WorkerRelay
|
||||||
onNewOfferListeners []*OfferListener
|
// relayListener is not blocking because the listener is using a goroutine to process the messages
|
||||||
|
// and it will only keep the latest message if multiple offers are received in a short time
|
||||||
|
// this is to avoid blocking the handshaker if the listener is doing some heavy processing
|
||||||
|
// and also to avoid processing old offers if multiple offers are received in a short time
|
||||||
|
// the listener will always process the latest offer
|
||||||
|
relayListener *AsyncOfferListener
|
||||||
|
iceListener func(remoteOfferAnswer *OfferAnswer)
|
||||||
|
|
||||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||||
remoteOffersCh chan OfferAnswer
|
remoteOffersCh chan OfferAnswer
|
||||||
@@ -70,28 +76,39 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||||
l := NewOfferListener(offer)
|
h.relayListener = NewAsyncOfferListener(offer)
|
||||||
h.onNewOfferListeners = append(h.onNewOfferListeners, l)
|
}
|
||||||
|
|
||||||
|
func (h *Handshaker) AddICEListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||||
|
h.iceListener = offer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshaker) Listen(ctx context.Context) {
|
func (h *Handshaker) Listen(ctx context.Context) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case remoteOfferAnswer := <-h.remoteOffersCh:
|
case remoteOfferAnswer := <-h.remoteOffersCh:
|
||||||
// received confirmation from the remote peer -> ready to proceed
|
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||||
|
if h.relayListener != nil {
|
||||||
|
h.relayListener.Notify(&remoteOfferAnswer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.iceListener != nil {
|
||||||
|
h.iceListener(&remoteOfferAnswer)
|
||||||
|
}
|
||||||
|
|
||||||
if err := h.sendAnswer(); err != nil {
|
if err := h.sendAnswer(); err != nil {
|
||||||
h.log.Errorf("failed to send remote offer confirmation: %s", err)
|
h.log.Errorf("failed to send remote offer confirmation: %s", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for _, listener := range h.onNewOfferListeners {
|
|
||||||
listener.Notify(&remoteOfferAnswer)
|
|
||||||
}
|
|
||||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
|
||||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||||
for _, listener := range h.onNewOfferListeners {
|
if h.relayListener != nil {
|
||||||
listener.Notify(&remoteOfferAnswer)
|
h.relayListener.Notify(&remoteOfferAnswer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.iceListener != nil {
|
||||||
|
h.iceListener(&remoteOfferAnswer)
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
h.log.Infof("stop listening for remote offers and answers")
|
h.log.Infof("stop listening for remote offers and answers")
|
||||||
|
|||||||
@@ -13,20 +13,20 @@ func (oa *OfferAnswer) SessionIDString() string {
|
|||||||
return oa.SessionID.String()
|
return oa.SessionID.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
type OfferListener struct {
|
type AsyncOfferListener struct {
|
||||||
fn callbackFunc
|
fn callbackFunc
|
||||||
running bool
|
running bool
|
||||||
latest *OfferAnswer
|
latest *OfferAnswer
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOfferListener(fn callbackFunc) *OfferListener {
|
func NewAsyncOfferListener(fn callbackFunc) *AsyncOfferListener {
|
||||||
return &OfferListener{
|
return &AsyncOfferListener{
|
||||||
fn: fn,
|
fn: fn,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
|
func (o *AsyncOfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
|
||||||
o.mu.Lock()
|
o.mu.Lock()
|
||||||
defer o.mu.Unlock()
|
defer o.mu.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ func Test_newOfferListener(t *testing.T) {
|
|||||||
runChan <- struct{}{}
|
runChan <- struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
hl := NewOfferListener(longRunningFn)
|
hl := NewAsyncOfferListener(longRunningFn)
|
||||||
|
|
||||||
hl.Notify(dummyOfferAnswer)
|
hl.Notify(dummyOfferAnswer)
|
||||||
hl.Notify(dummyOfferAnswer)
|
hl.Notify(dummyOfferAnswer)
|
||||||
|
|||||||
@@ -18,4 +18,5 @@ type WGIface interface {
|
|||||||
GetStats() (map[string]configurer.WGStats, error)
|
GetStats() (map[string]configurer.WGStats, error)
|
||||||
GetProxy() wgproxy.Proxy
|
GetProxy() wgproxy.Proxy
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
|
RemoveEndpointAddress(key string) error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -92,23 +92,16 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
|
|||||||
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||||
w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString())
|
w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString())
|
||||||
w.muxAgent.Lock()
|
w.muxAgent.Lock()
|
||||||
|
defer w.muxAgent.Unlock()
|
||||||
|
|
||||||
if w.agentConnecting {
|
if w.agent != nil || w.agentConnecting {
|
||||||
w.log.Debugf("agent connection is in progress, skipping the offer")
|
|
||||||
w.muxAgent.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if w.agent != nil {
|
|
||||||
// backward compatibility with old clients that do not send session ID
|
// backward compatibility with old clients that do not send session ID
|
||||||
if remoteOfferAnswer.SessionID == nil {
|
if remoteOfferAnswer.SessionID == nil {
|
||||||
w.log.Debugf("agent already exists, skipping the offer")
|
w.log.Debugf("agent already exists, skipping the offer")
|
||||||
w.muxAgent.Unlock()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if w.remoteSessionID == *remoteOfferAnswer.SessionID {
|
if w.remoteSessionID == *remoteOfferAnswer.SessionID {
|
||||||
w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString())
|
w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString())
|
||||||
w.muxAgent.Unlock()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.log.Debugf("agent already exists, recreate the connection")
|
w.log.Debugf("agent already exists, recreate the connection")
|
||||||
@@ -116,6 +109,12 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
if err := w.agent.Close(); err != nil {
|
if err := w.agent.Close(); err != nil {
|
||||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sessionID, err := NewICESessionID()
|
||||||
|
if err != nil {
|
||||||
|
w.log.Errorf("failed to create new session ID: %s", err)
|
||||||
|
}
|
||||||
|
w.sessionID = sessionID
|
||||||
w.agent = nil
|
w.agent = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,18 +125,23 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
preferredCandidateTypes = icemaker.CandidateTypes()
|
preferredCandidateTypes = icemaker.CandidateTypes()
|
||||||
}
|
}
|
||||||
|
|
||||||
w.log.Debugf("recreate ICE agent")
|
if remoteOfferAnswer.SessionID != nil {
|
||||||
|
w.log.Debugf("recreate ICE agent: %s / %s", w.sessionID, *remoteOfferAnswer.SessionID)
|
||||||
|
}
|
||||||
dialerCtx, dialerCancel := context.WithCancel(w.ctx)
|
dialerCtx, dialerCancel := context.WithCancel(w.ctx)
|
||||||
agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes)
|
agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.log.Errorf("failed to recreate ICE Agent: %s", err)
|
w.log.Errorf("failed to recreate ICE Agent: %s", err)
|
||||||
w.muxAgent.Unlock()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.agent = agent
|
w.agent = agent
|
||||||
w.agentDialerCancel = dialerCancel
|
w.agentDialerCancel = dialerCancel
|
||||||
w.agentConnecting = true
|
w.agentConnecting = true
|
||||||
w.muxAgent.Unlock()
|
if remoteOfferAnswer.SessionID != nil {
|
||||||
|
w.remoteSessionID = *remoteOfferAnswer.SessionID
|
||||||
|
} else {
|
||||||
|
w.remoteSessionID = ""
|
||||||
|
}
|
||||||
|
|
||||||
go w.connect(dialerCtx, agent, remoteOfferAnswer)
|
go w.connect(dialerCtx, agent, remoteOfferAnswer)
|
||||||
}
|
}
|
||||||
@@ -293,9 +297,6 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
|
|||||||
w.muxAgent.Lock()
|
w.muxAgent.Lock()
|
||||||
w.agentConnecting = false
|
w.agentConnecting = false
|
||||||
w.lastSuccess = time.Now()
|
w.lastSuccess = time.Now()
|
||||||
if remoteOfferAnswer.SessionID != nil {
|
|
||||||
w.remoteSessionID = *remoteOfferAnswer.SessionID
|
|
||||||
}
|
|
||||||
w.muxAgent.Unlock()
|
w.muxAgent.Unlock()
|
||||||
|
|
||||||
// todo: the potential problem is a race between the onConnectionStateChange
|
// todo: the potential problem is a race between the onConnectionStateChange
|
||||||
@@ -309,16 +310,17 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
w.muxAgent.Lock()
|
w.muxAgent.Lock()
|
||||||
// todo review does it make sense to generate new session ID all the time when w.agent==agent
|
|
||||||
sessionID, err := NewICESessionID()
|
|
||||||
if err != nil {
|
|
||||||
w.log.Errorf("failed to create new session ID: %s", err)
|
|
||||||
}
|
|
||||||
w.sessionID = sessionID
|
|
||||||
|
|
||||||
if w.agent == agent {
|
if w.agent == agent {
|
||||||
|
// consider to remove from here and move to the OnNewOffer
|
||||||
|
sessionID, err := NewICESessionID()
|
||||||
|
if err != nil {
|
||||||
|
w.log.Errorf("failed to create new session ID: %s", err)
|
||||||
|
}
|
||||||
|
w.sessionID = sessionID
|
||||||
w.agent = nil
|
w.agent = nil
|
||||||
w.agentConnecting = false
|
w.agentConnecting = false
|
||||||
|
w.remoteSessionID = ""
|
||||||
}
|
}
|
||||||
w.muxAgent.Unlock()
|
w.muxAgent.Unlock()
|
||||||
}
|
}
|
||||||
@@ -395,11 +397,12 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
|||||||
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
||||||
// notify the conn.onICEStateDisconnected changes to update the current used priority
|
// notify the conn.onICEStateDisconnected changes to update the current used priority
|
||||||
|
|
||||||
|
w.closeAgent(agent, dialerCancel)
|
||||||
|
|
||||||
if w.lastKnownState == ice.ConnectionStateConnected {
|
if w.lastKnownState == ice.ConnectionStateConnected {
|
||||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||||
w.conn.onICEStateDisconnected()
|
w.conn.onICEStateDisconnected()
|
||||||
}
|
}
|
||||||
w.closeAgent(agent, dialerCancel)
|
|
||||||
default:
|
default:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -195,6 +195,7 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
|||||||
config := &Config{
|
config := &Config{
|
||||||
// defaults to false only for new (post 0.26) configurations
|
// defaults to false only for new (post 0.26) configurations
|
||||||
ServerSSHAllowed: util.False(),
|
ServerSSHAllowed: util.False(),
|
||||||
|
WgPort: iface.DefaultWgPort,
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := config.apply(input); err != nil {
|
if _, err := config.apply(input); err != nil {
|
||||||
|
|||||||
@@ -5,11 +5,14 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -141,6 +144,95 @@ func TestHiddenPreSharedKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewProfileDefaults(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tempDir, "config.json")
|
||||||
|
|
||||||
|
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||||
|
ConfigPath: configPath,
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "should create new config")
|
||||||
|
|
||||||
|
assert.Equal(t, DefaultManagementURL, config.ManagementURL.String(), "ManagementURL should have default")
|
||||||
|
assert.Equal(t, DefaultAdminURL, config.AdminURL.String(), "AdminURL should have default")
|
||||||
|
assert.NotEmpty(t, config.PrivateKey, "PrivateKey should be generated")
|
||||||
|
assert.NotEmpty(t, config.SSHKey, "SSHKey should be generated")
|
||||||
|
assert.Equal(t, iface.WgInterfaceDefault, config.WgIface, "WgIface should have default")
|
||||||
|
assert.Equal(t, iface.DefaultWgPort, config.WgPort, "WgPort should default to 51820")
|
||||||
|
assert.Equal(t, uint16(iface.DefaultMTU), config.MTU, "MTU should have default")
|
||||||
|
assert.Equal(t, dynamic.DefaultInterval, config.DNSRouteInterval, "DNSRouteInterval should have default")
|
||||||
|
assert.NotNil(t, config.ServerSSHAllowed, "ServerSSHAllowed should be set")
|
||||||
|
assert.NotNil(t, config.DisableNotifications, "DisableNotifications should be set")
|
||||||
|
assert.NotEmpty(t, config.IFaceBlackList, "IFaceBlackList should have defaults")
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
|
||||||
|
assert.NotNil(t, config.NetworkMonitor, "NetworkMonitor should be set on Windows/macOS")
|
||||||
|
assert.True(t, *config.NetworkMonitor, "NetworkMonitor should be enabled by default on Windows/macOS")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWireguardPortZeroExplicit(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tempDir, "config.json")
|
||||||
|
|
||||||
|
// Create a new profile with explicit port 0 (random port)
|
||||||
|
explicitZero := 0
|
||||||
|
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||||
|
ConfigPath: configPath,
|
||||||
|
WireguardPort: &explicitZero,
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "should create config with explicit port 0")
|
||||||
|
|
||||||
|
assert.Equal(t, 0, config.WgPort, "WgPort should be 0 when explicitly set by user")
|
||||||
|
|
||||||
|
// Verify it persists
|
||||||
|
readConfig, err := GetConfig(configPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, readConfig.WgPort, "WgPort should remain 0 after reading from file")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWireguardPortDefaultVsExplicit(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
wireguardPort *int
|
||||||
|
expectedPort int
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no port specified uses default",
|
||||||
|
wireguardPort: nil,
|
||||||
|
expectedPort: iface.DefaultWgPort,
|
||||||
|
description: "When user doesn't specify port, default to 51820",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit zero for random port",
|
||||||
|
wireguardPort: func() *int { v := 0; return &v }(),
|
||||||
|
expectedPort: 0,
|
||||||
|
description: "When user explicitly sets 0, use 0 for random port",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "explicit custom port",
|
||||||
|
wireguardPort: func() *int { v := 52000; return &v }(),
|
||||||
|
expectedPort: 52000,
|
||||||
|
description: "When user sets custom port, use that port",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tempDir, "config.json")
|
||||||
|
|
||||||
|
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||||
|
ConfigPath: configPath,
|
||||||
|
WireguardPort: tt.wireguardPort,
|
||||||
|
})
|
||||||
|
require.NoError(t, err, tt.description)
|
||||||
|
assert.Equal(t, tt.expectedPort, config.WgPort, tt.description)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestUpdateOldManagementURL(t *testing.T) {
|
func TestUpdateOldManagementURL(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ type DefaultManager struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
stop context.CancelFunc
|
stop context.CancelFunc
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
|
shutdownWg sync.WaitGroup
|
||||||
clientNetworks map[route.HAUniqueID]*client.Watcher
|
clientNetworks map[route.HAUniqueID]*client.Watcher
|
||||||
routeSelector *routeselector.RouteSelector
|
routeSelector *routeselector.RouteSelector
|
||||||
serverRouter *server.Router
|
serverRouter *server.Router
|
||||||
@@ -106,7 +107,7 @@ type DefaultManager struct {
|
|||||||
func NewManager(config ManagerConfig) *DefaultManager {
|
func NewManager(config ManagerConfig) *DefaultManager {
|
||||||
mCTX, cancel := context.WithCancel(config.Context)
|
mCTX, cancel := context.WithCancel(config.Context)
|
||||||
notifier := notifier.NewNotifier()
|
notifier := notifier.NewNotifier()
|
||||||
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
|
sysOps := systemops.New(config.WGInterface, notifier)
|
||||||
|
|
||||||
if runtime.GOOS == "windows" && config.WGInterface != nil {
|
if runtime.GOOS == "windows" && config.WGInterface != nil {
|
||||||
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
|
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
|
||||||
@@ -273,6 +274,7 @@ func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error {
|
|||||||
// Stop stops the manager watchers and clean firewall rules
|
// Stop stops the manager watchers and clean firewall rules
|
||||||
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
||||||
m.stop()
|
m.stop()
|
||||||
|
m.shutdownWg.Wait()
|
||||||
if m.serverRouter != nil {
|
if m.serverRouter != nil {
|
||||||
m.serverRouter.CleanUp()
|
m.serverRouter.CleanUp()
|
||||||
}
|
}
|
||||||
@@ -474,7 +476,11 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
|||||||
}
|
}
|
||||||
clientNetworkWatcher := client.NewWatcher(config)
|
clientNetworkWatcher := client.NewWatcher(config)
|
||||||
m.clientNetworks[id] = clientNetworkWatcher
|
m.clientNetworks[id] = clientNetworkWatcher
|
||||||
go clientNetworkWatcher.Start()
|
m.shutdownWg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer m.shutdownWg.Done()
|
||||||
|
clientNetworkWatcher.Start()
|
||||||
|
}()
|
||||||
clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes})
|
clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -516,7 +522,11 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
|
|||||||
}
|
}
|
||||||
clientNetworkWatcher = client.NewWatcher(config)
|
clientNetworkWatcher = client.NewWatcher(config)
|
||||||
m.clientNetworks[id] = clientNetworkWatcher
|
m.clientNetworks[id] = clientNetworkWatcher
|
||||||
go clientNetworkWatcher.Start()
|
m.shutdownWg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer m.shutdownWg.Done()
|
||||||
|
clientNetworkWatcher.Start()
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
update := client.RoutesUpdate{
|
update := client.RoutesUpdate{
|
||||||
UpdateSerial: updateSerial,
|
UpdateSerial: updateSerial,
|
||||||
|
|||||||
8
client/internal/routemanager/systemops/flush_nonbsd.go
Normal file
8
client/internal/routemanager/systemops/flush_nonbsd.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
//go:build !((darwin && !ios) || dragonfly || freebsd || netbsd || openbsd)
|
||||||
|
|
||||||
|
package systemops
|
||||||
|
|
||||||
|
// FlushMarkedRoutes is a no-op on non-BSD platforms.
|
||||||
|
func (r *SysOps) FlushMarkedRoutes() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -13,11 +13,11 @@ func (s *ShutdownState) Name() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) Cleanup() error {
|
func (s *ShutdownState) Cleanup() error {
|
||||||
sysops := NewSysOps(nil, nil)
|
sysOps := New(nil, nil)
|
||||||
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
|
sysOps.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysOps.removeFromRouteTable)
|
||||||
sysops.refCounter.LoadData((*ExclusionCounter)(s))
|
sysOps.refCounter.LoadData((*ExclusionCounter)(s))
|
||||||
|
|
||||||
return sysops.refCounter.Flush()
|
return sysOps.refCounter.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
|
func (s *ShutdownState) MarshalJSON() ([]byte, error) {
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ type SysOps struct {
|
|||||||
localSubnetsCacheTime time.Time
|
localSubnetsCacheTime time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
func New(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
||||||
return &SysOps{
|
return &SysOps{
|
||||||
wgInterface: wgInterface,
|
wgInterface: wgInterface,
|
||||||
notifier: notifier,
|
notifier: notifier,
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ func TestConcurrentRoutes(t *testing.T) {
|
|||||||
_, intf = setupDummyInterface(t)
|
_, intf = setupDummyInterface(t)
|
||||||
nexthop = Nexthop{netip.Addr{}, intf}
|
nexthop = Nexthop{netip.Addr{}, intf}
|
||||||
|
|
||||||
r := NewSysOps(nil, nil)
|
r := New(nil, nil)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
for i := 0; i < 1024; i++ {
|
for i := 0; i < 1024; i++ {
|
||||||
@@ -146,7 +146,7 @@ func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR strin
|
|||||||
|
|
||||||
nexthop := Nexthop{netip.Addr{}, netIntf}
|
nexthop := Nexthop{netip.Addr{}, netIntf}
|
||||||
|
|
||||||
r := NewSysOps(nil, nil)
|
r := New(nil, nil)
|
||||||
err = r.addToRouteTable(prefix, nexthop)
|
err = r.addToRouteTable(prefix, nexthop)
|
||||||
require.NoError(t, err, "Failed to add route to table")
|
require.NoError(t, err, "Failed to add route to table")
|
||||||
|
|
||||||
|
|||||||
@@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) {
|
|||||||
|
|
||||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := New(wgInterface, nil)
|
||||||
advancedRouting := nbnet.AdvancedRouting()
|
advancedRouting := nbnet.AdvancedRouting()
|
||||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -342,7 +342,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
|
|||||||
|
|
||||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := New(wgInterface, nil)
|
||||||
advancedRouting := nbnet.AdvancedRouting()
|
advancedRouting := nbnet.AdvancedRouting()
|
||||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -486,7 +486,7 @@ func setupTestEnv(t *testing.T) {
|
|||||||
assert.NoError(t, wgInterface.Close())
|
assert.NoError(t, wgInterface.Close())
|
||||||
})
|
})
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := New(wgInterface, nil)
|
||||||
advancedRouting := nbnet.AdvancedRouting()
|
advancedRouting := nbnet.AdvancedRouting()
|
||||||
err := r.SetupRouting(nil, nil, advancedRouting)
|
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||||
require.NoError(t, err, "setupRouting should not return err")
|
require.NoError(t, err, "setupRouting should not return err")
|
||||||
|
|||||||
@@ -7,19 +7,39 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/route"
|
"golang.org/x/net/route"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
|
||||||
|
)
|
||||||
|
|
||||||
|
var routeProtoFlag int
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
switch os.Getenv(envRouteProtoFlag) {
|
||||||
|
case "2":
|
||||||
|
routeProtoFlag = unix.RTF_PROTO2
|
||||||
|
case "3":
|
||||||
|
routeProtoFlag = unix.RTF_PROTO3
|
||||||
|
default:
|
||||||
|
routeProtoFlag = unix.RTF_PROTO1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||||
return r.setupRefCounter(initAddresses, stateManager)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
@@ -28,12 +48,88 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRout
|
|||||||
return r.cleanupRefCounter(stateManager)
|
return r.cleanupRefCounter(stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
|
||||||
|
func (r *SysOps) FlushMarkedRoutes() error {
|
||||||
|
rib, err := retryFetchRIB()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("fetch routing table: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse routing table: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
flushedCount := 0
|
||||||
|
|
||||||
|
for _, msg := range msgs {
|
||||||
|
rtMsg, ok := msg.(*route.RouteMessage)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if rtMsg.Flags&routeProtoFlag == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
routeInfo, err := MsgToRoute(rtMsg)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Skipping route flush: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !routeInfo.Dst.IsValid() || !routeInfo.Dst.IsSingleIP() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
nexthop := Nexthop{
|
||||||
|
IP: routeInfo.Gw,
|
||||||
|
Intf: routeInfo.Interface,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.removeFromRouteTable(routeInfo.Dst, nexthop); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", routeInfo.Dst, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
flushedCount++
|
||||||
|
log.Debugf("Flushed marked route: %s", routeInfo.Dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
if flushedCount > 0 {
|
||||||
|
log.Infof("Flushed %d residual NetBird routes from previous session", flushedCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||||
|
if prefix.IsSingleIP() {
|
||||||
|
log.Debugf("Adding single IP route: %s via %s", prefix, formatNexthop(nexthop))
|
||||||
|
}
|
||||||
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
|
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||||
return r.routeSocket(unix.RTM_DELETE, prefix, nexthop)
|
if prefix.IsSingleIP() {
|
||||||
|
log.Debugf("Removing single IP route: %s via %s", prefix, formatNexthop(nexthop))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.routeSocket(unix.RTM_DELETE, prefix, nexthop); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix.IsSingleIP() {
|
||||||
|
log.Debugf("Route removal completed for %s, verifying...", prefix)
|
||||||
|
if exists := r.verifyRouteRemoved(prefix); exists {
|
||||||
|
log.Warnf("Route %s still exists in routing table after removal", prefix)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Verified route %s successfully removed", prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) error {
|
func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) error {
|
||||||
@@ -105,7 +201,7 @@ func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func(
|
|||||||
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
|
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
|
||||||
msg = &route.RouteMessage{
|
msg = &route.RouteMessage{
|
||||||
Type: action,
|
Type: action,
|
||||||
Flags: unix.RTF_UP,
|
Flags: unix.RTF_UP | routeProtoFlag,
|
||||||
Version: unix.RTM_VERSION,
|
Version: unix.RTM_VERSION,
|
||||||
Seq: r.getSeq(),
|
Seq: r.getSeq(),
|
||||||
}
|
}
|
||||||
@@ -200,3 +296,51 @@ func prefixToRouteNetmask(prefix netip.Prefix) (route.Addr, error) {
|
|||||||
|
|
||||||
return nil, fmt.Errorf("unknown IP version in prefix: %s", prefix.Addr().String())
|
return nil, fmt.Errorf("unknown IP version in prefix: %s", prefix.Addr().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// formatNexthop returns a string representation of the nexthop for logging.
|
||||||
|
func formatNexthop(nexthop Nexthop) string {
|
||||||
|
if nexthop.IP.IsValid() {
|
||||||
|
return nexthop.IP.String()
|
||||||
|
}
|
||||||
|
if nexthop.Intf != nil {
|
||||||
|
return nexthop.Intf.Name
|
||||||
|
}
|
||||||
|
return "direct"
|
||||||
|
}
|
||||||
|
|
||||||
|
// verifyRouteRemoved checks if a route still exists in the routing table.
|
||||||
|
func (r *SysOps) verifyRouteRemoved(prefix netip.Prefix) bool {
|
||||||
|
rib, err := retryFetchRIB()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Failed to fetch RIB for route verification: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Failed to parse RIB for route verification: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, msg := range msgs {
|
||||||
|
rtMsg, ok := msg.(*route.RouteMessage)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if rtMsg.Flags&routeProtoFlag == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
routeInfo, err := MsgToRoute(rtMsg)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if routeInfo.Dst == prefix {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -295,7 +295,7 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage,
|
|||||||
data, err := os.ReadFile(m.filePath)
|
data, err := os.ReadFile(m.filePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, fs.ErrNotExist) {
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
log.Debug("state file does not exist")
|
log.Debugf("state file %s does not exist", m.filePath)
|
||||||
return nil, nil // nolint:nilnil
|
return nil, nil // nolint:nilnil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("read state file: %w", err)
|
return nil, fmt.Errorf("read state file: %w", err)
|
||||||
|
|||||||
59
client/internal/winregistry/volatile_windows.go
Normal file
59
client/internal/winregistry/volatile_windows.go
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
package winregistry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows/registry"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
advapi = syscall.NewLazyDLL("advapi32.dll")
|
||||||
|
regCreateKeyExW = advapi.NewProc("RegCreateKeyExW")
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Registry key options
|
||||||
|
regOptionNonVolatile = 0x0 // Key is preserved when system is rebooted
|
||||||
|
regOptionVolatile = 0x1 // Key is not preserved when system is rebooted
|
||||||
|
|
||||||
|
// Registry disposition values
|
||||||
|
regCreatedNewKey = 0x1
|
||||||
|
regOpenedExistingKey = 0x2
|
||||||
|
)
|
||||||
|
|
||||||
|
// CreateVolatileKey creates a volatile registry key named path under open key root.
|
||||||
|
// CreateVolatileKey returns the new key and a boolean flag that reports whether the key already existed.
|
||||||
|
// The access parameter specifies the access rights for the key to be created.
|
||||||
|
//
|
||||||
|
// Volatile keys are stored in memory and are automatically deleted when the system is shut down.
|
||||||
|
// This provides automatic cleanup without requiring manual registry maintenance.
|
||||||
|
func CreateVolatileKey(root registry.Key, path string, access uint32) (registry.Key, bool, error) {
|
||||||
|
pathPtr, err := syscall.UTF16PtrFromString(path)
|
||||||
|
if err != nil {
|
||||||
|
return 0, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
handle syscall.Handle
|
||||||
|
disposition uint32
|
||||||
|
)
|
||||||
|
|
||||||
|
ret, _, _ := regCreateKeyExW.Call(
|
||||||
|
uintptr(root),
|
||||||
|
uintptr(unsafe.Pointer(pathPtr)),
|
||||||
|
0, // reserved
|
||||||
|
0, // class
|
||||||
|
uintptr(regOptionVolatile), // options - volatile key
|
||||||
|
uintptr(access), // desired access
|
||||||
|
0, // security attributes
|
||||||
|
uintptr(unsafe.Pointer(&handle)),
|
||||||
|
uintptr(unsafe.Pointer(&disposition)),
|
||||||
|
)
|
||||||
|
|
||||||
|
if ret != 0 {
|
||||||
|
return 0, false, syscall.Errno(ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
return registry.Key(handle), disposition == regOpenedExistingKey, nil
|
||||||
|
}
|
||||||
@@ -17,8 +17,7 @@ type Conn struct {
|
|||||||
ID hooks.ConnectionID
|
ID hooks.ConnectionID
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection.
|
||||||
// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection.
|
|
||||||
func (c *Conn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
return closeConn(c.ID, c.Conn)
|
return closeConn(c.ID, c.Conn)
|
||||||
}
|
}
|
||||||
@@ -29,7 +28,7 @@ type TCPConn struct {
|
|||||||
ID hooks.ConnectionID
|
ID hooks.ConnectionID
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection.
|
// Close overrides the net.TCPConn Close method to execute all registered hooks after closing the connection.
|
||||||
func (c *TCPConn) Close() error {
|
func (c *TCPConn) Close() error {
|
||||||
return closeConn(c.ID, c.TCPConn)
|
return closeConn(c.ID, c.TCPConn)
|
||||||
}
|
}
|
||||||
@@ -37,13 +36,16 @@ func (c *TCPConn) Close() error {
|
|||||||
// closeConn is a helper function to close connections and execute close hooks.
|
// closeConn is a helper function to close connections and execute close hooks.
|
||||||
func closeConn(id hooks.ConnectionID, conn io.Closer) error {
|
func closeConn(id hooks.ConnectionID, conn io.Closer) error {
|
||||||
err := conn.Close()
|
err := conn.Close()
|
||||||
|
cleanupConnID(id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupConnID executes close hooks for a connection ID.
|
||||||
|
func cleanupConnID(id hooks.ConnectionID) {
|
||||||
closeHooks := hooks.GetCloseHooks()
|
closeHooks := hooks.GetCloseHooks()
|
||||||
for _, hook := range closeHooks {
|
for _, hook := range closeHooks {
|
||||||
if err := hook(id); err != nil {
|
if err := hook(id); err != nil {
|
||||||
log.Errorf("Error executing close hook: %v", err)
|
log.Errorf("Error executing close hook: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -74,7 +74,6 @@ func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, erro
|
|||||||
}
|
}
|
||||||
return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil
|
return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := conn.Close(); err != nil {
|
if err := conn.Close(); err != nil {
|
||||||
log.Errorf("failed to close connection: %v", err)
|
log.Errorf("failed to close connection: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
|
|||||||
|
|
||||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
cleanupConnID(connID)
|
||||||
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
|
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,7 +65,7 @@ func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address str
|
|||||||
|
|
||||||
ips, err := resolver.LookupIPAddr(ctx, host)
|
ips, err := resolver.LookupIPAddr(ctx, host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to resolve address %s: %w", address, err)
|
return fmt.Errorf("resolve address %s: %w", address, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
|
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
|||||||
return c.PacketConn.WriteTo(b, addr)
|
return c.PacketConn.WriteTo(b, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
|
// Close overrides the net.PacketConn Close method to execute all registered hooks after closing the connection.
|
||||||
func (c *PacketConn) Close() error {
|
func (c *PacketConn) Close() error {
|
||||||
defer c.seenAddrs.Clear()
|
defer c.seenAddrs.Clear()
|
||||||
return closeConn(c.ID, c.PacketConn)
|
return closeConn(c.ID, c.PacketConn)
|
||||||
@@ -69,7 +69,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
|||||||
return c.UDPConn.WriteTo(b, addr)
|
return c.UDPConn.WriteTo(b, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
|
// Close overrides the net.UDPConn Close method to execute all registered hooks after closing the connection.
|
||||||
func (c *UDPConn) Close() error {
|
func (c *UDPConn) Close() error {
|
||||||
defer c.seenAddrs.Clear()
|
defer c.seenAddrs.Clear()
|
||||||
return closeConn(c.ID, c.UDPConn)
|
return closeConn(c.ID, c.UDPConn)
|
||||||
|
|||||||
@@ -353,6 +353,13 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
|||||||
config.CustomDNSAddress = []byte{}
|
config.CustomDNSAddress = []byte{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
config.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
|
||||||
|
|
||||||
|
if msg.DnsRouteInterval != nil {
|
||||||
|
interval := msg.DnsRouteInterval.AsDuration()
|
||||||
|
config.DNSRouteInterval = &interval
|
||||||
|
}
|
||||||
|
|
||||||
config.RosenpassEnabled = msg.RosenpassEnabled
|
config.RosenpassEnabled = msg.RosenpassEnabled
|
||||||
config.RosenpassPermissive = msg.RosenpassPermissive
|
config.RosenpassPermissive = msg.RosenpassPermissive
|
||||||
config.DisableAutoConnect = msg.DisableAutoConnect
|
config.DisableAutoConnect = msg.DisableAutoConnect
|
||||||
|
|||||||
298
client/server/setconfig_test.go
Normal file
298
client/server/setconfig_test.go
Normal file
@@ -0,0 +1,298 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/protobuf/types/known/durationpb"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestSetConfig_AllFieldsSaved ensures that all fields in SetConfigRequest are properly saved to the config.
|
||||||
|
// This test uses reflection to detect when new fields are added but not handled in SetConfig.
|
||||||
|
func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
|
||||||
|
origDefaultConfigPath := profilemanager.DefaultConfigPath
|
||||||
|
origActiveProfileStatePath := profilemanager.ActiveProfileStatePath
|
||||||
|
profilemanager.ConfigDirOverride = tempDir
|
||||||
|
profilemanager.DefaultConfigPathDir = tempDir
|
||||||
|
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
|
||||||
|
profilemanager.DefaultConfigPath = filepath.Join(tempDir, "default.json")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
profilemanager.DefaultConfigPathDir = origDefaultProfileDir
|
||||||
|
profilemanager.ActiveProfileStatePath = origActiveProfileStatePath
|
||||||
|
profilemanager.DefaultConfigPath = origDefaultConfigPath
|
||||||
|
profilemanager.ConfigDirOverride = ""
|
||||||
|
})
|
||||||
|
|
||||||
|
currUser, err := user.Current()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
profName := "test-profile"
|
||||||
|
|
||||||
|
ic := profilemanager.ConfigInput{
|
||||||
|
ConfigPath: filepath.Join(tempDir, profName+".json"),
|
||||||
|
ManagementURL: "https://api.netbird.io:443",
|
||||||
|
}
|
||||||
|
_, err = profilemanager.UpdateOrCreateConfig(ic)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pm := profilemanager.ServiceManager{}
|
||||||
|
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||||
|
Name: profName,
|
||||||
|
Username: currUser.Username,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
s := New(ctx, "console", "", false, false)
|
||||||
|
|
||||||
|
rosenpassEnabled := true
|
||||||
|
rosenpassPermissive := true
|
||||||
|
serverSSHAllowed := true
|
||||||
|
interfaceName := "utun100"
|
||||||
|
wireguardPort := int64(51820)
|
||||||
|
preSharedKey := "test-psk"
|
||||||
|
disableAutoConnect := true
|
||||||
|
networkMonitor := true
|
||||||
|
disableClientRoutes := true
|
||||||
|
disableServerRoutes := true
|
||||||
|
disableDNS := true
|
||||||
|
disableFirewall := true
|
||||||
|
blockLANAccess := true
|
||||||
|
disableNotifications := true
|
||||||
|
lazyConnectionEnabled := true
|
||||||
|
blockInbound := true
|
||||||
|
mtu := int64(1280)
|
||||||
|
|
||||||
|
req := &proto.SetConfigRequest{
|
||||||
|
ProfileName: profName,
|
||||||
|
Username: currUser.Username,
|
||||||
|
ManagementUrl: "https://new-api.netbird.io:443",
|
||||||
|
AdminURL: "https://new-admin.netbird.io",
|
||||||
|
RosenpassEnabled: &rosenpassEnabled,
|
||||||
|
RosenpassPermissive: &rosenpassPermissive,
|
||||||
|
ServerSSHAllowed: &serverSSHAllowed,
|
||||||
|
InterfaceName: &interfaceName,
|
||||||
|
WireguardPort: &wireguardPort,
|
||||||
|
OptionalPreSharedKey: &preSharedKey,
|
||||||
|
DisableAutoConnect: &disableAutoConnect,
|
||||||
|
NetworkMonitor: &networkMonitor,
|
||||||
|
DisableClientRoutes: &disableClientRoutes,
|
||||||
|
DisableServerRoutes: &disableServerRoutes,
|
||||||
|
DisableDns: &disableDNS,
|
||||||
|
DisableFirewall: &disableFirewall,
|
||||||
|
BlockLanAccess: &blockLANAccess,
|
||||||
|
DisableNotifications: &disableNotifications,
|
||||||
|
LazyConnectionEnabled: &lazyConnectionEnabled,
|
||||||
|
BlockInbound: &blockInbound,
|
||||||
|
NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
|
||||||
|
CleanNATExternalIPs: false,
|
||||||
|
CustomDNSAddress: []byte("1.1.1.1:53"),
|
||||||
|
ExtraIFaceBlacklist: []string{"eth1", "eth2"},
|
||||||
|
DnsLabels: []string{"label1", "label2"},
|
||||||
|
CleanDNSLabels: false,
|
||||||
|
DnsRouteInterval: durationpb.New(2 * time.Minute),
|
||||||
|
Mtu: &mtu,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = s.SetConfig(ctx, req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
profState := profilemanager.ActiveProfileState{
|
||||||
|
Name: profName,
|
||||||
|
Username: currUser.Username,
|
||||||
|
}
|
||||||
|
cfgPath, err := profState.FilePath()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cfg, err := profilemanager.GetConfig(cfgPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, "https://new-api.netbird.io:443", cfg.ManagementURL.String())
|
||||||
|
require.Equal(t, "https://new-admin.netbird.io:443", cfg.AdminURL.String())
|
||||||
|
require.Equal(t, rosenpassEnabled, cfg.RosenpassEnabled)
|
||||||
|
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
|
||||||
|
require.NotNil(t, cfg.ServerSSHAllowed)
|
||||||
|
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
|
||||||
|
require.Equal(t, interfaceName, cfg.WgIface)
|
||||||
|
require.Equal(t, int(wireguardPort), cfg.WgPort)
|
||||||
|
require.Equal(t, preSharedKey, cfg.PreSharedKey)
|
||||||
|
require.Equal(t, disableAutoConnect, cfg.DisableAutoConnect)
|
||||||
|
require.NotNil(t, cfg.NetworkMonitor)
|
||||||
|
require.Equal(t, networkMonitor, *cfg.NetworkMonitor)
|
||||||
|
require.Equal(t, disableClientRoutes, cfg.DisableClientRoutes)
|
||||||
|
require.Equal(t, disableServerRoutes, cfg.DisableServerRoutes)
|
||||||
|
require.Equal(t, disableDNS, cfg.DisableDNS)
|
||||||
|
require.Equal(t, disableFirewall, cfg.DisableFirewall)
|
||||||
|
require.Equal(t, blockLANAccess, cfg.BlockLANAccess)
|
||||||
|
require.NotNil(t, cfg.DisableNotifications)
|
||||||
|
require.Equal(t, disableNotifications, *cfg.DisableNotifications)
|
||||||
|
require.Equal(t, lazyConnectionEnabled, cfg.LazyConnectionEnabled)
|
||||||
|
require.Equal(t, blockInbound, cfg.BlockInbound)
|
||||||
|
require.Equal(t, []string{"1.2.3.4", "5.6.7.8"}, cfg.NATExternalIPs)
|
||||||
|
require.Equal(t, "1.1.1.1:53", cfg.CustomDNSAddress)
|
||||||
|
// IFaceBlackList contains defaults + extras
|
||||||
|
require.Contains(t, cfg.IFaceBlackList, "eth1")
|
||||||
|
require.Contains(t, cfg.IFaceBlackList, "eth2")
|
||||||
|
require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList())
|
||||||
|
require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval)
|
||||||
|
require.Equal(t, uint16(mtu), cfg.MTU)
|
||||||
|
|
||||||
|
verifyAllFieldsCovered(t, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// verifyAllFieldsCovered uses reflection to ensure we're testing all fields in SetConfigRequest.
|
||||||
|
// If a new field is added to SetConfigRequest, this function will fail the test,
|
||||||
|
// forcing the developer to update both the SetConfig handler and this test.
|
||||||
|
func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
metadataFields := map[string]bool{
|
||||||
|
"state": true, // protobuf internal
|
||||||
|
"sizeCache": true, // protobuf internal
|
||||||
|
"unknownFields": true, // protobuf internal
|
||||||
|
"Username": true, // metadata
|
||||||
|
"ProfileName": true, // metadata
|
||||||
|
"CleanNATExternalIPs": true, // control flag for clearing
|
||||||
|
"CleanDNSLabels": true, // control flag for clearing
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedFields := map[string]bool{
|
||||||
|
"ManagementUrl": true,
|
||||||
|
"AdminURL": true,
|
||||||
|
"RosenpassEnabled": true,
|
||||||
|
"RosenpassPermissive": true,
|
||||||
|
"ServerSSHAllowed": true,
|
||||||
|
"InterfaceName": true,
|
||||||
|
"WireguardPort": true,
|
||||||
|
"OptionalPreSharedKey": true,
|
||||||
|
"DisableAutoConnect": true,
|
||||||
|
"NetworkMonitor": true,
|
||||||
|
"DisableClientRoutes": true,
|
||||||
|
"DisableServerRoutes": true,
|
||||||
|
"DisableDns": true,
|
||||||
|
"DisableFirewall": true,
|
||||||
|
"BlockLanAccess": true,
|
||||||
|
"DisableNotifications": true,
|
||||||
|
"LazyConnectionEnabled": true,
|
||||||
|
"BlockInbound": true,
|
||||||
|
"NatExternalIPs": true,
|
||||||
|
"CustomDNSAddress": true,
|
||||||
|
"ExtraIFaceBlacklist": true,
|
||||||
|
"DnsLabels": true,
|
||||||
|
"DnsRouteInterval": true,
|
||||||
|
"Mtu": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
val := reflect.ValueOf(req).Elem()
|
||||||
|
typ := val.Type()
|
||||||
|
|
||||||
|
var unexpectedFields []string
|
||||||
|
for i := 0; i < val.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
fieldName := field.Name
|
||||||
|
|
||||||
|
if metadataFields[fieldName] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !expectedFields[fieldName] {
|
||||||
|
unexpectedFields = append(unexpectedFields, fieldName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(unexpectedFields) > 0 {
|
||||||
|
t.Fatalf("New field(s) detected in SetConfigRequest: %v", unexpectedFields)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCLIFlags_MappedToSetConfig ensures all CLI flags that modify config are properly mapped to SetConfigRequest.
|
||||||
|
// This test catches bugs where a new CLI flag is added but not wired to the SetConfigRequest in setupSetConfigReq.
|
||||||
|
func TestCLIFlags_MappedToSetConfig(t *testing.T) {
|
||||||
|
// Map of CLI flag names to their corresponding SetConfigRequest field names.
|
||||||
|
// This map must be updated when adding new config-related CLI flags.
|
||||||
|
flagToField := map[string]string{
|
||||||
|
"management-url": "ManagementUrl",
|
||||||
|
"admin-url": "AdminURL",
|
||||||
|
"enable-rosenpass": "RosenpassEnabled",
|
||||||
|
"rosenpass-permissive": "RosenpassPermissive",
|
||||||
|
"allow-server-ssh": "ServerSSHAllowed",
|
||||||
|
"interface-name": "InterfaceName",
|
||||||
|
"wireguard-port": "WireguardPort",
|
||||||
|
"preshared-key": "OptionalPreSharedKey",
|
||||||
|
"disable-auto-connect": "DisableAutoConnect",
|
||||||
|
"network-monitor": "NetworkMonitor",
|
||||||
|
"disable-client-routes": "DisableClientRoutes",
|
||||||
|
"disable-server-routes": "DisableServerRoutes",
|
||||||
|
"disable-dns": "DisableDns",
|
||||||
|
"disable-firewall": "DisableFirewall",
|
||||||
|
"block-lan-access": "BlockLanAccess",
|
||||||
|
"block-inbound": "BlockInbound",
|
||||||
|
"enable-lazy-connection": "LazyConnectionEnabled",
|
||||||
|
"external-ip-map": "NatExternalIPs",
|
||||||
|
"dns-resolver-address": "CustomDNSAddress",
|
||||||
|
"extra-iface-blacklist": "ExtraIFaceBlacklist",
|
||||||
|
"extra-dns-labels": "DnsLabels",
|
||||||
|
"dns-router-interval": "DnsRouteInterval",
|
||||||
|
"mtu": "Mtu",
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetConfigRequest fields that don't have CLI flags (settable only via UI or other means).
|
||||||
|
fieldsWithoutCLIFlags := map[string]bool{
|
||||||
|
"DisableNotifications": true, // Only settable via UI
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get all SetConfigRequest fields to verify our map is complete.
|
||||||
|
req := &proto.SetConfigRequest{}
|
||||||
|
val := reflect.ValueOf(req).Elem()
|
||||||
|
typ := val.Type()
|
||||||
|
|
||||||
|
var unmappedFields []string
|
||||||
|
for i := 0; i < val.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
fieldName := field.Name
|
||||||
|
|
||||||
|
// Skip protobuf internal fields and metadata fields.
|
||||||
|
if fieldName == "state" || fieldName == "sizeCache" || fieldName == "unknownFields" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if fieldName == "Username" || fieldName == "ProfileName" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if fieldName == "CleanNATExternalIPs" || fieldName == "CleanDNSLabels" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this field is either mapped to a CLI flag or explicitly documented as having no CLI flag.
|
||||||
|
mappedToCLI := false
|
||||||
|
for _, mappedField := range flagToField {
|
||||||
|
if mappedField == fieldName {
|
||||||
|
mappedToCLI = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hasNoCLIFlag := fieldsWithoutCLIFlags[fieldName]
|
||||||
|
|
||||||
|
if !mappedToCLI && !hasNoCLIFlag {
|
||||||
|
unmappedFields = append(unmappedFields, fieldName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(unmappedFields) > 0 {
|
||||||
|
t.Fatalf("SetConfigRequest field(s) not documented: %v\n"+
|
||||||
|
"Either add the CLI flag to flagToField map, or if there's no CLI flag for this field, "+
|
||||||
|
"add it to fieldsWithoutCLIFlags map with a comment explaining why.", unmappedFields)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("All SetConfigRequest fields are properly documented")
|
||||||
|
}
|
||||||
@@ -10,7 +10,9 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -135,5 +137,12 @@ func restoreResidualState(ctx context.Context, statePath string) error {
|
|||||||
merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clean up any remaining routes independently of the state file
|
||||||
|
if !nbnet.AdvancedRouting() {
|
||||||
|
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -205,15 +205,18 @@ func mapPeers(
|
|||||||
localICEEndpoint := ""
|
localICEEndpoint := ""
|
||||||
remoteICEEndpoint := ""
|
remoteICEEndpoint := ""
|
||||||
relayServerAddress := ""
|
relayServerAddress := ""
|
||||||
connType := "P2P"
|
connType := "-"
|
||||||
lastHandshake := time.Time{}
|
lastHandshake := time.Time{}
|
||||||
transferReceived := int64(0)
|
transferReceived := int64(0)
|
||||||
transferSent := int64(0)
|
transferSent := int64(0)
|
||||||
|
|
||||||
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
||||||
|
|
||||||
if pbPeerState.Relayed {
|
if isPeerConnected {
|
||||||
connType = "Relayed"
|
connType = "P2P"
|
||||||
|
if pbPeerState.Relayed {
|
||||||
|
connType = "Relayed"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) {
|
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) {
|
||||||
|
|||||||
@@ -31,7 +31,6 @@ import (
|
|||||||
"fyne.io/systray"
|
"fyne.io/systray"
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/skratchdot/open-golang/open"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
@@ -633,7 +632,7 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error {
|
func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error {
|
||||||
err := open.Run(loginResp.VerificationURIComplete)
|
err := openURL(loginResp.VerificationURIComplete)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("opening the verification uri in the browser failed: %v", err)
|
log.Errorf("opening the verification uri in the browser failed: %v", err)
|
||||||
return err
|
return err
|
||||||
@@ -1354,7 +1353,13 @@ func (s *serviceClient) updateConfig() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL.
|
// showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL.
|
||||||
func (s *serviceClient) showLoginURL() {
|
// It also starts a background goroutine that periodically checks if the client is already connected
|
||||||
|
// and closes the window if so. The goroutine can be cancelled by the returned CancelFunc, and it is
|
||||||
|
// also cancelled when the window is closed.
|
||||||
|
func (s *serviceClient) showLoginURL() context.CancelFunc {
|
||||||
|
|
||||||
|
// create a cancellable context for the background check goroutine
|
||||||
|
ctx, cancel := context.WithCancel(s.ctx)
|
||||||
|
|
||||||
resIcon := fyne.NewStaticResource("netbird.png", iconAbout)
|
resIcon := fyne.NewStaticResource("netbird.png", iconAbout)
|
||||||
|
|
||||||
@@ -1363,6 +1368,8 @@ func (s *serviceClient) showLoginURL() {
|
|||||||
s.wLoginURL.Resize(fyne.NewSize(400, 200))
|
s.wLoginURL.Resize(fyne.NewSize(400, 200))
|
||||||
s.wLoginURL.SetIcon(resIcon)
|
s.wLoginURL.SetIcon(resIcon)
|
||||||
}
|
}
|
||||||
|
// ensure goroutine is cancelled when the window is closed
|
||||||
|
s.wLoginURL.SetOnClosed(func() { cancel() })
|
||||||
// add a description label
|
// add a description label
|
||||||
label := widget.NewLabel("Your NetBird session has expired.\nPlease re-authenticate to continue using NetBird.")
|
label := widget.NewLabel("Your NetBird session has expired.\nPlease re-authenticate to continue using NetBird.")
|
||||||
|
|
||||||
@@ -1443,10 +1450,46 @@ func (s *serviceClient) showLoginURL() {
|
|||||||
)
|
)
|
||||||
s.wLoginURL.SetContent(container.NewCenter(content))
|
s.wLoginURL.SetContent(container.NewCenter(content))
|
||||||
|
|
||||||
|
// start a goroutine to check connection status and close the window if connected
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(5 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
conn, err := s.getSrvClient(failFastTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if status.Status == string(internal.StatusConnected) {
|
||||||
|
if s.wLoginURL != nil {
|
||||||
|
s.wLoginURL.Close()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
s.wLoginURL.Show()
|
s.wLoginURL.Show()
|
||||||
|
|
||||||
|
// return cancel func so callers can stop the background goroutine if desired
|
||||||
|
return cancel
|
||||||
}
|
}
|
||||||
|
|
||||||
func openURL(url string) error {
|
func openURL(url string) error {
|
||||||
|
if browser := os.Getenv("BROWSER"); browser != "" {
|
||||||
|
return exec.Command(browser, url).Start()
|
||||||
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
switch runtime.GOOS {
|
switch runtime.GOOS {
|
||||||
case "windows":
|
case "windows":
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/skratchdot/open-golang/open"
|
"github.com/skratchdot/open-golang/open"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
uptypes "github.com/netbirdio/netbird/upload-server/types"
|
uptypes "github.com/netbirdio/netbird/upload-server/types"
|
||||||
@@ -426,6 +427,12 @@ func (s *serviceClient) collectDebugData(
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pm := profilemanager.NewProfileManager()
|
||||||
|
var profName string
|
||||||
|
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||||
|
profName = activeProf.Name
|
||||||
|
}
|
||||||
|
|
||||||
postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Failed to get post-up status: %v", err)
|
log.Warnf("Failed to get post-up status: %v", err)
|
||||||
@@ -433,7 +440,7 @@ func (s *serviceClient) collectDebugData(
|
|||||||
|
|
||||||
var postUpStatusOutput string
|
var postUpStatusOutput string
|
||||||
if postUpStatus != nil {
|
if postUpStatus != nil {
|
||||||
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "")
|
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||||
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||||
}
|
}
|
||||||
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||||
@@ -450,7 +457,7 @@ func (s *serviceClient) collectDebugData(
|
|||||||
|
|
||||||
var preDownStatusOutput string
|
var preDownStatusOutput string
|
||||||
if preDownStatus != nil {
|
if preDownStatus != nil {
|
||||||
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "")
|
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||||
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||||
}
|
}
|
||||||
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
|
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
|
||||||
@@ -574,6 +581,12 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
|
|||||||
return nil, fmt.Errorf("get client: %v", err)
|
return nil, fmt.Errorf("get client: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pm := profilemanager.NewProfileManager()
|
||||||
|
var profName string
|
||||||
|
if activeProf, err := pm.GetActiveProfile(); err == nil {
|
||||||
|
profName = activeProf.Name
|
||||||
|
}
|
||||||
|
|
||||||
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to get status for debug bundle: %v", err)
|
log.Warnf("failed to get status for debug bundle: %v", err)
|
||||||
@@ -581,7 +594,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
|
|||||||
|
|
||||||
var statusOutput string
|
var statusOutput string
|
||||||
if statusResp != nil {
|
if statusResp != nil {
|
||||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "")
|
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
|
||||||
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -73,8 +73,8 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config {
|
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, requiresCredSSP bool) *tls.Config {
|
||||||
return &tls.Config{
|
config := &tls.Config{
|
||||||
InsecureSkipVerify: true, // We'll validate manually after handshake
|
InsecureSkipVerify: true, // We'll validate manually after handshake
|
||||||
VerifyConnection: func(cs tls.ConnectionState) error {
|
VerifyConnection: func(cs tls.ConnectionState) error {
|
||||||
var certChain [][]byte
|
var certChain [][]byte
|
||||||
@@ -93,4 +93,15 @@ func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tl
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CredSSP (NLA) requires TLS 1.2 - it's incompatible with TLS 1.3
|
||||||
|
if requiresCredSSP {
|
||||||
|
config.MinVersion = tls.VersionTLS12
|
||||||
|
config.MaxVersion = tls.VersionTLS12
|
||||||
|
} else {
|
||||||
|
config.MinVersion = tls.VersionTLS12
|
||||||
|
config.MaxVersion = tls.VersionTLS13
|
||||||
|
}
|
||||||
|
|
||||||
|
return config
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,11 +6,13 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/asn1"
|
"encoding/asn1"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall/js"
|
"syscall/js"
|
||||||
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
@@ -19,18 +21,34 @@ const (
|
|||||||
RDCleanPathVersion = 3390
|
RDCleanPathVersion = 3390
|
||||||
RDCleanPathProxyHost = "rdcleanpath.proxy.local"
|
RDCleanPathProxyHost = "rdcleanpath.proxy.local"
|
||||||
RDCleanPathProxyScheme = "ws"
|
RDCleanPathProxyScheme = "ws"
|
||||||
|
|
||||||
|
rdpDialTimeout = 15 * time.Second
|
||||||
|
|
||||||
|
GeneralErrorCode = 1
|
||||||
|
WSAETimedOut = 10060
|
||||||
|
WSAEConnRefused = 10061
|
||||||
|
WSAEConnAborted = 10053
|
||||||
|
WSAEConnReset = 10054
|
||||||
|
WSAEGenericError = 10050
|
||||||
)
|
)
|
||||||
|
|
||||||
type RDCleanPathPDU struct {
|
type RDCleanPathPDU struct {
|
||||||
Version int64 `asn1:"tag:0,explicit"`
|
Version int64 `asn1:"tag:0,explicit"`
|
||||||
Error []byte `asn1:"tag:1,explicit,optional"`
|
Error RDCleanPathErr `asn1:"tag:1,explicit,optional"`
|
||||||
Destination string `asn1:"utf8,tag:2,explicit,optional"`
|
Destination string `asn1:"utf8,tag:2,explicit,optional"`
|
||||||
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
|
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
|
||||||
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
|
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
|
||||||
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
|
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
|
||||||
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
|
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
|
||||||
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
|
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
|
||||||
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
|
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RDCleanPathErr struct {
|
||||||
|
ErrorCode int16 `asn1:"tag:0,explicit"`
|
||||||
|
HTTPStatusCode int16 `asn1:"tag:1,explicit,optional"`
|
||||||
|
WSALastError int16 `asn1:"tag:2,explicit,optional"`
|
||||||
|
TLSAlertCode int8 `asn1:"tag:3,explicit,optional"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type RDCleanPathProxy struct {
|
type RDCleanPathProxy struct {
|
||||||
@@ -210,9 +228,13 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
|
|||||||
destination := conn.destination
|
destination := conn.destination
|
||||||
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
|
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
|
||||||
|
|
||||||
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
|
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to connect to %s: %v", destination, err)
|
log.Errorf("Failed to connect to %s: %v", destination, err)
|
||||||
|
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
conn.rdpConn = rdpConn
|
conn.rdpConn = rdpConn
|
||||||
@@ -220,6 +242,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
|
|||||||
_, err = rdpConn.Write(firstPacket)
|
_, err = rdpConn.Write(firstPacket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to write first packet: %v", err)
|
log.Errorf("Failed to write first packet: %v", err)
|
||||||
|
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -227,6 +250,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
|
|||||||
n, err := rdpConn.Read(response)
|
n, err := rdpConn.Read(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to read X.224 response: %v", err)
|
log.Errorf("Failed to read X.224 response: %v", err)
|
||||||
|
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -269,3 +293,52 @@ func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
|
|||||||
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
|
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||||
|
data, err := asn1.Marshal(pdu)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to marshal error PDU: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.sendToWebSocket(conn, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func errorToWSACode(err error) int16 {
|
||||||
|
if err == nil {
|
||||||
|
return WSAEGenericError
|
||||||
|
}
|
||||||
|
var netErr *net.OpError
|
||||||
|
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||||
|
return WSAETimedOut
|
||||||
|
}
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return WSAETimedOut
|
||||||
|
}
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
return WSAEConnAborted
|
||||||
|
}
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
return WSAEConnReset
|
||||||
|
}
|
||||||
|
return WSAEGenericError
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWSAError(err error) RDCleanPathPDU {
|
||||||
|
return RDCleanPathPDU{
|
||||||
|
Version: RDCleanPathVersion,
|
||||||
|
Error: RDCleanPathErr{
|
||||||
|
ErrorCode: GeneralErrorCode,
|
||||||
|
WSALastError: errorToWSACode(err),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTPError(statusCode int16) RDCleanPathPDU {
|
||||||
|
return RDCleanPathPDU{
|
||||||
|
Version: RDCleanPathVersion,
|
||||||
|
Error: RDCleanPathErr{
|
||||||
|
ErrorCode: GeneralErrorCode,
|
||||||
|
HTTPStatusCode: statusCode,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
package rdp
|
package rdp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/asn1"
|
"encoding/asn1"
|
||||||
"io"
|
"io"
|
||||||
@@ -11,11 +12,17 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// MS-RDPBCGR: confusingly named, actually means PROTOCOL_HYBRID (CredSSP)
|
||||||
|
protocolSSL = 0x00000001
|
||||||
|
protocolHybridEx = 0x00000008
|
||||||
|
)
|
||||||
|
|
||||||
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||||
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
|
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
|
||||||
|
|
||||||
if pdu.Version != RDCleanPathVersion {
|
if pdu.Version != RDCleanPathVersion {
|
||||||
p.sendRDCleanPathError(conn, "Unsupported version")
|
p.sendRDCleanPathError(conn, newHTTPError(400))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -24,10 +31,13 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
|
|||||||
destination = pdu.Destination
|
destination = pdu.Destination
|
||||||
}
|
}
|
||||||
|
|
||||||
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
|
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to connect to %s: %v", destination, err)
|
log.Errorf("Failed to connect to %s: %v", destination, err)
|
||||||
p.sendRDCleanPathError(conn, "Connection failed")
|
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||||
p.cleanupConnection(conn)
|
p.cleanupConnection(conn)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -40,6 +50,34 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
|
|||||||
p.setupTLSConnection(conn, pdu)
|
p.setupTLSConnection(conn, pdu)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// detectCredSSPFromX224 checks if the X.224 response indicates NLA/CredSSP is required.
|
||||||
|
// Per MS-RDPBCGR spec: byte 11 = TYPE_RDP_NEG_RSP (0x02), bytes 15-18 = selectedProtocol flags.
|
||||||
|
// Returns (requiresTLS12, selectedProtocol, detectionSuccessful).
|
||||||
|
func (p *RDCleanPathProxy) detectCredSSPFromX224(x224Response []byte) (bool, uint32, bool) {
|
||||||
|
const minResponseLength = 19
|
||||||
|
|
||||||
|
if len(x224Response) < minResponseLength {
|
||||||
|
return false, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per X.224 specification:
|
||||||
|
// x224Response[0] == 0x03: Length of X.224 header (3 bytes)
|
||||||
|
// x224Response[5] == 0xD0: X.224 Data TPDU code
|
||||||
|
if x224Response[0] != 0x03 || x224Response[5] != 0xD0 {
|
||||||
|
return false, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if x224Response[11] == 0x02 {
|
||||||
|
flags := uint32(x224Response[15]) | uint32(x224Response[16])<<8 |
|
||||||
|
uint32(x224Response[17])<<16 | uint32(x224Response[18])<<24
|
||||||
|
|
||||||
|
hasNLA := (flags & (protocolSSL | protocolHybridEx)) != 0
|
||||||
|
return hasNLA, flags, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
|
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||||
var x224Response []byte
|
var x224Response []byte
|
||||||
if len(pdu.X224ConnectionPDU) > 0 {
|
if len(pdu.X224ConnectionPDU) > 0 {
|
||||||
@@ -47,7 +85,7 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
|
|||||||
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
|
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to write X.224 PDU: %v", err)
|
log.Errorf("Failed to write X.224 PDU: %v", err)
|
||||||
p.sendRDCleanPathError(conn, "Failed to forward X.224")
|
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,21 +93,32 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
|
|||||||
n, err := conn.rdpConn.Read(response)
|
n, err := conn.rdpConn.Read(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to read X.224 response: %v", err)
|
log.Errorf("Failed to read X.224 response: %v", err)
|
||||||
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
|
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
x224Response = response[:n]
|
x224Response = response[:n]
|
||||||
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
|
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsConfig := p.getTLSConfigWithValidation(conn)
|
requiresCredSSP, selectedProtocol, detected := p.detectCredSSPFromX224(x224Response)
|
||||||
|
if detected {
|
||||||
|
if requiresCredSSP {
|
||||||
|
log.Warnf("Detected NLA/CredSSP (selectedProtocol: 0x%08X), forcing TLS 1.2 for compatibility", selectedProtocol)
|
||||||
|
} else {
|
||||||
|
log.Warnf("No NLA/CredSSP detected (selectedProtocol: 0x%08X), allowing up to TLS 1.3", selectedProtocol)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Warnf("Could not detect RDP security protocol, allowing up to TLS 1.3")
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := p.getTLSConfigWithValidation(conn, requiresCredSSP)
|
||||||
|
|
||||||
tlsConn := tls.Client(conn.rdpConn, tlsConfig)
|
tlsConn := tls.Client(conn.rdpConn, tlsConfig)
|
||||||
conn.tlsConn = tlsConn
|
conn.tlsConn = tlsConn
|
||||||
|
|
||||||
if err := tlsConn.Handshake(); err != nil {
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
log.Errorf("TLS handshake failed: %v", err)
|
log.Errorf("TLS handshake failed: %v", err)
|
||||||
p.sendRDCleanPathError(conn, "TLS handshake failed")
|
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,47 +155,6 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
|
|||||||
p.cleanupConnection(conn)
|
p.cleanupConnection(conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
|
|
||||||
if len(pdu.X224ConnectionPDU) > 0 {
|
|
||||||
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
|
|
||||||
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to write X.224 PDU: %v", err)
|
|
||||||
p.sendRDCleanPathError(conn, "Failed to forward X.224")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
response := make([]byte, 1024)
|
|
||||||
n, err := conn.rdpConn.Read(response)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to read X.224 response: %v", err)
|
|
||||||
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
responsePDU := RDCleanPathPDU{
|
|
||||||
Version: RDCleanPathVersion,
|
|
||||||
X224ConnectionPDU: response[:n],
|
|
||||||
ServerAddr: conn.destination,
|
|
||||||
}
|
|
||||||
|
|
||||||
p.sendRDCleanPathPDU(conn, responsePDU)
|
|
||||||
} else {
|
|
||||||
responsePDU := RDCleanPathPDU{
|
|
||||||
Version: RDCleanPathVersion,
|
|
||||||
ServerAddr: conn.destination,
|
|
||||||
}
|
|
||||||
p.sendRDCleanPathPDU(conn, responsePDU)
|
|
||||||
}
|
|
||||||
|
|
||||||
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
|
|
||||||
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
|
|
||||||
|
|
||||||
<-conn.ctx.Done()
|
|
||||||
log.Debug("TCP connection context done, cleaning up")
|
|
||||||
p.cleanupConnection(conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||||
data, err := asn1.Marshal(pdu)
|
data, err := asn1.Marshal(pdu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -158,21 +166,6 @@ func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDClean
|
|||||||
p.sendToWebSocket(conn, data)
|
p.sendToWebSocket(conn, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) {
|
|
||||||
pdu := RDCleanPathPDU{
|
|
||||||
Version: RDCleanPathVersion,
|
|
||||||
Error: []byte(errorMsg),
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := asn1.Marshal(pdu)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to marshal error PDU: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
p.sendToWebSocket(conn, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
|
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
|
||||||
msgChan := make(chan []byte)
|
msgChan := make(chan []byte)
|
||||||
errChan := make(chan error)
|
errChan := make(chan error)
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -62,7 +62,7 @@ require (
|
|||||||
github.com/miekg/dns v1.1.59
|
github.com/miekg/dns v1.1.59
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
github.com/nadoo/ipset v0.5.0
|
github.com/nadoo/ipset v0.5.0
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||||
github.com/oschwald/maxminddb-golang v1.12.0
|
github.com/oschwald/maxminddb-golang v1.12.0
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
|||||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 h1:moJbL1uuaWR35yUgHZ6suijjqqW8/qGCuPPBXu5MeWQ=
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48/go.mod h1:ifKa2jGPsOzZhJFo72v2AE5nMP3GYvlhoZ9JV6lHlJ8=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||||
|
|||||||
@@ -185,12 +185,15 @@ if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then
|
|||||||
echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME"
|
echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME"
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
export NETBIRD_SIGNAL_PROTOCOL="https"
|
|
||||||
unset NETBIRD_LETSENCRYPT_DOMAIN
|
unset NETBIRD_LETSENCRYPT_DOMAIN
|
||||||
unset NETBIRD_MGMT_API_CERT_FILE
|
unset NETBIRD_MGMT_API_CERT_FILE
|
||||||
unset NETBIRD_MGMT_API_CERT_KEY_FILE
|
unset NETBIRD_MGMT_API_CERT_KEY_FILE
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [[ -n "$NETBIRD_MGMT_API_CERT_FILE" && -n "$NETBIRD_MGMT_API_CERT_KEY_FILE" ]]; then
|
||||||
|
export NETBIRD_SIGNAL_PROTOCOL="https"
|
||||||
|
fi
|
||||||
|
|
||||||
# Check if management identity provider is set
|
# Check if management identity provider is set
|
||||||
if [ -n "$NETBIRD_MGMT_IDP" ]; then
|
if [ -n "$NETBIRD_MGMT_IDP" ]; then
|
||||||
EXTRA_CONFIG={}
|
EXTRA_CONFIG={}
|
||||||
|
|||||||
@@ -40,13 +40,21 @@ services:
|
|||||||
signal:
|
signal:
|
||||||
<<: *default
|
<<: *default
|
||||||
image: netbirdio/signal:$NETBIRD_SIGNAL_TAG
|
image: netbirdio/signal:$NETBIRD_SIGNAL_TAG
|
||||||
|
depends_on:
|
||||||
|
- dashboard
|
||||||
volumes:
|
volumes:
|
||||||
- $SIGNAL_VOLUMENAME:/var/lib/netbird
|
- $SIGNAL_VOLUMENAME:/var/lib/netbird
|
||||||
|
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro
|
||||||
ports:
|
ports:
|
||||||
- $NETBIRD_SIGNAL_PORT:80
|
- $NETBIRD_SIGNAL_PORT:80
|
||||||
# # port and command for Let's Encrypt validation
|
# # port and command for Let's Encrypt validation
|
||||||
# - 443:443
|
# - 443:443
|
||||||
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
|
# command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"]
|
||||||
|
command: [
|
||||||
|
"--cert-file", "$NETBIRD_MGMT_API_CERT_FILE",
|
||||||
|
"--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE",
|
||||||
|
"--log-file", "console"
|
||||||
|
]
|
||||||
|
|
||||||
# Relay
|
# Relay
|
||||||
relay:
|
relay:
|
||||||
|
|||||||
@@ -47,8 +47,9 @@ services:
|
|||||||
- traefik.enable=true
|
- traefik.enable=true
|
||||||
- traefik.http.routers.netbird-wsproxy-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/signal`)
|
- traefik.http.routers.netbird-wsproxy-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/signal`)
|
||||||
- traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal
|
- traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal
|
||||||
- traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=10000
|
- traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=80
|
||||||
- traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`)
|
- traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`)
|
||||||
|
- traefik.http.routers.netbird-signal.service=netbird-signal
|
||||||
- traefik.http.services.netbird-signal.loadbalancer.server.port=10000
|
- traefik.http.services.netbird-signal.loadbalancer.server.port=10000
|
||||||
- traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c
|
- traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c
|
||||||
|
|
||||||
|
|||||||
@@ -621,7 +621,7 @@ renderCaddyfile() {
|
|||||||
# relay
|
# relay
|
||||||
reverse_proxy /relay* relay:80
|
reverse_proxy /relay* relay:80
|
||||||
# Signal
|
# Signal
|
||||||
reverse_proxy /ws-proxy/signal* signal:10000
|
reverse_proxy /ws-proxy/signal* signal:80
|
||||||
reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000
|
reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000
|
||||||
# Management
|
# Management
|
||||||
reverse_proxy /api/* management:80
|
reverse_proxy /api/* management:80
|
||||||
@@ -682,17 +682,6 @@ renderManagementJson() {
|
|||||||
"URI": "stun:$NETBIRD_DOMAIN:3478"
|
"URI": "stun:$NETBIRD_DOMAIN:3478"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"TURNConfig": {
|
|
||||||
"Turns": [
|
|
||||||
{
|
|
||||||
"Proto": "udp",
|
|
||||||
"URI": "turn:$NETBIRD_DOMAIN:3478",
|
|
||||||
"Username": "$TURN_USER",
|
|
||||||
"Password": "$TURN_PASSWORD"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"TimeBasedCredentials": false
|
|
||||||
},
|
|
||||||
"Relay": {
|
"Relay": {
|
||||||
"Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"],
|
"Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"],
|
||||||
"CredentialsTTL": "24h",
|
"CredentialsTTL": "24h",
|
||||||
|
|||||||
@@ -35,7 +35,13 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
|
|||||||
|
|
||||||
func (s *BaseServer) PermissionsManager() permissions.Manager {
|
func (s *BaseServer) PermissionsManager() permissions.Manager {
|
||||||
return Create(s, func() permissions.Manager {
|
return Create(s, func() permissions.Manager {
|
||||||
return integrations.InitPermissionsManager(s.Store())
|
manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter())
|
||||||
|
|
||||||
|
s.AfterInit(func(s *BaseServer) {
|
||||||
|
manager.SetAccountManager(s.AccountManager())
|
||||||
|
})
|
||||||
|
|
||||||
|
return manager
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -252,7 +251,7 @@ func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
|
func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
|
||||||
wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), ManagementLegacyPort), wsproxyserver.WithOTelMeter(meter))
|
wsProxy := wsproxyserver.New(gRPCHandler, wsproxyserver.WithOTelMeter(meter))
|
||||||
|
|
||||||
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||||
switch {
|
switch {
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ type Manager interface {
|
|||||||
GetIdpManager() idp.Manager
|
GetIdpManager() idp.Manager
|
||||||
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
|
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
|
||||||
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||||
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error)
|
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
|
||||||
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||||
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
|
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
|
||||||
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
|
|||||||
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
|
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
|
||||||
grpsInfoMap := groups.ToGroupsInfoMap(grps, 0)
|
grpsInfoMap := groups.ToGroupsInfoMap(grps, 0)
|
||||||
|
|
||||||
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
|
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
|
||||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||||
@@ -86,7 +86,9 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, valid := validPeers[peer.ID]
|
_, valid := validPeers[peer.ID]
|
||||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid))
|
reason := invalidPeers[peer.ID]
|
||||||
|
|
||||||
|
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid, reason))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -147,16 +149,17 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
|
|||||||
|
|
||||||
grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0)
|
grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0)
|
||||||
|
|
||||||
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err)
|
log.WithContext(ctx).Errorf("failed to get validated peers: %v", err)
|
||||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, valid := validPeers[peer.ID]
|
_, valid := validPeers[peer.ID]
|
||||||
|
reason := invalidPeers[peer.ID]
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid))
|
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
|
func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
|
||||||
@@ -240,22 +243,25 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0))
|
respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0))
|
||||||
}
|
}
|
||||||
|
|
||||||
validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err)
|
log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err)
|
||||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.setApprovalRequiredFlag(respBody, validPeersMap)
|
h.setApprovalRequiredFlag(respBody, validPeersMap, invalidPeersMap)
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, respBody)
|
util.WriteJSONObject(r.Context(), w, respBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) {
|
func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersMap map[string]struct{}, invalidPeersMap map[string]string) {
|
||||||
for _, peer := range respBody {
|
for _, peer := range respBody {
|
||||||
_, ok := approvedPeersMap[peer.Id]
|
_, ok := validPeersMap[peer.Id]
|
||||||
if !ok {
|
if !ok {
|
||||||
peer.ApprovalRequired = true
|
peer.ApprovalRequired = true
|
||||||
|
|
||||||
|
reason := invalidPeersMap[peer.Id]
|
||||||
|
peer.DisapprovalReason = &reason
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -304,7 +310,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
|
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
|
||||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||||
@@ -430,13 +436,13 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer {
|
func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool, reason string) *api.Peer {
|
||||||
osVersion := peer.Meta.OSVersion
|
osVersion := peer.Meta.OSVersion
|
||||||
if osVersion == "" {
|
if osVersion == "" {
|
||||||
osVersion = peer.Meta.Core
|
osVersion = peer.Meta.Core
|
||||||
}
|
}
|
||||||
|
|
||||||
return &api.Peer{
|
apiPeer := &api.Peer{
|
||||||
CreatedAt: peer.CreatedAt,
|
CreatedAt: peer.CreatedAt,
|
||||||
Id: peer.ID,
|
Id: peer.ID,
|
||||||
Name: peer.Name,
|
Name: peer.Name,
|
||||||
@@ -465,6 +471,12 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
|
|||||||
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||||
Ephemeral: peer.Ephemeral,
|
Ephemeral: peer.Ephemeral,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !approved {
|
||||||
|
apiPeer.DisapprovalReason = &reason
|
||||||
|
}
|
||||||
|
|
||||||
|
return apiPeer
|
||||||
}
|
}
|
||||||
|
|
||||||
func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch {
|
func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch {
|
||||||
|
|||||||
@@ -7,9 +7,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
|||||||
@@ -26,9 +26,11 @@ type mockHTTPClient struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
||||||
body, err := io.ReadAll(req.Body)
|
if req.Body != nil {
|
||||||
if err == nil {
|
body, err := io.ReadAll(req.Body)
|
||||||
c.reqBody = string(body)
|
if err == nil {
|
||||||
|
c.reqBody = string(body)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return &http.Response{
|
return &http.Response{
|
||||||
StatusCode: c.code,
|
StatusCode: c.code,
|
||||||
|
|||||||
@@ -201,6 +201,12 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
|
|||||||
APIToken: config.ExtraConfig["ApiToken"],
|
APIToken: config.ExtraConfig["ApiToken"],
|
||||||
}
|
}
|
||||||
return NewJumpCloudManager(jumpcloudConfig, appMetrics)
|
return NewJumpCloudManager(jumpcloudConfig, appMetrics)
|
||||||
|
case "pocketid":
|
||||||
|
pocketidConfig := PocketIdClientConfig{
|
||||||
|
APIToken: config.ExtraConfig["ApiToken"],
|
||||||
|
ManagementEndpoint: config.ExtraConfig["ManagementEndpoint"],
|
||||||
|
}
|
||||||
|
return NewPocketIdManager(pocketidConfig, appMetrics)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
|
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user