Compare commits

...

10 Commits

Author SHA1 Message Date
Pascal Fischer
ad8459ea2f add mysql support [WIP] 2024-09-27 13:44:50 +02:00
Zoltan Papp
4ebf6e1c4c [client] Close the remote conn in proxy (#2626)
Port the conn close call to eBPF proxy
2024-09-25 18:50:10 +02:00
pascal-fischer
1e4a0f77e2 Add get DB method to store (#2650) 2024-09-25 18:22:27 +02:00
Viktor Liu
b51d75204b [client] Anonymize relay address in status peers view (#2640) 2024-09-24 20:58:18 +02:00
Viktor Liu
e7d52c8c95 [client] Fix error count formatting (#2641) 2024-09-24 20:57:56 +02:00
Viktor Liu
ab82302c95 [client] Remove usage of custom dialer for localhost (#2639)
* Downgrade error log level for network monitor warnings

* Do not use custom dialer for localhost
2024-09-24 12:29:15 +02:00
pascal-fischer
d47be154ea [misc] Fix ip range posture check example (#2628) 2024-09-23 10:02:03 +02:00
Bethuel Mmbaga
35c892aea3 [management] Restrict accessible peers to user-owned peers for non-admins (#2618)
* Restrict accessible peers to user-owned peers for non-admin users

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add service user test

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* reuse account from token

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* return error when peer not found

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-20 12:36:58 +03:00
Zoltan Papp
fc4b37f7bc Exit from processConnResults after all tries (#2621)
* Exit from processConnResults after all tries

If all server is unavailable then the server picker never return
because we never close the result channel.
Count the number of the results and exit when we reached the
expected size
2024-09-19 13:49:28 +02:00
Zoltan Papp
6f0fd1d1b3 - Increase queue size and drop the overflowed messages (#2617)
- Explicit close the net.Conn in user space wgProxy when close the wgProxy
- Add extra logs
2024-09-19 13:49:09 +02:00
29 changed files with 784 additions and 291 deletions

View File

@@ -805,6 +805,9 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil { if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port) peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
} }
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
for i, route := range peer.Routes { for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeIPString(route) peer.Routes[i] = a.AnonymizeIPString(route)
} }

View File

@@ -8,8 +8,8 @@ import (
) )
func formatError(es []error) string { func formatError(es []error) string {
if len(es) == 0 { if len(es) == 1 {
return fmt.Sprintf("0 error occurred:\n\t* %s", es[0]) return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
} }
points := make([]string, len(es)) points := make([]string, len(es))

View File

@@ -292,7 +292,7 @@ func (e *Engine) Start() error {
e.wgInterface = wgIface e.wgInterface = wgIface
userspace := e.wgInterface.IsUserspaceBind() userspace := e.wgInterface.IsUserspaceBind()
e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort) e.wgProxyFactory = wgproxy.NewFactory(userspace, e.config.WgPort)
if e.config.RosenpassEnabled { if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled") log.Infof("rosenpass is enabled")

View File

@@ -24,7 +24,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
defer func() { defer func() {
err := unix.Close(fd) err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) { if err != nil && !errors.Is(err, unix.EBADF) {
log.Errorf("Network monitor: failed to close routing socket: %v", err) log.Warnf("Network monitor: failed to close routing socket: %v", err)
} }
}() }()
@@ -32,7 +32,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
<-ctx.Done() <-ctx.Done()
err := unix.Close(fd) err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) { if err != nil && !errors.Is(err, unix.EBADF) {
log.Debugf("Network monitor: closed routing socket") log.Debugf("Network monitor: closed routing socket: %v", err)
} }
}() }()
@@ -45,12 +45,12 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
n, err := unix.Read(fd, buf) n, err := unix.Read(fd, buf)
if err != nil { if err != nil {
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) { if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
log.Errorf("Network monitor: failed to read from routing socket: %v", err) log.Warnf("Network monitor: failed to read from routing socket: %v", err)
} }
continue continue
} }
if n < unix.SizeofRtMsghdr { if n < unix.SizeofRtMsghdr {
log.Errorf("Network monitor: read from routing socket returned less than expected: %d bytes", n) log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
continue continue
} }
@@ -61,7 +61,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
case unix.RTM_ADD, syscall.RTM_DELETE: case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n]) route, err := parseRouteMessage(buf[:n])
if err != nil { if err != nil {
log.Errorf("Network monitor: error parsing routing message: %v", err) log.Debugf("Network monitor: error parsing routing message: %v", err)
continue continue
} }

View File

@@ -518,18 +518,22 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
defer conn.mu.Unlock() defer conn.mu.Unlock()
if conn.ctx.Err() != nil { if conn.ctx.Err() != nil {
if err := rci.relayedConn.Close(); err != nil {
log.Warnf("failed to close unnecessary relayed connection: %v", err)
}
return return
} }
conn.log.Debugf("Relay connection is ready to use") conn.log.Debugf("Relay connection is ready to use")
conn.statusRelay.Set(StatusConnected) conn.statusRelay.Set(StatusConnected)
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx) wgProxy := conn.wgProxyFactory.GetProxy()
endpoint, err := wgProxy.AddTurnConn(rci.relayedConn) endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn)
if err != nil { if err != nil {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return return
} }
conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
conn.endpointRelay = endpointUdpAddr conn.endpointRelay = endpointUdpAddr
@@ -771,13 +775,12 @@ func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr,
return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
} }
conn.log.Debugf("setup ice turn connection") conn.log.Debugf("setup ice turn connection")
wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx) wgProxy := conn.wgProxyFactory.GetProxy()
ep, err := wgProxy.AddTurnConn(iceConnInfo.RemoteConn) ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn)
if err != nil { if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
err = wgProxy.CloseConn() if errClose := wgProxy.CloseConn(); errClose != nil {
if err != nil { conn.log.Warnf("failed to close turn proxy connection: %v", errClose)
conn.log.Warnf("failed to close turn proxy connection: %v", err)
} }
return nil, nil, err return nil, nil, err
} }

