mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 23:36:39 +00:00
Compare commits
11 Commits
v0.58.2
...
cli-ws-pro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fdf4f10d94 | ||
|
|
2b8a9f55c1 | ||
|
|
e7b5537dcc | ||
|
|
95794f53ce | ||
|
|
9bcd3ebed4 | ||
|
|
b85045e723 | ||
|
|
4d7e59f199 | ||
|
|
b5daec3b51 | ||
|
|
5e1a40c33f | ||
|
|
e8d301fdc9 | ||
|
|
17bab881f7 |
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros
|
||||
skip: go.mod,go.sum
|
||||
golangci:
|
||||
strategy:
|
||||
|
||||
67
.github/workflows/wasm-build-validation.yml
vendored
Normal file
67
.github/workflows/wasm-build-validation.yml
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
name: Wasm
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
js_lint:
|
||||
name: "JS / Lint"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: Install golangci-lint
|
||||
uses: golangci/golangci-lint-action@d6238b002a20823d52840fda27e2d4891c5952dc
|
||||
with:
|
||||
version: latest
|
||||
install-mode: binary
|
||||
skip-cache: true
|
||||
skip-pkg-cache: true
|
||||
skip-build-cache: true
|
||||
- name: Run golangci-lint for WASM
|
||||
run: |
|
||||
GOOS=js GOARCH=wasm golangci-lint run --timeout=12m --out-format colored-line-number ./client/...
|
||||
continue-on-error: true
|
||||
|
||||
js_build:
|
||||
name: "JS / Build"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
- name: Build Wasm client
|
||||
run: GOOS=js GOARCH=wasm go build -o netbird.wasm ./client/wasm/cmd
|
||||
env:
|
||||
CGO_ENABLED: 0
|
||||
- name: Check Wasm build size
|
||||
run: |
|
||||
echo "Wasm build size:"
|
||||
ls -lh netbird.wasm
|
||||
|
||||
SIZE=$(stat -c%s netbird.wasm)
|
||||
SIZE_MB=$((SIZE / 1024 / 1024))
|
||||
|
||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||
|
||||
if [ ${SIZE} -gt 52428800 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 50MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
0
.gitmodules
vendored
Normal file
0
.gitmodules
vendored
Normal file
@@ -2,6 +2,18 @@ version: 2
|
||||
|
||||
project_name: netbird
|
||||
builds:
|
||||
- id: netbird-wasm
|
||||
dir: client/wasm/cmd
|
||||
binary: netbird
|
||||
env: [GOOS=js, GOARCH=wasm, CGO_ENABLED=0]
|
||||
goos:
|
||||
- js
|
||||
goarch:
|
||||
- wasm
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||
|
||||
- id: netbird
|
||||
dir: client
|
||||
binary: netbird
|
||||
@@ -115,6 +127,11 @@ archives:
|
||||
- builds:
|
||||
- netbird
|
||||
- netbird-static
|
||||
- id: netbird-wasm
|
||||
builds:
|
||||
- netbird-wasm
|
||||
name_template: "{{ .ProjectName }}_{{ .Version }}"
|
||||
format: binary
|
||||
|
||||
nfpms:
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
|
||||
8
client/cmd/debug_js.go
Normal file
8
client/cmd/debug_js.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package cmd
|
||||
|
||||
import "context"
|
||||
|
||||
// SetupDebugHandler is a no-op for WASM
|
||||
func SetupDebugHandler(context.Context, interface{}, interface{}, interface{}, string) {
|
||||
// Debug handler not needed for WASM
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||
client "github.com/netbirdio/netbird/client/server"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -114,7 +116,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
}
|
||||
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -23,23 +23,29 @@ import (
|
||||
|
||||
var ErrClientAlreadyStarted = errors.New("client already started")
|
||||
var ErrClientNotStarted = errors.New("client not started")
|
||||
var ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
|
||||
// Client manages a netbird embedded client instance
|
||||
// Client manages a netbird embedded client instance.
|
||||
type Client struct {
|
||||
deviceName string
|
||||
config *profilemanager.Config
|
||||
mu sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
setupKey string
|
||||
jwtToken string
|
||||
connect *internal.ConnectClient
|
||||
}
|
||||
|
||||
// Options configures a new Client
|
||||
// Options configures a new Client.
|
||||
type Options struct {
|
||||
// DeviceName is this peer's name in the network
|
||||
DeviceName string
|
||||
// SetupKey is used for authentication
|
||||
SetupKey string
|
||||
// JWTToken is used for JWT-based authentication
|
||||
JWTToken string
|
||||
// PrivateKey is used for direct private key authentication
|
||||
PrivateKey string
|
||||
// ManagementURL overrides the default management server URL
|
||||
ManagementURL string
|
||||
// PreSharedKey is the pre-shared key for the WireGuard interface
|
||||
@@ -58,8 +64,35 @@ type Options struct {
|
||||
DisableClientRoutes bool
|
||||
}
|
||||
|
||||
// New creates a new netbird embedded client
|
||||
// validateCredentials checks that exactly one credential type is provided
|
||||
func (opts *Options) validateCredentials() error {
|
||||
credentialsProvided := 0
|
||||
if opts.SetupKey != "" {
|
||||
credentialsProvided++
|
||||
}
|
||||
if opts.JWTToken != "" {
|
||||
credentialsProvided++
|
||||
}
|
||||
if opts.PrivateKey != "" {
|
||||
credentialsProvided++
|
||||
}
|
||||
|
||||
if credentialsProvided == 0 {
|
||||
return fmt.Errorf("one of SetupKey, JWTToken, or PrivateKey must be provided")
|
||||
}
|
||||
if credentialsProvided > 1 {
|
||||
return fmt.Errorf("only one of SetupKey, JWTToken, or PrivateKey can be specified")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// New creates a new netbird embedded client.
|
||||
func New(opts Options) (*Client, error) {
|
||||
if err := opts.validateCredentials(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if opts.LogOutput != nil {
|
||||
logrus.SetOutput(opts.LogOutput)
|
||||
}
|
||||
@@ -107,9 +140,14 @@ func New(opts Options) (*Client, error) {
|
||||
return nil, fmt.Errorf("create config: %w", err)
|
||||
}
|
||||
|
||||
if opts.PrivateKey != "" {
|
||||
config.PrivateKey = opts.PrivateKey
|
||||
}
|
||||
|
||||
return &Client{
|
||||
deviceName: opts.DeviceName,
|
||||
setupKey: opts.SetupKey,
|
||||
jwtToken: opts.JWTToken,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
@@ -126,7 +164,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil {
|
||||
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
|
||||
return fmt.Errorf("login: %w", err)
|
||||
}
|
||||
|
||||
@@ -187,6 +225,16 @@ func (c *Client) Stop(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfig returns a copy of the internal client config.
|
||||
func (c *Client) GetConfig() (profilemanager.Config, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.config == nil {
|
||||
return profilemanager.Config{}, ErrConfigNotInitialized
|
||||
}
|
||||
return *c.config, nil
|
||||
}
|
||||
|
||||
// Dial dials a network address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
@@ -211,7 +259,7 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
|
||||
return nsnet.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// ListenTCP listens on the given address in the netbird network
|
||||
// ListenTCP listens on the given address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||
nsnet, addr, err := c.getNet()
|
||||
@@ -232,7 +280,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||
return nsnet.ListenTCP(tcpAddr)
|
||||
}
|
||||
|
||||
// ListenUDP listens on the given address in the netbird network
|
||||
// ListenUDP listens on the given address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
|
||||
nsnet, addr, err := c.getNet()
|
||||
|
||||
@@ -4,15 +4,9 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
@@ -20,37 +14,10 @@ import (
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
)
|
||||
|
||||
func WithCustomDialer() grpc.DialOption {
|
||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
if runtime.GOOS == "linux" {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
|
||||
}
|
||||
|
||||
// the custom dialer requires root permissions which are not required for use cases run as non-root
|
||||
if currentUser.Uid != "0" {
|
||||
log.Debug("Not running as root, using standard dialer")
|
||||
dialer := &net.Dialer{}
|
||||
return dialer.DialContext(ctx, "tcp", addr)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to dial: %s", err)
|
||||
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
})
|
||||
}
|
||||
|
||||
// grpcDialBackoff is the backoff mechanism for the grpc calls
|
||||
// Backoff returns a backoff configuration for gRPC calls
|
||||
func Backoff(ctx context.Context) backoff.BackOff {
|
||||
b := backoff.NewExponentialBackOff()
|
||||
b.MaxElapsedTime = 10 * time.Second
|
||||
@@ -58,7 +25,9 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
||||
return backoff.WithContext(b, ctx)
|
||||
}
|
||||
|
||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
||||
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||
if tlsEnabled {
|
||||
certPool, err := x509.SystemCertPool()
|
||||
@@ -68,7 +37,9 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.
|
||||
}
|
||||
|
||||
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
||||
RootCAs: certPool,
|
||||
// for js, outer websocket layer takes care of tls verification via WithCustomDialer
|
||||
InsecureSkipVerify: runtime.GOOS == "js",
|
||||
RootCAs: certPool,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -79,7 +50,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.
|
||||
connCtx,
|
||||
addr,
|
||||
transportOption,
|
||||
WithCustomDialer(),
|
||||
WithCustomDialer(tlsEnabled, component),
|
||||
grpc.WithBlock(),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 30 * time.Second,
|
||||
|
||||
44
client/grpc/dialer_generic.go
Normal file
44
client/grpc/dialer_generic.go
Normal file
@@ -0,0 +1,44 @@
|
||||
//go:build !js
|
||||
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/user"
|
||||
"runtime"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
|
||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
if runtime.GOOS == "linux" {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
|
||||
}
|
||||
|
||||
// the custom dialer requires root permissions which are not required for use cases run as non-root
|
||||
if currentUser.Uid != "0" {
|
||||
log.Debug("Not running as root, using standard dialer")
|
||||
dialer := &net.Dialer{}
|
||||
return dialer.DialContext(ctx, "tcp", addr)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to dial: %s", err)
|
||||
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
})
|
||||
}
|
||||
13
client/grpc/dialer_js.go
Normal file
13
client/grpc/dialer_js.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/util/wsproxy/client"
|
||||
)
|
||||
|
||||
// WithCustomDialer returns a gRPC dial option that uses WebSocket transport for WASM/JS environments.
|
||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
|
||||
return client.WithWebSocketDialer(tlsEnabled, component)
|
||||
}
|
||||
@@ -1,5 +1,17 @@
|
||||
package bind
|
||||
|
||||
import wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
import (
|
||||
"net"
|
||||
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
type Endpoint = wgConn.StdNetEndpoint
|
||||
|
||||
func EndpointToUDPAddr(e Endpoint) *net.UDPAddr {
|
||||
return &net.UDPAddr{
|
||||
IP: e.Addr().AsSlice(),
|
||||
Port: int(e.Port()),
|
||||
Zone: e.Addr().Zone(),
|
||||
}
|
||||
}
|
||||
|
||||
7
client/iface/bind/error.go
Normal file
7
client/iface/bind/error.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package bind
|
||||
|
||||
import "fmt"
|
||||
|
||||
var (
|
||||
ErrUDPMUXNotSupported = fmt.Errorf("UDPMUX is not supported in WASM")
|
||||
)
|
||||
@@ -1,6 +1,9 @@
|
||||
//go:build !js
|
||||
|
||||
package bind
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -20,11 +23,6 @@ import (
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
type RecvMessage struct {
|
||||
Endpoint *Endpoint
|
||||
Buffer []byte
|
||||
}
|
||||
|
||||
type receiverCreator struct {
|
||||
iceBind *ICEBind
|
||||
}
|
||||
@@ -42,37 +40,38 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD
|
||||
// use the port because in the Send function the wgConn.Endpoint the port info is not exported.
|
||||
type ICEBind struct {
|
||||
*wgConn.StdNetBind
|
||||
RecvChan chan RecvMessage
|
||||
|
||||
transportNet transport.Net
|
||||
filterFn udpmux.FilterFn
|
||||
endpoints map[netip.Addr]net.Conn
|
||||
endpointsMu sync.Mutex
|
||||
address wgaddr.Address
|
||||
mtu uint16
|
||||
|
||||
endpoints map[netip.Addr]net.Conn
|
||||
endpointsMu sync.Mutex
|
||||
recvChan chan recvMessage
|
||||
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
|
||||
// new closed channel. With the closedChanMu we can safely close the channel and create a new one
|
||||
closedChan chan struct{}
|
||||
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
|
||||
closed bool
|
||||
|
||||
muUDPMux sync.Mutex
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
address wgaddr.Address
|
||||
mtu uint16
|
||||
closedChan chan struct{}
|
||||
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
|
||||
closed bool
|
||||
activityRecorder *ActivityRecorder
|
||||
|
||||
muUDPMux sync.Mutex
|
||||
udpMux *udpmux.UniversalUDPMuxDefault
|
||||
}
|
||||
|
||||
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||
ib := &ICEBind{
|
||||
StdNetBind: b,
|
||||
RecvChan: make(chan RecvMessage, 1),
|
||||
transportNet: transportNet,
|
||||
filterFn: filterFn,
|
||||
address: address,
|
||||
mtu: mtu,
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
recvChan: make(chan recvMessage, 1),
|
||||
closedChan: make(chan struct{}),
|
||||
closed: true,
|
||||
mtu: mtu,
|
||||
address: address,
|
||||
activityRecorder: NewActivityRecorder(),
|
||||
}
|
||||
|
||||
@@ -83,10 +82,6 @@ func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wg
|
||||
return ib
|
||||
}
|
||||
|
||||
func (s *ICEBind) MTU() uint16 {
|
||||
return s.mtu
|
||||
}
|
||||
|
||||
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
||||
s.closed = false
|
||||
s.closedChanMu.Lock()
|
||||
@@ -139,6 +134,16 @@ func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
|
||||
delete(b.endpoints, fakeIP)
|
||||
}
|
||||
|
||||
func (b *ICEBind) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) {
|
||||
select {
|
||||
case <-b.closedChan:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case b.recvChan <- recvMessage{ep, buf}:
|
||||
}
|
||||
}
|
||||
|
||||
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||
b.endpointsMu.Lock()
|
||||
conn, ok := b.endpoints[ep.DstIP()]
|
||||
@@ -271,7 +276,7 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
|
||||
select {
|
||||
case <-c.closedChan:
|
||||
return 0, net.ErrClosed
|
||||
case msg, ok := <-c.RecvChan:
|
||||
case msg, ok := <-c.recvChan:
|
||||
if !ok {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
6
client/iface/bind/recv_msg.go
Normal file
6
client/iface/bind/recv_msg.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package bind
|
||||
|
||||
type recvMessage struct {
|
||||
Endpoint *Endpoint
|
||||
Buffer []byte
|
||||
}
|
||||
125
client/iface/bind/relay_bind.go
Normal file
125
client/iface/bind/relay_bind.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
)
|
||||
|
||||
// RelayBindJS is a conn.Bind implementation for WebAssembly environments.
|
||||
// Do not limit to build only js, because we want to be able to run tests
|
||||
type RelayBindJS struct {
|
||||
*conn.StdNetBind
|
||||
|
||||
recvChan chan recvMessage
|
||||
endpoints map[netip.Addr]net.Conn
|
||||
endpointsMu sync.Mutex
|
||||
activityRecorder *ActivityRecorder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewRelayBindJS() *RelayBindJS {
|
||||
return &RelayBindJS{
|
||||
recvChan: make(chan recvMessage, 100),
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
activityRecorder: NewActivityRecorder(),
|
||||
}
|
||||
}
|
||||
|
||||
// Open creates a receive function for handling relay packets in WASM.
|
||||
func (s *RelayBindJS) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
|
||||
log.Debugf("Open: creating receive function for port %d", uport)
|
||||
|
||||
s.ctx, s.cancel = context.WithCancel(context.Background())
|
||||
|
||||
receiveFn := func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return 0, net.ErrClosed
|
||||
case msg, ok := <-s.recvChan:
|
||||
if !ok {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
copy(bufs[0], msg.Buffer)
|
||||
sizes[0] = len(msg.Buffer)
|
||||
eps[0] = conn.Endpoint(msg.Endpoint)
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("Open: receive function created, returning port %d", uport)
|
||||
return []conn.ReceiveFunc{receiveFn}, uport, nil
|
||||
}
|
||||
|
||||
func (s *RelayBindJS) Close() error {
|
||||
if s.cancel == nil {
|
||||
return nil
|
||||
}
|
||||
log.Debugf("close RelayBindJS")
|
||||
s.cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RelayBindJS) ReceiveFromEndpoint(ctx context.Context, ep *Endpoint, buf []byte) {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case s.recvChan <- recvMessage{ep, buf}:
|
||||
}
|
||||
}
|
||||
|
||||
// Send forwards packets through the relay connection for WASM.
|
||||
func (s *RelayBindJS) Send(bufs [][]byte, ep conn.Endpoint) error {
|
||||
if ep == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
fakeIP := ep.DstIP()
|
||||
|
||||
s.endpointsMu.Lock()
|
||||
relayConn, ok := s.endpoints[fakeIP]
|
||||
s.endpointsMu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, buf := range bufs {
|
||||
if _, err := relayConn.Write(buf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *RelayBindJS) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
|
||||
b.endpointsMu.Lock()
|
||||
b.endpoints[fakeIP] = conn
|
||||
b.endpointsMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *RelayBindJS) RemoveEndpoint(fakeIP netip.Addr) {
|
||||
s.endpointsMu.Lock()
|
||||
defer s.endpointsMu.Unlock()
|
||||
|
||||
delete(s.endpoints, fakeIP)
|
||||
}
|
||||
|
||||
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
||||
func (s *RelayBindJS) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
return nil, ErrUDPMUXNotSupported
|
||||
}
|
||||
|
||||
func (s *RelayBindJS) ActivityRecorder() *ActivityRecorder {
|
||||
return s.activityRecorder
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build linux || windows || freebsd
|
||||
//go:build linux || windows || freebsd || js || wasip1
|
||||
|
||||
package configurer
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !windows
|
||||
//go:build !windows && !js
|
||||
|
||||
package configurer
|
||||
|
||||
|
||||
23
client/iface/configurer/uapi_js.go
Normal file
23
client/iface/configurer/uapi_js.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package configurer
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
type noopListener struct{}
|
||||
|
||||
func (n *noopListener) Accept() (net.Conn, error) {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
|
||||
func (n *noopListener) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *noopListener) Addr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func openUAPI(deviceName string) (net.Listener, error) {
|
||||
return &noopListener{}, nil
|
||||
}
|
||||
@@ -1,9 +1,11 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
@@ -15,6 +17,12 @@ import (
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
type Bind interface {
|
||||
conn.Bind
|
||||
GetICEMux() (*udpmux.UniversalUDPMuxDefault, error)
|
||||
ActivityRecorder() *bind.ActivityRecorder
|
||||
}
|
||||
|
||||
type TunNetstackDevice struct {
|
||||
name string
|
||||
address wgaddr.Address
|
||||
@@ -22,7 +30,7 @@ type TunNetstackDevice struct {
|
||||
key string
|
||||
mtu uint16
|
||||
listenAddress string
|
||||
iceBind *bind.ICEBind
|
||||
bind Bind
|
||||
|
||||
device *device.Device
|
||||
filteredDevice *FilteredDevice
|
||||
@@ -33,7 +41,7 @@ type TunNetstackDevice struct {
|
||||
net *netstack.Net
|
||||
}
|
||||
|
||||
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
|
||||
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, bind Bind, listenAddress string) *TunNetstackDevice {
|
||||
return &TunNetstackDevice{
|
||||
name: name,
|
||||
address: address,
|
||||
@@ -41,7 +49,7 @@ func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key stri
|
||||
key: key,
|
||||
mtu: mtu,
|
||||
listenAddress: listenAddress,
|
||||
iceBind: iceBind,
|
||||
bind: bind,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,11 +74,11 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
||||
|
||||
t.device = device.NewDevice(
|
||||
t.filteredDevice,
|
||||
t.iceBind,
|
||||
t.bind,
|
||||
device.NewLogger(wgLogLevel(), "[netbird] "),
|
||||
)
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
_ = tunIface.Close()
|
||||
@@ -91,11 +99,15 @@ func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
udpMux, err := t.iceBind.GetICEMux()
|
||||
if err != nil {
|
||||
udpMux, err := t.bind.GetICEMux()
|
||||
if err != nil && !errors.Is(err, bind.ErrUDPMUXNotSupported) {
|
||||
return nil, err
|
||||
}
|
||||
t.udpMux = udpMux
|
||||
|
||||
if udpMux != nil {
|
||||
t.udpMux = udpMux
|
||||
}
|
||||
|
||||
log.Debugf("netstack device is ready to use")
|
||||
return udpMux, nil
|
||||
}
|
||||
|
||||
27
client/iface/device/device_netstack_test.go
Normal file
27
client/iface/device/device_netstack_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
func TestNewNetstackDevice(t *testing.T) {
|
||||
privateKey, _ := wgtypes.GeneratePrivateKey()
|
||||
wgAddress, _ := wgaddr.ParseWGAddress("1.2.3.4/24")
|
||||
|
||||
relayBind := bind.NewRelayBindJS()
|
||||
nsTun := NewNetstackDevice("wtx", wgAddress, 1234, privateKey.String(), 1500, relayBind, netstack.ListenAddr())
|
||||
|
||||
cfgr, err := nsTun.Create()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create netstack device: %v", err)
|
||||
}
|
||||
if cfgr == nil {
|
||||
t.Fatal("expected non-nil configurer")
|
||||
}
|
||||
}
|
||||
6
client/iface/iface_destroy_js.go
Normal file
6
client/iface/iface_destroy_js.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package iface
|
||||
|
||||
// Destroy is a no-op on WASM
|
||||
func (w *WGIface) Destroy() error {
|
||||
return nil
|
||||
}
|
||||
@@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
}
|
||||
return wgIFace, nil
|
||||
}
|
||||
@@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
}
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
tun: tun,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
}
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
41
client/iface/iface_new_freebsd.go
Normal file
41
client/iface/iface_new_freebsd.go
Normal file
@@ -0,0 +1,41 @@
|
||||
//go:build freebsd
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wgIFace := &WGIface{}
|
||||
|
||||
if netstack.IsEnabled() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
if device.ModuleTunIsLoaded() {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("couldn't check or load tun module")
|
||||
}
|
||||
@@ -21,7 +21,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgIFace := &WGIface{
|
||||
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),
|
||||
userspaceBind: true,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
}
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
27
client/iface/iface_new_js.go
Normal file
27
client/iface/iface_new_js.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package iface
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// NewWGIFace creates a new WireGuard interface for WASM (always uses netstack mode)
|
||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
relayBind := bind.NewRelayBindJS()
|
||||
|
||||
wgIface := &WGIface{
|
||||
tun: device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, relayBind, netstack.ListenAddr()),
|
||||
userspaceBind: true,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(relayBind, opts.MTU),
|
||||
}
|
||||
|
||||
return wgIface, nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
//go:build linux && !android
|
||||
|
||||
package iface
|
||||
|
||||
@@ -25,7 +25,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU)
|
||||
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||
wgIFace.userspaceBind = true
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind, opts.MTU)
|
||||
return wgIFace, nil
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
tun: tun,
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind, opts.MTU),
|
||||
}
|
||||
return wgIFace, nil
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !js
|
||||
|
||||
package netstack
|
||||
|
||||
import (
|
||||
|
||||
12
client/iface/netstack/env_js.go
Normal file
12
client/iface/netstack/env_js.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package netstack
|
||||
|
||||
const EnvUseNetstackMode = "NB_USE_NETSTACK_MODE"
|
||||
|
||||
// IsEnabled always returns true for js since it's the only mode available
|
||||
func IsEnabled() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func ListenAddr() string {
|
||||
return ""
|
||||
}
|
||||
@@ -16,28 +16,38 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
|
||||
)
|
||||
|
||||
type ProxyBind struct {
|
||||
Bind *bind.ICEBind
|
||||
|
||||
fakeNetIP *netip.AddrPort
|
||||
wgBindEndpoint *bind.Endpoint
|
||||
remoteConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
|
||||
closeListener *listener.CloseListener
|
||||
type Bind interface {
|
||||
SetEndpoint(addr netip.Addr, conn net.Conn)
|
||||
RemoveEndpoint(addr netip.Addr)
|
||||
ReceiveFromEndpoint(ctx context.Context, ep *bind.Endpoint, buf []byte)
|
||||
}
|
||||
|
||||
func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
|
||||
type ProxyBind struct {
|
||||
bind Bind
|
||||
|
||||
// wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address
|
||||
wgRelayedEndpoint *bind.Endpoint
|
||||
wgCurrentUsed *bind.Endpoint
|
||||
remoteConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
|
||||
paused bool
|
||||
pausedCond *sync.Cond
|
||||
isStarted bool
|
||||
|
||||
closeListener *listener.CloseListener
|
||||
mtu uint16
|
||||
}
|
||||
|
||||
func NewProxyBind(bind Bind, mtu uint16) *ProxyBind {
|
||||
p := &ProxyBind{
|
||||
Bind: bind,
|
||||
bind: bind,
|
||||
closeListener: listener.NewCloseListener(),
|
||||
pausedCond: sync.NewCond(&sync.Mutex{}),
|
||||
mtu: mtu + bufsize.WGBufferOverhead,
|
||||
}
|
||||
|
||||
return p
|
||||
@@ -46,25 +56,25 @@ func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
|
||||
// AddTurnConn adds a new connection to the bind.
|
||||
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
|
||||
// WireGuard configuration.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context is used for proxyToLocal to avoid unnecessary error messages
|
||||
// - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address
|
||||
// - remoteConn: The established TURN connection to the remote peer
|
||||
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
|
||||
fakeNetIP, err := fakeAddress(nbAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.fakeNetIP = fakeNetIP
|
||||
p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
|
||||
p.wgRelayedEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
|
||||
p.remoteConn = remoteConn
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
|
||||
return &net.UDPAddr{
|
||||
IP: p.fakeNetIP.Addr().AsSlice(),
|
||||
Port: int(p.fakeNetIP.Port()),
|
||||
Zone: p.fakeNetIP.Addr().Zone(),
|
||||
}
|
||||
return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint)
|
||||
}
|
||||
|
||||
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
|
||||
@@ -76,17 +86,21 @@ func (p *ProxyBind) Work() {
|
||||
return
|
||||
}
|
||||
|
||||
p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn)
|
||||
p.bind.SetEndpoint(p.wgRelayedEndpoint.Addr(), p.remoteConn)
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
p.wgCurrentUsed = p.wgRelayedEndpoint
|
||||
|
||||
// Start the proxy only once
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
go p.proxyToLocal(p.ctx)
|
||||
}
|
||||
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
func (p *ProxyBind) Pause() {
|
||||
@@ -94,9 +108,25 @@ func (p *ProxyBind) Pause() {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = true
|
||||
p.pausedMu.Unlock()
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
p.wgCurrentUsed = addrToEndpoint(endpoint)
|
||||
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
|
||||
ip, _ := netip.AddrFromSlice(addr.IP.To4())
|
||||
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
|
||||
return &bind.Endpoint{AddrPort: addrPort}
|
||||
}
|
||||
|
||||
func (p *ProxyBind) CloseConn() error {
|
||||
@@ -107,6 +137,10 @@ func (p *ProxyBind) CloseConn() error {
|
||||
}
|
||||
|
||||
func (p *ProxyBind) close() error {
|
||||
if p.remoteConn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.closeMu.Lock()
|
||||
defer p.closeMu.Unlock()
|
||||
|
||||
@@ -120,7 +154,12 @@ func (p *ProxyBind) close() error {
|
||||
|
||||
p.cancel()
|
||||
|
||||
p.Bind.RemoveEndpoint(p.fakeNetIP.Addr())
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
|
||||
p.bind.RemoveEndpoint(p.wgRelayedEndpoint.Addr())
|
||||
|
||||
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
|
||||
return rErr
|
||||
@@ -136,7 +175,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
}()
|
||||
|
||||
for {
|
||||
buf := make([]byte, p.Bind.MTU()+bufsize.WGBufferOverhead)
|
||||
buf := make([]byte, p.mtu)
|
||||
n, err := p.remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
@@ -147,18 +186,13 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
if p.paused {
|
||||
p.pausedMu.Unlock()
|
||||
continue
|
||||
p.pausedCond.L.Lock()
|
||||
for p.paused {
|
||||
p.pausedCond.Wait()
|
||||
}
|
||||
|
||||
msg := bind.RecvMessage{
|
||||
Endpoint: p.wgBindEndpoint,
|
||||
Buffer: buf[:n],
|
||||
}
|
||||
p.Bind.RecvChan <- msg
|
||||
p.pausedMu.Unlock()
|
||||
p.bind.ReceiveFromEndpoint(ctx, p.wgCurrentUsed, buf[:n])
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,9 +6,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
@@ -18,6 +16,7 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/iface/bufsize"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
@@ -27,6 +26,10 @@ const (
|
||||
loopbackAddr = "127.0.0.1"
|
||||
)
|
||||
|
||||
var (
|
||||
localHostNetIP = net.ParseIP("127.0.0.1")
|
||||
)
|
||||
|
||||
// WGEBPFProxy definition for proxy with EBPF support
|
||||
type WGEBPFProxy struct {
|
||||
localWGListenPort int
|
||||
@@ -64,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error {
|
||||
return err
|
||||
}
|
||||
|
||||
p.rawConn, err = p.prepareSenderRawSocket()
|
||||
p.rawConn, err = rawsocket.PrepareSenderRawSocket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -214,57 +217,17 @@ generatePort:
|
||||
return p.lastUsedPort, nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
|
||||
// Create a raw socket.
|
||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||
}
|
||||
|
||||
// Bind the socket to the "lo" interface.
|
||||
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the fwmark on the socket.
|
||||
err = nbnet.SetSocketOpt(fd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||
}
|
||||
|
||||
// Convert the file descriptor to a PacketConn.
|
||||
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
||||
if file == nil {
|
||||
return nil, fmt.Errorf("converting fd to file failed")
|
||||
}
|
||||
packetConn, err := net.FilePacketConn(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
||||
}
|
||||
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
|
||||
localhost := net.ParseIP("127.0.0.1")
|
||||
|
||||
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
|
||||
payload := gopacket.Payload(data)
|
||||
ipH := &layers.IPv4{
|
||||
DstIP: localhost,
|
||||
SrcIP: localhost,
|
||||
DstIP: localHostNetIP,
|
||||
SrcIP: endpointAddr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(port),
|
||||
SrcPort: layers.UDPPort(endpointAddr.Port),
|
||||
DstPort: layers.UDPPort(p.localWGListenPort),
|
||||
}
|
||||
|
||||
@@ -279,7 +242,7 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil {
|
||||
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
|
||||
return fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -18,41 +18,42 @@ import (
|
||||
|
||||
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
|
||||
type ProxyWrapper struct {
|
||||
WgeBPFProxy *WGEBPFProxy
|
||||
wgeBPFProxy *WGEBPFProxy
|
||||
|
||||
remoteConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
wgEndpointAddr *net.UDPAddr
|
||||
wgRelayedEndpointAddr *net.UDPAddr
|
||||
wgEndpointCurrentUsedAddr *net.UDPAddr
|
||||
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
paused bool
|
||||
pausedCond *sync.Cond
|
||||
isStarted bool
|
||||
|
||||
closeListener *listener.CloseListener
|
||||
}
|
||||
|
||||
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper {
|
||||
func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
|
||||
return &ProxyWrapper{
|
||||
WgeBPFProxy: WgeBPFProxy,
|
||||
wgeBPFProxy: proxy,
|
||||
pausedCond: sync.NewCond(&sync.Mutex{}),
|
||||
closeListener: listener.NewCloseListener(),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
|
||||
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
|
||||
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add turn conn: %w", err)
|
||||
}
|
||||
p.remoteConn = remoteConn
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
p.wgEndpointAddr = addr
|
||||
p.wgRelayedEndpointAddr = addr
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
|
||||
return p.wgEndpointAddr
|
||||
return p.wgRelayedEndpointAddr
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
|
||||
@@ -64,14 +65,18 @@ func (p *ProxyWrapper) Work() {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
p.pausedMu.Unlock()
|
||||
|
||||
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
|
||||
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
go p.proxyToLocal(p.ctx)
|
||||
}
|
||||
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) Pause() {
|
||||
@@ -80,45 +85,59 @@ func (p *ProxyWrapper) Pause() {
|
||||
}
|
||||
|
||||
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
|
||||
p.pausedMu.Lock()
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = true
|
||||
p.pausedMu.Unlock()
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
|
||||
p.wgEndpointCurrentUsedAddr = endpoint
|
||||
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
||||
func (e *ProxyWrapper) CloseConn() error {
|
||||
if e.cancel == nil {
|
||||
func (p *ProxyWrapper) CloseConn() error {
|
||||
if p.cancel == nil {
|
||||
return fmt.Errorf("proxy not started")
|
||||
}
|
||||
|
||||
e.cancel()
|
||||
p.cancel()
|
||||
|
||||
e.closeListener.SetCloseListener(nil)
|
||||
p.closeListener.SetCloseListener(nil)
|
||||
|
||||
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
return fmt.Errorf("close remote conn: %w", err)
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
|
||||
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
return fmt.Errorf("failed to close remote conn: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
|
||||
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
|
||||
defer p.wgeBPFProxy.removeTurnConn(uint16(p.wgRelayedEndpointAddr.Port))
|
||||
|
||||
buf := make([]byte, p.WgeBPFProxy.mtu+bufsize.WGBufferOverhead)
|
||||
buf := make([]byte, p.wgeBPFProxy.mtu+bufsize.WGBufferOverhead)
|
||||
for {
|
||||
n, err := p.readFromRemote(ctx, buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
if p.paused {
|
||||
p.pausedMu.Unlock()
|
||||
continue
|
||||
p.pausedCond.L.Lock()
|
||||
for p.paused {
|
||||
p.pausedCond.Wait()
|
||||
}
|
||||
|
||||
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
|
||||
p.pausedMu.Unlock()
|
||||
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
|
||||
p.pausedCond.L.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
@@ -137,7 +156,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
|
||||
}
|
||||
p.closeListener.Notify()
|
||||
if !errors.Is(err, io.EOF) {
|
||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
|
||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgRelayedEndpointAddr.Port, err)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -39,7 +39,6 @@ func (w *KernelFactory) GetProxy() Proxy {
|
||||
}
|
||||
|
||||
return ebpf.NewProxyWrapper(w.ebpfProxy)
|
||||
|
||||
}
|
||||
|
||||
func (w *KernelFactory) Free() error {
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||
)
|
||||
|
||||
// KernelFactory todo: check eBPF support on FreeBSD
|
||||
type KernelFactory struct {
|
||||
wgPort int
|
||||
mtu uint16
|
||||
}
|
||||
|
||||
func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
|
||||
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
|
||||
f := &KernelFactory{
|
||||
wgPort: wgPort,
|
||||
mtu: mtu,
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
func (w *KernelFactory) GetProxy() Proxy {
|
||||
return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu)
|
||||
}
|
||||
|
||||
func (w *KernelFactory) Free() error {
|
||||
return nil
|
||||
}
|
||||
@@ -3,24 +3,25 @@ package wgproxy
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
|
||||
)
|
||||
|
||||
type USPFactory struct {
|
||||
bind *bind.ICEBind
|
||||
bind proxyBind.Bind
|
||||
mtu uint16
|
||||
}
|
||||
|
||||
func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
|
||||
func NewUSPFactory(bind proxyBind.Bind, mtu uint16) *USPFactory {
|
||||
log.Infof("WireGuard Proxy Factory will produce bind proxy")
|
||||
f := &USPFactory{
|
||||
bind: iceBind,
|
||||
bind: bind,
|
||||
mtu: mtu,
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func (w *USPFactory) GetProxy() Proxy {
|
||||
return proxyBind.NewProxyBind(w.bind)
|
||||
return proxyBind.NewProxyBind(w.bind, w.mtu)
|
||||
}
|
||||
|
||||
func (w *USPFactory) Free() error {
|
||||
|
||||
@@ -11,6 +11,11 @@ type Proxy interface {
|
||||
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
|
||||
Work() // Work start or resume the proxy
|
||||
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
|
||||
|
||||
//RedirectAs resume the forwarding the packages from relayed connection to WireGuard interface if it was paused
|
||||
//and rewrite the src address to the endpoint address.
|
||||
//With this logic can avoid the package loss from relayed connections.
|
||||
RedirectAs(endpoint *net.UDPAddr)
|
||||
CloseConn() error
|
||||
SetDisconnectListener(disconnected func())
|
||||
}
|
||||
|
||||
@@ -3,54 +3,82 @@
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||
)
|
||||
|
||||
func TestProxyCloseByRemoteConnEBPF(t *testing.T) {
|
||||
if os.Getenv("GITHUB_ACTIONS") != "true" {
|
||||
t.Skip("Skipping test as it requires root privileges")
|
||||
}
|
||||
ctx := context.Background()
|
||||
func seedProxies() ([]proxyInstance, error) {
|
||||
pl := make([]proxyInstance, 0)
|
||||
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %s", err)
|
||||
return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
proxy Proxy
|
||||
}{
|
||||
{
|
||||
name: "ebpf proxy",
|
||||
proxy: &ebpf.ProxyWrapper{
|
||||
WgeBPFProxy: ebpfProxy,
|
||||
},
|
||||
},
|
||||
pEbpf := proxyInstance{
|
||||
name: "ebpf kernel proxy",
|
||||
proxy: ebpf.NewProxyWrapper(ebpfProxy),
|
||||
wgPort: 51831,
|
||||
closeFn: ebpfProxy.Free,
|
||||
}
|
||||
pl = append(pl, pEbpf)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
relayedConn := newMockConn()
|
||||
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
|
||||
if err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
|
||||
_ = relayedConn.Close()
|
||||
if err := tt.proxy.CloseConn(); err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
})
|
||||
pUDP := proxyInstance{
|
||||
name: "udp kernel proxy",
|
||||
proxy: udp.NewWGUDPProxy(51832, 1280),
|
||||
wgPort: 51832,
|
||||
closeFn: func() error { return nil },
|
||||
}
|
||||
pl = append(pl, pUDP)
|
||||
return pl, nil
|
||||
}
|
||||
|
||||
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
|
||||
pl := make([]proxyInstance, 0)
|
||||
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
|
||||
}
|
||||
|
||||
pEbpf := proxyInstance{
|
||||
name: "ebpf kernel proxy",
|
||||
proxy: ebpf.NewProxyWrapper(ebpfProxy),
|
||||
wgPort: 51831,
|
||||
closeFn: ebpfProxy.Free,
|
||||
}
|
||||
pl = append(pl, pEbpf)
|
||||
|
||||
pUDP := proxyInstance{
|
||||
name: "udp kernel proxy",
|
||||
proxy: udp.NewWGUDPProxy(51832, 1280),
|
||||
wgPort: 51832,
|
||||
closeFn: func() error { return nil },
|
||||
}
|
||||
pl = append(pl, pUDP)
|
||||
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
|
||||
endpointAddress := &net.UDPAddr{
|
||||
IP: net.IPv4(10, 0, 0, 1),
|
||||
Port: 1234,
|
||||
}
|
||||
|
||||
pBind := proxyInstance{
|
||||
name: "bind proxy",
|
||||
proxy: bindproxy.NewProxyBind(iceBind, 0),
|
||||
endpointAddr: endpointAddress,
|
||||
closeFn: func() error { return nil },
|
||||
}
|
||||
pl = append(pl, pBind)
|
||||
|
||||
return pl, nil
|
||||
}
|
||||
|
||||
39
client/iface/wgproxy/proxy_seed_test.go
Normal file
39
client/iface/wgproxy/proxy_seed_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
//go:build !linux
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
|
||||
)
|
||||
|
||||
func seedProxies() ([]proxyInstance, error) {
|
||||
// todo extend with Bind proxy
|
||||
pl := make([]proxyInstance, 0)
|
||||
return pl, nil
|
||||
}
|
||||
|
||||
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
|
||||
pl := make([]proxyInstance, 0)
|
||||
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
|
||||
endpointAddress := &net.UDPAddr{
|
||||
IP: net.IPv4(10, 0, 0, 1),
|
||||
Port: 1234,
|
||||
}
|
||||
|
||||
pBind := proxyInstance{
|
||||
name: "bind proxy",
|
||||
proxy: bindproxy.NewProxyBind(iceBind, 0),
|
||||
endpointAddr: endpointAddress,
|
||||
closeFn: func() error { return nil },
|
||||
}
|
||||
pl = append(pl, pBind)
|
||||
return pl, nil
|
||||
}
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build linux
|
||||
|
||||
package wgproxy
|
||||
|
||||
import (
|
||||
@@ -7,12 +5,9 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
|
||||
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -22,6 +17,14 @@ func TestMain(m *testing.M) {
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
type proxyInstance struct {
|
||||
name string
|
||||
proxy Proxy
|
||||
wgPort int
|
||||
endpointAddr *net.UDPAddr
|
||||
closeFn func() error
|
||||
}
|
||||
|
||||
type mocConn struct {
|
||||
closeChan chan struct{}
|
||||
closed bool
|
||||
@@ -78,41 +81,21 @@ func (m *mocConn) SetWriteDeadline(t time.Time) error {
|
||||
func TestProxyCloseByRemoteConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
proxy Proxy
|
||||
}{
|
||||
{
|
||||
name: "userspace proxy",
|
||||
proxy: udpProxy.NewWGUDPProxy(51830, 1280),
|
||||
},
|
||||
tests, err := seedProxyForProxyCloseByRemoteConn()
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
|
||||
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
|
||||
if err := ebpfProxy.Listen(); err != nil {
|
||||
t.Fatalf("failed to initialize ebpf proxy: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := ebpfProxy.Free(); err != nil {
|
||||
t.Errorf("failed to free ebpf proxy: %s", err)
|
||||
}
|
||||
}()
|
||||
proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
|
||||
|
||||
tests = append(tests, struct {
|
||||
name string
|
||||
proxy Proxy
|
||||
}{
|
||||
name: "ebpf proxy",
|
||||
proxy: proxyWrapper,
|
||||
})
|
||||
}
|
||||
relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
|
||||
defer func() {
|
||||
_ = relayedConn.Close()
|
||||
}()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr, _ := net.ResolveUDPAddr("udp", "100.108.135.221:51892")
|
||||
relayedConn := newMockConn()
|
||||
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
|
||||
err := tt.proxy.AddTurnConn(ctx, addr, relayedConn)
|
||||
if err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
@@ -124,3 +107,104 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProxyRedirect todo extend the proxies with Bind proxy
|
||||
func TestProxyRedirect(t *testing.T) {
|
||||
tests, err := seedProxies()
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
redirectTraffic(t, tt.proxy, tt.wgPort, tt.endpointAddr)
|
||||
if err := tt.closeFn(); err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func redirectTraffic(t *testing.T, proxy Proxy, wgPort int, endPointAddr *net.UDPAddr) {
|
||||
t.Helper()
|
||||
|
||||
msgHelloFromRelay := []byte("hello from relay")
|
||||
msgRedirected := [][]byte{
|
||||
[]byte("hello 1. to p2p"),
|
||||
[]byte("hello 2. to p2p"),
|
||||
[]byte("hello 3. to p2p"),
|
||||
}
|
||||
|
||||
dummyWgListener, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: wgPort})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to listen on udp port: %s", err)
|
||||
}
|
||||
|
||||
relayedServer, _ := net.ListenUDP("udp",
|
||||
&net.UDPAddr{
|
||||
IP: net.IPv4(127, 0, 0, 1),
|
||||
Port: 1234,
|
||||
},
|
||||
)
|
||||
|
||||
relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
|
||||
|
||||
defer func() {
|
||||
_ = dummyWgListener.Close()
|
||||
_ = relayedConn.Close()
|
||||
_ = relayedServer.Close()
|
||||
}()
|
||||
|
||||
if err := proxy.AddTurnConn(context.Background(), endPointAddr, relayedConn); err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := proxy.CloseConn(); err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
proxy.Work()
|
||||
|
||||
if _, err := relayedServer.WriteTo(msgHelloFromRelay, relayedConn.LocalAddr()); err != nil {
|
||||
t.Errorf("error relayedServer.Write(msgHelloFromRelay): %v", err)
|
||||
}
|
||||
|
||||
n, err := dummyWgListener.Read(make([]byte, 1024))
|
||||
if err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
|
||||
if n != len(msgHelloFromRelay) {
|
||||
t.Errorf("expected %d bytes, got %d", len(msgHelloFromRelay), n)
|
||||
}
|
||||
|
||||
p2pEndpointAddr := &net.UDPAddr{
|
||||
IP: net.IPv4(192, 168, 0, 56),
|
||||
Port: 1234,
|
||||
}
|
||||
proxy.RedirectAs(p2pEndpointAddr)
|
||||
|
||||
for _, msg := range msgRedirected {
|
||||
if _, err := relayedServer.WriteTo(msg, relayedConn.LocalAddr()); err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < len(msgRedirected); i++ {
|
||||
buf := make([]byte, 1024)
|
||||
n, rAddr, err := dummyWgListener.ReadFrom(buf)
|
||||
if err != nil {
|
||||
t.Errorf("error: %v", err)
|
||||
}
|
||||
|
||||
if rAddr.String() != p2pEndpointAddr.String() {
|
||||
t.Errorf("expected %s, got %s", p2pEndpointAddr.String(), rAddr.String())
|
||||
}
|
||||
if string(buf[:n]) != string(msgRedirected[i]) {
|
||||
t.Errorf("expected %s, got %s", string(msgRedirected[i]), string(buf[:n]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
50
client/iface/wgproxy/rawsocket/rawsocket.go
Normal file
50
client/iface/wgproxy/rawsocket/rawsocket.go
Normal file
@@ -0,0 +1,50 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package rawsocket
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func PrepareSenderRawSocket() (net.PacketConn, error) {
|
||||
// Create a raw socket.
|
||||
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating raw socket failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
|
||||
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
|
||||
}
|
||||
|
||||
// Bind the socket to the "lo" interface.
|
||||
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
||||
}
|
||||
|
||||
// Set the fwmark on the socket.
|
||||
err = nbnet.SetSocketOpt(fd)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
||||
}
|
||||
|
||||
// Convert the file descriptor to a PacketConn.
|
||||
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
||||
if file == nil {
|
||||
return nil, fmt.Errorf("converting fd to file failed")
|
||||
}
|
||||
packetConn, err := net.FilePacketConn(file)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
|
||||
}
|
||||
|
||||
return packetConn, nil
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
@@ -21,16 +23,18 @@ type WGUDPProxy struct {
|
||||
localWGListenPort int
|
||||
mtu uint16
|
||||
|
||||
remoteConn net.Conn
|
||||
localConn net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
remoteConn net.Conn
|
||||
localConn net.Conn
|
||||
srcFakerConn *SrcFaker
|
||||
sendPkg func(data []byte) (int, error)
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
|
||||
pausedMu sync.Mutex
|
||||
paused bool
|
||||
isStarted bool
|
||||
paused bool
|
||||
pausedCond *sync.Cond
|
||||
isStarted bool
|
||||
|
||||
closeListener *listener.CloseListener
|
||||
}
|
||||
@@ -41,6 +45,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
|
||||
p := &WGUDPProxy{
|
||||
localWGListenPort: wgPort,
|
||||
mtu: mtu,
|
||||
pausedCond: sync.NewCond(&sync.Mutex{}),
|
||||
closeListener: listener.NewCloseListener(),
|
||||
}
|
||||
return p
|
||||
@@ -61,6 +66,7 @@ func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, rem
|
||||
|
||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||
p.localConn = localConn
|
||||
p.sendPkg = p.localConn.Write
|
||||
p.remoteConn = remoteConn
|
||||
|
||||
return err
|
||||
@@ -84,15 +90,24 @@ func (p *WGUDPProxy) Work() {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
p.pausedMu.Unlock()
|
||||
p.sendPkg = p.localConn.Write
|
||||
|
||||
if p.srcFakerConn != nil {
|
||||
if err := p.srcFakerConn.Close(); err != nil {
|
||||
log.Errorf("failed to close src faker conn: %s", err)
|
||||
}
|
||||
p.srcFakerConn = nil
|
||||
}
|
||||
|
||||
if !p.isStarted {
|
||||
p.isStarted = true
|
||||
go p.proxyToRemote(p.ctx)
|
||||
go p.proxyToLocal(p.ctx)
|
||||
}
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
// Pause pauses the proxy from receiving data from the remote peer
|
||||
@@ -101,9 +116,35 @@ func (p *WGUDPProxy) Pause() {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = true
|
||||
p.pausedMu.Unlock()
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
// RedirectAs start to use the fake sourced raw socket as package sender
|
||||
func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) {
|
||||
p.pausedCond.L.Lock()
|
||||
defer func() {
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
}()
|
||||
|
||||
p.paused = false
|
||||
if p.srcFakerConn != nil {
|
||||
if err := p.srcFakerConn.Close(); err != nil {
|
||||
log.Errorf("failed to close src faker conn: %s", err)
|
||||
}
|
||||
p.srcFakerConn = nil
|
||||
}
|
||||
srcFakerConn, err := NewSrcFaker(p.localWGListenPort, endpoint)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create src faker conn: %s", err)
|
||||
// fallback to continue without redirecting
|
||||
p.paused = true
|
||||
return
|
||||
}
|
||||
p.srcFakerConn = srcFakerConn
|
||||
p.sendPkg = p.srcFakerConn.SendPkg
|
||||
}
|
||||
|
||||
// CloseConn close the localConn
|
||||
@@ -115,6 +156,8 @@ func (p *WGUDPProxy) CloseConn() error {
|
||||
}
|
||||
|
||||
func (p *WGUDPProxy) close() error {
|
||||
var result *multierror.Error
|
||||
|
||||
p.closeMu.Lock()
|
||||
defer p.closeMu.Unlock()
|
||||
|
||||
@@ -128,7 +171,11 @@ func (p *WGUDPProxy) close() error {
|
||||
|
||||
p.cancel()
|
||||
|
||||
var result *multierror.Error
|
||||
p.pausedCond.L.Lock()
|
||||
p.paused = false
|
||||
p.pausedCond.Signal()
|
||||
p.pausedCond.L.Unlock()
|
||||
|
||||
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
|
||||
}
|
||||
@@ -136,6 +183,13 @@ func (p *WGUDPProxy) close() error {
|
||||
if err := p.localConn.Close(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
|
||||
}
|
||||
|
||||
if p.srcFakerConn != nil {
|
||||
if err := p.srcFakerConn.Close(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("src faker raw conn: %s", err))
|
||||
}
|
||||
}
|
||||
|
||||
return cerrors.FormatErrorOrNil(result)
|
||||
}
|
||||
|
||||
@@ -194,14 +248,12 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
p.pausedMu.Lock()
|
||||
if p.paused {
|
||||
p.pausedMu.Unlock()
|
||||
continue
|
||||
p.pausedCond.L.Lock()
|
||||
for p.paused {
|
||||
p.pausedCond.Wait()
|
||||
}
|
||||
|
||||
_, err = p.localConn.Write(buf[:n])
|
||||
p.pausedMu.Unlock()
|
||||
_, err = p.sendPkg(buf[:n])
|
||||
p.pausedCond.L.Unlock()
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
|
||||
101
client/iface/wgproxy/udp/rawsocket.go
Normal file
101
client/iface/wgproxy/udp/rawsocket.go
Normal file
@@ -0,0 +1,101 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package udp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
|
||||
)
|
||||
|
||||
var (
|
||||
serializeOpts = gopacket.SerializeOptions{
|
||||
ComputeChecksums: true,
|
||||
FixLengths: true,
|
||||
}
|
||||
|
||||
localHostNetIPAddr = &net.IPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
}
|
||||
)
|
||||
|
||||
type SrcFaker struct {
|
||||
srcAddr *net.UDPAddr
|
||||
|
||||
rawSocket net.PacketConn
|
||||
ipH gopacket.SerializableLayer
|
||||
udpH gopacket.SerializableLayer
|
||||
layerBuffer gopacket.SerializeBuffer
|
||||
}
|
||||
|
||||
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
|
||||
rawSocket, err := rawsocket.PrepareSenderRawSocket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f := &SrcFaker{
|
||||
srcAddr: srcAddr,
|
||||
rawSocket: rawSocket,
|
||||
ipH: ipH,
|
||||
udpH: udpH,
|
||||
layerBuffer: gopacket.NewSerializeBuffer(),
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (f *SrcFaker) Close() error {
|
||||
return f.rawSocket.Close()
|
||||
}
|
||||
|
||||
func (f *SrcFaker) SendPkg(data []byte) (int, error) {
|
||||
defer func() {
|
||||
if err := f.layerBuffer.Clear(); err != nil {
|
||||
log.Errorf("failed to clear layer buffer: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
payload := gopacket.Payload(data)
|
||||
|
||||
err := gopacket.SerializeLayers(f.layerBuffer, serializeOpts, f.ipH, f.udpH, payload)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("serialize layers: %w", err)
|
||||
}
|
||||
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("write to raw conn: %w", err)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
|
||||
ipH := &layers.IPv4{
|
||||
DstIP: net.ParseIP("127.0.0.1"),
|
||||
SrcIP: srcAddr.IP,
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolUDP,
|
||||
}
|
||||
udpH := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(srcAddr.Port),
|
||||
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
|
||||
}
|
||||
|
||||
err := udpH.SetNetworkLayerForChecksum(ipH)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
|
||||
}
|
||||
|
||||
return ipH, udpH, nil
|
||||
}
|
||||
@@ -240,15 +240,19 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
|
||||
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
||||
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
||||
for i, domain := range domains {
|
||||
policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
||||
if r.gpo {
|
||||
policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
||||
}
|
||||
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
||||
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
||||
|
||||
singleDomain := []string{domain}
|
||||
|
||||
if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err)
|
||||
if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
|
||||
}
|
||||
|
||||
if r.gpo {
|
||||
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("added NRPT entry for domain: %s", domain)
|
||||
@@ -401,6 +405,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error {
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err))
|
||||
}
|
||||
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err))
|
||||
}
|
||||
@@ -412,6 +417,7 @@ func (r *registryConfigurator) removeDNSMatchPolicies() error {
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err))
|
||||
}
|
||||
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err))
|
||||
}
|
||||
|
||||
5
client/internal/dns/server_js.go
Normal file
5
client/internal/dns/server_js.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package dns
|
||||
|
||||
func (s *DefaultServer) initialize() (hostManager, error) {
|
||||
return &noopHostConfigurator{}, nil
|
||||
}
|
||||
19
client/internal/dns/unclean_shutdown_js.go
Normal file
19
client/internal/dns/unclean_shutdown_js.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type ShutdownState struct{}
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "dns_state"
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ShutdownState) RestoreUncleanShutdownConfigs(context.Context) error {
|
||||
return nil
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -11,14 +12,18 @@ import (
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
var (
|
||||
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
|
||||
listenPort uint16 = 5353
|
||||
listenPortMu sync.RWMutex
|
||||
)
|
||||
|
||||
const (
|
||||
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
|
||||
ListenPort = 5353
|
||||
dnsTTL = 60 //seconds
|
||||
dnsTTL = 60 //seconds
|
||||
)
|
||||
|
||||
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
|
||||
@@ -35,12 +40,20 @@ type Manager struct {
|
||||
fwRules []firewall.Rule
|
||||
tcpRules []firewall.Rule
|
||||
dnsForwarder *DNSForwarder
|
||||
port uint16
|
||||
}
|
||||
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
|
||||
func ListenPort() uint16 {
|
||||
listenPortMu.RLock()
|
||||
defer listenPortMu.RUnlock()
|
||||
return listenPort
|
||||
}
|
||||
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager {
|
||||
return &Manager{
|
||||
firewall: fw,
|
||||
statusRecorder: statusRecorder,
|
||||
port: port,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,7 +67,13 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
|
||||
return err
|
||||
}
|
||||
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.firewall, m.statusRecorder)
|
||||
if m.port > 0 {
|
||||
listenPortMu.Lock()
|
||||
listenPort = m.port
|
||||
listenPortMu.Unlock()
|
||||
}
|
||||
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder)
|
||||
go func() {
|
||||
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
|
||||
// todo handle close error if it is exists
|
||||
@@ -94,7 +113,7 @@ func (m *Manager) Stop(ctx context.Context) error {
|
||||
func (m *Manager) allowDNSFirewall() error {
|
||||
dport := &firewall.Port{
|
||||
IsRange: false,
|
||||
Values: []uint16{ListenPort},
|
||||
Values: []uint16{ListenPort()},
|
||||
}
|
||||
|
||||
if m.firewall == nil {
|
||||
|
||||
@@ -202,6 +202,9 @@ type Engine struct {
|
||||
// WireGuard interface monitor
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
wgIfaceMonitorWg sync.WaitGroup
|
||||
|
||||
// dns forwarder port
|
||||
dnsFwdPort uint16
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -244,6 +247,7 @@ func NewEngine(
|
||||
statusRecorder: statusRecorder,
|
||||
checks: checks,
|
||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||
dnsFwdPort: dnsfwd.ListenPort(),
|
||||
}
|
||||
|
||||
sm := profilemanager.NewServiceManager("")
|
||||
@@ -453,8 +457,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return fmt.Errorf("up wg interface: %w", err)
|
||||
}
|
||||
|
||||
|
||||
|
||||
// if inbound conns are blocked there is no need to create the ACL manager
|
||||
if e.firewall != nil && !e.config.BlockInbound {
|
||||
e.acl = acl.NewDefaultManager(e.firewall)
|
||||
@@ -466,14 +468,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return fmt.Errorf("initialize dns server: %w", err)
|
||||
}
|
||||
|
||||
iceCfg := icemaker.Config{
|
||||
StunTurn: &e.stunTurn,
|
||||
InterfaceBlackList: e.config.IFaceBlackList,
|
||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||
UDPMux: e.udpMux.SingleSocketUDPMux,
|
||||
UDPMuxSrflx: e.udpMux,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
}
|
||||
iceCfg := e.createICEConfig()
|
||||
|
||||
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface)
|
||||
e.connMgr.Start(e.ctx)
|
||||
@@ -1089,7 +1084,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
}
|
||||
|
||||
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
|
||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
|
||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort))
|
||||
|
||||
// Ingress forward rules
|
||||
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
|
||||
@@ -1347,14 +1342,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
Addr: e.getRosenpassAddr(),
|
||||
PermissiveMode: e.config.RosenpassPermissive,
|
||||
},
|
||||
ICEConfig: icemaker.Config{
|
||||
StunTurn: &e.stunTurn,
|
||||
InterfaceBlackList: e.config.IFaceBlackList,
|
||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||
UDPMux: e.udpMux.SingleSocketUDPMux,
|
||||
UDPMuxSrflx: e.udpMux,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
},
|
||||
ICEConfig: e.createICEConfig(),
|
||||
}
|
||||
|
||||
serviceDependencies := peer.ServiceDependencies{
|
||||
@@ -1855,6 +1843,7 @@ func (e *Engine) GetWgAddr() netip.Addr {
|
||||
func (e *Engine) updateDNSForwarder(
|
||||
enabled bool,
|
||||
fwdEntries []*dnsfwd.ForwarderEntry,
|
||||
forwarderPort uint16,
|
||||
) {
|
||||
if e.config.DisableServerRoutes {
|
||||
return
|
||||
@@ -1871,16 +1860,20 @@ func (e *Engine) updateDNSForwarder(
|
||||
}
|
||||
|
||||
if len(fwdEntries) > 0 {
|
||||
if e.dnsForwardMgr == nil {
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
|
||||
|
||||
switch {
|
||||
case e.dnsForwardMgr == nil:
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
|
||||
log.Infof("started domain router service with %d entries", len(fwdEntries))
|
||||
} else {
|
||||
case e.dnsFwdPort != forwarderPort:
|
||||
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
|
||||
e.restartDnsFwd(fwdEntries, forwarderPort)
|
||||
e.dnsFwdPort = forwarderPort
|
||||
|
||||
default:
|
||||
e.dnsForwardMgr.UpdateDomains(fwdEntries)
|
||||
}
|
||||
} else if e.dnsForwardMgr != nil {
|
||||
@@ -1890,6 +1883,20 @@ func (e *Engine) updateDNSForwarder(
|
||||
}
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) {
|
||||
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
|
||||
// stop and start the forwarder to apply the new port
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) GetNet() (*netstack.Net, error) {
|
||||
|
||||
19
client/internal/engine_generic.go
Normal file
19
client/internal/engine_generic.go
Normal file
@@ -0,0 +1,19 @@
|
||||
//go:build !js
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
)
|
||||
|
||||
// createICEConfig creates ICE configuration for non-WASM environments
|
||||
func (e *Engine) createICEConfig() icemaker.Config {
|
||||
return icemaker.Config{
|
||||
StunTurn: &e.stunTurn,
|
||||
InterfaceBlackList: e.config.IFaceBlackList,
|
||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||
UDPMux: e.udpMux.SingleSocketUDPMux,
|
||||
UDPMuxSrflx: e.udpMux,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
}
|
||||
}
|
||||
18
client/internal/engine_js.go
Normal file
18
client/internal/engine_js.go
Normal file
@@ -0,0 +1,18 @@
|
||||
//go:build js
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
)
|
||||
|
||||
// createICEConfig creates ICE configuration for WASM environment.
|
||||
func (e *Engine) createICEConfig() icemaker.Config {
|
||||
cfg := icemaker.Config{
|
||||
StunTurn: &e.stunTurn,
|
||||
InterfaceBlackList: e.config.IFaceBlackList,
|
||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
@@ -27,6 +27,10 @@ import (
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
@@ -42,10 +46,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
@@ -1584,7 +1586,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
}
|
||||
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -138,7 +138,7 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) {
|
||||
|
||||
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
|
||||
// check dns collection
|
||||
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == dnsfwd.ListenPort) {
|
||||
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
12
client/internal/networkmonitor/check_change_js.go
Normal file
12
client/internal/networkmonitor/check_change_js.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
// No-op for WASM - network changes don't apply
|
||||
return nil
|
||||
}
|
||||
@@ -28,10 +28,6 @@ import (
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultWgKeepAlive = 25 * time.Second
|
||||
)
|
||||
|
||||
type ServiceDependencies struct {
|
||||
StatusRecorder *Status
|
||||
Signaler *Signaler
|
||||
@@ -117,6 +113,8 @@ type Conn struct {
|
||||
|
||||
// debug purpose
|
||||
dumpState *stateDump
|
||||
|
||||
endpointUpdater *EndpointUpdater
|
||||
}
|
||||
|
||||
// NewConn creates a new not opened Conn to the remote peer.
|
||||
@@ -129,17 +127,18 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
connLog := log.WithField("peer", config.Key)
|
||||
|
||||
var conn = &Conn{
|
||||
Log: connLog,
|
||||
config: config,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
signaler: services.Signaler,
|
||||
iFaceDiscover: services.IFaceDiscover,
|
||||
relayManager: services.RelayManager,
|
||||
srWatcher: services.SrWatcher,
|
||||
semaphore: services.Semaphore,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
|
||||
Log: connLog,
|
||||
config: config,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
signaler: services.Signaler,
|
||||
iFaceDiscover: services.IFaceDiscover,
|
||||
relayManager: services.RelayManager,
|
||||
srWatcher: services.SrWatcher,
|
||||
semaphore: services.Semaphore,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
|
||||
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
@@ -249,7 +248,7 @@ func (conn *Conn) Close(signalToRemote bool) {
|
||||
conn.wgProxyICE = nil
|
||||
}
|
||||
|
||||
if err := conn.removeWgPeer(); err != nil {
|
||||
if err := conn.endpointUpdater.RemoveWgPeer(); err != nil {
|
||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
|
||||
@@ -375,12 +374,19 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
wgProxy.Work()
|
||||
}
|
||||
|
||||
if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil {
|
||||
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
|
||||
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
|
||||
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
|
||||
conn.handleConfigurationFailure(err, wgProxy)
|
||||
return
|
||||
}
|
||||
wgConfigWorkaround()
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
conn.Log.Debugf("redirect packets from relayed conn to WireGuard")
|
||||
conn.wgProxyRelay.RedirectAs(ep)
|
||||
}
|
||||
|
||||
conn.currentConnPriority = priority
|
||||
conn.statusICE.SetConnected()
|
||||
conn.updateIceState(iceConnInfo)
|
||||
@@ -409,7 +415,8 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
conn.dumpState.SwitchToRelay()
|
||||
conn.wgProxyRelay.Work()
|
||||
|
||||
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil {
|
||||
presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
|
||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil {
|
||||
conn.Log.Errorf("failed to switch to relay conn: %v", err)
|
||||
}
|
||||
|
||||
@@ -418,6 +425,7 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
}()
|
||||
conn.wgProxyRelay.Work()
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
} else {
|
||||
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
||||
@@ -477,7 +485,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
}
|
||||
|
||||
wgProxy.Work()
|
||||
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
|
||||
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
|
||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
|
||||
if err := wgProxy.CloseConn(); err != nil {
|
||||
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
||||
}
|
||||
@@ -545,17 +554,6 @@ func (conn *Conn) onGuardEvent() {
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr, remoteRPKey []byte) error {
|
||||
presharedKey := conn.presharedKey(remoteRPKey)
|
||||
return conn.config.WgConfig.WgInterface.UpdatePeer(
|
||||
conn.config.WgConfig.RemoteKey,
|
||||
conn.config.WgConfig.AllowedIps,
|
||||
defaultWgKeepAlive,
|
||||
addr,
|
||||
presharedKey,
|
||||
)
|
||||
}
|
||||
|
||||
func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
@@ -698,10 +696,6 @@ func (conn *Conn) isICEActive() bool {
|
||||
return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected
|
||||
}
|
||||
|
||||
func (conn *Conn) removeWgPeer() error {
|
||||
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||
}
|
||||
|
||||
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
|
||||
conn.Log.Warnf("Failed to update wg peer configuration: %v", err)
|
||||
if wgProxy != nil {
|
||||
|
||||
105
client/internal/peer/endpoint.go
Normal file
105
client/internal/peer/endpoint.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultWgKeepAlive = 25 * time.Second
|
||||
fallbackDelay = 5 * time.Second
|
||||
)
|
||||
|
||||
type EndpointUpdater struct {
|
||||
log *logrus.Entry
|
||||
wgConfig WgConfig
|
||||
initiator bool
|
||||
|
||||
// mu protects updateWireGuardPeer and cancelFunc
|
||||
mu sync.Mutex
|
||||
cancelFunc func()
|
||||
updateWg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewEndpointUpdater(log *logrus.Entry, wgConfig WgConfig, initiator bool) *EndpointUpdater {
|
||||
return &EndpointUpdater{
|
||||
log: log,
|
||||
wgConfig: wgConfig,
|
||||
initiator: initiator,
|
||||
}
|
||||
}
|
||||
|
||||
// ConfigureWGEndpoint sets up the WireGuard endpoint configuration.
|
||||
// The initiator immediately configures the endpoint, while the non-initiator
|
||||
// waits for a fallback period before configuring to avoid handshake congestion.
|
||||
func (e *EndpointUpdater) ConfigureWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
if e.initiator {
|
||||
e.log.Debugf("configure up WireGuard as initiatr")
|
||||
return e.updateWireGuardPeer(addr, presharedKey)
|
||||
}
|
||||
|
||||
// prevent to run new update while cancel the previous update
|
||||
e.waitForCloseTheDelayedUpdate()
|
||||
|
||||
var ctx context.Context
|
||||
ctx, e.cancelFunc = context.WithCancel(context.Background())
|
||||
e.updateWg.Add(1)
|
||||
go e.scheduleDelayedUpdate(ctx, addr, presharedKey)
|
||||
|
||||
e.log.Debugf("configure up WireGuard and wait for handshake")
|
||||
return e.updateWireGuardPeer(nil, presharedKey)
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) RemoveWgPeer() error {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
e.waitForCloseTheDelayedUpdate()
|
||||
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
|
||||
if e.cancelFunc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
e.cancelFunc()
|
||||
e.cancelFunc = nil
|
||||
e.updateWg.Wait()
|
||||
}
|
||||
|
||||
// scheduleDelayedUpdate waits for the fallback period before updating the endpoint
|
||||
func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.UDPAddr, presharedKey *wgtypes.Key) {
|
||||
defer e.updateWg.Done()
|
||||
t := time.NewTimer(fallbackDelay)
|
||||
defer t.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
e.mu.Lock()
|
||||
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
|
||||
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err)
|
||||
}
|
||||
e.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||
return e.wgConfig.WgInterface.UpdatePeer(
|
||||
e.wgConfig.RemoteKey,
|
||||
e.wgConfig.AllowedIps,
|
||||
defaultWgKeepAlive,
|
||||
endpoint,
|
||||
presharedKey,
|
||||
)
|
||||
}
|
||||
@@ -3,6 +3,8 @@ package guard
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -24,8 +26,8 @@ type ICEMonitor struct {
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
iceConfig icemaker.Config
|
||||
|
||||
currentCandidates []ice.Candidate
|
||||
candidatesMu sync.Mutex
|
||||
currentCandidatesAddress []string
|
||||
candidatesMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor {
|
||||
@@ -115,16 +117,21 @@ func (cm *ICEMonitor) updateCandidates(newCandidates []ice.Candidate) bool {
|
||||
cm.candidatesMu.Lock()
|
||||
defer cm.candidatesMu.Unlock()
|
||||
|
||||
if len(cm.currentCandidates) != len(newCandidates) {
|
||||
cm.currentCandidates = newCandidates
|
||||
newAddresses := make([]string, len(newCandidates))
|
||||
for i, c := range newCandidates {
|
||||
newAddresses[i] = c.Address()
|
||||
}
|
||||
sort.Strings(newAddresses)
|
||||
|
||||
if len(cm.currentCandidatesAddress) != len(newAddresses) {
|
||||
cm.currentCandidatesAddress = newAddresses
|
||||
return true
|
||||
}
|
||||
|
||||
for i, candidate := range cm.currentCandidates {
|
||||
if candidate.Address() != newCandidates[i].Address() {
|
||||
cm.currentCandidates = newCandidates
|
||||
return true
|
||||
}
|
||||
// Compare elements
|
||||
if !slices.Equal(cm.currentCandidatesAddress, newAddresses) {
|
||||
cm.currentCandidatesAddress = newAddresses
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
|
||||
@@ -24,8 +24,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
const dnsTimeout = 8 * time.Second
|
||||
@@ -257,7 +257,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort())
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
|
||||
48
client/internal/routemanager/systemops/systemops_js.go
Normal file
48
client/internal/routemanager/systemops/systemops_js.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
var ErrRouteNotSupported = errors.New("route operations not supported on js")
|
||||
|
||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return ErrRouteNotSupported
|
||||
}
|
||||
|
||||
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||
return ErrRouteNotSupported
|
||||
}
|
||||
|
||||
func GetRoutesFromTable() ([]netip.Prefix, error) {
|
||||
return []netip.Prefix{}, nil
|
||||
}
|
||||
|
||||
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||
return []netip.Prefix{}, nil
|
||||
}
|
||||
|
||||
// GetDetailedRoutesFromTable returns empty routes for WASM.
|
||||
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
|
||||
return []DetailedRoute{}, nil
|
||||
}
|
||||
|
||||
func (r *SysOps) AddVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
return ErrRouteNotSupported
|
||||
}
|
||||
|
||||
func (r *SysOps) RemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) error {
|
||||
return ErrRouteNotSupported
|
||||
}
|
||||
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, _ bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, _ bool) error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux && !ios
|
||||
//go:build !linux && !ios && !js
|
||||
|
||||
package systemops
|
||||
|
||||
|
||||
@@ -10,23 +10,26 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
@@ -314,7 +317,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
}
|
||||
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import "context"
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
|
||||
137
client/ssh/ssh_js.go
Normal file
137
client/ssh/ssh_js.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var ErrSSHNotSupported = errors.New("SSH is not supported in WASM environment")
|
||||
|
||||
// Server is a dummy SSH server interface for WASM.
|
||||
type Server interface {
|
||||
Start() error
|
||||
Stop() error
|
||||
EnableSSH(enabled bool)
|
||||
AddAuthorizedKey(peer string, key string) error
|
||||
RemoveAuthorizedKey(key string)
|
||||
}
|
||||
|
||||
type dummyServer struct{}
|
||||
|
||||
func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
|
||||
return &dummyServer{}, nil
|
||||
}
|
||||
|
||||
func NewServer(addr string) Server {
|
||||
return &dummyServer{}
|
||||
}
|
||||
|
||||
func (s *dummyServer) Start() error {
|
||||
return ErrSSHNotSupported
|
||||
}
|
||||
|
||||
func (s *dummyServer) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dummyServer) EnableSSH(enabled bool) {
|
||||
}
|
||||
|
||||
func (s *dummyServer) AddAuthorizedKey(peer string, key string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dummyServer) RemoveAuthorizedKey(key string) {
|
||||
}
|
||||
|
||||
type Client struct{}
|
||||
|
||||
func NewClient(ctx context.Context, addr string, config interface{}, recorder *SessionRecorder) (*Client, error) {
|
||||
return nil, ErrSSHNotSupported
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Run(command []string) error {
|
||||
return ErrSSHNotSupported
|
||||
}
|
||||
|
||||
type SessionRecorder struct{}
|
||||
|
||||
func NewSessionRecorder() *SessionRecorder {
|
||||
return &SessionRecorder{}
|
||||
}
|
||||
|
||||
func (r *SessionRecorder) Record(session string, data []byte) {
|
||||
}
|
||||
|
||||
func GetUserShell() string {
|
||||
return "/bin/sh"
|
||||
}
|
||||
|
||||
func LookupUserInfo(username string) (string, string, error) {
|
||||
return "", "", ErrSSHNotSupported
|
||||
}
|
||||
|
||||
const DefaultSSHPort = 44338
|
||||
|
||||
const ED25519 = "ed25519"
|
||||
|
||||
func isRoot() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func GeneratePrivateKey(keyType string) ([]byte, error) {
|
||||
if keyType != ED25519 {
|
||||
return nil, errors.New("only ED25519 keys are supported in WASM")
|
||||
}
|
||||
|
||||
_, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pkcs8Bytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pemBlock := &pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: pkcs8Bytes,
|
||||
}
|
||||
|
||||
pemBytes := pem.EncodeToMemory(pemBlock)
|
||||
return pemBytes, nil
|
||||
}
|
||||
|
||||
func GeneratePublicKey(privateKey []byte) ([]byte, error) {
|
||||
signer, err := ssh.ParsePrivateKey(privateKey)
|
||||
if err != nil {
|
||||
block, _ := pem.Decode(privateKey)
|
||||
if block != nil {
|
||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
signer, err = ssh.NewSignerFromKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
pubKeyBytes := ssh.MarshalAuthorizedKey(signer.PublicKey())
|
||||
return []byte(strings.TrimSpace(string(pubKeyBytes))), nil
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
|
||||
231
client/system/info_js.go
Normal file
231
client/system/info_js.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall/js"
|
||||
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// UpdateStaticInfoAsync is a no-op on JS as there is no static info to update
|
||||
func UpdateStaticInfoAsync() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
// GetInfo retrieves system information for WASM environment
|
||||
func GetInfo(_ context.Context) *Info {
|
||||
info := &Info{
|
||||
GoOS: runtime.GOOS,
|
||||
Kernel: runtime.GOARCH,
|
||||
KernelVersion: runtime.GOARCH,
|
||||
Platform: runtime.GOARCH,
|
||||
OS: runtime.GOARCH,
|
||||
Hostname: "wasm-client",
|
||||
CPUs: runtime.NumCPU(),
|
||||
NetbirdVersion: version.NetbirdVersion(),
|
||||
}
|
||||
|
||||
collectBrowserInfo(info)
|
||||
collectLocationInfo(info)
|
||||
collectSystemInfo(info)
|
||||
return info
|
||||
}
|
||||
|
||||
func collectBrowserInfo(info *Info) {
|
||||
navigator := js.Global().Get("navigator")
|
||||
if navigator.IsUndefined() {
|
||||
return
|
||||
}
|
||||
|
||||
collectUserAgent(info, navigator)
|
||||
collectPlatform(info, navigator)
|
||||
collectCPUInfo(info, navigator)
|
||||
}
|
||||
|
||||
func collectUserAgent(info *Info, navigator js.Value) {
|
||||
ua := navigator.Get("userAgent")
|
||||
if ua.IsUndefined() {
|
||||
return
|
||||
}
|
||||
|
||||
userAgent := ua.String()
|
||||
os, osVersion := parseOSFromUserAgent(userAgent)
|
||||
if os != "" {
|
||||
info.OS = os
|
||||
}
|
||||
if osVersion != "" {
|
||||
info.OSVersion = osVersion
|
||||
}
|
||||
}
|
||||
|
||||
func collectPlatform(info *Info, navigator js.Value) {
|
||||
// Try regular platform property
|
||||
if plat := navigator.Get("platform"); !plat.IsUndefined() {
|
||||
if platStr := plat.String(); platStr != "" {
|
||||
info.Platform = platStr
|
||||
}
|
||||
}
|
||||
|
||||
// Try newer userAgentData API for more accurate platform
|
||||
userAgentData := navigator.Get("userAgentData")
|
||||
if userAgentData.IsUndefined() {
|
||||
return
|
||||
}
|
||||
|
||||
platformInfo := userAgentData.Get("platform")
|
||||
if !platformInfo.IsUndefined() {
|
||||
if platStr := platformInfo.String(); platStr != "" {
|
||||
info.Platform = platStr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func collectCPUInfo(info *Info, navigator js.Value) {
|
||||
hardwareConcurrency := navigator.Get("hardwareConcurrency")
|
||||
if !hardwareConcurrency.IsUndefined() {
|
||||
info.CPUs = hardwareConcurrency.Int()
|
||||
}
|
||||
}
|
||||
|
||||
func collectLocationInfo(info *Info) {
|
||||
location := js.Global().Get("location")
|
||||
if location.IsUndefined() {
|
||||
return
|
||||
}
|
||||
|
||||
if host := location.Get("hostname"); !host.IsUndefined() {
|
||||
hostnameStr := host.String()
|
||||
if hostnameStr != "" && hostnameStr != "localhost" {
|
||||
info.Hostname = hostnameStr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkFileAndProcess(_ []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
}
|
||||
|
||||
func collectSystemInfo(info *Info) {
|
||||
navigator := js.Global().Get("navigator")
|
||||
if navigator.IsUndefined() {
|
||||
return
|
||||
}
|
||||
|
||||
if vendor := navigator.Get("vendor"); !vendor.IsUndefined() {
|
||||
info.SystemManufacturer = vendor.String()
|
||||
}
|
||||
|
||||
if product := navigator.Get("product"); !product.IsUndefined() {
|
||||
info.SystemProductName = product.String()
|
||||
}
|
||||
|
||||
if userAgent := navigator.Get("userAgent"); !userAgent.IsUndefined() {
|
||||
ua := userAgent.String()
|
||||
info.Environment = detectEnvironmentFromUA(ua)
|
||||
}
|
||||
}
|
||||
|
||||
func parseOSFromUserAgent(userAgent string) (string, string) {
|
||||
if userAgent == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.Contains(userAgent, "Windows NT"):
|
||||
return parseWindowsVersion(userAgent)
|
||||
case strings.Contains(userAgent, "Mac OS X"):
|
||||
return parseMacOSVersion(userAgent)
|
||||
case strings.Contains(userAgent, "FreeBSD"):
|
||||
return "FreeBSD", ""
|
||||
case strings.Contains(userAgent, "OpenBSD"):
|
||||
return "OpenBSD", ""
|
||||
case strings.Contains(userAgent, "NetBSD"):
|
||||
return "NetBSD", ""
|
||||
case strings.Contains(userAgent, "Linux"):
|
||||
return parseLinuxVersion(userAgent)
|
||||
case strings.Contains(userAgent, "iPhone") || strings.Contains(userAgent, "iPad"):
|
||||
return parseiOSVersion(userAgent)
|
||||
case strings.Contains(userAgent, "CrOS"):
|
||||
return "ChromeOS", ""
|
||||
default:
|
||||
return "", ""
|
||||
}
|
||||
}
|
||||
|
||||
func parseWindowsVersion(userAgent string) (string, string) {
|
||||
switch {
|
||||
case strings.Contains(userAgent, "Windows NT 10.0; Win64; x64"):
|
||||
return "Windows", "10/11"
|
||||
case strings.Contains(userAgent, "Windows NT 10.0"):
|
||||
return "Windows", "10"
|
||||
case strings.Contains(userAgent, "Windows NT 6.3"):
|
||||
return "Windows", "8.1"
|
||||
case strings.Contains(userAgent, "Windows NT 6.2"):
|
||||
return "Windows", "8"
|
||||
case strings.Contains(userAgent, "Windows NT 6.1"):
|
||||
return "Windows", "7"
|
||||
default:
|
||||
return "Windows", "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func parseMacOSVersion(userAgent string) (string, string) {
|
||||
idx := strings.Index(userAgent, "Mac OS X ")
|
||||
if idx == -1 {
|
||||
return "macOS", "Unknown"
|
||||
}
|
||||
|
||||
versionStart := idx + len("Mac OS X ")
|
||||
versionEnd := strings.Index(userAgent[versionStart:], ")")
|
||||
if versionEnd <= 0 {
|
||||
return "macOS", "Unknown"
|
||||
}
|
||||
|
||||
ver := userAgent[versionStart : versionStart+versionEnd]
|
||||
ver = strings.ReplaceAll(ver, "_", ".")
|
||||
return "macOS", ver
|
||||
}
|
||||
|
||||
func parseLinuxVersion(userAgent string) (string, string) {
|
||||
if strings.Contains(userAgent, "Android") {
|
||||
return "Android", extractAndroidVersion(userAgent)
|
||||
}
|
||||
if strings.Contains(userAgent, "Ubuntu") {
|
||||
return "Ubuntu", ""
|
||||
}
|
||||
return "Linux", ""
|
||||
}
|
||||
|
||||
func parseiOSVersion(userAgent string) (string, string) {
|
||||
idx := strings.Index(userAgent, "OS ")
|
||||
if idx == -1 {
|
||||
return "iOS", "Unknown"
|
||||
}
|
||||
|
||||
versionStart := idx + 3
|
||||
versionEnd := strings.Index(userAgent[versionStart:], " ")
|
||||
if versionEnd <= 0 {
|
||||
return "iOS", "Unknown"
|
||||
}
|
||||
|
||||
ver := userAgent[versionStart : versionStart+versionEnd]
|
||||
ver = strings.ReplaceAll(ver, "_", ".")
|
||||
return "iOS", ver
|
||||
}
|
||||
|
||||
func extractAndroidVersion(userAgent string) string {
|
||||
if idx := strings.Index(userAgent, "Android "); idx != -1 {
|
||||
versionStart := idx + len("Android ")
|
||||
versionEnd := strings.IndexAny(userAgent[versionStart:], ";)")
|
||||
if versionEnd > 0 {
|
||||
return userAgent[versionStart : versionStart+versionEnd]
|
||||
}
|
||||
}
|
||||
return "Unknown"
|
||||
}
|
||||
|
||||
func detectEnvironmentFromUA(_ string) Environment {
|
||||
return Environment{}
|
||||
}
|
||||
245
client/wasm/cmd/main.go
Normal file
245
client/wasm/cmd/main.go
Normal file
@@ -0,0 +1,245 @@
|
||||
//go:build js
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const (
|
||||
clientStartTimeout = 30 * time.Second
|
||||
clientStopTimeout = 10 * time.Second
|
||||
defaultLogLevel = "warn"
|
||||
)
|
||||
|
||||
func main() {
|
||||
js.Global().Set("NetBirdClient", js.FuncOf(netBirdClientConstructor))
|
||||
|
||||
select {}
|
||||
}
|
||||
|
||||
func startClient(ctx context.Context, nbClient *netbird.Client) error {
|
||||
log.Info("Starting NetBird client...")
|
||||
if err := nbClient.Start(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Info("NetBird client started successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseClientOptions extracts NetBird options from JavaScript object
|
||||
func parseClientOptions(jsOptions js.Value) (netbird.Options, error) {
|
||||
options := netbird.Options{
|
||||
DeviceName: "dashboard-client",
|
||||
LogLevel: defaultLogLevel,
|
||||
}
|
||||
|
||||
if jwtToken := jsOptions.Get("jwtToken"); !jwtToken.IsNull() && !jwtToken.IsUndefined() {
|
||||
options.JWTToken = jwtToken.String()
|
||||
}
|
||||
|
||||
if setupKey := jsOptions.Get("setupKey"); !setupKey.IsNull() && !setupKey.IsUndefined() {
|
||||
options.SetupKey = setupKey.String()
|
||||
}
|
||||
|
||||
if privateKey := jsOptions.Get("privateKey"); !privateKey.IsNull() && !privateKey.IsUndefined() {
|
||||
options.PrivateKey = privateKey.String()
|
||||
}
|
||||
|
||||
if mgmtURL := jsOptions.Get("managementURL"); !mgmtURL.IsNull() && !mgmtURL.IsUndefined() {
|
||||
mgmtURLStr := mgmtURL.String()
|
||||
if mgmtURLStr != "" {
|
||||
options.ManagementURL = mgmtURLStr
|
||||
}
|
||||
}
|
||||
|
||||
if logLevel := jsOptions.Get("logLevel"); !logLevel.IsNull() && !logLevel.IsUndefined() {
|
||||
options.LogLevel = logLevel.String()
|
||||
}
|
||||
|
||||
if deviceName := jsOptions.Get("deviceName"); !deviceName.IsNull() && !deviceName.IsUndefined() {
|
||||
options.DeviceName = deviceName.String()
|
||||
}
|
||||
|
||||
return options, nil
|
||||
}
|
||||
|
||||
// createStartMethod creates the start method for the client
|
||||
func createStartMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), clientStartTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := startClient(ctx, client); err != nil {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
resolve.Invoke(js.ValueOf(true))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createStopMethod creates the stop method for the client
|
||||
func createStopMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), clientStopTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
log.Errorf("Error stopping client: %v", err)
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("NetBird client stopped")
|
||||
resolve.Invoke(js.ValueOf(true))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createSSHMethod creates the SSH connection method
|
||||
func createSSHMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 2 {
|
||||
return js.ValueOf("error: requires host and port")
|
||||
}
|
||||
|
||||
host := args[0].String()
|
||||
port := args[1].Int()
|
||||
username := "root"
|
||||
if len(args) > 2 && args[2].String() != "" {
|
||||
username = args[2].String()
|
||||
}
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
sshClient := ssh.NewClient(client)
|
||||
|
||||
if err := sshClient.Connect(host, port, username); err != nil {
|
||||
reject.Invoke(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := sshClient.StartSession(80, 24); err != nil {
|
||||
if closeErr := sshClient.Close(); closeErr != nil {
|
||||
log.Errorf("Error closing SSH client: %v", closeErr)
|
||||
}
|
||||
reject.Invoke(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
jsInterface := ssh.CreateJSInterface(sshClient)
|
||||
resolve.Invoke(jsInterface)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createProxyRequestMethod creates the proxyRequest method
|
||||
func createProxyRequestMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf("error: request details required")
|
||||
}
|
||||
|
||||
request := args[0]
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
response, err := http.ProxyRequest(client, request)
|
||||
if err != nil {
|
||||
reject.Invoke(err.Error())
|
||||
return
|
||||
}
|
||||
resolve.Invoke(response)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createRDPProxyMethod creates the RDP proxy method
|
||||
func createRDPProxyMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 2 {
|
||||
return js.ValueOf("error: hostname and port required")
|
||||
}
|
||||
|
||||
proxy := rdp.NewRDCleanPathProxy(client)
|
||||
return proxy.CreateProxy(args[0].String(), args[1].String())
|
||||
})
|
||||
}
|
||||
|
||||
// createPromise is a helper to create JavaScript promises
|
||||
func createPromise(handler func(resolve, reject js.Value)) js.Value {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
||||
resolve := promiseArgs[0]
|
||||
reject := promiseArgs[1]
|
||||
|
||||
go handler(resolve, reject)
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
// createClientObject wraps the NetBird client in a JavaScript object
|
||||
func createClientObject(client *netbird.Client) js.Value {
|
||||
obj := make(map[string]interface{})
|
||||
|
||||
obj["start"] = createStartMethod(client)
|
||||
obj["stop"] = createStopMethod(client)
|
||||
obj["createSSHConnection"] = createSSHMethod(client)
|
||||
obj["proxyRequest"] = createProxyRequestMethod(client)
|
||||
obj["createRDPProxy"] = createRDPProxyMethod(client)
|
||||
|
||||
return js.ValueOf(obj)
|
||||
}
|
||||
|
||||
// netBirdClientConstructor acts as a JavaScript constructor function
|
||||
func netBirdClientConstructor(this js.Value, args []js.Value) any {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(this js.Value, promiseArgs []js.Value) any {
|
||||
resolve := promiseArgs[0]
|
||||
reject := promiseArgs[1]
|
||||
|
||||
if len(args) < 1 {
|
||||
reject.Invoke(js.ValueOf("Options object required"))
|
||||
return nil
|
||||
}
|
||||
|
||||
go func() {
|
||||
options, err := parseClientOptions(args[0])
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if err := util.InitLog(options.LogLevel, util.LogConsole); err != nil {
|
||||
log.Warnf("Failed to initialize logging: %v", err)
|
||||
}
|
||||
|
||||
log.Infof("Creating NetBird client with options: deviceName=%s, hasJWT=%v, hasSetupKey=%v, mgmtURL=%s",
|
||||
options.DeviceName, options.JWTToken != "", options.SetupKey != "", options.ManagementURL)
|
||||
|
||||
client, err := netbird.New(options)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(fmt.Sprintf("create client: %v", err)))
|
||||
return
|
||||
}
|
||||
|
||||
clientObj := createClientObject(client)
|
||||
log.Info("NetBird client created successfully")
|
||||
resolve.Invoke(clientObj)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
100
client/wasm/internal/http/http.go
Normal file
100
client/wasm/internal/http/http.go
Normal file
@@ -0,0 +1,100 @@
|
||||
//go:build js
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/http"
|
||||
"strings"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
)
|
||||
|
||||
const (
|
||||
httpTimeout = 30 * time.Second
|
||||
maxResponseSize = 1024 * 1024 // 1MB
|
||||
)
|
||||
|
||||
// performRequest executes an HTTP request through NetBird and returns the response and body
|
||||
func performRequest(nbClient *netbird.Client, method, url string, headers map[string]string, body []byte) (*http.Response, []byte, error) {
|
||||
httpClient := nbClient.NewHTTPClient()
|
||||
httpClient.Timeout = httpTimeout
|
||||
|
||||
req, err := http.NewRequest(method, url, strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
for key, value := range headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
log.Errorf("failed to close response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
return resp, respBody, nil
|
||||
}
|
||||
|
||||
// ProxyRequest performs a proxied HTTP request through NetBird and returns a JavaScript object
|
||||
func ProxyRequest(nbClient *netbird.Client, request js.Value) (js.Value, error) {
|
||||
url := request.Get("url").String()
|
||||
if url == "" {
|
||||
return js.Undefined(), fmt.Errorf("URL is required")
|
||||
}
|
||||
|
||||
method := "GET"
|
||||
if methodVal := request.Get("method"); !methodVal.IsNull() && !methodVal.IsUndefined() {
|
||||
method = strings.ToUpper(methodVal.String())
|
||||
}
|
||||
|
||||
var requestBody []byte
|
||||
if bodyVal := request.Get("body"); !bodyVal.IsNull() && !bodyVal.IsUndefined() {
|
||||
requestBody = []byte(bodyVal.String())
|
||||
}
|
||||
|
||||
requestHeaders := make(map[string]string)
|
||||
if headersVal := request.Get("headers"); !headersVal.IsNull() && !headersVal.IsUndefined() && headersVal.Type() == js.TypeObject {
|
||||
headerKeys := js.Global().Get("Object").Call("keys", headersVal)
|
||||
for i := 0; i < headerKeys.Length(); i++ {
|
||||
key := headerKeys.Index(i).String()
|
||||
value := headersVal.Get(key).String()
|
||||
requestHeaders[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
resp, body, err := performRequest(nbClient, method, url, requestHeaders, requestBody)
|
||||
if err != nil {
|
||||
return js.Undefined(), err
|
||||
}
|
||||
|
||||
result := js.Global().Get("Object").New()
|
||||
result.Set("status", resp.StatusCode)
|
||||
result.Set("statusText", resp.Status)
|
||||
result.Set("body", string(body))
|
||||
|
||||
headers := js.Global().Get("Object").New()
|
||||
for key, values := range resp.Header {
|
||||
if len(values) > 0 {
|
||||
headers.Set(strings.ToLower(key), values[0])
|
||||
}
|
||||
}
|
||||
result.Set("headers", headers)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
96
client/wasm/internal/rdp/cert_validation.go
Normal file
96
client/wasm/internal/rdp/cert_validation.go
Normal file
@@ -0,0 +1,96 @@
|
||||
//go:build js
|
||||
|
||||
package rdp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
certValidationTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) {
|
||||
if !conn.wsHandlers.Get("onCertificateRequest").Truthy() {
|
||||
return false, fmt.Errorf("certificate validation handler not configured")
|
||||
}
|
||||
|
||||
certInfo := js.Global().Get("Object").New()
|
||||
certInfo.Set("ServerAddr", conn.destination)
|
||||
|
||||
certArray := js.Global().Get("Array").New()
|
||||
for i, certBytes := range certChain {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(certBytes))
|
||||
js.CopyBytesToJS(uint8Array, certBytes)
|
||||
certArray.SetIndex(i, uint8Array)
|
||||
}
|
||||
certInfo.Set("ServerCertChain", certArray)
|
||||
if len(certChain) > 0 {
|
||||
cert, err := x509.ParseCertificate(certChain[0])
|
||||
if err == nil {
|
||||
info := js.Global().Get("Object").New()
|
||||
info.Set("subject", cert.Subject.String())
|
||||
info.Set("issuer", cert.Issuer.String())
|
||||
info.Set("validFrom", cert.NotBefore.Format(time.RFC3339))
|
||||
info.Set("validTo", cert.NotAfter.Format(time.RFC3339))
|
||||
info.Set("serialNumber", cert.SerialNumber.String())
|
||||
certInfo.Set("CertificateInfo", info)
|
||||
}
|
||||
}
|
||||
|
||||
promise := conn.wsHandlers.Call("onCertificateRequest", certInfo)
|
||||
|
||||
resultChan := make(chan bool)
|
||||
errorChan := make(chan error)
|
||||
|
||||
promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
result := args[0].Bool()
|
||||
resultChan <- result
|
||||
return nil
|
||||
})).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
errorChan <- fmt.Errorf("certificate validation failed")
|
||||
return nil
|
||||
}))
|
||||
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
if result {
|
||||
log.Info("Certificate accepted by user")
|
||||
} else {
|
||||
log.Info("Certificate rejected by user")
|
||||
}
|
||||
return result, nil
|
||||
case err := <-errorChan:
|
||||
return false, err
|
||||
case <-time.After(certValidationTimeout):
|
||||
return false, fmt.Errorf("certificate validation timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config {
|
||||
return &tls.Config{
|
||||
InsecureSkipVerify: true, // We'll validate manually after handshake
|
||||
VerifyConnection: func(cs tls.ConnectionState) error {
|
||||
var certChain [][]byte
|
||||
for _, cert := range cs.PeerCertificates {
|
||||
certChain = append(certChain, cert.Raw)
|
||||
}
|
||||
|
||||
accepted, err := p.validateCertificateWithJS(conn, certChain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !accepted {
|
||||
return fmt.Errorf("certificate rejected by user")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
271
client/wasm/internal/rdp/rdcleanpath.go
Normal file
271
client/wasm/internal/rdp/rdcleanpath.go
Normal file
@@ -0,0 +1,271 @@
|
||||
//go:build js
|
||||
|
||||
package rdp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall/js"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
RDCleanPathVersion = 3390
|
||||
RDCleanPathProxyHost = "rdcleanpath.proxy.local"
|
||||
RDCleanPathProxyScheme = "ws"
|
||||
)
|
||||
|
||||
type RDCleanPathPDU struct {
|
||||
Version int64 `asn1:"tag:0,explicit"`
|
||||
Error []byte `asn1:"tag:1,explicit,optional"`
|
||||
Destination string `asn1:"utf8,tag:2,explicit,optional"`
|
||||
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
|
||||
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
|
||||
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
|
||||
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
|
||||
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
|
||||
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
|
||||
}
|
||||
|
||||
type RDCleanPathProxy struct {
|
||||
nbClient interface {
|
||||
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
activeConnections map[string]*proxyConnection
|
||||
destinations map[string]string
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type proxyConnection struct {
|
||||
id string
|
||||
destination string
|
||||
rdpConn net.Conn
|
||||
tlsConn *tls.Conn
|
||||
wsHandlers js.Value
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewRDCleanPathProxy creates a new RDCleanPath proxy
|
||||
func NewRDCleanPathProxy(client interface {
|
||||
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}) *RDCleanPathProxy {
|
||||
return &RDCleanPathProxy{
|
||||
nbClient: client,
|
||||
activeConnections: make(map[string]*proxyConnection),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateProxy creates a new proxy endpoint for the given destination
|
||||
func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
||||
destination := fmt.Sprintf("%s:%s", hostname, port)
|
||||
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
resolve := args[0]
|
||||
|
||||
go func() {
|
||||
proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections))
|
||||
|
||||
p.mu.Lock()
|
||||
if p.destinations == nil {
|
||||
p.destinations = make(map[string]string)
|
||||
}
|
||||
p.destinations[proxyID] = destination
|
||||
p.mu.Unlock()
|
||||
|
||||
proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID)
|
||||
|
||||
// Register the WebSocket handler for this specific proxy
|
||||
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf("error: requires WebSocket argument")
|
||||
}
|
||||
|
||||
ws := args[0]
|
||||
p.HandleWebSocketConnection(ws, proxyID)
|
||||
return nil
|
||||
}))
|
||||
|
||||
log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination)
|
||||
resolve.Invoke(proxyURL)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
// HandleWebSocketConnection handles incoming WebSocket connections from IronRDP
|
||||
func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string) {
|
||||
p.mu.Lock()
|
||||
destination := p.destinations[proxyID]
|
||||
p.mu.Unlock()
|
||||
|
||||
if destination == "" {
|
||||
log.Errorf("No destination found for proxy ID: %s", proxyID)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// Don't defer cancel here - it will be called by cleanupConnection
|
||||
|
||||
conn := &proxyConnection{
|
||||
id: proxyID,
|
||||
destination: destination,
|
||||
wsHandlers: ws,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
p.activeConnections[proxyID] = conn
|
||||
p.mu.Unlock()
|
||||
|
||||
p.setupWebSocketHandlers(ws, conn)
|
||||
|
||||
log.Infof("RDCleanPath proxy WebSocket connection established for %s", proxyID)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) {
|
||||
ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
data := args[0]
|
||||
go p.handleWebSocketMessage(conn, data)
|
||||
return nil
|
||||
}))
|
||||
|
||||
ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
log.Debug("WebSocket closed by JavaScript")
|
||||
conn.cancel()
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) {
|
||||
if !data.InstanceOf(js.Global().Get("Uint8Array")) {
|
||||
return
|
||||
}
|
||||
|
||||
length := data.Get("length").Int()
|
||||
bytes := make([]byte, length)
|
||||
js.CopyBytesToGo(bytes, data)
|
||||
|
||||
if conn.rdpConn != nil || conn.tlsConn != nil {
|
||||
p.forwardToRDP(conn, bytes)
|
||||
return
|
||||
}
|
||||
|
||||
var pdu RDCleanPathPDU
|
||||
_, err := asn1.Unmarshal(bytes, &pdu)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to parse RDCleanPath PDU: %v", err)
|
||||
n := len(bytes)
|
||||
if n > 20 {
|
||||
n = 20
|
||||
}
|
||||
log.Warnf("First %d bytes: %x", n, bytes[:n])
|
||||
|
||||
if len(bytes) > 0 && bytes[0] == 0x03 {
|
||||
log.Debug("Received raw RDP packet instead of RDCleanPath PDU")
|
||||
go p.handleDirectRDP(conn, bytes)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
go p.processRDCleanPathPDU(conn, pdu)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) forwardToRDP(conn *proxyConnection, bytes []byte) {
|
||||
var writer io.Writer
|
||||
var connType string
|
||||
|
||||
if conn.tlsConn != nil {
|
||||
writer = conn.tlsConn
|
||||
connType = "TLS"
|
||||
} else if conn.rdpConn != nil {
|
||||
writer = conn.rdpConn
|
||||
connType = "TCP"
|
||||
} else {
|
||||
log.Error("No RDP connection available")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := writer.Write(bytes); err != nil {
|
||||
log.Errorf("Failed to write to %s: %v", connType, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []byte) {
|
||||
defer p.cleanupConnection(conn)
|
||||
|
||||
destination := conn.destination
|
||||
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to connect to %s: %v", destination, err)
|
||||
return
|
||||
}
|
||||
conn.rdpConn = rdpConn
|
||||
|
||||
_, err = rdpConn.Write(firstPacket)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write first packet: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]byte, 1024)
|
||||
n, err := rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.sendToWebSocket(conn, response[:n])
|
||||
|
||||
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
|
||||
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) {
|
||||
log.Debugf("Cleaning up connection %s", conn.id)
|
||||
conn.cancel()
|
||||
if conn.tlsConn != nil {
|
||||
log.Debug("Closing TLS connection")
|
||||
if err := conn.tlsConn.Close(); err != nil {
|
||||
log.Debugf("Error closing TLS connection: %v", err)
|
||||
}
|
||||
conn.tlsConn = nil
|
||||
}
|
||||
if conn.rdpConn != nil {
|
||||
log.Debug("Closing TCP connection")
|
||||
if err := conn.rdpConn.Close(); err != nil {
|
||||
log.Debugf("Error closing TCP connection: %v", err)
|
||||
}
|
||||
conn.rdpConn = nil
|
||||
}
|
||||
p.mu.Lock()
|
||||
delete(p.activeConnections, conn.id)
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
|
||||
if conn.wsHandlers.Get("receiveFromGo").Truthy() {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(data))
|
||||
js.CopyBytesToJS(uint8Array, data)
|
||||
conn.wsHandlers.Call("receiveFromGo", uint8Array.Get("buffer"))
|
||||
} else if conn.wsHandlers.Get("send").Truthy() {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(data))
|
||||
js.CopyBytesToJS(uint8Array, data)
|
||||
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
|
||||
}
|
||||
}
|
||||
251
client/wasm/internal/rdp/rdcleanpath_handlers.go
Normal file
251
client/wasm/internal/rdp/rdcleanpath_handlers.go
Normal file
@@ -0,0 +1,251 @@
|
||||
//go:build js
|
||||
|
||||
package rdp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/asn1"
|
||||
"io"
|
||||
"syscall/js"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
|
||||
|
||||
if pdu.Version != RDCleanPathVersion {
|
||||
p.sendRDCleanPathError(conn, "Unsupported version")
|
||||
return
|
||||
}
|
||||
|
||||
destination := conn.destination
|
||||
if pdu.Destination != "" {
|
||||
destination = pdu.Destination
|
||||
}
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to connect to %s: %v", destination, err)
|
||||
p.sendRDCleanPathError(conn, "Connection failed")
|
||||
p.cleanupConnection(conn)
|
||||
return
|
||||
}
|
||||
conn.rdpConn = rdpConn
|
||||
|
||||
// RDP always starts with X.224 negotiation, then determines if TLS is needed
|
||||
// Modern RDP (since Windows Vista/2008) typically requires TLS
|
||||
// The X.224 Connection Confirm response will indicate if TLS is required
|
||||
// For now, we'll attempt TLS for all connections as it's the modern default
|
||||
p.setupTLSConnection(conn, pdu)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
var x224Response []byte
|
||||
if len(pdu.X224ConnectionPDU) > 0 {
|
||||
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
|
||||
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write X.224 PDU: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to forward X.224")
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]byte, 1024)
|
||||
n, err := conn.rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
|
||||
return
|
||||
}
|
||||
x224Response = response[:n]
|
||||
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
|
||||
}
|
||||
|
||||
tlsConfig := p.getTLSConfigWithValidation(conn)
|
||||
|
||||
tlsConn := tls.Client(conn.rdpConn, tlsConfig)
|
||||
conn.tlsConn = tlsConn
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
log.Errorf("TLS handshake failed: %v", err)
|
||||
p.sendRDCleanPathError(conn, "TLS handshake failed")
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("TLS handshake successful")
|
||||
|
||||
// Certificate validation happens during handshake via VerifyConnection callback
|
||||
var certChain [][]byte
|
||||
connState := tlsConn.ConnectionState()
|
||||
if len(connState.PeerCertificates) > 0 {
|
||||
for _, cert := range connState.PeerCertificates {
|
||||
certChain = append(certChain, cert.Raw)
|
||||
}
|
||||
log.Debugf("Extracted %d certificates from TLS connection", len(certChain))
|
||||
}
|
||||
|
||||
responsePDU := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
ServerAddr: conn.destination,
|
||||
ServerCertChain: certChain,
|
||||
}
|
||||
|
||||
if len(x224Response) > 0 {
|
||||
responsePDU.X224ConnectionPDU = x224Response
|
||||
}
|
||||
|
||||
p.sendRDCleanPathPDU(conn, responsePDU)
|
||||
|
||||
log.Debug("Starting TLS forwarding")
|
||||
go p.forwardConnToWS(conn, conn.tlsConn, "TLS")
|
||||
go p.forwardWSToConn(conn, conn.tlsConn, "TLS")
|
||||
|
||||
<-conn.ctx.Done()
|
||||
log.Debug("TLS connection context done, cleaning up")
|
||||
p.cleanupConnection(conn)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
if len(pdu.X224ConnectionPDU) > 0 {
|
||||
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
|
||||
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write X.224 PDU: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to forward X.224")
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]byte, 1024)
|
||||
n, err := conn.rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
|
||||
return
|
||||
}
|
||||
|
||||
responsePDU := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
X224ConnectionPDU: response[:n],
|
||||
ServerAddr: conn.destination,
|
||||
}
|
||||
|
||||
p.sendRDCleanPathPDU(conn, responsePDU)
|
||||
} else {
|
||||
responsePDU := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
ServerAddr: conn.destination,
|
||||
}
|
||||
p.sendRDCleanPathPDU(conn, responsePDU)
|
||||
}
|
||||
|
||||
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
|
||||
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
|
||||
|
||||
<-conn.ctx.Done()
|
||||
log.Debug("TCP connection context done, cleaning up")
|
||||
p.cleanupConnection(conn)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
data, err := asn1.Marshal(pdu)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal RDCleanPath PDU: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("Sending RDCleanPath PDU response (%d bytes)", len(data))
|
||||
p.sendToWebSocket(conn, data)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) {
|
||||
pdu := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
Error: []byte(errorMsg),
|
||||
}
|
||||
|
||||
data, err := asn1.Marshal(pdu)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal error PDU: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.sendToWebSocket(conn, data)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
|
||||
msgChan := make(chan []byte)
|
||||
errChan := make(chan error)
|
||||
|
||||
handler := js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
if len(args) < 1 {
|
||||
errChan <- io.EOF
|
||||
return nil
|
||||
}
|
||||
|
||||
data := args[0]
|
||||
if data.InstanceOf(js.Global().Get("Uint8Array")) {
|
||||
length := data.Get("length").Int()
|
||||
bytes := make([]byte, length)
|
||||
js.CopyBytesToGo(bytes, data)
|
||||
msgChan <- bytes
|
||||
}
|
||||
return nil
|
||||
})
|
||||
defer handler.Release()
|
||||
|
||||
conn.wsHandlers.Set("onceGoMessage", handler)
|
||||
|
||||
select {
|
||||
case msg := <-msgChan:
|
||||
return msg, nil
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
case <-conn.ctx.Done():
|
||||
return nil, conn.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) forwardWSToConn(conn *proxyConnection, dst io.Writer, connType string) {
|
||||
for {
|
||||
if conn.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := p.readWebSocketMessage(conn)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Errorf("Failed to read from WebSocket: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
_, err = dst.Write(msg)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write to %s: %v", connType, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) forwardConnToWS(conn *proxyConnection, src io.Reader, connType string) {
|
||||
buffer := make([]byte, 32*1024)
|
||||
|
||||
for {
|
||||
if conn.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
n, err := src.Read(buffer)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Errorf("Failed to read from %s: %v", connType, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
p.sendToWebSocket(conn, buffer[:n])
|
||||
}
|
||||
}
|
||||
}
|
||||
213
client/wasm/internal/ssh/client.go
Normal file
213
client/wasm/internal/ssh/client.go
Normal file
@@ -0,0 +1,213 @@
|
||||
//go:build js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
)
|
||||
|
||||
const (
|
||||
sshDialTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
func closeWithLog(c io.Closer, resource string) {
|
||||
if c != nil {
|
||||
if err := c.Close(); err != nil {
|
||||
logrus.Debugf("Failed to close %s: %v", resource, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
nbClient *netbird.Client
|
||||
sshClient *ssh.Client
|
||||
session *ssh.Session
|
||||
stdin io.WriteCloser
|
||||
stdout io.Reader
|
||||
stderr io.Reader
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewClient creates a new SSH client
|
||||
func NewClient(nbClient *netbird.Client) *Client {
|
||||
return &Client{
|
||||
nbClient: nbClient,
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes an SSH connection through NetBird network
|
||||
func (c *Client) Connect(host string, port int, username string) error {
|
||||
addr := fmt.Sprintf("%s:%d", host, port)
|
||||
logrus.Infof("SSH: Connecting to %s as %s", addr, username)
|
||||
|
||||
var authMethods []ssh.AuthMethod
|
||||
|
||||
nbConfig, err := c.nbClient.GetConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get NetBird config: %w", err)
|
||||
}
|
||||
if nbConfig.SSHKey == "" {
|
||||
return fmt.Errorf("no NetBird SSH key available - key should be generated during client initialization")
|
||||
}
|
||||
|
||||
signer, err := parseSSHPrivateKey([]byte(nbConfig.SSHKey))
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse NetBird SSH private key: %w", err)
|
||||
}
|
||||
|
||||
pubKey := signer.PublicKey()
|
||||
logrus.Infof("SSH: Using NetBird key authentication with public key type: %s", pubKey.Type())
|
||||
|
||||
authMethods = append(authMethods, ssh.PublicKeys(signer))
|
||||
|
||||
config := &ssh.ClientConfig{
|
||||
User: username,
|
||||
Auth: authMethods,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: sshDialTimeout,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := c.nbClient.Dial(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial %s: %w", addr, err)
|
||||
}
|
||||
|
||||
sshConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
|
||||
if err != nil {
|
||||
closeWithLog(conn, "connection after handshake error")
|
||||
return fmt.Errorf("SSH handshake: %w", err)
|
||||
}
|
||||
|
||||
c.sshClient = ssh.NewClient(sshConn, chans, reqs)
|
||||
logrus.Infof("SSH: Connected to %s", addr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartSession starts an SSH session with PTY
|
||||
func (c *Client) StartSession(cols, rows int) error {
|
||||
if c.sshClient == nil {
|
||||
return fmt.Errorf("SSH client not connected")
|
||||
}
|
||||
|
||||
session, err := c.sshClient.NewSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create session: %w", err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.session = session
|
||||
|
||||
modes := ssh.TerminalModes{
|
||||
ssh.ECHO: 1,
|
||||
ssh.TTY_OP_ISPEED: 14400,
|
||||
ssh.TTY_OP_OSPEED: 14400,
|
||||
ssh.VINTR: 3,
|
||||
ssh.VQUIT: 28,
|
||||
ssh.VERASE: 127,
|
||||
}
|
||||
|
||||
if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil {
|
||||
closeWithLog(session, "session after PTY error")
|
||||
return fmt.Errorf("PTY request: %w", err)
|
||||
}
|
||||
|
||||
c.stdin, err = session.StdinPipe()
|
||||
if err != nil {
|
||||
closeWithLog(session, "session after stdin error")
|
||||
return fmt.Errorf("get stdin: %w", err)
|
||||
}
|
||||
|
||||
c.stdout, err = session.StdoutPipe()
|
||||
if err != nil {
|
||||
closeWithLog(session, "session after stdout error")
|
||||
return fmt.Errorf("get stdout: %w", err)
|
||||
}
|
||||
|
||||
c.stderr, err = session.StderrPipe()
|
||||
if err != nil {
|
||||
closeWithLog(session, "session after stderr error")
|
||||
return fmt.Errorf("get stderr: %w", err)
|
||||
}
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
closeWithLog(session, "session after shell error")
|
||||
return fmt.Errorf("start shell: %w", err)
|
||||
}
|
||||
|
||||
logrus.Info("SSH: Session started with PTY")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write sends data to the SSH session
|
||||
func (c *Client) Write(data []byte) (int, error) {
|
||||
c.mu.RLock()
|
||||
stdin := c.stdin
|
||||
c.mu.RUnlock()
|
||||
|
||||
if stdin == nil {
|
||||
return 0, fmt.Errorf("SSH session not started")
|
||||
}
|
||||
return stdin.Write(data)
|
||||
}
|
||||
|
||||
// Read reads data from the SSH session
|
||||
func (c *Client) Read(buffer []byte) (int, error) {
|
||||
c.mu.RLock()
|
||||
stdout := c.stdout
|
||||
c.mu.RUnlock()
|
||||
|
||||
if stdout == nil {
|
||||
return 0, fmt.Errorf("SSH session not started")
|
||||
}
|
||||
return stdout.Read(buffer)
|
||||
}
|
||||
|
||||
// Resize updates the terminal size
|
||||
func (c *Client) Resize(cols, rows int) error {
|
||||
c.mu.RLock()
|
||||
session := c.session
|
||||
c.mu.RUnlock()
|
||||
|
||||
if session == nil {
|
||||
return fmt.Errorf("SSH session not started")
|
||||
}
|
||||
return session.WindowChange(rows, cols)
|
||||
}
|
||||
|
||||
// Close closes the SSH connection
|
||||
func (c *Client) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.session != nil {
|
||||
closeWithLog(c.session, "SSH session")
|
||||
c.session = nil
|
||||
}
|
||||
if c.stdin != nil {
|
||||
closeWithLog(c.stdin, "stdin")
|
||||
c.stdin = nil
|
||||
}
|
||||
c.stdout = nil
|
||||
c.stderr = nil
|
||||
|
||||
if c.sshClient != nil {
|
||||
err := c.sshClient.Close()
|
||||
c.sshClient = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
78
client/wasm/internal/ssh/handlers.go
Normal file
78
client/wasm/internal/ssh/handlers.go
Normal file
@@ -0,0 +1,78 @@
|
||||
//go:build js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"io"
|
||||
"syscall/js"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// CreateJSInterface creates a JavaScript interface for the SSH client
|
||||
func CreateJSInterface(client *Client) js.Value {
|
||||
jsInterface := js.Global().Get("Object").Call("create", js.Null())
|
||||
|
||||
jsInterface.Set("write", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
|
||||
data := args[0]
|
||||
var bytes []byte
|
||||
|
||||
if data.Type() == js.TypeString {
|
||||
bytes = []byte(data.String())
|
||||
} else {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(data)
|
||||
length := uint8Array.Get("length").Int()
|
||||
bytes = make([]byte, length)
|
||||
js.CopyBytesToGo(bytes, uint8Array)
|
||||
}
|
||||
|
||||
_, err := client.Write(bytes)
|
||||
return js.ValueOf(err == nil)
|
||||
}))
|
||||
|
||||
jsInterface.Set("resize", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 2 {
|
||||
return js.ValueOf(false)
|
||||
}
|
||||
cols := args[0].Int()
|
||||
rows := args[1].Int()
|
||||
err := client.Resize(cols, rows)
|
||||
return js.ValueOf(err == nil)
|
||||
}))
|
||||
|
||||
jsInterface.Set("close", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
client.Close()
|
||||
return js.Undefined()
|
||||
}))
|
||||
|
||||
go readLoop(client, jsInterface)
|
||||
|
||||
return jsInterface
|
||||
}
|
||||
|
||||
func readLoop(client *Client, jsInterface js.Value) {
|
||||
buffer := make([]byte, 4096)
|
||||
for {
|
||||
n, err := client.Read(buffer)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
logrus.Debugf("SSH read error: %v", err)
|
||||
}
|
||||
if onclose := jsInterface.Get("onclose"); !onclose.IsUndefined() {
|
||||
onclose.Invoke()
|
||||
}
|
||||
client.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if ondata := jsInterface.Get("ondata"); !ondata.IsUndefined() {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(n)
|
||||
js.CopyBytesToJS(uint8Array, buffer[:n])
|
||||
ondata.Invoke(uint8Array)
|
||||
}
|
||||
}
|
||||
}
|
||||
50
client/wasm/internal/ssh/key.go
Normal file
50
client/wasm/internal/ssh/key.go
Normal file
@@ -0,0 +1,50 @@
|
||||
//go:build js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// parseSSHPrivateKey parses a private key in either SSH or PKCS8 format
|
||||
func parseSSHPrivateKey(keyPEM []byte) (ssh.Signer, error) {
|
||||
keyStr := string(keyPEM)
|
||||
if !strings.Contains(keyStr, "-----BEGIN") {
|
||||
keyPEM = []byte("-----BEGIN PRIVATE KEY-----\n" + keyStr + "\n-----END PRIVATE KEY-----")
|
||||
}
|
||||
|
||||
signer, err := ssh.ParsePrivateKey(keyPEM)
|
||||
if err == nil {
|
||||
return signer, nil
|
||||
}
|
||||
logrus.Debugf("SSH: Failed to parse as SSH format: %v", err)
|
||||
|
||||
block, _ := pem.Decode(keyPEM)
|
||||
if block == nil {
|
||||
keyPreview := string(keyPEM)
|
||||
if len(keyPreview) > 100 {
|
||||
keyPreview = keyPreview[:100]
|
||||
}
|
||||
return nil, fmt.Errorf("decode PEM block from key: %s", keyPreview)
|
||||
}
|
||||
|
||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
logrus.Debugf("SSH: Failed to parse as PKCS8: %v", err)
|
||||
if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
|
||||
return ssh.NewSignerFromKey(rsaKey)
|
||||
}
|
||||
if ecKey, err := x509.ParseECPrivateKey(block.Bytes); err == nil {
|
||||
return ssh.NewSignerFromKey(ecKey)
|
||||
}
|
||||
return nil, fmt.Errorf("parse private key: %w", err)
|
||||
}
|
||||
|
||||
return ssh.NewSignerFromKey(key)
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !js
|
||||
|
||||
package encryption
|
||||
|
||||
import (
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
nbgrpc "github.com/netbirdio/netbird/client/grpc"
|
||||
"github.com/netbirdio/netbird/flow/proto"
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
"github.com/netbirdio/netbird/util/wsproxy"
|
||||
)
|
||||
|
||||
type GRPCClient struct {
|
||||
@@ -38,7 +39,8 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
|
||||
return nil, fmt.Errorf("parsing url: %w", err)
|
||||
}
|
||||
var opts []grpc.DialOption
|
||||
if parsedURL.Scheme == "https" {
|
||||
tlsEnabled := parsedURL.Scheme == "https"
|
||||
if tlsEnabled {
|
||||
certPool, err := x509.SystemCertPool()
|
||||
if err != nil || certPool == nil {
|
||||
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
||||
@@ -53,7 +55,7 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
|
||||
}
|
||||
|
||||
opts = append(opts,
|
||||
nbgrpc.WithCustomDialer(),
|
||||
nbgrpc.WithCustomDialer(tlsEnabled, wsproxy.FlowComponent),
|
||||
grpc.WithIdleTimeout(interval*2),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 30 * time.Second,
|
||||
|
||||
4
go.mod
4
go.mod
@@ -37,7 +37,7 @@ require (
|
||||
github.com/c-robinson/iplib v1.0.3
|
||||
github.com/caddyserver/certmagic v0.21.3
|
||||
github.com/cilium/ebpf v0.15.0
|
||||
github.com/coder/websocket v1.8.12
|
||||
github.com/coder/websocket v1.8.13
|
||||
github.com/coreos/go-iptables v0.7.0
|
||||
github.com/creack/pty v1.1.18
|
||||
github.com/eko/gocache/lib/v4 v4.2.0
|
||||
@@ -102,6 +102,7 @@ require (
|
||||
goauthentik.io/api/v3 v3.2023051.3
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
||||
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
|
||||
golang.org/x/mod v0.25.0
|
||||
golang.org/x/net v0.42.0
|
||||
golang.org/x/oauth2 v0.28.0
|
||||
golang.org/x/sync v0.16.0
|
||||
@@ -243,7 +244,6 @@ require (
|
||||
go.uber.org/mock v0.4.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/image v0.18.0 // indirect
|
||||
golang.org/x/mod v0.25.0 // indirect
|
||||
golang.org/x/text v0.27.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
golang.org/x/tools v0.34.0 // indirect
|
||||
|
||||
4
go.sum
4
go.sum
@@ -140,8 +140,8 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk
|
||||
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
||||
github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
|
||||
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
|
||||
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
|
||||
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
|
||||
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
|
||||
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
|
||||
github.com/containerd/containerd v1.7.27 h1:yFyEyojddO3MIGVER2xJLWoCIn+Up4GaHFquP7hsFII=
|
||||
github.com/containerd/containerd v1.7.27/go.mod h1:xZmPnl75Vc+BLGt4MIfu6bp+fy03gdHAn9bz+FreFR0=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
|
||||
@@ -45,6 +45,9 @@ services:
|
||||
- $SIGNAL_VOLUMENAME:/var/lib/netbird
|
||||
labels:
|
||||
- traefik.enable=true
|
||||
- traefik.http.routers.netbird-wsproxy-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/signal`)
|
||||
- traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal
|
||||
- traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=10000
|
||||
- traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`)
|
||||
- traefik.http.services.netbird-signal.loadbalancer.server.port=10000
|
||||
- traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c
|
||||
@@ -87,7 +90,9 @@ services:
|
||||
- traefik.http.routers.netbird-api.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/api`)
|
||||
- traefik.http.routers.netbird-api.service=netbird-api
|
||||
- traefik.http.services.netbird-api.loadbalancer.server.port=33073
|
||||
|
||||
- traefik.http.routers.netbird-wsproxy-mgmt.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/management`)
|
||||
- traefik.http.routers.netbird-wsproxy-mgmt.service=netbird-wsproxy-mgmt
|
||||
- traefik.http.services.netbird-wsproxy-mgmt.loadbalancer.server.port=33073
|
||||
- traefik.http.routers.netbird-management.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/management.ManagementService/`)
|
||||
- traefik.http.routers.netbird-management.service=netbird-management
|
||||
- traefik.http.services.netbird-management.loadbalancer.server.port=33073
|
||||
|
||||
@@ -621,9 +621,11 @@ renderCaddyfile() {
|
||||
# relay
|
||||
reverse_proxy /relay* relay:80
|
||||
# Signal
|
||||
reverse_proxy /ws-proxy/signal* signal:10000
|
||||
reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000
|
||||
# Management
|
||||
reverse_proxy /api/* management:80
|
||||
reverse_proxy /ws-proxy/management* management:80
|
||||
reverse_proxy /management.ManagementService/* h2c://management:80
|
||||
# Zitadel
|
||||
reverse_proxy /zitadel.admin.v1.AdminService/* h2c://zitadel:8080
|
||||
|
||||
@@ -20,6 +20,10 @@ upstream management {
|
||||
# insert the grpc+http port of your management container here
|
||||
server 127.0.0.1:8012;
|
||||
}
|
||||
upstream relay {
|
||||
# insert the port of your relay container here
|
||||
server 127.0.0.1:33080;
|
||||
}
|
||||
|
||||
server {
|
||||
# HTTP server config
|
||||
@@ -52,6 +56,14 @@ server {
|
||||
location / {
|
||||
proxy_pass http://dashboard;
|
||||
}
|
||||
# Proxy Signal wsproxy endpoint
|
||||
location /ws-proxy/signal {
|
||||
proxy_pass http://signal;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "Upgrade";
|
||||
proxy_set_header Host $host;
|
||||
}
|
||||
# Proxy Signal
|
||||
location /signalexchange.SignalExchange/ {
|
||||
grpc_pass grpc://signal;
|
||||
@@ -64,6 +76,14 @@ server {
|
||||
location /api {
|
||||
proxy_pass http://management;
|
||||
}
|
||||
# Proxy Management wsproxy endpoint
|
||||
location /ws-proxy/management {
|
||||
proxy_pass http://management;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "Upgrade";
|
||||
proxy_set_header Host $host;
|
||||
}
|
||||
# Proxy Management grpc endpoint
|
||||
location /management.ManagementService/ {
|
||||
grpc_pass grpc://management;
|
||||
@@ -72,6 +92,14 @@ server {
|
||||
grpc_send_timeout 1d;
|
||||
grpc_socket_keepalive on;
|
||||
}
|
||||
# Proxy Relay
|
||||
location /relay {
|
||||
proxy_pass http://relay;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "Upgrade";
|
||||
proxy_set_header Host $host;
|
||||
}
|
||||
|
||||
ssl_certificate /etc/ssl/certs/ssl-cert-snakeoil.pem;
|
||||
ssl_certificate_key /etc/ssl/certs/ssl-cert-snakeoil.pem;
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMgmtDataDir = "/var/lib/netbird/"
|
||||
defaultMgmtConfigDir = "/etc/netbird"
|
||||
@@ -16,3 +22,5 @@ const (
|
||||
|
||||
defaultSingleAccModeDomain = "netbird.selfhosted"
|
||||
)
|
||||
|
||||
var defaultWSProxyAddr = fmt.Sprintf("127.0.0.1:%d", server.ManagementLegacyPort)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@@ -26,11 +27,11 @@ import (
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var newServer = func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) server.Server {
|
||||
return server.NewServer(config, dnsDomain, mgmtSingleAccModeDomain, mgmtPort, mgmtMetricsPort, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled)
|
||||
var newServer = func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool, wsProxyAddr string) server.Server {
|
||||
return server.NewServer(config, dnsDomain, mgmtSingleAccModeDomain, mgmtPort, mgmtMetricsPort, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled, wsProxyAddr)
|
||||
}
|
||||
|
||||
func SetNewServer(fn func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) server.Server) {
|
||||
func SetNewServer(fn func(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort int, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool, wsProxyAddr string) server.Server) {
|
||||
newServer = fn
|
||||
}
|
||||
|
||||
@@ -82,6 +83,10 @@ var (
|
||||
return fmt.Errorf("failed parsing the provided dns-domain. Valid status: %t, Length: %d", valid, len(dnsDomain))
|
||||
}
|
||||
|
||||
if _, _, err := net.SplitHostPort(wsProxyAddr); err != nil {
|
||||
return fmt.Errorf("invalid ws-proxy-backend-addr format: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
@@ -108,7 +113,7 @@ var (
|
||||
mgmtSingleAccModeDomain = ""
|
||||
}
|
||||
|
||||
srv := newServer(config, dnsDomain, mgmtSingleAccModeDomain, mgmtPort, mgmtMetricsPort, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled)
|
||||
srv := newServer(config, dnsDomain, mgmtSingleAccModeDomain, mgmtPort, mgmtMetricsPort, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled, wsProxyAddr)
|
||||
go func() {
|
||||
if err := srv.Start(cmd.Context()); err != nil {
|
||||
log.Fatalf("Server error: %v", err)
|
||||
|
||||
@@ -31,6 +31,7 @@ var (
|
||||
mgmtSingleAccModeDomain string
|
||||
certFile string
|
||||
certKey string
|
||||
wsProxyAddr string
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird-mgmt",
|
||||
@@ -57,6 +58,7 @@ func init() {
|
||||
mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
|
||||
mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
|
||||
mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location")
|
||||
mgmtCmd.Flags().StringVar(&wsProxyAddr, "ws-proxy-backend-addr", defaultWSProxyAddr, "WebSocket proxy backend address in host:port format that the proxy will dial")
|
||||
mgmtCmd.Flags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file")
|
||||
mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
|
||||
mgmtCmd.Flags().StringVar(&mgmtSingleAccModeDomain, "single-account-mode-domain", defaultSingleAccModeDomain, "Enables single account mode. This means that all the users will be under the same account grouped by the specified domain. If the installation has more than one account, the property is ineffective. Enabled by default with the default domain "+defaultSingleAccModeDomain)
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||
)
|
||||
|
||||
func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager {
|
||||
@@ -56,8 +58,8 @@ func (s *BaseServer) AuthManager() auth.Manager {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) EphemeralManager() *server.EphemeralManager {
|
||||
return Create(s, func() *server.EphemeralManager {
|
||||
return server.NewEphemeralManager(s.Store(), s.AccountManager())
|
||||
func (s *BaseServer) EphemeralManager() ephemeral.Manager {
|
||||
return Create(s, func() ephemeral.Manager {
|
||||
return manager.NewEphemeralManager(s.Store(), s.AccountManager())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -65,6 +65,10 @@ func (s *BaseServer) AccountManager() account.Manager {
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create account manager: %v", err)
|
||||
}
|
||||
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
accountManager.SetEphemeralManager(s.EphemeralManager())
|
||||
})
|
||||
return accountManager
|
||||
})
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
@@ -22,6 +23,8 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/metrics"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/util/wsproxy"
|
||||
wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -54,6 +57,7 @@ type BaseServer struct {
|
||||
mgmtSingleAccModeDomain string
|
||||
mgmtMetricsPort int
|
||||
mgmtPort int
|
||||
wsProxyAddr string
|
||||
|
||||
listener net.Listener
|
||||
certManager *autocert.Manager
|
||||
@@ -65,7 +69,7 @@ type BaseServer struct {
|
||||
}
|
||||
|
||||
// NewServer initializes and configures a new Server instance
|
||||
func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) *BaseServer {
|
||||
func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool, wsProxyAddr string) *BaseServer {
|
||||
return &BaseServer{
|
||||
config: config,
|
||||
container: make(map[string]any),
|
||||
@@ -76,6 +80,7 @@ func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain strin
|
||||
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
||||
mgmtPort: mgmtPort,
|
||||
mgmtMetricsPort: mgmtMetricsPort,
|
||||
wsProxyAddr: wsProxyAddr,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,12 +97,6 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
s.PeersManager()
|
||||
s.GeoLocationManager()
|
||||
|
||||
for _, fn := range s.afterInit {
|
||||
if fn != nil {
|
||||
fn(s)
|
||||
}
|
||||
}
|
||||
|
||||
err := s.Metrics().Expose(srvCtx, s.mgmtMetricsPort, "/metrics")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to expose metrics: %v", err)
|
||||
@@ -147,7 +146,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
|
||||
}
|
||||
|
||||
rootHandler := handlerFunc(s.GRPCServer(), s.APIHandler())
|
||||
rootHandler := s.handlerFunc(s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter())
|
||||
switch {
|
||||
case s.certManager != nil:
|
||||
// a call to certManager.Listener() always creates a new listener so we do it once
|
||||
@@ -176,8 +175,15 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
for _, fn := range s.afterInit {
|
||||
if fn != nil {
|
||||
fn(s)
|
||||
}
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion())
|
||||
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String())
|
||||
log.WithContext(ctx).Infof("WebSocket proxy configured to forward to: %s", s.wsProxyAddr)
|
||||
s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled)
|
||||
|
||||
s.update = version.NewUpdate("nb/management")
|
||||
@@ -247,13 +253,17 @@ func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config)
|
||||
return util.DirectWriteJson(ctx, path, config)
|
||||
}
|
||||
|
||||
func handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler) http.Handler {
|
||||
func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
|
||||
wsProxy := wsproxyserver.New(s.wsProxyAddr, wsproxyserver.WithOTelMeter(meter))
|
||||
|
||||
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
grpcHeader := strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") ||
|
||||
strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto")
|
||||
if request.ProtoMajor == 2 && grpcHeader {
|
||||
switch {
|
||||
case request.ProtoMajor == 2 && (strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc") ||
|
||||
strings.HasPrefix(request.Header.Get("Content-Type"), "application/grpc+proto")):
|
||||
gRPCHandler.ServeHTTP(writer, request)
|
||||
} else {
|
||||
case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
|
||||
wsProxy.Handler().ServeHTTP(writer, request)
|
||||
default:
|
||||
httpHandler.ServeHTTP(writer, request)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -35,6 +35,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
@@ -74,6 +75,7 @@ type DefaultAccountManager struct {
|
||||
ctx context.Context
|
||||
eventStore activity.Store
|
||||
geo geolocation.Geolocation
|
||||
ephemeralManager ephemeral.Manager
|
||||
|
||||
requestBuffer *AccountRequestBuffer
|
||||
|
||||
@@ -261,6 +263,10 @@ func BuildManager(
|
||||
return am, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SetEphemeralManager(em ephemeral.Manager) {
|
||||
am.ephemeralManager = em
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) startWarmup(ctx context.Context) {
|
||||
var initialInterval int64
|
||||
intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS")
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -56,7 +57,7 @@ type Manager interface {
|
||||
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
|
||||
AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
|
||||
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
|
||||
@@ -125,5 +126,6 @@ type Manager interface {
|
||||
UpdateToPrimaryAccount(ctx context.Context, accountId string) error
|
||||
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
|
||||
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
SetEphemeralManager(em ephemeral.Manager)
|
||||
AllowSync(string, uint64) bool
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account
|
||||
setupKey = key.Key
|
||||
}
|
||||
|
||||
_, _, _, err := manager.AddPeer(context.Background(), setupKey, userID, peer)
|
||||
_, _, _, err := manager.AddPeer(context.Background(), "", setupKey, userID, peer, false)
|
||||
if err != nil {
|
||||
t.Error("expected to add new peer successfully after creating new account, but failed", err)
|
||||
}
|
||||
@@ -1048,10 +1048,10 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
||||
}
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Errorf("expecting peer to be added, got failure %v", err)
|
||||
return
|
||||
@@ -1112,10 +1112,10 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
expectedUserID := userID
|
||||
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Errorf("expecting peer to be added, got failure %v, account users: %v", err, account.CreatedBy)
|
||||
return
|
||||
@@ -1429,10 +1429,10 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
|
||||
peerKey := key.PublicKey().String()
|
||||
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: peerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: peerKey},
|
||||
})
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Errorf("expecting peer to be added, got failure %v", err)
|
||||
return
|
||||
@@ -1805,11 +1805,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
LoginExpirationEnabled: true,
|
||||
})
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
@@ -1861,11 +1861,11 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
LoginExpirationEnabled: true,
|
||||
})
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
@@ -1904,11 +1904,11 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
|
||||
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
LoginExpirationEnabled: true,
|
||||
})
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
|
||||
@@ -2952,14 +2952,14 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account,
|
||||
}
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
Status: &nbpeer.PeerStatus{
|
||||
Connected: true,
|
||||
LastSeen: time.Now().UTC(),
|
||||
},
|
||||
})
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("expecting peer to be added, got failure %v", err)
|
||||
}
|
||||
@@ -3552,16 +3552,16 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
|
||||
key2, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
|
||||
peer1, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
|
||||
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key1.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
|
||||
})
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer1")
|
||||
|
||||
peer2, _, _, err := manager.AddPeer(context.Background(), "", userID, &nbpeer.Peer{
|
||||
peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: key2.PublicKey().String(),
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
|
||||
})
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer2")
|
||||
|
||||
t.Run("update peer IP successfully", func(t *testing.T) {
|
||||
|
||||
@@ -6,9 +6,11 @@ import (
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/mod/semver"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -18,6 +20,13 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
dnsForwarderPort = 22054
|
||||
oldForwarderPort = 5353
|
||||
)
|
||||
|
||||
const dnsForwarderPortMinVersion = "v0.59.0"
|
||||
|
||||
// DNSConfigCache is a thread-safe cache for DNS configuration components
|
||||
type DNSConfigCache struct {
|
||||
NameServerGroups sync.Map
|
||||
@@ -183,12 +192,45 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID
|
||||
return validateGroups(settings.DisabledManagementGroups, groups)
|
||||
}
|
||||
|
||||
// computeForwarderPort checks if all peers in the account have updated to a specific version or newer.
|
||||
// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0.
|
||||
func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 {
|
||||
if len(peers) == 0 {
|
||||
return oldForwarderPort
|
||||
}
|
||||
|
||||
reqVer := semver.Canonical(requiredVersion)
|
||||
|
||||
// Check if all peers have the required version or newer
|
||||
for _, peer := range peers {
|
||||
|
||||
// Development version is always supported
|
||||
if peer.Meta.WtVersion == "development" {
|
||||
continue
|
||||
}
|
||||
peerVersion := semver.Canonical("v" + peer.Meta.WtVersion)
|
||||
if peerVersion == "" {
|
||||
// If any peer doesn't have version info, return 0
|
||||
return oldForwarderPort
|
||||
}
|
||||
|
||||
// Compare versions
|
||||
if semver.Compare(peerVersion, reqVer) < 0 {
|
||||
return oldForwarderPort
|
||||
}
|
||||
}
|
||||
|
||||
// All peers have the required version or newer
|
||||
return dnsForwarderPort
|
||||
}
|
||||
|
||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
||||
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
|
||||
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache, forwardPort int64) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{
|
||||
ServiceEnable: update.ServiceEnable,
|
||||
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
|
||||
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
|
||||
ForwarderPort: forwardPort,
|
||||
}
|
||||
|
||||
for _, zone := range update.CustomZones {
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -281,11 +280,11 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
|
||||
return nil, err
|
||||
}
|
||||
|
||||
savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1)
|
||||
savedPeer1, _, _, err := am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer1, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2)
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer2, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -324,13 +323,13 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account.NameServerGroups[dnsNSGroup1] = &dns.NameServerGroup{
|
||||
account.NameServerGroups[dnsNSGroup1] = &nbdns.NameServerGroup{
|
||||
ID: dnsNSGroup1,
|
||||
Name: "ns-group-1",
|
||||
NameServers: []dns.NameServer{{
|
||||
NameServers: []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr(savedPeer1.IP.String()),
|
||||
NSType: dns.UDPNameServerType,
|
||||
Port: dns.DefaultDNSPort,
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
Primary: true,
|
||||
Enabled: true,
|
||||
@@ -395,7 +394,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
toProtocolDNSConfig(testData, cache)
|
||||
toProtocolDNSConfig(testData, cache, dnsForwarderPort)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -403,7 +402,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache := &DNSConfigCache{}
|
||||
toProtocolDNSConfig(testData, cache)
|
||||
toProtocolDNSConfig(testData, cache, dnsForwarderPort)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -456,13 +455,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
}
|
||||
|
||||
// First run with config1
|
||||
result1 := toProtocolDNSConfig(config1, &cache)
|
||||
result1 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort)
|
||||
|
||||
// Second run with config2
|
||||
result2 := toProtocolDNSConfig(config2, &cache)
|
||||
result2 := toProtocolDNSConfig(config2, &cache, dnsForwarderPort)
|
||||
|
||||
// Third run with config1 again
|
||||
result3 := toProtocolDNSConfig(config1, &cache)
|
||||
result3 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort)
|
||||
|
||||
// Verify that result1 and result3 are identical
|
||||
if !reflect.DeepEqual(result1, result3) {
|
||||
@@ -483,6 +482,107 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeForwarderPort(t *testing.T) {
|
||||
// Test with empty peers list
|
||||
peers := []*nbpeer.Peer{}
|
||||
result := computeForwarderPort(peers, "v0.59.0")
|
||||
if result != oldForwarderPort {
|
||||
t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have old versions
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.57.0",
|
||||
},
|
||||
},
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.26.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != oldForwarderPort {
|
||||
t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have new versions
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.59.0",
|
||||
},
|
||||
},
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.59.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != dnsForwarderPort {
|
||||
t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have mixed versions
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.59.0",
|
||||
},
|
||||
},
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "0.57.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != oldForwarderPort {
|
||||
t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have empty version
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != oldForwarderPort {
|
||||
t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result)
|
||||
}
|
||||
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "development",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result == oldForwarderPort {
|
||||
t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result)
|
||||
}
|
||||
|
||||
// Test with peers that have unknown version string
|
||||
peers = []*nbpeer.Peer{
|
||||
{
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
WtVersion: "unknown",
|
||||
},
|
||||
},
|
||||
}
|
||||
result = computeForwarderPort(peers, "v0.59.0")
|
||||
if result != oldForwarderPort {
|
||||
t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
@@ -534,10 +634,10 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
}()
|
||||
|
||||
_, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "ns-group", "ns-group", []dns.NameServer{{
|
||||
context.Background(), account.Id, "ns-group", "ns-group", []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr(peer1.IP.String()),
|
||||
NSType: dns.UDPNameServerType,
|
||||
Port: dns.DefaultDNSPort,
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupB"},
|
||||
true, []string{}, true, userID, false,
|
||||
@@ -567,10 +667,10 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
}()
|
||||
|
||||
_, err = manager.CreateNameServerGroup(
|
||||
context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{
|
||||
context.Background(), account.Id, "ns-group-1", "ns-group-1", []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr(peer1.IP.String()),
|
||||
NSType: dns.UDPNameServerType,
|
||||
Port: dns.DefaultDNSPort,
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
[]string{"groupA"},
|
||||
true, []string{}, true, userID, false,
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -55,7 +56,7 @@ type GRPCServer struct {
|
||||
config *nbconfig.Config
|
||||
secretsManager SecretsManager
|
||||
appMetrics telemetry.AppMetrics
|
||||
ephemeralManager *EphemeralManager
|
||||
ephemeralManager ephemeral.Manager
|
||||
peerLocks sync.Map
|
||||
authManager auth.Manager
|
||||
|
||||
@@ -73,7 +74,7 @@ func NewServer(
|
||||
peersUpdateManager *PeersUpdateManager,
|
||||
secretsManager SecretsManager,
|
||||
appMetrics telemetry.AppMetrics,
|
||||
ephemeralManager *EphemeralManager,
|
||||
ephemeralManager ephemeral.Manager,
|
||||
authManager auth.Manager,
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator,
|
||||
) (*GRPCServer, error) {
|
||||
@@ -714,13 +715,13 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
|
||||
}
|
||||
}
|
||||
|
||||
func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string) *proto.SyncResponse {
|
||||
func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
|
||||
response := &proto.SyncResponse{
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache),
|
||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||
},
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
@@ -731,11 +732,11 @@ func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.P
|
||||
|
||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||
|
||||
allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||
allPeers = appendRemotePeerConfig(allPeers, networkMap.Peers, dnsName)
|
||||
response.RemotePeers = allPeers
|
||||
response.NetworkMap.RemotePeers = allPeers
|
||||
response.RemotePeersIsEmpty = len(allPeers) == 0
|
||||
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName)
|
||||
response.RemotePeers = remotePeers
|
||||
response.NetworkMap.RemotePeers = remotePeers
|
||||
response.RemotePeersIsEmpty = len(remotePeers) == 0
|
||||
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
|
||||
|
||||
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
|
||||
@@ -807,7 +808,14 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
|
||||
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
|
||||
}
|
||||
|
||||
plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups)
|
||||
// Get all peers in the account for forwarder port computation
|
||||
allPeers, err := s.accountManager.GetStore().GetAccountPeers(ctx, store.LockingStrengthNone, peer.AccountID, "", "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("get account peers: %w", err)
|
||||
}
|
||||
dnsFwdPort := computeForwarderPort(allPeers, dnsForwarderPortMinVersion)
|
||||
|
||||
plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
||||
if err != nil {
|
||||
|
||||
@@ -32,6 +32,7 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
|
||||
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/temporary-access", peersHandler.CreateTemporaryAccess).Methods("POST", "OPTIONS")
|
||||
}
|
||||
|
||||
// NewHandler creates a new peers Handler
|
||||
@@ -318,6 +319,88 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
||||
}
|
||||
|
||||
func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
if len(peerID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PeerTemporaryAccessRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
newPeer := &nbpeer.Peer{}
|
||||
newPeer.FromAPITemporaryAccessRequest(&req)
|
||||
|
||||
targetPeer, err := h.accountManager.GetPeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
peer, _, _, err := h.accountManager.AddPeer(r.Context(), userAuth.AccountId, "", userAuth.UserId, newPeer, true)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
for _, rule := range req.Rules {
|
||||
protocol, portRange, err := types.ParseRuleString(rule)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
policy := &types.Policy{
|
||||
AccountID: userAuth.AccountId,
|
||||
Description: "Temporary access policy for peer " + peer.Name,
|
||||
Name: "Temporary access policy for peer " + peer.Name,
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
Name: "Temporary access rule",
|
||||
Description: "Temporary access rule",
|
||||
Enabled: true,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
SourceResource: types.Resource{
|
||||
Type: types.ResourceTypePeer,
|
||||
ID: peer.ID,
|
||||
},
|
||||
DestinationResource: types.Resource{
|
||||
Type: types.ResourceTypePeer,
|
||||
ID: targetPeer.ID,
|
||||
},
|
||||
Bidirectional: false,
|
||||
Protocol: protocol,
|
||||
PortRanges: []types.RulePortRange{portRange},
|
||||
}},
|
||||
}
|
||||
|
||||
_, err = h.accountManager.SavePolicy(r.Context(), userAuth.AccountId, userAuth.UserId, policy, true)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
resp := &api.PeerTemporaryAccessResponse{
|
||||
Id: peer.ID,
|
||||
Name: peer.Name,
|
||||
Rules: req.Rules,
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
func toAccessiblePeers(netMap *types.NetworkMap, dnsDomain string) []api.AccessiblePeer {
|
||||
accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers))
|
||||
for _, p := range netMap.Peers {
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -460,7 +461,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
|
||||
|
||||
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
|
||||
ephemeralMgr := NewEphemeralManager(store, accountManager)
|
||||
ephemeralMgr := manager.NewEphemeralManager(store, accountManager)
|
||||
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{})
|
||||
if err != nil {
|
||||
return nil, nil, "", cleanup, err
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -228,7 +229,7 @@ func startServer(
|
||||
peersUpdateManager,
|
||||
secretsManager,
|
||||
nil,
|
||||
nil,
|
||||
&manager.EphemeralManager{},
|
||||
nil,
|
||||
server.MockIntegratedValidator{},
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -41,7 +42,7 @@ type MockAccountManager struct {
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
|
||||
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
|
||||
AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error)
|
||||
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
||||
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error)
|
||||
@@ -351,12 +352,14 @@ func (am *MockAccountManager) GetPeerNetwork(ctx context.Context, peerKey string
|
||||
// AddPeer mock implementation of AddPeer from server.AccountManager interface
|
||||
func (am *MockAccountManager) AddPeer(
|
||||
ctx context.Context,
|
||||
accountID string,
|
||||
setupKey string,
|
||||
userId string,
|
||||
peer *nbpeer.Peer,
|
||||
temporary bool,
|
||||
) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
if am.AddPeerFunc != nil {
|
||||
return am.AddPeerFunc(ctx, setupKey, userId, peer)
|
||||
return am.AddPeerFunc(ctx, accountID, setupKey, userId, peer, temporary)
|
||||
}
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented")
|
||||
}
|
||||
@@ -972,6 +975,11 @@ func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth n
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented")
|
||||
}
|
||||
|
||||
// SetEphemeralManager mocks SetEphemeralManager of the AccountManager interface
|
||||
func (am *MockAccountManager) SetEphemeralManager(em ephemeral.Manager) {
|
||||
// Mock implementation - does nothing
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) AllowSync(key string, hash uint64) bool {
|
||||
if am.AllowSyncFunc != nil {
|
||||
return am.AllowSyncFunc(key, hash)
|
||||
|
||||
@@ -876,11 +876,11 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", userID, peer1)
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer1, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", userID, peer2)
|
||||
_, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer2, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user