Compare commits

...

15 Commits

Author SHA1 Message Date
Zoltan Papp
0c4c2a8cb2 Listen on all address 2023-12-13 22:39:12 +01:00
Zoltan Papp
9c6c944e46 Add socks5 proxy 2023-12-13 16:58:31 +01:00
Zoltan Papp
cec4232860 Netstack poc 2023-12-12 17:35:16 +01:00
Maycon Santos
2d1dfa3ae7 Fix jwks validation and flag/config overriding (#1380)
Ensure the jwks expiresInTime is not zero and add a log indicating the new expiration time

Replace the configuration property only when the flag is being used
2023-12-12 14:56:27 +01:00
Yury Gargay
5961c8330e Fix SaveOrAddUser and GetPeers methods in MockAccountManager (#1374) 2023-12-11 17:32:10 +01:00
Bethuel Mmbaga
d275d411aa Enable JWT group-based user authorization (#1368)
* Extend management API to support list of allowed JWT groups (#1366)

* Add JWTAllowGroups settings to account management

* Return an empty group list if jwt allow groups is not set

* Add JwtAllowGroups to account settings in handler test

* Add JWT group-based user authorization (#1373)

* Add JWTAllowGroups settings to account management

* Return an empty group list if jwt allow groups is not set

* Add JwtAllowGroups to account settings in handler test

* Implement user access validation authentication based on JWT groups

* Remove the slices package import due to compatibility issues with the gitHub workflow(s) Go version

* Refactor auth middleware and test for extracted claim handling

* Optimize JWT group check in auth middleware to cover nil and empty allowed groups
2023-12-11 18:59:15 +03:00
Yury Gargay
5ecafef5d2 Fix ListUsers method in MockAccountManager (#1367) 2023-12-11 15:00:02 +01:00
Yury Gargay
d073a250cc Specify ref for sync tag workflow (#1365) 2023-12-08 14:18:49 +01:00
Yury Gargay
a1c48468ab Add Dev Container Support section in contributing guideline (#1363) 2023-12-08 11:54:50 +01:00
Maycon Santos
dd1e730454 Update API descriptions and examples (#1364) 2023-12-08 11:39:33 +01:00
Yury Gargay
050f140245 Add sync-tag.yml GitHub workflow (#1362) 2023-12-08 10:55:31 +01:00
Zoltan Papp
006ba32086 Fix/acl for forward (#1305)
Fix ACL on routed traffic and code refactor
2023-12-08 10:48:21 +01:00
Yury Gargay
b03343bc4d Add sync-main.yml GitHub workflow (#1359) 2023-12-06 17:51:11 +01:00
pascal-fischer
36d62f1844 Merge pull request #1358 from netbirdio/fix/tests-after-peer-validation
Fix tests after peer validation
2023-12-06 15:39:37 +01:00
Pascal Fischer
08733ed8d5 update tests 2023-12-06 15:02:10 +01:00
77 changed files with 4161 additions and 3996 deletions

22
.github/workflows/sync-main.yml vendored Normal file
View File

@@ -0,0 +1,22 @@
name: sync main
on:
push:
branches:
- main
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
trigger_sync_main:
runs-on: ubuntu-latest
steps:
- name: Trigger main branch sync
uses: benc-uk/workflow-dispatch@v1
with:
workflow: sync-main.yml
repo: ${{ secrets.UPSTREAM_REPO }}
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "sha": "${{ github.sha }}" }'

23
.github/workflows/sync-tag.yml vendored Normal file
View File

@@ -0,0 +1,23 @@
name: sync tag
on:
push:
tags:
- 'v*'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
trigger_sync_tag:
runs-on: ubuntu-latest
steps:
- name: Trigger release tag sync
uses: benc-uk/workflow-dispatch@v1
with:
workflow: sync-tag.yml
ref: main
repo: ${{ secrets.UPSTREAM_REPO }}
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'

View File

@@ -19,6 +19,7 @@ If you haven't already, join our slack workspace [here](https://join.slack.com/t
- [Development setup](#development-setup)
- [Requirements](#requirements)
- [Local NetBird setup](#local-netbird-setup)
- [Dev Container Support](#dev-container-support)
- [Build and start](#build-and-start)
- [Test suite](#test-suite)
- [Checklist before submitting a PR](#checklist-before-submitting-a-pr)
@@ -135,6 +136,48 @@ checked out and set up:
go mod tidy
```
### Dev Container Support
If you prefer using a dev container for development, NetBird now includes support for dev containers.
Dev containers provide a consistent and isolated development environment, making it easier for contributors to get started quickly. Follow the steps below to set up NetBird in a dev container.
#### 1. Prerequisites:
* Install Docker on your machine: [Docker Installation Guide](https://docs.docker.com/get-docker/)
* Install Visual Studio Code: [VS Code Installation Guide](https://code.visualstudio.com/download)
* If you prefer JetBrains Goland please follow this [manual](https://www.jetbrains.com/help/go/connect-to-devcontainer.html)
#### 2. Clone the Repository:
Clone the repository following previous [Local NetBird setup](#local-netbird-setup).
#### 3. Open in project in IDE of your choice:
**VScode**:
Open the project folder in Visual Studio Code:
```bash
code .
```
When you open the project in VS Code, it will detect the presence of a dev container configuration.
Click on the green "Reopen in Container" button in the bottom-right corner of VS Code.
**Goland**:
Open GoLand and select `"File" > "Open"` to open the NetBird project folder.
GoLand will detect the dev container configuration and prompt you to open the project in the container. Accept the prompt.
#### 4. Wait for the Container to Build:
VsCode or GoLand will use the specified Docker image to build the dev container. This might take some time, depending on your internet connection.
#### 6. Development:
Once the container is built, you can start developing within the dev container. All the necessary dependencies and configurations are set up within the container.
### Build and start
#### Client

32
client/firewall/create.go Normal file
View File

@@ -0,0 +1,32 @@
//go:build !linux || android
package firewall
import (
"context"
"fmt"
"runtime"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
)
// NewFirewall creates a firewall manager instance
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
}
// use userspace packet filtering firewall
fm, err := uspfilter.Create(iface)
if err != nil {
return nil, err
}
err = fm.AllowNetbird()
if err != nil {
log.Warnf("failed to allow netbird interface traffic: %v", err)
}
return fm, nil
}

View File

@@ -0,0 +1,107 @@
//go:build !android
package firewall
import (
"context"
"fmt"
"os"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
log "github.com/sirupsen/logrus"
nbiptables "github.com/netbirdio/netbird/client/firewall/iptables"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
)
const (
// UNKNOWN is the default value for the firewall type for unknown firewall type
UNKNOWN FWType = iota
// IPTABLES is the value for the iptables firewall type
IPTABLES
// NFTABLES is the value for the nftables firewall type
NFTABLES
)
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
var fm firewall.Manager
var errFw error
switch check() {
case IPTABLES:
log.Debug("creating an iptables firewall manager")
fm, errFw = nbiptables.Create(context, iface)
if errFw != nil {
log.Errorf("failed to create iptables manager: %s", errFw)
}
case NFTABLES:
log.Debug("creating an nftables firewall manager")
fm, errFw = nbnftables.Create(context, iface)
if errFw != nil {
log.Errorf("failed to create nftables manager: %s", errFw)
}
default:
errFw = fmt.Errorf("no firewall manager found")
log.Debug("no firewall manager found, try to use userspace packet filtering firewall")
}
if iface.IsUserspaceBind() {
var errUsp error
if errFw == nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
} else {
fm, errUsp = uspfilter.Create(iface)
}
if errUsp != nil {
log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
return nil, errUsp
}
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
return fm, nil
}
if errFw != nil {
return nil, errFw
}
return fm, nil
}
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func check() FWType {
nf := nftables.Conn{}
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
return NFTABLES
}
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return UNKNOWN
}
if isIptablesClientAvailable(ip) {
return IPTABLES
}
return UNKNOWN
}
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}

11
client/firewall/iface.go Normal file
View File

@@ -0,0 +1,11 @@
package firewall
import "github.com/netbirdio/netbird/iface"
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
SetFilter(iface.PacketFilter) error
}

View File

@@ -0,0 +1,473 @@
package iptables
import (
"fmt"
"net"
"strconv"
"github.com/coreos/go-iptables/iptables"
"github.com/google/uuid"
"github.com/nadoo/ipset"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
tableName = "filter"
// rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT"
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
postRoutingMark = "0x000007e4"
)
type aclManager struct {
iptablesClient *iptables.IPTables
wgIface iFaceMapper
routeingFwChainName string
entries map[string][][]string
ipsetStore *ipsetStore
}
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) {
m := &aclManager{
iptablesClient: iptablesClient,
wgIface: wgIface,
routeingFwChainName: routeingFwChainName,
entries: make(map[string][][]string),
ipsetStore: newIpsetStore(),
}
err := ipset.Init()
if err != nil {
return nil, fmt.Errorf("failed to init ipset: %w", err)
}
m.seedInitialEntries()
err = m.cleanChains()
if err != nil {
return nil, err
}
err = m.createDefaultChains()
if err != nil {
return nil, err
}
return m, nil
}
func (m *aclManager) AddFiltering(
ip net.IP,
protocol firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
var dPortVal, sPortVal string
if dPort != nil && dPort.Values != nil {
// TODO: we support only one port per rule in current implementation of ACLs
dPortVal = strconv.Itoa(dPort.Values[0])
}
if sPort != nil && sPort.Values != nil {
sPortVal = strconv.Itoa(sPort.Values[0])
}
var chain string
if direction == firewall.RuleDirectionOUT {
chain = chainNameOutputRules
} else {
chain = chainNameInputRules
}
ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal)
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
if ipsetName != "" {
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
if err := ipset.Add(ipsetName, ip.String()); err != nil {
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
}
// if ruleset already exists it means we already have the firewall rule
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
ipList.addIP(ip.String())
return []firewall.Rule{&Rule{
ruleID: uuid.New().String(),
ipsetName: ipsetName,
ip: ip.String(),
chain: chain,
specs: specs,
}}, nil
}
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %s before use it: %s", ipsetName, err)
}
if err := ipset.Create(ipsetName); err != nil {
return nil, fmt.Errorf("failed to create ipset: %w", err)
}
if err := ipset.Add(ipsetName, ip.String()); err != nil {
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
}
ipList := newIpList(ip.String())
m.ipsetStore.addIpList(ipsetName, ipList)
}
ok, err := m.iptablesClient.Exists("filter", chain, specs...)
if err != nil {
return nil, fmt.Errorf("failed to check rule: %w", err)
}
if ok {
return nil, fmt.Errorf("rule already exists")
}
if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil {
return nil, err
}
rule := &Rule{
ruleID: uuid.New().String(),
specs: specs,
ipsetName: ipsetName,
ip: ip.String(),
chain: chain,
}
if !shouldAddToPrerouting(protocol, dPort, direction) {
return []firewall.Rule{rule}, nil
}
rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip)
if err != nil {
return []firewall.Rule{rule}, err
}
return []firewall.Rule{rule, rulePrerouting}, nil
}
// DeleteRule from the firewall by rule definition
func (m *aclManager) DeleteRule(rule firewall.Rule) error {
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
}
if r.chain == "PREROUTING" {
goto DELETERULE
}
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
// delete IP from ruleset IPs list and ipset
if _, ok := ipsetList.ips[r.ip]; ok {
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
return fmt.Errorf("failed to delete ip from ipset: %w", err)
}
delete(ipsetList.ips, r.ip)
}
// if after delete, set still contains other IPs,
// no need to delete firewall rule and we should exit here
if len(ipsetList.ips) != 0 {
return nil
}
// we delete last IP from the set, that means we need to delete
// set itself and associated firewall rule too
m.ipsetStore.deleteIpset(r.ipsetName)
if err := ipset.Destroy(r.ipsetName); err != nil {
log.Errorf("delete empty ipset: %v", err)
}
}
DELETERULE:
var table string
if r.chain == "PREROUTING" {
table = "mangle"
} else {
table = "filter"
}
err := m.iptablesClient.Delete(table, r.chain, r.specs...)
if err != nil {
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
}
return err
}
func (m *aclManager) Reset() error {
return m.cleanChains()
}
func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) {
var src []string
if ipsetName != "" {
src = []string{"-m", "set", "--set", ipsetName, "src"}
} else {
src = []string{"-s", ip.String()}
}
specs := []string{
"-d", m.wgIface.Address().IP.String(),
"-p", protocol,
"--dport", port,
"-j", "MARK", "--set-mark", postRoutingMark,
}
specs = append(src, specs...)
ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...)
if err != nil {
return nil, fmt.Errorf("failed to check rule: %w", err)
}
if ok {
return nil, fmt.Errorf("rule already exists")
}
if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil {
return nil, err
}
rule := &Rule{
ruleID: uuid.New().String(),
specs: specs,
ipsetName: ipsetName,
ip: ip.String(),
chain: "PREROUTING",
}
return rule, nil
}
// todo write less destructive cleanup mechanism
func (m *aclManager) cleanChains() error {
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
rules := m.entries["OUTPUT"]
for _, rule := range rules {
err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules)
if err != nil {
log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err)
return err
}
}
ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules)
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
for _, rule := range m.entries["INPUT"] {
err := m.iptablesClient.DeleteIfExists(tableName, "INPUT", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
for _, rule := range m.entries["FORWARD"] {
err := m.iptablesClient.DeleteIfExists(tableName, "FORWARD", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameInputRules)
if err != nil {
log.Debugf("failed to clear and delete %s chain: %s", chainNameInputRules, err)
return err
}
}
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
for _, rule := range m.entries["PREROUTING"] {
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = m.iptablesClient.ClearChain("mangle", "PREROUTING")
if err != nil {
log.Debugf("failed to clear %s chain: %s", "PREROUTING", err)
return err
}
}
for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
}
if err := ipset.Destroy(ipsetName); err != nil {
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
}
m.ipsetStore.deleteIpset(ipsetName)
}
return nil
}
func (m *aclManager) createDefaultChains() error {
// chain netbird-acl-input-rules
if err := m.iptablesClient.NewChain(tableName, chainNameInputRules); err != nil {
log.Debugf("failed to create '%s' chain: %s", chainNameInputRules, err)
return err
}
// chain netbird-acl-output-rules
if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil {
log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err)
return err
}
for chainName, rules := range m.entries {
for _, rule := range rules {
if chainName == "FORWARD" {
// position 2 because we add it after router's, jump rule
if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
return err
}
} else {
if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
return err
}
}
}
}
return nil
}
func (m *aclManager) seedInitialEntries() {
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("FORWARD",
[]string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
m.appendToEntries("FORWARD",
[]string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName})
m.appendToEntries("PREROUTING",
[]string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark})
}
func (m *aclManager) appendToEntries(chainName string, spec []string) {
m.entries[chainName] = append(m.entries[chainName], spec)
}
// filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs(
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
) (specs []string) {
matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0
if ip.String() == "0.0.0.0" {
matchByIP = false
}
switch direction {
case firewall.RuleDirectionIN:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
} else {
specs = append(specs, "-s", ip.String())
}
}
case firewall.RuleDirectionOUT:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
} else {
specs = append(specs, "-d", ip.String())
}
}
}
if protocol != "all" {
specs = append(specs, "-p", protocol)
}
if sPort != "" {
specs = append(specs, "--sport", sPort)
}
if dPort != "" {
specs = append(specs, "--dport", dPort)
}
return append(specs, "-j", actionToStr(action))
}
func actionToStr(action firewall.Action) string {
if action == firewall.ActionAccept {
return "ACCEPT"
}
return "DROP"
}
func transformIPsetName(ipsetName string, sPort, dPort string) string {
switch {
case ipsetName == "":
return ""
case sPort != "" && dPort != "":
return ipsetName + "-sport-dport"
case sPort != "":
return ipsetName + "-sport"
case dPort != "":
return ipsetName + "-dport"
default:
return ipsetName
}
}
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
if proto == "all" {
return false
}
if direction != firewall.RuleDirectionIN {
return false
}
if dPort == nil {
return false
}
return true
}

View File

@@ -1,43 +1,27 @@
package iptables
import (
"context"
"fmt"
"net"
"strconv"
"sync"
"github.com/coreos/go-iptables/iptables"
"github.com/google/uuid"
"github.com/nadoo/ipset"
log "github.com/sirupsen/logrus"
fw "github.com/netbirdio/netbird/client/firewall"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface"
)
const (
// ChainInputFilterName is the name of the chain that is used for filtering incoming packets
ChainInputFilterName = "NETBIRD-ACL-INPUT"
// ChainOutputFilterName is the name of the chain that is used for filtering outgoing packets
ChainOutputFilterName = "NETBIRD-ACL-OUTPUT"
)
// dropAllDefaultRule in the Netbird chain
var dropAllDefaultRule = []string{"-j", "DROP"}
// Manager of iptables firewall
type Manager struct {
mutex sync.Mutex
wgIface iFaceMapper
ipv4Client *iptables.IPTables
ipv6Client *iptables.IPTables
inputDefaultRuleSpecs []string
outputDefaultRuleSpecs []string
wgIface iFaceMapper
rulesets map[string]ruleset
aclMgr *aclManager
router *routerManager
}
// iFaceMapper defines subset methods of interface required for manager
@@ -47,47 +31,29 @@ type iFaceMapper interface {
IsUserspaceBind() bool
}
type ruleset struct {
rule *Rule
ips map[string]string
}
// Create iptables firewall manager
func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) {
m := &Manager{
wgIface: wgIface,
inputDefaultRuleSpecs: []string{
"-i", wgIface.Name(), "-j", ChainInputFilterName, "-s", wgIface.Address().String()},
outputDefaultRuleSpecs: []string{
"-o", wgIface.Name(), "-j", ChainOutputFilterName, "-d", wgIface.Address().String()},
rulesets: make(map[string]ruleset),
}
err := ipset.Init()
if err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
}
// init clients for booth ipv4 and ipv6
m.ipv4Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
}
if ipv6Supported {
m.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
log.Warnf("ip6tables is not installed in the system or not supported: %v. Access rules for this protocol won't be applied.", err)
}
m := &Manager{
wgIface: wgIface,
ipv4Client: iptablesClient,
}
if m.ipv4Client == nil && m.ipv6Client == nil {
return nil, fmt.Errorf("iptables is not installed in the system or not enough permissions to use it")
m.router, err = newRouterManager(context, iptablesClient)
if err != nil {
log.Debugf("failed to initialize route related chains: %s", err)
return nil, err
}
m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName())
if err != nil {
log.Debugf("failed to initialize ACL manager: %s", err)
return nil, err
}
if err := m.Reset(); err != nil {
return nil, fmt.Errorf("failed to reset firewall: %v", err)
}
return m, nil
}
@@ -96,159 +62,44 @@ func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) {
// Comment will be ignored because some system this feature is not supported
func (m *Manager) AddFiltering(
ip net.IP,
protocol fw.Protocol,
sPort *fw.Port,
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
protocol firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipsetName string,
comment string,
) (fw.Rule, error) {
) ([]firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
client, err := m.client(ip)
if err != nil {
return nil, err
}
var dPortVal, sPortVal string
if dPort != nil && dPort.Values != nil {
// TODO: we support only one port per rule in current implementation of ACLs
dPortVal = strconv.Itoa(dPort.Values[0])
}
if sPort != nil && sPort.Values != nil {
sPortVal = strconv.Itoa(sPort.Values[0])
}
ipsetName = m.transformIPsetName(ipsetName, sPortVal, dPortVal)
ruleID := uuid.New().String()
if ipsetName != "" {
rs, rsExists := m.rulesets[ipsetName]
if !rsExists {
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q before use it: %v", ipsetName, err)
}
if err := ipset.Create(ipsetName); err != nil {
return nil, fmt.Errorf("failed to create ipset: %w", err)
}
}
if err := ipset.Add(ipsetName, ip.String()); err != nil {
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
}
if rsExists {
// if ruleset already exists it means we already have the firewall rule
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
rs.ips[ip.String()] = ruleID
return &Rule{
ruleID: ruleID,
ipsetName: ipsetName,
ip: ip.String(),
dst: direction == fw.RuleDirectionOUT,
v6: ip.To4() == nil,
}, nil
}
// this is new ipset so we need to create firewall rule for it
}
specs := m.filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
if direction == fw.RuleDirectionOUT {
ok, err := client.Exists("filter", ChainOutputFilterName, specs...)
if err != nil {
return nil, fmt.Errorf("check is output rule already exists: %w", err)
}
if ok {
return nil, fmt.Errorf("input rule already exists")
}
if err := client.Insert("filter", ChainOutputFilterName, 1, specs...); err != nil {
return nil, err
}
} else {
ok, err := client.Exists("filter", ChainInputFilterName, specs...)
if err != nil {
return nil, fmt.Errorf("check is input rule already exists: %w", err)
}
if ok {
return nil, fmt.Errorf("input rule already exists")
}
if err := client.Insert("filter", ChainInputFilterName, 1, specs...); err != nil {
return nil, err
}
}
rule := &Rule{
ruleID: ruleID,
specs: specs,
ipsetName: ipsetName,
ip: ip.String(),
dst: direction == fw.RuleDirectionOUT,
v6: ip.To4() == nil,
}
if ipsetName != "" {
// ipset name is defined and it means that this rule was created
// for it, need to associate it with ruleset
m.rulesets[ipsetName] = ruleset{
rule: rule,
ips: map[string]string{rule.ip: ruleID},
}
}
return rule, nil
return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
}
// DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule fw.Rule) error {
func (m *Manager) DeleteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
}
return m.aclMgr.DeleteRule(rule)
}
client := m.ipv4Client
if r.v6 {
if m.ipv6Client == nil {
return fmt.Errorf("ipv6 is not supported")
}
client = m.ipv6Client
}
func (m *Manager) IsServerRouteSupported() bool {
return true
}
if rs, ok := m.rulesets[r.ipsetName]; ok {
// delete IP from ruleset IPs list and ipset
if _, ok := rs.ips[r.ip]; ok {
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
return fmt.Errorf("failed to delete ip from ipset: %w", err)
}
delete(rs.ips, r.ip)
}
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
// if after delete, set still contains other IPs,
// no need to delete firewall rule and we should exit here
if len(rs.ips) != 0 {
return nil
}
return m.router.InsertRoutingRules(pair)
}
// we delete last IP from the set, that means we need to delete
// set itself and associated firewall rule too
delete(m.rulesets, r.ipsetName)
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := ipset.Destroy(r.ipsetName); err != nil {
log.Errorf("delete empty ipset: %v", err)
}
r = rs.rule
}
if r.dst {
return client.Delete("filter", ChainOutputFilterName, r.specs...)
}
return client.Delete("filter", ChainInputFilterName, r.specs...)
return m.router.RemoveRoutingRules(pair)
}
// Reset firewall to the default state
@@ -256,223 +107,49 @@ func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.reset(m.ipv4Client, "filter"); err != nil {
return fmt.Errorf("clean ipv4 firewall ACL input chain: %w", err)
errAcl := m.aclMgr.Reset()
if errAcl != nil {
log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl)
}
if m.ipv6Client != nil {
if err := m.reset(m.ipv6Client, "filter"); err != nil {
return fmt.Errorf("clean ipv6 firewall ACL input chain: %w", err)
}
errMgr := m.router.Reset()
if errMgr != nil {
log.Errorf("failed to clean up router rules from firewall: %s", errMgr)
return errMgr
}
return nil
return errAcl
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
if m.wgIface.IsUserspaceBind() {
_, err := m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
fw.RuleDirectionIN,
fw.ActionAccept,
"",
"",
)
if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
}
_, err = m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
fw.RuleDirectionOUT,
fw.ActionAccept,
"",
"",
)
return err
if !m.wgIface.IsUserspaceBind() {
return nil
}
return nil
_, err := m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
firewall.RuleDirectionIN,
firewall.ActionAccept,
"",
"",
)
if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
}
_, err = m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
firewall.RuleDirectionOUT,
firewall.ActionAccept,
"",
"",
)
return err
}
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// reset firewall chain, clear it and drop it
func (m *Manager) reset(client *iptables.IPTables, table string) error {
ok, err := client.ChainExists(table, ChainInputFilterName)
if err != nil {
return fmt.Errorf("failed to check if input chain exists: %w", err)
}
if ok {
if ok, err := client.Exists("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
return err
} else if ok {
if err := client.Delete("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
log.WithError(err).Errorf("failed to delete default input rule: %v", err)
}
}
}
ok, err = client.ChainExists(table, ChainOutputFilterName)
if err != nil {
return fmt.Errorf("failed to check if output chain exists: %w", err)
}
if ok {
if ok, err := client.Exists("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil {
return err
} else if ok {
if err := client.Delete("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil {
log.WithError(err).Errorf("failed to delete default output rule: %v", err)
}
}
}
if err := client.ClearAndDeleteChain(table, ChainInputFilterName); err != nil {
log.Errorf("failed to clear and delete input chain: %v", err)
return nil
}
if err := client.ClearAndDeleteChain(table, ChainOutputFilterName); err != nil {
log.Errorf("failed to clear and delete input chain: %v", err)
return nil
}
for ipsetName := range m.rulesets {
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
}
if err := ipset.Destroy(ipsetName); err != nil {
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
}
delete(m.rulesets, ipsetName)
}
return nil
}
// filterRuleSpecs returns the specs of a filtering rule
func (m *Manager) filterRuleSpecs(
ip net.IP, protocol string, sPort, dPort string, direction fw.RuleDirection, action fw.Action, ipsetName string,
) (specs []string) {
matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0
if s := ip.String(); s == "0.0.0.0" || s == "::" {
matchByIP = false
}
switch direction {
case fw.RuleDirectionIN:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
} else {
specs = append(specs, "-s", ip.String())
}
}
case fw.RuleDirectionOUT:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
} else {
specs = append(specs, "-d", ip.String())
}
}
}
if protocol != "all" {
specs = append(specs, "-p", protocol)
}
if sPort != "" {
specs = append(specs, "--sport", sPort)
}
if dPort != "" {
specs = append(specs, "--dport", dPort)
}
return append(specs, "-j", m.actionToStr(action))
}
// rawClient returns corresponding iptables client for the given ip
func (m *Manager) rawClient(ip net.IP) (*iptables.IPTables, error) {
if ip.To4() != nil {
return m.ipv4Client, nil
}
if m.ipv6Client == nil {
return nil, fmt.Errorf("ipv6 is not supported")
}
return m.ipv6Client, nil
}
// client returns client with initialized chain and default rules
func (m *Manager) client(ip net.IP) (*iptables.IPTables, error) {
client, err := m.rawClient(ip)
if err != nil {
return nil, err
}
ok, err := client.ChainExists("filter", ChainInputFilterName)
if err != nil {
return nil, fmt.Errorf("failed to check if chain exists: %w", err)
}
if !ok {
if err := client.NewChain("filter", ChainInputFilterName); err != nil {
return nil, fmt.Errorf("failed to create input chain: %w", err)
}
if err := client.AppendUnique("filter", ChainInputFilterName, dropAllDefaultRule...); err != nil {
return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err)
}
if err := client.Insert("filter", "INPUT", 1, m.inputDefaultRuleSpecs...); err != nil {
return nil, fmt.Errorf("failed to create input chain jump rule: %w", err)
}
}
ok, err = client.ChainExists("filter", ChainOutputFilterName)
if err != nil {
return nil, fmt.Errorf("failed to check if chain exists: %w", err)
}
if !ok {
if err := client.NewChain("filter", ChainOutputFilterName); err != nil {
return nil, fmt.Errorf("failed to create output chain: %w", err)
}
if err := client.AppendUnique("filter", ChainOutputFilterName, dropAllDefaultRule...); err != nil {
return nil, fmt.Errorf("failed to create default drop all in netbird output chain: %w", err)
}
if err := client.AppendUnique("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil {
return nil, fmt.Errorf("failed to create output chain jump rule: %w", err)
}
}
return client, nil
}
func (m *Manager) actionToStr(action fw.Action) string {
if action == fw.ActionAccept {
return "ACCEPT"
}
return "DROP"
}
func (m *Manager) transformIPsetName(ipsetName string, sPort, dPort string) string {
switch {
case ipsetName == "":
return ""
case sPort != "" && dPort != "":
return ipsetName + "-sport-dport"
case sPort != "":
return ipsetName + "-sport"
case dPort != "":
return ipsetName + "-dport"
default:
return ipsetName
}
}

View File

@@ -1,6 +1,7 @@
package iptables
import (
"context"
"fmt"
"net"
"testing"
@@ -9,7 +10,7 @@ import (
"github.com/coreos/go-iptables/iptables"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface"
)
@@ -55,7 +56,7 @@ func TestIptablesManager(t *testing.T) {
}
// just check on the local interface
manager, err := Create(mock, true)
manager, err := Create(context.Background(), mock)
require.NoError(t, err)
time.Sleep(time.Second)
@@ -67,17 +68,20 @@ func TestIptablesManager(t *testing.T) {
time.Sleep(time.Second)
}()
var rule1 fw.Rule
var rule1 []fw.Rule
t.Run("add first rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
for _, r := range rule1 {
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
}
})
var rule2 fw.Rule
var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
@@ -87,21 +91,28 @@ func TestIptablesManager(t *testing.T) {
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule")
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...)
for _, r := range rule2 {
rr := r.(*Rule)
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
}
})
t.Run("delete first rule", func(t *testing.T) {
err := manager.DeleteRule(rule1)
require.NoError(t, err, "failed to delete rule")
for _, r := range rule1 {
err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule")
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...)
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
}
})
t.Run("delete second rule", func(t *testing.T) {
err := manager.DeleteRule(rule2)
require.NoError(t, err, "failed to delete rule")
for _, r := range rule2 {
err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule")
}
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
})
t.Run("reset check", func(t *testing.T) {
@@ -114,11 +125,11 @@ func TestIptablesManager(t *testing.T) {
err = manager.Reset()
require.NoError(t, err, "failed to reset")
ok, err := ipv4Client.ChainExists("filter", ChainInputFilterName)
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
require.NoError(t, err, "failed check chain exists")
if ok {
require.NoErrorf(t, err, "chain '%v' still exists after Reset", ChainInputFilterName)
require.NoErrorf(t, err, "chain '%v' still exists after Reset", chainNameInputRules)
}
})
}
@@ -143,7 +154,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
}
// just check on the local interface
manager, err := Create(mock, true)
manager, err := Create(context.Background(), mock)
require.NoError(t, err)
time.Sleep(time.Second)
@@ -155,7 +166,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
time.Sleep(time.Second)
}()
var rule1 fw.Rule
var rule1 []fw.Rule
t.Run("add first rule with set", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
@@ -165,12 +176,14 @@ func TestIptablesManagerIPSet(t *testing.T) {
)
require.NoError(t, err, "failed to add rule")
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
require.Equal(t, rule1.(*Rule).ipsetName, "default-dport", "ipset name must be set")
require.Equal(t, rule1.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
for _, r := range rule1 {
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set")
require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
}
})
var rule2 fw.Rule
var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
@@ -180,23 +193,29 @@ func TestIptablesManagerIPSet(t *testing.T) {
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
"default", "accept HTTPS traffic from ports range",
)
require.NoError(t, err, "failed to add rule")
require.Equal(t, rule2.(*Rule).ipsetName, "default-sport", "ipset name must be set")
require.Equal(t, rule2.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
for _, r := range rule2 {
require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
require.Equal(t, r.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
}
})
t.Run("delete first rule", func(t *testing.T) {
err := manager.DeleteRule(rule1)
require.NoError(t, err, "failed to delete rule")
for _, r := range rule1 {
err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule")
require.NotContains(t, manager.rulesets, rule1.(*Rule).ruleID, "rule must be removed form the ruleset index")
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
}
})
t.Run("delete second rule", func(t *testing.T) {
err := manager.DeleteRule(rule2)
require.NoError(t, err, "failed to delete rule")
for _, r := range rule2 {
err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule")
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty")
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
}
})
t.Run("reset check", func(t *testing.T) {
@@ -206,7 +225,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
}
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) {
t.Helper()
t.Helper()
exists, err := ipv4Client.Exists("filter", chainName, rulespec...)
require.NoError(t, err, "failed to check rule")
require.Falsef(t, !exists && mustExists, "rule '%v' does not exist", rulespec)
@@ -232,7 +251,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
manager, err := Create(mock, true)
manager, err := Create(context.Background(), mock)
require.NoError(t, err)
time.Sleep(time.Second)
@@ -243,7 +262,6 @@ func TestIptablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second)
}()
_, err = manager.client(net.ParseIP("10.20.0.100"))
require.NoError(t, err)
ip := net.ParseIP("10.20.0.100")

View File

@@ -0,0 +1,340 @@
//go:build !android
package iptables
import (
"context"
"fmt"
"strings"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
Ipv4Forwarding = "netbird-rt-forwarding"
ipv4Nat = "netbird-rt-nat"
)
// constants needed to manage and create iptable rules
const (
tableFilter = "filter"
tableNat = "nat"
chainFORWARD = "FORWARD"
chainPOSTROUTING = "POSTROUTING"
chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWD = "NETBIRD-RT-FWD"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
)
type routerManager struct {
ctx context.Context
stop context.CancelFunc
iptablesClient *iptables.IPTables
rules map[string][]string
}
func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) {
ctx, cancel := context.WithCancel(parentCtx)
m := &routerManager{
ctx: ctx,
stop: cancel,
iptablesClient: iptablesClient,
rules: make(map[string][]string),
}
err := m.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to cleanup routing rules: %s", err)
return nil, err
}
err = m.createContainers()
if err != nil {
log.Errorf("failed to create containers for route: %s", err)
}
return m, err
}
// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain
func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair)
if err != nil {
return err
}
err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair))
if err != nil {
return err
}
if !pair.Masquerade {
return nil
}
err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
if err != nil {
return err
}
err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
if err != nil {
return err
}
return nil
}
// insertRoutingRule inserts an iptable rule
func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
var err error
ruleKey := firewall.GenKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, ruleKey, pair.Source, pair.Destination)
existingRule, found := i.rules[ruleKey]
if found {
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
}
delete(i.rules, ruleKey)
}
err = i.iptablesClient.Insert(table, chain, 1, rule...)
if err != nil {
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
}
i.rules[ruleKey] = rule
return nil
}
// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error {
err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair))
if err != nil {
return err
}
if !pair.Masquerade {
return nil
}
err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair))
if err != nil {
return err
}
return nil
}
func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error {
var err error
ruleKey := firewall.GenKey(keyFormat, pair.ID)
existingRule, found := i.rules[ruleKey]
if found {
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
}
}
delete(i.rules, ruleKey)
return nil
}
func (i *routerManager) RouteingFwChainName() string {
return chainRTFWD
}
func (i *routerManager) Reset() error {
err := i.cleanUpDefaultForwardRules()
if err != nil {
return err
}
i.rules = make(map[string][]string)
return nil
}
func (i *routerManager) cleanUpDefaultForwardRules() error {
err := i.cleanJumpRules()
if err != nil {
return err
}
log.Debug("flushing routing related tables")
ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD)
if err != nil {
log.Errorf("failed check chain %s,error: %v", chainRTFWD, err)
return err
} else if ok {
err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD)
if err != nil {
log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err)
return err
}
}
ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT)
if err != nil {
log.Errorf("failed check chain %s,error: %v", chainRTNAT, err)
return err
} else if ok {
err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT)
if err != nil {
log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err)
return err
}
}
return nil
}
func (i *routerManager) createContainers() error {
if i.rules[Ipv4Forwarding] != nil {
return nil
}
errMSGFormat := "failed creating chain %s,error: %v"
err := i.createChain(tableFilter, chainRTFWD)
if err != nil {
return fmt.Errorf(errMSGFormat, chainRTFWD, err)
}
err = i.createChain(tableNat, chainRTNAT)
if err != nil {
return fmt.Errorf(errMSGFormat, chainRTNAT, err)
}
err = i.addJumpRules()
if err != nil {
return fmt.Errorf("error while creating jump rules: %v", err)
}
return nil
}
// addJumpRules create jump rules to send packets to NetBird chains
func (i *routerManager) addJumpRules() error {
rule := []string{"-j", chainRTFWD}
err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...)
if err != nil {
return err
}
i.rules[Ipv4Forwarding] = rule
rule = []string{"-j", chainRTNAT}
err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
if err != nil {
return err
}
i.rules[ipv4Nat] = rule
return nil
}
// cleanJumpRules cleans jump rules that was sending packets to NetBird chains
func (i *routerManager) cleanJumpRules() error {
var err error
errMSGFormat := "failed cleaning rule from chain %s,err: %v"
rule, found := i.rules[Ipv4Forwarding]
if found {
err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, chainFORWARD, err)
}
}
rule, found = i.rules[ipv4Nat]
if found {
err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err)
}
}
rules, err := i.iptablesClient.List("nat", "POSTROUTING")
if err != nil {
return fmt.Errorf("failed to list rules: %s", err)
}
for _, ruleString := range rules {
if !strings.Contains(ruleString, "NETBIRD") {
continue
}
rule := strings.Fields(ruleString)
err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...)
if err != nil {
return fmt.Errorf("failed to delete postrouting jump rule: %s", err)
}
}
rules, err = i.iptablesClient.List(tableFilter, "FORWARD")
if err != nil {
return fmt.Errorf("failed to list rules in FORWARD chain: %s", err)
}
for _, ruleString := range rules {
if !strings.Contains(ruleString, "NETBIRD") {
continue
}
rule := strings.Fields(ruleString)
err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...)
if err != nil {
return fmt.Errorf("failed to delete FORWARD jump rule: %s", err)
}
}
return nil
}
func (i *routerManager) createChain(table, newChain string) error {
chains, err := i.iptablesClient.ListChains(table)
if err != nil {
return fmt.Errorf("couldn't get %s table chains, error: %v", table, err)
}
shouldCreateChain := true
for _, chain := range chains {
if chain == newChain {
shouldCreateChain = false
}
}
if shouldCreateChain {
err = i.iptablesClient.NewChain(table, newChain)
if err != nil {
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
}
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
if err != nil {
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
}
}
return nil
}
// genRuleSpec generates rule specification with comment identifier
func genRuleSpec(jump, id, source, destination string) []string {
return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id}
}
func getIptablesRuleType(table string) string {
ruleType := "forwarding"
if table == tableNat {
ruleType = "nat"
}
return ruleType
}

View File

@@ -0,0 +1,229 @@
//go:build !android
package iptables
import (
"context"
"os/exec"
"testing"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
)
func isIptablesSupported() bool {
_, err4 := exec.LookPath("iptables")
return err4 == nil
}
func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouterManager(context.TODO(), iptablesClient)
require.NoError(t, err, "should return a valid iptables manager")
defer func() {
_ = manager.Reset()
}()
require.Len(t, manager.rules, 2, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD)
require.True(t, exists, "forwarding rule should exist")
exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
require.True(t, exists, "postrouting rule should exist")
pair := firewall.RouterPair{
ID: "abc",
Source: "100.100.100.1/32",
Destination: "100.100.100.0/24",
Masquerade: true,
}
forward4RuleKey := firewall.GenKey(firewall.ForwardingFormat, pair.ID)
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.Source, pair.Destination)
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
require.NoError(t, err, "inserting rule should not return error")
nat4RuleKey := firewall.GenKey(firewall.NatFormat, pair.ID)
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.Source, pair.Destination)
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
}
func TestIptablesManager_InsertRoutingRules(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouterManager(context.TODO(), iptablesClient)
require.NoError(t, err, "shouldn't return error")
defer func() {
err := manager.Reset()
if err != nil {
log.Errorf("failed to reset iptables manager: %s", err)
}
}()
err = manager.InsertRoutingRules(testCase.InputPair)
require.NoError(t, err, "forwarding pair should be inserted")
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.True(t, exists, "forwarding rule should exist")
foundRule, found := manager.rules[forwardRuleKey]
require.True(t, found, "forwarding rule should exist in the manager map")
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.True(t, exists, "income forwarding rule should exist")
foundRule, found = manager.rules[inForwardRuleKey]
require.True(t, found, "income forwarding rule should exist in the manager map")
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade {
require.True(t, exists, "nat rule should be created")
foundNatRule, foundNat := manager.rules[natRuleKey]
require.True(t, foundNat, "nat rule should exist in the map")
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[natRuleKey]
require.False(t, foundNat, "nat rule should not exist in the map")
}
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade {
require.True(t, exists, "income nat rule should be created")
foundNatRule, foundNat := manager.rules[inNatRuleKey]
require.True(t, foundNat, "income nat rule should exist in the map")
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[inNatRuleKey]
require.False(t, foundNat, "income nat rule should not exist in the map")
}
})
}
}
func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newRouterManager(context.TODO(), iptablesClient)
require.NoError(t, err, "shouldn't return error")
defer func() {
_ = manager.Reset()
}()
require.NoError(t, err, "shouldn't return error")
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
require.NoError(t, err, "inserting rule should not return error")
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
require.NoError(t, err, "inserting rule should not return error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveRoutingRules(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.False(t, exists, "forwarding rule should not exist")
_, found := manager.rules[forwardRuleKey]
require.False(t, found, "forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.False(t, exists, "income forwarding rule should not exist")
_, found = manager.rules[inForwardRuleKey]
require.False(t, found, "income forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "nat rule should not exist")
_, found = manager.rules[natRuleKey]
require.False(t, found, "nat rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "income nat rule should not exist")
_, found = manager.rules[inNatRuleKey]
require.False(t, found, "income nat rule should exist in the manager map")
})
}
}

View File

@@ -7,8 +7,7 @@ type Rule struct {
specs []string
ip string
dst bool
v6 bool
chain string
}
// GetRuleID returns the rule id

View File

@@ -0,0 +1,50 @@
package iptables
type ipList struct {
ips map[string]struct{}
}
func newIpList(ip string) ipList {
ips := make(map[string]struct{})
ips[ip] = struct{}{}
return ipList{
ips: ips,
}
}
func (s *ipList) addIP(ip string) {
s.ips[ip] = struct{}{}
}
type ipsetStore struct {
ipsets map[string]ipList // ipsetName -> ruleset
}
func newIpsetStore() *ipsetStore {
return &ipsetStore{
ipsets: make(map[string]ipList),
}
}
func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) {
r, ok := s.ipsets[ipsetName]
return r, ok
}
func (s *ipsetStore) addIpList(ipsetName string, list ipList) {
s.ipsets[ipsetName] = list
}
func (s *ipsetStore) deleteIpset(ipsetName string) {
s.ipsets[ipsetName] = ipList{}
delete(s.ipsets, ipsetName)
}
func (s *ipsetStore) ipsetNames() []string {
names := make([]string, 0, len(s.ipsets))
for name := range s.ipsets {
names = append(names, name)
}
return names
}

View File

@@ -1,9 +1,17 @@
package firewall
package manager
import (
"fmt"
"net"
)
const (
NatFormat = "netbird-nat-%s"
ForwardingFormat = "netbird-fwd-%s"
InNatFormat = "netbird-nat-in-%s"
InForwardingFormat = "netbird-fwd-in-%s"
)
// Rule abstraction should be implemented by each firewall manager
//
// Each firewall type for different OS can use different type
@@ -27,10 +35,8 @@ const (
type Action int
const (
// ActionUnknown is a unknown action
ActionUnknown Action = iota
// ActionAccept is the action to accept a packet
ActionAccept
ActionAccept Action = iota
// ActionDrop is the action to drop a packet
ActionDrop
)
@@ -56,16 +62,27 @@ type Manager interface {
action Action,
ipsetName string,
comment string,
) (Rule, error)
) ([]Rule, error)
// DeleteRule from the firewall by rule definition
DeleteRule(rule Rule) error
// IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool
// InsertRoutingRules inserts a routing firewall rule
InsertRoutingRules(pair RouterPair) error
// RemoveRoutingRules removes a routing firewall rule
RemoveRoutingRules(pair RouterPair) error
// Reset firewall to the default state
Reset() error
// Flush the changes to firewall controller
Flush() error
// TODO: migrate routemanager firewal actions to this interface
}
func GenKey(format string, input string) string {
return fmt.Sprintf(format, input)
}

View File

@@ -1,4 +1,4 @@
package firewall
package manager
import (
"strconv"

View File

@@ -0,0 +1,18 @@
package manager
type RouterPair struct {
ID string
Source string
Destination string
Masquerade bool
}
func GetInPair(pair RouterPair) RouterPair {
return RouterPair{
ID: pair.ID,
// invert Source/Destination
Source: pair.Destination,
Destination: pair.Source,
Masquerade: pair.Masquerade,
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,85 @@
package nftables
import (
"net"
)
type ipsetStore struct {
ipsetReference map[string]int
ipsets map[string]map[string]struct{} // ipsetName -> list of ips
}
func newIpsetStore() *ipsetStore {
return &ipsetStore{
ipsetReference: make(map[string]int),
ipsets: make(map[string]map[string]struct{}),
}
}
func (s *ipsetStore) ips(ipsetName string) (map[string]struct{}, bool) {
r, ok := s.ipsets[ipsetName]
return r, ok
}
func (s *ipsetStore) newIpset(ipsetName string) map[string]struct{} {
s.ipsetReference[ipsetName] = 0
ipList := make(map[string]struct{})
s.ipsets[ipsetName] = ipList
return ipList
}
func (s *ipsetStore) deleteIpset(ipsetName string) {
delete(s.ipsetReference, ipsetName)
delete(s.ipsets, ipsetName)
}
func (s *ipsetStore) DeleteIpFromSet(ipsetName string, ip net.IP) {
ipList, ok := s.ipsets[ipsetName]
if !ok {
return
}
delete(ipList, ip.String())
}
func (s *ipsetStore) AddIpToSet(ipsetName string, ip net.IP) {
ipList, ok := s.ipsets[ipsetName]
if !ok {
return
}
ipList[ip.String()] = struct{}{}
}
func (s *ipsetStore) IsIpInSet(ipsetName string, ip net.IP) bool {
ipList, ok := s.ipsets[ipsetName]
if !ok {
return false
}
_, ok = ipList[ip.String()]
return ok
}
func (s *ipsetStore) AddReferenceToIpset(ipsetName string) {
s.ipsetReference[ipsetName]++
}
func (s *ipsetStore) DeleteReferenceFromIpSet(ipsetName string) {
r, ok := s.ipsetReference[ipsetName]
if !ok {
return
}
if r == 0 {
return
}
s.ipsetReference[ipsetName]--
}
func (s *ipsetStore) HasReferenceToSet(ipsetName string) bool {
if _, ok := s.ipsetReference[ipsetName]; !ok {
return false
}
if s.ipsetReference[ipsetName] == 0 {
return false
}
return true
}

View File

@@ -2,90 +2,52 @@ package nftables
import (
"bytes"
"encoding/binary"
"context"
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"time"
"github.com/google/nftables"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/iface"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
// FilterTableName is the name of the table that is used for filtering by the Netbird client
FilterTableName = "netbird-acl"
// FilterInputChainName is the name of the chain that is used for filtering incoming packets
FilterInputChainName = "netbird-acl-input-filter"
// FilterOutputChainName is the name of the chain that is used for filtering outgoing packets
FilterOutputChainName = "netbird-acl-output-filter"
AllowNetbirdInputRuleID = "allow Netbird incoming traffic"
// tableName is the name of the table that is used for filtering by the Netbird client
tableName = "netbird"
)
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
// Manager of iptables firewall
type Manager struct {
mutex sync.Mutex
rConn *nftables.Conn
sConn *nftables.Conn
tableIPv4 *nftables.Table
tableIPv6 *nftables.Table
filterInputChainIPv4 *nftables.Chain
filterOutputChainIPv4 *nftables.Chain
filterInputChainIPv6 *nftables.Chain
filterOutputChainIPv6 *nftables.Chain
rulesetManager *rulesetManager
setRemovedIPs map[string]struct{}
setRemoved map[string]*nftables.Set
mutex sync.Mutex
rConn *nftables.Conn
wgIface iFaceMapper
}
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
router *router
aclManager *AclManager
}
// Create nftables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) {
// sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation)
// and is permanent. Using same connection for booth type of operations
// overloads netlink with high amount of rules ( > 10000)
sConn, err := nftables.New(nftables.AsLasting())
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
m := &Manager{
rConn: &nftables.Conn{},
wgIface: wgIface,
}
workTable, err := m.createWorkTable()
if err != nil {
return nil, err
}
m := &Manager{
rConn: &nftables.Conn{},
sConn: sConn,
rulesetManager: newRuleManager(),
setRemovedIPs: map[string]struct{}{},
setRemoved: map[string]*nftables.Set{},
wgIface: wgIface,
m.router, err = newRouter(context, workTable)
if err != nil {
return nil, err
}
if err := m.Reset(); err != nil {
m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName())
if err != nil {
return nil, err
}
@@ -98,649 +60,58 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
// rule ID as comment for the rule
func (m *Manager) AddFiltering(
ip net.IP,
proto fw.Protocol,
sPort *fw.Port,
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipsetName string,
comment string,
) (fw.Rule, error) {
) ([]firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
var (
err error
ipset *nftables.Set
table *nftables.Table
chain *nftables.Chain
)
if direction == fw.RuleDirectionOUT {
table, chain, err = m.chain(
ip,
FilterOutputChainName,
nftables.ChainHookOutput,
nftables.ChainPriorityFilter,
nftables.ChainTypeFilter)
} else {
table, chain, err = m.chain(
ip,
FilterInputChainName,
nftables.ChainHookInput,
nftables.ChainPriorityFilter,
nftables.ChainTypeFilter)
}
if err != nil {
return nil, err
}
rawIP := ip.To4()
if rawIP == nil {
rawIP = ip.To16()
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
}
rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName)
if ipsetName != "" {
// if we already have set with given name, just add ip to the set
// and return rule with new ID in other case let's create rule
// with fresh created set and set element
var isSetNew bool
ipset, err = m.rConn.GetSetByName(table, ipsetName)
if err != nil {
if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil {
return nil, fmt.Errorf("get set name: %v", err)
}
isSetNew = true
}
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
return nil, fmt.Errorf("add set element for the first time: %v", err)
}
if err := m.sConn.Flush(); err != nil {
return nil, fmt.Errorf("flush add elements: %v", err)
}
if !isSetNew {
// if we already have nftables rules with set for given direction
// just add new rule to the ruleset and return new fw.Rule object
if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok {
return m.rulesetManager.addRule(ruleset, rawIP)
}
// if ipset exists but it is not linked to rule for given direction
// create new rule for direction and bind ipset to it later
}
}
ifaceKey := expr.MetaKeyIIFNAME
if direction == fw.RuleDirectionOUT {
ifaceKey = expr.MetaKeyOIFNAME
}
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
}
if proto != "all" {
expressions = append(expressions, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(9),
Len: uint32(1),
})
var protoData []byte
switch proto {
case fw.ProtocolTCP:
protoData = []byte{unix.IPPROTO_TCP}
case fw.ProtocolUDP:
protoData = []byte{unix.IPPROTO_UDP}
case fw.ProtocolICMP:
protoData = []byte{unix.IPPROTO_ICMP}
default:
return nil, fmt.Errorf("unsupported protocol: %s", proto)
}
expressions = append(expressions, &expr.Cmp{
Register: 1,
Op: expr.CmpOpEq,
Data: protoData,
})
}
// check if rawIP contains zeroed IPv4 0.0.0.0 or same IPv6 value
// in that case not add IP match expression into the rule definition
if !bytes.HasPrefix(anyIP, rawIP) {
// source address position
addrLen := uint32(len(rawIP))
addrOffset := uint32(12)
if addrLen == 16 {
addrOffset = 8
}
// change to destination address position if need
if direction == fw.RuleDirectionOUT {
addrOffset += addrLen
}
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: addrOffset,
Len: addrLen,
},
)
// add individual IP for match if no ipset defined
if ipset == nil {
expressions = append(expressions,
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rawIP,
},
)
} else {
expressions = append(expressions,
&expr.Lookup{
SourceRegister: 1,
SetName: ipsetName,
SetID: ipset.ID,
},
)
}
}
if sPort != nil && len(sPort.Values) != 0 {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 0,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*sPort),
},
)
}
if dPort != nil && len(dPort.Values) != 0 {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*dPort),
},
)
}
if action == fw.ActionAccept {
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
} else {
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
}
userData := []byte(strings.Join([]string{rulesetID, comment}, " "))
rule := m.rConn.InsertRule(&nftables.Rule{
Table: table,
Chain: chain,
Position: 0,
Exprs: expressions,
UserData: userData,
})
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush insert rule: %v", err)
}
ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset)
return m.rulesetManager.addRule(ruleset, rawIP)
}
// getRulesetID returns ruleset ID based on given parameters
func (m *Manager) getRulesetID(
ip net.IP,
proto fw.Protocol,
sPort *fw.Port,
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
ipsetName string,
) string {
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
if sPort != nil {
rulesetID += sPort.String()
}
rulesetID += ":"
if dPort != nil {
rulesetID += dPort.String()
}
rulesetID += ":"
rulesetID += strconv.Itoa(int(action))
if ipsetName == "" {
return "ip:" + ip.String() + rulesetID
}
return "set:" + ipsetName + rulesetID
}
// createSet in given table by name
func (m *Manager) createSet(
table *nftables.Table,
rawIP []byte,
name string,
) (*nftables.Set, error) {
keyType := nftables.TypeIPAddr
if len(rawIP) == 16 {
keyType = nftables.TypeIP6Addr
}
// else we create new ipset and continue creating rule
ipset := &nftables.Set{
Name: name,
Table: table,
Dynamic: true,
KeyType: keyType,
}
if err := m.rConn.AddSet(ipset, nil); err != nil {
return nil, fmt.Errorf("create set: %v", err)
}
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush created set: %v", err)
}
return ipset, nil
}
// chain returns the chain for the given IP address with specific settings
func (m *Manager) chain(
ip net.IP,
name string,
hook nftables.ChainHook,
priority nftables.ChainPriority,
cType nftables.ChainType,
) (*nftables.Table, *nftables.Chain, error) {
var err error
getChain := func(c *nftables.Chain, tf nftables.TableFamily) (*nftables.Chain, error) {
if c != nil {
return c, nil
}
return m.createChainIfNotExists(tf, FilterTableName, name, hook, priority, cType)
}
if ip.To4() != nil {
if name == FilterInputChainName {
m.filterInputChainIPv4, err = getChain(m.filterInputChainIPv4, nftables.TableFamilyIPv4)
return m.tableIPv4, m.filterInputChainIPv4, err
}
m.filterOutputChainIPv4, err = getChain(m.filterOutputChainIPv4, nftables.TableFamilyIPv4)
return m.tableIPv4, m.filterOutputChainIPv4, err
}
if name == FilterInputChainName {
m.filterInputChainIPv6, err = getChain(m.filterInputChainIPv6, nftables.TableFamilyIPv6)
return m.tableIPv4, m.filterInputChainIPv6, err
}
m.filterOutputChainIPv6, err = getChain(m.filterOutputChainIPv6, nftables.TableFamilyIPv6)
return m.tableIPv4, m.filterOutputChainIPv6, err
}
// table returns the table for the given family of the IP address
func (m *Manager) table(
family nftables.TableFamily, tableName string,
) (*nftables.Table, error) {
// we cache access to Netbird ACL table only
if tableName != FilterTableName {
return m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
}
if family == nftables.TableFamilyIPv4 {
if m.tableIPv4 != nil {
return m.tableIPv4, nil
}
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
if err != nil {
return nil, err
}
m.tableIPv4 = table
return m.tableIPv4, nil
}
if m.tableIPv6 != nil {
return m.tableIPv6, nil
}
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6, tableName)
if err != nil {
return nil, err
}
m.tableIPv6 = table
return m.tableIPv6, nil
}
func (m *Manager) createTableIfNotExists(
family nftables.TableFamily, tableName string,
) (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(family)
if err != nil {
return nil, fmt.Errorf("list of tables: %w", err)
}
for _, t := range tables {
if t.Name == tableName {
return t, nil
}
}
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
if err := m.rConn.Flush(); err != nil {
return nil, err
}
return table, nil
}
func (m *Manager) createChainIfNotExists(
family nftables.TableFamily,
tableName string,
name string,
hooknum nftables.ChainHook,
priority nftables.ChainPriority,
chainType nftables.ChainType,
) (*nftables.Chain, error) {
table, err := m.table(family, tableName)
if err != nil {
return nil, err
}
chains, err := m.rConn.ListChainsOfTableFamily(family)
if err != nil {
return nil, fmt.Errorf("list of chains: %w", err)
}
for _, c := range chains {
if c.Name == name && c.Table.Name == table.Name {
return c, nil
}
}
polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{
Name: name,
Table: table,
Hooknum: hooknum,
Priority: priority,
Type: chainType,
Policy: &polAccept,
}
chain = m.rConn.AddChain(chain)
ifaceKey := expr.MetaKeyIIFNAME
shiftDSTAddr := 0
if name == FilterOutputChainName {
ifaceKey = expr.MetaKeyOIFNAME
shiftDSTAddr = 1
}
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
}
mask, _ := netip.AddrFromSlice(m.wgIface.Address().Network.Mask)
if m.wgIface.Address().IP.To4() == nil {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To16())
expressions = append(expressions,
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(8 + (16 * shiftDSTAddr)),
Len: 16,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 16,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: mask.Unmap().AsSlice(),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Verdict{Kind: expr.VerdictAccept},
)
} else {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
expressions = append(expressions,
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(12 + (4 * shiftDSTAddr)),
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Verdict{Kind: expr.VerdictAccept},
)
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: expressions,
})
expressions = []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{Kind: expr.VerdictDrop},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: expressions,
})
if err := m.rConn.Flush(); err != nil {
return nil, err
}
return chain, nil
return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
}
// DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule fw.Rule) error {
func (m *Manager) DeleteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
nativeRule, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
}
if nativeRule.nftRule == nil {
return nil
}
if nativeRule.nftSet != nil {
// call twice of delete set element raises error
// so we need to check if element is already removed
key := fmt.Sprintf("%s:%v", nativeRule.nftSet.Name, nativeRule.ip)
if _, ok := m.setRemovedIPs[key]; !ok {
err := m.sConn.SetDeleteElements(nativeRule.nftSet, []nftables.SetElement{{Key: nativeRule.ip}})
if err != nil {
log.Errorf("delete elements for set %q: %v", nativeRule.nftSet.Name, err)
}
if err := m.sConn.Flush(); err != nil {
return err
}
m.setRemovedIPs[key] = struct{}{}
}
}
if m.rulesetManager.deleteRule(nativeRule) {
// deleteRule indicates that we still have IP in the ruleset
// it means we should not remove the nftables rule but need to update set
// so we prepare IP to be removed from set on the next flush call
return nil
}
// ruleset doesn't contain IP anymore (or contains only one), remove nft rule
if err := m.rConn.DelRule(nativeRule.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
if err := m.rConn.Flush(); err != nil {
return err
}
nativeRule.nftRule = nil
if nativeRule.nftSet != nil {
if _, ok := m.setRemoved[nativeRule.nftSet.Name]; !ok {
m.setRemoved[nativeRule.nftSet.Name] = nativeRule.nftSet
}
nativeRule.nftSet = nil
}
return nil
return m.aclManager.DeleteRule(rule)
}
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
chains, err := m.rConn.ListChains()
if err != nil {
return fmt.Errorf("list of chains: %w", err)
}
for _, c := range chains {
// delete Netbird allow input traffic rule if it exists
if c.Table.Name == "filter" && c.Name == "INPUT" {
rules, err := m.rConn.GetRules(c.Table, c)
if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err)
continue
}
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(AllowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
}
}
}
}
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
m.rConn.DelChain(c)
}
}
tables, err := m.rConn.ListTables()
if err != nil {
return fmt.Errorf("list of tables: %w", err)
}
for _, t := range tables {
if t.Name == FilterTableName {
m.rConn.DelTable(t)
}
}
return m.rConn.Flush()
func (m *Manager) IsServerRouteSupported() bool {
return true
}
// Flush rule/chain/set operations from the buffer
//
// Method also get all rules after flush and refreshes handle values in the rulesets
func (m *Manager) Flush() error {
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.flushWithBackoff(); err != nil {
return err
}
return m.router.InsertRoutingRules(pair)
}
// set must be removed after flush rule changes
// otherwise we will get error
for _, s := range m.setRemoved {
m.rConn.FlushSet(s)
m.rConn.DelSet(s)
}
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if len(m.setRemoved) > 0 {
if err := m.flushWithBackoff(); err != nil {
return err
}
}
m.setRemovedIPs = map[string]struct{}{}
m.setRemoved = map[string]*nftables.Set{}
if err := m.refreshRuleHandles(m.tableIPv4, m.filterInputChainIPv4); err != nil {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
}
if err := m.refreshRuleHandles(m.tableIPv4, m.filterOutputChainIPv4); err != nil {
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
}
if err := m.refreshRuleHandles(m.tableIPv6, m.filterInputChainIPv6); err != nil {
log.Errorf("failed to refresh rule handles IPv6 input chain: %v", err)
}
if err := m.refreshRuleHandles(m.tableIPv6, m.filterOutputChainIPv6); err != nil {
log.Errorf("failed to refresh rule handles IPv6 output chain: %v", err)
}
return nil
return m.router.RemoveRoutingRules(pair)
}
// AllowNetbird allows netbird interface traffic
// todo review this method usage
func (m *Manager) AllowNetbird() error {
m.mutex.Lock()
defer m.mutex.Unlock()
tf := nftables.TableFamilyIPv4
if m.wgIface.Address().IP.To4() == nil {
tf = nftables.TableFamilyIPv6
}
chains, err := m.rConn.ListChainsOfTableFamily(tf)
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list of chains: %w", err)
}
@@ -777,47 +148,75 @@ func (m *Manager) AllowNetbird() error {
return nil
}
func (m *Manager) flushWithBackoff() (err error) {
backoff := 4
backoffTime := 1000 * time.Millisecond
for i := 0; ; i++ {
err = m.rConn.Flush()
if err != nil {
if !strings.Contains(err.Error(), "busy") {
return
}
log.Error("failed to flush nftables, retrying...")
if i == backoff-1 {
return err
}
time.Sleep(backoffTime)
backoffTime *= 2
continue
}
break
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
chains, err := m.rConn.ListChains()
if err != nil {
return fmt.Errorf("list of chains: %w", err)
}
return
for _, c := range chains {
// delete Netbird allow input traffic rule if it exists
if c.Table.Name == "filter" && c.Name == "INPUT" {
rules, err := m.rConn.GetRules(c.Table, c)
if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err)
continue
}
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
}
}
}
}
}
m.router.ResetForwardRules()
tables, err := m.rConn.ListTables()
if err != nil {
return fmt.Errorf("list of tables: %w", err)
}
for _, t := range tables {
if t.Name == tableName {
m.rConn.DelTable(t)
}
}
return m.rConn.Flush()
}
func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chain) error {
if table == nil || chain == nil {
return nil
}
// Flush rule/chain/set operations from the buffer
//
// Method also get all rules after flush and refreshes handle values in the rulesets
// todo review this method usage
func (m *Manager) Flush() error {
m.mutex.Lock()
defer m.mutex.Unlock()
list, err := m.rConn.GetRules(table, chain)
return m.aclManager.Flush()
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return err
return nil, fmt.Errorf("list of tables: %w", err)
}
for _, rule := range list {
if len(rule.UserData) != 0 {
if err := m.rulesetManager.setNftRuleHandle(rule); err != nil {
log.Errorf("failed to set rule handle: %v", err)
}
for _, t := range tables {
if t.Name == tableName {
m.rConn.DelTable(t)
}
}
return nil
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
err = m.rConn.Flush()
return table, err
}
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
@@ -835,7 +234,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
Kind: expr.VerdictAccept,
},
},
UserData: []byte(AllowNetbirdInputRuleID),
UserData: []byte(allowNetbirdInputRuleID),
}
_ = m.rConn.InsertRule(rule)
}
@@ -857,15 +256,3 @@ func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftable
}
return nil
}
func encodePort(port fw.Port) []byte {
bs := make([]byte, 2)
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
return bs
}
func ifname(n string) []byte {
b := make([]byte, 16)
copy(b, []byte(n+"\x00"))
return b
}

View File

@@ -1,6 +1,7 @@
package nftables
import (
"context"
"fmt"
"net"
"net/netip"
@@ -12,7 +13,7 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface"
)
@@ -53,7 +54,7 @@ func TestNftablesManager(t *testing.T) {
}
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(context.Background(), mock)
require.NoError(t, err)
time.Sleep(time.Second * 3)
@@ -82,14 +83,10 @@ func TestNftablesManager(t *testing.T) {
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules")
// test expectations:
// 1) regular rule
// 2) "accept extra routed traffic rule" for the interface
// 3) "drop all rule" for the interface
require.Len(t, rules, 3, "expected 3 rules")
require.Len(t, rules, 1, "expected 1 rules")
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
@@ -137,18 +134,17 @@ func TestNftablesManager(t *testing.T) {
}
require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
err = manager.DeleteRule(rule)
require.NoError(t, err, "failed to delete rule")
for _, r := range rule {
err = manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule")
}
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules")
// test expectations:
// 1) "accept extra routed traffic rule" for the interface
// 2) "drop all rule" for the interface
require.Len(t, rules, 2, "expected 2 rules after deletion")
require.Len(t, rules, 0, "expected 0 rules after deletion")
err = manager.Reset()
require.NoError(t, err, "failed to reset")
@@ -173,7 +169,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(context.Background(), mock)
require.NoError(t, err)
time.Sleep(time.Second * 3)

View File

@@ -0,0 +1,413 @@
package nftables
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"net/netip"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/manager"
)
const (
chainNameRouteingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-nat"
userDataAcceptForwardRuleSrc = "frwacceptsrc"
userDataAcceptForwardRuleDst = "frwacceptdst"
)
// some presets for building nftable rules
var (
zeroXor = binaryutil.NativeEndian.PutUint32(0)
exprCounterAccept = []expr.Any{
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
)
type router struct {
ctx context.Context
stop context.CancelFunc
conn *nftables.Conn
workTable *nftables.Table
filterTable *nftables.Table
chains map[string]*nftables.Chain
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
rules map[string]*nftables.Rule
isDefaultFwdRulesEnabled bool
}
func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
r := &router{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
}
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules")
} else {
return nil, err
}
}
err = r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.createContainers()
if err != nil {
log.Errorf("failed to create containers for route: %s", err)
}
return r, err
}
func (r *router) RouteingFwChainName() string {
return chainNameRouteingFw
}
// ResetForwardRules cleans existing nftables default forward rules from the system
func (r *router) ResetForwardRules() {
err := r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to reset forward rules: %s", err)
}
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
}
for _, table := range tables {
if table.Name == "filter" {
return table, nil
}
}
return nil, errFilterTableNotFound
}
func (r *router) createContainers() error {
r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRouteingFw,
Table: r.workTable,
})
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityNATSource - 1,
Type: nftables.ChainTypeNAT,
})
err := r.refreshRulesMap()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err)
}
return nil
}
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
if err != nil {
return err
}
err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
if err != nil {
return err
}
err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
if err != nil {
return err
}
if pair.Masquerade {
err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
if err != nil {
return err
}
err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
if err != nil {
return err
}
}
if r.filterTable != nil && !r.isDefaultFwdRulesEnabled {
log.Debugf("add default accept forward rule")
r.acceptForwardRule(pair.Source)
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err)
}
return nil
}
// insertRoutingRule inserts a nftable rule to the conn client flush queue
func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
var expression []expr.Any
if isNat {
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic
} else {
expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic
}
ruleKey := manager.GenKey(format, pair.ID)
_, exists := r.rules[ruleKey]
if exists {
err := r.removeRoutingRule(format, pair)
if err != nil {
return err
}
}
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainName],
Exprs: expression,
UserData: []byte(ruleKey),
})
return nil
}
func (r *router) acceptForwardRule(sourceNetwork string) {
src := generateCIDRMatcherExpressions(true, sourceNetwork)
dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0")
var exprs []expr.Any
exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic
Kind: expr.VerdictAccept,
})...)
rule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleSrc),
}
r.conn.AddRule(rule)
src = generateCIDRMatcherExpressions(true, "0.0.0.0/0")
dst = generateCIDRMatcherExpressions(false, sourceNetwork)
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
Kind: expr.VerdictAccept,
})...)
rule = &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleDst),
}
r.conn.AddRule(rule)
r.isDefaultFwdRulesEnabled = true
}
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
if err != nil {
return err
}
err = r.removeRoutingRule(manager.ForwardingFormat, pair)
if err != nil {
return err
}
err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair))
if err != nil {
return err
}
err = r.removeRoutingRule(manager.NatFormat, pair)
if err != nil {
return err
}
err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair))
if err != nil {
return err
}
if len(r.rules) == 0 {
err := r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
}
log.Debugf("nftables: removed rules for %s", pair.Destination)
return nil
}
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error {
ruleKey := manager.GenKey(format, pair.ID)
rule, found := r.rules[ruleKey]
if found {
ruleType := "forwarding"
if rule.Chain.Type == nftables.ChainTypeNAT {
ruleType = "nat"
}
err := r.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err)
}
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination)
delete(r.rules, ruleKey)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
func (r *router) refreshRulesMap() error {
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
r.rules[string(rule.UserData)] = rule
}
}
}
return nil
}
func (r *router) cleanUpDefaultForwardRules() error {
if r.filterTable == nil {
r.isDefaultFwdRulesEnabled = false
return nil
}
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return err
}
var rules []*nftables.Rule
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name {
continue
}
if chain.Name != "FORWARD" {
continue
}
rules, err = r.conn.GetRules(r.filterTable, chain)
if err != nil {
return err
}
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) {
err := r.conn.DelRule(rule)
if err != nil {
return err
}
}
}
r.isDefaultFwdRulesEnabled = false
return r.conn.Flush()
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
ip, network, _ := net.ParseCIDR(cidr)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
var offSet uint32
if source {
offSet = 12 // src offset
} else {
offSet = 16 // dst offset
}
return []expr.Any{
// fetch src add
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offSet,
Len: 4,
},
// net mask
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 4,
Mask: network.Mask,
Xor: zeroXor,
},
// net address
&expr.Cmp{
Register: 1,
Data: add.AsSlice(),
},
}
}

View File

@@ -0,0 +1,280 @@
//go:build !android
package nftables
import (
"context"
"testing"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
"github.com/google/nftables/expr"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
)
const (
// UNKNOWN is the default value for the firewall type for unknown firewall type
UNKNOWN = iota
// IPTABLES is the value for the iptables firewall type
IPTABLES
// NFTABLES is the value for the nftables firewall type
NFTABLES
)
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this OS")
}
table, err := createWorkTable()
if err != nil {
t.Fatal(err)
}
defer deleteWorkTable()
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(context.TODO(), table)
require.NoError(t, err, "failed to create router")
nftablesTestingClient := &nftables.Conn{}
defer manager.ResetForwardRules()
require.NoError(t, err, "shouldn't return error")
err = manager.InsertRoutingRules(testCase.InputPair)
defer func() {
_ = manager.RemoveRoutingRules(testCase.InputPair)
}()
require.NoError(t, err, "forwarding pair should be inserted")
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.InputPair.Masquerade {
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
testingExpression = append(sourceExp, destExp...) //nolint:gocritic
inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
found = 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.InputPair.Masquerade {
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
})
}
}
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this OS")
}
table, err := createWorkTable()
if err != nil {
t.Fatal(err)
}
defer deleteWorkTable()
for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(context.TODO(), table)
require.NoError(t, err, "failed to create router")
nftablesTestingClient := &nftables.Conn{}
defer manager.ResetForwardRules()
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRouteingFw],
Exprs: forwardExp,
UserData: []byte(forwardRuleKey),
})
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRoutingNat],
Exprs: natExp,
UserData: []byte(natRuleKey),
})
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRouteingFw],
Exprs: forwardExp,
UserData: []byte(inForwardRuleKey),
})
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRoutingNat],
Exprs: natExp,
UserData: []byte(inNatRuleKey),
})
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
manager.ResetForwardRules()
err = manager.RemoveRoutingRules(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 {
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
}
}
}
})
}
}
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func check() int {
nf := nftables.Conn{}
if _, err := nf.ListChains(); err == nil {
return NFTABLES
}
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return UNKNOWN
}
if isIptablesClientAvailable(ip) {
return IPTABLES
}
return UNKNOWN
}
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}
func createWorkTable() (*nftables.Table, error) {
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return nil, err
}
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, err
}
for _, t := range tables {
if t.Name == tableName {
sConn.DelTable(t)
}
}
table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
err = sConn.Flush()
return table, err
}
func deleteWorkTable() {
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return
}
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return
}
for _, t := range tables {
if t.Name == tableName {
sConn.DelTable(t)
}
}
}

View File

@@ -1,6 +1,8 @@
package nftables
import (
"net"
"github.com/google/nftables"
)
@@ -8,9 +10,8 @@ import (
type Rule struct {
nftRule *nftables.Rule
nftSet *nftables.Set
ruleID string
ip []byte
ruleID string
ip net.IP
}
// GetRuleID returns the rule id

View File

@@ -1,115 +0,0 @@
package nftables
import (
"bytes"
"fmt"
"github.com/google/nftables"
"github.com/rs/xid"
)
// nftRuleset links native firewall rule and ipset to ACL generated rules
type nftRuleset struct {
nftRule *nftables.Rule
nftSet *nftables.Set
issuedRules map[string]*Rule
rulesetID string
}
type rulesetManager struct {
rulesets map[string]*nftRuleset
nftSetName2rulesetID map[string]string
issuedRuleID2rulesetID map[string]string
}
func newRuleManager() *rulesetManager {
return &rulesetManager{
rulesets: map[string]*nftRuleset{},
nftSetName2rulesetID: map[string]string{},
issuedRuleID2rulesetID: map[string]string{},
}
}
func (r *rulesetManager) getRuleset(rulesetID string) (*nftRuleset, bool) {
ruleset, ok := r.rulesets[rulesetID]
return ruleset, ok
}
func (r *rulesetManager) createRuleset(
rulesetID string,
nftRule *nftables.Rule,
nftSet *nftables.Set,
) *nftRuleset {
ruleset := nftRuleset{
rulesetID: rulesetID,
nftRule: nftRule,
nftSet: nftSet,
issuedRules: map[string]*Rule{},
}
r.rulesets[ruleset.rulesetID] = &ruleset
if nftSet != nil {
r.nftSetName2rulesetID[nftSet.Name] = ruleset.rulesetID
}
return &ruleset
}
func (r *rulesetManager) addRule(
ruleset *nftRuleset,
ip []byte,
) (*Rule, error) {
if _, ok := r.rulesets[ruleset.rulesetID]; !ok {
return nil, fmt.Errorf("ruleset not found")
}
rule := Rule{
nftRule: ruleset.nftRule,
nftSet: ruleset.nftSet,
ruleID: xid.New().String(),
ip: ip,
}
ruleset.issuedRules[rule.ruleID] = &rule
r.issuedRuleID2rulesetID[rule.ruleID] = ruleset.rulesetID
return &rule, nil
}
// deleteRule from ruleset and returns true if contains other rules
func (r *rulesetManager) deleteRule(rule *Rule) bool {
rulesetID, ok := r.issuedRuleID2rulesetID[rule.ruleID]
if !ok {
return false
}
ruleset := r.rulesets[rulesetID]
if ruleset.nftRule == nil {
return false
}
delete(r.issuedRuleID2rulesetID, rule.ruleID)
delete(ruleset.issuedRules, rule.ruleID)
if len(ruleset.issuedRules) == 0 {
delete(r.rulesets, ruleset.rulesetID)
if rule.nftSet != nil {
delete(r.nftSetName2rulesetID, rule.nftSet.Name)
}
return false
}
return true
}
// setNftRuleHandle finds rule by userdata which contains rulesetID and updates it's handle number
//
// This is important to do, because after we add rule to the nftables we can't update it until
// we set correct handle value to it.
func (r *rulesetManager) setNftRuleHandle(nftRule *nftables.Rule) error {
split := bytes.Split(nftRule.UserData, []byte(" "))
ruleset, ok := r.rulesets[string(split[0])]
if !ok {
return fmt.Errorf("ruleset not found")
}
*ruleset.nftRule = *nftRule
return nil
}

View File

@@ -1,122 +0,0 @@
package nftables
import (
"testing"
"github.com/google/nftables"
"github.com/stretchr/testify/require"
)
func TestRulesetManager_createRuleset(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{
UserData: []byte(rulesetID),
}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
require.NotNil(t, ruleset, "createRuleset() failed")
require.Equal(t, ruleset.rulesetID, rulesetID, "rulesetID is incorrect")
require.Equal(t, ruleset.nftRule, &nftRule, "nftRule is incorrect")
}
func TestRulesetManager_addRule(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
// Add a rule to the ruleset.
ip := []byte("192.168.1.1")
rule, err := rulesetManager.addRule(ruleset, ip)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule, "rule should not be nil")
require.NotEqual(t, rule.ruleID, "ruleID is empty")
require.EqualValues(t, rule.ip, ip, "ip is incorrect")
require.Contains(t, ruleset.issuedRules, rule.ruleID, "ruleID already exists in ruleset")
require.Contains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "ruleID already exists in ruleset manager")
ruleset2 := &nftRuleset{
rulesetID: "ruleset-2",
}
_, err = rulesetManager.addRule(ruleset2, ip)
require.Error(t, err, "addRule() should have failed")
}
func TestRulesetManager_deleteRule(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
// Add a rule to the ruleset.
ip := []byte("192.168.1.1")
rule, err := rulesetManager.addRule(ruleset, ip)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule, "rule should not be nil")
ip2 := []byte("192.168.1.1")
rule2, err := rulesetManager.addRule(ruleset, ip2)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule2, "rule should not be nil")
hasNext := rulesetManager.deleteRule(rule)
require.True(t, hasNext, "deleteRule() should have returned true")
// Check that the rule is no longer in the manager.
require.NotContains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "rule should have been deleted")
hasNext = rulesetManager.deleteRule(rule2)
require.False(t, hasNext, "deleteRule() should have returned false")
}
func TestRulesetManager_setNftRuleHandle(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
// Add a rule to the ruleset.
ip := []byte("192.168.0.1")
rule, err := rulesetManager.addRule(ruleset, ip)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule, "rule should not be nil")
nftRuleCopy := nftRule
nftRuleCopy.Handle = 2
nftRuleCopy.UserData = []byte(rulesetID)
err = rulesetManager.setNftRuleHandle(&nftRuleCopy)
require.NoError(t, err, "setNftRuleHandle() failed")
// check correct work with references
require.Equal(t, nftRule.Handle, uint64(2), "nftRule.Handle is incorrect")
}
func TestRulesetManager_getRuleset(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
nftSet := nftables.Set{
ID: 2,
}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, &nftSet)
require.NotNil(t, ruleset, "createRuleset() failed")
find, ok := rulesetManager.getRuleset(rulesetID)
require.True(t, ok, "getRuleset() failed")
require.Equal(t, ruleset, find, "getRulesetBySetID() failed")
_, ok = rulesetManager.getRuleset("does-not-exist")
require.False(t, ok, "getRuleset() failed")
}

View File

@@ -0,0 +1,47 @@
//go:build !android
package test
import firewall "github.com/netbirdio/netbird/client/firewall/manager"
var (
InsertRuleTestCases = []struct {
Name string
InputPair firewall.RouterPair
}{
{
Name: "Insert Forwarding IPV4 Rule",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: "100.100.100.1/32",
Destination: "100.100.200.0/24",
Masquerade: false,
},
},
{
Name: "Insert Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: "100.100.100.1/32",
Destination: "100.100.200.0/24",
Masquerade: true,
},
},
}
RemoveRuleTestCases = []struct {
Name string
InputPair firewall.RouterPair
IpVersion string
}{
{
Name: "Remove Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: "100.100.100.1/32",
Destination: "100.100.200.0/24",
Masquerade: true,
},
},
}
)

View File

@@ -1,4 +1,4 @@
//go:build !windows && !linux
//go:build !windows
package uspfilter
@@ -10,10 +10,16 @@ func (m *Manager) Reset() error {
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if m.nativeFirewall != nil {
return m.nativeFirewall.Reset()
}
return nil
}
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
if m.nativeFirewall != nil {
return m.nativeFirewall.AllowNetbird()
}
return nil
}

View File

@@ -1,21 +0,0 @@
package uspfilter
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
return nil
}
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if m.resetHook != nil {
return m.resetHook()
}
return nil
}

View File

@@ -5,7 +5,7 @@ import (
"github.com/google/gopacket"
fw "github.com/netbirdio/netbird/client/firewall"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
// Rule to handle management of rules
@@ -15,7 +15,7 @@ type Rule struct {
ipLayer gopacket.LayerType
matchByIP bool
protoLayer gopacket.LayerType
direction fw.RuleDirection
direction firewall.RuleDirection
sPort uint16
dPort uint16
drop bool

View File

@@ -10,12 +10,16 @@ import (
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
fw "github.com/netbirdio/netbird/client/firewall"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface"
)
const layerTypeAll = 0
var (
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
SetFilter(iface.PacketFilter) error
@@ -27,12 +31,12 @@ type RuleSet map[string]Rule
// Manager userspace firewall manager
type Manager struct {
outgoingRules map[string]RuleSet
incomingRules map[string]RuleSet
wgNetwork *net.IPNet
decoders sync.Pool
wgIface IFaceMapper
resetHook func() error
outgoingRules map[string]RuleSet
incomingRules map[string]RuleSet
wgNetwork *net.IPNet
decoders sync.Pool
wgIface IFaceMapper
nativeFirewall firewall.Manager
mutex sync.RWMutex
}
@@ -52,6 +56,20 @@ type decoder struct {
// Create userspace firewall manager constructor
func Create(iface IFaceMapper) (*Manager, error) {
return create(iface)
}
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
mgr, err := create(iface)
if err != nil {
return nil, err
}
mgr.nativeFirewall = nativeFirewall
return mgr, nil
}
func create(iface IFaceMapper) (*Manager, error) {
m := &Manager{
decoders: sync.Pool{
New: func() any {
@@ -77,27 +95,50 @@ func Create(iface IFaceMapper) (*Manager, error) {
return m, nil
}
func (m *Manager) IsServerRouteSupported() bool {
if m.nativeFirewall == nil {
return false
} else {
return true
}
}
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
return m.nativeFirewall.InsertRoutingRules(pair)
}
// RemoveRoutingRules removes a routing firewall rule
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
return m.nativeFirewall.RemoveRoutingRules(pair)
}
// AddFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *Manager) AddFiltering(
ip net.IP,
proto fw.Protocol,
sPort *fw.Port,
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipsetName string,
comment string,
) (fw.Rule, error) {
) ([]firewall.Rule, error) {
r := Rule{
id: uuid.New().String(),
ip: ip,
ipLayer: layers.LayerTypeIPv6,
matchByIP: true,
direction: direction,
drop: action == fw.ActionDrop,
drop: action == firewall.ActionDrop,
comment: comment,
}
if ipNormalized := ip.To4(); ipNormalized != nil {
@@ -118,21 +159,21 @@ func (m *Manager) AddFiltering(
}
switch proto {
case fw.ProtocolTCP:
case firewall.ProtocolTCP:
r.protoLayer = layers.LayerTypeTCP
case fw.ProtocolUDP:
case firewall.ProtocolUDP:
r.protoLayer = layers.LayerTypeUDP
case fw.ProtocolICMP:
case firewall.ProtocolICMP:
r.protoLayer = layers.LayerTypeICMPv4
if r.ipLayer == layers.LayerTypeIPv6 {
r.protoLayer = layers.LayerTypeICMPv6
}
case fw.ProtocolALL:
case firewall.ProtocolALL:
r.protoLayer = layerTypeAll
}
m.mutex.Lock()
if direction == fw.RuleDirectionIN {
if direction == firewall.RuleDirectionIN {
if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(RuleSet)
}
@@ -144,12 +185,11 @@ func (m *Manager) AddFiltering(
m.outgoingRules[r.ip.String()][r.id] = r
}
m.mutex.Unlock()
return &r, nil
return []firewall.Rule{&r}, nil
}
// DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule fw.Rule) error {
func (m *Manager) DeleteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -158,7 +198,7 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
}
if r.direction == fw.RuleDirectionIN {
if r.direction == firewall.RuleDirectionIN {
_, ok := m.incomingRules[r.ip.String()][r.id]
if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
@@ -322,7 +362,7 @@ func (m *Manager) AddUDPPacketHook(
protoLayer: layers.LayerTypeUDP,
dPort: dPort,
ipLayer: layers.LayerTypeIPv6,
direction: fw.RuleDirectionOUT,
direction: firewall.RuleDirectionOUT,
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
udpHook: hook,
}
@@ -333,7 +373,7 @@ func (m *Manager) AddUDPPacketHook(
m.mutex.Lock()
if in {
r.direction = fw.RuleDirectionIN
r.direction = firewall.RuleDirectionIN
if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(map[string]Rule)
}
@@ -370,8 +410,3 @@ func (m *Manager) RemovePacketHook(hookID string) error {
}
return fmt.Errorf("hook with given id not found")
}
// SetResetHook which will be executed in the end of Reset method
func (m *Manager) SetResetHook(hook func() error) {
m.resetHook = hook
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface"
)
@@ -125,24 +125,32 @@ func TestManagerDeleteRule(t *testing.T) {
return
}
err = m.DeleteRule(rule)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
for _, r := range rule {
err = m.DeleteRule(r)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
}
}
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; !ok {
t.Errorf("rule2 is not in the incomingRules")
for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
t.Errorf("rule2 is not in the incomingRules")
}
}
err = m.DeleteRule(rule2)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
for _, r := range rule2 {
err = m.DeleteRule(r)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
}
}
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; ok {
t.Errorf("rule2 is not in the incomingRules")
for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; ok {
t.Errorf("rule2 is not in the incomingRules")
}
}
}

View File

@@ -11,42 +11,27 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/iface"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
SetFilter(iface.PacketFilter) error
}
// Manager is a ACL rules manager
type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap)
Stop()
}
// DefaultManager uses firewall manager to handle
type DefaultManager struct {
manager firewall.Manager
firewall firewall.Manager
ipsetCounter int
rulesPairs map[string][]firewall.Rule
mutex sync.Mutex
}
type ipsetInfo struct {
name string
ipCount int
}
func newDefaultManager(fm firewall.Manager) *DefaultManager {
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
return &DefaultManager{
manager: fm,
firewall: fm,
rulesPairs: make(map[string][]firewall.Rule),
}
}
@@ -69,13 +54,13 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
time.Since(start), total)
}()
if d.manager == nil {
if d.firewall == nil {
log.Debug("firewall manager is not supported, skipping firewall rules")
return
}
defer func() {
if err := d.manager.Flush(); err != nil {
if err := d.firewall.Flush(); err != nil {
log.Error("failed to flush firewall rules: ", err)
}
}()
@@ -125,57 +110,35 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
)
}
applyFailed := false
newRulePairs := make(map[string][]firewall.Rule)
ipsetByRuleSelectors := make(map[string]*ipsetInfo)
// calculate which IP's can be grouped in by which ipset
// to do that we use rule selector (which is just rule properties without IP's)
for _, r := range rules {
selector := d.getRuleGroupingSelector(r)
ipset, ok := ipsetByRuleSelectors[selector]
if !ok {
ipset = &ipsetInfo{}
}
ipset.ipCount++
ipsetByRuleSelectors[selector] = ipset
}
ipsetByRuleSelectors := make(map[string]string)
for _, r := range rules {
// if this rule is member of rule selection with more than DefaultIPsCountForSet
// it's IP address can be used in the ipset for firewall manager which supports it
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)]
if ipset.name == "" {
selector := d.getRuleGroupingSelector(r)
ipsetName, ok := ipsetByRuleSelectors[selector]
if !ok {
d.ipsetCounter++
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter)
ipsetName = fmt.Sprintf("nb%07d", d.ipsetCounter)
ipsetByRuleSelectors[selector] = ipsetName
}
ipsetName := ipset.name
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
if err != nil {
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
applyFailed = true
d.rollBack(newRulePairs)
break
}
newRulePairs[pairID] = rulePair
}
if applyFailed {
log.Error("failed to apply firewall rules, rollback ACL to previous state")
for _, rules := range newRulePairs {
for _, rule := range rules {
if err := d.manager.DeleteRule(rule); err != nil {
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
continue
}
}
if len(rules) > 0 {
d.rulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair
}
return
}
for pairID, rules := range d.rulesPairs {
if _, ok := newRulePairs[pairID]; !ok {
for _, rule := range rules {
if err := d.manager.DeleteRule(rule); err != nil {
if err := d.firewall.DeleteRule(rule); err != nil {
log.Errorf("failed to delete firewall rule: %v", err)
continue
}
@@ -186,16 +149,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
d.rulesPairs = newRulePairs
}
// Stop ACL controller and clear firewall state
func (d *DefaultManager) Stop() {
d.mutex.Lock()
defer d.mutex.Unlock()
if err := d.manager.Reset(); err != nil {
log.WithError(err).Error("reset firewall state")
}
}
func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule,
ipsetName string,
@@ -205,14 +158,14 @@ func (d *DefaultManager) protoRuleToFirewallRule(
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
}
protocol := convertToFirewallProtocol(r.Protocol)
if protocol == firewall.ProtocolUnknown {
return "", nil, fmt.Errorf("invalid protocol type: %d, skipping firewall rule", r.Protocol)
protocol, err := convertToFirewallProtocol(r.Protocol)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
}
action := convertFirewallAction(r.Action)
if action == firewall.ActionUnknown {
return "", nil, fmt.Errorf("invalid action type: %d, skipping firewall rule", r.Action)
action, err := convertFirewallAction(r.Action)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
}
var port *firewall.Port
@@ -232,7 +185,6 @@ func (d *DefaultManager) protoRuleToFirewallRule(
}
var rules []firewall.Rule
var err error
switch r.Direction {
case mgmProto.FirewallRule_IN:
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
@@ -246,7 +198,6 @@ func (d *DefaultManager) protoRuleToFirewallRule(
return "", nil, err
}
d.rulesPairs[ruleID] = rules
return ruleID, rules, nil
}
@@ -259,24 +210,24 @@ func (d *DefaultManager) addInRules(
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule
rule, err := d.manager.AddFiltering(
rule, err := d.firewall.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
rules = append(rules, rule)
rules = append(rules, rule...)
if shouldSkipInvertedRule(protocol, port) {
return rules, nil
}
rule, err = d.manager.AddFiltering(
rule, err = d.firewall.AddFiltering(
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
return append(rules, rule), nil
return append(rules, rule...), nil
}
func (d *DefaultManager) addOutRules(
@@ -288,24 +239,24 @@ func (d *DefaultManager) addOutRules(
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule
rule, err := d.manager.AddFiltering(
rule, err := d.firewall.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
rules = append(rules, rule)
rules = append(rules, rule...)
if shouldSkipInvertedRule(protocol, port) {
return rules, nil
}
rule, err = d.manager.AddFiltering(
rule, err = d.firewall.AddFiltering(
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
return append(rules, rule), nil
return append(rules, rule...), nil
}
// getRuleID() returns unique ID for the rule based on its parameters.
@@ -461,18 +412,29 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
}
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol {
func (d *DefaultManager) rollBack(newRulePairs map[string][]firewall.Rule) {
log.Debugf("rollback ACL to previous state")
for _, rules := range newRulePairs {
for _, rule := range rules {
if err := d.firewall.DeleteRule(rule); err != nil {
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
}
}
}
}
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) (firewall.Protocol, error) {
switch protocol {
case mgmProto.FirewallRule_TCP:
return firewall.ProtocolTCP
return firewall.ProtocolTCP, nil
case mgmProto.FirewallRule_UDP:
return firewall.ProtocolUDP
return firewall.ProtocolUDP, nil
case mgmProto.FirewallRule_ICMP:
return firewall.ProtocolICMP
return firewall.ProtocolICMP, nil
case mgmProto.FirewallRule_ALL:
return firewall.ProtocolALL
return firewall.ProtocolALL, nil
default:
return firewall.ProtocolUnknown
return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
}
}
@@ -480,13 +442,13 @@ func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) boo
return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil
}
func convertFirewallAction(action mgmProto.FirewallRuleAction) firewall.Action {
func convertFirewallAction(action mgmProto.FirewallRuleAction) (firewall.Action, error) {
switch action {
case mgmProto.FirewallRule_ACCEPT:
return firewall.ActionAccept
return firewall.ActionAccept, nil
case mgmProto.FirewallRule_DROP:
return firewall.ActionDrop
return firewall.ActionDrop, nil
default:
return firewall.ActionUnknown
return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action)
}
}

View File

@@ -1,28 +0,0 @@
//go:build !linux || android
package acl
import (
"fmt"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
)
// Create creates a firewall manager instance
func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
if iface.IsUserspaceBind() {
// use userspace packet filtering firewall
fm, err := uspfilter.Create(iface)
if err != nil {
return nil, err
}
if err := fm.AllowNetbird(); err != nil {
log.Warnf("failed to allow netbird interface traffic: %v", err)
}
return newDefaultManager(fm), nil
}
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
}

View File

@@ -1,77 +0,0 @@
//go:build !android
package acl
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/iptables"
"github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/checkfw"
)
// Create creates a firewall manager instance for the Linux
func Create(iface IFaceMapper) (*DefaultManager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
var fm firewall.Manager
var err error
checkResult := checkfw.Check()
switch checkResult {
case checkfw.IPTABLES, checkfw.IPTABLESWITHV6:
log.Debug("creating an iptables firewall manager for access control")
ipv6Supported := checkResult == checkfw.IPTABLESWITHV6
if fm, err = iptables.Create(iface, ipv6Supported); err != nil {
log.Infof("failed to create iptables manager for access control: %s", err)
}
case checkfw.NFTABLES:
log.Debug("creating an nftables firewall manager for access control")
if fm, err = nftables.Create(iface); err != nil {
log.Debugf("failed to create nftables manager for access control: %s", err)
}
}
var resetHookForUserspace func() error
if fm != nil && err == nil {
// err shadowing is used here, to ignore this error
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
resetHookForUserspace = fm.Reset
}
if iface.IsUserspaceBind() {
// use userspace packet filtering firewall
usfm, err := uspfilter.Create(iface)
if err != nil {
log.Debugf("failed to create userspace filtering firewall: %s", err)
return nil, err
}
// set kernel space firewall Reset as hook for userspace firewall
// manager Reset method, to clean up
if resetHookForUserspace != nil {
usfm.SetResetHook(resetHookForUserspace)
}
// to be consistent for any future extensions.
// ignore this error
if err := usfm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
fm = usfm
}
if fm == nil || err != nil {
log.Errorf("failed to create firewall manager: %s", err)
// no firewall manager found or initialized correctly
return nil, err
}
return newDefaultManager(fm), nil
}

View File

@@ -1,11 +1,14 @@
package acl
import (
"context"
"net"
"testing"
"github.com/golang/mock/gomock"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/iface"
mgmProto "github.com/netbirdio/netbird/management/proto"
@@ -49,12 +52,15 @@ func TestDefaultManager(t *testing.T) {
}).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
acl, err := Create(ifaceMock)
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
if err != nil {
t.Errorf("create ACL manager: %v", err)
t.Errorf("create firewall: %v", err)
return
}
defer acl.Stop()
defer func(fw manager.Manager) {
_ = fw.Reset()
}(fw)
acl := NewDefaultManager(fw)
t.Run("apply firewall rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap)
@@ -339,12 +345,15 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
}).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
acl, err := Create(ifaceMock)
fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
if err != nil {
t.Errorf("create ACL manager: %v", err)
t.Errorf("create firewall: %v", err)
return
}
defer acl.Stop()
defer func(fw manager.Manager) {
_ = fw.Reset()
}(fw)
acl := NewDefaultManager(fw)
acl.ApplyFiltering(networkMap)

View File

@@ -1,3 +0,0 @@
//go:build !linux || android
package checkfw

View File

@@ -1,56 +0,0 @@
//go:build !android
package checkfw
import (
"os"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
)
const (
// UNKNOWN is the default value for the firewall type for unknown firewall type
UNKNOWN FWType = iota
// IPTABLES is the value for the iptables firewall type
IPTABLES
// IPTABLESWITHV6 is the value for the iptables firewall type with ipv6
IPTABLESWITHV6
// NFTABLES is the value for the nftables firewall type
NFTABLES
)
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
// Check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func Check() FWType {
nf := nftables.Conn{}
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
return NFTABLES
}
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err == nil {
if isIptablesClientAvailable(ip) {
ipSupport := IPTABLES
ipv6, ip6Err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if ip6Err == nil {
if isIptablesClientAvailable(ipv6) {
ipSupport = IPTABLESWITHV6
}
}
return ipSupport
}
}
return UNKNOWN
}
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}

View File

@@ -25,13 +25,30 @@ const (
type osManagerType int
func (t osManagerType) String() string {
switch t {
case netbirdManager:
return "netbird"
case fileManager:
return "file"
case networkManager:
return "networkManager"
case systemdManager:
return "systemd"
case resolvConfManager:
return "resolvconf"
default:
return "unknown"
}
}
func newHostManager(wgInterface WGIface) (hostManager, error) {
osManager, err := getOSDNSManagerType()
if err != nil {
return nil, err
}
log.Debugf("discovered mode is: %d", osManager)
log.Debugf("discovered mode is: %s", osManager)
switch osManager {
case networkManager:
return newNetworkManagerDbusConfigurator(wgInterface)
@@ -65,7 +82,6 @@ func getOSDNSManagerType() (osManagerType, error) {
return netbirdManager, nil
}
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
log.Debugf("is nm running on supported v? %t", isNetworkManagerSupportedVersion())
return networkManager, nil
}
if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {

View File

@@ -17,6 +17,8 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -115,6 +117,7 @@ type Engine struct {
statusRecorder *peer.Status
firewall manager.Manager
routeManager routemanager.Manager
acl acl.Manager
@@ -231,6 +234,19 @@ func (e *Engine) Start() error {
return err
}
e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface)
if err != nil {
log.Errorf("failed creating firewall manager: %s", err)
}
if e.firewall != nil && e.firewall.IsServerRouteSupported() {
err = e.routeManager.EnableServerRouter(e.firewall)
if err != nil {
e.close()
return err
}
}
err = e.wgInterface.Configure(myPrivateKey.String(), e.config.WgPort)
if err != nil {
log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIFaceName, err.Error())
@@ -258,10 +274,8 @@ func (e *Engine) Start() error {
e.udpMux = mux
}
if acl, err := acl.Create(e.wgInterface); err != nil {
log.Errorf("failed to create ACL manager, policy will not work: %s", err.Error())
} else {
e.acl = acl
if e.firewall != nil {
e.acl = acl.NewDefaultManager(e.firewall)
}
err = e.dnsServer.Initialize()
@@ -1044,8 +1058,11 @@ func (e *Engine) close() {
e.dnsServer.Stop()
}
if e.acl != nil {
e.acl.Stop()
if e.firewall != nil {
err := e.firewall.Reset()
if err != nil {
log.Warnf("failed to reset firewall: %s", err)
}
}
}

View File

@@ -1,75 +0,0 @@
package routemanager
var insertRuleTestCases = []struct {
name string
inputPair routerPair
ipVersion string
}{
{
name: "Insert Forwarding IPV4 Rule",
inputPair: routerPair{
ID: "zxa",
source: "100.100.100.1/32",
destination: "100.100.200.0/24",
masquerade: false,
},
ipVersion: ipv4,
},
{
name: "Insert Forwarding And Nat IPV4 Rules",
inputPair: routerPair{
ID: "zxa",
source: "100.100.100.1/32",
destination: "100.100.200.0/24",
masquerade: true,
},
ipVersion: ipv4,
},
{
name: "Insert Forwarding IPV6 Rule",
inputPair: routerPair{
ID: "zxa",
source: "fc00::1/128",
destination: "fc12::/64",
masquerade: false,
},
ipVersion: ipv6,
},
{
name: "Insert Forwarding And Nat IPV6 Rules",
inputPair: routerPair{
ID: "zxa",
source: "fc00::1/128",
destination: "fc12::/64",
masquerade: true,
},
ipVersion: ipv6,
},
}
var removeRuleTestCases = []struct {
name string
inputPair routerPair
ipVersion string
}{
{
name: "Remove Forwarding And Nat IPV4 Rules",
inputPair: routerPair{
ID: "zxa",
source: "100.100.100.1/32",
destination: "100.100.200.0/24",
masquerade: true,
},
ipVersion: ipv4,
},
{
name: "Remove Forwarding And Nat IPV6 Rules",
inputPair: routerPair{
ID: "zxa",
source: "fc00::1/128",
destination: "fc12::/64",
masquerade: true,
},
ipVersion: ipv6,
},
}

View File

@@ -1,12 +0,0 @@
package routemanager
type firewallManager interface {
// RestoreOrCreateContainers restores or creates a firewall container set of rules, tables and default rules
RestoreOrCreateContainers() error
// InsertRoutingRules inserts a routing firewall rule
InsertRoutingRules(pair routerPair) error
// RemoveRoutingRules removes a routing firewall rule
RemoveRoutingRules(pair routerPair) error
// CleanRoutingRules cleans a firewall set of containers
CleanRoutingRules()
}

View File

@@ -1,55 +0,0 @@
//go:build !android
package routemanager
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/checkfw"
)
const (
ipv6Forwarding = "netbird-rt-ipv6-forwarding"
ipv4Forwarding = "netbird-rt-ipv4-forwarding"
ipv6Nat = "netbird-rt-ipv6-nat"
ipv4Nat = "netbird-rt-ipv4-nat"
natFormat = "netbird-nat-%s"
forwardingFormat = "netbird-fwd-%s"
inNatFormat = "netbird-nat-in-%s"
inForwardingFormat = "netbird-fwd-in-%s"
ipv6 = "ipv6"
ipv4 = "ipv4"
)
func genKey(format string, input string) string {
return fmt.Sprintf(format, input)
}
// newFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
func newFirewall(parentCTX context.Context) (firewallManager, error) {
checkResult := checkfw.Check()
switch checkResult {
case checkfw.IPTABLES, checkfw.IPTABLESWITHV6:
log.Debug("creating an iptables firewall manager for route rules")
ipv6Supported := checkResult == checkfw.IPTABLESWITHV6
return newIptablesManager(parentCTX, ipv6Supported)
case checkfw.NFTABLES:
log.Info("creating an nftables firewall manager for route rules")
return newNFTablesManager(parentCTX), nil
}
return nil, fmt.Errorf("couldn't initialize nftables or iptables clients. Using a dummy firewall manager for route rules")
}
func getInPair(pair routerPair) routerPair {
return routerPair{
ID: pair.ID,
// invert source/destination
source: pair.destination,
destination: pair.source,
masquerade: pair.masquerade,
}
}

View File

@@ -1,15 +0,0 @@
//go:build !linux
// +build !linux
package routemanager
import (
"context"
"fmt"
"runtime"
)
// newFirewall returns a nil manager
func newFirewall(context.Context) (firewallManager, error) {
return nil, fmt.Errorf("firewall not supported on %s", runtime.GOOS)
}

View File

@@ -1,487 +0,0 @@
//go:build !android
package routemanager
import (
"context"
"fmt"
"net/netip"
"os/exec"
"strings"
"sync"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
)
func isIptablesSupported() bool {
_, err4 := exec.LookPath("iptables")
_, err6 := exec.LookPath("ip6tables")
return err4 == nil && err6 == nil
}
// constants needed to manage and create iptable rules
const (
iptablesFilterTable = "filter"
iptablesNatTable = "nat"
iptablesForwardChain = "FORWARD"
iptablesPostRoutingChain = "POSTROUTING"
iptablesRoutingNatChain = "NETBIRD-RT-NAT"
iptablesRoutingForwardingChain = "NETBIRD-RT-FWD"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
)
// some presets for building nftable rules
var (
iptablesDefaultForwardingRule = []string{"-j", iptablesRoutingForwardingChain, "-m", "comment", "--comment"}
iptablesDefaultNetbirdForwardingRule = []string{"-j", "RETURN"}
iptablesDefaultNatRule = []string{"-j", iptablesRoutingNatChain, "-m", "comment", "--comment"}
iptablesDefaultNetbirdNatRule = []string{"-j", "RETURN"}
)
type iptablesManager struct {
ctx context.Context
stop context.CancelFunc
ipv4Client *iptables.IPTables
ipv6Client *iptables.IPTables
rules map[string]map[string][]string
mux sync.Mutex
}
func newIptablesManager(parentCtx context.Context, ipv6Supported bool) (*iptablesManager, error) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return nil, fmt.Errorf("failed to initialize iptables for ipv4: %s", err)
}
ctx, cancel := context.WithCancel(parentCtx)
manager := &iptablesManager{
ctx: ctx,
stop: cancel,
ipv4Client: ipv4Client,
rules: make(map[string]map[string][]string),
}
if ipv6Supported {
manager.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
log.Warnf("failed to initialize iptables for ipv6: %s. Routes for this protocol won't be applied.", err)
}
}
return manager, nil
}
// CleanRoutingRules cleans existing iptables resources that we created by the agent
func (i *iptablesManager) CleanRoutingRules() {
i.mux.Lock()
defer i.mux.Unlock()
err := i.cleanJumpRules()
if err != nil {
log.Error(err)
}
log.Debug("flushing tables")
errMSGFormat := "iptables: failed cleaning %s chain %s,error: %v"
if i.ipv4Client != nil {
err = i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
if err != nil {
log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
}
err = i.ipv4Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
if err != nil {
log.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
}
}
if i.ipv6Client != nil {
err = i.ipv6Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
if err != nil {
log.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
}
err = i.ipv6Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
if err != nil {
log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
}
}
log.Info("done cleaning up iptables rules")
}
// RestoreOrCreateContainers restores existing iptables containers (chains and rules)
// if they don't exist, we create them
func (i *iptablesManager) RestoreOrCreateContainers() error {
i.mux.Lock()
defer i.mux.Unlock()
if i.rules[ipv4][ipv4Forwarding] != nil && i.rules[ipv6][ipv6Forwarding] != nil {
return nil
}
errMSGFormat := "iptables: failed creating %s chain %s,error: %v"
if i.ipv4Client != nil {
err := createChain(i.ipv4Client, iptablesFilterTable, iptablesRoutingForwardingChain)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
}
err = createChain(i.ipv4Client, iptablesNatTable, iptablesRoutingNatChain)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
}
err = i.restoreRules(i.ipv4Client)
if err != nil {
return fmt.Errorf("iptables: error while restoring ipv4 rules: %v", err)
}
}
if i.ipv6Client != nil {
err := createChain(i.ipv6Client, iptablesFilterTable, iptablesRoutingForwardingChain)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
}
err = createChain(i.ipv6Client, iptablesNatTable, iptablesRoutingNatChain)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
}
err = i.restoreRules(i.ipv6Client)
if err != nil {
return fmt.Errorf("iptables: error while restoring ipv6 rules: %v", err)
}
}
err := i.addJumpRules()
if err != nil {
return fmt.Errorf("iptables: error while creating jump rules: %v", err)
}
return nil
}
// addJumpRules create jump rules to send packets to NetBird chains
func (i *iptablesManager) addJumpRules() error {
err := i.cleanJumpRules()
if err != nil {
return err
}
if i.ipv4Client != nil {
rule := append(iptablesDefaultForwardingRule, ipv4Forwarding) //nolint:gocritic
err = i.ipv4Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
if err != nil {
return err
}
i.rules[ipv4][ipv4Forwarding] = rule
rule = append(iptablesDefaultNatRule, ipv4Nat) //nolint:gocritic
err = i.ipv4Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
if err != nil {
return err
}
i.rules[ipv4][ipv4Nat] = rule
}
if i.ipv6Client != nil {
rule := append(iptablesDefaultForwardingRule, ipv6Forwarding) //nolint:gocritic
err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
if err != nil {
return err
}
i.rules[ipv6][ipv6Forwarding] = rule
rule = append(iptablesDefaultNatRule, ipv6Nat) //nolint:gocritic
err = i.ipv6Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
if err != nil {
return err
}
i.rules[ipv6][ipv6Nat] = rule
}
return nil
}
// cleanJumpRules cleans jump rules that was sending packets to NetBird chains
func (i *iptablesManager) cleanJumpRules() error {
var err error
errMSGFormat := "iptables: failed cleaning rule from %s chain %s,err: %v"
rule, found := i.rules[ipv4][ipv4Forwarding]
if i.ipv4Client != nil {
if found {
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Forwarding)
err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv4, iptablesForwardChain, err)
}
}
rule, found = i.rules[ipv4][ipv4Nat]
if found {
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Nat)
err = i.ipv4Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err)
}
}
}
if i.ipv6Client == nil {
rule, found = i.rules[ipv6][ipv6Forwarding]
if found {
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Forwarding)
err = i.ipv6Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv6, iptablesForwardChain, err)
}
}
rule, found = i.rules[ipv6][ipv6Nat]
if found {
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Nat)
err = i.ipv6Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err)
}
}
}
return nil
}
func iptablesProtoToString(proto iptables.Protocol) string {
if proto == iptables.ProtocolIPv6 {
return ipv6
}
return ipv4
}
// restoreRules restores existing NetBird rules
func (i *iptablesManager) restoreRules(iptablesClient *iptables.IPTables) error {
ipVersion := iptablesProtoToString(iptablesClient.Proto())
if i.rules[ipVersion] == nil {
i.rules[ipVersion] = make(map[string][]string)
}
table := iptablesFilterTable
for _, chain := range []string{iptablesForwardChain, iptablesRoutingForwardingChain} {
rules, err := iptablesClient.List(table, chain)
if err != nil {
return err
}
for _, ruleString := range rules {
rule := strings.Fields(ruleString)
id := getRuleRouteID(rule)
if id != "" {
i.rules[ipVersion][id] = rule[2:]
}
}
}
table = iptablesNatTable
for _, chain := range []string{iptablesPostRoutingChain, iptablesRoutingNatChain} {
rules, err := iptablesClient.List(table, chain)
if err != nil {
return err
}
for _, ruleString := range rules {
rule := strings.Fields(ruleString)
id := getRuleRouteID(rule)
if id != "" {
i.rules[ipVersion][id] = rule[2:]
}
}
}
return nil
}
// createChain create NetBird chains
func createChain(iptables *iptables.IPTables, table, newChain string) error {
chains, err := iptables.ListChains(table)
if err != nil {
return fmt.Errorf("couldn't get %s %s table chains, error: %v", iptablesProtoToString(iptables.Proto()), table, err)
}
shouldCreateChain := true
for _, chain := range chains {
if chain == newChain {
shouldCreateChain = false
}
}
if shouldCreateChain {
err = iptables.NewChain(table, newChain)
if err != nil {
return fmt.Errorf("couldn't create %s chain %s in %s table, error: %v", iptablesProtoToString(iptables.Proto()), newChain, table, err)
}
if table == iptablesNatTable {
err = iptables.Append(table, newChain, iptablesDefaultNetbirdNatRule...)
} else {
err = iptables.Append(table, newChain, iptablesDefaultNetbirdForwardingRule...)
}
if err != nil {
return fmt.Errorf("couldn't create %s chain %s default rule, error: %v", iptablesProtoToString(iptables.Proto()), newChain, err)
}
}
return nil
}
// genRuleSpec generates rule specification with comment identifier
func genRuleSpec(jump, id, source, destination string) []string {
return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id}
}
// getRuleRouteID returns the rule ID if matches our prefix
func getRuleRouteID(rule []string) string {
for i, flag := range rule {
if flag == "--comment" {
id := rule[i+1]
if strings.HasPrefix(id, "netbird-") {
return id
}
}
}
return ""
}
// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain
func (i *iptablesManager) InsertRoutingRules(pair routerPair) error {
i.mux.Lock()
defer i.mux.Unlock()
err := i.insertRoutingRule(forwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, routingFinalForwardJump, pair)
if err != nil {
return err
}
err = i.insertRoutingRule(inForwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, routingFinalForwardJump, getInPair(pair))
if err != nil {
return err
}
if !pair.masquerade {
return nil
}
err = i.insertRoutingRule(natFormat, iptablesNatTable, iptablesRoutingNatChain, routingFinalNatJump, pair)
if err != nil {
return err
}
err = i.insertRoutingRule(inNatFormat, iptablesNatTable, iptablesRoutingNatChain, routingFinalNatJump, getInPair(pair))
if err != nil {
return err
}
return nil
}
// insertRoutingRule inserts an iptable rule
func (i *iptablesManager) insertRoutingRule(keyFormat, table, chain, jump string, pair routerPair) error {
var err error
prefix := netip.MustParsePrefix(pair.source)
ipVersion := ipv4
iptablesClient := i.ipv4Client
if prefix.Addr().Unmap().Is6() {
iptablesClient = i.ipv6Client
ipVersion = ipv6
}
if iptablesClient == nil {
return fmt.Errorf("unable to insert iptables routing rules. Iptables client is not initialized")
}
ruleKey := genKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, ruleKey, pair.source, pair.destination)
existingRule, found := i.rules[ipVersion][ruleKey]
if found {
err = iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("iptables: error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
}
delete(i.rules[ipVersion], ruleKey)
}
err = iptablesClient.Insert(table, chain, 1, rule...)
if err != nil {
return fmt.Errorf("iptables: error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
}
i.rules[ipVersion][ruleKey] = rule
return nil
}
// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
func (i *iptablesManager) RemoveRoutingRules(pair routerPair) error {
i.mux.Lock()
defer i.mux.Unlock()
err := i.removeRoutingRule(forwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(inForwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, getInPair(pair))
if err != nil {
return err
}
if !pair.masquerade {
return nil
}
err = i.removeRoutingRule(natFormat, iptablesNatTable, iptablesRoutingNatChain, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(inNatFormat, iptablesNatTable, iptablesRoutingNatChain, getInPair(pair))
if err != nil {
return err
}
return nil
}
// removeRoutingRule removes an iptables rule
func (i *iptablesManager) removeRoutingRule(keyFormat, table, chain string, pair routerPair) error {
var err error
prefix := netip.MustParsePrefix(pair.source)
ipVersion := ipv4
iptablesClient := i.ipv4Client
if prefix.Addr().Unmap().Is6() {
iptablesClient = i.ipv6Client
ipVersion = ipv6
}
if iptablesClient == nil {
return fmt.Errorf("unable to remove iptables routing rules. Iptables client is not initialized")
}
ruleKey := genKey(keyFormat, pair.ID)
existingRule, found := i.rules[ipVersion][ruleKey]
if found {
err = iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("iptables: error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
}
}
delete(i.rules[ipVersion], ruleKey)
return nil
}
func getIptablesRuleType(table string) string {
ruleType := "forwarding"
if table == iptablesNatTable {
ruleType = "nat"
}
return ruleType
}

View File

@@ -1,294 +0,0 @@
//go:build !android
package routemanager
import (
"context"
"testing"
"github.com/coreos/go-iptables/iptables"
"github.com/stretchr/testify/require"
)
func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
manager, err := newIptablesManager(context.TODO(), true)
require.NoError(t, err, "should return a valid iptables manager")
defer manager.CleanRoutingRules()
err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
require.Len(t, manager.rules, 2, "should have created maps for ipv4 and ipv6")
require.Len(t, manager.rules[ipv4], 2, "should have created minimal rules for ipv4")
exists, err := manager.ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesFilterTable, iptablesForwardChain)
require.True(t, exists, "forwarding rule should exist")
exists, err = manager.ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesNatTable, iptablesPostRoutingChain)
require.True(t, exists, "postrouting rule should exist")
require.Len(t, manager.rules[ipv6], 2, "should have created minimal rules for ipv6")
exists, err = manager.ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesFilterTable, iptablesForwardChain)
require.True(t, exists, "forwarding rule should exist")
exists, err = manager.ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesNatTable, iptablesPostRoutingChain)
require.True(t, exists, "postrouting rule should exist")
pair := routerPair{
ID: "abc",
source: "100.100.100.1/32",
destination: "100.100.100.0/24",
masquerade: true,
}
forward4RuleKey := genKey(forwardingFormat, pair.ID)
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.source, pair.destination)
err = manager.ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...)
require.NoError(t, err, "inserting rule should not return error")
nat4RuleKey := genKey(natFormat, pair.ID)
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.source, pair.destination)
err = manager.ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error")
pair = routerPair{
ID: "abc",
source: "fc00::1/128",
destination: "fc11::/64",
masquerade: true,
}
forward6RuleKey := genKey(forwardingFormat, pair.ID)
forward6Rule := genRuleSpec(routingFinalForwardJump, forward6RuleKey, pair.source, pair.destination)
err = manager.ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...)
require.NoError(t, err, "inserting rule should not return error")
nat6RuleKey := genKey(natFormat, pair.ID)
nat6Rule := genRuleSpec(routingFinalNatJump, nat6RuleKey, pair.source, pair.destination)
err = manager.ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...)
require.NoError(t, err, "inserting rule should not return error")
delete(manager.rules, ipv4)
delete(manager.rules, ipv6)
err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
require.Len(t, manager.rules[ipv4], 4, "should have restored all rules for ipv4")
foundRule, found := manager.rules[ipv4][forward4RuleKey]
require.True(t, found, "forwarding rule should exist in the map")
require.Equal(t, forward4Rule[:4], foundRule[:4], "stored forwarding rule should match")
foundRule, found = manager.rules[ipv4][nat4RuleKey]
require.True(t, found, "nat rule should exist in the map")
require.Equal(t, nat4Rule[:4], foundRule[:4], "stored nat rule should match")
require.Len(t, manager.rules[ipv6], 4, "should have restored all rules for ipv6")
foundRule, found = manager.rules[ipv6][forward6RuleKey]
require.True(t, found, "forwarding rule should exist in the map")
require.Equal(t, forward6Rule[:4], foundRule[:4], "stored forward rule should match")
foundRule, found = manager.rules[ipv6][nat6RuleKey]
require.True(t, found, "nat rule should exist in the map")
require.Equal(t, nat6Rule[:4], foundRule[:4], "stored nat rule should match")
}
func TestIptablesManager_InsertRoutingRules(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
for _, testCase := range insertRuleTestCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
iptablesClient := ipv4Client
if testCase.ipVersion == ipv6 {
iptablesClient = ipv6Client
}
manager := &iptablesManager{
ctx: ctx,
stop: cancel,
ipv4Client: ipv4Client,
ipv6Client: ipv6Client,
rules: make(map[string]map[string][]string),
}
defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
err = manager.InsertRoutingRules(testCase.inputPair)
require.NoError(t, err, "forwarding pair should be inserted")
forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
exists, err := iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
require.True(t, exists, "forwarding rule should exist")
foundRule, found := manager.rules[testCase.ipVersion][forwardRuleKey]
require.True(t, found, "forwarding rule should exist in the manager map")
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
exists, err = iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
require.True(t, exists, "income forwarding rule should exist")
foundRule, found = manager.rules[testCase.ipVersion][inForwardRuleKey]
require.True(t, found, "income forwarding rule should exist in the manager map")
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
if testCase.inputPair.masquerade {
require.True(t, exists, "nat rule should be created")
foundNatRule, foundNat := manager.rules[testCase.ipVersion][natRuleKey]
require.True(t, foundNat, "nat rule should exist in the map")
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[testCase.ipVersion][natRuleKey]
require.False(t, foundNat, "nat rule should not exist in the map")
}
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
if testCase.inputPair.masquerade {
require.True(t, exists, "income nat rule should be created")
foundNatRule, foundNat := manager.rules[testCase.ipVersion][inNatRuleKey]
require.True(t, foundNat, "income nat rule should exist in the map")
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[testCase.ipVersion][inNatRuleKey]
require.False(t, foundNat, "income nat rule should not exist in the map")
}
})
}
}
func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
for _, testCase := range removeRuleTestCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
iptablesClient := ipv4Client
if testCase.ipVersion == ipv6 {
iptablesClient = ipv6Client
}
manager := &iptablesManager{
ctx: ctx,
stop: cancel,
ipv4Client: ipv4Client,
ipv6Client: ipv6Client,
rules: make(map[string]map[string][]string),
}
defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...)
require.NoError(t, err, "inserting rule should not return error")
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, inForwardRule...)
require.NoError(t, err, "inserting rule should not return error")
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error")
delete(manager.rules, ipv4)
delete(manager.rules, ipv6)
err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveRoutingRules(testCase.inputPair)
require.NoError(t, err, "shouldn't return error")
exists, err := iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
require.False(t, exists, "forwarding rule should not exist")
_, found := manager.rules[testCase.ipVersion][forwardRuleKey]
require.False(t, found, "forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
require.False(t, exists, "income forwarding rule should not exist")
_, found = manager.rules[testCase.ipVersion][inForwardRuleKey]
require.False(t, found, "income forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
require.False(t, exists, "nat rule should not exist")
_, found = manager.rules[testCase.ipVersion][natRuleKey]
require.False(t, found, "nat rule should exist in the manager map")
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
require.False(t, exists, "income nat rule should not exist")
_, found = manager.rules[testCase.ipVersion][inNatRuleKey]
require.False(t, found, "income nat rule should exist in the manager map")
})
}
}

View File

@@ -7,6 +7,7 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface"
@@ -19,6 +20,7 @@ type Manager interface {
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
EnableServerRouter(firewall firewall.Manager) error
Stop()
}
@@ -35,19 +37,12 @@ type DefaultManager struct {
notifier *notifier
}
// NewManager returns a new route manager
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager {
srvRouter, err := newServerRouter(ctx, wgInterface)
if err != nil {
log.Errorf("server router is not supported: %s", err)
}
mCTX, cancel := context.WithCancel(ctx)
dm := &DefaultManager{
ctx: mCTX,
stop: cancel,
clientNetworks: make(map[string]*clientNetwork),
serverRouter: srvRouter,
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
@@ -61,6 +56,15 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
return dm
}
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
var err error
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall)
if err != nil {
return err
}
return nil
}
// Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop() {
m.stop()

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
@@ -37,6 +38,10 @@ func (m *MockManager) SetRouteChangeListener(listener listener.NetworkChangeList
}
func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error {
panic("implement me")
}
// Stop mock implementation of Stop from Manager interface
func (m *MockManager) Stop() {
if m.StopFunc != nil {

View File

@@ -1,571 +0,0 @@
//go:build !android
package routemanager
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
)
const (
nftablesTable = "netbird-rt"
nftablesRoutingForwardingChain = "netbird-rt-fwd"
nftablesRoutingNatChain = "netbird-rt-nat"
userDataAcceptForwardRuleSrc = "frwacceptsrc"
userDataAcceptForwardRuleDst = "frwacceptdst"
)
// constants needed to create nftable rules
const (
ipv4Len = 4
ipv4SrcOffset = 12
ipv4DestOffset = 16
ipv6Len = 16
ipv6SrcOffset = 8
ipv6DestOffset = 24
exprDirectionSource = "source"
exprDirectionDestination = "destination"
)
// some presets for building nftable rules
var (
zeroXor = binaryutil.NativeEndian.PutUint32(0)
zeroXor6 = append(binaryutil.NativeEndian.PutUint64(0), binaryutil.NativeEndian.PutUint64(0)...)
exprAllowRelatedEstablished = []expr.Any{
&expr.Ct{
Register: 1,
SourceRegister: false,
Key: 0,
},
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 4,
Mask: []uint8{0x6, 0x0, 0x0, 0x0},
Xor: zeroXor,
},
&expr.Cmp{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
exprCounterAccept = []expr.Any{
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
)
type nftablesManager struct {
ctx context.Context
stop context.CancelFunc
conn *nftables.Conn
tableIPv4 *nftables.Table
tableIPv6 *nftables.Table
chains map[string]map[string]*nftables.Chain
rules map[string]*nftables.Rule
filterTable *nftables.Table
defaultForwardRules []*nftables.Rule
mux sync.Mutex
}
func newNFTablesManager(parentCtx context.Context) *nftablesManager {
ctx, cancel := context.WithCancel(parentCtx)
return &nftablesManager{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
chains: make(map[string]map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
defaultForwardRules: make([]*nftables.Rule, 2),
}
}
// CleanRoutingRules cleans existing nftables rules from the system
func (n *nftablesManager) CleanRoutingRules() {
n.mux.Lock()
defer n.mux.Unlock()
log.Debug("flushing tables")
if n.tableIPv4 != nil && n.tableIPv6 != nil {
n.conn.FlushTable(n.tableIPv6)
n.conn.FlushTable(n.tableIPv4)
}
if n.defaultForwardRules[0] != nil {
err := n.eraseDefaultForwardRule()
if err != nil {
log.Errorf("failed to delete forward rule: %s", err)
}
}
log.Debugf("flushing tables result in: %v error", n.conn.Flush())
}
// RestoreOrCreateContainers restores existing nftables containers (tables and chains)
// if they don't exist, we create them
func (n *nftablesManager) RestoreOrCreateContainers() error {
n.mux.Lock()
defer n.mux.Unlock()
if n.tableIPv6 != nil && n.tableIPv4 != nil {
log.Debugf("nftables: containers already restored, skipping")
return nil
}
tables, err := n.conn.ListTables()
if err != nil {
return fmt.Errorf("nftables: unable to list tables: %v", err)
}
for _, table := range tables {
if table.Name == "filter" && table.Family == nftables.TableFamilyIPv4 {
log.Debugf("nftables: found filter table for ipv4")
n.filterTable = table
continue
}
if table.Name == nftablesTable {
if table.Family == nftables.TableFamilyIPv4 {
n.tableIPv4 = table
continue
}
n.tableIPv6 = table
}
}
if n.tableIPv4 == nil {
n.tableIPv4 = n.conn.AddTable(&nftables.Table{
Name: nftablesTable,
Family: nftables.TableFamilyIPv4,
})
}
if n.tableIPv6 == nil {
n.tableIPv6 = n.conn.AddTable(&nftables.Table{
Name: nftablesTable,
Family: nftables.TableFamilyIPv6,
})
}
chains, err := n.conn.ListChains()
if err != nil {
return fmt.Errorf("nftables: unable to list chains: %v", err)
}
n.chains[ipv4] = make(map[string]*nftables.Chain)
n.chains[ipv6] = make(map[string]*nftables.Chain)
for _, chain := range chains {
switch {
case chain.Table.Name == nftablesTable && chain.Table.Family == nftables.TableFamilyIPv4:
n.chains[ipv4][chain.Name] = chain
case chain.Table.Name == nftablesTable && chain.Table.Family == nftables.TableFamilyIPv6:
n.chains[ipv6][chain.Name] = chain
}
}
if _, found := n.chains[ipv4][nftablesRoutingForwardingChain]; !found {
n.chains[ipv4][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{
Name: nftablesRoutingForwardingChain,
Table: n.tableIPv4,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityNATDest + 1,
Type: nftables.ChainTypeFilter,
})
}
if _, found := n.chains[ipv4][nftablesRoutingNatChain]; !found {
n.chains[ipv4][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{
Name: nftablesRoutingNatChain,
Table: n.tableIPv4,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityNATSource - 1,
Type: nftables.ChainTypeNAT,
})
}
if _, found := n.chains[ipv6][nftablesRoutingForwardingChain]; !found {
n.chains[ipv6][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{
Name: nftablesRoutingForwardingChain,
Table: n.tableIPv6,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityNATDest + 1,
Type: nftables.ChainTypeFilter,
})
}
if _, found := n.chains[ipv6][nftablesRoutingNatChain]; !found {
n.chains[ipv6][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{
Name: nftablesRoutingNatChain,
Table: n.tableIPv6,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityNATSource - 1,
Type: nftables.ChainTypeNAT,
})
}
err = n.refreshRulesMap()
if err != nil {
return err
}
n.checkOrCreateDefaultForwardingRules()
err = n.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
func (n *nftablesManager) refreshRulesMap() error {
for _, registeredChains := range n.chains {
for _, chain := range registeredChains {
rules, err := n.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
n.rules[string(rule.UserData)] = rule
}
}
}
}
return nil
}
func (n *nftablesManager) eraseDefaultForwardRule() error {
if n.defaultForwardRules[0] == nil {
return nil
}
err := n.refreshDefaultForwardRule()
if err != nil {
return err
}
for i, r := range n.defaultForwardRules {
err = n.conn.DelRule(r)
if err != nil {
log.Errorf("failed to delete forward rule (%d): %s", i, err)
}
n.defaultForwardRules[i] = nil
}
return nil
}
func (n *nftablesManager) refreshDefaultForwardRule() error {
rules, err := n.conn.GetRules(n.defaultForwardRules[0].Table, n.defaultForwardRules[0].Chain)
if err != nil {
return fmt.Errorf("unable to list rules in forward chain: %s", err)
}
found := false
for i, r := range n.defaultForwardRules {
for _, rule := range rules {
if string(rule.UserData) == string(r.UserData) {
n.defaultForwardRules[i] = rule
found = true
break
}
}
}
if !found {
return fmt.Errorf("unable to find forward accept rule")
}
return nil
}
func (n *nftablesManager) acceptForwardRule(sourceNetwork string) error {
src := generateCIDRMatcherExpressions("source", sourceNetwork)
dst := generateCIDRMatcherExpressions("destination", "0.0.0.0/0")
var exprs []expr.Any
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
Kind: expr.VerdictAccept,
})...)
r := &nftables.Rule{
Table: n.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: n.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleSrc),
}
n.defaultForwardRules[0] = n.conn.AddRule(r)
src = generateCIDRMatcherExpressions("source", "0.0.0.0/0")
dst = generateCIDRMatcherExpressions("destination", sourceNetwork)
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
Kind: expr.VerdictAccept,
})...)
r = &nftables.Rule{
Table: n.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: n.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleDst),
}
n.defaultForwardRules[1] = n.conn.AddRule(r)
return nil
}
// checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled
func (n *nftablesManager) checkOrCreateDefaultForwardingRules() {
_, foundIPv4 := n.rules[ipv4Forwarding]
if !foundIPv4 {
n.rules[ipv4Forwarding] = n.conn.AddRule(&nftables.Rule{
Table: n.tableIPv4,
Chain: n.chains[ipv4][nftablesRoutingForwardingChain],
Exprs: exprAllowRelatedEstablished,
UserData: []byte(ipv4Forwarding),
})
}
_, foundIPv6 := n.rules[ipv6Forwarding]
if !foundIPv6 {
n.rules[ipv6Forwarding] = n.conn.AddRule(&nftables.Rule{
Table: n.tableIPv6,
Chain: n.chains[ipv6][nftablesRoutingForwardingChain],
Exprs: exprAllowRelatedEstablished,
UserData: []byte(ipv6Forwarding),
})
}
}
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (n *nftablesManager) InsertRoutingRules(pair routerPair) error {
n.mux.Lock()
defer n.mux.Unlock()
err := n.refreshRulesMap()
if err != nil {
return err
}
err = n.insertRoutingRule(forwardingFormat, nftablesRoutingForwardingChain, pair, false)
if err != nil {
return err
}
err = n.insertRoutingRule(inForwardingFormat, nftablesRoutingForwardingChain, getInPair(pair), false)
if err != nil {
return err
}
if pair.masquerade {
err = n.insertRoutingRule(natFormat, nftablesRoutingNatChain, pair, true)
if err != nil {
return err
}
err = n.insertRoutingRule(inNatFormat, nftablesRoutingNatChain, getInPair(pair), true)
if err != nil {
return err
}
}
if n.defaultForwardRules[0] == nil && n.filterTable != nil {
err = n.acceptForwardRule(pair.source)
if err != nil {
log.Errorf("unable to create default forward rule: %s", err)
}
log.Debugf("default accept forward rule added")
}
err = n.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
}
return nil
}
// insertRoutingRule inserts a nftable rule to the conn client flush queue
func (n *nftablesManager) insertRoutingRule(format, chain string, pair routerPair, isNat bool) error {
prefix := netip.MustParsePrefix(pair.source)
sourceExp := generateCIDRMatcherExpressions("source", pair.source)
destExp := generateCIDRMatcherExpressions("destination", pair.destination)
var expression []expr.Any
if isNat {
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
} else {
expression = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
}
ruleKey := genKey(format, pair.ID)
_, exists := n.rules[ruleKey]
if exists {
err := n.removeRoutingRule(format, pair)
if err != nil {
return err
}
}
if prefix.Addr().Unmap().Is4() {
n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv4,
Chain: n.chains[ipv4][chain],
Exprs: expression,
UserData: []byte(ruleKey),
})
} else {
n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv6,
Chain: n.chains[ipv6][chain],
Exprs: expression,
UserData: []byte(ruleKey),
})
}
return nil
}
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
n.mux.Lock()
defer n.mux.Unlock()
err := n.refreshRulesMap()
if err != nil {
return err
}
err = n.removeRoutingRule(forwardingFormat, pair)
if err != nil {
return err
}
err = n.removeRoutingRule(inForwardingFormat, getInPair(pair))
if err != nil {
return err
}
err = n.removeRoutingRule(natFormat, pair)
if err != nil {
return err
}
err = n.removeRoutingRule(inNatFormat, getInPair(pair))
if err != nil {
return err
}
if len(n.rules) == 2 && n.defaultForwardRules[0] != nil {
err := n.eraseDefaultForwardRule()
if err != nil {
log.Errorf("failed to delete default fwd rule: %s", err)
}
}
err = n.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
}
log.Debugf("nftables: removed rules for %s", pair.destination)
return nil
}
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) error {
ruleKey := genKey(format, pair.ID)
rule, found := n.rules[ruleKey]
if found {
ruleType := "forwarding"
if rule.Chain.Type == nftables.ChainTypeNAT {
ruleType = "nat"
}
err := n.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.destination, err)
}
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.destination)
delete(n.rules, ruleKey)
}
return nil
}
// getPayloadDirectives get expression directives based on ip version and direction
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
switch {
case direction == exprDirectionSource && isIPv4:
return ipv4SrcOffset, ipv4Len, zeroXor
case direction == exprDirectionDestination && isIPv4:
return ipv4DestOffset, ipv4Len, zeroXor
case direction == exprDirectionSource && isIPv6:
return ipv6SrcOffset, ipv6Len, zeroXor6
case direction == exprDirectionDestination && isIPv6:
return ipv6DestOffset, ipv6Len, zeroXor6
default:
panic("no matched payload directive")
}
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(direction string, cidr string) []expr.Any {
ip, network, _ := net.ParseCIDR(cidr)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
offSet, packetLen, zeroXor := getPayloadDirectives(direction, add.Is4(), add.Is6())
return []expr.Any{
// fetch src add
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offSet,
Len: packetLen,
},
// net mask
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: packetLen,
Mask: network.Mask,
Xor: zeroXor,
},
// net address
&expr.Cmp{
Register: 1,
Data: add.AsSlice(),
},
}
}

View File

@@ -1,324 +0,0 @@
//go:build !android
package routemanager
import (
"context"
"testing"
"github.com/google/nftables"
"github.com/google/nftables/expr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/checkfw"
)
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
if checkfw.Check() != checkfw.NFTABLES {
t.Skip("nftables not supported on this OS")
}
manager := newNFTablesManager(context.TODO())
nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv4")
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv6")
require.Len(t, manager.rules, 2, "should have created rules for ipv4 and ipv6")
pair := routerPair{
ID: "abc",
source: "100.100.100.1/32",
destination: "100.100.100.0/24",
masquerade: true,
}
sourceExp := generateCIDRMatcherExpressions("source", pair.source)
destExp := generateCIDRMatcherExpressions("destination", pair.destination)
forward4Exp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
forward4RuleKey := genKey(forwardingFormat, pair.ID)
inserted4Forwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.tableIPv4,
Chain: manager.chains[ipv4][nftablesRoutingForwardingChain],
Exprs: forward4Exp,
UserData: []byte(forward4RuleKey),
})
nat4Exp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
nat4RuleKey := genKey(natFormat, pair.ID)
inserted4Nat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.tableIPv4,
Chain: manager.chains[ipv4][nftablesRoutingNatChain],
Exprs: nat4Exp,
UserData: []byte(nat4RuleKey),
})
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
pair = routerPair{
ID: "xyz",
source: "fc00::1/128",
destination: "fc11::/64",
masquerade: true,
}
sourceExp = generateCIDRMatcherExpressions("source", pair.source)
destExp = generateCIDRMatcherExpressions("destination", pair.destination)
forward6Exp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
forward6RuleKey := genKey(forwardingFormat, pair.ID)
inserted6Forwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.tableIPv6,
Chain: manager.chains[ipv6][nftablesRoutingForwardingChain],
Exprs: forward6Exp,
UserData: []byte(forward6RuleKey),
})
nat6Exp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
nat6RuleKey := genKey(natFormat, pair.ID)
inserted6Nat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.tableIPv6,
Chain: manager.chains[ipv6][nftablesRoutingNatChain],
Exprs: nat6Exp,
UserData: []byte(nat6RuleKey),
})
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
manager.tableIPv4 = nil
manager.tableIPv6 = nil
err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv4")
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv6")
require.Len(t, manager.rules, 6, "should have restored all rules for ipv4 and ipv6")
foundRule, found := manager.rules[forward4RuleKey]
require.True(t, found, "forwarding rule should exist in the map")
assert.Equal(t, inserted4Forwarding.Exprs, foundRule.Exprs, "stored forwarding rule expressions should match")
foundRule, found = manager.rules[nat4RuleKey]
require.True(t, found, "nat rule should exist in the map")
// match len of output as nftables client doesn't return expressions with masquerade expression
assert.ElementsMatch(t, inserted4Nat.Exprs[:len(foundRule.Exprs)], foundRule.Exprs, "stored nat rule expressions should match")
foundRule, found = manager.rules[forward6RuleKey]
require.True(t, found, "forwarding rule should exist in the map")
assert.Equal(t, inserted6Forwarding.Exprs, foundRule.Exprs, "stored forward rule should match")
foundRule, found = manager.rules[nat6RuleKey]
require.True(t, found, "nat rule should exist in the map")
// match len of output as nftables client doesn't return expressions with masquerade expression
assert.ElementsMatch(t, inserted6Nat.Exprs[:len(foundRule.Exprs)], foundRule.Exprs, "stored nat rule should match")
}
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
if checkfw.Check() != checkfw.NFTABLES {
t.Skip("nftables not supported on this OS")
}
for _, testCase := range insertRuleTestCases {
t.Run(testCase.name, func(t *testing.T) {
manager := newNFTablesManager(context.TODO())
nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
err = manager.InsertRoutingRules(testCase.inputPair)
require.NoError(t, err, "forwarding pair should be inserted")
sourceExp := generateCIDRMatcherExpressions("source", testCase.inputPair.source)
destExp := generateCIDRMatcherExpressions("destination", testCase.inputPair.destination)
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
fwdRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
found := 0
for _, registeredChains := range manager.chains {
for _, chain := range registeredChains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
found = 1
}
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.inputPair.masquerade {
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
found := 0
for _, registeredChains := range manager.chains {
for _, chain := range registeredChains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
found = 1
}
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
sourceExp = generateCIDRMatcherExpressions("source", getInPair(testCase.inputPair).source)
destExp = generateCIDRMatcherExpressions("destination", getInPair(testCase.inputPair).destination)
testingExpression = append(sourceExp, destExp...) //nolint:gocritic
inFwdRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
found = 0
for _, registeredChains := range manager.chains {
for _, chain := range registeredChains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
found = 1
}
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.inputPair.masquerade {
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
found := 0
for _, registeredChains := range manager.chains {
for _, chain := range registeredChains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
found = 1
}
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
})
}
}
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
if checkfw.Check() != checkfw.NFTABLES {
t.Skip("nftables not supported on this OS")
}
for _, testCase := range removeRuleTestCases {
t.Run(testCase.name, func(t *testing.T) {
manager := newNFTablesManager(context.TODO())
nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
table := manager.tableIPv4
if testCase.ipVersion == ipv6 {
table = manager.tableIPv6
}
sourceExp := generateCIDRMatcherExpressions("source", testCase.inputPair.source)
destExp := generateCIDRMatcherExpressions("destination", testCase.inputPair.destination)
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: table,
Chain: manager.chains[testCase.ipVersion][nftablesRoutingForwardingChain],
Exprs: forwardExp,
UserData: []byte(forwardRuleKey),
})
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: table,
Chain: manager.chains[testCase.ipVersion][nftablesRoutingNatChain],
Exprs: natExp,
UserData: []byte(natRuleKey),
})
sourceExp = generateCIDRMatcherExpressions("source", getInPair(testCase.inputPair).source)
destExp = generateCIDRMatcherExpressions("destination", getInPair(testCase.inputPair).destination)
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: table,
Chain: manager.chains[testCase.ipVersion][nftablesRoutingForwardingChain],
Exprs: forwardExp,
UserData: []byte(inForwardRuleKey),
})
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: table,
Chain: manager.chains[testCase.ipVersion][nftablesRoutingNatChain],
Exprs: natExp,
UserData: []byte(inNatRuleKey),
})
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
manager.tableIPv4 = nil
manager.tableIPv6 = nil
err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveRoutingRules(testCase.inputPair)
require.NoError(t, err, "shouldn't return error")
for _, registeredChains := range manager.chains {
for _, chain := range registeredChains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 {
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
}
}
}
}
})
}
}

View File

@@ -1,24 +0,0 @@
package routemanager
import (
"net/netip"
"github.com/netbirdio/netbird/route"
)
type routerPair struct {
ID string
source string
destination string
masquerade bool
}
func routeToRouterPair(source string, route *route.Route) routerPair {
parsed := netip.MustParsePrefix(source).Masked()
return routerPair{
ID: route.ID,
source: parsed.String(),
destination: route.Network.Masked().String(),
masquerade: route.Masquerade,
}
}

View File

@@ -4,9 +4,10 @@ import (
"context"
"fmt"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface"
)
func newServerRouter(context.Context, *iface.WGIface) (serverRouter, error) {
func newServerRouter(context.Context, *iface.WGIface, firewall.Manager) (serverRouter, error) {
return nil, fmt.Errorf("server route not supported on this os")
}

View File

@@ -4,11 +4,12 @@ package routemanager
import (
"context"
"fmt"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
@@ -17,16 +18,11 @@ type defaultServerRouter struct {
mux sync.Mutex
ctx context.Context
routes map[string]*route.Route
firewall firewallManager
firewall firewall.Manager
wgInterface *iface.WGIface
}
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) (serverRouter, error) {
firewall, err := newFirewall(ctx)
if err != nil {
return nil, err
}
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager) (serverRouter, error) {
return &defaultServerRouter{
ctx: ctx,
routes: make(map[string]*route.Route),
@@ -38,13 +34,6 @@ func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) (serverRou
func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) error {
serverRoutesToRemove := make([]string, 0)
if len(routesMap) > 0 {
err := m.firewall.RestoreOrCreateContainers()
if err != nil {
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
}
}
for routeID := range m.routes {
update, found := routesMap[routeID]
if !found || !update.IsEqual(m.routes[routeID]) {
@@ -121,5 +110,22 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
}
func (m *defaultServerRouter) cleanUp() {
m.firewall.CleanRoutingRules()
m.mux.Lock()
defer m.mux.Unlock()
for _, r := range m.routes {
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), r))
if err != nil {
log.Warnf("failed to remove clean up route: %s", r.ID)
}
}
}
func routeToRouterPair(source string, route *route.Route) firewall.RouterPair {
parsed := netip.MustParsePrefix(source).Masked()
return firewall.RouterPair{
ID: route.ID,
Source: parsed.String(),
Destination: route.Network.Masked().String(),
Masquerade: route.Masquerade,
}
}

View File

@@ -1,3 +1,5 @@
//go:build !android
package routemanager
import (

3
go.mod
View File

@@ -30,6 +30,7 @@ require (
require (
fyne.io/fyne/v2 v2.1.4
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5
github.com/c-robinson/iplib v1.0.3
github.com/cilium/ebpf v0.10.0
github.com/coreos/go-iptables v0.7.0
@@ -109,6 +110,7 @@ require (
github.com/go-stack/stack v1.8.0 // indirect
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/btree v1.0.1 // indirect
github.com/google/s2a-go v0.1.4 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect
github.com/googleapis/gax-go/v2 v2.10.0 // indirect
@@ -154,6 +156,7 @@ require (
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 // indirect
honnef.co/go/tools v0.2.2 // indirect
k8s.io/apimachinery v0.23.5 // indirect
)

2
go.sum
View File

@@ -76,6 +76,8 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY=
github.com/bazelbuild/rules_go v0.30.0/go.mod h1:MC23Dc/wkXEyk3Wpq6lCqz0ZAYOZDw2DR5y3N1q2i7M=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=

View File

@@ -20,10 +20,10 @@ func NewWGIFace(iFaceName string, address string, mtu int, tunAdapter TunAdapter
return wgIFace, err
}
wgIFace.tun = newTunDevice(iFaceName, wgAddress, mtu, transportNet)
wgIFace.configurer = newWGConfigurer(iFaceName)
wgIFace.userspaceBind = !WireGuardModuleIsLoaded()
tun := newTunDevice(iFaceName, wgAddress, mtu, transportNet)
wgIFace.tun = tun
wgIFace.configurer = newWGConfigurer(tun)
wgIFace.userspaceBind = true
return wgIFace, nil
}

View File

@@ -3,7 +3,6 @@
package iface
import (
"fmt"
"os"
log "github.com/sirupsen/logrus"
@@ -11,14 +10,6 @@ import (
)
func (c *tunDevice) Create() error {
if WireGuardModuleIsLoaded() {
log.Infof("create tun interface with kernel WireGuard support: %s", c.DeviceName())
return c.createWithKernel()
}
if !tunModuleIsLoaded() {
return fmt.Errorf("couldn't check or load tun module")
}
log.Infof("create tun interface with userspace WireGuard support: %s", c.DeviceName())
var err error
c.netInterface, err = c.createWithUserspace()
@@ -27,7 +18,6 @@ func (c *tunDevice) Create() error {
}
return c.assignAddr()
}
// createWithKernel Creates a new WireGuard interface using kernel WireGuard module.
@@ -90,34 +80,7 @@ func (c *tunDevice) createWithKernel() error {
// assignAddr Adds IP address to the tunnel interface
func (c *tunDevice) assignAddr() error {
link := newWGLink(c.name)
//delete existing addresses
list, err := netlink.AddrList(link, 0)
if err != nil {
return err
}
if len(list) > 0 {
for _, a := range list {
addr := a
err = netlink.AddrDel(link, &addr)
if err != nil {
return err
}
}
}
log.Debugf("adding address %s to interface: %s", c.address.String(), c.name)
addr, _ := netlink.ParseAddr(c.address.String())
err = netlink.AddrAdd(link, addr)
if os.IsExist(err) {
log.Infof("interface %s already has the address: %s", c.name, c.address.String())
} else if err != nil {
return err
}
// On linux, the link must be brought up
err = netlink.LinkSetUp(link)
return err
return nil
}
type wgLink struct {

View File

@@ -10,10 +10,9 @@ import (
"golang.zx2c4.com/wireguard/ipc"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/iface/uspproxy"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
)
type tunDevice struct {
@@ -25,6 +24,7 @@ type tunDevice struct {
uapi net.Listener
wrapper *DeviceWrapper
close chan struct{}
tunDevice *device.Device
}
func newTunDevice(name string, address WGAddress, mtu int, transportNet transport.Net) *tunDevice {
@@ -46,6 +46,11 @@ func (c *tunDevice) WgAddress() WGAddress {
return c.address
}
func (t *tunDevice) Device() *device.Device {
// todo: potential nil pointer exception
return t.tunDevice
}
func (c *tunDevice) DeviceName() string {
return c.name
}
@@ -85,15 +90,14 @@ func (c *tunDevice) Close() error {
return err3
}
// createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation
func (c *tunDevice) createWithUserspace() (NetInterface, error) {
tunIface, err := tun.CreateTUN(c.name, c.mtu)
nsTun := uspproxy.NewNetStackTun(c.address.IP.String())
tunIface, err := nsTun.Create()
if err != nil {
return nil, err
}
c.wrapper = newDeviceWrapper(tunIface)
// We need to create a wireguard-go device and listen to configuration requests
tunDev := device.NewDevice(
c.wrapper,
c.iceBind,
@@ -105,33 +109,7 @@ func (c *tunDevice) createWithUserspace() (NetInterface, error) {
return nil, err
}
c.uapi, err = c.getUAPI(c.name)
if err != nil {
_ = tunIface.Close()
return nil, err
}
go func() {
for {
select {
case <-c.close:
log.Debugf("exit uapi.Accept()")
return
default:
}
uapiConn, uapiErr := c.uapi.Accept()
if uapiErr != nil {
log.Traceln("uapi Accept failed with error: ", uapiErr)
continue
}
go func() {
tunDev.IpcHandle(uapiConn)
log.Debugf("exit tunDevice.IpcHandle")
}()
}
}()
log.Debugln("UAPI listener started")
c.tunDevice = tunDev
return tunIface, nil
}

43
iface/uspproxy/proxy.go Normal file
View File

@@ -0,0 +1,43 @@
package uspproxy
import (
"context"
"net"
"github.com/armon/go-socks5"
log "github.com/sirupsen/logrus"
)
type Dialer interface {
Dial(ctx context.Context, network, addr string) (net.Conn, error)
}
// Proxy todo close server
type Proxy struct {
server *socks5.Server
}
func NewSocks5(dialer Dialer) (*Proxy, error) {
conf := &socks5.Config{
Dial: dialer.Dial,
}
server, err := socks5.New(conf)
if err != nil {
log.Debugf("failed to init socks5 proxy: %s", err)
return nil, err
}
return &Proxy{
server: server,
}, nil
}
func (s *Proxy) ListenAndServe(addr string) error {
go func() {
err := s.server.ListenAndServe("tcp", addr)
if err != nil {
log.Debugf("failed to start socks5 proxy: %s", err)
}
}()
return nil
}

55
iface/uspproxy/tun.go Normal file
View File

@@ -0,0 +1,55 @@
package uspproxy
import (
"context"
"net"
"net/netip"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
)
type NSDialer struct {
net *netstack.Net
}
func (d *NSDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err := d.net.Dial(network, addr)
if err != nil {
log.Debugf("failed to deal connection: %s", err)
}
return conn, err
}
type NetStackTun struct {
address string
proxy *Proxy
}
func NewNetStackTun(address string) *NetStackTun {
return &NetStackTun{
address: address,
}
}
func (t *NetStackTun) Create() (tun.Device, error) {
nsTunDev, tunNet, err := netstack.CreateNetTUN(
[]netip.Addr{netip.MustParseAddr(t.address)},
[]netip.Addr{netip.MustParseAddr("8.8.8.8")},
1420)
if err != nil {
return nil, err
}
dialer := &NSDialer{tunNet}
t.proxy, err = NewSocks5(dialer)
if err != nil {
// close nsTunDev
return nil, err
}
err = t.proxy.ListenAndServe("0.0.0.0:1234")
return nsTunDev, nil
}

View File

@@ -1,205 +0,0 @@
//go:build !android
package iface
import (
"fmt"
"net"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type wGConfigurer struct {
deviceName string
}
func newWGConfigurer(deviceName string) wGConfigurer {
return wGConfigurer{
deviceName: deviceName,
}
}
func (c *wGConfigurer) configureInterface(privateKey string, port int) error {
log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey)
if err != nil {
return err
}
fwmark := 0
config := wgtypes.Config{
PrivateKey: &key,
ReplacePeers: true,
FirewallMark: &fwmark,
ListenPort: &port,
}
err = c.configure(config)
if err != nil {
return fmt.Errorf(`received error "%w" while configuring interface %s with port %d`, err, c.deviceName, port)
}
return nil
}
func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
//parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true,
AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey,
Endpoint: endpoint,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = c.configure(config)
if err != nil {
return fmt.Errorf(`received error "%w" while updating peer on interface %s with settings: allowed ips %s, endpoint %s`, err, c.deviceName, allowedIps, endpoint.String())
}
return nil
}
func (c *wGConfigurer) removePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = c.configure(config)
if err != nil {
return fmt.Errorf(`received error "%w" while removing peer %s from interface %s`, err, peerKey, c.deviceName)
}
return nil
}
func (c *wGConfigurer) addAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
UpdateOnly: true,
ReplaceAllowedIPs: false,
AllowedIPs: []net.IPNet{*ipNet},
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = c.configure(config)
if err != nil {
return fmt.Errorf(`received error "%w" while adding allowed Ip to peer on interface %s with settings: allowed ips %s`, err, c.deviceName, allowedIP)
}
return nil
}
func (c *wGConfigurer) removeAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
existingPeer, err := c.getPeer(c.deviceName, peerKey)
if err != nil {
return err
}
newAllowedIPs := existingPeer.AllowedIPs
for i, existingAllowedIP := range existingPeer.AllowedIPs {
if existingAllowedIP.String() == ipNet.String() {
newAllowedIPs = append(existingPeer.AllowedIPs[:i], existingPeer.AllowedIPs[i+1:]...) //nolint:gocritic
break
}
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
UpdateOnly: true,
ReplaceAllowedIPs: true,
AllowedIPs: newAllowedIPs,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = c.configure(config)
if err != nil {
return fmt.Errorf(`received error "%w" while removing allowed IP from peer on interface %s with settings: allowed ips %s`, err, c.deviceName, allowedIP)
}
return nil
}
func (c *wGConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
wg, err := wgctrl.New()
if err != nil {
return wgtypes.Peer{}, err
}
defer func() {
err = wg.Close()
if err != nil {
log.Errorf("got error while closing wgctl: %v", err)
}
}()
wgDevice, err := wg.Device(ifaceName)
if err != nil {
return wgtypes.Peer{}, err
}
for _, peer := range wgDevice.Peers {
if peer.PublicKey.String() == peerPubKey {
return peer, nil
}
}
return wgtypes.Peer{}, fmt.Errorf("peer not found")
}
func (c *wGConfigurer) configure(config wgtypes.Config) error {
wg, err := wgctrl.New()
if err != nil {
return err
}
defer wg.Close()
// validate if device with name exists
_, err = wg.Device(c.deviceName)
if err != nil {
return err
}
log.Tracef("got Wireguard device %s", c.deviceName)
return wg.ConfigureDevice(c.deviceName, config)
}

View File

@@ -83,7 +83,10 @@ var (
if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err)
}
config.HttpConfig.IdpSignKeyRefreshEnabled = idpSignKeyRefreshEnabled
if cmd.Flag(idpSignKeyRefreshEnabledFlagName).Changed {
config.HttpConfig.IdpSignKeyRefreshEnabled = idpSignKeyRefreshEnabled
}
tlsEnabled := false
if mgmtLetsencryptDomain != "" || (config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "") {

View File

@@ -12,7 +12,8 @@ import (
const (
// ExitSetupFailed defines exit code
ExitSetupFailed = 1
ExitSetupFailed = 1
idpSignKeyRefreshEnabledFlagName = "idp-sign-key-refresh-enabled"
)
var (
@@ -62,7 +63,7 @@ func init() {
mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird")
mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max length is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain))
mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, "idp-sign-key-refresh-enabled", false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.")
mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, idpSignKeyRefreshEnabledFlagName, false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.")
mgmtCmd.Flags().BoolVar(&userDeleteFromIDPEnabled, "user-delete-from-idp", false, "Allows to delete user from IDP when user is deleted from account")
rootCmd.MarkFlagRequired("config") //nolint

View File

@@ -164,6 +164,9 @@ type Settings struct {
// JWTGroupsClaimName from which we extract groups name to add it to account groups
JWTGroupsClaimName string
// JWTAllowGroups list of groups to which users are allowed access
JWTAllowGroups []string `gorm:"serializer:json"`
// Extra is a dictionary of Account settings
Extra *account.ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"`
}
@@ -176,6 +179,7 @@ func (s *Settings) Copy() *Settings {
JWTGroupsEnabled: s.JWTGroupsEnabled,
JWTGroupsClaimName: s.JWTGroupsClaimName,
GroupsPropagationEnabled: s.GroupsPropagationEnabled,
JWTAllowGroups: s.JWTAllowGroups,
}
if s.Extra != nil {
settings.Extra = s.Extra.Copy()

View File

@@ -91,6 +91,9 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
if req.Settings.JwtGroupsClaimName != nil {
settings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName
}
if req.Settings.JwtAllowGroups != nil {
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
}
updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, settings)
if err != nil {
@@ -128,12 +131,18 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
}
func toAccountResponse(account *server.Account) *api.Account {
jwtAllowGroups := account.Settings.JWTAllowGroups
if jwtAllowGroups == nil {
jwtAllowGroups = []string{}
}
settings := api.AccountSettings{
PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()),
PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled,
GroupsPropagationEnabled: &account.Settings.GroupsPropagationEnabled,
JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled,
JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName,
JwtAllowGroups: &jwtAllowGroups,
}
if account.Settings.Extra != nil {

View File

@@ -95,6 +95,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
GroupsPropagationEnabled: br(false),
JwtGroupsClaimName: sr(""),
JwtGroupsEnabled: br(false),
JwtAllowGroups: &[]string{},
},
expectedArray: true,
expectedID: accountID,
@@ -112,6 +113,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
GroupsPropagationEnabled: br(false),
JwtGroupsClaimName: sr(""),
JwtGroupsEnabled: br(false),
JwtAllowGroups: &[]string{},
},
expectedArray: false,
expectedID: accountID,
@@ -121,7 +123,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
expectedBody: true,
requestType: http.MethodPut,
requestPath: "/api/accounts/" + accountID,
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\"}}"),
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"]}}"),
expectedStatus: http.StatusOK,
expectedSettings: api.AccountSettings{
PeerLoginExpiration: 15552000,
@@ -129,6 +131,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
GroupsPropagationEnabled: br(false),
JwtGroupsClaimName: sr("roles"),
JwtGroupsEnabled: br(true),
JwtAllowGroups: &[]string{"test"},
},
expectedArray: false,
expectedID: accountID,
@@ -146,6 +149,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
GroupsPropagationEnabled: br(true),
JwtGroupsClaimName: sr("groups"),
JwtGroupsEnabled: br(true),
JwtAllowGroups: &[]string{},
},
expectedArray: false,
expectedID: accountID,

View File

@@ -66,6 +66,12 @@ components:
description: Name of the claim from which we extract groups names to add it to account groups.
type: string
example: "roles"
jwt_allow_groups:
description: List of groups to which users are allowed access
type: array
items:
type: string
example: Administrators
extra:
$ref: '#/components/schemas/AccountExtraSettings'
required:
@@ -115,11 +121,11 @@ components:
format: date-time
example: 2023-05-05T09:00:35.477782Z
auto_groups:
description: Groups to auto-assign to peers registered by this user
description: Group IDs to auto-assign to peers registered by this user
type: array
items:
type: string
example: devs
example: ch8i4ug6lnn4g9hqv7m0
is_current:
description: Is true if authenticated user is the same as this user
type: boolean
@@ -154,11 +160,11 @@ components:
type: string
example: admin
auto_groups:
description: Groups to auto-assign to peers registered by this user
description: Group IDs to auto-assign to peers registered by this user
type: array
items:
type: string
example: devs
example: ch8i4ug6lnn4g9hqv7m0
is_blocked:
description: If set to true then user is blocked and can't use the system
type: boolean
@@ -183,11 +189,11 @@ components:
type: string
example: admin
auto_groups:
description: Groups to auto-assign to peers registered by this user
description: Group IDs to auto-assign to peers registered by this user
type: array
items:
type: string
example: devs
example: ch8i4ug6lnn4g9hqv7m0
is_service_user:
description: Is true if this user is a service user
type: boolean
@@ -405,7 +411,7 @@ components:
type: array
items:
type: string
example: "devs"
example: "ch8i4ug6lnn4g9hqv7m0"
updated_at:
description: Setup key last update date
type: string
@@ -460,7 +466,7 @@ components:
type: array
items:
type: string
example: "devs"
example: "ch8i4ug6lnn4g9hqv7m0"
usage_limit:
description: A number of times this key can be used. The value of 0 indicates the unlimited usage.
type: integer
@@ -621,13 +627,13 @@ components:
properties:
sources:
type: array
description: List of source groups
description: List of source group IDs
items:
type: string
example: "ch8i4ug6lnn4g9hqv7m1"
destinations:
type: array
description: List of destination groups
description: List of destination group IDs
items:
type: string
example: "ch8i4ug6lnn4g9hqv7m0"
@@ -645,12 +651,12 @@ components:
- type: object
properties:
sources:
description: Rule source groups
description: Rule source group IDs
type: array
items:
$ref: '#/components/schemas/GroupMinimum'
destinations:
description: Rule destination groups
description: Rule destination group IDs
type: array
items:
$ref: '#/components/schemas/GroupMinimum'
@@ -708,13 +714,13 @@ components:
- type: object
properties:
sources:
description: Policy rule source groups
description: Policy rule source group IDs
type: array
items:
type: string
example: "ch8i4ug6lnn4g9hqv797"
destinations:
description: Policy rule destination groups
description: Policy rule destination group IDs
type: array
items:
type: string
@@ -728,12 +734,12 @@ components:
- type: object
properties:
sources:
description: Policy rule source groups
description: Policy rule source group IDs
type: array
items:
$ref: '#/components/schemas/GroupMinimum'
destinations:
description: Policy rule destination groups
description: Policy rule destination group IDs
type: array
items:
$ref: '#/components/schemas/GroupMinimum'
@@ -834,7 +840,7 @@ components:
type: boolean
example: true
groups:
description: Route group tag groups
description: Group IDs containing routing peers
type: array
items:
type: string
@@ -891,17 +897,17 @@ components:
type: object
properties:
name:
description: Nameserver group name
description: Name of nameserver group name
type: string
maxLength: 40
minLength: 1
example: Google DNS
description:
description: Nameserver group description
description: Description of the nameserver group
type: string
example: Google DNS servers
nameservers:
description: Nameserver group
description: Nameserver list
minLength: 1
maxLength: 2
type: array
@@ -912,17 +918,17 @@ components:
type: boolean
example: true
groups:
description: Nameserver group tag groups
description: Distribution group IDs that defines group of peers that will use this nameserver group
type: array
items:
type: string
example: ch8i4ug6lnn4g9hqv7m0
primary:
description: Nameserver group primary status
description: Defines if a nameserver group is primary that resolves all domains. It should be true only if domains list is empty.
type: boolean
example: true
domains:
description: Nameserver group match domain list
description: Match domain list. It should be empty only if primary is true.
type: array
items:
type: string
@@ -930,7 +936,7 @@ components:
maxLength: 255
example: "example.com"
search_domains_enabled:
description: Nameserver group search domain status for match domains. It should be true only if domains list is not empty.
description: Search domain status for match domains. It should be true only if domains list is not empty.
type: boolean
example: true
required:

View File

@@ -160,6 +160,9 @@ type AccountSettings struct {
// GroupsPropagationEnabled Allows propagate the new user auto groups to peers that belongs to the user
GroupsPropagationEnabled *bool `json:"groups_propagation_enabled,omitempty"`
// JwtAllowGroups List of groups to which users are allowed access
JwtAllowGroups *[]string `json:"jwt_allow_groups,omitempty"`
// JwtGroupsClaimName Name of the claim from which we extract groups names to add it to account groups.
JwtGroupsClaimName *string `json:"jwt_groups_claim_name,omitempty"`
@@ -271,58 +274,58 @@ type NameserverNsType string
// NameserverGroup defines model for NameserverGroup.
type NameserverGroup struct {
// Description Nameserver group description
// Description Description of the nameserver group
Description string `json:"description"`
// Domains Nameserver group match domain list
// Domains Match domain list. It should be empty only if primary is true.
Domains []string `json:"domains"`
// Enabled Nameserver group status
Enabled bool `json:"enabled"`
// Groups Nameserver group tag groups
// Groups Distribution group IDs that defines group of peers that will use this nameserver group
Groups []string `json:"groups"`
// Id Nameserver group ID
Id string `json:"id"`
// Name Nameserver group name
// Name Name of nameserver group name
Name string `json:"name"`
// Nameservers Nameserver group
// Nameservers Nameserver list
Nameservers []Nameserver `json:"nameservers"`
// Primary Nameserver group primary status
// Primary Defines if a nameserver group is primary that resolves all domains. It should be true only if domains list is empty.
Primary bool `json:"primary"`
// SearchDomainsEnabled Nameserver group search domain status for match domains. It should be true only if domains list is not empty.
// SearchDomainsEnabled Search domain status for match domains. It should be true only if domains list is not empty.
SearchDomainsEnabled bool `json:"search_domains_enabled"`
}
// NameserverGroupRequest defines model for NameserverGroupRequest.
type NameserverGroupRequest struct {
// Description Nameserver group description
// Description Description of the nameserver group
Description string `json:"description"`
// Domains Nameserver group match domain list
// Domains Match domain list. It should be empty only if primary is true.
Domains []string `json:"domains"`
// Enabled Nameserver group status
Enabled bool `json:"enabled"`
// Groups Nameserver group tag groups
// Groups Distribution group IDs that defines group of peers that will use this nameserver group
Groups []string `json:"groups"`
// Name Nameserver group name
// Name Name of nameserver group name
Name string `json:"name"`
// Nameservers Nameserver group
// Nameservers Nameserver list
Nameservers []Nameserver `json:"nameservers"`
// Primary Nameserver group primary status
// Primary Defines if a nameserver group is primary that resolves all domains. It should be true only if domains list is empty.
Primary bool `json:"primary"`
// SearchDomainsEnabled Nameserver group search domain status for match domains. It should be true only if domains list is not empty.
// SearchDomainsEnabled Search domain status for match domains. It should be true only if domains list is not empty.
SearchDomainsEnabled bool `json:"search_domains_enabled"`
}
@@ -600,7 +603,7 @@ type PolicyRule struct {
// Description Policy rule friendly description
Description *string `json:"description,omitempty"`
// Destinations Policy rule destination groups
// Destinations Policy rule destination group IDs
Destinations []GroupMinimum `json:"destinations"`
// Enabled Policy rule status
@@ -618,7 +621,7 @@ type PolicyRule struct {
// Protocol Policy rule type of the traffic
Protocol PolicyRuleProtocol `json:"protocol"`
// Sources Policy rule source groups
// Sources Policy rule source group IDs
Sources []GroupMinimum `json:"sources"`
}
@@ -672,7 +675,7 @@ type PolicyRuleUpdate struct {
// Description Policy rule friendly description
Description *string `json:"description,omitempty"`
// Destinations Policy rule destination groups
// Destinations Policy rule destination group IDs
Destinations []string `json:"destinations"`
// Enabled Policy rule status
@@ -690,7 +693,7 @@ type PolicyRuleUpdate struct {
// Protocol Policy rule type of the traffic
Protocol PolicyRuleUpdateProtocol `json:"protocol"`
// Sources Policy rule source groups
// Sources Policy rule source group IDs
Sources []string `json:"sources"`
}
@@ -729,7 +732,7 @@ type Route struct {
// Enabled Route status
Enabled bool `json:"enabled"`
// Groups Route group tag groups
// Groups Group IDs containing routing peers
Groups []string `json:"groups"`
// Id Route Id
@@ -765,7 +768,7 @@ type RouteRequest struct {
// Enabled Route status
Enabled bool `json:"enabled"`
// Groups Route group tag groups
// Groups Group IDs containing routing peers
Groups []string `json:"groups"`
// Masquerade Indicate if peer should masquerade traffic to this route's prefix
@@ -792,7 +795,7 @@ type Rule struct {
// Description Rule friendly description
Description string `json:"description"`
// Destinations Rule destination groups
// Destinations Rule destination group IDs
Destinations []GroupMinimum `json:"destinations"`
// Disabled Rules status
@@ -807,7 +810,7 @@ type Rule struct {
// Name Rule name identifier
Name string `json:"name"`
// Sources Rule source groups
// Sources Rule source group IDs
Sources []GroupMinimum `json:"sources"`
}
@@ -831,7 +834,7 @@ type RuleRequest struct {
// Description Rule friendly description
Description string `json:"description"`
// Destinations List of destination groups
// Destinations List of destination group IDs
Destinations *[]string `json:"destinations,omitempty"`
// Disabled Rules status
@@ -843,7 +846,7 @@ type RuleRequest struct {
// Name Rule name identifier
Name string `json:"name"`
// Sources List of source groups
// Sources List of source group IDs
Sources *[]string `json:"sources,omitempty"`
}
@@ -918,7 +921,7 @@ type SetupKeyRequest struct {
// User defines model for User.
type User struct {
// AutoGroups Groups to auto-assign to peers registered by this user
// AutoGroups Group IDs to auto-assign to peers registered by this user
AutoGroups []string `json:"auto_groups"`
// Email User's email address
@@ -957,7 +960,7 @@ type UserStatus string
// UserCreateRequest defines model for UserCreateRequest.
type UserCreateRequest struct {
// AutoGroups Groups to auto-assign to peers registered by this user
// AutoGroups Group IDs to auto-assign to peers registered by this user
AutoGroups []string `json:"auto_groups"`
// Email User's Email to send invite to
@@ -975,7 +978,7 @@ type UserCreateRequest struct {
// UserRequest defines model for UserRequest.
type UserRequest struct {
// AutoGroups Groups to auto-assign to peers registered by this user
// AutoGroups Group IDs to auto-assign to peers registered by this user
AutoGroups []string `json:"auto_groups"`
// IsBlocked If set to true then user is blocked and can't use the system

View File

@@ -34,12 +34,20 @@ type emptyObject struct {
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
)
authMiddleware := middleware.NewAuthMiddleware(
accountManager.GetAccountFromPAT,
jwtValidator.ValidateAndParse,
accountManager.MarkPATUsed,
accountManager.GetAccountFromToken,
claimsExtractor,
authCfg.Audience,
authCfg.UserIDClaim)
authCfg.UserIDClaim,
)
corsMiddleware := cors.AllowAll()
@@ -60,11 +68,6 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid
AuthCfg: authCfg,
}
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
)
integrations.RegisterHandlers(api.Router, accountManager, claimsExtractor)
api.addAccountsEndpoint()
api.addPeersEndpoint()

View File

@@ -26,11 +26,16 @@ type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error)
// MarkPATUsedFunc function
type MarkPATUsedFunc func(token string) error
// GetAccountFromTokenFunc function
type GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
getAccountFromPAT GetAccountFromPATFunc
validateAndParseToken ValidateAndParseTokenFunc
markPATUsed MarkPATUsedFunc
getAccountFromToken GetAccountFromTokenFunc
claimsExtractor *jwtclaims.ClaimsExtractor
audience string
userIDClaim string
}
@@ -40,14 +45,19 @@ const (
)
// NewAuthMiddleware instance constructor
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, audience string, userIdClaim string) *AuthMiddleware {
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
markPATUsed MarkPATUsedFunc, getAccountFromToken GetAccountFromTokenFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
audience string, userIdClaim string) *AuthMiddleware {
if userIdClaim == "" {
userIdClaim = jwtclaims.UserIDClaim
}
return &AuthMiddleware{
getAccountFromPAT: getAccountFromPAT,
validateAndParseToken: validateAndParseToken,
markPATUsed: markPATUsed,
getAccountFromToken: getAccountFromToken,
claimsExtractor: claimsExtractor,
audience: audience,
userIDClaim: userIdClaim,
}
@@ -107,6 +117,10 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
return nil
}
if err := m.verifyUserAccess(validatedToken); err != nil {
return err
}
// If we get here, everything worked and we can set the
// user property in context.
newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint
@@ -115,6 +129,41 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
return nil
}
// verifyUserAccess checks if a user, based on a validated JWT token,
// is allowed access, particularly in cases where the admin enabled JWT
// group propagation and designated certain groups with access permissions.
func (m *AuthMiddleware) verifyUserAccess(validatedToken *jwt.Token) error {
authClaims := m.claimsExtractor.FromToken(validatedToken)
account, _, err := m.getAccountFromToken(authClaims)
if err != nil {
return fmt.Errorf("failed to get the account from token: %w", err)
}
// Ensures JWT group synchronization to the management is enabled before,
// filtering access based on the allowed groups.
if account.Settings != nil && account.Settings.JWTGroupsEnabled {
if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 {
userJWTGroups := make([]string, 0)
if claim, ok := authClaims.Raw[account.Settings.JWTGroupsClaimName]; ok {
if claimGroups, ok := claim.([]interface{}); ok {
for _, g := range claimGroups {
if group, ok := g.(string); ok {
userJWTGroups = append(userJWTGroups, group)
}
}
}
}
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
}
}
}
return nil
}
// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
token, err := getTokenFromPATRequest(auth)
@@ -168,3 +217,15 @@ func getTokenFromPATRequest(authHeaderParts []string) (string, error) {
return authHeaderParts[1], nil
}
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
for _, userGroup := range userGroups {
for _, allowedGroup := range allowedGroups {
if userGroup == allowedGroup {
return true
}
}
}
return false
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
)
const (
@@ -54,7 +55,13 @@ func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server
func mockValidateAndParseToken(token string) (*jwt.Token, error) {
if token == JWT {
return &jwt.Token{}, nil
return &jwt.Token{
Claims: jwt.MapClaims{
userIDClaim: userID,
audience + jwtclaims.AccountIDSuffix: accountID,
},
Valid: true,
}, nil
}
return nil, fmt.Errorf("JWT invalid")
}
@@ -66,6 +73,19 @@ func mockMarkPATUsed(token string) error {
return fmt.Errorf("Should never get reached")
}
func mockGetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
if testAccount.Id != claims.AccountId {
return nil, nil, fmt.Errorf("account with id %s does not exist", claims.AccountId)
}
user, ok := testAccount.Users[claims.UserId]
if !ok {
return nil, nil, fmt.Errorf("user with id %s does not exist", claims.UserId)
}
return testAccount, user, nil
}
func TestAuthMiddleware_Handler(t *testing.T) {
tt := []struct {
name string
@@ -108,7 +128,20 @@ func TestAuthMiddleware_Handler(t *testing.T) {
// do nothing
})
authMiddleware := NewAuthMiddleware(mockGetAccountFromPAT, mockValidateAndParseToken, mockMarkPATUsed, audience, userIDClaim)
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(audience),
jwtclaims.WithUserIDClaim(userIDClaim),
)
authMiddleware := NewAuthMiddleware(
mockGetAccountFromPAT,
mockValidateAndParseToken,
mockMarkPATUsed,
mockGetAccountFromToken,
claimsExtractor,
audience,
userIDClaim,
)
handlerToTest := authMiddleware.Handler(nextHandler)

View File

@@ -108,6 +108,8 @@ func NewJWTValidator(issuer string, audienceList []string, keysLocation string,
refreshedKeys = keys
}
log.Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
keys = refreshedKeys
}
}
@@ -179,7 +181,7 @@ func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) {
// stillValid returns true if the JSONWebKey still valid and have enough time to be used
func (jwks *Jwks) stillValid() bool {
return jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime)
return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime)
}
func getPemKeys(keysLocation string) (*Jwks, error) {

View File

@@ -366,7 +366,7 @@ func (am *MockAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*se
func (am *MockAccountManager) ListUsers(accountID string) ([]*server.User, error) {
if am.ListUsersFunc != nil {
return am.ListUsers(accountID)
return am.ListUsersFunc(accountID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListUsers is not implemented")
}
@@ -464,7 +464,7 @@ func (am *MockAccountManager) SaveUser(accountID, userID string, user *server.Us
// SaveOrAddUser mocks SaveOrAddUser of the AccountManager interface
func (am *MockAccountManager) SaveOrAddUser(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) {
if am.SaveUserFunc != nil {
if am.SaveOrAddUserFunc != nil {
return am.SaveOrAddUserFunc(accountID, userID, user, addIfNotExists)
}
return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUser is not implemented")
@@ -545,7 +545,7 @@ func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.Authorization
// GetPeers mocks GetPeers of the AccountManager interface
func (am *MockAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.Peer, error) {
if am.GetAccountFromTokenFunc != nil {
if am.GetPeersFunc != nil {
return am.GetPeersFunc(accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeers is not implemented")

View File

@@ -15,36 +15,44 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
account := &Account{
Peers: map[string]*nbpeer.Peer{
"peerA": {
ID: "peerA",
IP: net.ParseIP("100.65.14.88"),
ID: "peerA",
IP: net.ParseIP("100.65.14.88"),
Status: &nbpeer.PeerStatus{},
},
"peerB": {
ID: "peerB",
IP: net.ParseIP("100.65.80.39"),
ID: "peerB",
IP: net.ParseIP("100.65.80.39"),
Status: &nbpeer.PeerStatus{},
},
"peerC": {
ID: "peerC",
IP: net.ParseIP("100.65.254.139"),
ID: "peerC",
IP: net.ParseIP("100.65.254.139"),
Status: &nbpeer.PeerStatus{},
},
"peerD": {
ID: "peerD",
IP: net.ParseIP("100.65.62.5"),
ID: "peerD",
IP: net.ParseIP("100.65.62.5"),
Status: &nbpeer.PeerStatus{},
},
"peerE": {
ID: "peerE",
IP: net.ParseIP("100.65.32.206"),
ID: "peerE",
IP: net.ParseIP("100.65.32.206"),
Status: &nbpeer.PeerStatus{},
},
"peerF": {
ID: "peerF",
IP: net.ParseIP("100.65.250.202"),
ID: "peerF",
IP: net.ParseIP("100.65.250.202"),
Status: &nbpeer.PeerStatus{},
},
"peerG": {
ID: "peerG",
IP: net.ParseIP("100.65.13.186"),
ID: "peerG",
IP: net.ParseIP("100.65.13.186"),
Status: &nbpeer.PeerStatus{},
},
"peerH": {
ID: "peerH",
IP: net.ParseIP("100.65.29.55"),
ID: "peerH",
IP: net.ParseIP("100.65.29.55"),
Status: &nbpeer.PeerStatus{},
},
},
Groups: map[string]*Group{
@@ -259,16 +267,19 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
account := &Account{
Peers: map[string]*nbpeer.Peer{
"peerA": {
ID: "peerA",
IP: net.ParseIP("100.65.14.88"),
ID: "peerA",
IP: net.ParseIP("100.65.14.88"),
Status: &nbpeer.PeerStatus{},
},
"peerB": {
ID: "peerB",
IP: net.ParseIP("100.65.80.39"),
ID: "peerB",
IP: net.ParseIP("100.65.80.39"),
Status: &nbpeer.PeerStatus{},
},
"peerC": {
ID: "peerC",
IP: net.ParseIP("100.65.254.139"),
ID: "peerC",
IP: net.ParseIP("100.65.254.139"),
Status: &nbpeer.PeerStatus{},
},
},
Groups: map[string]*Group{

View File

@@ -1062,6 +1062,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
account.Peers[peer1.ID] = peer1
@@ -1087,6 +1088,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
account.Peers[peer2.ID] = peer2
@@ -1112,6 +1114,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
account.Peers[peer3.ID] = peer3
@@ -1137,6 +1140,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
account.Peers[peer4.ID] = peer4
@@ -1162,6 +1166,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
WtVersion: "development",
UIVersion: "development",
},
Status: &nbpeer.PeerStatus{},
}
account.Peers[peer5.ID] = peer5