View File

@@ -44,7 +44,7 @@ func TestNewConn_interfaceFilter(t *testing.T) {
} }
func TestConn_GetKey(t *testing.T) { func TestConn_GetKey(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
@@ -59,7 +59,7 @@ func TestConn_GetKey(t *testing.T) {
} }
func TestConn_OnRemoteOffer(t *testing.T) { func TestConn_OnRemoteOffer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
@@ -96,7 +96,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
} }
func TestConn_OnRemoteAnswer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()
@@ -132,7 +132,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestConn_Status(t *testing.T) { func TestConn_Status(t *testing.T) {
wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() { defer func() {
_ = wgProxyFactory.Free() _ = wgProxyFactory.Free()
}() }()

View File

@@ -1,4 +1,4 @@
package wgproxy package ebpf
import ( import (
"fmt" "fmt"

View File

@@ -1,4 +1,4 @@
package wgproxy package ebpf
import ( import (
"fmt" "fmt"

View File

@@ -1,6 +1,6 @@
//go:build linux && !android //go:build linux && !android
package wgproxy package ebpf
import ( import (
"context" "context"
@@ -13,47 +13,49 @@ import (
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/ebpf" "github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
const (
loopbackAddr = "127.0.0.1"
)
// WGEBPFProxy definition for proxy with EBPF support // WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct { type WGEBPFProxy struct {
ebpfManager ebpfMgr.Manager
ctx context.Context
cancel context.CancelFunc
lastUsedPort uint16
localWGListenPort int localWGListenPort int
ebpfManager ebpfMgr.Manager
turnConnStore map[uint16]net.Conn turnConnStore map[uint16]net.Conn
turnConnMutex sync.Mutex turnConnMutex sync.Mutex
rawConn net.PacketConn lastUsedPort uint16
conn transport.UDPConn rawConn net.PacketConn
conn transport.UDPConn
ctx context.Context
ctxCancel context.CancelFunc
} }
// NewWGEBPFProxy create new WGEBPFProxy instance // NewWGEBPFProxy create new WGEBPFProxy instance
func NewWGEBPFProxy(ctx context.Context, wgPort int) *WGEBPFProxy { func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
log.Debugf("instantiate ebpf proxy") log.Debugf("instantiate ebpf proxy")
wgProxy := &WGEBPFProxy{ wgProxy := &WGEBPFProxy{
localWGListenPort: wgPort, localWGListenPort: wgPort,
ebpfManager: ebpf.GetEbpfManagerInstance(), ebpfManager: ebpf.GetEbpfManagerInstance(),
lastUsedPort: 0,
turnConnStore: make(map[uint16]net.Conn), turnConnStore: make(map[uint16]net.Conn),
} }
wgProxy.ctx, wgProxy.cancel = context.WithCancel(ctx)
return wgProxy return wgProxy
} }
// listen load ebpf program and listen the proxy // Listen load ebpf program and listen the proxy
func (p *WGEBPFProxy) listen() error { func (p *WGEBPFProxy) Listen() error {
pl := portLookup{} pl := portLookup{}
wgPorxyPort, err := pl.searchFreePort() wgPorxyPort, err := pl.searchFreePort()
if err != nil { if err != nil {
@@ -72,9 +74,11 @@ func (p *WGEBPFProxy) listen() error {
addr := net.UDPAddr{ addr := net.UDPAddr{
Port: wgPorxyPort, Port: wgPorxyPort,
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP(loopbackAddr),
} }
p.ctx, p.ctxCancel = context.WithCancel(context.Background())
conn, err := nbnet.ListenUDP("udp", &addr) conn, err := nbnet.ListenUDP("udp", &addr)
if err != nil { if err != nil {
cErr := p.Free() cErr := p.Free()
@@ -91,108 +95,112 @@ func (p *WGEBPFProxy) listen() error {
} }
// AddTurnConn add new turn connection for the proxy // AddTurnConn add new turn connection for the proxy
func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) { func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) {
wgEndpointPort, err := p.storeTurnConn(turnConn) wgEndpointPort, err := p.storeTurnConn(turnConn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
go p.proxyToLocal(wgEndpointPort, turnConn) go p.proxyToLocal(ctx, wgEndpointPort, turnConn)
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort) log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
wgEndpoint := &net.UDPAddr{ wgEndpoint := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"), IP: net.ParseIP(loopbackAddr),
Port: int(wgEndpointPort), Port: int(wgEndpointPort),
} }
return wgEndpoint, nil return wgEndpoint, nil
} }
// CloseConn doing nothing because this type of proxy implementation does not store the connection // Free resources except the remoteConns will be keep open.
func (p *WGEBPFProxy) CloseConn() error {
return nil
}
// Free resources
func (p *WGEBPFProxy) Free() error { func (p *WGEBPFProxy) Free() error {
log.Debugf("free up ebpf wg proxy") log.Debugf("free up ebpf wg proxy")
var err1, err2, err3 error if p.ctx != nil && p.ctx.Err() != nil {
if p.conn != nil { //nolint
err1 = p.conn.Close() return nil
} }
err2 = p.ebpfManager.FreeWGProxy() p.ctxCancel()
if p.rawConn != nil {
err3 = p.rawConn.Close() var result *multierror.Error
if err := p.conn.Close(); err != nil {
result = multierror.Append(result, err)
} }
if err1 != nil { if err := p.ebpfManager.FreeWGProxy(); err != nil {
return err1 result = multierror.Append(result, err)
} }
if err2 != nil { if err := p.rawConn.Close(); err != nil {
return err2 result = multierror.Append(result, err)
} }
return nberrors.FormatErrorOrNil(result)
return err3
} }
func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) { func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) {
defer p.removeTurnConn(endpointPort)
var (
err error
n int
)
buf := make([]byte, 1500) buf := make([]byte, 1500)
var err error for ctx.Err() == nil {
defer func() { n, err = remoteConn.Read(buf)
p.removeTurnConn(endpointPort) if err != nil {
}() if ctx.Err() != nil {
for {
select {
case <-p.ctx.Done():
return
default:
var n int
n, err = remoteConn.Read(buf)
if err != nil {
if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
return return
} }
err = p.sendPkg(buf[:n], endpointPort) if err != io.EOF {
if err != nil { log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
log.Errorf("failed to write out turn pkg to local conn: %v", err)
} }
return
}
if err := p.sendPkg(buf[:n], endpointPort); err != nil {
if ctx.Err() != nil || p.ctx.Err() != nil {
return
}
log.Errorf("failed to write out turn pkg to local conn: %v", err)
} }
} }
} }
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn // proxyToRemote read messages from local WireGuard interface and forward it to remote conn
// From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() { func (p *WGEBPFProxy) proxyToRemote() {
buf := make([]byte, 1500) buf := make([]byte, 1500)
for { for p.ctx.Err() == nil {
select { if err := p.readAndForwardPacket(buf); err != nil {
case <-p.ctx.Done(): if p.ctx.Err() != nil {
return
default:
n, addr, err := p.conn.ReadFromUDP(buf)
if err != nil {
log.Errorf("failed to read UDP pkg from WG: %s", err)
return return
} }
log.Errorf("failed to proxy packet to remote conn: %s", err)
p.turnConnMutex.Lock()
conn, ok := p.turnConnStore[uint16(addr.Port)]
p.turnConnMutex.Unlock()
if !ok {
log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port)
continue
}
_, err = conn.Write(buf[:n])
if err != nil {
log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err)
}
} }
} }
} }
func (p *WGEBPFProxy) readAndForwardPacket(buf []byte) error {
n, addr, err := p.conn.ReadFromUDP(buf)
if err != nil {
return fmt.Errorf("failed to read UDP packet from WG: %w", err)
}
p.turnConnMutex.Lock()
conn, ok := p.turnConnStore[uint16(addr.Port)]
p.turnConnMutex.Unlock()
if !ok {
if p.ctx.Err() == nil {
log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port)
}
return nil
}
if _, err := conn.Write(buf[:n]); err != nil {
return fmt.Errorf("failed to forward local WG packet (%d) to remote turn conn: %w", addr.Port, err)
}
return nil
}
func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) { func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
p.turnConnMutex.Lock() p.turnConnMutex.Lock()
defer p.turnConnMutex.Unlock() defer p.turnConnMutex.Unlock()
@@ -206,11 +214,14 @@ func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
} }
func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) { func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) {
log.Debugf("remove turn conn from store by port: %d", turnConnID)
p.turnConnMutex.Lock() p.turnConnMutex.Lock()
defer p.turnConnMutex.Unlock() defer p.turnConnMutex.Unlock()
delete(p.turnConnStore, turnConnID)
_, ok := p.turnConnStore[turnConnID]
if ok {
log.Debugf("remove turn conn from store by port: %d", turnConnID)
}
delete(p.turnConnStore, turnConnID)
} }
func (p *WGEBPFProxy) nextFreePort() (uint16, error) { func (p *WGEBPFProxy) nextFreePort() (uint16, error) {

View File

@@ -1,14 +1,13 @@
//go:build linux && !android //go:build linux && !android
package wgproxy package ebpf
import ( import (
"context"
"testing" "testing"
) )
func TestWGEBPFProxy_connStore(t *testing.T) { func TestWGEBPFProxy_connStore(t *testing.T) {
wgProxy := NewWGEBPFProxy(context.Background(), 1) wgProxy := NewWGEBPFProxy(1)
p, _ := wgProxy.storeTurnConn(nil) p, _ := wgProxy.storeTurnConn(nil)
if p != 1 { if p != 1 {
@@ -28,7 +27,7 @@ func TestWGEBPFProxy_connStore(t *testing.T) {
} }
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) { func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
wgProxy := NewWGEBPFProxy(context.Background(), 1) wgProxy := NewWGEBPFProxy(1)
_, _ = wgProxy.storeTurnConn(nil) _, _ = wgProxy.storeTurnConn(nil)
wgProxy.lastUsedPort = 65535 wgProxy.lastUsedPort = 65535
@@ -44,7 +43,7 @@ func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
} }
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) { func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
wgProxy := NewWGEBPFProxy(context.Background(), 1) wgProxy := NewWGEBPFProxy(1)
for i := 0; i < 65535; i++ { for i := 0; i < 65535; i++ {
_, _ = wgProxy.storeTurnConn(nil) _, _ = wgProxy.storeTurnConn(nil)

View File

@@ -0,0 +1,44 @@
//go:build linux && !android
package ebpf
import (
"context"
"fmt"
"net"
)
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct {
WgeBPFProxy *WGEBPFProxy
remoteConn net.Conn
cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread
}
func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
ctxConn, cancel := context.WithCancel(ctx)
addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn)
if err != nil {
cancel()
return nil, fmt.Errorf("add turn conn: %w", err)
}
e.remoteConn = remoteConn
e.cancel = cancel
return addr, err
}
// CloseConn close the remoteConn and automatically remove the conn instance from the map
func (e *ProxyWrapper) CloseConn() error {
if e.cancel == nil {
return fmt.Errorf("proxy not started")
}
e.cancel()
if err := e.remoteConn.Close(); err != nil {
return fmt.Errorf("failed to close remote conn: %w", err)
}
return nil
}

View File

@@ -1,22 +0,0 @@
package wgproxy
import "context"
type Factory struct {
wgPort int
ebpfProxy Proxy
}
func (w *Factory) GetProxy(ctx context.Context) Proxy {
if w.ebpfProxy != nil {
return w.ebpfProxy
}
return NewWGUserSpaceProxy(ctx, w.wgPort)
}
func (w *Factory) Free() error {
if w.ebpfProxy != nil {
return w.ebpfProxy.Free()
}
return nil
}

View File

@@ -3,20 +3,26 @@
package wgproxy package wgproxy
import ( import (
"context"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/internal/wgproxy/usp"
) )
func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory { type Factory struct {
wgPort int
ebpfProxy *ebpf.WGEBPFProxy
}
func NewFactory(userspace bool, wgPort int) *Factory {
f := &Factory{wgPort: wgPort} f := &Factory{wgPort: wgPort}
if userspace { if userspace {
return f return f
} }
ebpfProxy := NewWGEBPFProxy(ctx, wgPort) ebpfProxy := ebpf.NewWGEBPFProxy(wgPort)
err := ebpfProxy.listen() err := ebpfProxy.Listen()
if err != nil { if err != nil {
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
return f return f
@@ -25,3 +31,20 @@ func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
f.ebpfProxy = ebpfProxy f.ebpfProxy = ebpfProxy
return f return f
} }
func (w *Factory) GetProxy() Proxy {
if w.ebpfProxy != nil {
p := &ebpf.ProxyWrapper{
WgeBPFProxy: w.ebpfProxy,
}
return p
}
return usp.NewWGUserSpaceProxy(w.wgPort)
}
func (w *Factory) Free() error {
if w.ebpfProxy == nil {
return nil
}
return w.ebpfProxy.Free()
}

View File

@@ -2,8 +2,20 @@
package wgproxy package wgproxy
import "context" import "github.com/netbirdio/netbird/client/internal/wgproxy/usp"
func NewFactory(ctx context.Context, _ bool, wgPort int) *Factory { type Factory struct {
wgPort int
}
func NewFactory(_ bool, wgPort int) *Factory {
return &Factory{wgPort: wgPort} return &Factory{wgPort: wgPort}
} }
func (w *Factory) GetProxy() Proxy {
return usp.NewWGUserSpaceProxy(w.wgPort)
}
func (w *Factory) Free() error {
return nil
}

View File

@@ -1,12 +1,12 @@
package wgproxy package wgproxy
import ( import (
"context"
"net" "net"
) )
// Proxy is a transfer layer between the Turn connection and the WireGuard // Proxy is a transfer layer between the relayed connection and the WireGuard
type Proxy interface { type Proxy interface {
AddTurnConn(turnConn net.Conn) (net.Addr, error) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error)
CloseConn() error CloseConn() error
Free() error
} }

View File

@@ -0,0 +1,128 @@
//go:build linux
package wgproxy
import (
"context"
"io"
"net"
"os"
"runtime"
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/internal/wgproxy/usp"
"github.com/netbirdio/netbird/util"
)
func TestMain(m *testing.M) {
_ = util.InitLog("trace", "console")
code := m.Run()
os.Exit(code)
}
type mocConn struct {
closeChan chan struct{}
closed bool
}
func newMockConn() *mocConn {
return &mocConn{
closeChan: make(chan struct{}),
}
}
func (m *mocConn) Read(b []byte) (n int, err error) {
<-m.closeChan
return 0, io.EOF
}
func (m *mocConn) Write(b []byte) (n int, err error) {
<-m.closeChan
return 0, io.EOF
}
func (m *mocConn) Close() error {
if m.closed == true {
return nil
}
m.closed = true
close(m.closeChan)
return nil
}
func (m *mocConn) LocalAddr() net.Addr {
panic("implement me")
}
func (m *mocConn) RemoteAddr() net.Addr {
return &net.UDPAddr{
IP: net.ParseIP("172.16.254.1"),
}
}
func (m *mocConn) SetDeadline(t time.Time) error {
panic("implement me")
}
func (m *mocConn) SetReadDeadline(t time.Time) error {
panic("implement me")
}
func (m *mocConn) SetWriteDeadline(t time.Time) error {
panic("implement me")
}
func TestProxyCloseByRemoteConn(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
proxy Proxy
}{
{
name: "userspace proxy",
proxy: usp.NewWGUserSpaceProxy(51830),
},
}
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
ebpfProxy := ebpf.NewWGEBPFProxy(51831)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %s", err)
}
}()
proxyWrapper := &ebpf.ProxyWrapper{
WgeBPFProxy: ebpfProxy,
}
tests = append(tests, struct {
name string
proxy Proxy
}{
name: "ebpf proxy",
proxy: proxyWrapper,
})
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn()
_, err := tt.proxy.AddTurnConn(ctx, relayedConn)
if err != nil {
t.Errorf("error: %v", err)
}
_ = relayedConn.Close()
if err := tt.proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
})
}
}

