Compare commits

...

19 Commits

Author SHA1 Message Date
Zoltán Papp
c495eaa549 Move interface near to engine 2024-10-10 14:46:51 +02:00
Zoltán Papp
b8026ad541 Merge branch 'main' into relay/fix/wg-roaming 2024-10-09 23:24:07 +02:00
Zoltán Papp
a5deeda727 Revert force install change 2024-10-09 19:20:20 +02:00
Zoltán Papp
5b2d5f8df1 Try to force install libpcap 2024-10-09 19:12:33 +02:00
Zoltán Papp
6369706ade Merge branch 'main' into relay/fix/wg-roaming 2024-10-09 18:54:30 +02:00
Zoltan Papp
e3dfbe5acf Add trace log 2024-10-09 14:07:35 +02:00
Zoltan Papp
deeb05047d Handle addr resolve error 2024-10-09 14:05:43 +02:00
Zoltan Papp
1814b07a4b Replace error check to errors.Is 2024-10-09 14:02:23 +02:00
Zoltán Papp
b04d19bb0a Fix nil pointer in error handling 2024-10-08 12:04:57 +02:00
Zoltán Papp
20815c9f90 Remove unused function 2024-10-07 13:28:21 +02:00
Zoltán Papp
ba3cdb30ee Remove unnecessary ctx cancel check 2024-10-07 13:05:11 +02:00
Zoltán Papp
1f25bb0751 Reducate cognitive complexity 2024-10-07 12:58:45 +02:00
Zoltán Papp
9e7aac3a56 Reducate cognitive complexity 2024-10-07 12:52:55 +02:00
Zoltán Papp
718d9526a7 Fix test 2024-10-07 12:45:21 +02:00
Zoltán Papp
48184ecf21 Fix eBPF pause handling 2024-10-07 12:40:53 +02:00
Zoltán Papp
f18ae8b925 Apply pause logic 2024-10-07 11:22:48 +02:00
Zoltán Papp
90d9dd4c08 Remove unused function from eBPF proxy 2024-10-07 10:35:53 +02:00
Zoltán Papp
acad98e328 Code cleaning 2024-10-03 02:29:46 +02:00
Zoltán Papp
9d75cc3273 Add pause function for proxies 2024-10-03 01:24:05 +02:00
19 changed files with 321 additions and 191 deletions

View File

@@ -141,7 +141,7 @@ type Engine struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
wgInterface iface.IWGIface wgInterface IWGIface
wgProxyFactory *wgproxy.Factory wgProxyFactory *wgproxy.Factory
udpMux *bind.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
@@ -326,7 +326,7 @@ func (e *Engine) Start() error {
} }
e.dnsServer = dnsServer e.dnsServer = dnsServer
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes) e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface.(*iface.WGIface), e.statusRecorder, e.relayManager, initialRoutes)
beforePeerHook, afterPeerHook, err := e.routeManager.Init() beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil { if err != nil {
log.Errorf("Failed to initialize route manager: %s", err) log.Errorf("Failed to initialize route manager: %s", err)
@@ -921,7 +921,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
wgConfig := peer.WgConfig{ wgConfig := peer.WgConfig{
RemoteKey: pubKey, RemoteKey: pubKey,
WgListenPort: e.config.WgPort, WgListenPort: e.config.WgPort,
WgInterface: e.wgInterface, WgInterface: e.wgInterface.(*iface.WGIface),
AllowedIps: allowedIPs, AllowedIps: allowedIPs,
PreSharedKey: e.config.PreSharedKey, PreSharedKey: e.config.PreSharedKey,
} }

View File

@@ -242,13 +242,13 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
peer.NewRecorder("https://mgm"), peer.NewRecorder("https://mgm"),
nil) nil)
wgIface := &iface.MockWGIface{ wgIface := &MockWGIface{
RemovePeerFunc: func(peerKey string) error { RemovePeerFunc: func(peerKey string) error {
return nil return nil
}, },
} }
engine.wgInterface = wgIface engine.wgInterface = wgIface
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil) engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface.(*iface.WGIface), engine.statusRecorder, relayMgr, nil)
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
} }

View File

