mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-24 11:16:38 +00:00
Merge branch 'main' into local/engine-restart
This commit is contained in:
@@ -146,12 +146,11 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
||||
// it's IP address can be used in the ipset for firewall manager which supports it
|
||||
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)]
|
||||
ipsetName := ""
|
||||
if ipset.name == "" {
|
||||
d.ipsetCounter++
|
||||
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter)
|
||||
}
|
||||
ipsetName = ipset.name
|
||||
ipsetName := ipset.name
|
||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
||||
if err != nil {
|
||||
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux
|
||||
//go:build !linux || android
|
||||
|
||||
package acl
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
)
|
||||
|
||||
@@ -17,6 +19,9 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := fm.AllowNetbird(); err != nil {
|
||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||
}
|
||||
return newDefaultManager(fm), nil
|
||||
}
|
||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build !android
|
||||
|
||||
package acl
|
||||
|
||||
import (
|
||||
@@ -7,26 +9,68 @@ import (
|
||||
"github.com/netbirdio/netbird/client/firewall/iptables"
|
||||
"github.com/netbirdio/netbird/client/firewall/nftables"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
"github.com/netbirdio/netbird/client/internal/checkfw"
|
||||
)
|
||||
|
||||
// Create creates a firewall manager instance for the Linux
|
||||
func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
||||
func Create(iface IFaceMapper) (*DefaultManager, error) {
|
||||
// on the linux system we try to user nftables or iptables
|
||||
// in any case, because we need to allow netbird interface traffic
|
||||
// so we use AllowNetbird traffic from these firewall managers
|
||||
// for the userspace packet filtering firewall
|
||||
var fm firewall.Manager
|
||||
var err error
|
||||
|
||||
checkResult := checkfw.Check()
|
||||
switch checkResult {
|
||||
case checkfw.IPTABLES, checkfw.IPTABLESWITHV6:
|
||||
log.Debug("creating an iptables firewall manager for access control")
|
||||
ipv6Supported := checkResult == checkfw.IPTABLESWITHV6
|
||||
if fm, err = iptables.Create(iface, ipv6Supported); err != nil {
|
||||
log.Infof("failed to create iptables manager for access control: %s", err)
|
||||
}
|
||||
case checkfw.NFTABLES:
|
||||
log.Debug("creating an nftables firewall manager for access control")
|
||||
if fm, err = nftables.Create(iface); err != nil {
|
||||
log.Debugf("failed to create nftables manager for access control: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
var resetHookForUserspace func() error
|
||||
if fm != nil && err == nil {
|
||||
// err shadowing is used here, to ignore this error
|
||||
if err := fm.AllowNetbird(); err != nil {
|
||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||
}
|
||||
resetHookForUserspace = fm.Reset
|
||||
}
|
||||
|
||||
if iface.IsUserspaceBind() {
|
||||
// use userspace packet filtering firewall
|
||||
if fm, err = uspfilter.Create(iface); err != nil {
|
||||
usfm, err := uspfilter.Create(iface)
|
||||
if err != nil {
|
||||
log.Debugf("failed to create userspace filtering firewall: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if fm, err = nftables.Create(iface); err != nil {
|
||||
log.Debugf("failed to create nftables manager: %s", err)
|
||||
// fallback to iptables
|
||||
if fm, err = iptables.Create(iface); err != nil {
|
||||
log.Errorf("failed to create iptables manager: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// set kernel space firewall Reset as hook for userspace firewall
|
||||
// manager Reset method, to clean up
|
||||
if resetHookForUserspace != nil {
|
||||
usfm.SetResetHook(resetHookForUserspace)
|
||||
}
|
||||
|
||||
// to be consistent for any future extensions.
|
||||
// ignore this error
|
||||
if err := usfm.AllowNetbird(); err != nil {
|
||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||
}
|
||||
fm = usfm
|
||||
}
|
||||
|
||||
if fm == nil || err != nil {
|
||||
log.Errorf("failed to create firewall manager: %s", err)
|
||||
// no firewall manager found or initialized correctly
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newDefaultManager(fm), nil
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
@@ -32,13 +34,22 @@ func TestDefaultManager(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iface := mocks.NewMockIFaceMapper(ctrl)
|
||||
iface.EXPECT().IsUserspaceBind().Return(true)
|
||||
// iface.EXPECT().Name().Return("lo")
|
||||
iface.EXPECT().SetFilter(gomock.Any())
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true)
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse IP address: %v", err)
|
||||
}
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
acl, err := Create(iface)
|
||||
acl, err := Create(ifaceMock)
|
||||
if err != nil {
|
||||
t.Errorf("create ACL manager: %v", err)
|
||||
return
|
||||
@@ -311,13 +322,22 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
iface := mocks.NewMockIFaceMapper(ctrl)
|
||||
iface.EXPECT().IsUserspaceBind().Return(true)
|
||||
// iface.EXPECT().Name().Return("lo")
|
||||
iface.EXPECT().SetFilter(gomock.Any())
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true)
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse IP address: %v", err)
|
||||
}
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
|
||||
// we receive one rule from the management so for testing purposes ignore it
|
||||
acl, err := Create(iface)
|
||||
acl, err := Create(ifaceMock)
|
||||
if err != nil {
|
||||
t.Errorf("create ACL manager: %v", err)
|
||||
return
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
@@ -57,34 +57,50 @@ func (t TokenInfo) GetTokenToUse() string {
|
||||
return t.AccessToken
|
||||
}
|
||||
|
||||
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration.
|
||||
// NewOAuthFlow initializes and returns the appropriate OAuth flow based on the management configuration
|
||||
//
|
||||
// It starts by initializing the PKCE.If this process fails, it resorts to the Device Code Flow,
|
||||
// and if that also fails, the authentication process is deemed unsuccessful
|
||||
//
|
||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||
func NewOAuthFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
||||
log.Debug("getting device authorization flow info")
|
||||
|
||||
// Try to initialize the Device Authorization Flow
|
||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err == nil {
|
||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
||||
if runtime.GOOS == "linux" && !isLinuxRunningDesktop() {
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
}
|
||||
|
||||
log.Debugf("getting device authorization flow info failed with error: %v", err)
|
||||
log.Debugf("falling back to pkce authorization flow info")
|
||||
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
|
||||
if err != nil {
|
||||
// fallback to device code flow
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
}
|
||||
return pkceFlow, nil
|
||||
}
|
||||
|
||||
// If Device Authorization Flow failed, try the PKCE Authorization Flow
|
||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||
}
|
||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||
}
|
||||
|
||||
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
|
||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
s, ok := gstatus.FromError(err)
|
||||
if ok && s.Code() == codes.NotFound {
|
||||
return nil, fmt.Errorf("no SSO provider returned from management. " +
|
||||
"If you are using hosting Netbird see documentation at " +
|
||||
"https://github.com/netbirdio/netbird/tree/main/management for details")
|
||||
"Please proceed with setting up this device using setup keys " +
|
||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||
} else if ok && s.Code() == codes.Unimplemented {
|
||||
return nil, fmt.Errorf("the management server, %s, does not support SSO providers, "+
|
||||
"please update your server or use Setup Keys to login", config.ManagementURL)
|
||||
} else {
|
||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||
return nil, fmt.Errorf("getting device authorization flow info failed with error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net"
|
||||
@@ -78,7 +79,7 @@ func (p *PKCEAuthorizationFlow) GetClientID(_ context.Context) string {
|
||||
}
|
||||
|
||||
// RequestAuthInfo requests a authorization code login flow information.
|
||||
func (p *PKCEAuthorizationFlow) RequestAuthInfo(_ context.Context) (AuthFlowInfo, error) {
|
||||
func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowInfo, error) {
|
||||
state, err := randomBytesInHex(24)
|
||||
if err != nil {
|
||||
return AuthFlowInfo{}, fmt.Errorf("could not generate random state: %v", err)
|
||||
@@ -112,60 +113,37 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
|
||||
tokenChan := make(chan *oauth2.Token, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go p.startServer(tokenChan, errChan)
|
||||
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
|
||||
if err != nil {
|
||||
return TokenInfo{}, fmt.Errorf("failed to parse redirect URL: %v", err)
|
||||
}
|
||||
|
||||
server := &http.Server{Addr: fmt.Sprintf(":%s", parsedURL.Port())}
|
||||
defer func() {
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
log.Errorf("failed to close the server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go p.startServer(server, tokenChan, errChan)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return TokenInfo{}, ctx.Err()
|
||||
case token := <-tokenChan:
|
||||
return p.handleOAuthToken(token)
|
||||
return p.parseOAuthToken(token)
|
||||
case err := <-errChan:
|
||||
return TokenInfo{}, err
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errChan chan<- error) {
|
||||
parsedURL, err := url.Parse(p.oAuthConfig.RedirectURL)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("failed to parse redirect URL: %v", err)
|
||||
return
|
||||
}
|
||||
port := parsedURL.Port()
|
||||
|
||||
server := http.Server{Addr: fmt.Sprintf(":%s", port)}
|
||||
defer func() {
|
||||
if err := server.Shutdown(context.Background()); err != nil {
|
||||
log.Errorf("error while shutting down pkce flow server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
||||
tokenValidatorFunc := func() (*oauth2.Token, error) {
|
||||
query := req.URL.Query()
|
||||
|
||||
if authError := query.Get(queryError); authError != "" {
|
||||
authErrorDesc := query.Get(queryErrorDesc)
|
||||
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
|
||||
}
|
||||
|
||||
// Prevent timing attacks on state
|
||||
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
||||
return nil, fmt.Errorf("invalid state")
|
||||
}
|
||||
|
||||
code := query.Get(queryCode)
|
||||
if code == "" {
|
||||
return nil, fmt.Errorf("missing code")
|
||||
}
|
||||
|
||||
return p.oAuthConfig.Exchange(
|
||||
req.Context(),
|
||||
code,
|
||||
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
|
||||
)
|
||||
}
|
||||
|
||||
token, err := tokenValidatorFunc()
|
||||
func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
|
||||
token, err := p.handleRequest(req)
|
||||
if err != nil {
|
||||
renderPKCEFlowTmpl(w, err)
|
||||
errChan <- fmt.Errorf("PKCE authorization flow failed: %v", err)
|
||||
@@ -176,12 +154,38 @@ func (p *PKCEAuthorizationFlow) startServer(tokenChan chan<- *oauth2.Token, errC
|
||||
tokenChan <- token
|
||||
})
|
||||
|
||||
if err := server.ListenAndServe(); err != nil {
|
||||
server.Handler = mux
|
||||
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
errChan <- err
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PKCEAuthorizationFlow) handleOAuthToken(token *oauth2.Token) (TokenInfo, error) {
|
||||
func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token, error) {
|
||||
query := req.URL.Query()
|
||||
|
||||
if authError := query.Get(queryError); authError != "" {
|
||||
authErrorDesc := query.Get(queryErrorDesc)
|
||||
return nil, fmt.Errorf("%s.%s", authError, authErrorDesc)
|
||||
}
|
||||
|
||||
// Prevent timing attacks on the state
|
||||
if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 {
|
||||
return nil, fmt.Errorf("invalid state")
|
||||
}
|
||||
|
||||
code := query.Get(queryCode)
|
||||
if code == "" {
|
||||
return nil, fmt.Errorf("missing code")
|
||||
}
|
||||
|
||||
return p.oAuthConfig.Exchange(
|
||||
req.Context(),
|
||||
code,
|
||||
oauth2.SetAuthURLParam("code_verifier", p.codeVerifier),
|
||||
)
|
||||
}
|
||||
|
||||
func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, error) {
|
||||
tokenInfo := TokenInfo{
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
@@ -60,3 +61,8 @@ func isValidAccessToken(token string, audience string) error {
|
||||
|
||||
return fmt.Errorf("invalid JWT token audience field")
|
||||
}
|
||||
|
||||
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
|
||||
func isLinuxRunningDesktop() bool {
|
||||
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
||||
}
|
||||
|
||||
3
client/internal/checkfw/check.go
Normal file
3
client/internal/checkfw/check.go
Normal file
@@ -0,0 +1,3 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package checkfw
|
||||
56
client/internal/checkfw/check_linux.go
Normal file
56
client/internal/checkfw/check_linux.go
Normal file
@@ -0,0 +1,56 @@
|
||||
//go:build !android
|
||||
|
||||
package checkfw
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
)
|
||||
|
||||
const (
|
||||
// UNKNOWN is the default value for the firewall type for unknown firewall type
|
||||
UNKNOWN FWType = iota
|
||||
// IPTABLES is the value for the iptables firewall type
|
||||
IPTABLES
|
||||
// IPTABLESWITHV6 is the value for the iptables firewall type with ipv6
|
||||
IPTABLESWITHV6
|
||||
// NFTABLES is the value for the nftables firewall type
|
||||
NFTABLES
|
||||
)
|
||||
|
||||
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
|
||||
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||
|
||||
// FWType is the type for the firewall type
|
||||
type FWType int
|
||||
|
||||
// Check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
||||
func Check() FWType {
|
||||
nf := nftables.Conn{}
|
||||
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
|
||||
return NFTABLES
|
||||
}
|
||||
|
||||
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err == nil {
|
||||
if isIptablesClientAvailable(ip) {
|
||||
ipSupport := IPTABLES
|
||||
ipv6, ip6Err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
if ip6Err == nil {
|
||||
if isIptablesClientAvailable(ipv6) {
|
||||
ipSupport = IPTABLESWITHV6
|
||||
}
|
||||
}
|
||||
return ipSupport
|
||||
}
|
||||
}
|
||||
|
||||
return UNKNOWN
|
||||
}
|
||||
|
||||
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||
_, err := client.ListChains("filter")
|
||||
return err == nil
|
||||
}
|
||||
@@ -23,9 +23,6 @@ func TestGetConfig(t *testing.T) {
|
||||
assert.Equal(t, config.ManagementURL.String(), DefaultManagementURL)
|
||||
assert.Equal(t, config.AdminURL.String(), DefaultAdminURL)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
managementURL := "https://test.management.url:33071"
|
||||
adminURL := "https://app.admin.url:443"
|
||||
path := filepath.Join(t.TempDir(), "config.json")
|
||||
|
||||
@@ -192,8 +192,6 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
|
||||
state.Set(StatusConnected)
|
||||
|
||||
statusRecorder.ClientStart()
|
||||
|
||||
<-engineCtx.Done()
|
||||
statusRecorder.ClientTeardown()
|
||||
|
||||
@@ -214,6 +212,7 @@ func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
return nil
|
||||
}
|
||||
|
||||
statusRecorder.ClientStart()
|
||||
err = backoff.Retry(operation, backOff)
|
||||
if err != nil {
|
||||
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
|
||||
|
||||
@@ -238,7 +238,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
hostUpdate := s.currentConfig
|
||||
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
||||
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
||||
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver")
|
||||
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
|
||||
hostUpdate.routeAll = false
|
||||
}
|
||||
|
||||
|
||||
@@ -777,7 +777,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
newNet, err := stdnet.NewNet(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("create stdnet: %v", err)
|
||||
return nil, nil
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", iface.DefaultMTU, nil, newNet)
|
||||
|
||||
@@ -11,6 +11,9 @@ import (
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -24,10 +27,11 @@ type serviceViaListener struct {
|
||||
dnsMux *dns.ServeMux
|
||||
customAddr *netip.AddrPort
|
||||
server *dns.Server
|
||||
runtimeIP string
|
||||
runtimePort int
|
||||
listenIP string
|
||||
listenPort int
|
||||
listenerIsRunning bool
|
||||
listenerFlagLock sync.Mutex
|
||||
ebpfService ebpfMgr.Manager
|
||||
}
|
||||
|
||||
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
|
||||
@@ -43,6 +47,7 @@ func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *service
|
||||
UDPSize: 65535,
|
||||
},
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -55,13 +60,21 @@ func (s *serviceViaListener) Listen() error {
|
||||
}
|
||||
|
||||
var err error
|
||||
s.runtimeIP, s.runtimePort, err = s.evalRuntimeAddress()
|
||||
s.listenIP, s.listenPort, err = s.evalListenAddress()
|
||||
if err != nil {
|
||||
log.Errorf("failed to eval runtime address: %s", err)
|
||||
return err
|
||||
}
|
||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||
s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)
|
||||
|
||||
if s.shouldApplyPortFwd() {
|
||||
s.ebpfService = ebpf.GetEbpfManagerInstance()
|
||||
err = s.ebpfService.LoadDNSFwd(s.listenIP, s.listenPort)
|
||||
if err != nil {
|
||||
log.Warnf("failed to load DNS port forwarder, custom port may not work well on some Linux operating systems: %s", err)
|
||||
s.ebpfService = nil
|
||||
}
|
||||
}
|
||||
log.Debugf("starting dns on %s", s.server.Addr)
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
@@ -69,9 +82,10 @@ func (s *serviceViaListener) Listen() error {
|
||||
|
||||
err := s.server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -90,6 +104,13 @@ func (s *serviceViaListener) Stop() {
|
||||
if err != nil {
|
||||
log.Errorf("stopping dns server listener returned an error: %v", err)
|
||||
}
|
||||
|
||||
if s.ebpfService != nil {
|
||||
err = s.ebpfService.FreeDNSFwd()
|
||||
if err != nil {
|
||||
log.Errorf("stopping traffic forwarder returned an error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
||||
@@ -101,11 +122,18 @@ func (s *serviceViaListener) DeregisterMux(pattern string) {
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) RuntimePort() int {
|
||||
return s.runtimePort
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if s.ebpfService != nil {
|
||||
return defaultPort
|
||||
} else {
|
||||
return s.listenPort
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) RuntimeIP() string {
|
||||
return s.runtimeIP
|
||||
return s.listenIP
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) setListenerStatus(running bool) {
|
||||
@@ -136,10 +164,30 @@ func (s *serviceViaListener) getFirstListenerAvailable() (string, int, error) {
|
||||
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) evalRuntimeAddress() (string, int, error) {
|
||||
func (s *serviceViaListener) evalListenAddress() (string, int, error) {
|
||||
if s.customAddr != nil {
|
||||
return s.customAddr.Addr().String(), int(s.customAddr.Port()), nil
|
||||
}
|
||||
|
||||
return s.getFirstListenerAvailable()
|
||||
}
|
||||
|
||||
// shouldApplyPortFwd decides whether to apply eBPF program to capture DNS traffic on port 53.
|
||||
// This is needed because on some operating systems if we start a DNS server not on a default port 53, the domain name
|
||||
// resolution won't work.
|
||||
// So, in case we are running on Linux and picked a non-default port (53) we should fall back to the eBPF solution that will capture
|
||||
// traffic on port 53 and forward it to a local DNS server running on 5053.
|
||||
func (s *serviceViaListener) shouldApplyPortFwd() bool {
|
||||
if runtime.GOOS != "linux" {
|
||||
return false
|
||||
}
|
||||
|
||||
if s.customAddr != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if s.listenPort == defaultPort {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -54,13 +54,16 @@ type bpfSpecs struct {
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfProgramSpecs struct {
|
||||
NbWgProxy *ebpf.ProgramSpec `ebpf:"nb_wg_proxy"`
|
||||
NbXdpProg *ebpf.ProgramSpec `ebpf:"nb_xdp_prog"`
|
||||
}
|
||||
|
||||
// bpfMapSpecs contains maps before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfMapSpecs struct {
|
||||
NbFeatures *ebpf.MapSpec `ebpf:"nb_features"`
|
||||
NbMapDnsIp *ebpf.MapSpec `ebpf:"nb_map_dns_ip"`
|
||||
NbMapDnsPort *ebpf.MapSpec `ebpf:"nb_map_dns_port"`
|
||||
NbWgProxySettingsMap *ebpf.MapSpec `ebpf:"nb_wg_proxy_settings_map"`
|
||||
}
|
||||
|
||||
@@ -83,11 +86,17 @@ func (o *bpfObjects) Close() error {
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfMaps struct {
|
||||
NbFeatures *ebpf.Map `ebpf:"nb_features"`
|
||||
NbMapDnsIp *ebpf.Map `ebpf:"nb_map_dns_ip"`
|
||||
NbMapDnsPort *ebpf.Map `ebpf:"nb_map_dns_port"`
|
||||
NbWgProxySettingsMap *ebpf.Map `ebpf:"nb_wg_proxy_settings_map"`
|
||||
}
|
||||
|
||||
func (m *bpfMaps) Close() error {
|
||||
return _BpfClose(
|
||||
m.NbFeatures,
|
||||
m.NbMapDnsIp,
|
||||
m.NbMapDnsPort,
|
||||
m.NbWgProxySettingsMap,
|
||||
)
|
||||
}
|
||||
@@ -96,12 +105,12 @@ func (m *bpfMaps) Close() error {
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfPrograms struct {
|
||||
NbWgProxy *ebpf.Program `ebpf:"nb_wg_proxy"`
|
||||
NbXdpProg *ebpf.Program `ebpf:"nb_xdp_prog"`
|
||||
}
|
||||
|
||||
func (p *bpfPrograms) Close() error {
|
||||
return _BpfClose(
|
||||
p.NbWgProxy,
|
||||
p.NbXdpProg,
|
||||
)
|
||||
}
|
||||
|
||||
BIN
client/internal/ebpf/ebpf/bpf_bpfeb.o
Normal file
BIN
client/internal/ebpf/ebpf/bpf_bpfeb.o
Normal file
Binary file not shown.
@@ -54,13 +54,16 @@ type bpfSpecs struct {
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfProgramSpecs struct {
|
||||
NbWgProxy *ebpf.ProgramSpec `ebpf:"nb_wg_proxy"`
|
||||
NbXdpProg *ebpf.ProgramSpec `ebpf:"nb_xdp_prog"`
|
||||
}
|
||||
|
||||
// bpfMapSpecs contains maps before they are loaded into the kernel.
|
||||
//
|
||||
// It can be passed ebpf.CollectionSpec.Assign.
|
||||
type bpfMapSpecs struct {
|
||||
NbFeatures *ebpf.MapSpec `ebpf:"nb_features"`
|
||||
NbMapDnsIp *ebpf.MapSpec `ebpf:"nb_map_dns_ip"`
|
||||
NbMapDnsPort *ebpf.MapSpec `ebpf:"nb_map_dns_port"`
|
||||
NbWgProxySettingsMap *ebpf.MapSpec `ebpf:"nb_wg_proxy_settings_map"`
|
||||
}
|
||||
|
||||
@@ -83,11 +86,17 @@ func (o *bpfObjects) Close() error {
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfMaps struct {
|
||||
NbFeatures *ebpf.Map `ebpf:"nb_features"`
|
||||
NbMapDnsIp *ebpf.Map `ebpf:"nb_map_dns_ip"`
|
||||
NbMapDnsPort *ebpf.Map `ebpf:"nb_map_dns_port"`
|
||||
NbWgProxySettingsMap *ebpf.Map `ebpf:"nb_wg_proxy_settings_map"`
|
||||
}
|
||||
|
||||
func (m *bpfMaps) Close() error {
|
||||
return _BpfClose(
|
||||
m.NbFeatures,
|
||||
m.NbMapDnsIp,
|
||||
m.NbMapDnsPort,
|
||||
m.NbWgProxySettingsMap,
|
||||
)
|
||||
}
|
||||
@@ -96,12 +105,12 @@ func (m *bpfMaps) Close() error {
|
||||
//
|
||||
// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign.
|
||||
type bpfPrograms struct {
|
||||
NbWgProxy *ebpf.Program `ebpf:"nb_wg_proxy"`
|
||||
NbXdpProg *ebpf.Program `ebpf:"nb_xdp_prog"`
|
||||
}
|
||||
|
||||
func (p *bpfPrograms) Close() error {
|
||||
return _BpfClose(
|
||||
p.NbWgProxy,
|
||||
p.NbXdpProg,
|
||||
)
|
||||
}
|
||||
|
||||
BIN
client/internal/ebpf/ebpf/bpf_bpfel.o
Normal file
BIN
client/internal/ebpf/ebpf/bpf_bpfel.o
Normal file
Binary file not shown.
51
client/internal/ebpf/ebpf/dns_fwd_linux.go
Normal file
51
client/internal/ebpf/ebpf/dns_fwd_linux.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
mapKeyDNSIP uint32 = 0
|
||||
mapKeyDNSPort uint32 = 1
|
||||
)
|
||||
|
||||
func (tf *GeneralManager) LoadDNSFwd(ip string, dnsPort int) error {
|
||||
log.Debugf("load ebpf DNS forwarder: address: %s:%d", ip, dnsPort)
|
||||
tf.lock.Lock()
|
||||
defer tf.lock.Unlock()
|
||||
|
||||
err := tf.loadXdp()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tf.bpfObjs.NbMapDnsIp.Put(mapKeyDNSIP, ip2int(ip))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tf.bpfObjs.NbMapDnsPort.Put(mapKeyDNSPort, uint16(dnsPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tf.setFeatureFlag(featureFlagDnsForwarder)
|
||||
err = tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tf *GeneralManager) FreeDNSFwd() error {
|
||||
log.Debugf("free ebpf DNS forwarder")
|
||||
return tf.unsetFeatureFlag(featureFlagDnsForwarder)
|
||||
}
|
||||
|
||||
func ip2int(ipString string) uint32 {
|
||||
ip := net.ParseIP(ipString)
|
||||
return binary.BigEndian.Uint32(ip.To4())
|
||||
}
|
||||
116
client/internal/ebpf/ebpf/manager_linux.go
Normal file
116
client/internal/ebpf/ebpf/manager_linux.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/cilium/ebpf/link"
|
||||
"github.com/cilium/ebpf/rlimit"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
mapKeyFeatures uint32 = 0
|
||||
|
||||
featureFlagWGProxy = 0b00000001
|
||||
featureFlagDnsForwarder = 0b00000010
|
||||
)
|
||||
|
||||
var (
|
||||
singleton manager.Manager
|
||||
singletonLock = &sync.Mutex{}
|
||||
)
|
||||
|
||||
// required packages libbpf-dev, libc6-dev-i386-amd64-cross
|
||||
|
||||
// GeneralManager is used to load multiple eBPF programs with a custom check (if then) done in prog.c
|
||||
// The manager simply adds a feature (byte) of each program to a map that is shared between the userspace and kernel.
|
||||
// When packet arrives, the C code checks for each feature (if it is set) and executes each enabled program (e.g., dns_fwd.c and wg_proxy.c).
|
||||
//
|
||||
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang-14 bpf src/prog.c -- -I /usr/x86_64-linux-gnu/include
|
||||
type GeneralManager struct {
|
||||
lock sync.Mutex
|
||||
link link.Link
|
||||
featureFlags uint16
|
||||
bpfObjs bpfObjects
|
||||
}
|
||||
|
||||
// GetEbpfManagerInstance return a static eBpf Manager instance
|
||||
func GetEbpfManagerInstance() manager.Manager {
|
||||
singletonLock.Lock()
|
||||
defer singletonLock.Unlock()
|
||||
if singleton != nil {
|
||||
return singleton
|
||||
}
|
||||
singleton = &GeneralManager{}
|
||||
return singleton
|
||||
}
|
||||
|
||||
func (tf *GeneralManager) setFeatureFlag(feature uint16) {
|
||||
tf.featureFlags = tf.featureFlags | feature
|
||||
}
|
||||
|
||||
func (tf *GeneralManager) loadXdp() error {
|
||||
if tf.link != nil {
|
||||
return nil
|
||||
}
|
||||
// it required for Docker
|
||||
err := rlimit.RemoveMemlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
iFace, err := net.InterfaceByName("lo")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// load pre-compiled programs into the kernel.
|
||||
err = loadBpfObjects(&tf.bpfObjs, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tf.link, err = link.AttachXDP(link.XDPOptions{
|
||||
Program: tf.bpfObjs.NbXdpProg,
|
||||
Interface: iFace.Index,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
_ = tf.bpfObjs.Close()
|
||||
tf.link = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tf *GeneralManager) unsetFeatureFlag(feature uint16) error {
|
||||
tf.lock.Lock()
|
||||
defer tf.lock.Unlock()
|
||||
tf.featureFlags &^= feature
|
||||
|
||||
if tf.link == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if tf.featureFlags == 0 {
|
||||
return tf.close()
|
||||
}
|
||||
|
||||
return tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags)
|
||||
}
|
||||
|
||||
func (tf *GeneralManager) close() error {
|
||||
log.Debugf("detach ebpf program ")
|
||||
err := tf.bpfObjs.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed to close eBpf objects: %s", err)
|
||||
}
|
||||
|
||||
err = tf.link.Close()
|
||||
tf.link = nil
|
||||
return err
|
||||
}
|
||||
40
client/internal/ebpf/ebpf/manager_linux_test.go
Normal file
40
client/internal/ebpf/ebpf/manager_linux_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestManager_setFeatureFlag(t *testing.T) {
|
||||
mgr := GeneralManager{}
|
||||
mgr.setFeatureFlag(featureFlagWGProxy)
|
||||
if mgr.featureFlags != 1 {
|
||||
t.Errorf("invalid faeture state")
|
||||
}
|
||||
|
||||
mgr.setFeatureFlag(featureFlagDnsForwarder)
|
||||
if mgr.featureFlags != 3 {
|
||||
t.Errorf("invalid faeture state")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_unsetFeatureFlag(t *testing.T) {
|
||||
mgr := GeneralManager{}
|
||||
mgr.setFeatureFlag(featureFlagWGProxy)
|
||||
mgr.setFeatureFlag(featureFlagDnsForwarder)
|
||||
|
||||
err := mgr.unsetFeatureFlag(featureFlagWGProxy)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %s", err)
|
||||
}
|
||||
if mgr.featureFlags != 2 {
|
||||
t.Errorf("invalid faeture state, expected: %d, got: %d", 2, mgr.featureFlags)
|
||||
}
|
||||
|
||||
err = mgr.unsetFeatureFlag(featureFlagDnsForwarder)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %s", err)
|
||||
}
|
||||
if mgr.featureFlags != 0 {
|
||||
t.Errorf("invalid faeture state, expected: %d, got: %d", 0, mgr.featureFlags)
|
||||
}
|
||||
}
|
||||
64
client/internal/ebpf/ebpf/src/dns_fwd.c
Normal file
64
client/internal/ebpf/ebpf/src/dns_fwd.c
Normal file
@@ -0,0 +1,64 @@
|
||||
const __u32 map_key_dns_ip = 0;
|
||||
const __u32 map_key_dns_port = 1;
|
||||
|
||||
struct bpf_map_def SEC("maps") nb_map_dns_ip = {
|
||||
.type = BPF_MAP_TYPE_ARRAY,
|
||||
.key_size = sizeof(__u32),
|
||||
.value_size = sizeof(__u32),
|
||||
.max_entries = 10,
|
||||
};
|
||||
|
||||
struct bpf_map_def SEC("maps") nb_map_dns_port = {
|
||||
.type = BPF_MAP_TYPE_ARRAY,
|
||||
.key_size = sizeof(__u32),
|
||||
.value_size = sizeof(__u16),
|
||||
.max_entries = 10,
|
||||
};
|
||||
|
||||
__be32 dns_ip = 0;
|
||||
__be16 dns_port = 0;
|
||||
|
||||
// 13568 is 53 in big endian
|
||||
__be16 GENERAL_DNS_PORT = 13568;
|
||||
|
||||
bool read_settings() {
|
||||
__u16 *port_value;
|
||||
__u32 *ip_value;
|
||||
|
||||
// read dns ip
|
||||
ip_value = bpf_map_lookup_elem(&nb_map_dns_ip, &map_key_dns_ip);
|
||||
if(!ip_value) {
|
||||
return false;
|
||||
}
|
||||
dns_ip = htonl(*ip_value);
|
||||
|
||||
// read dns port
|
||||
port_value = bpf_map_lookup_elem(&nb_map_dns_port, &map_key_dns_port);
|
||||
if (!port_value) {
|
||||
return false;
|
||||
}
|
||||
dns_port = htons(*port_value);
|
||||
return true;
|
||||
}
|
||||
|
||||
int xdp_dns_fwd(struct iphdr *ip, struct udphdr *udp) {
|
||||
if (dns_port == 0) {
|
||||
if(!read_settings()){
|
||||
return XDP_PASS;
|
||||
}
|
||||
bpf_printk("dns port: %d", ntohs(dns_port));
|
||||
bpf_printk("dns ip: %d", ntohl(dns_ip));
|
||||
}
|
||||
|
||||
if (udp->dest == GENERAL_DNS_PORT && ip->daddr == dns_ip) {
|
||||
udp->dest = dns_port;
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
if (udp->source == dns_port && ip->saddr == dns_ip) {
|
||||
udp->source = GENERAL_DNS_PORT;
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
return XDP_PASS;
|
||||
}
|
||||
66
client/internal/ebpf/ebpf/src/prog.c
Normal file
66
client/internal/ebpf/ebpf/src/prog.c
Normal file
@@ -0,0 +1,66 @@
|
||||
#include <stdbool.h>
|
||||
#include <linux/if_ether.h> // ETH_P_IP
|
||||
#include <linux/udp.h>
|
||||
#include <linux/ip.h>
|
||||
#include <netinet/in.h>
|
||||
#include <linux/bpf.h>
|
||||
#include <bpf/bpf_helpers.h>
|
||||
#include "dns_fwd.c"
|
||||
#include "wg_proxy.c"
|
||||
|
||||
#define bpf_printk(fmt, ...) \
|
||||
({ \
|
||||
char ____fmt[] = fmt; \
|
||||
bpf_trace_printk(____fmt, sizeof(____fmt), ##__VA_ARGS__); \
|
||||
})
|
||||
|
||||
const __u16 flag_feature_wg_proxy = 0b01;
|
||||
const __u16 flag_feature_dns_fwd = 0b10;
|
||||
|
||||
const __u32 map_key_features = 0;
|
||||
struct bpf_map_def SEC("maps") nb_features = {
|
||||
.type = BPF_MAP_TYPE_ARRAY,
|
||||
.key_size = sizeof(__u32),
|
||||
.value_size = sizeof(__u16),
|
||||
.max_entries = 10,
|
||||
};
|
||||
|
||||
SEC("xdp")
|
||||
int nb_xdp_prog(struct xdp_md *ctx) {
|
||||
__u16 *features;
|
||||
features = bpf_map_lookup_elem(&nb_features, &map_key_features);
|
||||
if (!features) {
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
void *data = (void *)(long)ctx->data;
|
||||
void *data_end = (void *)(long)ctx->data_end;
|
||||
struct ethhdr *eth = data;
|
||||
struct iphdr *ip = (data + sizeof(struct ethhdr));
|
||||
struct udphdr *udp = (data + sizeof(struct ethhdr) + sizeof(struct iphdr));
|
||||
|
||||
// return early if not enough data
|
||||
if (data + sizeof(struct ethhdr) + sizeof(struct iphdr) + sizeof(struct udphdr) > data_end){
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
// skip non IPv4 packages
|
||||
if (eth->h_proto != htons(ETH_P_IP)) {
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
// skip non UPD packages
|
||||
if (ip->protocol != IPPROTO_UDP) {
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
if (*features & flag_feature_dns_fwd) {
|
||||
xdp_dns_fwd(ip, udp);
|
||||
}
|
||||
|
||||
if (*features & flag_feature_wg_proxy) {
|
||||
xdp_wg_proxy(ip, udp);
|
||||
}
|
||||
return XDP_PASS;
|
||||
}
|
||||
char _license[] SEC("license") = "GPL";
|
||||
54
client/internal/ebpf/ebpf/src/wg_proxy.c
Normal file
54
client/internal/ebpf/ebpf/src/wg_proxy.c
Normal file
@@ -0,0 +1,54 @@
|
||||
const __u32 map_key_proxy_port = 0;
|
||||
const __u32 map_key_wg_port = 1;
|
||||
|
||||
struct bpf_map_def SEC("maps") nb_wg_proxy_settings_map = {
|
||||
.type = BPF_MAP_TYPE_ARRAY,
|
||||
.key_size = sizeof(__u32),
|
||||
.value_size = sizeof(__u16),
|
||||
.max_entries = 10,
|
||||
};
|
||||
|
||||
__u16 proxy_port = 0;
|
||||
__u16 wg_port = 0;
|
||||
|
||||
bool read_port_settings() {
|
||||
__u16 *value;
|
||||
value = bpf_map_lookup_elem(&nb_wg_proxy_settings_map, &map_key_proxy_port);
|
||||
if (!value) {
|
||||
return false;
|
||||
}
|
||||
|
||||
proxy_port = *value;
|
||||
|
||||
value = bpf_map_lookup_elem(&nb_wg_proxy_settings_map, &map_key_wg_port);
|
||||
if (!value) {
|
||||
return false;
|
||||
}
|
||||
wg_port = htons(*value);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int xdp_wg_proxy(struct iphdr *ip, struct udphdr *udp) {
|
||||
if (proxy_port == 0 || wg_port == 0) {
|
||||
if (!read_port_settings()){
|
||||
return XDP_PASS;
|
||||
}
|
||||
bpf_printk("proxy port: %d, wg port: %d", proxy_port, wg_port);
|
||||
}
|
||||
|
||||
// 2130706433 = 127.0.0.1
|
||||
if (ip->daddr != htonl(2130706433)) {
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
if (udp->source != wg_port){
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
__be16 new_src_port = udp->dest;
|
||||
__be16 new_dst_port = htons(proxy_port);
|
||||
udp->dest = new_dst_port;
|
||||
udp->source = new_src_port;
|
||||
return XDP_PASS;
|
||||
}
|
||||
41
client/internal/ebpf/ebpf/wg_proxy_linux.go
Normal file
41
client/internal/ebpf/ebpf/wg_proxy_linux.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package ebpf
|
||||
|
||||
import log "github.com/sirupsen/logrus"
|
||||
|
||||
const (
|
||||
mapKeyProxyPort uint32 = 0
|
||||
mapKeyWgPort uint32 = 1
|
||||
)
|
||||
|
||||
func (tf *GeneralManager) LoadWgProxy(proxyPort, wgPort int) error {
|
||||
log.Debugf("load ebpf WG proxy")
|
||||
tf.lock.Lock()
|
||||
defer tf.lock.Unlock()
|
||||
|
||||
err := tf.loadXdp()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tf.bpfObjs.NbWgProxySettingsMap.Put(mapKeyProxyPort, uint16(proxyPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tf.bpfObjs.NbWgProxySettingsMap.Put(mapKeyWgPort, uint16(wgPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tf.setFeatureFlag(featureFlagWGProxy)
|
||||
err = tf.bpfObjs.NbFeatures.Put(mapKeyFeatures, tf.featureFlags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tf *GeneralManager) FreeWGProxy() error {
|
||||
log.Debugf("free ebpf WG proxy")
|
||||
return tf.unsetFeatureFlag(featureFlagWGProxy)
|
||||
}
|
||||
15
client/internal/ebpf/instantiater_linux.go
Normal file
15
client/internal/ebpf/instantiater_linux.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build !android
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf/ebpf"
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
)
|
||||
|
||||
// GetEbpfManagerInstance is a wrapper function. This encapsulation is required because if the code import the internal
|
||||
// ebpf package the Go compiler will include the object files. But it is not supported on Android. It can cause instant
|
||||
// panic on older Android version.
|
||||
func GetEbpfManagerInstance() manager.Manager {
|
||||
return ebpf.GetEbpfManagerInstance()
|
||||
}
|
||||
10
client/internal/ebpf/instantiater_nonlinux.go
Normal file
10
client/internal/ebpf/instantiater_nonlinux.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build !linux || android
|
||||
|
||||
package ebpf
|
||||
|
||||
import "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
|
||||
// GetEbpfManagerInstance return error because ebpf is not supported on all os
|
||||
func GetEbpfManagerInstance() manager.Manager {
|
||||
panic("unsupported os")
|
||||
}
|
||||
9
client/internal/ebpf/manager/manager.go
Normal file
9
client/internal/ebpf/manager/manager.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package manager
|
||||
|
||||
// Manager is used to load multiple eBPF programs. E.g., current DNS programs and WireGuard proxy
|
||||
type Manager interface {
|
||||
LoadDNSFwd(ip string, dnsPort int) error
|
||||
FreeDNSFwd() error
|
||||
LoadWgProxy(proxyPort, wgPort int) error
|
||||
FreeWGProxy() error
|
||||
}
|
||||
@@ -995,14 +995,12 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
log.Warnf("invalid external IP, %s, ignoring external IP mapping '%s'", external, mapping)
|
||||
break
|
||||
}
|
||||
if externalIP != nil {
|
||||
mappedIP := externalIP.String()
|
||||
if internalIP != nil {
|
||||
mappedIP = mappedIP + "/" + internalIP.String()
|
||||
}
|
||||
mappedIPs = append(mappedIPs, mappedIP)
|
||||
log.Infof("parsed external IP mapping of '%s' as '%s'", mapping, mappedIP)
|
||||
mappedIP := externalIP.String()
|
||||
if internalIP != nil {
|
||||
mappedIP = mappedIP + "/" + internalIP.String()
|
||||
}
|
||||
mappedIPs = append(mappedIPs, mappedIP)
|
||||
log.Infof("parsed external IP mapping of '%s' as '%s'", mapping, mappedIP)
|
||||
}
|
||||
if len(mappedIPs) != len(e.config.NATExternalIPs) {
|
||||
log.Warnf("one or more external IP mappings failed to parse, ignoring all mappings")
|
||||
|
||||
@@ -1046,15 +1046,15 @@ func startManagement(dataDir string) (*grpc.Server, string, error) {
|
||||
peersUpdateManager := server.NewPeersUpdateManager()
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", nil
|
||||
return nil, "", err
|
||||
}
|
||||
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "",
|
||||
eventStore)
|
||||
eventStore, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
|
||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil)
|
||||
mgmtServer, err := server.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func (n *notifier) clientStart() {
|
||||
n.serverStateLock.Lock()
|
||||
defer n.serverStateLock.Unlock()
|
||||
n.currentClientState = true
|
||||
n.lastNotification = stateConnected
|
||||
n.lastNotification = stateConnecting
|
||||
n.notify(n.lastNotification)
|
||||
}
|
||||
|
||||
@@ -114,7 +114,7 @@ func (n *notifier) calculateState(managementConn, signalConn bool) int {
|
||||
return stateConnected
|
||||
}
|
||||
|
||||
if !managementConn && !signalConn {
|
||||
if !managementConn && !signalConn && !n.currentClientState {
|
||||
return stateDisconnected
|
||||
}
|
||||
|
||||
|
||||
@@ -155,7 +155,10 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
||||
|
||||
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
||||
state, err := c.statusRecorder.GetPeer(peerKey)
|
||||
if err != nil || state.ConnStatus != peer.StatusConnected {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if state.ConnStatus != peer.StatusConnected {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/checkfw"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -26,15 +28,20 @@ func genKey(format string, input string) string {
|
||||
return fmt.Sprintf(format, input)
|
||||
}
|
||||
|
||||
// NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
|
||||
func NewFirewall(parentCTX context.Context) firewallManager {
|
||||
manager, err := newNFTablesManager(parentCTX)
|
||||
if err == nil {
|
||||
log.Debugf("nftables firewall manager will be used")
|
||||
return manager
|
||||
// newFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
|
||||
func newFirewall(parentCTX context.Context) (firewallManager, error) {
|
||||
checkResult := checkfw.Check()
|
||||
switch checkResult {
|
||||
case checkfw.IPTABLES, checkfw.IPTABLESWITHV6:
|
||||
log.Debug("creating an iptables firewall manager for route rules")
|
||||
ipv6Supported := checkResult == checkfw.IPTABLESWITHV6
|
||||
return newIptablesManager(parentCTX, ipv6Supported)
|
||||
case checkfw.NFTABLES:
|
||||
log.Info("creating an nftables firewall manager for route rules")
|
||||
return newNFTablesManager(parentCTX), nil
|
||||
}
|
||||
log.Debugf("fallback to iptables firewall manager: %s", err)
|
||||
return newIptablesManager(parentCTX)
|
||||
|
||||
return nil, fmt.Errorf("couldn't initialize nftables or iptables clients. Using a dummy firewall manager for route rules")
|
||||
}
|
||||
|
||||
func getInPair(pair routerPair) routerPair {
|
||||
|
||||
@@ -3,24 +3,13 @@
|
||||
|
||||
package routemanager
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
type unimplementedFirewall struct{}
|
||||
|
||||
func (unimplementedFirewall) RestoreOrCreateContainers() error {
|
||||
return nil
|
||||
}
|
||||
func (unimplementedFirewall) InsertRoutingRules(pair routerPair) error {
|
||||
return nil
|
||||
}
|
||||
func (unimplementedFirewall) RemoveRoutingRules(pair routerPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (unimplementedFirewall) CleanRoutingRules() {
|
||||
}
|
||||
|
||||
// NewFirewall returns an unimplemented Firewall manager
|
||||
func NewFirewall(parentCtx context.Context) firewallManager {
|
||||
return unimplementedFirewall{}
|
||||
// newFirewall returns a nil manager
|
||||
func newFirewall(context.Context) (firewallManager, error) {
|
||||
return nil, fmt.Errorf("firewall not supported on %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
@@ -49,30 +49,28 @@ type iptablesManager struct {
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func newIptablesManager(parentCtx context.Context) *iptablesManager {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
func newIptablesManager(parentCtx context.Context, ipv6Supported bool) (*iptablesManager, error) {
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err != nil {
|
||||
log.Debugf("failed to initialize iptables for ipv4: %s", err)
|
||||
} else if !isIptablesClientAvailable(ipv4Client) {
|
||||
log.Infof("iptables is missing for ipv4")
|
||||
ipv4Client = nil
|
||||
}
|
||||
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
if err != nil {
|
||||
log.Debugf("failed to initialize iptables for ipv6: %s", err)
|
||||
} else if !isIptablesClientAvailable(ipv6Client) {
|
||||
log.Infof("iptables is missing for ipv6")
|
||||
ipv6Client = nil
|
||||
return nil, fmt.Errorf("failed to initialize iptables for ipv4: %s", err)
|
||||
}
|
||||
|
||||
return &iptablesManager{
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
manager := &iptablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
ipv4Client: ipv4Client,
|
||||
ipv6Client: ipv6Client,
|
||||
rules: make(map[string]map[string][]string),
|
||||
}
|
||||
|
||||
if ipv6Supported {
|
||||
manager.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
if err != nil {
|
||||
log.Warnf("failed to initialize iptables for ipv6: %s. Routes for this protocol won't be applied.", err)
|
||||
}
|
||||
}
|
||||
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
// CleanRoutingRules cleans existing iptables resources that we created by the agent
|
||||
@@ -395,6 +393,10 @@ func (i *iptablesManager) insertRoutingRule(keyFormat, table, chain, jump string
|
||||
ipVersion = ipv6
|
||||
}
|
||||
|
||||
if iptablesClient == nil {
|
||||
return fmt.Errorf("unable to insert iptables routing rules. Iptables client is not initialized")
|
||||
}
|
||||
|
||||
ruleKey := genKey(keyFormat, pair.ID)
|
||||
rule := genRuleSpec(jump, ruleKey, pair.source, pair.destination)
|
||||
existingRule, found := i.rules[ipVersion][ruleKey]
|
||||
@@ -459,6 +461,10 @@ func (i *iptablesManager) removeRoutingRule(keyFormat, table, chain string, pair
|
||||
ipVersion = ipv6
|
||||
}
|
||||
|
||||
if iptablesClient == nil {
|
||||
return fmt.Errorf("unable to remove iptables routing rules. Iptables client is not initialized")
|
||||
}
|
||||
|
||||
ruleKey := genKey(keyFormat, pair.ID)
|
||||
existingRule, found := i.rules[ipVersion][ruleKey]
|
||||
if found {
|
||||
@@ -479,8 +485,3 @@ func getIptablesRuleType(table string) string {
|
||||
}
|
||||
return ruleType
|
||||
}
|
||||
|
||||
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||
_, err := client.ListChains("filter")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
@@ -16,11 +16,12 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
manager := newIptablesManager(context.TODO())
|
||||
manager, err := newIptablesManager(context.TODO(), true)
|
||||
require.NoError(t, err, "should return a valid iptables manager")
|
||||
|
||||
defer manager.CleanRoutingRules()
|
||||
|
||||
err := manager.RestoreOrCreateContainers()
|
||||
err = manager.RestoreOrCreateContainers()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
require.Len(t, manager.rules, 2, "should have created maps for ipv4 and ipv6")
|
||||
|
||||
@@ -27,7 +27,7 @@ type DefaultManager struct {
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
clientNetworks map[string]*clientNetwork
|
||||
serverRouter *serverRouter
|
||||
serverRouter serverRouter
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
pubKey string
|
||||
@@ -36,13 +36,17 @@ type DefaultManager struct {
|
||||
|
||||
// NewManager returns a new route manager
|
||||
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager {
|
||||
mCTX, cancel := context.WithCancel(ctx)
|
||||
srvRouter, err := newServerRouter(ctx, wgInterface)
|
||||
if err != nil {
|
||||
log.Errorf("server router is not supported: %s", err)
|
||||
}
|
||||
|
||||
mCTX, cancel := context.WithCancel(ctx)
|
||||
dm := &DefaultManager{
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
clientNetworks: make(map[string]*clientNetwork),
|
||||
serverRouter: newServerRouter(ctx, wgInterface),
|
||||
serverRouter: srvRouter,
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
pubKey: pubKey,
|
||||
@@ -59,7 +63,9 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
|
||||
// Stop stops the manager watchers and clean firewall rules
|
||||
func (m *DefaultManager) Stop() {
|
||||
m.stop()
|
||||
m.serverRouter.cleanUp()
|
||||
if m.serverRouter != nil {
|
||||
m.serverRouter.cleanUp()
|
||||
}
|
||||
m.ctx = nil
|
||||
}
|
||||
|
||||
@@ -77,9 +83,12 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
||||
|
||||
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
|
||||
m.notifier.onNewRoutes(newClientRoutesIDMap)
|
||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
if m.serverRouter != nil {
|
||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -3,11 +3,12 @@ package routemanager
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/pion/transport/v2/stdnet"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/pion/transport/v2/stdnet"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
@@ -30,7 +31,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
inputInitRoutes []*route.Route
|
||||
inputRoutes []*route.Route
|
||||
inputSerial uint64
|
||||
shouldCheckServerRoutes bool
|
||||
removeSrvRouter bool
|
||||
serverRoutesExpected int
|
||||
clientNetworkWatchersExpected int
|
||||
}{
|
||||
@@ -87,7 +88,6 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
shouldCheckServerRoutes: runtime.GOOS == "linux",
|
||||
serverRoutesExpected: 2,
|
||||
clientNetworkWatchersExpected: 0,
|
||||
},
|
||||
@@ -116,10 +116,38 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
shouldCheckServerRoutes: runtime.GOOS == "linux",
|
||||
serverRoutesExpected: 1,
|
||||
clientNetworkWatchersExpected: 1,
|
||||
},
|
||||
{
|
||||
name: "Should Create 1 Route For Client and Skip Server Route On Empty Server Router",
|
||||
inputRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("100.64.30.250/30"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
{
|
||||
ID: "b",
|
||||
NetID: "routeB",
|
||||
Peer: remotePeerKey1,
|
||||
Network: netip.MustParsePrefix("8.8.9.9/32"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
removeSrvRouter: true,
|
||||
serverRoutesExpected: 0,
|
||||
clientNetworkWatchersExpected: 1,
|
||||
},
|
||||
{
|
||||
name: "Should Create 1 HA Route and 1 Standalone",
|
||||
inputRoutes: []*route.Route{
|
||||
@@ -174,25 +202,6 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
inputSerial: 1,
|
||||
clientNetworkWatchersExpected: 0,
|
||||
},
|
||||
{
|
||||
name: "No Server Routes Should Be Added To Non Linux",
|
||||
inputRoutes: []*route.Route{
|
||||
{
|
||||
ID: "a",
|
||||
NetID: "routeA",
|
||||
Peer: localPeerKey,
|
||||
Network: netip.MustParsePrefix("1.2.3.4/32"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Metric: 9999,
|
||||
Masquerade: false,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
shouldCheckServerRoutes: runtime.GOOS != "linux",
|
||||
serverRoutesExpected: 0,
|
||||
clientNetworkWatchersExpected: 0,
|
||||
},
|
||||
{
|
||||
name: "Remove 1 Client Route",
|
||||
inputInitRoutes: []*route.Route{
|
||||
@@ -335,7 +344,6 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
},
|
||||
inputRoutes: []*route.Route{},
|
||||
inputSerial: 1,
|
||||
shouldCheckServerRoutes: true,
|
||||
serverRoutesExpected: 0,
|
||||
clientNetworkWatchersExpected: 0,
|
||||
},
|
||||
@@ -384,7 +392,6 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
},
|
||||
},
|
||||
inputSerial: 1,
|
||||
shouldCheckServerRoutes: runtime.GOOS == "linux",
|
||||
serverRoutesExpected: 2,
|
||||
clientNetworkWatchersExpected: 1,
|
||||
},
|
||||
@@ -409,6 +416,10 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
|
||||
defer routeManager.Stop()
|
||||
|
||||
if testCase.removeSrvRouter {
|
||||
routeManager.serverRouter = nil
|
||||
}
|
||||
|
||||
if len(testCase.inputInitRoutes) > 0 {
|
||||
err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
|
||||
require.NoError(t, err, "should update routes with init routes")
|
||||
@@ -419,8 +430,9 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
|
||||
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
|
||||
|
||||
if testCase.shouldCheckServerRoutes {
|
||||
require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match")
|
||||
if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
|
||||
sr := routeManager.serverRouter.(*defaultServerRouter)
|
||||
require.Len(t, sr.routes, testCase.serverRoutesExpected, "server networks size should match")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -86,10 +86,10 @@ type nftablesManager struct {
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) {
|
||||
func newNFTablesManager(parentCtx context.Context) *nftablesManager {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
|
||||
mgr := &nftablesManager{
|
||||
return &nftablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
@@ -97,18 +97,6 @@ func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) {
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
defaultForwardRules: make([]*nftables.Rule, 2),
|
||||
}
|
||||
|
||||
err := mgr.isSupported()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = mgr.readFilterTable()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
// CleanRoutingRules cleans existing nftables rules from the system
|
||||
@@ -147,6 +135,10 @@ func (n *nftablesManager) RestoreOrCreateContainers() error {
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
if table.Name == "filter" {
|
||||
n.filterTable = table
|
||||
continue
|
||||
}
|
||||
if table.Name == nftablesTable {
|
||||
if table.Family == nftables.TableFamilyIPv4 {
|
||||
n.tableIPv4 = table
|
||||
@@ -259,21 +251,6 @@ func (n *nftablesManager) refreshRulesMap() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *nftablesManager) readFilterTable() error {
|
||||
tables, err := n.conn.ListTables()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, t := range tables {
|
||||
if t.Name == "filter" {
|
||||
n.filterTable = t
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *nftablesManager) eraseDefaultForwardRule() error {
|
||||
if n.defaultForwardRules[0] == nil {
|
||||
return nil
|
||||
@@ -544,14 +521,6 @@ func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *nftablesManager) isSupported() error {
|
||||
_, err := n.conn.ListChains()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables is not supported: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getPayloadDirectives get expression directives based on ip version and direction
|
||||
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
|
||||
switch {
|
||||
|
||||
@@ -10,20 +10,23 @@ import (
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/checkfw"
|
||||
)
|
||||
|
||||
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
|
||||
manager, err := newNFTablesManager(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create nftables manager: %s", err)
|
||||
if checkfw.Check() != checkfw.NFTABLES {
|
||||
t.Skip("nftables not supported on this OS")
|
||||
}
|
||||
|
||||
manager := newNFTablesManager(context.TODO())
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer manager.CleanRoutingRules()
|
||||
|
||||
err = manager.RestoreOrCreateContainers()
|
||||
err := manager.RestoreOrCreateContainers()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
|
||||
@@ -126,19 +129,19 @@ func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
||||
if checkfw.Check() != checkfw.NFTABLES {
|
||||
t.Skip("nftables not supported on this OS")
|
||||
}
|
||||
|
||||
for _, testCase := range insertRuleTestCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
manager, err := newNFTablesManager(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create nftables manager: %s", err)
|
||||
}
|
||||
manager := newNFTablesManager(context.TODO())
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer manager.CleanRoutingRules()
|
||||
|
||||
err = manager.RestoreOrCreateContainers()
|
||||
err := manager.RestoreOrCreateContainers()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.InsertRoutingRules(testCase.inputPair)
|
||||
@@ -226,19 +229,19 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
if checkfw.Check() != checkfw.NFTABLES {
|
||||
t.Skip("nftables not supported on this OS")
|
||||
}
|
||||
|
||||
for _, testCase := range removeRuleTestCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
manager, err := newNFTablesManager(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create nftables manager: %s", err)
|
||||
}
|
||||
manager := newNFTablesManager(context.TODO())
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer manager.CleanRoutingRules()
|
||||
|
||||
err = manager.RestoreOrCreateContainers()
|
||||
err := manager.RestoreOrCreateContainers()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
table := manager.tableIPv4
|
||||
|
||||
9
client/internal/routemanager/server.go
Normal file
9
client/internal/routemanager/server.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package routemanager
|
||||
|
||||
import "github.com/netbirdio/netbird/route"
|
||||
|
||||
type serverRouter interface {
|
||||
updateRoutes(map[string]*route.Route) error
|
||||
removeFromServerNetwork(*route.Route) error
|
||||
cleanUp()
|
||||
}
|
||||
@@ -2,20 +2,11 @@ package routemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type serverRouter struct {
|
||||
func newServerRouter(context.Context, *iface.WGIface) (serverRouter, error) {
|
||||
return nil, fmt.Errorf("server route not supported on this os")
|
||||
}
|
||||
|
||||
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter {
|
||||
return &serverRouter{}
|
||||
}
|
||||
|
||||
func (r *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *serverRouter) cleanUp() {}
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type serverRouter struct {
|
||||
type defaultServerRouter struct {
|
||||
mux sync.Mutex
|
||||
ctx context.Context
|
||||
routes map[string]*route.Route
|
||||
@@ -21,16 +21,21 @@ type serverRouter struct {
|
||||
wgInterface *iface.WGIface
|
||||
}
|
||||
|
||||
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) *serverRouter {
|
||||
return &serverRouter{
|
||||
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) (serverRouter, error) {
|
||||
firewall, err := newFirewall(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &defaultServerRouter{
|
||||
ctx: ctx,
|
||||
routes: make(map[string]*route.Route),
|
||||
firewall: NewFirewall(ctx),
|
||||
firewall: firewall,
|
||||
wgInterface: wgInterface,
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
|
||||
func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) error {
|
||||
serverRoutesToRemove := make([]string, 0)
|
||||
|
||||
if len(routesMap) > 0 {
|
||||
@@ -81,7 +86,7 @@ func (m *serverRouter) updateRoutes(routesMap map[string]*route.Route) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
|
||||
func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not removing from server network because context is done")
|
||||
@@ -98,7 +103,7 @@ func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverRouter) addToServerNetwork(route *route.Route) error {
|
||||
func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not adding to server network because context is done")
|
||||
@@ -115,6 +120,6 @@ func (m *serverRouter) addToServerNetwork(route *route.Route) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverRouter) cleanUp() {
|
||||
func (m *defaultServerRouter) cleanUp() {
|
||||
m.firewall.CleanRoutingRules()
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ func InterfaceFilter(disallowList []string) func(string) bool {
|
||||
|
||||
for _, s := range disallowList {
|
||||
if strings.HasPrefix(iFace, s) {
|
||||
log.Debugf("ignoring interface %s - it is not allowed", iFace)
|
||||
log.Tracef("ignoring interface %s - it is not allowed", iFace)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -1,84 +0,0 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"net"
|
||||
|
||||
"github.com/cilium/ebpf/link"
|
||||
"github.com/cilium/ebpf/rlimit"
|
||||
)
|
||||
|
||||
const (
|
||||
mapKeyProxyPort uint32 = 0
|
||||
mapKeyWgPort uint32 = 1
|
||||
)
|
||||
|
||||
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang-14 bpf src/portreplace.c --
|
||||
|
||||
// EBPF is a wrapper for eBPF program
|
||||
type EBPF struct {
|
||||
link link.Link
|
||||
}
|
||||
|
||||
// NewEBPF create new EBPF instance
|
||||
func NewEBPF() *EBPF {
|
||||
return &EBPF{}
|
||||
}
|
||||
|
||||
// Load load ebpf program
|
||||
func (l *EBPF) Load(proxyPort, wgPort int) error {
|
||||
// it required for Docker
|
||||
err := rlimit.RemoveMemlock()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ifce, err := net.InterfaceByName("lo")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Load pre-compiled programs into the kernel.
|
||||
objs := bpfObjects{}
|
||||
err = loadBpfObjects(&objs, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = objs.Close()
|
||||
}()
|
||||
|
||||
err = objs.NbWgProxySettingsMap.Put(mapKeyProxyPort, uint16(proxyPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = objs.NbWgProxySettingsMap.Put(mapKeyWgPort, uint16(wgPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = objs.NbWgProxySettingsMap.Close()
|
||||
}()
|
||||
|
||||
l.link, err = link.AttachXDP(link.XDPOptions{
|
||||
Program: objs.NbWgProxy,
|
||||
Interface: ifce.Index,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Free ebpf program
|
||||
func (l *EBPF) Free() error {
|
||||
if l.link != nil {
|
||||
return l.link.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
package ebpf
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_newEBPF(t *testing.T) {
|
||||
ebpf := NewEBPF()
|
||||
err := ebpf.Load(1234, 51892)
|
||||
defer func() {
|
||||
_ = ebpf.Free()
|
||||
}()
|
||||
if err != nil {
|
||||
t.Errorf("%s", err)
|
||||
}
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
#include <stdbool.h>
|
||||
#include <linux/if_ether.h> // ETH_P_IP
|
||||
#include <linux/udp.h>
|
||||
#include <linux/ip.h>
|
||||
#include <netinet/in.h>
|
||||
#include <linux/bpf.h>
|
||||
#include <bpf/bpf_helpers.h>
|
||||
|
||||
#define bpf_printk(fmt, ...) \
|
||||
({ \
|
||||
char ____fmt[] = fmt; \
|
||||
bpf_trace_printk(____fmt, sizeof(____fmt), ##__VA_ARGS__); \
|
||||
})
|
||||
|
||||
const __u32 map_key_proxy_port = 0;
|
||||
const __u32 map_key_wg_port = 1;
|
||||
|
||||
struct bpf_map_def SEC("maps") nb_wg_proxy_settings_map = {
|
||||
.type = BPF_MAP_TYPE_ARRAY,
|
||||
.key_size = sizeof(__u32),
|
||||
.value_size = sizeof(__u16),
|
||||
.max_entries = 10,
|
||||
};
|
||||
|
||||
__u16 proxy_port = 0;
|
||||
__u16 wg_port = 0;
|
||||
|
||||
bool read_port_settings() {
|
||||
__u16 *value;
|
||||
value = bpf_map_lookup_elem(&nb_wg_proxy_settings_map, &map_key_proxy_port);
|
||||
if(!value) {
|
||||
return false;
|
||||
}
|
||||
|
||||
proxy_port = *value;
|
||||
|
||||
value = bpf_map_lookup_elem(&nb_wg_proxy_settings_map, &map_key_wg_port);
|
||||
if(!value) {
|
||||
return false;
|
||||
}
|
||||
wg_port = *value;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
SEC("xdp")
|
||||
int nb_wg_proxy(struct xdp_md *ctx) {
|
||||
if(proxy_port == 0 || wg_port == 0) {
|
||||
if(!read_port_settings()){
|
||||
return XDP_PASS;
|
||||
}
|
||||
bpf_printk("proxy port: %d, wg port: %d", proxy_port, wg_port);
|
||||
}
|
||||
|
||||
void *data = (void *)(long)ctx->data;
|
||||
void *data_end = (void *)(long)ctx->data_end;
|
||||
struct ethhdr *eth = data;
|
||||
struct iphdr *ip = (data + sizeof(struct ethhdr));
|
||||
struct udphdr *udp = (data + sizeof(struct ethhdr) + sizeof(struct iphdr));
|
||||
|
||||
// return early if not enough data
|
||||
if (data + sizeof(struct ethhdr) + sizeof(struct iphdr) + sizeof(struct udphdr) > data_end){
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
// skip non IPv4 packages
|
||||
if (eth->h_proto != htons(ETH_P_IP)) {
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
if (ip->protocol != IPPROTO_UDP) {
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
// 2130706433 = 127.0.0.1
|
||||
if (ip->daddr != htonl(2130706433)) {
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
if (udp->source != htons(wg_port)){
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
__be16 new_src_port = udp->dest;
|
||||
__be16 new_dst_port = htons(proxy_port);
|
||||
udp->dest = new_dst_port;
|
||||
udp->source = new_src_port;
|
||||
return XDP_PASS;
|
||||
}
|
||||
char _license[] SEC("license") = "GPL";
|
||||
@@ -12,15 +12,15 @@ import (
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
ebpf2 "github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
|
||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||
)
|
||||
|
||||
// WGEBPFProxy definition for proxy with EBPF support
|
||||
type WGEBPFProxy struct {
|
||||
ebpf *ebpf2.EBPF
|
||||
ebpfManager ebpfMgr.Manager
|
||||
lastUsedPort uint16
|
||||
localWGListenPort int
|
||||
|
||||
@@ -36,7 +36,7 @@ func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
|
||||
log.Debugf("instantiate ebpf proxy")
|
||||
wgProxy := &WGEBPFProxy{
|
||||
localWGListenPort: wgPort,
|
||||
ebpf: ebpf2.NewEBPF(),
|
||||
ebpfManager: ebpf.GetEbpfManagerInstance(),
|
||||
lastUsedPort: 0,
|
||||
turnConnStore: make(map[uint16]net.Conn),
|
||||
}
|
||||
@@ -56,7 +56,7 @@ func (p *WGEBPFProxy) Listen() error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = p.ebpf.Load(wgPorxyPort, p.localWGListenPort)
|
||||
err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -110,7 +110,7 @@ func (p *WGEBPFProxy) Free() error {
|
||||
err1 = p.conn.Close()
|
||||
}
|
||||
|
||||
err2 = p.ebpf.Free()
|
||||
err2 = p.ebpfManager.FreeWGProxy()
|
||||
if p.rawConn != nil {
|
||||
err3 = p.rawConn.Close()
|
||||
}
|
||||
@@ -135,6 +135,7 @@ func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
|
||||
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
|
||||
}
|
||||
p.removeTurnConn(endpointPort)
|
||||
log.Infof("stop forward turn packages to port: %d. error: %s", endpointPort, err)
|
||||
return
|
||||
}
|
||||
err = p.sendPkg(buf[:n], endpointPort)
|
||||
@@ -158,7 +159,7 @@ func (p *WGEBPFProxy) proxyToRemote() {
|
||||
conn, ok := p.turnConnStore[uint16(addr.Port)]
|
||||
p.turnConnMutex.Unlock()
|
||||
if !ok {
|
||||
log.Errorf("turn conn not found by port: %d", addr.Port)
|
||||
log.Infof("turn conn not found by port: %d", addr.Port)
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user