View File

@@ -1,120 +0,0 @@
package wgproxy
import (
"context"
"fmt"
"io"
"net"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
)
// WGUserSpaceProxy proxies
type WGUserSpaceProxy struct {
localWGListenPort int
ctx context.Context
cancel context.CancelFunc
remoteConn net.Conn
localConn net.Conn
}
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy
func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUserSpaceProxy{
localWGListenPort: wgPort,
}
p.ctx, p.cancel = context.WithCancel(ctx)
return p
}
// AddTurnConn start the proxy with the given remote conn
func (p *WGUserSpaceProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
p.remoteConn = turnConn
var err error
p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err
}
go p.proxyToRemote()
go p.proxyToLocal()
return p.localConn.LocalAddr(), err
}
// CloseConn close the localConn
func (p *WGUserSpaceProxy) CloseConn() error {
p.cancel()
if p.localConn == nil {
return nil
}
return p.localConn.Close()
}
// Free doing nothing because this implementation of proxy does not have global state
func (p *WGUserSpaceProxy) Free() error {
return nil
}
// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
// blocks
func (p *WGUserSpaceProxy) proxyToRemote() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
return
default:
n, err := p.localConn.Read(buf)
if err != nil {
log.Debugf("failed to read from wg interface conn: %s", err)
continue
}
_, err = p.remoteConn.Write(buf[:n])
if err != nil {
if err == io.EOF {
p.cancel()
} else {
log.Debugf("failed to write to remote conn: %s", err)
}
continue
}
}
}
}
// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
// blocks
func (p *WGUserSpaceProxy) proxyToLocal() {
buf := make([]byte, 1500)
for {
select {
case <-p.ctx.Done():
return
default:
n, err := p.remoteConn.Read(buf)
if err != nil {
if err == io.EOF {
p.cancel()
return
}
log.Errorf("failed to read from remote conn: %s", err)
continue
}
_, err = p.localConn.Write(buf[:n])
if err != nil {
log.Debugf("failed to write to wg interface conn: %s", err)
continue
}
}
}
}