@@ -1,4 +1,4 @@
package iface package internal
import ( import (
"net" "net"

View File

@@ -1,6 +1,6 @@
//go:build !windows //go:build !windows
package iface package internal
import ( import (
"net" "net"

View File

@@ -1,4 +1,4 @@
package iface package internal
import ( import (
"net" "net"

View File

@@ -39,7 +39,7 @@ const (
type WgConfig struct { type WgConfig struct {
WgListenPort int WgListenPort int
RemoteKey string RemoteKey string
WgInterface iface.IWGIface WgInterface *iface.WGIface
AllowedIps string AllowedIps string
PreSharedKey *wgtypes.Key PreSharedKey *wgtypes.Key
} }
@@ -82,8 +82,6 @@ type Conn struct {
config ConnConfig config ConnConfig
statusRecorder *Status statusRecorder *Status
wgProxyFactory *wgproxy.Factory wgProxyFactory *wgproxy.Factory
wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy
signaler *Signaler signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager relayManager *relayClient.Manager
@@ -106,7 +104,8 @@ type Conn struct {
beforeAddPeerHooks []nbnet.AddHookFunc beforeAddPeerHooks []nbnet.AddHookFunc
afterRemovePeerHooks []nbnet.RemoveHookFunc afterRemovePeerHooks []nbnet.RemoveHookFunc
endpointRelay *net.UDPAddr wgProxyICE wgproxy.Proxy
wgProxyRelay wgproxy.Proxy
// for reconnection operations // for reconnection operations
iCEDisconnected chan bool iCEDisconnected chan bool
@@ -257,8 +256,7 @@ func (conn *Conn) Close() {
conn.wgProxyICE = nil conn.wgProxyICE = nil
} }
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) if err := conn.removeWgPeer(); err != nil {
if err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err) conn.log.Errorf("failed to remove wg endpoint: %v", err)
} }
@@ -430,54 +428,59 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
conn.log.Debugf("ICE connection is ready") conn.log.Debugf("ICE connection is ready")
conn.statusICE.Set(StatusConnected)
defer conn.updateIceState(iceConnInfo)
if conn.currentConnPriority > priority { if conn.currentConnPriority > priority {
conn.statusICE.Set(StatusConnected)
conn.updateIceState(iceConnInfo)
return return
} }
conn.log.Infof("set ICE to active connection") conn.log.Infof("set ICE to active connection")
endpoint, wgProxy, err := conn.getEndpointForICEConnInfo(iceConnInfo) var (
if err != nil { ep *net.UDPAddr
return wgProxy wgproxy.Proxy
err error
)
if iceConnInfo.RelayedOnLocal {
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
return
}
ep = wgProxy.EndpointAddr()
conn.wgProxyICE = wgProxy
} else {
directEp, err := net.ResolveUDPAddr("udp", iceConnInfo.RemoteConn.RemoteAddr().String())
if err != nil {
log.Errorf("failed to resolveUDPaddr")
conn.handleConfigurationFailure(err, nil)
return
}
ep = directEp
} }
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP) conn.log.Errorf("Before add peer hook failed: %v", err)
conn.connIDICE = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err)
}
} }
conn.workerRelay.DisableWgWatcher() conn.workerRelay.DisableWgWatcher()
err = conn.configureWGEndpoint(endpointUdpAddr) if conn.wgProxyRelay != nil {
if err != nil { conn.wgProxyRelay.Pause()
if wgProxy != nil { }
if err := wgProxy.CloseConn(); err != nil {
conn.log.Warnf("Failed to close turn connection: %v", err) if wgProxy != nil {
} wgProxy.Work()
} }
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
if err = conn.configureWGEndpoint(ep); err != nil {
conn.handleConfigurationFailure(err, wgProxy)
return return
} }
wgConfigWorkaround() wgConfigWorkaround()
if conn.wgProxyICE != nil {
if err := conn.wgProxyICE.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
conn.wgProxyICE = wgProxy
conn.currentConnPriority = priority conn.currentConnPriority = priority
conn.statusICE.Set(StatusConnected)
conn.updateIceState(iceConnInfo)
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
} }
@@ -492,11 +495,18 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
conn.log.Tracef("ICE connection state changed to %s", newState) conn.log.Tracef("ICE connection state changed to %s", newState)
if conn.wgProxyICE != nil {
if err := conn.wgProxyICE.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
// switch back to relay connection // switch back to relay connection
if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay { if conn.isReadyToUpgrade() {
conn.log.Debugf("ICE disconnected, set Relay to active connection") conn.log.Debugf("ICE disconnected, set Relay to active connection")
err := conn.configureWGEndpoint(conn.endpointRelay) conn.wgProxyRelay.Work()
if err != nil {
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err) conn.log.Errorf("failed to switch to relay conn: %v", err)
} }
conn.workerRelay.EnableWgWatcher(conn.ctx) conn.workerRelay.EnableWgWatcher(conn.ctx)
@@ -506,10 +516,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
changed := conn.statusICE.Get() != newState && newState != StatusConnecting changed := conn.statusICE.Get() != newState && newState != StatusConnecting
conn.statusICE.Set(newState) conn.statusICE.Set(newState)
select { conn.notifyReconnectLoopICEDisconnected(changed)
case conn.iCEDisconnected <- changed:
default:
}
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
@@ -530,61 +537,48 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
if conn.ctx.Err() != nil { if conn.ctx.Err() != nil {
if err := rci.relayedConn.Close(); err != nil { if err := rci.relayedConn.Close(); err != nil {
log.Warnf("failed to close unnecessary relayed connection: %v", err) conn.log.Warnf("failed to close unnecessary relayed connection: %v", err)
} }
return return
} }
conn.log.Debugf("Relay connection is ready to use") conn.log.Debugf("Relay connection has been established, setup the WireGuard")
conn.statusRelay.Set(StatusConnected)
wgProxy := conn.wgProxyFactory.GetProxy() wgProxy, err := conn.newProxy(rci.relayedConn)
endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn)
if err != nil { if err != nil {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return return
} }
conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
conn.endpointRelay = endpointUdpAddr
conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) if conn.iceP2PIsActive() {
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
if conn.currentConnPriority > connPriorityRelay { conn.wgProxyRelay = wgProxy
if conn.statusICE.Get() == StatusConnected { conn.statusRelay.Set(StatusConnected)
log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
return return
}
} }
conn.connIDRelay = nbnet.GenerateConnID() if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
for _, hook := range conn.beforeAddPeerHooks { conn.log.Errorf("Before add peer hook failed: %v", err)
if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil {
conn.log.Errorf("Before add peer hook failed: %v", err)
}
} }
err = conn.configureWGEndpoint(endpointUdpAddr) wgProxy.Work()
if err != nil { if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil {
if err := wgProxy.CloseConn(); err != nil { if err := wgProxy.CloseConn(); err != nil {
conn.log.Warnf("Failed to close relay connection: %v", err) conn.log.Warnf("Failed to close relay connection: %v", err)
} }
conn.log.Errorf("Failed to update wg peer configuration: %v", err) conn.log.Errorf("Failed to update WireGuard peer configuration: %v", err)
return return
} }
conn.workerRelay.EnableWgWatcher(conn.ctx) conn.workerRelay.EnableWgWatcher(conn.ctx)
wgConfigWorkaround() wgConfigWorkaround()
if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
conn.wgProxyRelay = wgProxy
conn.currentConnPriority = connPriorityRelay conn.currentConnPriority = connPriorityRelay
conn.statusRelay.Set(StatusConnected)
conn.wgProxyRelay = wgProxy
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
conn.log.Infof("start to communicate with peer via relay") conn.log.Infof("start to communicate with peer via relay")
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
} }
@@ -597,29 +591,23 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
return return
} }
log.Debugf("relay connection is disconnected") conn.log.Debugf("relay connection is disconnected")
if conn.currentConnPriority == connPriorityRelay { if conn.currentConnPriority == connPriorityRelay {
log.Debugf("clean up WireGuard config") conn.log.Debugf("clean up WireGuard config")
err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) if err := conn.removeWgPeer(); err != nil {
if err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err) conn.log.Errorf("failed to remove wg endpoint: %v", err)
} }
} }
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
conn.endpointRelay = nil
_ = conn.wgProxyRelay.CloseConn() _ = conn.wgProxyRelay.CloseConn()
conn.wgProxyRelay = nil conn.wgProxyRelay = nil
} }
changed := conn.statusRelay.Get() != StatusDisconnected changed := conn.statusRelay.Get() != StatusDisconnected
conn.statusRelay.Set(StatusDisconnected) conn.statusRelay.Set(StatusDisconnected)
conn.notifyReconnectLoopRelayDisconnected(changed)
select {
case conn.relayDisconnected <- changed:
default:
}
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
@@ -627,9 +615,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {
Relayed: conn.isRelayed(), Relayed: conn.isRelayed(),
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
} }
if err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState); err != nil {
err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState)
if err != nil {
conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err) conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
} }
} }
@@ -765,6 +751,16 @@ func (conn *Conn) isConnected() bool {
return true return true
} }
func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error {
conn.connIDICE = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connIDICE, ip); err != nil {
return err
}
}
return nil
}
func (conn *Conn) freeUpConnID() { func (conn *Conn) freeUpConnID() {
if conn.connIDRelay != "" { if conn.connIDRelay != "" {
for _, hook := range conn.afterRemovePeerHooks { for _, hook := range conn.afterRemovePeerHooks {
@@ -785,21 +781,52 @@ func (conn *Conn) freeUpConnID() {
} }
} }
func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, wgproxy.Proxy, error) { func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
if !iceConnInfo.RelayedOnLocal { conn.log.Debugf("setup proxied WireGuard connection")
return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
}
conn.log.Debugf("setup ice turn connection")
wgProxy := conn.wgProxyFactory.GetProxy() wgProxy := conn.wgProxyFactory.GetProxy()
ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn) if err := wgProxy.AddTurnConn(conn.ctx, remoteConn); err != nil {
if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
if errClose := wgProxy.CloseConn(); errClose != nil { return nil, err
conn.log.Warnf("failed to close turn proxy connection: %v", errClose) }
} return wgProxy, nil
return nil, nil, err }
func (conn *Conn) isReadyToUpgrade() bool {
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay
}
func (conn *Conn) iceP2PIsActive() bool {
return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected
}
func (conn *Conn) removeWgPeer() error {
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
}
func (conn *Conn) notifyReconnectLoopRelayDisconnected(changed bool) {
select {
case conn.relayDisconnected <- changed:
default:
}
}
func (conn *Conn) notifyReconnectLoopICEDisconnected(changed bool) {
select {
case conn.iCEDisconnected <- changed:
default:
}
}
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
if wgProxy != nil {
if ierr := wgProxy.CloseConn(); ierr != nil {
conn.log.Warnf("Failed to close wg proxy: %v", ierr)
}
}
if conn.wgProxyRelay != nil {
conn.wgProxyRelay.Work()
} }
return ep, wgProxy, nil
} }
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {

View File

@@ -43,7 +43,7 @@ type clientNetwork struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
statusRecorder *peer.Status statusRecorder *peer.Status
wgInterface iface.IWGIface wgInterface *iface.WGIface
routes map[route.ID]*route.Route routes map[route.ID]*route.Route
routeUpdate chan routesUpdate routeUpdate chan routesUpdate
peerStateUpdate chan struct{} peerStateUpdate chan struct{}
@@ -53,7 +53,7 @@ type clientNetwork struct {
updateSerial uint64 updateSerial uint64
} }
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface iface.IWGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork { func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface *iface.WGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
client := &clientNetwork{ client := &clientNetwork{
@@ -378,7 +378,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
} }
} }
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface iface.IWGIface) RouteHandler { func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface *iface.WGIface) RouteHandler {
if rt.IsDynamic() { if rt.IsDynamic() {
dns := nbdns.NewServiceViaMemory(wgInterface) dns := nbdns.NewServiceViaMemory(wgInterface)
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())) return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))

