Compare commits

...

16 Commits

Author SHA1 Message Date
Viktor Liu
0ef476b014 [client] Fix state dump panic (#3519) 2025-03-16 15:13:04 +01:00
Zoltan Papp
6f82e96d6a [client] Set info logs (#3504)
collect and log connection stats per peer every 10 minutes
2025-03-14 22:34:41 +01:00
Viktor Liu
a2faae5d62 [client] Fix anonymized addresses documentation (#3505) 2025-03-14 11:38:16 +01:00
Zoltan Papp
4a3cbcd38a Nil check on route manager (#3486) 2025-03-13 00:04:00 +01:00
Misha Bragin
c2980bc8cf Update link to kubernetes operator (#3489) 2025-03-12 21:18:19 +01:00
Pascal Fischer
67ae871ce4 [management] return empty array instead of null on networks endpoints (#3480) 2025-03-11 00:20:54 +01:00
Maycon Santos
39ff5e833a [misc] Update slack invite link (#3479) 2025-03-11 00:12:11 +01:00
Zoltan Papp
cd9eff5331 Increase the timeout to 50 sec (#3481) 2025-03-10 18:23:47 +01:00
Viktor Liu
80ceb80197 [client] Ignore candidates that are part of the the wireguard subnet (#3472) 2025-03-10 13:59:21 +01:00
Zoltan Papp
636a0e2475 [client] Fix engine restart (#3435)
- Refactor the network monitoring to handle one event and it after return
- In the engine restart cancel the upper layer context and the responsibility of the engine stop will be the upper layer
- Before triggering a restart, the engine checks whether the state is already down. This helps avoid unnecessary delayed network restart events.
2025-03-10 13:32:12 +01:00
Viktor Liu
e66e329bf6 [client] Add option to autostart netbird ui in the Windows installer (#3469) 2025-03-10 13:19:17 +01:00
Zoltan Papp
aaa23beeec [client] Prevent to block channel writing (#3474)
The "runningChan" provides feedback to the UI or any client about whether the service is up and running. If the client exits earlier than when the service successfully starts, then this channel causes a block.

- Added timeout for reading the channel to ensure we don't cause blocks for too long for the caller
- Modified channel writing operations to be non-blocking
2025-03-10 13:17:09 +01:00
Zoltan Papp
6bef474e9e [client] Prevent panic in case of double close call (#3475)
Prevent panic in case of double close call
2025-03-10 13:16:28 +01:00
Maycon Santos
81040ff80a [docs] Update typo (#3477) 2025-03-10 11:52:36 +01:00
Viktor Liu
c73481aee4 [client] Enable windows stderr logs by default (#3476) 2025-03-10 11:30:49 +01:00
Viktor Liu
fc1da94520 [client, management] Add port forwarding (#3275)
Add initial support to ingress ports on the client code.

- new types where added
- new protocol messages and controller
2025-03-09 16:06:43 +01:00
149 changed files with 5193 additions and 1584 deletions

View File

@@ -258,7 +258,7 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04
steps:
@@ -325,8 +325,8 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
@@ -392,7 +392,7 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
@@ -461,7 +461,7 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04
steps:

View File

@@ -12,7 +12,7 @@
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a>
<br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
<br>
@@ -29,13 +29,13 @@
<br/>
See <a href="https://netbird.io/docs/">Documentation</a>
<br/>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">Slack channel</a>
<br/>
</strong>
<br>
<a href="https://netbird.io/webinars/achieve-zero-trust-access-to-k8s?utm_source=github&utm_campaign=2502%20-%20webinar%20-%20How%20to%20Achieve%20Zero%20Trust%20Access%20to%20Kubernetes%20-%20Effortlessly&utm_medium=github">
Webinar: Securely Access Kubernetes without Port Forwarding and Jump Hosts
<a href="https://github.com/netbirdio/kubernetes-operator">
New: NetBird Kubernetes Operator
</a>
</p>

View File

@@ -26,7 +26,7 @@ type Anonymizer struct {
}
func DefaultAddresses() (netip.Addr, netip.Addr) {
// 192.51.100.0, 100::
// 198.51.100.0, 100::
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
}

View File

@@ -0,0 +1,98 @@
package cmd
import (
"fmt"
"sort"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var forwardingRulesCmd = &cobra.Command{
Use: "forwarding",
Short: "List forwarding rules",
Long: `Commands to list forwarding rules.`,
}
var forwardingRulesListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List forwarding rules",
Example: " netbird forwarding list",
Long: "Commands to list forwarding rules.",
RunE: listForwardingRules,
}
func listForwardingRules(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ForwardingRules(cmd.Context(), &proto.EmptyRequest{})
if err != nil {
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
}
if len(resp.GetRules()) == 0 {
cmd.Println("No forwarding rules available.")
return nil
}
printForwardingRules(cmd, resp.GetRules())
return nil
}
func printForwardingRules(cmd *cobra.Command, rules []*proto.ForwardingRule) {
cmd.Println("Available forwarding rules:")
// Sort rules by translated address
sort.Slice(rules, func(i, j int) bool {
if rules[i].GetTranslatedAddress() != rules[j].GetTranslatedAddress() {
return rules[i].GetTranslatedAddress() < rules[j].GetTranslatedAddress()
}
if rules[i].GetProtocol() != rules[j].GetProtocol() {
return rules[i].GetProtocol() < rules[j].GetProtocol()
}
return getFirstPort(rules[i].GetDestinationPort()) < getFirstPort(rules[j].GetDestinationPort())
})
var lastIP string
for _, rule := range rules {
dPort := portToString(rule.GetDestinationPort())
tPort := portToString(rule.GetTranslatedPort())
if lastIP != rule.GetTranslatedAddress() {
lastIP = rule.GetTranslatedAddress()
cmd.Printf("\nTranslated peer: %s\n", rule.GetTranslatedHostname())
}
cmd.Printf(" Local %s/%s to %s:%s\n", rule.GetProtocol(), dPort, rule.GetTranslatedAddress(), tPort)
}
}
func getFirstPort(portInfo *proto.PortInfo) int {
switch v := portInfo.PortSelection.(type) {
case *proto.PortInfo_Port:
return int(v.Port)
case *proto.PortInfo_Range_:
return int(v.Range.GetStart())
default:
return 0
}
}
func portToString(translatedPort *proto.PortInfo) string {
switch v := translatedPort.PortSelection.(type) {
case *proto.PortInfo_Port:
return fmt.Sprintf("%d", v.Port)
case *proto.PortInfo_Range_:
return fmt.Sprintf("%d-%d", v.Range.GetStart(), v.Range.GetEnd())
default:
return "No port specified"
}
}

View File

@@ -145,6 +145,7 @@ func init() {
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(networksCMD)
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
@@ -153,6 +154,8 @@ func init() {
networksCMD.AddCommand(routesListCmd)
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
forwardingRulesCmd.AddCommand(forwardingRulesListCmd)
debugCmd.AddCommand(debugBundleCmd)
debugCmd.AddCommand(logCmd)
logCmd.AddCommand(logLevelCmd)

View File

@@ -10,6 +10,7 @@ import (
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -89,7 +90,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock())
if err != nil {
t.Fatal(err)
}

View File

@@ -134,10 +134,11 @@ func (c *Client) Start(startCtx context.Context) error {
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available
run := make(chan error, 1)
run := make(chan struct{}, 1)
clientErr := make(chan error, 1)
go func() {
if err := client.Run(run); err != nil {
run <- err
clientErr <- err
}
}()
@@ -147,13 +148,9 @@ func (c *Client) Start(startCtx context.Context) error {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
}
return startCtx.Err()
case err := <-run:
if err != nil {
if stopErr := client.Stop(); stopErr != nil {
return fmt.Errorf("stop error after failed to startup. Stop error: %w. Start error: %w", stopErr, err)
}
return fmt.Errorf("startup: %w", err)
}
case err := <-clientErr:
return fmt.Errorf("startup: %w", err)
case <-run:
}
c.connect = client

View File

@@ -4,12 +4,13 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string
Address() device.WGAddress
Address() wgaddr.Address
IsUserspaceBind() bool
SetFilter(device.PacketFilter) error
GetDevice() *device.FilteredDevice

View File

@@ -30,10 +30,8 @@ type entry struct {
}
type aclManager struct {
iptablesClient *iptables.IPTables
wgIface iFaceMapper
routingFwChainName string
iptablesClient *iptables.IPTables
wgIface iFaceMapper
entries aclEntries
optionalEntries map[string][]entry
ipsetStore *ipsetStore
@@ -41,12 +39,10 @@ type aclManager struct {
stateManager *statemanager.Manager
}
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
m := &aclManager{
iptablesClient: iptablesClient,
wgIface: wgIface,
routingFwChainName: routingFwChainName,
iptablesClient: iptablesClient,
wgIface: wgIface,
entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(),
@@ -314,9 +310,12 @@ func (m *aclManager) seedInitialEntries() {
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
// Inbound is handled by our ACLs, the rest is dropped.
// For outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules.
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
}
func (m *aclManager) seedInitialOptionalEntries() {

View File

@@ -13,7 +13,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -31,7 +31,7 @@ type Manager struct {
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
Address() wgaddr.Address
IsUserspaceBind() bool
}
@@ -52,7 +52,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
return nil, fmt.Errorf("create router: %w", err)
}
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
m.aclMgr, err = newAclManager(iptablesClient, wgIface)
if err != nil {
return nil, fmt.Errorf("create acl manager: %w", err)
}
@@ -226,6 +226,22 @@ func (m *Manager) DisableRouting() error {
return nil
}
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteDNATRule(rule)
}
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@@ -10,15 +10,15 @@ import (
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
@@ -31,7 +31,7 @@ var ifaceMock = &iFaceMock{
// iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct {
NameFunc func() string
AddressFunc func() iface.WGAddress
AddressFunc func() wgaddr.Address
}
func (i *iFaceMock) Name() string {
@@ -41,7 +41,7 @@ func (i *iFaceMock) Name() string {
panic("NameFunc is not set")
}
func (i *iFaceMock) Address() iface.WGAddress {
func (i *iFaceMock) Address() wgaddr.Address {
if i.AddressFunc != nil {
return i.AddressFunc()
}
@@ -117,8 +117,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
@@ -184,8 +184,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),

View File

@@ -16,6 +16,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
@@ -23,22 +24,36 @@ import (
// constants needed to manage and create iptable rules
const (
tableFilter = "filter"
tableNat = "nat"
tableMangle = "mangle"
tableFilter = "filter"
tableNat = "nat"
tableMangle = "mangle"
chainPOSTROUTING = "POSTROUTING"
chainPREROUTING = "PREROUTING"
chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWD = "NETBIRD-RT-FWD"
chainRTFWDIN = "NETBIRD-RT-FWD-IN"
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
chainRTPRE = "NETBIRD-RT-PRE"
chainRTRDR = "NETBIRD-RT-RDR"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
jumpPre = "jump-pre"
jumpNat = "jump-nat"
matchSet = "--match-set"
jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post"
matchSet = "--match-set"
dnatSuffix = "_dnat"
snatSuffix = "_snat"
fwdSuffix = "_fwd"
)
type ruleInfo struct {
chain string
table string
rule []string
}
type routeFilteringRuleParams struct {
Sources []netip.Prefix
Destination netip.Prefix
@@ -62,6 +77,7 @@ type router struct {
legacyManagement bool
stateManager *statemanager.Manager
ipFwdState *ipfwdstate.IPForwardingState
}
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
@@ -69,6 +85,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router,
iptablesClient: iptablesClient,
rules: make(map[string][]string),
wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
}
r.ipsetCounter = refcounter.New(
@@ -139,9 +156,9 @@ func (r *router) AddRouteFiltering(
var err error
if action == firewall.ActionDrop {
// after the established rule
err = r.iptablesClient.Insert(tableFilter, chainRTFWD, 2, rule...)
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
} else {
err = r.iptablesClient.Append(tableFilter, chainRTFWD, rule...)
err = r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...)
}
if err != nil {
@@ -156,12 +173,12 @@ func (r *router) AddRouteFiltering(
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.GetRuleID()
ruleKey := rule.ID()
if rule, exists := r.rules[ruleKey]; exists {
setName := r.findSetNameInRule(rule)
if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil {
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
return fmt.Errorf("delete route rule: %v", err)
}
delete(r.rules, ruleKey)
@@ -212,6 +229,10 @@ func (r *router) deleteIpSet(setName string) error {
// AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
@@ -238,6 +259,10 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)
}
@@ -264,7 +289,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
}
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
if err := r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...); err != nil {
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
@@ -277,7 +302,7 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
delete(r.rules, ruleKey)
@@ -305,7 +330,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
continue
}
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} else {
delete(r.rules, k)
@@ -343,9 +368,11 @@ func (r *router) cleanUpDefaultForwardRules() error {
chain string
table string
}{
{chainRTFWD, tableFilter},
{chainRTNAT, tableNat},
{chainRTFWDIN, tableFilter},
{chainRTFWDOUT, tableFilter},
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
{chainRTRDR, tableNat},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil {
@@ -365,16 +392,22 @@ func (r *router) createContainers() error {
chain string
table string
}{
{chainRTFWD, tableFilter},
{chainRTFWDIN, tableFilter},
{chainRTFWDOUT, tableFilter},
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
{chainRTRDR, tableNat},
} {
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
}
}
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
if err := r.insertEstablishedRule(chainRTFWDIN); err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
if err := r.insertEstablishedRule(chainRTFWDOUT); err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
@@ -415,27 +448,6 @@ func (r *router) addPostroutingRules() error {
return nil
}
func (r *router) createAndSetupChain(chain string) error {
table := r.getTableForChain(chain)
if err := r.iptablesClient.NewChain(table, chain); err != nil {
return fmt.Errorf("failed creating chain %s, error: %v", chain, err)
}
return nil
}
func (r *router) getTableForChain(chain string) string {
switch chain {
case chainRTNAT:
return tableNat
case chainRTPRE:
return tableMangle
default:
return tableFilter
}
}
func (r *router) insertEstablishedRule(chain string) error {
establishedRule := getConntrackEstablished()
@@ -454,28 +466,43 @@ func (r *router) addJumpRules() error {
// Jump to NAT chain
natRule := []string{"-j", chainRTNAT}
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
return fmt.Errorf("add nat jump rule: %v", err)
return fmt.Errorf("add nat postrouting jump rule: %v", err)
}
r.rules[jumpNat] = natRule
r.rules[jumpNatPost] = natRule
// Jump to prerouting chain
// Jump to mangle prerouting chain
preRule := []string{"-j", chainRTPRE}
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
return fmt.Errorf("add prerouting jump rule: %v", err)
return fmt.Errorf("add mangle prerouting jump rule: %v", err)
}
r.rules[jumpPre] = preRule
r.rules[jumpManglePre] = preRule
// Jump to nat prerouting chain
rdrRule := []string{"-j", chainRTRDR}
if err := r.iptablesClient.Insert(tableNat, chainPREROUTING, 1, rdrRule...); err != nil {
return fmt.Errorf("add nat prerouting jump rule: %v", err)
}
r.rules[jumpNatPre] = rdrRule
return nil
}
func (r *router) cleanJumpRules() error {
for _, ruleKey := range []string{jumpNat, jumpPre} {
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} {
if rule, exists := r.rules[ruleKey]; exists {
table := tableNat
chain := chainPOSTROUTING
if ruleKey == jumpPre {
var table, chain string
switch ruleKey {
case jumpNatPost:
table = tableNat
chain = chainPOSTROUTING
case jumpManglePre:
table = tableMangle
chain = chainPREROUTING
case jumpNatPre:
table = tableNat
chain = chainPREROUTING
default:
return fmt.Errorf("unknown jump rule: %s", ruleKey)
}
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
@@ -520,6 +547,8 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
}
r.rules[ruleKey] = rule
r.updateState()
return nil
}
@@ -535,6 +564,7 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
log.Debugf("marking rule %s not found", ruleKey)
}
r.updateState()
return nil
}
@@ -564,6 +594,137 @@ func (r *router) updateState() {
}
}
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
ruleKey := rule.ID()
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
return rule, nil
}
toDestination := rule.TranslatedAddress.String()
switch {
case len(rule.TranslatedPort.Values) == 0:
// no translated port, use original port
case len(rule.TranslatedPort.Values) == 1:
toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0])
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
// need the "/originalport" suffix to avoid dnat port randomization
toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0])
default:
return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
}
proto := strings.ToLower(string(rule.Protocol))
rules := make(map[string]ruleInfo, 3)
// DNAT rule
dnatRule := []string{
"!", "-i", r.wgIface.Name(),
"-p", proto,
"-j", "DNAT",
"--to-destination", toDestination,
}
dnatRule = append(dnatRule, applyPort("--dport", &rule.DestinationPort)...)
rules[ruleKey+dnatSuffix] = ruleInfo{
table: tableNat,
chain: chainRTRDR,
rule: dnatRule,
}
// SNAT rule
snatRule := []string{
"-o", r.wgIface.Name(),
"-p", proto,
"-d", rule.TranslatedAddress.String(),
"-j", "MASQUERADE",
}
snatRule = append(snatRule, applyPort("--dport", &rule.TranslatedPort)...)
rules[ruleKey+snatSuffix] = ruleInfo{
table: tableNat,
chain: chainRTNAT,
rule: snatRule,
}
// Forward filtering rule, if fwd policy is DROP
forwardRule := []string{
"-o", r.wgIface.Name(),
"-p", proto,
"-d", rule.TranslatedAddress.String(),
"-j", "ACCEPT",
}
forwardRule = append(forwardRule, applyPort("--dport", &rule.TranslatedPort)...)
rules[ruleKey+fwdSuffix] = ruleInfo{
table: tableFilter,
chain: chainRTFWDOUT,
rule: forwardRule,
}
for key, ruleInfo := range rules {
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
log.Errorf("rollback failed: %v", rollbackErr)
}
return nil, fmt.Errorf("add rule %s: %w", key, err)
}
r.rules[key] = ruleInfo.rule
}
r.updateState()
return rule, nil
}
func (r *router) rollbackRules(rules map[string]ruleInfo) error {
var merr *multierror.Error
for key, ruleInfo := range rules {
if err := r.iptablesClient.DeleteIfExists(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("rollback rule %s: %w", key, err))
// On rollback error, add to rules map for next cleanup
r.rules[key] = ruleInfo.rule
}
}
if merr != nil {
r.updateState()
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
ruleKey := rule.ID()
var merr *multierror.Error
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete DNAT rule: %w", err))
}
delete(r.rules, ruleKey+dnatSuffix)
}
if snatRule, exists := r.rules[ruleKey+snatSuffix]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTNAT, snatRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete SNAT rule: %w", err))
}
delete(r.rules, ruleKey+snatSuffix)
}
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
}
delete(r.rules, ruleKey+fwdSuffix)
}
r.updateState()
return nberrors.FormatErrorOrNil(merr)
}
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
var rule []string

View File

@@ -39,12 +39,14 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
}()
// Now 5 rules:
// 1. established rule in forward chain
// 2. jump rule to NAT chain
// 3. jump rule to PRE chain
// 4. static outbound masquerade rule
// 5. static return masquerade rule
require.Len(t, manager.rules, 5, "should have created rules map")
// 1. established rule forward in
// 2. estbalished rule forward out
// 3. jump rule to POST nat chain
// 4. jump rule to PRE mangle chain
// 5. jump rule to PRE nat chain
// 6. static outbound masquerade rule
// 7. static return masquerade rule
require.Len(t, manager.rules, 7, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
@@ -332,14 +334,14 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
require.NoError(t, err, "AddRouteFiltering failed")
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.GetRuleID()]
rule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "Rule not found in internal map")
// Log the internal rule
t.Logf("Internal rule: %v", rule)
// Check if the rule exists in iptables
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...)
exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, rule...)
assert.NoError(t, err, "Failed to check rule existence")
assert.True(t, exists, "Rule not found in iptables")

View File

@@ -12,6 +12,6 @@ type Rule struct {
}
// GetRuleID returns the rule id
func (r *Rule) GetRuleID() string {
func (r *Rule) ID() string {
return r.ruleID
}

View File

@@ -4,21 +4,20 @@ import (
"fmt"
"sync"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() device.WGAddress {
func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress
}

View File

@@ -26,8 +26,8 @@ const (
// Each firewall type for different OS can use different type
// of the properties to hold data of the created rule
type Rule interface {
// GetRuleID returns the rule id
GetRuleID() string
// ID returns the rule id
ID() string
}
// RuleDirection is the traffic direction which a rule is applied
@@ -105,6 +105,12 @@ type Manager interface {
EnableRouting() error
DisableRouting() error
// AddDNATRule adds a DNAT rule
AddDNATRule(ForwardRule) (Rule, error)
// DeleteDNATRule deletes a DNAT rule
DeleteDNATRule(Rule) error
}
func GenKey(format string, pair RouterPair) string {

View File

@@ -0,0 +1,27 @@
package manager
import (
"fmt"
"net/netip"
)
// ForwardRule todo figure out better place to this to avoid circular imports
type ForwardRule struct {
Protocol Protocol
DestinationPort Port
TranslatedAddress netip.Addr
TranslatedPort Port
}
func (r ForwardRule) ID() string {
id := fmt.Sprintf("%s;%s;%s;%s",
r.Protocol,
r.DestinationPort.String(),
r.TranslatedAddress.String(),
r.TranslatedPort.String())
return id
}
func (r ForwardRule) String() string {
return fmt.Sprintf("protocol: %s, destinationPort: %s, translatedAddress: %s, translatedPort: %s", r.Protocol, r.DestinationPort.String(), r.TranslatedAddress.String(), r.TranslatedPort.String())
}

View File

@@ -1,30 +1,12 @@
package manager
import (
"fmt"
"strconv"
)
// Protocol is the protocol of the port
type Protocol string
const (
// ProtocolTCP is the TCP protocol
ProtocolTCP Protocol = "tcp"
// ProtocolUDP is the UDP protocol
ProtocolUDP Protocol = "udp"
// ProtocolICMP is the ICMP protocol
ProtocolICMP Protocol = "icmp"
// ProtocolALL cover all supported protocols
ProtocolALL Protocol = "all"
// ProtocolUnknown unknown protocol
ProtocolUnknown Protocol = "unknown"
)
// Port of the address for firewall rule
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
type Port struct {
// IsRange is true Values contains two values, the first is the start port, the second is the end port
IsRange bool
@@ -33,6 +15,25 @@ type Port struct {
Values []uint16
}
func NewPort(ports ...int) (*Port, error) {
if len(ports) == 0 {
return nil, fmt.Errorf("no port provided")
}
ports16 := make([]uint16, len(ports))
for i, port := range ports {
if port < 1 || port > 65535 {
return nil, fmt.Errorf("invalid port number: %d (must be between 1-65535)", port)
}
ports16[i] = uint16(port)
}
return &Port{
IsRange: len(ports) > 1,
Values: ports16,
}, nil
}
// String interface implementation
func (p *Port) String() string {
var ports string

View File

@@ -0,0 +1,19 @@
package manager
// Protocol is the protocol of the port
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
type Protocol string
const (
// ProtocolTCP is the TCP protocol
ProtocolTCP Protocol = "tcp"
// ProtocolUDP is the UDP protocol
ProtocolUDP Protocol = "udp"
// ProtocolICMP is the ICMP protocol
ProtocolICMP Protocol = "icmp"
// ProtocolALL cover all supported protocols
ProtocolALL Protocol = "all"
)

View File

@@ -127,7 +127,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
delete(m.rules, r.GetRuleID())
delete(m.rules, r.ID())
return m.rConn.Flush()
}
@@ -141,7 +141,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
delete(m.rules, r.GetRuleID())
delete(m.rules, r.ID())
return m.rConn.Flush()
}
@@ -176,7 +176,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
return err
}
delete(m.rules, r.GetRuleID())
delete(m.rules, r.ID())
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {

View File

@@ -14,7 +14,7 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -29,7 +29,7 @@ const (
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
Address() wgaddr.Address
IsUserspaceBind() bool
}
@@ -342,6 +342,22 @@ func (m *Manager) Flush() error {
return m.aclManager.Flush()
}
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteDNATRule(rule)
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {

View File

@@ -16,15 +16,15 @@ import (
"golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
@@ -37,7 +37,7 @@ var ifaceMock = &iFaceMock{
// iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct {
NameFunc func() string
AddressFunc func() iface.WGAddress
AddressFunc func() wgaddr.Address
}
func (i *iFaceMock) Name() string {
@@ -47,7 +47,7 @@ func (i *iFaceMock) Name() string {
panic("NameFunc is not set")
}
func (i *iFaceMock) Address() iface.WGAddress {
func (i *iFaceMock) Address() wgaddr.Address {
if i.AddressFunc != nil {
return i.AddressFunc()
}
@@ -171,8 +171,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),

View File

@@ -14,23 +14,31 @@ import (
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/google/nftables/xt"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameForward = "FORWARD"
tableNat = "nat"
chainNameNatPrerouting = "PREROUTING"
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingRdr = "netbird-rt-redirect"
chainNameForward = "FORWARD"
userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif"
dnatSuffix = "_dnat"
snatSuffix = "_snat"
)
const refreshRulesMapError = "refresh rules map: %w"
@@ -49,16 +57,18 @@ type router struct {
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool
}
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
r := &router{
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
wgIface: wgIface,
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
}
r.ipsetCounter = refcounter.New(
@@ -98,7 +108,52 @@ func (r *router) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear()
return r.removeAcceptForwardRules()
var merr *multierror.Error
if err := r.removeAcceptForwardRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err))
}
if err := r.removeNatPreroutingRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) removeNatPreroutingRules() error {
table := &nftables.Table{
Name: tableNat,
Family: nftables.TableFamilyIPv4,
}
chain := &nftables.Chain{
Name: chainNameNatPrerouting,
Table: table,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
}
rules, err := r.conn.GetRules(table, chain)
if err != nil {
return fmt.Errorf("get rules from nat table: %w", err)
}
var merr *multierror.Error
// Delete rules that have our UserData suffix
for _, rule := range rules {
if len(rule.UserData) == 0 || !strings.HasSuffix(string(rule.UserData), dnatSuffix) {
continue
}
if err := r.conn.DelRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete rule %s: %w", rule.UserData, err))
}
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
@@ -133,14 +188,22 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT,
})
r.chains[chainNameRoutingRdr] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingRdr,
Table: r.workTable,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
})
// Chain is created by acl manager
// TODO: move creation to a common place
r.chains[chainNamePrerouting] = &nftables.Chain{
Name: chainNamePrerouting,
Table: r.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
}
// Add the single NAT rule that matches on mark
@@ -281,7 +344,7 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleKey := rule.GetRuleID()
ruleKey := rule.ID()
nftRule, exists := r.rules[ruleKey]
if !exists {
log.Debugf("route rule %s not found", ruleKey)
@@ -410,6 +473,10 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
// AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
@@ -836,6 +903,10 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
// RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
@@ -896,6 +967,269 @@ func (r *router) refreshRulesMap() error {
return nil
}
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
ruleKey := rule.ID()
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
return rule, nil
}
protoNum, err := protoToInt(rule.Protocol)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
if err := r.addDnatRedirect(rule, protoNum, ruleKey); err != nil {
return nil, err
}
r.addDnatMasq(rule, protoNum, ruleKey)
// Unlike iptables, there's no point in adding "out" rules in the forward chain here as our policy is ACCEPT.
// To overcome DROP policies in other chains, we'd have to add rules to the chains there.
// We also cannot just add "oif <iface> accept" there and filter in our own table as we don't know what is supposed to be allowed.
// TODO: find chains with drop policies and add rules there
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush rules: %w", err)
}
return &rule, nil
}
func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, ruleKey string) error {
dnatExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
}
dnatExprs = append(dnatExprs, applyPort(&rule.DestinationPort, false)...)
// shifted translated port is not supported in nftables, so we hand this over to xtables
if rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2 {
if rule.TranslatedPort.Values[0] != rule.DestinationPort.Values[0] ||
rule.TranslatedPort.Values[1] != rule.DestinationPort.Values[1] {
return r.addXTablesRedirect(dnatExprs, ruleKey, rule)
}
}
additionalExprs, regProtoMin, regProtoMax, err := r.handleTranslatedPort(rule)
if err != nil {
return err
}
dnatExprs = append(dnatExprs, additionalExprs...)
dnatExprs = append(dnatExprs,
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: regProtoMin,
RegProtoMax: regProtoMax,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingRdr],
Exprs: dnatExprs,
UserData: []byte(ruleKey + dnatSuffix),
}
r.conn.AddRule(dnatRule)
r.rules[ruleKey+dnatSuffix] = dnatRule
return nil
}
func (r *router) handleTranslatedPort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
switch {
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
return r.handlePortRange(rule)
case len(rule.TranslatedPort.Values) == 0:
return r.handleAddressOnly(rule)
case len(rule.TranslatedPort.Values) == 1:
return r.handleSinglePort(rule)
default:
return nil, 0, 0, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
}
}
func (r *router) handlePortRange(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
},
&expr.Immediate{
Register: 3,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[1]),
},
}
return exprs, 2, 3, nil
}
func (r *router) handleAddressOnly(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
}
return exprs, 0, 0, nil
}
func (r *router) handleSinglePort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
},
}
return exprs, 2, 0, nil
}
func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule firewall.ForwardRule) error {
dnatExprs = append(dnatExprs,
&expr.Counter{},
&expr.Target{
Name: "DNAT",
Rev: 2,
Info: &xt.NatRange2{
NatRange: xt.NatRange{
Flags: uint(xt.NatRangeMapIPs | xt.NatRangeProtoSpecified | xt.NatRangeProtoOffset),
MinIP: rule.TranslatedAddress.AsSlice(),
MaxIP: rule.TranslatedAddress.AsSlice(),
MinPort: rule.TranslatedPort.Values[0],
MaxPort: rule.TranslatedPort.Values[1],
},
BasePort: rule.DestinationPort.Values[0],
},
},
)
dnatRule := &nftables.Rule{
Table: &nftables.Table{
Name: tableNat,
Family: nftables.TableFamilyIPv4,
},
Chain: &nftables.Chain{
Name: chainNameNatPrerouting,
Table: r.filterTable,
Type: nftables.ChainTypeNAT,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
},
Exprs: dnatExprs,
UserData: []byte(ruleKey + dnatSuffix),
}
r.conn.AddRule(dnatRule)
r.rules[ruleKey+dnatSuffix] = dnatRule
return nil
}
func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey string) {
masqExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
}
masqExprs = append(masqExprs, applyPort(&rule.TranslatedPort, false)...)
masqExprs = append(masqExprs, &expr.Masq{})
masqRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: masqExprs,
UserData: []byte(ruleKey + snatSuffix),
}
r.conn.AddRule(masqRule)
r.rules[ruleKey+snatSuffix] = masqRule
}
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
ruleKey := rule.ID()
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if err := r.conn.DelRule(dnatRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
}
}
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
if err := r.conn.DelRule(masqRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
}
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
if merr == nil {
delete(r.rules, ruleKey+dnatSuffix)
delete(r.rules, ruleKey+snatSuffix)
}
return nberrors.FormatErrorOrNil(merr)
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
var offset uint32
@@ -959,15 +1293,11 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
if port.IsRange && len(port.Values) == 2 {
// Handle port range
exprs = append(exprs,
&expr.Cmp{
Op: expr.CmpOpGte,
&expr.Range{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(port.Values[0]),
},
&expr.Cmp{
Op: expr.CmpOpLte,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(port.Values[1]),
FromData: binaryutil.BigEndian.PutUint16(port.Values[0]),
ToData: binaryutil.BigEndian.PutUint16(port.Values[1]),
},
)
} else {

View File

@@ -319,7 +319,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
})
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.GetRuleID()]
rule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "Rule not found in internal map")
t.Log("Internal rule expressions:")
@@ -336,7 +336,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
var nftRule *nftables.Rule
for _, rule := range rules {
if string(rule.UserData) == ruleKey.GetRuleID() {
if string(rule.UserData) == ruleKey.ID() {
nftRule = rule
break
}
@@ -595,16 +595,20 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
payloadFound = true
}
case *expr.Cmp:
if port.IsRange {
if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte {
case *expr.Range:
if port.IsRange && len(port.Values) == 2 {
fromPort := binary.BigEndian.Uint16(ex.FromData)
toPort := binary.BigEndian.Uint16(ex.ToData)
if fromPort == port.Values[0] && toPort == port.Values[1] {
portMatchFound = true
}
} else {
}
case *expr.Cmp:
if !port.IsRange {
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
portValue := binary.BigEndian.Uint16(ex.Data)
for _, p := range port.Values {
if uint16(p) == portValue {
if p == portValue {
portMatchFound = true
break
}

View File

@@ -16,6 +16,6 @@ type Rule struct {
}
// GetRuleID returns the rule id
func (r *Rule) GetRuleID() string {
func (r *Rule) ID() string {
return r.ruleID
}

View File

@@ -3,21 +3,20 @@ package nftables
import (
"fmt"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() device.WGAddress {
func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress
}

View File

@@ -3,14 +3,14 @@ package common
import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
SetFilter(device.PacketFilter) error
Address() iface.WGAddress
Address() wgaddr.Address
GetWGDevice() *wgdevice.Device
GetDevice() *device.FilteredDevice
}

View File

@@ -1,6 +1,7 @@
package conntrack
import (
"context"
"net"
"sync"
"time"
@@ -39,8 +40,8 @@ type ICMPTracker struct {
connections map[ICMPConnKey]*ICMPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex
done chan struct{}
ipPool *PreallocatedIPs
}
@@ -50,16 +51,18 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
timeout = DefaultICMPTimeout
}
ctx, cancel := context.WithCancel(context.Background())
tracker := &ICMPTracker{
logger: logger,
connections: make(map[ICMPConnKey]*ICMPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
done: make(chan struct{}),
tickerCancel: cancel,
ipPool: NewPreallocatedIPs(),
}
go tracker.cleanupRoutine()
go tracker.cleanupRoutine(ctx)
return tracker
}
@@ -119,12 +122,14 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq
conn.Sequence == seq
}
func (t *ICMPTracker) cleanupRoutine() {
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
defer t.tickerCancel()
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-t.done:
case <-ctx.Done():
return
}
}
@@ -146,8 +151,7 @@ func (t *ICMPTracker) cleanup() {
// Close stops the cleanup routine and releases resources
func (t *ICMPTracker) Close() {
t.cleanupTicker.Stop()
close(t.done)
t.tickerCancel()
t.mutex.Lock()
for _, conn := range t.connections {

View File

@@ -3,6 +3,7 @@ package conntrack
// TODO: Send RST packets for invalid/timed-out connections
import (
"context"
"net"
"sync"
"sync/atomic"
@@ -85,23 +86,26 @@ type TCPTracker struct {
connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex
cleanupTicker *time.Ticker
done chan struct{}
tickerCancel context.CancelFunc
timeout time.Duration
ipPool *PreallocatedIPs
}
// NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
ctx, cancel := context.WithCancel(context.Background())
tracker := &TCPTracker{
logger: logger,
connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
done: make(chan struct{}),
tickerCancel: cancel,
timeout: timeout,
ipPool: NewPreallocatedIPs(),
}
go tracker.cleanupRoutine()
go tracker.cleanupRoutine(ctx)
return tracker
}
@@ -315,12 +319,14 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
return false
}
func (t *TCPTracker) cleanupRoutine() {
func (t *TCPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-t.done:
case <-ctx.Done():
return
}
}
@@ -355,8 +361,7 @@ func (t *TCPTracker) cleanup() {
// Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() {
t.cleanupTicker.Stop()
close(t.done)
t.tickerCancel()
// Clean up all remaining IPs
t.mutex.Lock()

View File

@@ -1,6 +1,7 @@
package conntrack
import (
"context"
"net"
"sync"
"time"
@@ -26,8 +27,8 @@ type UDPTracker struct {
connections map[ConnKey]*UDPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex
done chan struct{}
ipPool *PreallocatedIPs
}
@@ -37,16 +38,18 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
timeout = DefaultUDPTimeout
}
ctx, cancel := context.WithCancel(context.Background())
tracker := &UDPTracker{
logger: logger,
connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval),
done: make(chan struct{}),
tickerCancel: cancel,
ipPool: NewPreallocatedIPs(),
}
go tracker.cleanupRoutine()
go tracker.cleanupRoutine(ctx)
return tracker
}
@@ -103,12 +106,14 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
}
// cleanupRoutine periodically removes stale connections
func (t *UDPTracker) cleanupRoutine() {
func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-t.done:
case <-ctx.Done():
return
}
}
@@ -131,8 +136,7 @@ func (t *UDPTracker) cleanup() {
// Close stops the cleanup routine and releases resources
func (t *UDPTracker) Close() {
t.cleanupTicker.Stop()
close(t.done)
t.tickerCancel()
t.mutex.Lock()
for _, conn := range t.connections {

View File

@@ -1,6 +1,7 @@
package conntrack
import (
"context"
"net"
"testing"
"time"
@@ -34,7 +35,7 @@ func TestNewUDPTracker(t *testing.T) {
assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections)
assert.NotNil(t, tracker.cleanupTicker)
assert.NotNil(t, tracker.done)
assert.NotNil(t, tracker.tickerCancel)
})
}
}
@@ -154,18 +155,21 @@ func TestUDPTracker_Cleanup(t *testing.T) {
timeout := 50 * time.Millisecond
cleanupInterval := 25 * time.Millisecond
ctx, tickerCancel := context.WithCancel(context.Background())
defer tickerCancel()
// Create tracker with custom cleanup interval
tracker := &UDPTracker{
connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(cleanupInterval),
done: make(chan struct{}),
tickerCancel: tickerCancel,
ipPool: NewPreallocatedIPs(),
logger: logger,
}
// Start cleanup routine
go tracker.cleanupRoutine()
go tracker.cleanupRoutine(ctx)
// Add some connections
connections := []struct {

View File

@@ -6,19 +6,19 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func TestLocalIPManager(t *testing.T) {
tests := []struct {
name string
setupAddr iface.WGAddress
setupAddr wgaddr.Address
testIP net.IP
expected bool
}{
{
name: "Localhost range",
setupAddr: iface.WGAddress{
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
@@ -30,7 +30,7 @@ func TestLocalIPManager(t *testing.T) {
},
{
name: "Localhost standard address",
setupAddr: iface.WGAddress{
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
@@ -42,7 +42,7 @@ func TestLocalIPManager(t *testing.T) {
},
{
name: "Localhost range edge",
setupAddr: iface.WGAddress{
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
@@ -54,7 +54,7 @@ func TestLocalIPManager(t *testing.T) {
},
{
name: "Local IP matches",
setupAddr: iface.WGAddress{
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
@@ -66,7 +66,7 @@ func TestLocalIPManager(t *testing.T) {
},
{
name: "Local IP doesn't match",
setupAddr: iface.WGAddress{
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
@@ -78,7 +78,7 @@ func TestLocalIPManager(t *testing.T) {
},
{
name: "IPv6 address",
setupAddr: iface.WGAddress{
setupAddr: wgaddr.Address{
IP: net.ParseIP("fe80::1"),
Network: &net.IPNet{
IP: net.ParseIP("fe80::"),
@@ -95,7 +95,7 @@ func TestLocalIPManager(t *testing.T) {
manager := newLocalIPManager()
mock := &IFaceMock{
AddressFunc: func() iface.WGAddress {
AddressFunc: func() wgaddr.Address {
return tt.setupAddr
},
}

View File

@@ -24,8 +24,8 @@ type PeerRule struct {
udpHook func([]byte) bool
}
// GetRuleID returns the rule id
func (r *PeerRule) GetRuleID() string {
// ID returns the rule id
func (r *PeerRule) ID() string {
return r.id
}
@@ -39,7 +39,7 @@ type RouteRule struct {
action firewall.Action
}
// GetRuleID returns the rule id
func (r *RouteRule) GetRuleID() string {
// ID returns the rule id
func (r *RouteRule) ID() string {
return r.id
}

View File

@@ -42,6 +42,8 @@ const (
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
)
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
// RuleSet is a set of rules grouped by a string key
type RuleSet map[string]PeerRule
@@ -437,7 +439,7 @@ func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
ruleID := rule.GetRuleID()
ruleID := rule.ID()
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
return r.id == ruleID
})
@@ -478,6 +480,22 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errNatNotSupported
}
return m.nativeFirewall.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errNatNotSupported
}
return m.nativeFirewall.DeleteDNATRule(rule)
}
// DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte) bool {
return m.processOutgoingHooks(packetData)

View File

@@ -12,9 +12,9 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func TestPeerACLFiltering(t *testing.T) {
@@ -26,8 +26,8 @@ func TestPeerACLFiltering(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: localIP,
Network: wgNet,
}
@@ -288,8 +288,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: localIP,
Network: wgNet,
}

View File

@@ -16,15 +16,15 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
var logger = log.NewFromLogrus(logrus.StandardLogger())
type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error
AddressFunc func() iface.WGAddress
AddressFunc func() wgaddr.Address
GetWGDeviceFunc func() *wgdevice.Device
GetDeviceFunc func() *device.FilteredDevice
}
@@ -50,9 +50,9 @@ func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
return i.SetFilterFunc(iface)
}
func (i *IFaceMock) Address() iface.WGAddress {
func (i *IFaceMock) Address() wgaddr.Address {
if i.AddressFunc == nil {
return iface.WGAddress{}
return wgaddr.Address{}
}
return i.AddressFunc()
}
@@ -135,7 +135,7 @@ func TestManagerDeleteRule(t *testing.T) {
}
for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
if _, ok := m.incomingRules[ip.String()][r.ID()]; !ok {
t.Errorf("rule2 is not in the incomingRules")
}
}
@@ -149,7 +149,7 @@ func TestManagerDeleteRule(t *testing.T) {
}
for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; ok {
if _, ok := m.incomingRules[ip.String()][r.ID()]; ok {
t.Errorf("rule2 is not in the incomingRules")
}
}
@@ -268,8 +268,8 @@ func TestManagerReset(t *testing.T) {
func TestNotMatchByIP(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),

View File

@@ -13,6 +13,8 @@ import (
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type RecvMessage struct {
@@ -51,9 +53,10 @@ type ICEBind struct {
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
address wgaddr.Address
}
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
@@ -63,6 +66,7 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}),
closed: true,
address: address,
}
rc := receiverCreator{
@@ -142,9 +146,10 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {

View File

@@ -17,6 +17,8 @@ import (
"github.com/pion/logging"
"github.com/pion/stun/v2"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// FilterFn is a function that filters out candidates based on the address.
@@ -41,6 +43,7 @@ type UniversalUDPMuxParams struct {
XORMappedAddrCacheTTL time.Duration
Net transport.Net
FilterFn FilterFn
WGAddress wgaddr.Address
}
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
@@ -64,6 +67,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
mux: m,
logger: params.Logger,
filterFn: params.FilterFn,
address: params.WGAddress,
}
// embed UDPMux
@@ -118,6 +122,7 @@ type udpConn struct {
filterFn FilterFn
// TODO: reset cache on route changes
addrCache sync.Map
address wgaddr.Address
}
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
@@ -159,6 +164,11 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
return nil
}
if u.address.Network.Contains(a.AsSlice()) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
}
if isRouted, prefix, err := u.filterFn(a); err != nil {
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
} else {

View File

@@ -9,13 +9,14 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type WGTunDevice interface {
Create() (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error
WgAddress() WGAddress
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice

View File

@@ -13,11 +13,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
type WGTunDevice struct {
address WGAddress
address wgaddr.Address
port int
key string
mtu int
@@ -31,7 +32,7 @@ type WGTunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
return &WGTunDevice{
address: address,
port: port,
@@ -93,7 +94,7 @@ func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *WGTunDevice) UpdateAddr(addr WGAddress) error {
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
// todo implement
return nil
}
@@ -123,7 +124,7 @@ func (t *WGTunDevice) DeviceName() string {
return t.name
}
func (t *WGTunDevice) WgAddress() WGAddress {
func (t *WGTunDevice) WgAddress() wgaddr.Address {
return t.address
}

View File

@@ -13,11 +13,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type TunDevice struct {
name string
address WGAddress
address wgaddr.Address
port int
key string
mtu int
@@ -29,7 +30,7 @@ type TunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{
name: name,
address: address,
@@ -85,7 +86,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *TunDevice) UpdateAddr(address WGAddress) error {
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address
return t.assignAddr()
}
@@ -106,7 +107,7 @@ func (t *TunDevice) Close() error {
return nil
}
func (t *TunDevice) WgAddress() WGAddress {
func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address
}

View File

@@ -14,11 +14,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type TunDevice struct {
name string
address WGAddress
address wgaddr.Address
port int
key string
iceBind *bind.ICEBind
@@ -30,7 +31,7 @@ type TunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
func NewTunDevice(name string, address wgaddr.Address, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
return &TunDevice{
name: name,
address: address,
@@ -120,11 +121,11 @@ func (t *TunDevice) Close() error {
return nil
}
func (t *TunDevice) WgAddress() WGAddress {
func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address
}
func (t *TunDevice) UpdateAddr(addr WGAddress) error {
func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
// todo implement
return nil
}

View File

@@ -14,12 +14,13 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/sharedsock"
)
type TunKernelDevice struct {
name string
address WGAddress
address wgaddr.Address
wgPort int
key string
mtu int
@@ -34,7 +35,7 @@ type TunKernelDevice struct {
filterFn bind.FilterFn
}
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
ctx, cancel := context.WithCancel(context.Background())
return &TunKernelDevice{
ctx: ctx,
@@ -99,9 +100,10 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return nil, err
}
bindParams := bind.UniversalUDPMuxParams{
UDPConn: rawSock,
Net: t.transportNet,
FilterFn: t.filterFn,
UDPConn: rawSock,
Net: t.transportNet,
FilterFn: t.filterFn,
WGAddress: t.address,
}
mux := bind.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx)
@@ -112,7 +114,7 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return t.udpMux, nil
}
func (t *TunKernelDevice) UpdateAddr(address WGAddress) error {
func (t *TunKernelDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address
return t.assignAddr()
}
@@ -145,7 +147,7 @@ func (t *TunKernelDevice) Close() error {
return closErr
}
func (t *TunKernelDevice) WgAddress() WGAddress {
func (t *TunKernelDevice) WgAddress() wgaddr.Address {
return t.address
}

View File

@@ -13,12 +13,13 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
)
type TunNetstackDevice struct {
name string
address WGAddress
address wgaddr.Address
port int
key string
mtu int
@@ -34,7 +35,7 @@ type TunNetstackDevice struct {
net *netstack.Net
}
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
return &TunNetstackDevice{
name: name,
address: address,
@@ -97,7 +98,7 @@ func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *TunNetstackDevice) UpdateAddr(WGAddress) error {
func (t *TunNetstackDevice) UpdateAddr(wgaddr.Address) error {
return nil
}
@@ -116,7 +117,7 @@ func (t *TunNetstackDevice) Close() error {
return nil
}
func (t *TunNetstackDevice) WgAddress() WGAddress {
func (t *TunNetstackDevice) WgAddress() wgaddr.Address {
return t.address
}

View File

@@ -12,11 +12,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type USPDevice struct {
name string
address WGAddress
address wgaddr.Address
port int
key string
mtu int
@@ -28,7 +29,7 @@ type USPDevice struct {
configurer WGConfigurer
}
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
log.Infof("using userspace bind mode")
return &USPDevice{
@@ -93,7 +94,7 @@ func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *USPDevice) UpdateAddr(address WGAddress) error {
func (t *USPDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address
return t.assignAddr()
}
@@ -113,7 +114,7 @@ func (t *USPDevice) Close() error {
return nil
}
func (t *USPDevice) WgAddress() WGAddress {
func (t *USPDevice) WgAddress() wgaddr.Address {
return t.address
}

View File

@@ -13,13 +13,14 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}"
type TunDevice struct {
name string
address WGAddress
address wgaddr.Address
port int
key string
mtu int
@@ -32,7 +33,7 @@ type TunDevice struct {
configurer WGConfigurer
}
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{
name: name,
address: address,
@@ -118,7 +119,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
func (t *TunDevice) UpdateAddr(address WGAddress) error {
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address
return t.assignAddr()
}
@@ -139,7 +140,7 @@ func (t *TunDevice) Close() error {
}
return nil
}
func (t *TunDevice) WgAddress() WGAddress {
func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address
}

View File

@@ -6,6 +6,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/freebsd"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type wgLink struct {
@@ -56,7 +57,7 @@ func (l *wgLink) up() error {
return nil
}
func (l *wgLink) assignAddr(address WGAddress) error {
func (l *wgLink) assignAddr(address wgaddr.Address) error {
link, err := freebsd.LinkByName(l.name)
if err != nil {
return fmt.Errorf("link by name: %w", err)

View File

@@ -8,6 +8,8 @@ import (
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type wgLink struct {
@@ -90,7 +92,7 @@ func (l *wgLink) up() error {
return nil
}
func (l *wgLink) assignAddr(address WGAddress) error {
func (l *wgLink) assignAddr(address wgaddr.Address) error {
//delete existing addresses
list, err := netlink.AddrList(l, 0)
if err != nil {

View File

@@ -7,13 +7,14 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error
WgAddress() WGAddress
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice

View File

@@ -19,6 +19,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
@@ -28,8 +29,6 @@ const (
WgInterfaceDefault = configurer.WgInterfaceDefault
)
type WGAddress = device.WGAddress
type wgProxyFactory interface {
GetProxy() wgproxy.Proxy
Free() error
@@ -72,7 +71,7 @@ func (w *WGIface) Name() string {
}
// Address returns the interface address
func (w *WGIface) Address() device.WGAddress {
func (w *WGIface) Address() wgaddr.Address {
return w.tun.WgAddress()
}
@@ -103,7 +102,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
w.mu.Lock()
defer w.mu.Unlock()
addr, err := device.ParseWGAddress(newAddr)
addr, err := wgaddr.ParseWGAddress(newAddr)
if err != nil {
return err
}

View File

@@ -3,17 +3,18 @@ package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace := &WGIface{
userspaceBind: true,

View File

@@ -6,17 +6,18 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
var tun WGTunDevice
if netstack.IsEnabled() {

View File

@@ -5,17 +5,18 @@ package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),

View File

@@ -8,12 +8,13 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
@@ -21,7 +22,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{}
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
@@ -34,7 +35,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
return wgIFace, nil
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)

View File

@@ -4,16 +4,17 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
wgaddr "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
var tun WGTunDevice
if netstack.IsEnabled() {

View File

@@ -1,29 +1,29 @@
package device
package wgaddr
import (
"fmt"
"net"
)
// WGAddress WireGuard parsed address
type WGAddress struct {
// Address WireGuard parsed address
type Address struct {
IP net.IP
Network *net.IPNet
}
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func ParseWGAddress(address string) (WGAddress, error) {
func ParseWGAddress(address string) (Address, error) {
ip, network, err := net.ParseCIDR(address)
if err != nil {
return WGAddress{}, err
return Address{}, err
}
return WGAddress{
return Address{
IP: ip,
Network: network,
}, nil
}
func (addr WGAddress) String() string {
func (addr Address) String() string {
maskSize, _ := addr.Network.Mask.Size()
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
}

View File

@@ -22,6 +22,8 @@
!define UI_REG_APP_PATH "Software\Microsoft\Windows\CurrentVersion\App Paths\${UI_APP_EXE}"
!define UI_UNINSTALL_PATH "Software\Microsoft\Windows\CurrentVersion\Uninstall\${UI_APP_NAME}"
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
Unicode True
######################################################################
@@ -68,6 +70,9 @@ ShowInstDetails Show
!insertmacro MUI_PAGE_DIRECTORY
; Custom page for autostart checkbox
Page custom AutostartPage AutostartPageLeave
!insertmacro MUI_PAGE_INSTFILES
!insertmacro MUI_PAGE_FINISH
@@ -80,8 +85,36 @@ ShowInstDetails Show
!insertmacro MUI_LANGUAGE "English"
; Variables for autostart option
Var AutostartCheckbox
Var AutostartEnabled
######################################################################
; Function to create the autostart options page
Function AutostartPage
!insertmacro MUI_HEADER_TEXT "Startup Options" "Configure how ${APP_NAME} launches with Windows."
nsDialogs::Create 1018
Pop $0
${If} $0 == error
Abort
${EndIf}
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
Pop $AutostartCheckbox
${NSD_Check} $AutostartCheckbox ; Default to checked
StrCpy $AutostartEnabled "1" ; Default to enabled
nsDialogs::Show
FunctionEnd
; Function to handle leaving the autostart page
Function AutostartPageLeave
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
FunctionEnd
Function GetAppFromCommand
Exch $1
Push $2
@@ -163,6 +196,16 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
; Create autostart registry entry based on checkbox
DetailPrint "Autostart enabled: $AutostartEnabled"
${If} $AutostartEnabled == "1"
WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe"
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
${Else}
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DetailPrint "Autostart not enabled by user"
${EndIf}
EnVar::SetHKLM
EnVar::AddValueEx "path" "$INSTDIR"
@@ -186,7 +229,10 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
# kill ui client
ExecWait `taskkill /im ${UI_APP_EXE}.exe`
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
# wait the service uninstall take unblock the executable
Sleep 3000

View File

@@ -12,7 +12,7 @@ import (
type RuleID string
func (r RuleID) GetRuleID() string {
func (r RuleID) ID() string {
return string(r)
}

View File

@@ -245,7 +245,7 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
return "", fmt.Errorf("add route rule: %w", err)
}
return id.RuleID(addedRule.GetRuleID()), nil
return id.RuleID(addedRule.ID()), nil
}
func (d *DefaultManager) protoRuleToFirewallRule(
@@ -515,7 +515,7 @@ func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
for _, rules := range newRulePairs {
for _, rule := range rules {
if err := d.firewall.DeletePeerRule(rule); err != nil {
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.ID(), err)
}
}
}

View File

@@ -8,7 +8,7 @@ import (
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
@@ -45,7 +45,7 @@ func TestDefaultManager(t *testing.T) {
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip,
Network: network,
}).AnyTimes()
@@ -74,7 +74,7 @@ func TestDefaultManager(t *testing.T) {
t.Run("add extra rules", func(t *testing.T) {
existedPairs := map[string]struct{}{}
for id := range acl.peerRulesPairs {
existedPairs[id.GetRuleID()] = struct{}{}
existedPairs[id.ID()] = struct{}{}
}
// remove first rule
@@ -100,7 +100,7 @@ func TestDefaultManager(t *testing.T) {
// check that old rule was removed
previousCount := 0
for id := range acl.peerRulesPairs {
if _, ok := existedPairs[id.GetRuleID()]; ok {
if _, ok := existedPairs[id.ID()]; ok {
previousCount++
}
}
@@ -339,7 +339,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip,
Network: network,
}).AnyTimes()

View File

@@ -10,8 +10,8 @@ import (
gomock "github.com/golang/mock/gomock"
wgdevice "golang.zx2c4.com/wireguard/device"
iface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// MockIFaceMapper is a mock of IFaceMapper interface.
@@ -38,10 +38,10 @@ func (m *MockIFaceMapper) EXPECT() *MockIFaceMapperMockRecorder {
}
// Address mocks base method.
func (m *MockIFaceMapper) Address() iface.WGAddress {
func (m *MockIFaceMapper) Address() wgaddr.Address {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Address")
ret0, _ := ret[0].(iface.WGAddress)
ret0, _ := ret[0].(wgaddr.Address)
return ret0
}

View File

@@ -61,7 +61,7 @@ func NewConnectClient(
}
// Run with main logic.
func (c *ConnectClient) Run(runningChan chan error) error {
func (c *ConnectClient) Run(runningChan chan struct{}) error {
return c.run(MobileDependency{}, runningChan)
}
@@ -102,7 +102,7 @@ func (c *ConnectClient) RunOniOS(
return c.run(mobileDependency, nil)
}
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan error) error {
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error {
defer func() {
if r := recover(); r != nil {
rec := c.statusRecorder
@@ -159,10 +159,9 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}
defer c.statusRecorder.ClientStop()
runningChanOpen := true
operation := func() error {
// if context cancelled we not start new backoff cycle
if c.isContextCancelled() {
if c.ctx.Err() != nil {
return nil
}
@@ -282,10 +281,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
if runningChan != nil && runningChanOpen {
runningChan <- nil
close(runningChan)
runningChanOpen = false
if runningChan != nil {
select {
case runningChan <- struct{}{}:
default:
}
}
<-engineCtx.Done()
@@ -379,15 +379,6 @@ func (c *ConnectClient) Stop() error {
return nil
}
func (c *ConnectClient) isContextCancelled() bool {
select {
case <-c.ctx.Done():
return true
default:
return false
}
}
// SetNetworkMapPersistence enables or disables network map persistence.
// When enabled, the last received network map will be stored and can be retrieved
// through the Engine's getLatestNetworkMap method. When disabled, any stored

View File

@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
@@ -37,9 +38,9 @@ func (w *mocWGIface) Name() string {
panic("implement me")
}
func (w *mocWGIface) Address() iface.WGAddress {
func (w *mocWGIface) Address() wgaddr.Address {
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
return iface.WGAddress{
return wgaddr.Address{
IP: ip,
Network: network,
}

View File

@@ -5,15 +5,15 @@ package dns
import (
"net"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// WGIface defines subset methods of interface required for manager
type WGIface interface {
Name() string
Address() iface.WGAddress
Address() wgaddr.Address
ToInterface() *net.Interface
IsUserspaceBind() bool
GetFilter() device.PacketFilter

View File

@@ -1,15 +1,15 @@
package dns
import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// WGIface defines subset methods of interface required for manager
type WGIface interface {
Name() string
Address() iface.WGAddress
Address() wgaddr.Address
IsUserspaceBind() bool
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice

View File

@@ -25,7 +25,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
@@ -33,6 +33,7 @@ import (
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/guard"
@@ -169,10 +170,11 @@ type Engine struct {
statusRecorder *peer.Status
firewall manager.Manager
routeManager routemanager.Manager
acl acl.Manager
dnsForwardMgr *dnsfwd.Manager
firewall firewallManager.Manager
routeManager routemanager.Manager
acl acl.Manager
dnsForwardMgr *dnsfwd.Manager
ingressGatewayMgr *ingressgw.Manager
dnsServer dns.Server
@@ -266,6 +268,13 @@ func (e *Engine) Stop() error {
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
e.stopDNSServer()
if e.ingressGatewayMgr != nil {
if err := e.ingressGatewayMgr.Close(); err != nil {
log.Warnf("failed to cleanup forward rules: %v", err)
}
e.ingressGatewayMgr = nil
}
if e.routeManager != nil {
e.routeManager.Stop(e.stateManager)
}
@@ -469,15 +478,15 @@ func (e *Engine) initFirewall() error {
}
rosenpassPort := e.rpManager.GetAddress().Port
port := manager.Port{Values: []uint16{uint16(rosenpassPort)}}
port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}}
// this rule is static and will be torn down on engine down by the firewall manager
if _, err := e.firewall.AddPeerFiltering(
net.IP{0, 0, 0, 0},
manager.ProtocolUDP,
firewallManager.ProtocolUDP,
nil,
&port,
manager.ActionAccept,
firewallManager.ActionAccept,
"",
"",
); err != nil {
@@ -505,10 +514,10 @@ func (e *Engine) blockLanAccess() {
if _, err := e.firewall.AddRouteFiltering(
[]netip.Prefix{v4},
network,
manager.ProtocolALL,
firewallManager.ProtocolALL,
nil,
nil,
manager.ActionDrop,
firewallManager.ActionDrop,
); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add fw rule for network %s: %w", network, err))
}
@@ -912,6 +921,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Errorf("failed to update clientRoutes, err: %v", err)
}
// Ingress forward rules
if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil {
log.Errorf("failed to update forward rules, err: %v", err)
}
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
e.updateOfflinePeers(networkMap.GetOfflinePeers())
@@ -1482,7 +1496,7 @@ func (e *Engine) GetRouteManager() routemanager.Manager {
}
// GetFirewallManager returns the firewall manager
func (e *Engine) GetFirewallManager() manager.Manager {
func (e *Engine) GetFirewallManager() firewallManager.Manager {
return e.firewall
}
@@ -1575,16 +1589,19 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)
}
// restartEngine restarts the engine by cancelling the client context
func (e *Engine) restartEngine() {
log.Info("restarting engine")
CtxGetState(e.ctx).Set(StatusConnecting)
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
if e.ctx.Err() != nil {
return
}
log.Info("restarting engine")
CtxGetState(e.ctx).Set(StatusConnecting)
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
log.Infof("cancelling client, engine will be recreated")
log.Infof("cancelling client context, engine will be recreated")
e.clientCancel()
}
@@ -1596,34 +1613,17 @@ func (e *Engine) startNetworkMonitor() {
e.networkMonitor = networkmonitor.New()
go func() {
var mu sync.Mutex
var debounceTimer *time.Timer
// Start the network monitor with a callback, Start will block until the monitor is stopped,
// a network change is detected, or an error occurs on start up
err := e.networkMonitor.Start(e.ctx, func() {
// This function is called when a network change is detected
mu.Lock()
defer mu.Unlock()
if debounceTimer != nil {
log.Infof("Network monitor: detected network change, reset debounceTimer")
debounceTimer.Stop()
if err := e.networkMonitor.Listen(e.ctx); err != nil {
if errors.Is(err, context.Canceled) {
log.Infof("network monitor stopped")
return
}
// Set a new timer to debounce rapid network changes
debounceTimer = time.AfterFunc(2*time.Second, func() {
// This function is called after the debounce period
mu.Lock()
defer mu.Unlock()
log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
})
})
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
log.Errorf("Network monitor: %v", err)
log.Errorf("network monitor error: %v", err)
return
}
log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
}()
}
@@ -1770,6 +1770,74 @@ func (e *Engine) Address() (netip.Addr, error) {
return ip.Unmap(), nil
}
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error {
if e.firewall == nil {
log.Warn("firewall is disabled, not updating forwarding rules")
return nil
}
if len(rules) == 0 {
if e.ingressGatewayMgr == nil {
return nil
}
err := e.ingressGatewayMgr.Close()
e.ingressGatewayMgr = nil
e.statusRecorder.SetIngressGwMgr(nil)
return err
}
if e.ingressGatewayMgr == nil {
mgr := ingressgw.NewManager(e.firewall)
e.ingressGatewayMgr = mgr
e.statusRecorder.SetIngressGwMgr(mgr)
}
var merr *multierror.Error
forwardingRules := make([]firewallManager.ForwardRule, 0, len(rules))
for _, rule := range rules {
proto, err := convertToFirewallProtocol(rule.GetProtocol())
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("failed to convert protocol '%s': %w", rule.GetProtocol(), err))
continue
}
dstPortInfo, err := convertPortInfo(rule.GetDestinationPort())
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("invalid destination port '%v': %w", rule.GetDestinationPort(), err))
continue
}
translateIP, err := convertToIP(rule.GetTranslatedAddress())
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("failed to convert translated address '%s': %w", rule.GetTranslatedAddress(), err))
continue
}
translatePort, err := convertPortInfo(rule.GetTranslatedPort())
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("invalid translate port '%v': %w", rule.GetTranslatedPort(), err))
continue
}
forwardRule := firewallManager.ForwardRule{
Protocol: proto,
DestinationPort: *dstPortInfo,
TranslatedAddress: translateIP,
TranslatedPort: *translatePort,
}
forwardingRules = append(forwardingRules, forwardRule)
}
log.Infof("updating forwarding rules: %d", len(forwardingRules))
if err := e.ingressGatewayMgr.Update(forwardingRules); err != nil {
log.Errorf("failed to update forwarding rules: %v", err)
}
return nberrors.FormatErrorOrNil(merr)
}
// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
for _, check := range checks {

View File

@@ -31,6 +31,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -44,6 +45,7 @@ import (
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -74,7 +76,7 @@ type MockWGIface struct {
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
IsUserspaceBindFunc func() bool
NameFunc func() string
AddressFunc func() device.WGAddress
AddressFunc func() wgaddr.Address
ToInterfaceFunc func() *net.Interface
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error
@@ -113,7 +115,7 @@ func (m *MockWGIface) Name() string {
return m.NameFunc()
}
func (m *MockWGIface) Address() device.WGAddress {
func (m *MockWGIface) Address() wgaddr.Address {
return m.AddressFunc()
}
@@ -363,8 +365,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
RemovePeerFunc: func(peerKey string) error {
return nil
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
@@ -1433,7 +1435,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock())
if err != nil {
return nil, "", err
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
@@ -20,7 +21,7 @@ type wgIfaceBase interface {
CreateOnAndroid(routeRange []string, ip string, domains []string) error
IsUserspaceBind() bool
Name() string
Address() device.WGAddress
Address() wgaddr.Address
ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error

View File

@@ -0,0 +1,107 @@
package ingressgw
import (
"fmt"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
type DNATFirewall interface {
AddDNATRule(fwdRule firewall.ForwardRule) (firewall.Rule, error)
DeleteDNATRule(rule firewall.Rule) error
}
type RulePair struct {
firewall.ForwardRule
firewall.Rule
}
type Manager struct {
dnatFirewall DNATFirewall
rules map[string]RulePair // keys is the ID of the ForwardRule
rulesMu sync.Mutex
}
func NewManager(dnatFirewall DNATFirewall) *Manager {
return &Manager{
dnatFirewall: dnatFirewall,
rules: make(map[string]RulePair),
}
}
func (h *Manager) Update(forwardRules []firewall.ForwardRule) error {
h.rulesMu.Lock()
defer h.rulesMu.Unlock()
var mErr *multierror.Error
toDelete := make(map[string]RulePair, len(h.rules))
for id, r := range h.rules {
toDelete[id] = r
}
// Process new/updated rules
for _, fwdRule := range forwardRules {
id := fwdRule.ID()
if _, ok := h.rules[id]; ok {
delete(toDelete, id)
continue
}
rule, err := h.dnatFirewall.AddDNATRule(fwdRule)
if err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("add forward rule '%s': %v", fwdRule.String(), err))
continue
}
log.Infof("forward rule has been added '%s'", fwdRule)
h.rules[id] = RulePair{
ForwardRule: fwdRule,
Rule: rule,
}
}
// Remove deleted rules
for id, rulePair := range toDelete {
if err := h.dnatFirewall.DeleteDNATRule(rulePair.Rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete forward rule '%s': %v", rulePair.ForwardRule.String(), err))
}
log.Infof("forward rule has been deleted '%s'", rulePair.ForwardRule)
delete(h.rules, id)
}
return nberrors.FormatErrorOrNil(mErr)
}
func (h *Manager) Close() error {
h.rulesMu.Lock()
defer h.rulesMu.Unlock()
log.Infof("clean up all (%d) forward rules", len(h.rules))
var mErr *multierror.Error
for _, rule := range h.rules {
if err := h.dnatFirewall.DeleteDNATRule(rule.Rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete forward rule '%s': %v", rule, err))
}
}
h.rules = make(map[string]RulePair)
return nberrors.FormatErrorOrNil(mErr)
}
func (h *Manager) Rules() []firewall.ForwardRule {
h.rulesMu.Lock()
defer h.rulesMu.Unlock()
rules := make([]firewall.ForwardRule, 0, len(h.rules))
for _, rulePair := range h.rules {
rules = append(rules, rulePair.ForwardRule)
}
return rules
}

View File

@@ -0,0 +1,281 @@
package ingressgw
import (
"fmt"
"net/netip"
"testing"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
var (
_ firewall.Rule = (*MocFwRule)(nil)
_ DNATFirewall = &MockDNATFirewall{}
)
type MocFwRule struct {
id string
}
func (m *MocFwRule) ID() string {
return string(m.id)
}
type MockDNATFirewall struct {
throwError bool
}
func (m *MockDNATFirewall) AddDNATRule(fwdRule firewall.ForwardRule) (firewall.Rule, error) {
if m.throwError {
return nil, fmt.Errorf("moc error")
}
fwRule := &MocFwRule{
id: fwdRule.ID(),
}
return fwRule, nil
}
func (m *MockDNATFirewall) DeleteDNATRule(rule firewall.Rule) error {
if m.throwError {
return fmt.Errorf("moc error")
}
return nil
}
func (m *MockDNATFirewall) forceToThrowErrors() {
m.throwError = true
}
func TestManager_AddRule(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
updates := []firewall.ForwardRule{
{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
},
{
Protocol: firewall.ProtocolUDP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}}
if err := mgr.Update(updates); err != nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != len(updates) {
t.Errorf("unexpected rules count: %d", len(rules))
}
}
func TestManager_UpdateRule(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
ruleUDP := firewall.ForwardRule{
Protocol: firewall.ProtocolUDP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.2"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleUDP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != 1 {
t.Errorf("unexpected rules count: %d", len(rules))
}
if rules[0].TranslatedAddress.String() != ruleUDP.TranslatedAddress.String() {
t.Errorf("unexpected rule: %v", rules[0])
}
if rules[0].TranslatedPort.String() != ruleUDP.TranslatedPort.String() {
t.Errorf("unexpected rule: %v", rules[0])
}
if rules[0].DestinationPort.String() != ruleUDP.DestinationPort.String() {
t.Errorf("unexpected rule: %v", rules[0])
}
if rules[0].Protocol != ruleUDP.Protocol {
t.Errorf("unexpected rule: %v", rules[0])
}
}
func TestManager_ExtendRules(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
ruleUDP := firewall.ForwardRule{
Protocol: firewall.ProtocolUDP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.2"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP, ruleUDP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != 2 {
t.Errorf("unexpected rules count: %d", len(rules))
}
}
func TestManager_UnderlingError(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
ruleUDP := firewall.ForwardRule{
Protocol: firewall.ProtocolUDP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.2"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
fw.forceToThrowErrors()
if err := mgr.Update([]firewall.ForwardRule{ruleTCP, ruleUDP}); err == nil {
t.Errorf("expected error")
}
rules := mgr.Rules()
if len(rules) != 1 {
t.Errorf("unexpected rules count: %d", len(rules))
}
}
func TestManager_Cleanup(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
if err := mgr.Update([]firewall.ForwardRule{}); err != nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != 0 {
t.Errorf("unexpected rules count: %d", len(rules))
}
}
func TestManager_DeleteBrokenRule(t *testing.T) {
fw := &MockDNATFirewall{}
// force to throw errors when Add DNAT Rule
fw.forceToThrowErrors()
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err == nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != 0 {
t.Errorf("unexpected rules count: %d", len(rules))
}
// simulate that to remove a broken rule
if err := mgr.Update([]firewall.ForwardRule{}); err != nil {
t.Errorf("unexpected error: %v", err)
}
if err := mgr.Close(); err != nil {
t.Errorf("unexpected error: %v", err)
}
}
func TestManager_Close(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
if err := mgr.Close(); err != nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != 0 {
t.Errorf("unexpected rules count: %d", len(rules))
}
}

View File

@@ -0,0 +1,58 @@
package internal
import (
"errors"
"fmt"
"net"
"net/netip"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewallManager.Protocol, error) {
switch protocol {
case mgmProto.RuleProtocol_TCP:
return firewallManager.ProtocolTCP, nil
case mgmProto.RuleProtocol_UDP:
return firewallManager.ProtocolUDP, nil
case mgmProto.RuleProtocol_ICMP:
return firewallManager.ProtocolICMP, nil
case mgmProto.RuleProtocol_ALL:
return firewallManager.ProtocolALL, nil
default:
return "", fmt.Errorf("invalid protocol type: %s", protocol.String())
}
}
func convertPortInfo(portInfo *mgmProto.PortInfo) (*firewallManager.Port, error) {
if portInfo == nil {
return nil, errors.New("portInfo cannot be nil")
}
if portInfo.GetPort() != 0 {
return firewallManager.NewPort(int(portInfo.GetPort()))
}
if portInfo.GetRange() != nil {
return firewallManager.NewPort(int(portInfo.GetRange().Start), int(portInfo.GetRange().End))
}
return nil, fmt.Errorf("invalid portInfo: %v", portInfo)
}
func convertToIP(rawIP []byte) (netip.Addr, error) {
if rawIP == nil {
return netip.Addr{}, errors.New("input bytes cannot be nil")
}
if len(rawIP) != net.IPv4len && len(rawIP) != net.IPv6len {
return netip.Addr{}, fmt.Errorf("invalid IP length: %d", len(rawIP))
}
if len(rawIP) == net.IPv4len {
return netip.AddrFrom4([4]byte(rawIP)), nil
}
return netip.AddrFrom16([16]byte(rawIP)), nil
}

View File

@@ -16,7 +16,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("failed to open routing socket: %v", err)
@@ -28,18 +28,10 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
}
}()
go func() {
<-ctx.Done()
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Debugf("Network monitor: closed routing socket: %v", err)
}
}()
for {
select {
case <-ctx.Done():
return ErrStopped
return ctx.Err()
default:
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
@@ -76,11 +68,11 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
switch msg.Type {
case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
go callback()
return nil
case unix.RTM_DELETE:
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
go callback()
return nil
}
}
}

View File

@@ -14,7 +14,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
if nexthopv4.Intf == nil && nexthopv6.Intf == nil {
return errors.New("no interfaces available")
}
@@ -31,8 +31,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
for {
select {
case <-ctx.Done():
return ErrStopped
return ctx.Err()
// handle route changes
case route := <-routeChan:
// default route and main table
@@ -43,12 +42,10 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
// triggered on added/replaced routes
case syscall.RTM_NEWROUTE:
log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex)
go callback()
return nil
case syscall.RTM_DELROUTE:
if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) {
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
go callback()
return nil
}
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
routeMonitor, err := systemops.NewRouteMonitor(ctx)
if err != nil {
return fmt.Errorf("failed to create route monitor: %w", err)
@@ -24,20 +24,20 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
for {
select {
case <-ctx.Done():
return ErrStopped
return ctx.Err()
case route := <-routeMonitor.RouteUpdates():
if route.Destination.Bits() != 0 {
continue
}
if routeChanged(route, nexthopv4, nexthopv6, callback) {
break
if routeChanged(route, nexthopv4, nexthopv6) {
return nil
}
}
}
}
func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) bool {
func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop) bool {
intf := "<nil>"
if route.Interface != nil {
intf = route.Interface.Name
@@ -51,18 +51,15 @@ func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Ne
case systemops.RouteModified:
// TODO: get routing table to figure out if our route is affected for modified routes
log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf)
go callback()
return true
case systemops.RouteAdded:
if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP {
log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf)
go callback()
return true
}
case systemops.RouteDeleted:
if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf)
go callback()
return true
}
}

View File

@@ -1,12 +1,27 @@
//go:build !ios && !android
package networkmonitor
import (
"context"
"errors"
"fmt"
"net/netip"
"runtime/debug"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
var ErrStopped = errors.New("monitor has been stopped")
const (
debounceTime = 2 * time.Second
)
var checkChangeFn = checkChange
// NetworkMonitor watches for changes in network configuration.
type NetworkMonitor struct {
@@ -19,3 +34,99 @@ type NetworkMonitor struct {
func New() *NetworkMonitor {
return &NetworkMonitor{}
}
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
nw.mu.Lock()
if nw.cancel != nil {
nw.mu.Unlock()
return errors.New("network monitor already started")
}
ctx, nw.cancel = context.WithCancel(ctx)
defer nw.cancel()
nw.wg.Add(1)
nw.mu.Unlock()
defer nw.wg.Done()
var nexthop4, nexthop6 systemops.Nexthop
operation := func() error {
var errv4, errv6 error
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
if errv4 != nil && errv6 != nil {
return errors.New("failed to get default next hops")
}
if errv4 == nil {
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
}
if errv6 == nil {
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
}
// continue if either route was found
return nil
}
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
if err := backoff.Retry(operation, expBackOff); err != nil {
return fmt.Errorf("failed to get default next hops: %w", err)
}
// recover in case sys ops panic
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack())
}
}()
event := make(chan struct{}, 1)
go nw.checkChanges(ctx, event, nexthop4, nexthop6)
// debounce changes
timer := time.NewTimer(0)
timer.Stop()
for {
select {
case <-event:
timer.Reset(debounceTime)
case <-timer.C:
return nil
case <-ctx.Done():
timer.Stop()
return ctx.Err()
}
}
}
// Stop stops the network monitor.
func (nw *NetworkMonitor) Stop() {
nw.mu.Lock()
defer nw.mu.Unlock()
if nw.cancel == nil {
return
}
nw.cancel()
nw.wg.Wait()
}
func (nw *NetworkMonitor) checkChanges(ctx context.Context, event chan struct{}, nexthop4 systemops.Nexthop, nexthop6 systemops.Nexthop) {
for {
if err := checkChangeFn(ctx, nexthop4, nexthop6); err != nil {
close(event)
return
}
// prevent blocking
select {
case event <- struct{}{}:
default:
}
}
}

View File

@@ -1,82 +0,0 @@
//go:build !ios && !android
package networkmonitor
import (
"context"
"errors"
"fmt"
"net/netip"
"runtime/debug"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns.
func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error) {
if ctx.Err() != nil {
return ctx.Err()
}
nw.mu.Lock()
ctx, nw.cancel = context.WithCancel(ctx)
nw.mu.Unlock()
nw.wg.Add(1)
defer nw.wg.Done()
var nexthop4, nexthop6 systemops.Nexthop
operation := func() error {
var errv4, errv6 error
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
if errv4 != nil && errv6 != nil {
return errors.New("failed to get default next hops")
}
if errv4 == nil {
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
}
if errv6 == nil {
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
}
// continue if either route was found
return nil
}
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
if err := backoff.Retry(operation, expBackOff); err != nil {
return fmt.Errorf("failed to get default next hops: %w", err)
}
// recover in case sys ops panic
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack())
}
}()
if err := checkChange(ctx, nexthop4, nexthop6, callback); err != nil {
return fmt.Errorf("check change: %w", err)
}
return nil
}
// Stop stops the network monitor.
func (nw *NetworkMonitor) Stop() {
nw.mu.Lock()
defer nw.mu.Unlock()
if nw.cancel != nil {
nw.cancel()
nw.wg.Wait()
}
}

View File

@@ -2,10 +2,21 @@
package networkmonitor
import "context"
import (
"context"
"fmt"
)
func (nw *NetworkMonitor) Start(context.Context, func()) error {
return nil
type NetworkMonitor struct {
}
// New creates a new network monitor.
func New() *NetworkMonitor {
return &NetworkMonitor{}
}
func (nw *NetworkMonitor) Listen(_ context.Context) error {
return fmt.Errorf("network monitor not supported on mobile platforms")
}
func (nw *NetworkMonitor) Stop() {

View File

@@ -0,0 +1,99 @@
package networkmonitor
import (
"context"
"errors"
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
type MocMultiEvent struct {
counter int
}
func (m *MocMultiEvent) checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
if m.counter == 0 {
<-ctx.Done()
return ctx.Err()
}
time.Sleep(1 * time.Second)
m.counter--
return nil
}
func TestNetworkMonitor_Close(t *testing.T) {
checkChangeFn = func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
<-ctx.Done()
return ctx.Err()
}
nw := New()
var resErr error
done := make(chan struct{})
go func() {
resErr = nw.Listen(context.Background())
close(done)
}()
time.Sleep(1 * time.Second) // wait for the goroutine to start
nw.Stop()
<-done
if !errors.Is(resErr, context.Canceled) {
t.Errorf("unexpected error: %v", resErr)
}
}
func TestNetworkMonitor_Event(t *testing.T) {
checkChangeFn = func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
timeout, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
select {
case <-ctx.Done():
return ctx.Err()
case <-timeout.Done():
return nil
}
}
nw := New()
defer nw.Stop()
var resErr error
done := make(chan struct{})
go func() {
resErr = nw.Listen(context.Background())
close(done)
}()
<-done
if !errors.Is(resErr, nil) {
t.Errorf("unexpected error: %v", nil)
}
}
func TestNetworkMonitor_MultiEvent(t *testing.T) {
eventsRepeated := 3
me := &MocMultiEvent{counter: eventsRepeated}
checkChangeFn = me.checkChange
nw := New()
defer nw.Stop()
done := make(chan struct{})
started := time.Now()
go func() {
if resErr := nw.Listen(context.Background()); resErr != nil {
t.Errorf("unexpected error: %v", resErr)
}
close(done)
}()
<-done
expectedResponseTime := time.Duration(eventsRepeated)*time.Second + debounceTime
if time.Since(started) < expectedResponseTime {
t.Errorf("unexpected duration: %v", time.Since(started))
}
}

View File

@@ -114,6 +114,9 @@ type Conn struct {
guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup
// debug purpose
dumpState *stateDump
}
// NewConn creates a new not opened Conn to the remote peer.
@@ -137,10 +140,11 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(),
semaphore: semaphore,
dumpState: newStateDump(connLog),
}
ctrl := isController(config)
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager)
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager, conn.dumpState)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
@@ -160,6 +164,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
go conn.handshaker.Listen()
go conn.dumpState.Start(ctx)
return conn, nil
}
@@ -193,6 +198,7 @@ func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
defer conn.semaphore.Done(conn.ctx)
conn.waitInitialRandomSleepTime(ctx)
conn.dumpState.SendOffer()
err := conn.handshaker.sendOffer()
if err != nil {
conn.log.Errorf("failed to send initial offer: %v", err)
@@ -251,12 +257,14 @@ func (conn *Conn) Close() {
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready
func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
conn.log.Debugf("OnRemoteAnswer, status ICE: %s, status relay: %s", conn.statusICE, conn.statusRelay)
conn.dumpState.RemoteAnswer()
conn.log.Infof("OnRemoteAnswer, status ICE: %s, status relay: %s", conn.statusICE, conn.statusRelay)
return conn.handshaker.OnRemoteAnswer(answer)
}
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
conn.dumpState.RemoteCandidate()
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
}
@@ -278,7 +286,8 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
}
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {
conn.log.Debugf("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
conn.dumpState.RemoteOffer()
conn.log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
return conn.handshaker.OnRemoteOffer(offer)
}
@@ -322,6 +331,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
}
conn.log.Infof("set ICE to active connection")
conn.dumpState.P2PConnected()
var (
ep *net.UDPAddr
@@ -329,6 +339,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
err error
)
if iceConnInfo.RelayedOnLocal {
conn.dumpState.NewLocalProxy()
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
@@ -390,6 +401,7 @@ func (conn *Conn) onICEStateDisconnected() {
// switch back to relay connection
if conn.isReadyToUpgrade() {
conn.log.Infof("ICE disconnected, set Relay to active connection")
conn.dumpState.SwitchToRelay()
conn.wgProxyRelay.Work()
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
@@ -432,6 +444,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
return
}
conn.dumpState.RelayConnected()
conn.log.Debugf("Relay connection has been established, setup the WireGuard")
wgProxy, err := conn.newProxy(rci.relayedConn)
@@ -439,6 +452,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return
}
conn.dumpState.NewLocalProxy()
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
@@ -481,10 +495,10 @@ func (conn *Conn) onRelayDisconnected() {
return
}
conn.log.Debugf("relay connection is disconnected")
conn.log.Infof("relay connection is disconnected")
if conn.currentConnPriority == connPriorityRelay {
conn.log.Debugf("clean up WireGuard config")
conn.log.Infof("clean up WireGuard config")
if err := conn.removeWgPeer(); err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err)
}
@@ -516,7 +530,8 @@ func (conn *Conn) listenGuardEvent(ctx context.Context) {
for {
select {
case <-conn.guard.Reconnect:
conn.log.Debugf("send offer to peer")
conn.log.Infof("send offer to peer")
conn.dumpState.SendOffer()
if err := conn.handshaker.SendOffer(); err != nil {
conn.log.Errorf("failed to send offer: %v", err)
}

View File

@@ -76,19 +76,19 @@ func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAn
func (h *Handshaker) Listen() {
for {
h.log.Debugf("wait for remote offer confirmation")
h.log.Info("wait for remote offer confirmation")
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation()
if err != nil {
var connectionClosedError *ConnectionClosedError
if errors.As(err, &connectionClosedError) {
h.log.Tracef("stop handshaker")
h.log.Info("exit from handshaker")
return
}
h.log.Errorf("failed to received remote offer confirmation: %s", err)
continue
}
h.log.Debugf("received connection confirmation, running version %s and with remote WireGuard listen port %d", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort)
h.log.Infof("received connection confirmation, running version %s and with remote WireGuard listen port %d", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort)
for _, listener := range h.onNewOfferListeners {
go listener(remoteOfferAnswer)
}
@@ -108,7 +108,7 @@ func (h *Handshaker) OnRemoteOffer(offer OfferAnswer) bool {
case h.remoteOffersCh <- offer:
return true
default:
h.log.Debugf("OnRemoteOffer skipping message because is not ready")
h.log.Warnf("OnRemoteOffer skipping message because is not ready")
// connection might not be ready yet to receive so we ignore the message
return false
}
@@ -131,8 +131,7 @@ func (h *Handshaker) waitForRemoteOfferConfirmation() (*OfferAnswer, error) {
select {
case remoteOfferAnswer := <-h.remoteOffersCh:
// received confirmation from the remote peer -> ready to proceed
err := h.sendAnswer()
if err != nil {
if err := h.sendAnswer(); err != nil {
return nil, err
}
return &remoteOfferAnswer, nil
@@ -168,7 +167,7 @@ func (h *Handshaker) sendOffer() error {
}
func (h *Handshaker) sendAnswer() error {
h.log.Debugf("sending answer")
h.log.Infof("sending answer")
uFrag, pwd := h.ice.GetLocalUserCredentials()
answer := OfferAnswer{

View File

@@ -8,6 +8,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
@@ -16,4 +17,5 @@ type WGIface interface {
RemovePeer(peerKey string) error
GetStats(peerKey string) (configurer.WGStats, error)
GetProxy() wgproxy.Proxy
Address() wgaddr.Address
}

View File

@@ -0,0 +1,112 @@
package peer
import (
"context"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
type stateDump struct {
log *log.Entry
sentOffer int
remoteOffer int
remoteAnswer int
remoteCandidate int
p2pConnected int
switchToRelay int
wgCheckSuccess int
relayConnected int
localProxies int
mu sync.Mutex
}
func newStateDump(log *log.Entry) *stateDump {
return &stateDump{
log: log,
}
}
func (s *stateDump) Start(ctx context.Context) {
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.dumpState()
case <-ctx.Done():
return
}
}
}
func (s *stateDump) RemoteOffer() {
s.mu.Lock()
defer s.mu.Unlock()
s.remoteOffer++
}
func (s *stateDump) RemoteCandidate() {
s.mu.Lock()
defer s.mu.Unlock()
s.remoteCandidate++
}
func (s *stateDump) SendOffer() {
s.mu.Lock()
defer s.mu.Unlock()
s.sentOffer++
}
func (s *stateDump) dumpState() {
s.mu.Lock()
defer s.mu.Unlock()
s.log.Infof("Dump stat: SentOffer: %d, RemoteOffer: %d, RemoteAnswer: %d, RemoteCandidate: %d, P2PConnected: %d, SwitchToRelay: %d, WGCheckSuccess: %d, RelayConnected: %d, LocalProxies: %d",
s.sentOffer, s.remoteOffer, s.remoteAnswer, s.remoteCandidate, s.p2pConnected, s.switchToRelay, s.wgCheckSuccess, s.relayConnected, s.localProxies)
}
func (s *stateDump) RemoteAnswer() {
s.mu.Lock()
defer s.mu.Unlock()
s.remoteAnswer++
}
func (s *stateDump) P2PConnected() {
s.mu.Lock()
defer s.mu.Unlock()
s.p2pConnected++
}
func (s *stateDump) SwitchToRelay() {
s.mu.Lock()
defer s.mu.Unlock()
s.switchToRelay++
}
func (s *stateDump) WGcheckSuccess() {
s.mu.Lock()
defer s.mu.Unlock()
s.wgCheckSuccess++
}
func (s *stateDump) RelayConnected() {
s.mu.Lock()
defer s.mu.Unlock()
s.relayConnected++
}
func (s *stateDump) NewLocalProxy() {
s.mu.Lock()
defer s.mu.Unlock()
s.localProxies++
}

View File

@@ -14,7 +14,9 @@ import (
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/domain"
@@ -132,13 +134,14 @@ type NSGroupState struct {
// FullStatus contains the full state held by the Status instance
type FullStatus struct {
Peers []State
ManagementState ManagementState
SignalState SignalState
LocalPeerState LocalPeerState
RosenpassState RosenpassState
Relays []relay.ProbeResult
NSGroupStates []NSGroupState
Peers []State
ManagementState ManagementState
SignalState SignalState
LocalPeerState LocalPeerState
RosenpassState RosenpassState
Relays []relay.ProbeResult
NSGroupStates []NSGroupState
NumOfForwardingRules int
}
// Status holds a state of peers, signal, management connections and relays
@@ -171,6 +174,8 @@ type Status struct {
eventMux sync.RWMutex
eventStreams map[string]chan *proto.SystemEvent
eventQueue *EventQueue
ingressGwMgr *ingressgw.Manager
}
// NewRecorder returns a new Status instance
@@ -193,6 +198,12 @@ func (d *Status) SetRelayMgr(manager *relayClient.Manager) {
d.relayMgr = manager
}
func (d *Status) SetIngressGwMgr(ingressGwMgr *ingressgw.Manager) {
d.mux.Lock()
defer d.mux.Unlock()
d.ingressGwMgr = ingressGwMgr
}
// ReplaceOfflinePeers replaces
func (d *Status) ReplaceOfflinePeers(replacement []State) {
d.mux.Lock()
@@ -235,6 +246,18 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
return state, nil
}
func (d *Status) PeerByIP(ip string) (string, bool) {
d.mux.Lock()
defer d.mux.Unlock()
for _, state := range d.peers {
if state.IP == ip {
return state.FQDN, true
}
}
return "", false
}
// RemovePeer removes peer from Daemon status map
func (d *Status) RemovePeer(peerPubKey string) error {
d.mux.Lock()
@@ -734,6 +757,16 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
return append(relayStates, relayState)
}
func (d *Status) ForwardingRules() []firewall.ForwardRule {
d.mux.Lock()
defer d.mux.Unlock()
if d.ingressGwMgr == nil {
return nil
}
return d.ingressGwMgr.Rules()
}
func (d *Status) GetDNSStates() []NSGroupState {
d.mux.Lock()
defer d.mux.Unlock()
@@ -751,11 +784,12 @@ func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo
// GetFullStatus gets full status
func (d *Status) GetFullStatus() FullStatus {
fullStatus := FullStatus{
ManagementState: d.GetManagementState(),
SignalState: d.GetSignalState(),
Relays: d.GetRelayStates(),
RosenpassState: d.GetRosenpassState(),
NSGroupStates: d.GetDNSStates(),
ManagementState: d.GetManagementState(),
SignalState: d.GetSignalState(),
Relays: d.GetRelayStates(),
RosenpassState: d.GetRosenpassState(),
NSGroupStates: d.GetDNSStates(),
NumOfForwardingRules: len(d.ForwardingRules()),
}
d.mux.Lock()

View File

@@ -27,6 +27,7 @@ type WGWatcher struct {
log *log.Entry
wgIfaceStater WGInterfaceStater
peerKey string
stateDump *stateDump
ctx context.Context
ctxCancel context.CancelFunc
@@ -34,11 +35,12 @@ type WGWatcher struct {
waitGroup sync.WaitGroup
}
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string) *WGWatcher {
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
return &WGWatcher{
log: log,
wgIfaceStater: wgIfaceStater,
peerKey: peerKey,
stateDump: stateDump,
}
}
@@ -105,6 +107,7 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex
resetTime := time.Until(handshake.Add(checkPeriod))
timer.Reset(resetTime)
w.stateDump.WGcheckSuccess()
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
case <-ctx.Done():

View File

@@ -43,7 +43,7 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
mlog := log.WithField("peer", "tet")
mocWgIface := &MocWgIface{}
watcher := NewWGWatcher(mlog, mocWgIface, "")
watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump(mlog))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -72,7 +72,7 @@ func TestWGWatcher_ReEnable(t *testing.T) {
mlog := log.WithField("peer", "tet")
mocWgIface := &MocWgIface{}
watcher := NewWGWatcher(mlog, mocWgIface, "")
watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump(mlog))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

View File

@@ -358,6 +358,12 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
}
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
addr, err := netip.ParseAddr(candidate.Address())
if err != nil {
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
return false
}
var routePrefixes []netip.Prefix
for _, routes := range clientRoutes {
if len(routes) > 0 && routes[0] != nil {
@@ -365,14 +371,8 @@ func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool
}
}
addr, err := netip.ParseAddr(candidate.Address())
if err != nil {
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
return false
}
for _, prefix := range routePrefixes {
// default route is
// default route is handled by route exclusion / ip rules
if prefix.Bits() == 0 {
continue
}

View File

@@ -33,14 +33,14 @@ type WorkerRelay struct {
wgWatcher *WGWatcher
}
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService) *WorkerRelay {
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay {
r := &WorkerRelay{
log: log,
isController: ctrl,
config: config,
conn: conn,
relayManager: relayManager,
wgWatcher: NewWGWatcher(log, config.WgConfig.WgInterface, config.Key),
wgWatcher: NewWGWatcher(log, config.WgConfig.WgInterface, config.Key, stateDump),
}
return r
}

View File

@@ -302,7 +302,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error
// If the chosen route is the same as the current route, do nothing
if c.currentChosen != nil && c.currentChosen.ID == newChosenID &&
c.currentChosen.IsEqual(c.routes[newChosenID]) {
c.currentChosen.Equal(c.routes[newChosenID]) {
return nil
}

View File

@@ -3,9 +3,9 @@ package iface
import (
"net"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type wgIfaceBase interface {
@@ -13,7 +13,7 @@ type wgIfaceBase interface {
RemoveAllowedIP(peerKey string, allowedIP string) error
Name() string
Address() iface.WGAddress
Address() wgaddr.Address
ToInterface() *net.Interface
IsUserspaceBind() bool
GetFilter() device.PacketFilter

View File

@@ -0,0 +1,51 @@
package ipfwdstate
import (
"fmt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
// IPForwardingState is a struct that keeps track of the IP forwarding state.
// todo: read initial state of the IP forwarding from the system and reset the state based on it
type IPForwardingState struct {
enabledCounter int
}
func NewIPForwardingState() *IPForwardingState {
return &IPForwardingState{}
}
func (f *IPForwardingState) RequestForwarding() error {
if f.enabledCounter != 0 {
f.enabledCounter++
return nil
}
if err := systemops.EnableIPForwarding(); err != nil {
return fmt.Errorf("failed to enable IP forwarding with sysctl: %w", err)
}
f.enabledCounter = 1
log.Info("IP forwarding enabled")
return nil
}
func (f *IPForwardingState) ReleaseForwarding() error {
if f.enabledCounter == 0 {
return nil
}
if f.enabledCounter > 1 {
f.enabledCounter--
return nil
}
// if failed to disable IP forwarding we anyway decrement the counter
f.enabledCounter = 0
// todo call systemops.DisableIPForwarding()
return nil
}

View File

@@ -13,7 +13,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/route"
)
@@ -41,7 +40,7 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
for routeID := range m.routes {
update, found := routesMap[routeID]
if !found || !update.IsEqual(m.routes[routeID]) {
if !found || !update.Equal(m.routes[routeID]) {
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
}
}
@@ -71,9 +70,6 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
}
if len(m.routes) > 0 {
if err := systemops.EnableIPForwarding(); err != nil {
return fmt.Errorf("enable ip forwarding: %w", err)
}
if err := m.firewall.EnableRouting(); err != nil {
return fmt.Errorf("enable routing: %w", err)
}

File diff suppressed because it is too large Load Diff

View File

@@ -8,6 +8,8 @@ option go_package = "/proto";
package daemon;
message EmptyRequest {}
service DaemonService {
// Login uses setup key to prepare configuration for the daemon.
rpc Login(LoginRequest) returns (LoginResponse) {}
@@ -37,6 +39,8 @@ service DaemonService {
// Deselect specific routes
rpc DeselectNetworks(SelectNetworksRequest) returns (SelectNetworksResponse) {}
rpc ForwardingRules(EmptyRequest) returns (ForwardingRulesResponse) {}
// DebugBundle creates a debug bundle
rpc DebugBundle(DebugBundleRequest) returns (DebugBundleResponse) {}
@@ -267,10 +271,12 @@ message FullStatus {
repeated PeerState peers = 4;
repeated RelayState relays = 5;
repeated NSGroupState dns_servers = 6;
int32 NumberOfForwardingRules = 8;
repeated SystemEvent events = 7;
}
// Networks
message ListNetworksRequest {
}
@@ -291,7 +297,6 @@ message IPList {
repeated string ips = 1;
}
message Network {
string ID = 1;
string range = 2;
@@ -300,6 +305,33 @@ message Network {
map<string, IPList> resolvedIPs = 5;
}
// ForwardingRules
message PortInfo {
oneof portSelection {
uint32 port = 1;
Range range = 2;
}
message Range {
uint32 start = 1;
uint32 end = 2;
}
}
message ForwardingRule {
string protocol = 1;
PortInfo destinationPort = 2;
string translatedAddress = 3;
string translatedHostname = 4;
PortInfo translatedPort = 5;
}
message ForwardingRulesResponse {
repeated ForwardingRule rules = 1;
}
// DebugBundler
message DebugBundleRequest {
bool anonymize = 1;
string status = 2;

View File

@@ -37,6 +37,7 @@ type DaemonServiceClient interface {
SelectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error)
// Deselect specific routes
DeselectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error)
ForwardingRules(ctx context.Context, in *EmptyRequest, opts ...grpc.CallOption) (*ForwardingRulesResponse, error)
// DebugBundle creates a debug bundle
DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error)
// GetLogLevel gets the log level of the daemon
@@ -145,6 +146,15 @@ func (c *daemonServiceClient) DeselectNetworks(ctx context.Context, in *SelectNe
return out, nil
}
func (c *daemonServiceClient) ForwardingRules(ctx context.Context, in *EmptyRequest, opts ...grpc.CallOption) (*ForwardingRulesResponse, error) {
out := new(ForwardingRulesResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/ForwardingRules", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error) {
out := new(DebugBundleResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/DebugBundle", in, out, opts...)
@@ -281,6 +291,7 @@ type DaemonServiceServer interface {
SelectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error)
// Deselect specific routes
DeselectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error)
ForwardingRules(context.Context, *EmptyRequest) (*ForwardingRulesResponse, error)
// DebugBundle creates a debug bundle
DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error)
// GetLogLevel gets the log level of the daemon
@@ -332,6 +343,9 @@ func (UnimplementedDaemonServiceServer) SelectNetworks(context.Context, *SelectN
func (UnimplementedDaemonServiceServer) DeselectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method DeselectNetworks not implemented")
}
func (UnimplementedDaemonServiceServer) ForwardingRules(context.Context, *EmptyRequest) (*ForwardingRulesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method ForwardingRules not implemented")
}
func (UnimplementedDaemonServiceServer) DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method DebugBundle not implemented")
}
@@ -537,6 +551,24 @@ func _DaemonService_DeselectNetworks_Handler(srv interface{}, ctx context.Contex
return interceptor(ctx, in, info, handler)
}
func _DaemonService_ForwardingRules_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(EmptyRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).ForwardingRules(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/ForwardingRules",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).ForwardingRules(ctx, req.(*EmptyRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_DebugBundle_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(DebugBundleRequest)
if err := dec(in); err != nil {
@@ -763,6 +795,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "DeselectNetworks",
Handler: _DaemonService_DeselectNetworks_Handler,
},
{
MethodName: "ForwardingRules",
Handler: _DaemonService_ForwardingRules_Handler,
},
{
MethodName: "DebugBundle",
Handler: _DaemonService_DebugBundle_Handler,

View File

@@ -53,7 +53,7 @@ The files in this bundle have been anonymized to protect sensitive information.
IP Addresses
IPv4 addresses are replaced with addresses starting from 192.51.100.0
IPv4 addresses are replaced with addresses starting from 198.51.100.0
IPv6 addresses are replaced with addresses starting from 100::
IP addresses from non public ranges and well known addresses are not anonymized (e.g. 8.8.8.8, 100.64.0.0/10, addresses starting with 192.168., 172.16., 10., etc.).

View File

@@ -0,0 +1,54 @@
package server
import (
"context"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/proto"
)
func (s *Server) ForwardingRules(context.Context, *proto.EmptyRequest) (*proto.ForwardingRulesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
rules := s.statusRecorder.ForwardingRules()
responseRules := make([]*proto.ForwardingRule, 0, len(rules))
for _, rule := range rules {
respRule := &proto.ForwardingRule{
Protocol: string(rule.Protocol),
DestinationPort: portToProto(rule.DestinationPort),
TranslatedAddress: rule.TranslatedAddress.String(),
TranslatedHostname: s.hostNameByTranslateAddress(rule.TranslatedAddress.String()),
TranslatedPort: portToProto(rule.TranslatedPort),
}
responseRules = append(responseRules, respRule)
}
return &proto.ForwardingRulesResponse{Rules: responseRules}, nil
}
func (s *Server) hostNameByTranslateAddress(ip string) string {
hostName, ok := s.statusRecorder.PeerByIP(ip)
if !ok {
return ip
}
return hostName
}
func portToProto(port firewall.Port) *proto.PortInfo {
var portInfo proto.PortInfo
if !port.IsRange {
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(port.Values[0])}
} else {
portInfo.PortSelection = &proto.PortInfo_Range_{
Range: &proto.PortInfo_Range{
Start: uint32(port.Values[0]),
End: uint32(port.Values[1]),
},
}
}
return &portInfo
}

View File

@@ -36,8 +36,13 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
return nil, fmt.Errorf("not connected")
}
routesMap := engine.GetRouteManager().GetClientRoutesWithNetID()
routeSelector := engine.GetRouteManager().GetRouteSelector()
routeMgr := engine.GetRouteManager()
if routeMgr == nil {
return nil, fmt.Errorf("no route manager")
}
routesMap := routeMgr.GetClientRoutesWithNetID()
routeSelector := routeMgr.GetRouteSelector()
var routes []*selectRoute
for id, rt := range routesMap {
@@ -123,6 +128,10 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
}
routeManager := engine.GetRouteManager()
if routeManager == nil {
return nil, fmt.Errorf("no route manager")
}
routeSelector := routeManager.GetRouteSelector()
if req.GetAll() {
routeSelector.SelectAllRoutes()
@@ -165,6 +174,10 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
}
routeManager := engine.GetRouteManager()
if routeManager == nil {
return nil, fmt.Errorf("no route manager")
}
routeSelector := routeManager.GetRouteSelector()
if req.GetAll() {
routeSelector.DeselectAllRoutes()

View File

@@ -3,7 +3,7 @@ package server
import (
"fmt"
"os"
"path/filepath"
"path"
"syscall"
log "github.com/sirupsen/logrus"
@@ -12,7 +12,6 @@ import (
)
const (
windowsPanicLogEnvVar = "NB_WINDOWS_PANIC_LOG"
// STD_ERROR_HANDLE ((DWORD)-12) = 4294967284
stdErrorHandle = ^uintptr(11)
)
@@ -25,13 +24,10 @@ var (
)
func handlePanicLog() error {
logPath := os.Getenv(windowsPanicLogEnvVar)
if logPath == "" {
return nil
}
// TODO: move this to a central location
logDir := path.Join(os.Getenv("PROGRAMDATA"), "Netbird")
logPath := path.Join(logDir, "netbird.err")
// Ensure the directory exists
logDir := filepath.Dir(logPath)
if err := os.MkdirAll(logDir, 0750); err != nil {
return fmt.Errorf("create panic log directory: %w", err)
}
@@ -39,13 +35,11 @@ func handlePanicLog() error {
return fmt.Errorf("enforce permission on panic log file: %w", err)
}
// Open log file with append mode
f, err := os.OpenFile(logPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
if err != nil {
return fmt.Errorf("open panic log file: %w", err)
}
// Redirect stderr to the file
if err = redirectStderr(f); err != nil {
if closeErr := f.Close(); closeErr != nil {
log.Warnf("failed to close file after redirect error: %v", closeErr)
@@ -59,7 +53,6 @@ func handlePanicLog() error {
// redirectStderr redirects stderr to the provided file
func redirectStderr(f *os.File) error {
// Get the current process's stderr handle
if err := setStdHandle(f); err != nil {
return fmt.Errorf("failed to set stderr handle: %w", err)
}

View File

@@ -160,7 +160,7 @@ func (s *Server) Start() error {
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status,
runningChan chan error,
runningChan chan struct{},
) {
backOff := getConnectWithBackoff(ctx)
retryStarted := false
@@ -628,20 +628,21 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
runningChan := make(chan error)
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan)
timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second)
defer cancel()
runningChan := make(chan struct{}, 1) // buffered channel to do not lose the signal
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan)
for {
select {
case err := <-runningChan:
if err != nil {
log.Debugf("waiting for engine to become ready failed: %s", err)
} else {
return &proto.UpResponse{}, nil
}
case <-runningChan:
return &proto.UpResponse{}, nil
case <-callerCtx.Done():
log.Debug("context done, stopping the wait for engine to become ready")
return nil, callerCtx.Err()
case <-timeoutCtx.Done():
log.Debug("up is timed out, stopping the wait for engine to become ready")
return nil, timeoutCtx.Err()
}
}
}
@@ -810,6 +811,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
for _, peerState := range fullStatus.Peers {
pbPeerState := &proto.PeerState{

View File

@@ -20,6 +20,7 @@ import (
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -128,7 +129,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock())
if err != nil {
return nil, "", err
}

Some files were not shown because too many files have changed in this diff Show More