View File

@@ -0,0 +1,146 @@
package usp
import (
"context"
"fmt"
"net"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/errors"
)
// WGUserSpaceProxy proxies
type WGUserSpaceProxy struct {
localWGListenPort int
ctx context.Context
cancel context.CancelFunc
remoteConn net.Conn
localConn net.Conn
closeMu sync.Mutex
closed bool
}
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUserSpaceProxy{
localWGListenPort: wgPort,
}
return p
}
// AddTurnConn start the proxy with the given remote conn
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
p.ctx, p.cancel = context.WithCancel(ctx)
p.remoteConn = remoteConn
var err error
dialer := net.Dialer{}
p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err
}
go p.proxyToRemote()
go p.proxyToLocal()
return p.localConn.LocalAddr(), err
}
// CloseConn close the localConn
func (p *WGUserSpaceProxy) CloseConn() error {
if p.cancel == nil {
return fmt.Errorf("proxy not started")
}
return p.close()
}
func (p *WGUserSpaceProxy) close() error {
p.closeMu.Lock()
defer p.closeMu.Unlock()
// prevent double close
if p.closed {
return nil
}
p.closed = true
p.cancel()
var result *multierror.Error
if err := p.remoteConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
}
if err := p.localConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
}
return errors.FormatErrorOrNil(result)
}
// proxyToRemote proxies from Wireguard to the RemoteKey
func (p *WGUserSpaceProxy) proxyToRemote() {
defer func() {
if err := p.close(); err != nil {
log.Warnf("error in proxy to remote loop: %s", err)
}
}()
buf := make([]byte, 1500)
for p.ctx.Err() == nil {
n, err := p.localConn.Read(buf)
if err != nil {
if p.ctx.Err() != nil {
return
}
log.Debugf("failed to read from wg interface conn: %s", err)
return
}
_, err = p.remoteConn.Write(buf[:n])
if err != nil {
if p.ctx.Err() != nil {
return
}
log.Debugf("failed to write to remote conn: %s", err)
return
}
}
}
// proxyToLocal proxies from the Remote peer to local WireGuard
func (p *WGUserSpaceProxy) proxyToLocal() {
defer func() {
if err := p.close(); err != nil {
log.Warnf("error in proxy to local loop: %s", err)
}
}()
buf := make([]byte, 1500)
for p.ctx.Err() == nil {
n, err := p.remoteConn.Read(buf)
if err != nil {
if p.ctx.Err() != nil {
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return
}
_, err = p.localConn.Write(buf[:n])
if err != nil {
if p.ctx.Err() != nil {
return
}
log.Debugf("failed to write to wg interface conn: %s", err)
continue
}
}
}

4
go.mod
View File

@@ -95,9 +95,10 @@ require (
golang.org/x/term v0.21.0 golang.org/x/term v0.21.0
google.golang.org/api v0.177.0 google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.5.7
gorm.io/driver/postgres v1.5.7 gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.3 gorm.io/driver/sqlite v1.5.3
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde gorm.io/gorm v1.25.7
nhooyr.io/websocket v1.8.11 nhooyr.io/websocket v1.8.11
) )
@@ -151,6 +152,7 @@ require (
github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-redis/redis/v8 v8.11.5 // indirect github.com/go-redis/redis/v8 v8.11.5 // indirect
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/go-text/render v0.1.0 // indirect github.com/go-text/render v0.1.0 // indirect
github.com/go-text/typesetting v0.1.0 // indirect github.com/go-text/typesetting v0.1.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect

8
go.sum
View File

@@ -238,6 +238,8 @@ github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow= github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
@@ -1224,12 +1226,14 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo=
gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM=
gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM= gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM=
gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA= gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA=
gorm.io/driver/sqlite v1.5.3 h1:7/0dUgX28KAcopdfbRWWl68Rflh6osa4rDh+m51KL2g= gorm.io/driver/sqlite v1.5.3 h1:7/0dUgX28KAcopdfbRWWl68Rflh6osa4rDh+m51KL2g=
gorm.io/driver/sqlite v1.5.3/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4= gorm.io/driver/sqlite v1.5.3/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4=
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde h1:9DShaph9qhkIYw7QF91I/ynrr4cOO2PZra2PFD7Mfeg= gorm.io/gorm v1.25.7 h1:VsD6acwRjz2zFxGO50gPO6AkNs7KKnvfzUjHQhZDz/A=
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY= gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ=

View File

@@ -950,7 +950,7 @@ components:
type: array type: array
items: items:
type: string type: string
example: ["192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56"] example: ["192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56"]
action: action:
description: Action to take upon policy match description: Action to take upon policy match
type: string type: string

View File

@@ -7,8 +7,6 @@ import (
"net/http" "net/http"
"github.com/gorilla/mux" "github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
@@ -16,6 +14,7 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
) )
// PeersHandler is a handler that returns peers of the account // PeersHandler is a handler that returns peers of the account
@@ -215,7 +214,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. // GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r) claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(r.Context(), claims) account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -228,6 +227,21 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request
return return
} }
// If the user is regular user and does not own the peer
// with the given peerID return an empty list
if !user.HasAdminPower() && !user.IsServiceUser {
peer, ok := account.Peers[peerID]
if !ok {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "peer not found"), w)
return
}
if peer.UserID != user.Id {
util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{})
return
}
}
dnsDomain := h.accountManager.GetDNSDomain() dnsDomain := h.accountManager.GetDNSDomain()
validPeers, err := h.accountManager.GetValidatedPeers(account) validPeers, err := h.accountManager.GetValidatedPeers(account)

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
@@ -12,20 +13,30 @@ import (
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/magiconair/properties/assert" "github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
) )
const testPeerID = "test_peer" type ctxKey string
const noUpdateChannelTestPeerID = "no-update-channel"
const (
testPeerID = "test_peer"
noUpdateChannelTestPeerID = "no-update-channel"
adminUser = "admin_user"
regularUser = "regular_user"
serviceUser = "service_user"
userIDKey ctxKey = "user_id"
)
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return &PeersHandler{ return &PeersHandler{
@@ -60,21 +71,57 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return "netbird.selfhosted" return "netbird.selfhosted"
}, },
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user") peersMap := make(map[string]*nbpeer.Peer)
return &server.Account{ for _, peer := range peers {
peersMap[peer.ID] = peer.Copy()
}
policy := &server.Policy{
ID: "policy",
AccountID: claims.AccountId,
Name: "policy",
Enabled: true,
Rules: []*server.PolicyRule{
{
ID: "rule",
Name: "rule",
Enabled: true,
Action: "accept",
Destinations: []string{"group1"},
Sources: []string{"group1"},
Bidirectional: true,
Protocol: "all",
Ports: []string{"80"},
},
},
}
srvUser := server.NewRegularUser(serviceUser)
srvUser.IsServiceUser = true
account := &server.Account{
Id: claims.AccountId, Id: claims.AccountId,
Domain: "hotmail.com", Domain: "hotmail.com",
Peers: map[string]*nbpeer.Peer{ Peers: peersMap,
peers[0].ID: peers[0],
peers[1].ID: peers[1],
},
Users: map[string]*server.User{ Users: map[string]*server.User{
"test_user": user, adminUser: server.NewAdminUser(adminUser),
regularUser: server.NewRegularUser(regularUser),
serviceUser: srvUser,
},
Groups: map[string]*nbgroup.Group{
"group1": {
ID: "group1",
AccountID: claims.AccountId,
Name: "group1",
Issued: "api",
Peers: maps.Keys(peersMap),
},
}, },
Settings: &server.Settings{ Settings: &server.Settings{
PeerLoginExpirationEnabled: true, PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour, PeerLoginExpiration: time.Hour,
}, },
Policies: []*server.Policy{policy},
Network: &server.Network{ Network: &server.Network{
Identifier: "ciclqisab2ss43jdn8q0", Identifier: "ciclqisab2ss43jdn8q0",
Net: net.IPNet{ Net: net.IPNet{
@@ -83,7 +130,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
}, },
Serial: 51, Serial: 51,
}, },
}, user, nil }
return account, account.Users[claims.UserId], nil
}, },
HasConnectedChannelFunc: func(peerID string) bool { HasConnectedChannelFunc: func(peerID string) bool {
statuses := make(map[string]struct{}) statuses := make(map[string]struct{})
@@ -99,8 +148,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
}, },
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
userID := r.Context().Value(userIDKey).(string)
return jwtclaims.AuthorizationClaims{ return jwtclaims.AuthorizationClaims{
UserId: "test_user", UserId: userID,
Domain: "hotmail.com", Domain: "hotmail.com",
AccountId: "test_id", AccountId: "test_id",
} }
@@ -197,6 +247,8 @@ func TestGetPeers(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
ctx := context.WithValue(context.Background(), userIDKey, "admin_user")
req = req.WithContext(ctx)
router := mux.NewRouter() router := mux.NewRouter()
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET") router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
@@ -251,3 +303,119 @@ func TestGetPeers(t *testing.T) {
}) })
} }
} }
func TestGetAccessiblePeers(t *testing.T) {
peer1 := &nbpeer.Peer{
ID: "peer1",
Key: "key1",
IP: net.ParseIP("100.64.0.1"),
Status: &nbpeer.PeerStatus{Connected: true},
Name: "peer1",
LoginExpirationEnabled: false,
UserID: regularUser,
}
peer2 := &nbpeer.Peer{
ID: "peer2",
Key: "key2",
IP: net.ParseIP("100.64.0.2"),
Status: &nbpeer.PeerStatus{Connected: true},
Name: "peer2",
LoginExpirationEnabled: false,
UserID: adminUser,
}
peer3 := &nbpeer.Peer{
ID: "peer3",
Key: "key3",
IP: net.ParseIP("100.64.0.3"),
Status: &nbpeer.PeerStatus{Connected: true},
Name: "peer3",
LoginExpirationEnabled: false,
UserID: regularUser,
}
p := initTestMetaData(peer1, peer2, peer3)
tt := []struct {
name string
peerID string
callerUserID string
expectedStatus int
expectedPeers []string
}{
{
name: "non admin user can access owned peer",
peerID: "peer1",
callerUserID: regularUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer2", "peer3"},
},
{
name: "non admin user can't access unowned peer",
peerID: "peer2",
callerUserID: regularUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{},
},
{
name: "admin user can access owned peer",
peerID: "peer2",
callerUserID: adminUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer3"},
},
{
name: "admin user can access unowned peer",
peerID: "peer3",
callerUserID: adminUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer2"},
},
{
name: "service user can access unowned peer",
peerID: "peer3",
callerUserID: serviceUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer2"},
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
ctx := context.WithValue(context.Background(), userIDKey, tc.callerUserID)
req = req.WithContext(ctx)
router := mux.NewRouter()
router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
if res.StatusCode != tc.expectedStatus {
t.Fatalf("handler returned wrong status code: got %v want %v", res.StatusCode, tc.expectedStatus)
}
body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}
defer res.Body.Close()
var accessiblePeers []api.AccessiblePeer
err = json.Unmarshal(body, &accessiblePeers)
if err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
peerIDs := make([]string, len(accessiblePeers))
for i, peer := range accessiblePeers {
peerIDs[i] = peer.Id
}
assert.ElementsMatch(t, peerIDs, tc.expectedPeers)
})
}
}