View File

@@ -48,7 +48,7 @@ type Route struct {
currentPeerKey string currentPeerKey string
cancel context.CancelFunc cancel context.CancelFunc
statusRecorder *peer.Status statusRecorder *peer.Status
wgInterface iface.IWGIface wgInterface *iface.WGIface
resolverAddr string resolverAddr string
} }
@@ -58,7 +58,7 @@ func NewRoute(
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
interval time.Duration, interval time.Duration,
statusRecorder *peer.Status, statusRecorder *peer.Status,
wgInterface iface.IWGIface, wgInterface *iface.WGIface,
resolverAddr string, resolverAddr string,
) *Route { ) *Route {
return &Route{ return &Route{

View File

@@ -52,7 +52,7 @@ type DefaultManager struct {
sysOps *systemops.SysOps sysOps *systemops.SysOps
statusRecorder *peer.Status statusRecorder *peer.Status
relayMgr *relayClient.Manager relayMgr *relayClient.Manager
wgInterface iface.IWGIface wgInterface *iface.WGIface
pubKey string pubKey string
notifier *notifier.Notifier notifier *notifier.Notifier
routeRefCounter *refcounter.RouteRefCounter routeRefCounter *refcounter.RouteRefCounter
@@ -64,7 +64,7 @@ func NewManager(
ctx context.Context, ctx context.Context,
pubKey string, pubKey string,
dnsRouteInterval time.Duration, dnsRouteInterval time.Duration,
wgInterface iface.IWGIface, wgInterface *iface.WGIface,
statusRecorder *peer.Status, statusRecorder *peer.Status,
relayMgr *relayClient.Manager, relayMgr *relayClient.Manager,
initialRoutes []*route.Route, initialRoutes []*route.Route,

View File

@@ -11,6 +11,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
) )
func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (serverRouter, error) { func newServerRouter(context.Context, *iface.WGIface, firewall.Manager, *peer.Status) (serverRouter, error) {
return nil, fmt.Errorf("server route not supported on this os") return nil, fmt.Errorf("server route not supported on this os")
} }

View File

@@ -22,11 +22,11 @@ type defaultServerRouter struct {
ctx context.Context ctx context.Context
routes map[route.ID]*route.Route routes map[route.ID]*route.Route
firewall firewall.Manager firewall firewall.Manager
wgInterface iface.IWGIface wgInterface *iface.WGIface
statusRecorder *peer.Status statusRecorder *peer.Status
} }
func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) { func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
return &defaultServerRouter{ return &defaultServerRouter{
ctx: ctx, ctx: ctx,
routes: make(map[route.ID]*route.Route), routes: make(map[route.ID]*route.Route),

View File

@@ -23,7 +23,7 @@ const (
) )
// Setup configures sysctl settings for RP filtering and source validation. // Setup configures sysctl settings for RP filtering and source validation.
func Setup(wgIface iface.IWGIface) (map[string]int, error) { func Setup(wgIface *iface.WGIface) (map[string]int, error) {
keys := map[string]int{} keys := map[string]int{}
var result *multierror.Error var result *multierror.Error

View File

@@ -19,7 +19,7 @@ type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
type SysOps struct { type SysOps struct {
refCounter *ExclusionCounter refCounter *ExclusionCounter
wgInterface iface.IWGIface wgInterface *iface.WGIface
// prefixes is tracking all the current added prefixes im memory // prefixes is tracking all the current added prefixes im memory
// (this is used in iOS as all route updates require a full table update) // (this is used in iOS as all route updates require a full table update)
//nolint //nolint
@@ -30,7 +30,7 @@ type SysOps struct {
notifier *notifier.Notifier notifier *notifier.Notifier
} }
func NewSysOps(wgInterface iface.IWGIface, notifier *notifier.Notifier) *SysOps { func NewSysOps(wgInterface *iface.WGIface, notifier *notifier.Notifier) *SysOps {
return &SysOps{ return &SysOps{
wgInterface: wgInterface, wgInterface: wgInterface,
notifier: notifier, notifier: notifier,

View File

@@ -122,7 +122,7 @@ func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface. // addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values. // If the next hop or interface is pointing to the VPN interface, it will return the initial values.
func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.IWGIface, initialNextHop Nexthop) (Nexthop, error) { func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
addr := prefix.Addr() addr := prefix.Addr()
switch { switch {
case addr.IsLoopback(), case addr.IsLoopback(),

View File

@@ -5,7 +5,6 @@ package ebpf
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"net" "net"
"os" "os"
"sync" "sync"
@@ -94,13 +93,12 @@ func (p *WGEBPFProxy) Listen() error {
} }
// AddTurnConn add new turn connection for the proxy // AddTurnConn add new turn connection for the proxy
func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) { func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (*net.UDPAddr, error) {
wgEndpointPort, err := p.storeTurnConn(turnConn) wgEndpointPort, err := p.storeTurnConn(turnConn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
go p.proxyToLocal(ctx, wgEndpointPort, turnConn)
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort) log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
wgEndpoint := &net.UDPAddr{ wgEndpoint := &net.UDPAddr{
@@ -137,35 +135,6 @@ func (p *WGEBPFProxy) Free() error {
return nberrors.FormatErrorOrNil(result) return nberrors.FormatErrorOrNil(result)
} }
func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) {
defer p.removeTurnConn(endpointPort)
var (
err error
n int
)
buf := make([]byte, 1500)
for ctx.Err() == nil {
n, err = remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return
}
if err != io.EOF {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
return
}
if err := p.sendPkg(buf[:n], endpointPort); err != nil {
if ctx.Err() != nil || p.ctx.Err() != nil {
return
}
log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
}
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn // proxyToRemote read messages from local WireGuard interface and forward it to remote conn
// From this go routine has only one instance. // From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() { func (p *WGEBPFProxy) proxyToRemote() {
@@ -280,7 +249,7 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
return packetConn, nil return packetConn, nil
} }
func (p *WGEBPFProxy) sendPkg(data []byte, port uint16) error { func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
localhost := net.ParseIP("127.0.0.1") localhost := net.ParseIP("127.0.0.1")
payload := gopacket.Payload(data) payload := gopacket.Payload(data)

View File

@@ -4,8 +4,13 @@ package ebpf
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io"
"net" "net"
"sync"
log "github.com/sirupsen/logrus"
) )
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
@@ -13,20 +18,55 @@ type ProxyWrapper struct {
WgeBPFProxy *WGEBPFProxy WgeBPFProxy *WGEBPFProxy
remoteConn net.Conn remoteConn net.Conn
cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread ctx context.Context
cancel context.CancelFunc
wgEndpointAddr *net.UDPAddr
pausedMu sync.Mutex
paused bool
isStarted bool
} }
func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
ctxConn, cancel := context.WithCancel(ctx) addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn)
if err != nil { if err != nil {
cancel() return fmt.Errorf("add turn conn: %w", err)
return nil, fmt.Errorf("add turn conn: %w", err)
} }
e.remoteConn = remoteConn p.remoteConn = remoteConn
e.cancel = cancel p.ctx, p.cancel = context.WithCancel(ctx)
return addr, err p.wgEndpointAddr = addr
return err
}
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgEndpointAddr
}
func (p *ProxyWrapper) Work() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = false
p.pausedMu.Unlock()
if !p.isStarted {
p.isStarted = true
go p.proxyToLocal(p.ctx)
}
}
func (p *ProxyWrapper) Pause() {
if p.remoteConn == nil {
return
}
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
p.pausedMu.Lock()
p.paused = true
p.pausedMu.Unlock()
} }
// CloseConn close the remoteConn and automatically remove the conn instance from the map // CloseConn close the remoteConn and automatically remove the conn instance from the map
@@ -42,3 +82,45 @@ func (e *ProxyWrapper) CloseConn() error {
} }
return nil return nil
} }
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
buf := make([]byte, 1500)
for {
n, err := p.readFromRemote(ctx, buf)
if err != nil {
return
}
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
}
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
p.pausedMu.Unlock()
if err != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
}
func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, error) {
n, err := p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return 0, ctx.Err()
}
if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
}
return 0, err
}
return n, nil
}