View File

@@ -15,6 +15,7 @@ import (
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
@@ -844,6 +845,16 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe
return NewSqlStore(ctx, db, PostgresStoreEngine, metrics) return NewSqlStore(ctx, db, PostgresStoreEngine, metrics)
} }
// NewMysqlStore creates a new MySql store.
func NewMysqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
db, err := gorm.Open(mysql.Open(dsn), getGormConfig())
if err != nil {
return nil, err
}
return NewSqlStore(ctx, db, MySqlStoreEngine, metrics)
}
func getGormConfig() *gorm.Config { func getGormConfig() *gorm.Config {
return &gorm.Config{ return &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent), Logger: logger.Default.LogMode(logger.Silent),
@@ -861,6 +872,15 @@ func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics) (Store,
return NewPostgresqlStore(ctx, dsn, metrics) return NewPostgresqlStore(ctx, dsn, metrics)
} }
// newMySqlStore initializes a new MySql store.
func newMySqlStore(ctx context.Context, metrics telemetry.AppMetrics) (Store, error) {
dsn, ok := os.LookupEnv(mySqlDsnEnv)
if !ok {
return nil, fmt.Errorf("%s is not set", mySqlDsnEnv)
}
return NewMysqlStore(ctx, dsn, metrics)
}
// NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir. // NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir.
func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) { func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) {
store, err := NewSqliteStore(ctx, dataDir, metrics) store, err := NewSqliteStore(ctx, dataDir, metrics)
@@ -1024,3 +1044,7 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store {
db: tx, db: tx,
} }
} }
func (s *SqlStore) GetDB() *gorm.DB {
return s.db
}

View File

@@ -95,8 +95,10 @@ const (
FileStoreEngine StoreEngine = "jsonfile" FileStoreEngine StoreEngine = "jsonfile"
SqliteStoreEngine StoreEngine = "sqlite" SqliteStoreEngine StoreEngine = "sqlite"
PostgresStoreEngine StoreEngine = "postgres" PostgresStoreEngine StoreEngine = "postgres"
MySqlStoreEngine StoreEngine = "mysql"
postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN" postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
mySqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN"
) )
func getStoreEngineFromEnv() StoreEngine { func getStoreEngineFromEnv() StoreEngine {
@@ -107,11 +109,12 @@ func getStoreEngineFromEnv() StoreEngine {
} }
value := StoreEngine(strings.ToLower(kind)) value := StoreEngine(strings.ToLower(kind))
if value == SqliteStoreEngine || value == PostgresStoreEngine { switch value {
case SqliteStoreEngine, PostgresStoreEngine, MySqlStoreEngine:
return value return value
default:
return SqliteStoreEngine
} }
return SqliteStoreEngine
} }
// getStoreEngine determines the store engine to use. // getStoreEngine determines the store engine to use.
@@ -158,6 +161,9 @@ func NewStore(ctx context.Context, kind StoreEngine, dataDir string, metrics tel
case PostgresStoreEngine: case PostgresStoreEngine:
log.WithContext(ctx).Info("using Postgres store engine") log.WithContext(ctx).Info("using Postgres store engine")
return newPostgresStore(ctx, metrics) return newPostgresStore(ctx, metrics)
case MySqlStoreEngine:
log.WithContext(ctx).Info("using MySQL store engine")
return newMySqlStore(ctx, metrics)
default: default:
return nil, fmt.Errorf("unsupported kind of store: %s", kind) return nil, fmt.Errorf("unsupported kind of store: %s", kind)
} }