View File

@@ -7,6 +7,9 @@ import (
// Proxy is a transfer layer between the relayed connection and the WireGuard // Proxy is a transfer layer between the relayed connection and the WireGuard
type Proxy interface { type Proxy interface {
AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) AddTurnConn(ctx context.Context, turnConn net.Conn) error
EndpointAddr() *net.UDPAddr
Work()
Pause()
CloseConn() error CloseConn() error
} }

View File

@@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn() relayedConn := newMockConn()
_, err := tt.proxy.AddTurnConn(ctx, relayedConn) err := tt.proxy.AddTurnConn(ctx, relayedConn)
if err != nil { if err != nil {
t.Errorf("error: %v", err) t.Errorf("error: %v", err)
} }

View File

@@ -15,13 +15,17 @@ import (
// WGUserSpaceProxy proxies // WGUserSpaceProxy proxies
type WGUserSpaceProxy struct { type WGUserSpaceProxy struct {
localWGListenPort int localWGListenPort int
ctx context.Context
cancel context.CancelFunc
remoteConn net.Conn remoteConn net.Conn
localConn net.Conn localConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex closeMu sync.Mutex
closed bool closed bool
pausedMu sync.Mutex
paused bool
isStarted bool
} }
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation // NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
@@ -33,24 +37,60 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
return p return p
} }
// AddTurnConn start the proxy with the given remote conn // AddTurnConn
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { // The provided Context must be non-nil. If the context expires before
p.ctx, p.cancel = context.WithCancel(ctx) // the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
p.remoteConn = remoteConn // connection.
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error {
var err error
dialer := net.Dialer{} dialer := net.Dialer{}
p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil { if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err) log.Errorf("failed dialing to local Wireguard port %s", err)
return nil, err return err
} }
go p.proxyToRemote() p.ctx, p.cancel = context.WithCancel(ctx)
go p.proxyToLocal() p.localConn = localConn
p.remoteConn = remoteConn
return p.localConn.LocalAddr(), err return err
}
func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr {
if p.localConn == nil {
return nil
}
endpointUdpAddr, _ := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
return endpointUdpAddr
}
// Work starts the proxy or resumes it if it was paused
func (p *WGUserSpaceProxy) Work() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = false
p.pausedMu.Unlock()
if !p.isStarted {
p.isStarted = true
go p.proxyToRemote(p.ctx)
go p.proxyToLocal(p.ctx)
}
}
// Pause pauses the proxy from receiving data from the remote peer
func (p *WGUserSpaceProxy) Pause() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = true
p.pausedMu.Unlock()
} }
// CloseConn close the localConn // CloseConn close the localConn
@@ -85,7 +125,7 @@ func (p *WGUserSpaceProxy) close() error {
} }
// proxyToRemote proxies from Wireguard to the RemoteKey // proxyToRemote proxies from Wireguard to the RemoteKey
func (p *WGUserSpaceProxy) proxyToRemote() { func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) {
defer func() { defer func() {
if err := p.close(); err != nil { if err := p.close(); err != nil {
log.Warnf("error in proxy to remote loop: %s", err) log.Warnf("error in proxy to remote loop: %s", err)
@@ -93,10 +133,10 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
}() }()
buf := make([]byte, 1500) buf := make([]byte, 1500)
for p.ctx.Err() == nil { for ctx.Err() == nil {
n, err := p.localConn.Read(buf) n, err := p.localConn.Read(buf)
if err != nil { if err != nil {
if p.ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
log.Debugf("failed to read from wg interface conn: %s", err) log.Debugf("failed to read from wg interface conn: %s", err)
@@ -105,7 +145,7 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
_, err = p.remoteConn.Write(buf[:n]) _, err = p.remoteConn.Write(buf[:n])
if err != nil { if err != nil {
if p.ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
@@ -116,7 +156,8 @@ func (p *WGUserSpaceProxy) proxyToRemote() {
} }
// proxyToLocal proxies from the Remote peer to local WireGuard // proxyToLocal proxies from the Remote peer to local WireGuard
func (p *WGUserSpaceProxy) proxyToLocal() { // if the proxy is paused it will drain the remote conn and drop the packets
func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) {
defer func() { defer func() {
if err := p.close(); err != nil { if err := p.close(); err != nil {
log.Warnf("error in proxy to local loop: %s", err) log.Warnf("error in proxy to local loop: %s", err)
@@ -124,19 +165,27 @@ func (p *WGUserSpaceProxy) proxyToLocal() {
}() }()
buf := make([]byte, 1500) buf := make([]byte, 1500)
for p.ctx.Err() == nil { for {
n, err := p.remoteConn.Read(buf) n, err := p.remoteConn.Read(buf)
if err != nil { if err != nil {
if p.ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return return
} }
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
}
_, err = p.localConn.Write(buf[:n]) _, err = p.localConn.Write(buf[:n])
p.pausedMu.Unlock()
if err != nil { if err != nil {
if p.ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
log.Debugf("failed to write to wg interface conn: %s", err) log.Debugf("failed to write to wg interface conn: %s", err)