View File

@@ -70,7 +70,7 @@ type User struct {
// Blocked indicates whether the user is blocked. Blocked users can't use the system. // Blocked indicates whether the user is blocked. Blocked users can't use the system.
Blocked bool Blocked bool
// LastLogin is the last time the user logged in to IdP // LastLogin is the last time the user logged in to IdP
LastLogin time.Time LastLogin time.Time `gorm:"type:TIMESTAMP;null;default:null"`
// CreatedAt records the time the user was created // CreatedAt records the time the user was created
CreatedAt time.Time CreatedAt time.Time

View File

@@ -58,7 +58,10 @@ func (m *Msg) Free() {
m.bufPool.Put(m.bufPtr) m.bufPool.Put(m.bufPtr)
} }
// connContainer is a container for the connection to the peer. It is responsible for managing the messages from the
// server and forwarding them to the upper layer content reader.
type connContainer struct { type connContainer struct {
log *log.Entry
conn *Conn conn *Conn
messages chan Msg messages chan Msg
msgChanLock sync.Mutex msgChanLock sync.Mutex
@@ -67,10 +70,10 @@ type connContainer struct {
cancel context.CancelFunc cancel context.CancelFunc
} }
func newConnContainer(conn *Conn, messages chan Msg) *connContainer { func newConnContainer(log *log.Entry, conn *Conn, messages chan Msg) *connContainer {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &connContainer{ return &connContainer{
log: log,
conn: conn, conn: conn,
messages: messages, messages: messages,
ctx: ctx, ctx: ctx,
@@ -91,6 +94,10 @@ func (cc *connContainer) writeMsg(msg Msg) {
case cc.messages <- msg: case cc.messages <- msg:
case <-cc.ctx.Done(): case <-cc.ctx.Done():
msg.Free() msg.Free()
default:
msg.Free()
cc.log.Infof("message queue is full")
// todo consider to close the connection
} }
} }
@@ -141,8 +148,8 @@ type Client struct {
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect // NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client { func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
hashedID, hashedStringId := messages.HashID(peerID) hashedID, hashedStringId := messages.HashID(peerID)
return &Client{ c := &Client{
log: log.WithFields(log.Fields{"client_id": hashedStringId, "relay": serverURL}), log: log.WithFields(log.Fields{"relay": serverURL}),
parentCtx: ctx, parentCtx: ctx,
connectionURL: serverURL, connectionURL: serverURL,
authTokenStore: authTokenStore, authTokenStore: authTokenStore,
@@ -155,6 +162,8 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token
}, },
conns: make(map[string]*connContainer), conns: make(map[string]*connContainer),
} }
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedStringId)
return c
} }
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs. // Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
@@ -203,10 +212,10 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
} }
c.log.Infof("open connection to peer: %s", hashedStringID) c.log.Infof("open connection to peer: %s", hashedStringID)
msgChannel := make(chan Msg, 2) msgChannel := make(chan Msg, 100)
conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL) conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
c.conns[hashedStringID] = newConnContainer(conn, msgChannel) c.conns[hashedStringID] = newConnContainer(c.log, conn, msgChannel)
return conn, nil return conn, nil
} }
@@ -455,7 +464,10 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in
} }
c.log.Errorf("health check timeout") c.log.Errorf("health check timeout")
internalStopFlag.set() internalStopFlag.set()
_ = conn.Close() // ignore the err because the readLoop will handle it if err := conn.Close(); err != nil {
// ignore the err handling because the readLoop will handle it
c.log.Warnf("failed to close connection: %s", err)
}
return return
case <-c.parentCtx.Done(): case <-c.parentCtx.Done():
err := c.close(true) err := c.close(true)
@@ -486,6 +498,7 @@ func (c *Client) closeConn(connReference *Conn, id string) error {
if container.conn != connReference { if container.conn != connReference {
return fmt.Errorf("conn reference mismatch") return fmt.Errorf("conn reference mismatch")
} }
c.log.Infof("free up connection to peer: %s", id)
delete(c.conns, id) delete(c.conns, id)
container.close() container.close()

View File

@@ -35,12 +35,15 @@ func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*C
connResultChan := make(chan connResult, totalServers) connResultChan := make(chan connResult, totalServers)
successChan := make(chan connResult, 1) successChan := make(chan connResult, 1)
concurrentLimiter := make(chan struct{}, maxConcurrentServers) concurrentLimiter := make(chan struct{}, maxConcurrentServers)
for _, url := range urls { for _, url := range urls {
// todo check if we have a successful connection so we do not need to connect to other servers
concurrentLimiter <- struct{}{} concurrentLimiter <- struct{}{}
go func(url string) { go func(url string) {
defer func() { <-concurrentLimiter }() defer func() {
<-concurrentLimiter
}()
sp.startConnection(parentCtx, connResultChan, url) sp.startConnection(parentCtx, connResultChan, url)
}(url) }(url)
} }
@@ -72,7 +75,8 @@ func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan con
func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) { func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) {
var hasSuccess bool var hasSuccess bool
for cr := range resultChan { for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ {
cr := <-resultChan
if cr.Err != nil { if cr.Err != nil {
log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err) log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
continue continue

View File

@@ -0,0 +1,31 @@
package client
import (
"context"
"errors"
"testing"
"time"
)
func TestServerPicker_UnavailableServers(t *testing.T) {
sp := ServerPicker{
TokenStore: nil,
PeerID: "test",
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
go func() {
_, err := sp.PickServer(ctx, []string{"rel://dummy1", "rel://dummy2"})
if err == nil {
t.Error(err)
}
cancel()
}()
<-ctx.Done()
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
t.Errorf("PickServer() took too long to complete")
}
}