Compare commits

...

44 Commits

Author SHA1 Message Date
Zoltan Papp
0c4c2a8cb2 Listen on all address 2023-12-13 22:39:12 +01:00
Zoltan Papp
9c6c944e46 Add socks5 proxy 2023-12-13 16:58:31 +01:00
Zoltan Papp
cec4232860 Netstack poc 2023-12-12 17:35:16 +01:00
Maycon Santos
2d1dfa3ae7 Fix jwks validation and flag/config overriding (#1380)
Ensure the jwks expiresInTime is not zero and add a log indicating the new expiration time

Replace the configuration property only when the flag is being used
2023-12-12 14:56:27 +01:00
Yury Gargay
5961c8330e Fix SaveOrAddUser and GetPeers methods in MockAccountManager (#1374) 2023-12-11 17:32:10 +01:00
Bethuel Mmbaga
d275d411aa Enable JWT group-based user authorization (#1368)
* Extend management API to support list of allowed JWT groups (#1366)

* Add JWTAllowGroups settings to account management

* Return an empty group list if jwt allow groups is not set

* Add JwtAllowGroups to account settings in handler test

* Add JWT group-based user authorization (#1373)

* Add JWTAllowGroups settings to account management

* Return an empty group list if jwt allow groups is not set

* Add JwtAllowGroups to account settings in handler test

* Implement user access validation authentication based on JWT groups

* Remove the slices package import due to compatibility issues with the gitHub workflow(s) Go version

* Refactor auth middleware and test for extracted claim handling

* Optimize JWT group check in auth middleware to cover nil and empty allowed groups
2023-12-11 18:59:15 +03:00
Yury Gargay
5ecafef5d2 Fix ListUsers method in MockAccountManager (#1367) 2023-12-11 15:00:02 +01:00
Yury Gargay
d073a250cc Specify ref for sync tag workflow (#1365) 2023-12-08 14:18:49 +01:00
Yury Gargay
a1c48468ab Add Dev Container Support section in contributing guideline (#1363) 2023-12-08 11:54:50 +01:00
Maycon Santos
dd1e730454 Update API descriptions and examples (#1364) 2023-12-08 11:39:33 +01:00
Yury Gargay
050f140245 Add sync-tag.yml GitHub workflow (#1362) 2023-12-08 10:55:31 +01:00
Zoltan Papp
006ba32086 Fix/acl for forward (#1305)
Fix ACL on routed traffic and code refactor
2023-12-08 10:48:21 +01:00
Yury Gargay
b03343bc4d Add sync-main.yml GitHub workflow (#1359) 2023-12-06 17:51:11 +01:00
pascal-fischer
36d62f1844 Merge pull request #1358 from netbirdio/fix/tests-after-peer-validation
Fix tests after peer validation
2023-12-06 15:39:37 +01:00
Pascal Fischer
08733ed8d5 update tests 2023-12-06 15:02:10 +01:00
Yury Gargay
27ed88f918 Implement lightweight method to check is peer has update channel (#1351)
Instead of GetAllConnectedPeers that need to traverse the whole
connections map in order to find one channel there.
2023-12-05 14:17:56 +01:00
pascal-fischer
45fc89b2c9 Merge pull request #1355 from netbirdio/chore/update-integrations-branch-reference
Chore: clean gomod reference
2023-12-05 13:13:14 +01:00
Pascal Fischer
f822a58326 go mod tidy 2023-12-05 12:54:01 +01:00
Pascal Fischer
d1f13025d1 switch back to use netbird main 2023-12-05 12:39:15 +01:00
pascal-fischer
3f8b500f0b Merge pull request #1341 from netbirdio/feature/peer-approval
Add peer and settings validation
2023-12-05 12:11:14 +01:00
Maycon Santos
0d2db4b172 update API doc 2023-12-04 19:02:16 +01:00
Pascal Fischer
7a18dea766 go mod tidy 2023-12-04 17:35:56 +01:00
pascal-fischer
ae5f69562d Merge branch 'main' into feature/peer-approval 2023-12-04 17:34:53 +01:00
pascal-fischer
755ffcfc73 Merge pull request #1353 from netbirdio/feature/extend-add-peer-event-with-setup-key
Extend add peer event meta with setup key name
2023-12-04 17:33:50 +01:00
Pascal Fischer
dc8f55f23e remove dependency cycle from prepare peer 2023-12-04 16:26:34 +01:00
Pascal Fischer
89249b414f move peer validation into getPeerconnectionResources 2023-12-04 14:53:38 +01:00
Pascal Fischer
92adf57fea fix map assignment 2023-12-04 13:49:46 +01:00
Pascal Fischer
1cd5a66575 adding setup key name to the event meta for adding peers by setup key 2023-12-04 13:00:13 +01:00
Pascal Fischer
b9fc008542 extract peer preparation 2023-12-04 12:49:50 +01:00
pascal-fischer
d5bf79bc51 Merge branch 'main' into feature/peer-approval 2023-12-01 18:12:59 +01:00
Pascal Fischer
4bf574037f fix sql store 2023-11-30 11:51:35 +01:00
Pascal Fischer
47c44d4b87 fix imports in sqlite store test 2023-11-30 11:08:51 +01:00
Pascal Fischer
96f866fb68 add missing imports after refactor 2023-11-29 16:46:46 +01:00
pascal-fischer
141065f14e Merge branch 'main' into feature/peer-approval 2023-11-29 16:27:01 +01:00
Pascal Fischer
8e74fb1fa8 add account id to validating peer update 2023-11-29 15:57:56 +01:00
Pascal Fischer
ba96e102b4 settings nil check 2023-11-29 15:16:11 +01:00
Pascal Fischer
2129b23fe7 allow sync for and return empty map 2023-11-29 14:56:06 +01:00
Pascal Fischer
efd05ca023 fix api references 2023-11-28 15:15:51 +01:00
Pascal Fischer
c829ad930c use separate package for signatures 2023-11-28 15:09:04 +01:00
Pascal Fischer
ad1f18a52a replace with updated integrations 2023-11-28 14:55:20 +01:00
Pascal Fischer
bab420ca77 extract account into separate package 2023-11-28 14:34:57 +01:00
Pascal Fischer
a729c83b06 extract peer into seperate package 2023-11-28 13:45:26 +01:00
Pascal Fischer
a7e55cc5e3 add signatures and frame for peer approval 2023-11-28 11:44:08 +01:00
Pascal Fischer
b7c0eba1e5 add extra settings struct 2023-11-27 17:04:40 +01:00
104 changed files with 4955 additions and 4468 deletions

22
.github/workflows/sync-main.yml vendored Normal file
View File

@@ -0,0 +1,22 @@
name: sync main
on:
push:
branches:
- main
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
trigger_sync_main:
runs-on: ubuntu-latest
steps:
- name: Trigger main branch sync
uses: benc-uk/workflow-dispatch@v1
with:
workflow: sync-main.yml
repo: ${{ secrets.UPSTREAM_REPO }}
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "sha": "${{ github.sha }}" }'

23
.github/workflows/sync-tag.yml vendored Normal file
View File

@@ -0,0 +1,23 @@
name: sync tag
on:
push:
tags:
- 'v*'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
trigger_sync_tag:
runs-on: ubuntu-latest
steps:
- name: Trigger release tag sync
uses: benc-uk/workflow-dispatch@v1
with:
workflow: sync-tag.yml
ref: main
repo: ${{ secrets.UPSTREAM_REPO }}
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'

View File

@@ -19,6 +19,7 @@ If you haven't already, join our slack workspace [here](https://join.slack.com/t
- [Development setup](#development-setup) - [Development setup](#development-setup)
- [Requirements](#requirements) - [Requirements](#requirements)
- [Local NetBird setup](#local-netbird-setup) - [Local NetBird setup](#local-netbird-setup)
- [Dev Container Support](#dev-container-support)
- [Build and start](#build-and-start) - [Build and start](#build-and-start)
- [Test suite](#test-suite) - [Test suite](#test-suite)
- [Checklist before submitting a PR](#checklist-before-submitting-a-pr) - [Checklist before submitting a PR](#checklist-before-submitting-a-pr)
@@ -135,6 +136,48 @@ checked out and set up:
go mod tidy go mod tidy
``` ```
### Dev Container Support
If you prefer using a dev container for development, NetBird now includes support for dev containers.
Dev containers provide a consistent and isolated development environment, making it easier for contributors to get started quickly. Follow the steps below to set up NetBird in a dev container.
#### 1. Prerequisites:
* Install Docker on your machine: [Docker Installation Guide](https://docs.docker.com/get-docker/)
* Install Visual Studio Code: [VS Code Installation Guide](https://code.visualstudio.com/download)
* If you prefer JetBrains Goland please follow this [manual](https://www.jetbrains.com/help/go/connect-to-devcontainer.html)
#### 2. Clone the Repository:
Clone the repository following previous [Local NetBird setup](#local-netbird-setup).
#### 3. Open in project in IDE of your choice:
**VScode**:
Open the project folder in Visual Studio Code:
```bash
code .
```
When you open the project in VS Code, it will detect the presence of a dev container configuration.
Click on the green "Reopen in Container" button in the bottom-right corner of VS Code.
**Goland**:
Open GoLand and select `"File" > "Open"` to open the NetBird project folder.
GoLand will detect the dev container configuration and prompt you to open the project in the container. Accept the prompt.
#### 4. Wait for the Container to Build:
VsCode or GoLand will use the specified Docker image to build the dev container. This might take some time, depending on your internet connection.
#### 6. Development:
Once the container is built, you can start developing within the dev container. All the necessary dependencies and configurations are set up within the container.
### Build and start ### Build and start
#### Client #### Client

32
client/firewall/create.go Normal file
View File

@@ -0,0 +1,32 @@
//go:build !linux || android
package firewall
import (
"context"
"fmt"
"runtime"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
)
// NewFirewall creates a firewall manager instance
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
}
// use userspace packet filtering firewall
fm, err := uspfilter.Create(iface)
if err != nil {
return nil, err
}
err = fm.AllowNetbird()
if err != nil {
log.Warnf("failed to allow netbird interface traffic: %v", err)
}
return fm, nil
}

View File

@@ -0,0 +1,107 @@
//go:build !android
package firewall
import (
"context"
"fmt"
"os"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
log "github.com/sirupsen/logrus"
nbiptables "github.com/netbirdio/netbird/client/firewall/iptables"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
)
const (
// UNKNOWN is the default value for the firewall type for unknown firewall type
UNKNOWN FWType = iota
// IPTABLES is the value for the iptables firewall type
IPTABLES
// NFTABLES is the value for the nftables firewall type
NFTABLES
)
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
var fm firewall.Manager
var errFw error
switch check() {
case IPTABLES:
log.Debug("creating an iptables firewall manager")
fm, errFw = nbiptables.Create(context, iface)
if errFw != nil {
log.Errorf("failed to create iptables manager: %s", errFw)
}
case NFTABLES:
log.Debug("creating an nftables firewall manager")
fm, errFw = nbnftables.Create(context, iface)
if errFw != nil {
log.Errorf("failed to create nftables manager: %s", errFw)
}
default:
errFw = fmt.Errorf("no firewall manager found")
log.Debug("no firewall manager found, try to use userspace packet filtering firewall")
}
if iface.IsUserspaceBind() {
var errUsp error
if errFw == nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
} else {
fm, errUsp = uspfilter.Create(iface)
}
if errUsp != nil {
log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
return nil, errUsp
}
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
return fm, nil
}
if errFw != nil {
return nil, errFw
}
return fm, nil
}
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func check() FWType {
nf := nftables.Conn{}
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
return NFTABLES
}
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return UNKNOWN
}
if isIptablesClientAvailable(ip) {
return IPTABLES
}
return UNKNOWN
}
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}

11
client/firewall/iface.go Normal file
View File

@@ -0,0 +1,11 @@
package firewall
import "github.com/netbirdio/netbird/iface"
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
SetFilter(iface.PacketFilter) error
}

View File

@@ -0,0 +1,473 @@
package iptables
import (
"fmt"
"net"
"strconv"
"github.com/coreos/go-iptables/iptables"
"github.com/google/uuid"
"github.com/nadoo/ipset"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
tableName = "filter"
// rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT"
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
postRoutingMark = "0x000007e4"
)
type aclManager struct {
iptablesClient *iptables.IPTables
wgIface iFaceMapper
routeingFwChainName string
entries map[string][][]string
ipsetStore *ipsetStore
}
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) {
m := &aclManager{
iptablesClient: iptablesClient,
wgIface: wgIface,
routeingFwChainName: routeingFwChainName,
entries: make(map[string][][]string),
ipsetStore: newIpsetStore(),
}
err := ipset.Init()
if err != nil {
return nil, fmt.Errorf("failed to init ipset: %w", err)
}
m.seedInitialEntries()
err = m.cleanChains()
if err != nil {
return nil, err
}
err = m.createDefaultChains()
if err != nil {
return nil, err
}
return m, nil
}
func (m *aclManager) AddFiltering(
ip net.IP,
protocol firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
var dPortVal, sPortVal string
if dPort != nil && dPort.Values != nil {
// TODO: we support only one port per rule in current implementation of ACLs
dPortVal = strconv.Itoa(dPort.Values[0])
}
if sPort != nil && sPort.Values != nil {
sPortVal = strconv.Itoa(sPort.Values[0])
}
var chain string
if direction == firewall.RuleDirectionOUT {
chain = chainNameOutputRules
} else {
chain = chainNameInputRules
}
ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal)
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
if ipsetName != "" {
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
if err := ipset.Add(ipsetName, ip.String()); err != nil {
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
}
// if ruleset already exists it means we already have the firewall rule
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
ipList.addIP(ip.String())
return []firewall.Rule{&Rule{
ruleID: uuid.New().String(),
ipsetName: ipsetName,
ip: ip.String(),
chain: chain,
specs: specs,
}}, nil
}
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %s before use it: %s", ipsetName, err)
}
if err := ipset.Create(ipsetName); err != nil {
return nil, fmt.Errorf("failed to create ipset: %w", err)
}
if err := ipset.Add(ipsetName, ip.String()); err != nil {
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
}
ipList := newIpList(ip.String())
m.ipsetStore.addIpList(ipsetName, ipList)
}
ok, err := m.iptablesClient.Exists("filter", chain, specs...)
if err != nil {
return nil, fmt.Errorf("failed to check rule: %w", err)
}
if ok {
return nil, fmt.Errorf("rule already exists")
}
if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil {
return nil, err
}
rule := &Rule{
ruleID: uuid.New().String(),
specs: specs,
ipsetName: ipsetName,
ip: ip.String(),
chain: chain,
}
if !shouldAddToPrerouting(protocol, dPort, direction) {
return []firewall.Rule{rule}, nil
}
rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip)
if err != nil {
return []firewall.Rule{rule}, err
}
return []firewall.Rule{rule, rulePrerouting}, nil
}
// DeleteRule from the firewall by rule definition
func (m *aclManager) DeleteRule(rule firewall.Rule) error {
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
}
if r.chain == "PREROUTING" {
goto DELETERULE
}
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
// delete IP from ruleset IPs list and ipset
if _, ok := ipsetList.ips[r.ip]; ok {
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
return fmt.Errorf("failed to delete ip from ipset: %w", err)
}
delete(ipsetList.ips, r.ip)
}
// if after delete, set still contains other IPs,
// no need to delete firewall rule and we should exit here
if len(ipsetList.ips) != 0 {
return nil
}
// we delete last IP from the set, that means we need to delete
// set itself and associated firewall rule too
m.ipsetStore.deleteIpset(r.ipsetName)
if err := ipset.Destroy(r.ipsetName); err != nil {
log.Errorf("delete empty ipset: %v", err)
}
}
DELETERULE:
var table string
if r.chain == "PREROUTING" {
table = "mangle"
} else {
table = "filter"
}
err := m.iptablesClient.Delete(table, r.chain, r.specs...)
if err != nil {
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
}
return err
}
func (m *aclManager) Reset() error {
return m.cleanChains()
}
func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) {
var src []string
if ipsetName != "" {
src = []string{"-m", "set", "--set", ipsetName, "src"}
} else {
src = []string{"-s", ip.String()}
}
specs := []string{
"-d", m.wgIface.Address().IP.String(),
"-p", protocol,
"--dport", port,
"-j", "MARK", "--set-mark", postRoutingMark,
}
specs = append(src, specs...)
ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...)
if err != nil {
return nil, fmt.Errorf("failed to check rule: %w", err)
}
if ok {
return nil, fmt.Errorf("rule already exists")
}
if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil {
return nil, err
}
rule := &Rule{
ruleID: uuid.New().String(),
specs: specs,
ipsetName: ipsetName,
ip: ip.String(),
chain: "PREROUTING",
}
return rule, nil
}
// todo write less destructive cleanup mechanism
func (m *aclManager) cleanChains() error {
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
rules := m.entries["OUTPUT"]
for _, rule := range rules {
err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules)
if err != nil {
log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err)
return err
}
}
ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules)
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
for _, rule := range m.entries["INPUT"] {
err := m.iptablesClient.DeleteIfExists(tableName, "INPUT", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
for _, rule := range m.entries["FORWARD"] {
err := m.iptablesClient.DeleteIfExists(tableName, "FORWARD", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameInputRules)
if err != nil {
log.Debugf("failed to clear and delete %s chain: %s", chainNameInputRules, err)
return err
}
}
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
for _, rule := range m.entries["PREROUTING"] {
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = m.iptablesClient.ClearChain("mangle", "PREROUTING")
if err != nil {
log.Debugf("failed to clear %s chain: %s", "PREROUTING", err)
return err
}
}
for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
}
if err := ipset.Destroy(ipsetName); err != nil {
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
}
m.ipsetStore.deleteIpset(ipsetName)
}
return nil
}
func (m *aclManager) createDefaultChains() error {
// chain netbird-acl-input-rules
if err := m.iptablesClient.NewChain(tableName, chainNameInputRules); err != nil {
log.Debugf("failed to create '%s' chain: %s", chainNameInputRules, err)
return err
}
// chain netbird-acl-output-rules
if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil {
log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err)
return err
}
for chainName, rules := range m.entries {
for _, rule := range rules {
if chainName == "FORWARD" {
// position 2 because we add it after router's, jump rule
if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
return err
}
} else {
if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
return err
}
}
}
}
return nil
}
func (m *aclManager) seedInitialEntries() {
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("FORWARD",
[]string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
m.appendToEntries("FORWARD",
[]string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName})
m.appendToEntries("PREROUTING",
[]string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark})
}
func (m *aclManager) appendToEntries(chainName string, spec []string) {
m.entries[chainName] = append(m.entries[chainName], spec)
}
// filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs(
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
) (specs []string) {
matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0
if ip.String() == "0.0.0.0" {
matchByIP = false
}
switch direction {
case firewall.RuleDirectionIN:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
} else {
specs = append(specs, "-s", ip.String())
}
}
case firewall.RuleDirectionOUT:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
} else {
specs = append(specs, "-d", ip.String())
}
}
}
if protocol != "all" {
specs = append(specs, "-p", protocol)
}
if sPort != "" {
specs = append(specs, "--sport", sPort)
}
if dPort != "" {
specs = append(specs, "--dport", dPort)
}
return append(specs, "-j", actionToStr(action))
}
func actionToStr(action firewall.Action) string {
if action == firewall.ActionAccept {
return "ACCEPT"
}
return "DROP"
}
func transformIPsetName(ipsetName string, sPort, dPort string) string {
switch {
case ipsetName == "":
return ""
case sPort != "" && dPort != "":
return ipsetName + "-sport-dport"
case sPort != "":
return ipsetName + "-sport"
case dPort != "":
return ipsetName + "-dport"
default:
return ipsetName
}
}
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
if proto == "all" {
return false
}
if direction != firewall.RuleDirectionIN {
return false
}
if dPort == nil {
return false
}
return true
}

View File

@@ -1,43 +1,27 @@
package iptables package iptables
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"strconv"
"sync" "sync"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/google/uuid"
"github.com/nadoo/ipset"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
fw "github.com/netbirdio/netbird/client/firewall" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
) )
const (
// ChainInputFilterName is the name of the chain that is used for filtering incoming packets
ChainInputFilterName = "NETBIRD-ACL-INPUT"
// ChainOutputFilterName is the name of the chain that is used for filtering outgoing packets
ChainOutputFilterName = "NETBIRD-ACL-OUTPUT"
)
// dropAllDefaultRule in the Netbird chain
var dropAllDefaultRule = []string{"-j", "DROP"}
// Manager of iptables firewall // Manager of iptables firewall
type Manager struct { type Manager struct {
mutex sync.Mutex mutex sync.Mutex
wgIface iFaceMapper
ipv4Client *iptables.IPTables ipv4Client *iptables.IPTables
ipv6Client *iptables.IPTables aclMgr *aclManager
router *routerManager
inputDefaultRuleSpecs []string
outputDefaultRuleSpecs []string
wgIface iFaceMapper
rulesets map[string]ruleset
} }
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
@@ -47,47 +31,29 @@ type iFaceMapper interface {
IsUserspaceBind() bool IsUserspaceBind() bool
} }
type ruleset struct {
rule *Rule
ips map[string]string
}
// Create iptables firewall manager // Create iptables firewall manager
func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) { func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
m := &Manager{ iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
wgIface: wgIface,
inputDefaultRuleSpecs: []string{
"-i", wgIface.Name(), "-j", ChainInputFilterName, "-s", wgIface.Address().String()},
outputDefaultRuleSpecs: []string{
"-o", wgIface.Name(), "-j", ChainOutputFilterName, "-d", wgIface.Address().String()},
rulesets: make(map[string]ruleset),
}
err := ipset.Init()
if err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
}
// init clients for booth ipv4 and ipv6
m.ipv4Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil { if err != nil {
return nil, fmt.Errorf("iptables is not installed in the system or not supported") return nil, fmt.Errorf("iptables is not installed in the system or not supported")
} }
if ipv6Supported { m := &Manager{
m.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6) wgIface: wgIface,
if err != nil { ipv4Client: iptablesClient,
log.Warnf("ip6tables is not installed in the system or not supported: %v. Access rules for this protocol won't be applied.", err)
}
} }
if m.ipv4Client == nil && m.ipv6Client == nil { m.router, err = newRouterManager(context, iptablesClient)
return nil, fmt.Errorf("iptables is not installed in the system or not enough permissions to use it") if err != nil {
log.Debugf("failed to initialize route related chains: %s", err)
return nil, err
}
m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName())
if err != nil {
log.Debugf("failed to initialize ACL manager: %s", err)
return nil, err
} }
if err := m.Reset(); err != nil {
return nil, fmt.Errorf("failed to reset firewall: %v", err)
}
return m, nil return m, nil
} }
@@ -96,159 +62,44 @@ func Create(wgIface iFaceMapper, ipv6Supported bool) (*Manager, error) {
// Comment will be ignored because some system this feature is not supported // Comment will be ignored because some system this feature is not supported
func (m *Manager) AddFiltering( func (m *Manager) AddFiltering(
ip net.IP, ip net.IP,
protocol fw.Protocol, protocol firewall.Protocol,
sPort *fw.Port, sPort *firewall.Port,
dPort *fw.Port, dPort *firewall.Port,
direction fw.RuleDirection, direction firewall.RuleDirection,
action fw.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string, comment string,
) (fw.Rule, error) { ) ([]firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
client, err := m.client(ip) return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
if err != nil {
return nil, err
}
var dPortVal, sPortVal string
if dPort != nil && dPort.Values != nil {
// TODO: we support only one port per rule in current implementation of ACLs
dPortVal = strconv.Itoa(dPort.Values[0])
}
if sPort != nil && sPort.Values != nil {
sPortVal = strconv.Itoa(sPort.Values[0])
}
ipsetName = m.transformIPsetName(ipsetName, sPortVal, dPortVal)
ruleID := uuid.New().String()
if ipsetName != "" {
rs, rsExists := m.rulesets[ipsetName]
if !rsExists {
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q before use it: %v", ipsetName, err)
}
if err := ipset.Create(ipsetName); err != nil {
return nil, fmt.Errorf("failed to create ipset: %w", err)
}
}
if err := ipset.Add(ipsetName, ip.String()); err != nil {
return nil, fmt.Errorf("failed to add IP to ipset: %w", err)
}
if rsExists {
// if ruleset already exists it means we already have the firewall rule
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
rs.ips[ip.String()] = ruleID
return &Rule{
ruleID: ruleID,
ipsetName: ipsetName,
ip: ip.String(),
dst: direction == fw.RuleDirectionOUT,
v6: ip.To4() == nil,
}, nil
}
// this is new ipset so we need to create firewall rule for it
}
specs := m.filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
if direction == fw.RuleDirectionOUT {
ok, err := client.Exists("filter", ChainOutputFilterName, specs...)
if err != nil {
return nil, fmt.Errorf("check is output rule already exists: %w", err)
}
if ok {
return nil, fmt.Errorf("input rule already exists")
}
if err := client.Insert("filter", ChainOutputFilterName, 1, specs...); err != nil {
return nil, err
}
} else {
ok, err := client.Exists("filter", ChainInputFilterName, specs...)
if err != nil {
return nil, fmt.Errorf("check is input rule already exists: %w", err)
}
if ok {
return nil, fmt.Errorf("input rule already exists")
}
if err := client.Insert("filter", ChainInputFilterName, 1, specs...); err != nil {
return nil, err
}
}
rule := &Rule{
ruleID: ruleID,
specs: specs,
ipsetName: ipsetName,
ip: ip.String(),
dst: direction == fw.RuleDirectionOUT,
v6: ip.To4() == nil,
}
if ipsetName != "" {
// ipset name is defined and it means that this rule was created
// for it, need to associate it with ruleset
m.rulesets[ipsetName] = ruleset{
rule: rule,
ips: map[string]string{rule.ip: ruleID},
}
}
return rule, nil
} }
// DeleteRule from the firewall by rule definition // DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule fw.Rule) error { func (m *Manager) DeleteRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
r, ok := rule.(*Rule) return m.aclMgr.DeleteRule(rule)
if !ok { }
return fmt.Errorf("invalid rule type")
}
client := m.ipv4Client func (m *Manager) IsServerRouteSupported() bool {
if r.v6 { return true
if m.ipv6Client == nil { }
return fmt.Errorf("ipv6 is not supported")
}
client = m.ipv6Client
}
if rs, ok := m.rulesets[r.ipsetName]; ok { func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
// delete IP from ruleset IPs list and ipset m.mutex.Lock()
if _, ok := rs.ips[r.ip]; ok { defer m.mutex.Unlock()
if err := ipset.Del(r.ipsetName, r.ip); err != nil {
return fmt.Errorf("failed to delete ip from ipset: %w", err)
}
delete(rs.ips, r.ip)
}
// if after delete, set still contains other IPs, return m.router.InsertRoutingRules(pair)
// no need to delete firewall rule and we should exit here }
if len(rs.ips) != 0 {
return nil
}
// we delete last IP from the set, that means we need to delete func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
// set itself and associated firewall rule too m.mutex.Lock()
delete(m.rulesets, r.ipsetName) defer m.mutex.Unlock()
if err := ipset.Destroy(r.ipsetName); err != nil { return m.router.RemoveRoutingRules(pair)
log.Errorf("delete empty ipset: %v", err)
}
r = rs.rule
}
if r.dst {
return client.Delete("filter", ChainOutputFilterName, r.specs...)
}
return client.Delete("filter", ChainInputFilterName, r.specs...)
} }
// Reset firewall to the default state // Reset firewall to the default state
@@ -256,223 +107,49 @@ func (m *Manager) Reset() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if err := m.reset(m.ipv4Client, "filter"); err != nil { errAcl := m.aclMgr.Reset()
return fmt.Errorf("clean ipv4 firewall ACL input chain: %w", err) if errAcl != nil {
log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl)
} }
if m.ipv6Client != nil { errMgr := m.router.Reset()
if err := m.reset(m.ipv6Client, "filter"); err != nil { if errMgr != nil {
return fmt.Errorf("clean ipv6 firewall ACL input chain: %w", err) log.Errorf("failed to clean up router rules from firewall: %s", errMgr)
} return errMgr
} }
return errAcl
return nil
} }
// AllowNetbird allows netbird interface traffic // AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error { func (m *Manager) AllowNetbird() error {
if m.wgIface.IsUserspaceBind() { if !m.wgIface.IsUserspaceBind() {
_, err := m.AddFiltering( return nil
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
fw.RuleDirectionIN,
fw.ActionAccept,
"",
"",
)
if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
}
_, err = m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
fw.RuleDirectionOUT,
fw.ActionAccept,
"",
"",
)
return err
} }
return nil _, err := m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
firewall.RuleDirectionIN,
firewall.ActionAccept,
"",
"",
)
if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
}
_, err = m.AddFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
firewall.RuleDirectionOUT,
firewall.ActionAccept,
"",
"",
)
return err
} }
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
// reset firewall chain, clear it and drop it
func (m *Manager) reset(client *iptables.IPTables, table string) error {
ok, err := client.ChainExists(table, ChainInputFilterName)
if err != nil {
return fmt.Errorf("failed to check if input chain exists: %w", err)
}
if ok {
if ok, err := client.Exists("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
return err
} else if ok {
if err := client.Delete("filter", "INPUT", m.inputDefaultRuleSpecs...); err != nil {
log.WithError(err).Errorf("failed to delete default input rule: %v", err)
}
}
}
ok, err = client.ChainExists(table, ChainOutputFilterName)
if err != nil {
return fmt.Errorf("failed to check if output chain exists: %w", err)
}
if ok {
if ok, err := client.Exists("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil {
return err
} else if ok {
if err := client.Delete("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil {
log.WithError(err).Errorf("failed to delete default output rule: %v", err)
}
}
}
if err := client.ClearAndDeleteChain(table, ChainInputFilterName); err != nil {
log.Errorf("failed to clear and delete input chain: %v", err)
return nil
}
if err := client.ClearAndDeleteChain(table, ChainOutputFilterName); err != nil {
log.Errorf("failed to clear and delete input chain: %v", err)
return nil
}
for ipsetName := range m.rulesets {
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
}
if err := ipset.Destroy(ipsetName); err != nil {
log.Errorf("delete ipset %q during reset: %v", ipsetName, err)
}
delete(m.rulesets, ipsetName)
}
return nil
}
// filterRuleSpecs returns the specs of a filtering rule
func (m *Manager) filterRuleSpecs(
ip net.IP, protocol string, sPort, dPort string, direction fw.RuleDirection, action fw.Action, ipsetName string,
) (specs []string) {
matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0
if s := ip.String(); s == "0.0.0.0" || s == "::" {
matchByIP = false
}
switch direction {
case fw.RuleDirectionIN:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
} else {
specs = append(specs, "-s", ip.String())
}
}
case fw.RuleDirectionOUT:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
} else {
specs = append(specs, "-d", ip.String())
}
}
}
if protocol != "all" {
specs = append(specs, "-p", protocol)
}
if sPort != "" {
specs = append(specs, "--sport", sPort)
}
if dPort != "" {
specs = append(specs, "--dport", dPort)
}
return append(specs, "-j", m.actionToStr(action))
}
// rawClient returns corresponding iptables client for the given ip
func (m *Manager) rawClient(ip net.IP) (*iptables.IPTables, error) {
if ip.To4() != nil {
return m.ipv4Client, nil
}
if m.ipv6Client == nil {
return nil, fmt.Errorf("ipv6 is not supported")
}
return m.ipv6Client, nil
}
// client returns client with initialized chain and default rules
func (m *Manager) client(ip net.IP) (*iptables.IPTables, error) {
client, err := m.rawClient(ip)
if err != nil {
return nil, err
}
ok, err := client.ChainExists("filter", ChainInputFilterName)
if err != nil {
return nil, fmt.Errorf("failed to check if chain exists: %w", err)
}
if !ok {
if err := client.NewChain("filter", ChainInputFilterName); err != nil {
return nil, fmt.Errorf("failed to create input chain: %w", err)
}
if err := client.AppendUnique("filter", ChainInputFilterName, dropAllDefaultRule...); err != nil {
return nil, fmt.Errorf("failed to create default drop all in netbird input chain: %w", err)
}
if err := client.Insert("filter", "INPUT", 1, m.inputDefaultRuleSpecs...); err != nil {
return nil, fmt.Errorf("failed to create input chain jump rule: %w", err)
}
}
ok, err = client.ChainExists("filter", ChainOutputFilterName)
if err != nil {
return nil, fmt.Errorf("failed to check if chain exists: %w", err)
}
if !ok {
if err := client.NewChain("filter", ChainOutputFilterName); err != nil {
return nil, fmt.Errorf("failed to create output chain: %w", err)
}
if err := client.AppendUnique("filter", ChainOutputFilterName, dropAllDefaultRule...); err != nil {
return nil, fmt.Errorf("failed to create default drop all in netbird output chain: %w", err)
}
if err := client.AppendUnique("filter", "OUTPUT", m.outputDefaultRuleSpecs...); err != nil {
return nil, fmt.Errorf("failed to create output chain jump rule: %w", err)
}
}
return client, nil
}
func (m *Manager) actionToStr(action fw.Action) string {
if action == fw.ActionAccept {
return "ACCEPT"
}
return "DROP"
}
func (m *Manager) transformIPsetName(ipsetName string, sPort, dPort string) string {
switch {
case ipsetName == "":
return ""
case sPort != "" && dPort != "":
return ipsetName + "-sport-dport"
case sPort != "":
return ipsetName + "-sport"
case dPort != "":
return ipsetName + "-dport"
default:
return ipsetName
}
}

View File

@@ -1,6 +1,7 @@
package iptables package iptables
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"testing" "testing"
@@ -9,7 +10,7 @@ import (
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
) )
@@ -55,7 +56,7 @@ func TestIptablesManager(t *testing.T) {
} }
// just check on the local interface // just check on the local interface
manager, err := Create(mock, true) manager, err := Create(context.Background(), mock)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -67,17 +68,20 @@ func TestIptablesManager(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
var rule1 fw.Rule var rule1 []fw.Rule
t.Run("add first rule", func(t *testing.T) { t.Run("add first rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2") ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}} port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...) for _, r := range rule1 {
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
}
}) })
var rule2 fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
@@ -87,21 +91,28 @@ func TestIptablesManager(t *testing.T) {
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range") ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...) for _, r := range rule2 {
rr := r.(*Rule)
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
}
}) })
t.Run("delete first rule", func(t *testing.T) { t.Run("delete first rule", func(t *testing.T) {
err := manager.DeleteRule(rule1) for _, r := range rule1 {
require.NoError(t, err, "failed to delete rule") err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule")
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, false, rule1.(*Rule).specs...) checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
}
}) })
t.Run("delete second rule", func(t *testing.T) { t.Run("delete second rule", func(t *testing.T) {
err := manager.DeleteRule(rule2) for _, r := range rule2 {
require.NoError(t, err, "failed to delete rule") err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule")
}
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty") require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
}) })
t.Run("reset check", func(t *testing.T) { t.Run("reset check", func(t *testing.T) {
@@ -114,11 +125,11 @@ func TestIptablesManager(t *testing.T) {
err = manager.Reset() err = manager.Reset()
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
ok, err := ipv4Client.ChainExists("filter", ChainInputFilterName) ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
require.NoError(t, err, "failed check chain exists") require.NoError(t, err, "failed check chain exists")
if ok { if ok {
require.NoErrorf(t, err, "chain '%v' still exists after Reset", ChainInputFilterName) require.NoErrorf(t, err, "chain '%v' still exists after Reset", chainNameInputRules)
} }
}) })
} }
@@ -143,7 +154,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
} }
// just check on the local interface // just check on the local interface
manager, err := Create(mock, true) manager, err := Create(context.Background(), mock)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -155,7 +166,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
var rule1 fw.Rule var rule1 []fw.Rule
t.Run("add first rule with set", func(t *testing.T) { t.Run("add first rule with set", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2") ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}} port := &fw.Port{Values: []int{8080}}
@@ -165,12 +176,14 @@ func TestIptablesManagerIPSet(t *testing.T) {
) )
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...) for _, r := range rule1 {
require.Equal(t, rule1.(*Rule).ipsetName, "default-dport", "ipset name must be set") checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
require.Equal(t, rule1.(*Rule).ip, "10.20.0.2", "ipset IP must be set") require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set")
require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
}
}) })
var rule2 fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
@@ -180,23 +193,29 @@ func TestIptablesManagerIPSet(t *testing.T) {
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
"default", "accept HTTPS traffic from ports range", "default", "accept HTTPS traffic from ports range",
) )
require.NoError(t, err, "failed to add rule") for _, r := range rule2 {
require.Equal(t, rule2.(*Rule).ipsetName, "default-sport", "ipset name must be set") require.NoError(t, err, "failed to add rule")
require.Equal(t, rule2.(*Rule).ip, "10.20.0.3", "ipset IP must be set") require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
require.Equal(t, r.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
}
}) })
t.Run("delete first rule", func(t *testing.T) { t.Run("delete first rule", func(t *testing.T) {
err := manager.DeleteRule(rule1) for _, r := range rule1 {
require.NoError(t, err, "failed to delete rule") err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule")
require.NotContains(t, manager.rulesets, rule1.(*Rule).ruleID, "rule must be removed form the ruleset index") require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
}
}) })
t.Run("delete second rule", func(t *testing.T) { t.Run("delete second rule", func(t *testing.T) {
err := manager.DeleteRule(rule2) for _, r := range rule2 {
require.NoError(t, err, "failed to delete rule") err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule")
require.Empty(t, manager.rulesets, "rulesets index after removed second rule must be empty") require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
}
}) })
t.Run("reset check", func(t *testing.T) { t.Run("reset check", func(t *testing.T) {
@@ -206,7 +225,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
} }
func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) { func checkRuleSpecs(t *testing.T, ipv4Client *iptables.IPTables, chainName string, mustExists bool, rulespec ...string) {
t.Helper() t.Helper()
exists, err := ipv4Client.Exists("filter", chainName, rulespec...) exists, err := ipv4Client.Exists("filter", chainName, rulespec...)
require.NoError(t, err, "failed to check rule") require.NoError(t, err, "failed to check rule")
require.Falsef(t, !exists && mustExists, "rule '%v' does not exist", rulespec) require.Falsef(t, !exists && mustExists, "rule '%v' does not exist", rulespec)
@@ -232,7 +251,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(mock, true) manager, err := Create(context.Background(), mock)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -243,7 +262,6 @@ func TestIptablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
_, err = manager.client(net.ParseIP("10.20.0.100"))
require.NoError(t, err) require.NoError(t, err)
ip := net.ParseIP("10.20.0.100") ip := net.ParseIP("10.20.0.100")

View File

@@ -0,0 +1,340 @@
//go:build !android
package iptables
import (
"context"
"fmt"
"strings"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
Ipv4Forwarding = "netbird-rt-forwarding"
ipv4Nat = "netbird-rt-nat"
)
// constants needed to manage and create iptable rules
const (
tableFilter = "filter"
tableNat = "nat"
chainFORWARD = "FORWARD"
chainPOSTROUTING = "POSTROUTING"
chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWD = "NETBIRD-RT-FWD"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
)
type routerManager struct {
ctx context.Context
stop context.CancelFunc
iptablesClient *iptables.IPTables
rules map[string][]string
}
func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) {
ctx, cancel := context.WithCancel(parentCtx)
m := &routerManager{
ctx: ctx,
stop: cancel,
iptablesClient: iptablesClient,
rules: make(map[string][]string),
}
err := m.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to cleanup routing rules: %s", err)
return nil, err
}
err = m.createContainers()
if err != nil {
log.Errorf("failed to create containers for route: %s", err)
}
return m, err
}
// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain
func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair)
if err != nil {
return err
}
err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair))
if err != nil {
return err
}
if !pair.Masquerade {
return nil
}
err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
if err != nil {
return err
}
err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
if err != nil {
return err
}
return nil
}
// insertRoutingRule inserts an iptable rule
func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
var err error
ruleKey := firewall.GenKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, ruleKey, pair.Source, pair.Destination)
existingRule, found := i.rules[ruleKey]
if found {
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
}
delete(i.rules, ruleKey)
}
err = i.iptablesClient.Insert(table, chain, 1, rule...)
if err != nil {
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
}
i.rules[ruleKey] = rule
return nil
}
// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error {
err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair))
if err != nil {
return err
}
if !pair.Masquerade {
return nil
}
err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair))
if err != nil {
return err
}
return nil
}
func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error {
var err error
ruleKey := firewall.GenKey(keyFormat, pair.ID)
existingRule, found := i.rules[ruleKey]
if found {
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
}
}
delete(i.rules, ruleKey)
return nil
}
func (i *routerManager) RouteingFwChainName() string {
return chainRTFWD
}
func (i *routerManager) Reset() error {
err := i.cleanUpDefaultForwardRules()
if err != nil {
return err
}
i.rules = make(map[string][]string)
return nil
}
func (i *routerManager) cleanUpDefaultForwardRules() error {
err := i.cleanJumpRules()
if err != nil {
return err
}
log.Debug("flushing routing related tables")
ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD)
if err != nil {
log.Errorf("failed check chain %s,error: %v", chainRTFWD, err)
return err
} else if ok {
err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD)
if err != nil {
log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err)
return err
}
}
ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT)
if err != nil {
log.Errorf("failed check chain %s,error: %v", chainRTNAT, err)
return err
} else if ok {
err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT)
if err != nil {
log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err)
return err
}
}
return nil
}
func (i *routerManager) createContainers() error {
if i.rules[Ipv4Forwarding] != nil {
return nil
}
errMSGFormat := "failed creating chain %s,error: %v"
err := i.createChain(tableFilter, chainRTFWD)
if err != nil {
return fmt.Errorf(errMSGFormat, chainRTFWD, err)
}
err = i.createChain(tableNat, chainRTNAT)
if err != nil {
return fmt.Errorf(errMSGFormat, chainRTNAT, err)
}
err = i.addJumpRules()
if err != nil {
return fmt.Errorf("error while creating jump rules: %v", err)
}
return nil
}
// addJumpRules create jump rules to send packets to NetBird chains
func (i *routerManager) addJumpRules() error {
rule := []string{"-j", chainRTFWD}
err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...)
if err != nil {
return err
}
i.rules[Ipv4Forwarding] = rule
rule = []string{"-j", chainRTNAT}
err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
if err != nil {
return err
}
i.rules[ipv4Nat] = rule
return nil
}
// cleanJumpRules cleans jump rules that was sending packets to NetBird chains
func (i *routerManager) cleanJumpRules() error {
var err error
errMSGFormat := "failed cleaning rule from chain %s,err: %v"
rule, found := i.rules[Ipv4Forwarding]
if found {
err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, chainFORWARD, err)
}
}
rule, found = i.rules[ipv4Nat]
if found {
err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err)
}
}
rules, err := i.iptablesClient.List("nat", "POSTROUTING")
if err != nil {
return fmt.Errorf("failed to list rules: %s", err)
}
for _, ruleString := range rules {
if !strings.Contains(ruleString, "NETBIRD") {
continue
}
rule := strings.Fields(ruleString)
err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...)
if err != nil {
return fmt.Errorf("failed to delete postrouting jump rule: %s", err)
}
}
rules, err = i.iptablesClient.List(tableFilter, "FORWARD")
if err != nil {
return fmt.Errorf("failed to list rules in FORWARD chain: %s", err)
}
for _, ruleString := range rules {
if !strings.Contains(ruleString, "NETBIRD") {
continue
}
rule := strings.Fields(ruleString)
err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...)
if err != nil {
return fmt.Errorf("failed to delete FORWARD jump rule: %s", err)
}
}
return nil
}
func (i *routerManager) createChain(table, newChain string) error {
chains, err := i.iptablesClient.ListChains(table)
if err != nil {
return fmt.Errorf("couldn't get %s table chains, error: %v", table, err)
}
shouldCreateChain := true
for _, chain := range chains {
if chain == newChain {
shouldCreateChain = false
}
}
if shouldCreateChain {
err = i.iptablesClient.NewChain(table, newChain)
if err != nil {
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
}
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
if err != nil {
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
}
}
return nil
}
// genRuleSpec generates rule specification with comment identifier
func genRuleSpec(jump, id, source, destination string) []string {
return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id}
}
func getIptablesRuleType(table string) string {
ruleType := "forwarding"
if table == tableNat {
ruleType = "nat"
}
return ruleType
}

View File

@@ -0,0 +1,229 @@
//go:build !android
package iptables
import (
"context"
"os/exec"
"testing"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
)
func isIptablesSupported() bool {
_, err4 := exec.LookPath("iptables")
return err4 == nil
}
func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouterManager(context.TODO(), iptablesClient)
require.NoError(t, err, "should return a valid iptables manager")
defer func() {
_ = manager.Reset()
}()
require.Len(t, manager.rules, 2, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD)
require.True(t, exists, "forwarding rule should exist")
exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
require.True(t, exists, "postrouting rule should exist")
pair := firewall.RouterPair{
ID: "abc",
Source: "100.100.100.1/32",
Destination: "100.100.100.0/24",
Masquerade: true,
}
forward4RuleKey := firewall.GenKey(firewall.ForwardingFormat, pair.ID)
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.Source, pair.Destination)
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
require.NoError(t, err, "inserting rule should not return error")
nat4RuleKey := firewall.GenKey(firewall.NatFormat, pair.ID)
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.Source, pair.Destination)
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
}
func TestIptablesManager_InsertRoutingRules(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouterManager(context.TODO(), iptablesClient)
require.NoError(t, err, "shouldn't return error")
defer func() {
err := manager.Reset()
if err != nil {
log.Errorf("failed to reset iptables manager: %s", err)
}
}()
err = manager.InsertRoutingRules(testCase.InputPair)
require.NoError(t, err, "forwarding pair should be inserted")
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.True(t, exists, "forwarding rule should exist")
foundRule, found := manager.rules[forwardRuleKey]
require.True(t, found, "forwarding rule should exist in the manager map")
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.True(t, exists, "income forwarding rule should exist")
foundRule, found = manager.rules[inForwardRuleKey]
require.True(t, found, "income forwarding rule should exist in the manager map")
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade {
require.True(t, exists, "nat rule should be created")
foundNatRule, foundNat := manager.rules[natRuleKey]
require.True(t, foundNat, "nat rule should exist in the map")
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[natRuleKey]
require.False(t, foundNat, "nat rule should not exist in the map")
}
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade {
require.True(t, exists, "income nat rule should be created")
foundNatRule, foundNat := manager.rules[inNatRuleKey]
require.True(t, foundNat, "income nat rule should exist in the map")
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[inNatRuleKey]
require.False(t, foundNat, "income nat rule should not exist in the map")
}
})
}
}
func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newRouterManager(context.TODO(), iptablesClient)
require.NoError(t, err, "shouldn't return error")
defer func() {
_ = manager.Reset()
}()
require.NoError(t, err, "shouldn't return error")
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
require.NoError(t, err, "inserting rule should not return error")
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
require.NoError(t, err, "inserting rule should not return error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveRoutingRules(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.False(t, exists, "forwarding rule should not exist")
_, found := manager.rules[forwardRuleKey]
require.False(t, found, "forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.False(t, exists, "income forwarding rule should not exist")
_, found = manager.rules[inForwardRuleKey]
require.False(t, found, "income forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "nat rule should not exist")
_, found = manager.rules[natRuleKey]
require.False(t, found, "nat rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "income nat rule should not exist")
_, found = manager.rules[inNatRuleKey]
require.False(t, found, "income nat rule should exist in the manager map")
})
}
}

View File

@@ -7,8 +7,7 @@ type Rule struct {
specs []string specs []string
ip string ip string
dst bool chain string
v6 bool
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id

View File

@@ -0,0 +1,50 @@
package iptables
type ipList struct {
ips map[string]struct{}
}
func newIpList(ip string) ipList {
ips := make(map[string]struct{})
ips[ip] = struct{}{}
return ipList{
ips: ips,
}
}
func (s *ipList) addIP(ip string) {
s.ips[ip] = struct{}{}
}
type ipsetStore struct {
ipsets map[string]ipList // ipsetName -> ruleset
}
func newIpsetStore() *ipsetStore {
return &ipsetStore{
ipsets: make(map[string]ipList),
}
}
func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) {
r, ok := s.ipsets[ipsetName]
return r, ok
}
func (s *ipsetStore) addIpList(ipsetName string, list ipList) {
s.ipsets[ipsetName] = list
}
func (s *ipsetStore) deleteIpset(ipsetName string) {
s.ipsets[ipsetName] = ipList{}
delete(s.ipsets, ipsetName)
}
func (s *ipsetStore) ipsetNames() []string {
names := make([]string, 0, len(s.ipsets))
for name := range s.ipsets {
names = append(names, name)
}
return names
}

View File

@@ -1,9 +1,17 @@
package firewall package manager
import ( import (
"fmt"
"net" "net"
) )
const (
NatFormat = "netbird-nat-%s"
ForwardingFormat = "netbird-fwd-%s"
InNatFormat = "netbird-nat-in-%s"
InForwardingFormat = "netbird-fwd-in-%s"
)
// Rule abstraction should be implemented by each firewall manager // Rule abstraction should be implemented by each firewall manager
// //
// Each firewall type for different OS can use different type // Each firewall type for different OS can use different type
@@ -27,10 +35,8 @@ const (
type Action int type Action int
const ( const (
// ActionUnknown is a unknown action
ActionUnknown Action = iota
// ActionAccept is the action to accept a packet // ActionAccept is the action to accept a packet
ActionAccept ActionAccept Action = iota
// ActionDrop is the action to drop a packet // ActionDrop is the action to drop a packet
ActionDrop ActionDrop
) )
@@ -56,16 +62,27 @@ type Manager interface {
action Action, action Action,
ipsetName string, ipsetName string,
comment string, comment string,
) (Rule, error) ) ([]Rule, error)
// DeleteRule from the firewall by rule definition // DeleteRule from the firewall by rule definition
DeleteRule(rule Rule) error DeleteRule(rule Rule) error
// IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool
// InsertRoutingRules inserts a routing firewall rule
InsertRoutingRules(pair RouterPair) error
// RemoveRoutingRules removes a routing firewall rule
RemoveRoutingRules(pair RouterPair) error
// Reset firewall to the default state // Reset firewall to the default state
Reset() error Reset() error
// Flush the changes to firewall controller // Flush the changes to firewall controller
Flush() error Flush() error
}
// TODO: migrate routemanager firewal actions to this interface
func GenKey(format string, input string) string {
return fmt.Sprintf(format, input)
} }

View File

@@ -1,4 +1,4 @@
package firewall package manager
import ( import (
"strconv" "strconv"

View File

@@ -0,0 +1,18 @@
package manager
type RouterPair struct {
ID string
Source string
Destination string
Masquerade bool
}
func GetInPair(pair RouterPair) RouterPair {
return RouterPair{
ID: pair.ID,
// invert Source/Destination
Source: pair.Destination,
Destination: pair.Source,
Masquerade: pair.Masquerade,
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,85 @@
package nftables
import (
"net"
)
type ipsetStore struct {
ipsetReference map[string]int
ipsets map[string]map[string]struct{} // ipsetName -> list of ips
}
func newIpsetStore() *ipsetStore {
return &ipsetStore{
ipsetReference: make(map[string]int),
ipsets: make(map[string]map[string]struct{}),
}
}
func (s *ipsetStore) ips(ipsetName string) (map[string]struct{}, bool) {
r, ok := s.ipsets[ipsetName]
return r, ok
}
func (s *ipsetStore) newIpset(ipsetName string) map[string]struct{} {
s.ipsetReference[ipsetName] = 0
ipList := make(map[string]struct{})
s.ipsets[ipsetName] = ipList
return ipList
}
func (s *ipsetStore) deleteIpset(ipsetName string) {
delete(s.ipsetReference, ipsetName)
delete(s.ipsets, ipsetName)
}
func (s *ipsetStore) DeleteIpFromSet(ipsetName string, ip net.IP) {
ipList, ok := s.ipsets[ipsetName]
if !ok {
return
}
delete(ipList, ip.String())
}
func (s *ipsetStore) AddIpToSet(ipsetName string, ip net.IP) {
ipList, ok := s.ipsets[ipsetName]
if !ok {
return
}
ipList[ip.String()] = struct{}{}
}
func (s *ipsetStore) IsIpInSet(ipsetName string, ip net.IP) bool {
ipList, ok := s.ipsets[ipsetName]
if !ok {
return false
}
_, ok = ipList[ip.String()]
return ok
}
func (s *ipsetStore) AddReferenceToIpset(ipsetName string) {
s.ipsetReference[ipsetName]++
}
func (s *ipsetStore) DeleteReferenceFromIpSet(ipsetName string) {
r, ok := s.ipsetReference[ipsetName]
if !ok {
return
}
if r == 0 {
return
}
s.ipsetReference[ipsetName]--
}
func (s *ipsetStore) HasReferenceToSet(ipsetName string) bool {
if _, ok := s.ipsetReference[ipsetName]; !ok {
return false
}
if s.ipsetReference[ipsetName] == 0 {
return false
}
return true
}

View File

@@ -2,90 +2,52 @@ package nftables
import ( import (
"bytes" "bytes"
"encoding/binary" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"strconv"
"strings"
"sync" "sync"
"time"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface"
) )
const ( const (
// FilterTableName is the name of the table that is used for filtering by the Netbird client // tableName is the name of the table that is used for filtering by the Netbird client
FilterTableName = "netbird-acl" tableName = "netbird"
// FilterInputChainName is the name of the chain that is used for filtering incoming packets
FilterInputChainName = "netbird-acl-input-filter"
// FilterOutputChainName is the name of the chain that is used for filtering outgoing packets
FilterOutputChainName = "netbird-acl-output-filter"
AllowNetbirdInputRuleID = "allow Netbird incoming traffic"
) )
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
// Manager of iptables firewall // Manager of iptables firewall
type Manager struct { type Manager struct {
mutex sync.Mutex mutex sync.Mutex
rConn *nftables.Conn
rConn *nftables.Conn
sConn *nftables.Conn
tableIPv4 *nftables.Table
tableIPv6 *nftables.Table
filterInputChainIPv4 *nftables.Chain
filterOutputChainIPv4 *nftables.Chain
filterInputChainIPv6 *nftables.Chain
filterOutputChainIPv6 *nftables.Chain
rulesetManager *rulesetManager
setRemovedIPs map[string]struct{}
setRemoved map[string]*nftables.Set
wgIface iFaceMapper wgIface iFaceMapper
}
// iFaceMapper defines subset methods of interface required for manager router *router
type iFaceMapper interface { aclManager *AclManager
Name() string
Address() iface.WGAddress
} }
// Create nftables firewall manager // Create nftables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) { func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
// sConn is used for creating sets and adding/removing elements from them m := &Manager{
// it's differ then rConn (which does create new conn for each flush operation) rConn: &nftables.Conn{},
// and is permanent. Using same connection for booth type of operations wgIface: wgIface,
// overloads netlink with high amount of rules ( > 10000) }
sConn, err := nftables.New(nftables.AsLasting())
workTable, err := m.createWorkTable()
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := &Manager{ m.router, err = newRouter(context, workTable)
rConn: &nftables.Conn{}, if err != nil {
sConn: sConn, return nil, err
rulesetManager: newRuleManager(),
setRemovedIPs: map[string]struct{}{},
setRemoved: map[string]*nftables.Set{},
wgIface: wgIface,
} }
if err := m.Reset(); err != nil { m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName())
if err != nil {
return nil, err return nil, err
} }
@@ -98,649 +60,58 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *Manager) AddFiltering( func (m *Manager) AddFiltering(
ip net.IP, ip net.IP,
proto fw.Protocol, proto firewall.Protocol,
sPort *fw.Port, sPort *firewall.Port,
dPort *fw.Port, dPort *firewall.Port,
direction fw.RuleDirection, direction firewall.RuleDirection,
action fw.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string, comment string,
) (fw.Rule, error) { ) ([]firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
var (
err error
ipset *nftables.Set
table *nftables.Table
chain *nftables.Chain
)
if direction == fw.RuleDirectionOUT {
table, chain, err = m.chain(
ip,
FilterOutputChainName,
nftables.ChainHookOutput,
nftables.ChainPriorityFilter,
nftables.ChainTypeFilter)
} else {
table, chain, err = m.chain(
ip,
FilterInputChainName,
nftables.ChainHookInput,
nftables.ChainPriorityFilter,
nftables.ChainTypeFilter)
}
if err != nil {
return nil, err
}
rawIP := ip.To4() rawIP := ip.To4()
if rawIP == nil { if rawIP == nil {
rawIP = ip.To16() return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
} }
rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName) return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
if ipsetName != "" {
// if we already have set with given name, just add ip to the set
// and return rule with new ID in other case let's create rule
// with fresh created set and set element
var isSetNew bool
ipset, err = m.rConn.GetSetByName(table, ipsetName)
if err != nil {
if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil {
return nil, fmt.Errorf("get set name: %v", err)
}
isSetNew = true
}
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
return nil, fmt.Errorf("add set element for the first time: %v", err)
}
if err := m.sConn.Flush(); err != nil {
return nil, fmt.Errorf("flush add elements: %v", err)
}
if !isSetNew {
// if we already have nftables rules with set for given direction
// just add new rule to the ruleset and return new fw.Rule object
if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok {
return m.rulesetManager.addRule(ruleset, rawIP)
}
// if ipset exists but it is not linked to rule for given direction
// create new rule for direction and bind ipset to it later
}
}
ifaceKey := expr.MetaKeyIIFNAME
if direction == fw.RuleDirectionOUT {
ifaceKey = expr.MetaKeyOIFNAME
}
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
}
if proto != "all" {
expressions = append(expressions, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(9),
Len: uint32(1),
})
var protoData []byte
switch proto {
case fw.ProtocolTCP:
protoData = []byte{unix.IPPROTO_TCP}
case fw.ProtocolUDP:
protoData = []byte{unix.IPPROTO_UDP}
case fw.ProtocolICMP:
protoData = []byte{unix.IPPROTO_ICMP}
default:
return nil, fmt.Errorf("unsupported protocol: %s", proto)
}
expressions = append(expressions, &expr.Cmp{
Register: 1,
Op: expr.CmpOpEq,
Data: protoData,
})
}
// check if rawIP contains zeroed IPv4 0.0.0.0 or same IPv6 value
// in that case not add IP match expression into the rule definition
if !bytes.HasPrefix(anyIP, rawIP) {
// source address position
addrLen := uint32(len(rawIP))
addrOffset := uint32(12)
if addrLen == 16 {
addrOffset = 8
}
// change to destination address position if need
if direction == fw.RuleDirectionOUT {
addrOffset += addrLen
}
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: addrOffset,
Len: addrLen,
},
)
// add individual IP for match if no ipset defined
if ipset == nil {
expressions = append(expressions,
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rawIP,
},
)
} else {
expressions = append(expressions,
&expr.Lookup{
SourceRegister: 1,
SetName: ipsetName,
SetID: ipset.ID,
},
)
}
}
if sPort != nil && len(sPort.Values) != 0 {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 0,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*sPort),
},
)
}
if dPort != nil && len(dPort.Values) != 0 {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*dPort),
},
)
}
if action == fw.ActionAccept {
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
} else {
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
}
userData := []byte(strings.Join([]string{rulesetID, comment}, " "))
rule := m.rConn.InsertRule(&nftables.Rule{
Table: table,
Chain: chain,
Position: 0,
Exprs: expressions,
UserData: userData,
})
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush insert rule: %v", err)
}
ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset)
return m.rulesetManager.addRule(ruleset, rawIP)
}
// getRulesetID returns ruleset ID based on given parameters
func (m *Manager) getRulesetID(
ip net.IP,
proto fw.Protocol,
sPort *fw.Port,
dPort *fw.Port,
direction fw.RuleDirection,
action fw.Action,
ipsetName string,
) string {
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
if sPort != nil {
rulesetID += sPort.String()
}
rulesetID += ":"
if dPort != nil {
rulesetID += dPort.String()
}
rulesetID += ":"
rulesetID += strconv.Itoa(int(action))
if ipsetName == "" {
return "ip:" + ip.String() + rulesetID
}
return "set:" + ipsetName + rulesetID
}
// createSet in given table by name
func (m *Manager) createSet(
table *nftables.Table,
rawIP []byte,
name string,
) (*nftables.Set, error) {
keyType := nftables.TypeIPAddr
if len(rawIP) == 16 {
keyType = nftables.TypeIP6Addr
}
// else we create new ipset and continue creating rule
ipset := &nftables.Set{
Name: name,
Table: table,
Dynamic: true,
KeyType: keyType,
}
if err := m.rConn.AddSet(ipset, nil); err != nil {
return nil, fmt.Errorf("create set: %v", err)
}
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush created set: %v", err)
}
return ipset, nil
}
// chain returns the chain for the given IP address with specific settings
func (m *Manager) chain(
ip net.IP,
name string,
hook nftables.ChainHook,
priority nftables.ChainPriority,
cType nftables.ChainType,
) (*nftables.Table, *nftables.Chain, error) {
var err error
getChain := func(c *nftables.Chain, tf nftables.TableFamily) (*nftables.Chain, error) {
if c != nil {
return c, nil
}
return m.createChainIfNotExists(tf, FilterTableName, name, hook, priority, cType)
}
if ip.To4() != nil {
if name == FilterInputChainName {
m.filterInputChainIPv4, err = getChain(m.filterInputChainIPv4, nftables.TableFamilyIPv4)
return m.tableIPv4, m.filterInputChainIPv4, err
}
m.filterOutputChainIPv4, err = getChain(m.filterOutputChainIPv4, nftables.TableFamilyIPv4)
return m.tableIPv4, m.filterOutputChainIPv4, err
}
if name == FilterInputChainName {
m.filterInputChainIPv6, err = getChain(m.filterInputChainIPv6, nftables.TableFamilyIPv6)
return m.tableIPv4, m.filterInputChainIPv6, err
}
m.filterOutputChainIPv6, err = getChain(m.filterOutputChainIPv6, nftables.TableFamilyIPv6)
return m.tableIPv4, m.filterOutputChainIPv6, err
}
// table returns the table for the given family of the IP address
func (m *Manager) table(
family nftables.TableFamily, tableName string,
) (*nftables.Table, error) {
// we cache access to Netbird ACL table only
if tableName != FilterTableName {
return m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
}
if family == nftables.TableFamilyIPv4 {
if m.tableIPv4 != nil {
return m.tableIPv4, nil
}
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv4, tableName)
if err != nil {
return nil, err
}
m.tableIPv4 = table
return m.tableIPv4, nil
}
if m.tableIPv6 != nil {
return m.tableIPv6, nil
}
table, err := m.createTableIfNotExists(nftables.TableFamilyIPv6, tableName)
if err != nil {
return nil, err
}
m.tableIPv6 = table
return m.tableIPv6, nil
}
func (m *Manager) createTableIfNotExists(
family nftables.TableFamily, tableName string,
) (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(family)
if err != nil {
return nil, fmt.Errorf("list of tables: %w", err)
}
for _, t := range tables {
if t.Name == tableName {
return t, nil
}
}
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
if err := m.rConn.Flush(); err != nil {
return nil, err
}
return table, nil
}
func (m *Manager) createChainIfNotExists(
family nftables.TableFamily,
tableName string,
name string,
hooknum nftables.ChainHook,
priority nftables.ChainPriority,
chainType nftables.ChainType,
) (*nftables.Chain, error) {
table, err := m.table(family, tableName)
if err != nil {
return nil, err
}
chains, err := m.rConn.ListChainsOfTableFamily(family)
if err != nil {
return nil, fmt.Errorf("list of chains: %w", err)
}
for _, c := range chains {
if c.Name == name && c.Table.Name == table.Name {
return c, nil
}
}
polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{
Name: name,
Table: table,
Hooknum: hooknum,
Priority: priority,
Type: chainType,
Policy: &polAccept,
}
chain = m.rConn.AddChain(chain)
ifaceKey := expr.MetaKeyIIFNAME
shiftDSTAddr := 0
if name == FilterOutputChainName {
ifaceKey = expr.MetaKeyOIFNAME
shiftDSTAddr = 1
}
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
}
mask, _ := netip.AddrFromSlice(m.wgIface.Address().Network.Mask)
if m.wgIface.Address().IP.To4() == nil {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To16())
expressions = append(expressions,
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(8 + (16 * shiftDSTAddr)),
Len: 16,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 16,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: mask.Unmap().AsSlice(),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Verdict{Kind: expr.VerdictAccept},
)
} else {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
expressions = append(expressions,
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(12 + (4 * shiftDSTAddr)),
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Verdict{Kind: expr.VerdictAccept},
)
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: expressions,
})
expressions = []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{Kind: expr.VerdictDrop},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: expressions,
})
if err := m.rConn.Flush(); err != nil {
return nil, err
}
return chain, nil
} }
// DeleteRule from the firewall by rule definition // DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule fw.Rule) error { func (m *Manager) DeleteRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
nativeRule, ok := rule.(*Rule) return m.aclManager.DeleteRule(rule)
if !ok {
return fmt.Errorf("invalid rule type")
}
if nativeRule.nftRule == nil {
return nil
}
if nativeRule.nftSet != nil {
// call twice of delete set element raises error
// so we need to check if element is already removed
key := fmt.Sprintf("%s:%v", nativeRule.nftSet.Name, nativeRule.ip)
if _, ok := m.setRemovedIPs[key]; !ok {
err := m.sConn.SetDeleteElements(nativeRule.nftSet, []nftables.SetElement{{Key: nativeRule.ip}})
if err != nil {
log.Errorf("delete elements for set %q: %v", nativeRule.nftSet.Name, err)
}
if err := m.sConn.Flush(); err != nil {
return err
}
m.setRemovedIPs[key] = struct{}{}
}
}
if m.rulesetManager.deleteRule(nativeRule) {
// deleteRule indicates that we still have IP in the ruleset
// it means we should not remove the nftables rule but need to update set
// so we prepare IP to be removed from set on the next flush call
return nil
}
// ruleset doesn't contain IP anymore (or contains only one), remove nft rule
if err := m.rConn.DelRule(nativeRule.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
if err := m.rConn.Flush(); err != nil {
return err
}
nativeRule.nftRule = nil
if nativeRule.nftSet != nil {
if _, ok := m.setRemoved[nativeRule.nftSet.Name]; !ok {
m.setRemoved[nativeRule.nftSet.Name] = nativeRule.nftSet
}
nativeRule.nftSet = nil
}
return nil
} }
// Reset firewall to the default state func (m *Manager) IsServerRouteSupported() bool {
func (m *Manager) Reset() error { return true
m.mutex.Lock()
defer m.mutex.Unlock()
chains, err := m.rConn.ListChains()
if err != nil {
return fmt.Errorf("list of chains: %w", err)
}
for _, c := range chains {
// delete Netbird allow input traffic rule if it exists
if c.Table.Name == "filter" && c.Name == "INPUT" {
rules, err := m.rConn.GetRules(c.Table, c)
if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err)
continue
}
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(AllowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
}
}
}
}
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
m.rConn.DelChain(c)
}
}
tables, err := m.rConn.ListTables()
if err != nil {
return fmt.Errorf("list of tables: %w", err)
}
for _, t := range tables {
if t.Name == FilterTableName {
m.rConn.DelTable(t)
}
}
return m.rConn.Flush()
} }
// Flush rule/chain/set operations from the buffer func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
//
// Method also get all rules after flush and refreshes handle values in the rulesets
func (m *Manager) Flush() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if err := m.flushWithBackoff(); err != nil { return m.router.InsertRoutingRules(pair)
return err }
}
// set must be removed after flush rule changes func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
// otherwise we will get error m.mutex.Lock()
for _, s := range m.setRemoved { defer m.mutex.Unlock()
m.rConn.FlushSet(s)
m.rConn.DelSet(s)
}
if len(m.setRemoved) > 0 { return m.router.RemoveRoutingRules(pair)
if err := m.flushWithBackoff(); err != nil {
return err
}
}
m.setRemovedIPs = map[string]struct{}{}
m.setRemoved = map[string]*nftables.Set{}
if err := m.refreshRuleHandles(m.tableIPv4, m.filterInputChainIPv4); err != nil {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
}
if err := m.refreshRuleHandles(m.tableIPv4, m.filterOutputChainIPv4); err != nil {
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
}
if err := m.refreshRuleHandles(m.tableIPv6, m.filterInputChainIPv6); err != nil {
log.Errorf("failed to refresh rule handles IPv6 input chain: %v", err)
}
if err := m.refreshRuleHandles(m.tableIPv6, m.filterOutputChainIPv6); err != nil {
log.Errorf("failed to refresh rule handles IPv6 output chain: %v", err)
}
return nil
} }
// AllowNetbird allows netbird interface traffic // AllowNetbird allows netbird interface traffic
// todo review this method usage
func (m *Manager) AllowNetbird() error { func (m *Manager) AllowNetbird() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
tf := nftables.TableFamilyIPv4 chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if m.wgIface.Address().IP.To4() == nil {
tf = nftables.TableFamilyIPv6
}
chains, err := m.rConn.ListChainsOfTableFamily(tf)
if err != nil { if err != nil {
return fmt.Errorf("list of chains: %w", err) return fmt.Errorf("list of chains: %w", err)
} }
@@ -777,47 +148,75 @@ func (m *Manager) AllowNetbird() error {
return nil return nil
} }
func (m *Manager) flushWithBackoff() (err error) { // Reset firewall to the default state
backoff := 4 func (m *Manager) Reset() error {
backoffTime := 1000 * time.Millisecond m.mutex.Lock()
for i := 0; ; i++ { defer m.mutex.Unlock()
err = m.rConn.Flush()
if err != nil { chains, err := m.rConn.ListChains()
if !strings.Contains(err.Error(), "busy") { if err != nil {
return return fmt.Errorf("list of chains: %w", err)
}
log.Error("failed to flush nftables, retrying...")
if i == backoff-1 {
return err
}
time.Sleep(backoffTime)
backoffTime *= 2
continue
}
break
} }
return
for _, c := range chains {
// delete Netbird allow input traffic rule if it exists
if c.Table.Name == "filter" && c.Name == "INPUT" {
rules, err := m.rConn.GetRules(c.Table, c)
if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err)
continue
}
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
}
}
}
}
}
m.router.ResetForwardRules()
tables, err := m.rConn.ListTables()
if err != nil {
return fmt.Errorf("list of tables: %w", err)
}
for _, t := range tables {
if t.Name == tableName {
m.rConn.DelTable(t)
}
}
return m.rConn.Flush()
} }
func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chain) error { // Flush rule/chain/set operations from the buffer
if table == nil || chain == nil { //
return nil // Method also get all rules after flush and refreshes handle values in the rulesets
} // todo review this method usage
func (m *Manager) Flush() error {
m.mutex.Lock()
defer m.mutex.Unlock()
list, err := m.rConn.GetRules(table, chain) return m.aclManager.Flush()
}
func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("list of tables: %w", err)
} }
for _, rule := range list { for _, t := range tables {
if len(rule.UserData) != 0 { if t.Name == tableName {
if err := m.rulesetManager.setNftRuleHandle(rule); err != nil { m.rConn.DelTable(t)
log.Errorf("failed to set rule handle: %v", err)
}
} }
} }
return nil table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
err = m.rConn.Flush()
return table, err
} }
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
@@ -835,7 +234,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
Kind: expr.VerdictAccept, Kind: expr.VerdictAccept,
}, },
}, },
UserData: []byte(AllowNetbirdInputRuleID), UserData: []byte(allowNetbirdInputRuleID),
} }
_ = m.rConn.InsertRule(rule) _ = m.rConn.InsertRule(rule)
} }
@@ -857,15 +256,3 @@ func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftable
} }
return nil return nil
} }
func encodePort(port fw.Port) []byte {
bs := make([]byte, 2)
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
return bs
}
func ifname(n string) []byte {
b := make([]byte, 16)
copy(b, []byte(n+"\x00"))
return b
}

View File

@@ -1,6 +1,7 @@
package nftables package nftables
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -12,7 +13,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
) )
@@ -53,7 +54,7 @@ func TestNftablesManager(t *testing.T) {
} }
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(context.Background(), mock)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
@@ -82,14 +83,10 @@ func TestNftablesManager(t *testing.T) {
err = manager.Flush() err = manager.Flush()
require.NoError(t, err, "failed to flush") require.NoError(t, err, "failed to flush")
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4) rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules") require.NoError(t, err, "failed to get rules")
// test expectations: require.Len(t, rules, 1, "expected 1 rules")
// 1) regular rule
// 2) "accept extra routed traffic rule" for the interface
// 3) "drop all rule" for the interface
require.Len(t, rules, 3, "expected 3 rules")
ipToAdd, _ := netip.AddrFromSlice(ip) ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap() add := ipToAdd.Unmap()
@@ -137,18 +134,17 @@ func TestNftablesManager(t *testing.T) {
} }
require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions") require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
err = manager.DeleteRule(rule) for _, r := range rule {
require.NoError(t, err, "failed to delete rule") err = manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule")
}
err = manager.Flush() err = manager.Flush()
require.NoError(t, err, "failed to flush") require.NoError(t, err, "failed to flush")
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4) rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules") require.NoError(t, err, "failed to get rules")
// test expectations: require.Len(t, rules, 0, "expected 0 rules after deletion")
// 1) "accept extra routed traffic rule" for the interface
// 2) "drop all rule" for the interface
require.Len(t, rules, 2, "expected 2 rules after deletion")
err = manager.Reset() err = manager.Reset()
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
@@ -173,7 +169,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(context.Background(), mock)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)

View File

@@ -0,0 +1,413 @@
package nftables
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"net/netip"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/manager"
)
const (
chainNameRouteingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-nat"
userDataAcceptForwardRuleSrc = "frwacceptsrc"
userDataAcceptForwardRuleDst = "frwacceptdst"
)
// some presets for building nftable rules
var (
zeroXor = binaryutil.NativeEndian.PutUint32(0)
exprCounterAccept = []expr.Any{
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
)
type router struct {
ctx context.Context
stop context.CancelFunc
conn *nftables.Conn
workTable *nftables.Table
filterTable *nftables.Table
chains map[string]*nftables.Chain
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
rules map[string]*nftables.Rule
isDefaultFwdRulesEnabled bool
}
func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
r := &router{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
}
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules")
} else {
return nil, err
}
}
err = r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.createContainers()
if err != nil {
log.Errorf("failed to create containers for route: %s", err)
}
return r, err
}
func (r *router) RouteingFwChainName() string {
return chainNameRouteingFw
}
// ResetForwardRules cleans existing nftables default forward rules from the system
func (r *router) ResetForwardRules() {
err := r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to reset forward rules: %s", err)
}
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
}
for _, table := range tables {
if table.Name == "filter" {
return table, nil
}
}
return nil, errFilterTableNotFound
}
func (r *router) createContainers() error {
r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRouteingFw,
Table: r.workTable,
})
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityNATSource - 1,
Type: nftables.ChainTypeNAT,
})
err := r.refreshRulesMap()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err)
}
return nil
}
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
if err != nil {
return err
}
err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
if err != nil {
return err
}
err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
if err != nil {
return err
}
if pair.Masquerade {
err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
if err != nil {
return err
}
err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
if err != nil {
return err
}
}
if r.filterTable != nil && !r.isDefaultFwdRulesEnabled {
log.Debugf("add default accept forward rule")
r.acceptForwardRule(pair.Source)
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err)
}
return nil
}
// insertRoutingRule inserts a nftable rule to the conn client flush queue
func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
var expression []expr.Any
if isNat {
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic
} else {
expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic
}
ruleKey := manager.GenKey(format, pair.ID)
_, exists := r.rules[ruleKey]
if exists {
err := r.removeRoutingRule(format, pair)
if err != nil {
return err
}
}
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainName],
Exprs: expression,
UserData: []byte(ruleKey),
})
return nil
}
func (r *router) acceptForwardRule(sourceNetwork string) {
src := generateCIDRMatcherExpressions(true, sourceNetwork)
dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0")
var exprs []expr.Any
exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic
Kind: expr.VerdictAccept,
})...)
rule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleSrc),
}
r.conn.AddRule(rule)
src = generateCIDRMatcherExpressions(true, "0.0.0.0/0")
dst = generateCIDRMatcherExpressions(false, sourceNetwork)
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
Kind: expr.VerdictAccept,
})...)
rule = &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleDst),
}
r.conn.AddRule(rule)
r.isDefaultFwdRulesEnabled = true
}
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
if err != nil {
return err
}
err = r.removeRoutingRule(manager.ForwardingFormat, pair)
if err != nil {
return err
}
err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair))
if err != nil {
return err
}
err = r.removeRoutingRule(manager.NatFormat, pair)
if err != nil {
return err
}
err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair))
if err != nil {
return err
}
if len(r.rules) == 0 {
err := r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
}
log.Debugf("nftables: removed rules for %s", pair.Destination)
return nil
}
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error {
ruleKey := manager.GenKey(format, pair.ID)
rule, found := r.rules[ruleKey]
if found {
ruleType := "forwarding"
if rule.Chain.Type == nftables.ChainTypeNAT {
ruleType = "nat"
}
err := r.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err)
}
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination)
delete(r.rules, ruleKey)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
func (r *router) refreshRulesMap() error {
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
r.rules[string(rule.UserData)] = rule
}
}
}
return nil
}
func (r *router) cleanUpDefaultForwardRules() error {
if r.filterTable == nil {
r.isDefaultFwdRulesEnabled = false
return nil
}
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return err
}
var rules []*nftables.Rule
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name {
continue
}
if chain.Name != "FORWARD" {
continue
}
rules, err = r.conn.GetRules(r.filterTable, chain)
if err != nil {
return err
}
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) {
err := r.conn.DelRule(rule)
if err != nil {
return err
}
}
}
r.isDefaultFwdRulesEnabled = false
return r.conn.Flush()
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
ip, network, _ := net.ParseCIDR(cidr)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
var offSet uint32
if source {
offSet = 12 // src offset
} else {
offSet = 16 // dst offset
}
return []expr.Any{
// fetch src add
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offSet,
Len: 4,
},
// net mask
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 4,
Mask: network.Mask,
Xor: zeroXor,
},
// net address
&expr.Cmp{
Register: 1,
Data: add.AsSlice(),
},
}
}

View File

@@ -0,0 +1,280 @@
//go:build !android
package nftables
import (
"context"
"testing"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
"github.com/google/nftables/expr"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
)
const (
// UNKNOWN is the default value for the firewall type for unknown firewall type
UNKNOWN = iota
// IPTABLES is the value for the iptables firewall type
IPTABLES
// NFTABLES is the value for the nftables firewall type
NFTABLES
)
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this OS")
}
table, err := createWorkTable()
if err != nil {
t.Fatal(err)
}
defer deleteWorkTable()
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(context.TODO(), table)
require.NoError(t, err, "failed to create router")
nftablesTestingClient := &nftables.Conn{}
defer manager.ResetForwardRules()
require.NoError(t, err, "shouldn't return error")
err = manager.InsertRoutingRules(testCase.InputPair)
defer func() {
_ = manager.RemoveRoutingRules(testCase.InputPair)
}()
require.NoError(t, err, "forwarding pair should be inserted")
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.InputPair.Masquerade {
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
testingExpression = append(sourceExp, destExp...) //nolint:gocritic
inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
found = 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.InputPair.Masquerade {
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
})
}
}
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this OS")
}
table, err := createWorkTable()
if err != nil {
t.Fatal(err)
}
defer deleteWorkTable()
for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(context.TODO(), table)
require.NoError(t, err, "failed to create router")
nftablesTestingClient := &nftables.Conn{}
defer manager.ResetForwardRules()
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRouteingFw],
Exprs: forwardExp,
UserData: []byte(forwardRuleKey),
})
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRoutingNat],
Exprs: natExp,
UserData: []byte(natRuleKey),
})
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRouteingFw],
Exprs: forwardExp,
UserData: []byte(inForwardRuleKey),
})
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRoutingNat],
Exprs: natExp,
UserData: []byte(inNatRuleKey),
})
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
manager.ResetForwardRules()
err = manager.RemoveRoutingRules(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 {
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
}
}
}
})
}
}
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func check() int {
nf := nftables.Conn{}
if _, err := nf.ListChains(); err == nil {
return NFTABLES
}
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return UNKNOWN
}
if isIptablesClientAvailable(ip) {
return IPTABLES
}
return UNKNOWN
}
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}
func createWorkTable() (*nftables.Table, error) {
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return nil, err
}
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, err
}
for _, t := range tables {
if t.Name == tableName {
sConn.DelTable(t)
}
}
table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
err = sConn.Flush()
return table, err
}
func deleteWorkTable() {
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return
}
tables, err := sConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return
}
for _, t := range tables {
if t.Name == tableName {
sConn.DelTable(t)
}
}
}

View File

@@ -1,6 +1,8 @@
package nftables package nftables
import ( import (
"net"
"github.com/google/nftables" "github.com/google/nftables"
) )
@@ -8,9 +10,8 @@ import (
type Rule struct { type Rule struct {
nftRule *nftables.Rule nftRule *nftables.Rule
nftSet *nftables.Set nftSet *nftables.Set
ruleID string
ruleID string ip net.IP
ip []byte
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id

View File

@@ -1,115 +0,0 @@
package nftables
import (
"bytes"
"fmt"
"github.com/google/nftables"
"github.com/rs/xid"
)
// nftRuleset links native firewall rule and ipset to ACL generated rules
type nftRuleset struct {
nftRule *nftables.Rule
nftSet *nftables.Set
issuedRules map[string]*Rule
rulesetID string
}
type rulesetManager struct {
rulesets map[string]*nftRuleset
nftSetName2rulesetID map[string]string
issuedRuleID2rulesetID map[string]string
}
func newRuleManager() *rulesetManager {
return &rulesetManager{
rulesets: map[string]*nftRuleset{},
nftSetName2rulesetID: map[string]string{},
issuedRuleID2rulesetID: map[string]string{},
}
}
func (r *rulesetManager) getRuleset(rulesetID string) (*nftRuleset, bool) {
ruleset, ok := r.rulesets[rulesetID]
return ruleset, ok
}
func (r *rulesetManager) createRuleset(
rulesetID string,
nftRule *nftables.Rule,
nftSet *nftables.Set,
) *nftRuleset {
ruleset := nftRuleset{
rulesetID: rulesetID,
nftRule: nftRule,
nftSet: nftSet,
issuedRules: map[string]*Rule{},
}
r.rulesets[ruleset.rulesetID] = &ruleset
if nftSet != nil {
r.nftSetName2rulesetID[nftSet.Name] = ruleset.rulesetID
}
return &ruleset
}
func (r *rulesetManager) addRule(
ruleset *nftRuleset,
ip []byte,
) (*Rule, error) {
if _, ok := r.rulesets[ruleset.rulesetID]; !ok {
return nil, fmt.Errorf("ruleset not found")
}
rule := Rule{
nftRule: ruleset.nftRule,
nftSet: ruleset.nftSet,
ruleID: xid.New().String(),
ip: ip,
}
ruleset.issuedRules[rule.ruleID] = &rule
r.issuedRuleID2rulesetID[rule.ruleID] = ruleset.rulesetID
return &rule, nil
}
// deleteRule from ruleset and returns true if contains other rules
func (r *rulesetManager) deleteRule(rule *Rule) bool {
rulesetID, ok := r.issuedRuleID2rulesetID[rule.ruleID]
if !ok {
return false
}
ruleset := r.rulesets[rulesetID]
if ruleset.nftRule == nil {
return false
}
delete(r.issuedRuleID2rulesetID, rule.ruleID)
delete(ruleset.issuedRules, rule.ruleID)
if len(ruleset.issuedRules) == 0 {
delete(r.rulesets, ruleset.rulesetID)
if rule.nftSet != nil {
delete(r.nftSetName2rulesetID, rule.nftSet.Name)
}
return false
}
return true
}
// setNftRuleHandle finds rule by userdata which contains rulesetID and updates it's handle number
//
// This is important to do, because after we add rule to the nftables we can't update it until
// we set correct handle value to it.
func (r *rulesetManager) setNftRuleHandle(nftRule *nftables.Rule) error {
split := bytes.Split(nftRule.UserData, []byte(" "))
ruleset, ok := r.rulesets[string(split[0])]
if !ok {
return fmt.Errorf("ruleset not found")
}
*ruleset.nftRule = *nftRule
return nil
}

View File

@@ -1,122 +0,0 @@
package nftables
import (
"testing"
"github.com/google/nftables"
"github.com/stretchr/testify/require"
)
func TestRulesetManager_createRuleset(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{
UserData: []byte(rulesetID),
}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
require.NotNil(t, ruleset, "createRuleset() failed")
require.Equal(t, ruleset.rulesetID, rulesetID, "rulesetID is incorrect")
require.Equal(t, ruleset.nftRule, &nftRule, "nftRule is incorrect")
}
func TestRulesetManager_addRule(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
// Add a rule to the ruleset.
ip := []byte("192.168.1.1")
rule, err := rulesetManager.addRule(ruleset, ip)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule, "rule should not be nil")
require.NotEqual(t, rule.ruleID, "ruleID is empty")
require.EqualValues(t, rule.ip, ip, "ip is incorrect")
require.Contains(t, ruleset.issuedRules, rule.ruleID, "ruleID already exists in ruleset")
require.Contains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "ruleID already exists in ruleset manager")
ruleset2 := &nftRuleset{
rulesetID: "ruleset-2",
}
_, err = rulesetManager.addRule(ruleset2, ip)
require.Error(t, err, "addRule() should have failed")
}
func TestRulesetManager_deleteRule(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
// Add a rule to the ruleset.
ip := []byte("192.168.1.1")
rule, err := rulesetManager.addRule(ruleset, ip)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule, "rule should not be nil")
ip2 := []byte("192.168.1.1")
rule2, err := rulesetManager.addRule(ruleset, ip2)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule2, "rule should not be nil")
hasNext := rulesetManager.deleteRule(rule)
require.True(t, hasNext, "deleteRule() should have returned true")
// Check that the rule is no longer in the manager.
require.NotContains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "rule should have been deleted")
hasNext = rulesetManager.deleteRule(rule2)
require.False(t, hasNext, "deleteRule() should have returned false")
}
func TestRulesetManager_setNftRuleHandle(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
// Add a rule to the ruleset.
ip := []byte("192.168.0.1")
rule, err := rulesetManager.addRule(ruleset, ip)
require.NoError(t, err, "addRule() failed")
require.NotNil(t, rule, "rule should not be nil")
nftRuleCopy := nftRule
nftRuleCopy.Handle = 2
nftRuleCopy.UserData = []byte(rulesetID)
err = rulesetManager.setNftRuleHandle(&nftRuleCopy)
require.NoError(t, err, "setNftRuleHandle() failed")
// check correct work with references
require.Equal(t, nftRule.Handle, uint64(2), "nftRule.Handle is incorrect")
}
func TestRulesetManager_getRuleset(t *testing.T) {
// Create a ruleset manager.
rulesetManager := newRuleManager()
// Create a ruleset.
rulesetID := "ruleset-1"
nftRule := nftables.Rule{}
nftSet := nftables.Set{
ID: 2,
}
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, &nftSet)
require.NotNil(t, ruleset, "createRuleset() failed")
find, ok := rulesetManager.getRuleset(rulesetID)
require.True(t, ok, "getRuleset() failed")
require.Equal(t, ruleset, find, "getRulesetBySetID() failed")
_, ok = rulesetManager.getRuleset("does-not-exist")
require.False(t, ok, "getRuleset() failed")
}

View File

@@ -0,0 +1,47 @@
//go:build !android
package test
import firewall "github.com/netbirdio/netbird/client/firewall/manager"
var (
InsertRuleTestCases = []struct {
Name string
InputPair firewall.RouterPair
}{
{
Name: "Insert Forwarding IPV4 Rule",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: "100.100.100.1/32",
Destination: "100.100.200.0/24",
Masquerade: false,
},
},
{
Name: "Insert Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: "100.100.100.1/32",
Destination: "100.100.200.0/24",
Masquerade: true,
},
},
}
RemoveRuleTestCases = []struct {
Name string
InputPair firewall.RouterPair
IpVersion string
}{
{
Name: "Remove Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{
ID: "zxa",
Source: "100.100.100.1/32",
Destination: "100.100.200.0/24",
Masquerade: true,
},
},
}
)

View File

@@ -1,4 +1,4 @@
//go:build !windows && !linux //go:build !windows
package uspfilter package uspfilter
@@ -10,10 +10,16 @@ func (m *Manager) Reset() error {
m.outgoingRules = make(map[string]RuleSet) m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet) m.incomingRules = make(map[string]RuleSet)
if m.nativeFirewall != nil {
return m.nativeFirewall.Reset()
}
return nil return nil
} }
// AllowNetbird allows netbird interface traffic // AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error { func (m *Manager) AllowNetbird() error {
if m.nativeFirewall != nil {
return m.nativeFirewall.AllowNetbird()
}
return nil return nil
} }

View File

@@ -1,21 +0,0 @@
package uspfilter
// AllowNetbird allows netbird interface traffic
func (m *Manager) AllowNetbird() error {
return nil
}
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if m.resetHook != nil {
return m.resetHook()
}
return nil
}

View File

@@ -5,7 +5,7 @@ import (
"github.com/google/gopacket" "github.com/google/gopacket"
fw "github.com/netbirdio/netbird/client/firewall" firewall "github.com/netbirdio/netbird/client/firewall/manager"
) )
// Rule to handle management of rules // Rule to handle management of rules
@@ -15,7 +15,7 @@ type Rule struct {
ipLayer gopacket.LayerType ipLayer gopacket.LayerType
matchByIP bool matchByIP bool
protoLayer gopacket.LayerType protoLayer gopacket.LayerType
direction fw.RuleDirection direction firewall.RuleDirection
sPort uint16 sPort uint16
dPort uint16 dPort uint16
drop bool drop bool

View File

@@ -10,12 +10,16 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
fw "github.com/netbirdio/netbird/client/firewall" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
) )
const layerTypeAll = 0 const layerTypeAll = 0
var (
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
)
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
SetFilter(iface.PacketFilter) error SetFilter(iface.PacketFilter) error
@@ -27,12 +31,12 @@ type RuleSet map[string]Rule
// Manager userspace firewall manager // Manager userspace firewall manager
type Manager struct { type Manager struct {
outgoingRules map[string]RuleSet outgoingRules map[string]RuleSet
incomingRules map[string]RuleSet incomingRules map[string]RuleSet
wgNetwork *net.IPNet wgNetwork *net.IPNet
decoders sync.Pool decoders sync.Pool
wgIface IFaceMapper wgIface IFaceMapper
resetHook func() error nativeFirewall firewall.Manager
mutex sync.RWMutex mutex sync.RWMutex
} }
@@ -52,6 +56,20 @@ type decoder struct {
// Create userspace firewall manager constructor // Create userspace firewall manager constructor
func Create(iface IFaceMapper) (*Manager, error) { func Create(iface IFaceMapper) (*Manager, error) {
return create(iface)
}
func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) {
mgr, err := create(iface)
if err != nil {
return nil, err
}
mgr.nativeFirewall = nativeFirewall
return mgr, nil
}
func create(iface IFaceMapper) (*Manager, error) {
m := &Manager{ m := &Manager{
decoders: sync.Pool{ decoders: sync.Pool{
New: func() any { New: func() any {
@@ -77,27 +95,50 @@ func Create(iface IFaceMapper) (*Manager, error) {
return m, nil return m, nil
} }
func (m *Manager) IsServerRouteSupported() bool {
if m.nativeFirewall == nil {
return false
} else {
return true
}
}
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
return m.nativeFirewall.InsertRoutingRules(pair)
}
// RemoveRoutingRules removes a routing firewall rule
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
return m.nativeFirewall.RemoveRoutingRules(pair)
}
// AddFiltering rule to the firewall // AddFiltering rule to the firewall
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *Manager) AddFiltering( func (m *Manager) AddFiltering(
ip net.IP, ip net.IP,
proto fw.Protocol, proto firewall.Protocol,
sPort *fw.Port, sPort *firewall.Port,
dPort *fw.Port, dPort *firewall.Port,
direction fw.RuleDirection, direction firewall.RuleDirection,
action fw.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string, comment string,
) (fw.Rule, error) { ) ([]firewall.Rule, error) {
r := Rule{ r := Rule{
id: uuid.New().String(), id: uuid.New().String(),
ip: ip, ip: ip,
ipLayer: layers.LayerTypeIPv6, ipLayer: layers.LayerTypeIPv6,
matchByIP: true, matchByIP: true,
direction: direction, direction: direction,
drop: action == fw.ActionDrop, drop: action == firewall.ActionDrop,
comment: comment, comment: comment,
} }
if ipNormalized := ip.To4(); ipNormalized != nil { if ipNormalized := ip.To4(); ipNormalized != nil {
@@ -118,21 +159,21 @@ func (m *Manager) AddFiltering(
} }
switch proto { switch proto {
case fw.ProtocolTCP: case firewall.ProtocolTCP:
r.protoLayer = layers.LayerTypeTCP r.protoLayer = layers.LayerTypeTCP
case fw.ProtocolUDP: case firewall.ProtocolUDP:
r.protoLayer = layers.LayerTypeUDP r.protoLayer = layers.LayerTypeUDP
case fw.ProtocolICMP: case firewall.ProtocolICMP:
r.protoLayer = layers.LayerTypeICMPv4 r.protoLayer = layers.LayerTypeICMPv4
if r.ipLayer == layers.LayerTypeIPv6 { if r.ipLayer == layers.LayerTypeIPv6 {
r.protoLayer = layers.LayerTypeICMPv6 r.protoLayer = layers.LayerTypeICMPv6
} }
case fw.ProtocolALL: case firewall.ProtocolALL:
r.protoLayer = layerTypeAll r.protoLayer = layerTypeAll
} }
m.mutex.Lock() m.mutex.Lock()
if direction == fw.RuleDirectionIN { if direction == firewall.RuleDirectionIN {
if _, ok := m.incomingRules[r.ip.String()]; !ok { if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(RuleSet) m.incomingRules[r.ip.String()] = make(RuleSet)
} }
@@ -144,12 +185,11 @@ func (m *Manager) AddFiltering(
m.outgoingRules[r.ip.String()][r.id] = r m.outgoingRules[r.ip.String()][r.id] = r
} }
m.mutex.Unlock() m.mutex.Unlock()
return []firewall.Rule{&r}, nil
return &r, nil
} }
// DeleteRule from the firewall by rule definition // DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule fw.Rule) error { func (m *Manager) DeleteRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -158,7 +198,7 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
return fmt.Errorf("delete rule: invalid rule type: %T", rule) return fmt.Errorf("delete rule: invalid rule type: %T", rule)
} }
if r.direction == fw.RuleDirectionIN { if r.direction == firewall.RuleDirectionIN {
_, ok := m.incomingRules[r.ip.String()][r.id] _, ok := m.incomingRules[r.ip.String()][r.id]
if !ok { if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id) return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
@@ -322,7 +362,7 @@ func (m *Manager) AddUDPPacketHook(
protoLayer: layers.LayerTypeUDP, protoLayer: layers.LayerTypeUDP,
dPort: dPort, dPort: dPort,
ipLayer: layers.LayerTypeIPv6, ipLayer: layers.LayerTypeIPv6,
direction: fw.RuleDirectionOUT, direction: firewall.RuleDirectionOUT,
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort), comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
udpHook: hook, udpHook: hook,
} }
@@ -333,7 +373,7 @@ func (m *Manager) AddUDPPacketHook(
m.mutex.Lock() m.mutex.Lock()
if in { if in {
r.direction = fw.RuleDirectionIN r.direction = firewall.RuleDirectionIN
if _, ok := m.incomingRules[r.ip.String()]; !ok { if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(map[string]Rule) m.incomingRules[r.ip.String()] = make(map[string]Rule)
} }
@@ -370,8 +410,3 @@ func (m *Manager) RemovePacketHook(hookID string) error {
} }
return fmt.Errorf("hook with given id not found") return fmt.Errorf("hook with given id not found")
} }
// SetResetHook which will be executed in the end of Reset method
func (m *Manager) SetResetHook(hook func() error) {
m.resetHook = hook
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
) )
@@ -125,24 +125,32 @@ func TestManagerDeleteRule(t *testing.T) {
return return
} }
err = m.DeleteRule(rule) for _, r := range rule {
if err != nil { err = m.DeleteRule(r)
t.Errorf("failed to delete rule: %v", err) if err != nil {
return t.Errorf("failed to delete rule: %v", err)
return
}
} }
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; !ok { for _, r := range rule2 {
t.Errorf("rule2 is not in the incomingRules") if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
t.Errorf("rule2 is not in the incomingRules")
}
} }
err = m.DeleteRule(rule2) for _, r := range rule2 {
if err != nil { err = m.DeleteRule(r)
t.Errorf("failed to delete rule: %v", err) if err != nil {
return t.Errorf("failed to delete rule: %v", err)
return
}
} }
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; ok { for _, r := range rule2 {
t.Errorf("rule2 is not in the incomingRules") if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; ok {
t.Errorf("rule2 is not in the incomingRules")
}
} }
} }

View File

@@ -11,42 +11,27 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/iface"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
SetFilter(iface.PacketFilter) error
}
// Manager is a ACL rules manager // Manager is a ACL rules manager
type Manager interface { type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap) ApplyFiltering(networkMap *mgmProto.NetworkMap)
Stop()
} }
// DefaultManager uses firewall manager to handle // DefaultManager uses firewall manager to handle
type DefaultManager struct { type DefaultManager struct {
manager firewall.Manager firewall firewall.Manager
ipsetCounter int ipsetCounter int
rulesPairs map[string][]firewall.Rule rulesPairs map[string][]firewall.Rule
mutex sync.Mutex mutex sync.Mutex
} }
type ipsetInfo struct { func NewDefaultManager(fm firewall.Manager) *DefaultManager {
name string
ipCount int
}
func newDefaultManager(fm firewall.Manager) *DefaultManager {
return &DefaultManager{ return &DefaultManager{
manager: fm, firewall: fm,
rulesPairs: make(map[string][]firewall.Rule), rulesPairs: make(map[string][]firewall.Rule),
} }
} }
@@ -69,13 +54,13 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
time.Since(start), total) time.Since(start), total)
}() }()
if d.manager == nil { if d.firewall == nil {
log.Debug("firewall manager is not supported, skipping firewall rules") log.Debug("firewall manager is not supported, skipping firewall rules")
return return
} }
defer func() { defer func() {
if err := d.manager.Flush(); err != nil { if err := d.firewall.Flush(); err != nil {
log.Error("failed to flush firewall rules: ", err) log.Error("failed to flush firewall rules: ", err)
} }
}() }()
@@ -125,57 +110,35 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
) )
} }
applyFailed := false
newRulePairs := make(map[string][]firewall.Rule) newRulePairs := make(map[string][]firewall.Rule)
ipsetByRuleSelectors := make(map[string]*ipsetInfo) ipsetByRuleSelectors := make(map[string]string)
// calculate which IP's can be grouped in by which ipset
// to do that we use rule selector (which is just rule properties without IP's)
for _, r := range rules {
selector := d.getRuleGroupingSelector(r)
ipset, ok := ipsetByRuleSelectors[selector]
if !ok {
ipset = &ipsetInfo{}
}
ipset.ipCount++
ipsetByRuleSelectors[selector] = ipset
}
for _, r := range rules { for _, r := range rules {
// if this rule is member of rule selection with more than DefaultIPsCountForSet // if this rule is member of rule selection with more than DefaultIPsCountForSet
// it's IP address can be used in the ipset for firewall manager which supports it // it's IP address can be used in the ipset for firewall manager which supports it
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)] selector := d.getRuleGroupingSelector(r)
if ipset.name == "" { ipsetName, ok := ipsetByRuleSelectors[selector]
if !ok {
d.ipsetCounter++ d.ipsetCounter++
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter) ipsetName = fmt.Sprintf("nb%07d", d.ipsetCounter)
ipsetByRuleSelectors[selector] = ipsetName
} }
ipsetName := ipset.name
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName) pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
if err != nil { if err != nil {
log.Errorf("failed to apply firewall rule: %+v, %v", r, err) log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
applyFailed = true d.rollBack(newRulePairs)
break break
} }
newRulePairs[pairID] = rulePair if len(rules) > 0 {
} d.rulesPairs[pairID] = rulePair
if applyFailed { newRulePairs[pairID] = rulePair
log.Error("failed to apply firewall rules, rollback ACL to previous state")
for _, rules := range newRulePairs {
for _, rule := range rules {
if err := d.manager.DeleteRule(rule); err != nil {
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
continue
}
}
} }
return
} }
for pairID, rules := range d.rulesPairs { for pairID, rules := range d.rulesPairs {
if _, ok := newRulePairs[pairID]; !ok { if _, ok := newRulePairs[pairID]; !ok {
for _, rule := range rules { for _, rule := range rules {
if err := d.manager.DeleteRule(rule); err != nil { if err := d.firewall.DeleteRule(rule); err != nil {
log.Errorf("failed to delete firewall rule: %v", err) log.Errorf("failed to delete firewall rule: %v", err)
continue continue
} }
@@ -186,16 +149,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
d.rulesPairs = newRulePairs d.rulesPairs = newRulePairs
} }
// Stop ACL controller and clear firewall state
func (d *DefaultManager) Stop() {
d.mutex.Lock()
defer d.mutex.Unlock()
if err := d.manager.Reset(); err != nil {
log.WithError(err).Error("reset firewall state")
}
}
func (d *DefaultManager) protoRuleToFirewallRule( func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule, r *mgmProto.FirewallRule,
ipsetName string, ipsetName string,
@@ -205,14 +158,14 @@ func (d *DefaultManager) protoRuleToFirewallRule(
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule") return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
} }
protocol := convertToFirewallProtocol(r.Protocol) protocol, err := convertToFirewallProtocol(r.Protocol)
if protocol == firewall.ProtocolUnknown { if err != nil {
return "", nil, fmt.Errorf("invalid protocol type: %d, skipping firewall rule", r.Protocol) return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
} }
action := convertFirewallAction(r.Action) action, err := convertFirewallAction(r.Action)
if action == firewall.ActionUnknown { if err != nil {
return "", nil, fmt.Errorf("invalid action type: %d, skipping firewall rule", r.Action) return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
} }
var port *firewall.Port var port *firewall.Port
@@ -232,7 +185,6 @@ func (d *DefaultManager) protoRuleToFirewallRule(
} }
var rules []firewall.Rule var rules []firewall.Rule
var err error
switch r.Direction { switch r.Direction {
case mgmProto.FirewallRule_IN: case mgmProto.FirewallRule_IN:
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "") rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
@@ -246,7 +198,6 @@ func (d *DefaultManager) protoRuleToFirewallRule(
return "", nil, err return "", nil, err
} }
d.rulesPairs[ruleID] = rules
return ruleID, rules, nil return ruleID, rules, nil
} }
@@ -259,24 +210,24 @@ func (d *DefaultManager) addInRules(
comment string, comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
var rules []firewall.Rule var rules []firewall.Rule
rule, err := d.manager.AddFiltering( rule, err := d.firewall.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment) ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
} }
rules = append(rules, rule) rules = append(rules, rule...)
if shouldSkipInvertedRule(protocol, port) { if shouldSkipInvertedRule(protocol, port) {
return rules, nil return rules, nil
} }
rule, err = d.manager.AddFiltering( rule, err = d.firewall.AddFiltering(
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment) ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
} }
return append(rules, rule), nil return append(rules, rule...), nil
} }
func (d *DefaultManager) addOutRules( func (d *DefaultManager) addOutRules(
@@ -288,24 +239,24 @@ func (d *DefaultManager) addOutRules(
comment string, comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
var rules []firewall.Rule var rules []firewall.Rule
rule, err := d.manager.AddFiltering( rule, err := d.firewall.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment) ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
} }
rules = append(rules, rule) rules = append(rules, rule...)
if shouldSkipInvertedRule(protocol, port) { if shouldSkipInvertedRule(protocol, port) {
return rules, nil return rules, nil
} }
rule, err = d.manager.AddFiltering( rule, err = d.firewall.AddFiltering(
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment) ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
} }
return append(rules, rule), nil return append(rules, rule...), nil
} }
// getRuleID() returns unique ID for the rule based on its parameters. // getRuleID() returns unique ID for the rule based on its parameters.
@@ -461,18 +412,29 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port) return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
} }
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol { func (d *DefaultManager) rollBack(newRulePairs map[string][]firewall.Rule) {
log.Debugf("rollback ACL to previous state")
for _, rules := range newRulePairs {
for _, rule := range rules {
if err := d.firewall.DeleteRule(rule); err != nil {
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
}
}
}
}
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) (firewall.Protocol, error) {
switch protocol { switch protocol {
case mgmProto.FirewallRule_TCP: case mgmProto.FirewallRule_TCP:
return firewall.ProtocolTCP return firewall.ProtocolTCP, nil
case mgmProto.FirewallRule_UDP: case mgmProto.FirewallRule_UDP:
return firewall.ProtocolUDP return firewall.ProtocolUDP, nil
case mgmProto.FirewallRule_ICMP: case mgmProto.FirewallRule_ICMP:
return firewall.ProtocolICMP return firewall.ProtocolICMP, nil
case mgmProto.FirewallRule_ALL: case mgmProto.FirewallRule_ALL:
return firewall.ProtocolALL return firewall.ProtocolALL, nil
default: default:
return firewall.ProtocolUnknown return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
} }
} }
@@ -480,13 +442,13 @@ func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) boo
return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil
} }
func convertFirewallAction(action mgmProto.FirewallRuleAction) firewall.Action { func convertFirewallAction(action mgmProto.FirewallRuleAction) (firewall.Action, error) {
switch action { switch action {
case mgmProto.FirewallRule_ACCEPT: case mgmProto.FirewallRule_ACCEPT:
return firewall.ActionAccept return firewall.ActionAccept, nil
case mgmProto.FirewallRule_DROP: case mgmProto.FirewallRule_DROP:
return firewall.ActionDrop return firewall.ActionDrop, nil
default: default:
return firewall.ActionUnknown return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action)
} }
} }

View File

@@ -1,28 +0,0 @@
//go:build !linux || android
package acl
import (
"fmt"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
)
// Create creates a firewall manager instance
func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
if iface.IsUserspaceBind() {
// use userspace packet filtering firewall
fm, err := uspfilter.Create(iface)
if err != nil {
return nil, err
}
if err := fm.AllowNetbird(); err != nil {
log.Warnf("failed to allow netbird interface traffic: %v", err)
}
return newDefaultManager(fm), nil
}
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
}

View File

@@ -1,77 +0,0 @@
//go:build !android
package acl
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/iptables"
"github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/checkfw"
)
// Create creates a firewall manager instance for the Linux
func Create(iface IFaceMapper) (*DefaultManager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
var fm firewall.Manager
var err error
checkResult := checkfw.Check()
switch checkResult {
case checkfw.IPTABLES, checkfw.IPTABLESWITHV6:
log.Debug("creating an iptables firewall manager for access control")
ipv6Supported := checkResult == checkfw.IPTABLESWITHV6
if fm, err = iptables.Create(iface, ipv6Supported); err != nil {
log.Infof("failed to create iptables manager for access control: %s", err)
}
case checkfw.NFTABLES:
log.Debug("creating an nftables firewall manager for access control")
if fm, err = nftables.Create(iface); err != nil {
log.Debugf("failed to create nftables manager for access control: %s", err)
}
}
var resetHookForUserspace func() error
if fm != nil && err == nil {
// err shadowing is used here, to ignore this error
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
resetHookForUserspace = fm.Reset
}
if iface.IsUserspaceBind() {
// use userspace packet filtering firewall
usfm, err := uspfilter.Create(iface)
if err != nil {
log.Debugf("failed to create userspace filtering firewall: %s", err)
return nil, err
}
// set kernel space firewall Reset as hook for userspace firewall
// manager Reset method, to clean up
if resetHookForUserspace != nil {
usfm.SetResetHook(resetHookForUserspace)
}
// to be consistent for any future extensions.
// ignore this error
if err := usfm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
fm = usfm
}
if fm == nil || err != nil {
log.Errorf("failed to create firewall manager: %s", err)
// no firewall manager found or initialized correctly
return nil, err
}
return newDefaultManager(fm), nil
}

View File

@@ -1,11 +1,14 @@
package acl package acl
import ( import (
"context"
"net" "net"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
@@ -49,12 +52,15 @@ func TestDefaultManager(t *testing.T) {
}).AnyTimes() }).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
acl, err := Create(ifaceMock) fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
if err != nil { if err != nil {
t.Errorf("create ACL manager: %v", err) t.Errorf("create firewall: %v", err)
return return
} }
defer acl.Stop() defer func(fw manager.Manager) {
_ = fw.Reset()
}(fw)
acl := NewDefaultManager(fw)
t.Run("apply firewall rules", func(t *testing.T) { t.Run("apply firewall rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap)
@@ -339,12 +345,15 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
}).AnyTimes() }).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
acl, err := Create(ifaceMock) fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
if err != nil { if err != nil {
t.Errorf("create ACL manager: %v", err) t.Errorf("create firewall: %v", err)
return return
} }
defer acl.Stop() defer func(fw manager.Manager) {
_ = fw.Reset()
}(fw)
acl := NewDefaultManager(fw)
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap)

View File

@@ -1,3 +0,0 @@
//go:build !linux || android
package checkfw

View File

@@ -1,56 +0,0 @@
//go:build !android
package checkfw
import (
"os"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
)
const (
// UNKNOWN is the default value for the firewall type for unknown firewall type
UNKNOWN FWType = iota
// IPTABLES is the value for the iptables firewall type
IPTABLES
// IPTABLESWITHV6 is the value for the iptables firewall type with ipv6
IPTABLESWITHV6
// NFTABLES is the value for the nftables firewall type
NFTABLES
)
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
// Check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func Check() FWType {
nf := nftables.Conn{}
if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
return NFTABLES
}
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err == nil {
if isIptablesClientAvailable(ip) {
ipSupport := IPTABLES
ipv6, ip6Err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
if ip6Err == nil {
if isIptablesClientAvailable(ipv6) {
ipSupport = IPTABLESWITHV6
}
}
return ipSupport
}
}
return UNKNOWN
}
func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}

View File

@@ -25,13 +25,30 @@ const (
type osManagerType int type osManagerType int
func (t osManagerType) String() string {
switch t {
case netbirdManager:
return "netbird"
case fileManager:
return "file"
case networkManager:
return "networkManager"
case systemdManager:
return "systemd"
case resolvConfManager:
return "resolvconf"
default:
return "unknown"
}
}
func newHostManager(wgInterface WGIface) (hostManager, error) { func newHostManager(wgInterface WGIface) (hostManager, error) {
osManager, err := getOSDNSManagerType() osManager, err := getOSDNSManagerType()
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Debugf("discovered mode is: %d", osManager) log.Debugf("discovered mode is: %s", osManager)
switch osManager { switch osManager {
case networkManager: case networkManager:
return newNetworkManagerDbusConfigurator(wgInterface) return newNetworkManagerDbusConfigurator(wgInterface)
@@ -65,7 +82,6 @@ func getOSDNSManagerType() (osManagerType, error) {
return netbirdManager, nil return netbirdManager, nil
} }
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() { if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
log.Debugf("is nm running on supported v? %t", isNetworkManagerSupportedVersion())
return networkManager, nil return networkManager, nil
} }
if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) { if strings.Contains(text, "systemd-resolved") && isDbusListenerRunning(systemdResolvedDest, systemdDbusObjectNode) {

View File

@@ -17,6 +17,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@@ -115,6 +117,7 @@ type Engine struct {
statusRecorder *peer.Status statusRecorder *peer.Status
firewall manager.Manager
routeManager routemanager.Manager routeManager routemanager.Manager
acl acl.Manager acl acl.Manager
@@ -231,6 +234,19 @@ func (e *Engine) Start() error {
return err return err
} }
e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface)
if err != nil {
log.Errorf("failed creating firewall manager: %s", err)
}
if e.firewall != nil && e.firewall.IsServerRouteSupported() {
err = e.routeManager.EnableServerRouter(e.firewall)
if err != nil {
e.close()
return err
}
}
err = e.wgInterface.Configure(myPrivateKey.String(), e.config.WgPort) err = e.wgInterface.Configure(myPrivateKey.String(), e.config.WgPort)
if err != nil { if err != nil {
log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIFaceName, err.Error()) log.Errorf("failed configuring Wireguard interface [%s]: %s", wgIFaceName, err.Error())
@@ -258,10 +274,8 @@ func (e *Engine) Start() error {
e.udpMux = mux e.udpMux = mux
} }
if acl, err := acl.Create(e.wgInterface); err != nil { if e.firewall != nil {
log.Errorf("failed to create ACL manager, policy will not work: %s", err.Error()) e.acl = acl.NewDefaultManager(e.firewall)
} else {
e.acl = acl
} }
err = e.dnsServer.Initialize() err = e.dnsServer.Initialize()
@@ -1044,8 +1058,11 @@ func (e *Engine) close() {
e.dnsServer.Stop() e.dnsServer.Stop()
} }
if e.acl != nil { if e.firewall != nil {
e.acl.Stop() err := e.firewall.Reset()
if err != nil {
log.Warnf("failed to reset firewall: %s", err)
}
} }
} }

View File

@@ -1,75 +0,0 @@
package routemanager
var insertRuleTestCases = []struct {
name string
inputPair routerPair
ipVersion string
}{
{
name: "Insert Forwarding IPV4 Rule",
inputPair: routerPair{
ID: "zxa",
source: "100.100.100.1/32",
destination: "100.100.200.0/24",
masquerade: false,
},
ipVersion: ipv4,
},
{
name: "Insert Forwarding And Nat IPV4 Rules",
inputPair: routerPair{
ID: "zxa",
source: "100.100.100.1/32",
destination: "100.100.200.0/24",
masquerade: true,
},
ipVersion: ipv4,
},
{
name: "Insert Forwarding IPV6 Rule",
inputPair: routerPair{
ID: "zxa",
source: "fc00::1/128",
destination: "fc12::/64",
masquerade: false,
},
ipVersion: ipv6,
},
{
name: "Insert Forwarding And Nat IPV6 Rules",
inputPair: routerPair{
ID: "zxa",
source: "fc00::1/128",
destination: "fc12::/64",
masquerade: true,
},
ipVersion: ipv6,
},
}
var removeRuleTestCases = []struct {
name string
inputPair routerPair
ipVersion string
}{
{
name: "Remove Forwarding And Nat IPV4 Rules",
inputPair: routerPair{
ID: "zxa",
source: "100.100.100.1/32",
destination: "100.100.200.0/24",
masquerade: true,
},
ipVersion: ipv4,
},
{
name: "Remove Forwarding And Nat IPV6 Rules",
inputPair: routerPair{
ID: "zxa",
source: "fc00::1/128",
destination: "fc12::/64",
masquerade: true,
},
ipVersion: ipv6,
},
}

View File

@@ -1,12 +0,0 @@
package routemanager
type firewallManager interface {
// RestoreOrCreateContainers restores or creates a firewall container set of rules, tables and default rules
RestoreOrCreateContainers() error
// InsertRoutingRules inserts a routing firewall rule
InsertRoutingRules(pair routerPair) error
// RemoveRoutingRules removes a routing firewall rule
RemoveRoutingRules(pair routerPair) error
// CleanRoutingRules cleans a firewall set of containers
CleanRoutingRules()
}

View File

@@ -1,55 +0,0 @@
//go:build !android
package routemanager
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/checkfw"
)
const (
ipv6Forwarding = "netbird-rt-ipv6-forwarding"
ipv4Forwarding = "netbird-rt-ipv4-forwarding"
ipv6Nat = "netbird-rt-ipv6-nat"
ipv4Nat = "netbird-rt-ipv4-nat"
natFormat = "netbird-nat-%s"
forwardingFormat = "netbird-fwd-%s"
inNatFormat = "netbird-nat-in-%s"
inForwardingFormat = "netbird-fwd-in-%s"
ipv6 = "ipv6"
ipv4 = "ipv4"
)
func genKey(format string, input string) string {
return fmt.Sprintf(format, input)
}
// newFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
func newFirewall(parentCTX context.Context) (firewallManager, error) {
checkResult := checkfw.Check()
switch checkResult {
case checkfw.IPTABLES, checkfw.IPTABLESWITHV6:
log.Debug("creating an iptables firewall manager for route rules")
ipv6Supported := checkResult == checkfw.IPTABLESWITHV6
return newIptablesManager(parentCTX, ipv6Supported)
case checkfw.NFTABLES:
log.Info("creating an nftables firewall manager for route rules")
return newNFTablesManager(parentCTX), nil
}
return nil, fmt.Errorf("couldn't initialize nftables or iptables clients. Using a dummy firewall manager for route rules")
}
func getInPair(pair routerPair) routerPair {
return routerPair{
ID: pair.ID,
// invert source/destination
source: pair.destination,
destination: pair.source,
masquerade: pair.masquerade,
}
}

View File

@@ -1,15 +0,0 @@
//go:build !linux
// +build !linux
package routemanager
import (
"context"
"fmt"
"runtime"
)
// newFirewall returns a nil manager
func newFirewall(context.Context) (firewallManager, error) {
return nil, fmt.Errorf("firewall not supported on %s", runtime.GOOS)
}

View File

@@ -1,487 +0,0 @@
//go:build !android
package routemanager
import (
"context"
"fmt"
"net/netip"
"os/exec"
"strings"
"sync"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
)
func isIptablesSupported() bool {
_, err4 := exec.LookPath("iptables")
_, err6 := exec.LookPath("ip6tables")
return err4 == nil && err6 == nil
}
// constants needed to manage and create iptable rules
const (
iptablesFilterTable = "filter"
iptablesNatTable = "nat"
iptablesForwardChain = "FORWARD"
iptablesPostRoutingChain = "POSTROUTING"
iptablesRoutingNatChain = "NETBIRD-RT-NAT"
iptablesRoutingForwardingChain = "NETBIRD-RT-FWD"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
)
// some presets for building nftable rules
var (
iptablesDefaultForwardingRule = []string{"-j", iptablesRoutingForwardingChain, "-m", "comment", "--comment"}
iptablesDefaultNetbirdForwardingRule = []string{"-j", "RETURN"}
iptablesDefaultNatRule = []string{"-j", iptablesRoutingNatChain, "-m", "comment", "--comment"}
iptablesDefaultNetbirdNatRule = []string{"-j", "RETURN"}
)
type iptablesManager struct {
ctx context.Context
stop context.CancelFunc
ipv4Client *iptables.IPTables
ipv6Client *iptables.IPTables
rules map[string]map[string][]string
mux sync.Mutex
}
func newIptablesManager(parentCtx context.Context, ipv6Supported bool) (*iptablesManager, error) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return nil, fmt.Errorf("failed to initialize iptables for ipv4: %s", err)
}
ctx, cancel := context.WithCancel(parentCtx)
manager := &iptablesManager{
ctx: ctx,
stop: cancel,
ipv4Client: ipv4Client,
rules: make(map[string]map[string][]string),
}
if ipv6Supported {
manager.ipv6Client, err = iptables.NewWithProtocol(iptables.ProtocolIPv6)
if err != nil {
log.Warnf("failed to initialize iptables for ipv6: %s. Routes for this protocol won't be applied.", err)
}
}
return manager, nil
}
// CleanRoutingRules cleans existing iptables resources that we created by the agent
func (i *iptablesManager) CleanRoutingRules() {
i.mux.Lock()
defer i.mux.Unlock()
err := i.cleanJumpRules()
if err != nil {
log.Error(err)
}
log.Debug("flushing tables")
errMSGFormat := "iptables: failed cleaning %s chain %s,error: %v"
if i.ipv4Client != nil {
err = i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
if err != nil {
log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
}
err = i.ipv4Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
if err != nil {
log.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
}
}
if i.ipv6Client != nil {
err = i.ipv6Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
if err != nil {
log.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
}
err = i.ipv6Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
if err != nil {
log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
}
}
log.Info("done cleaning up iptables rules")
}
// RestoreOrCreateContainers restores existing iptables containers (chains and rules)
// if they don't exist, we create them
func (i *iptablesManager) RestoreOrCreateContainers() error {
i.mux.Lock()
defer i.mux.Unlock()
if i.rules[ipv4][ipv4Forwarding] != nil && i.rules[ipv6][ipv6Forwarding] != nil {
return nil
}
errMSGFormat := "iptables: failed creating %s chain %s,error: %v"
if i.ipv4Client != nil {
err := createChain(i.ipv4Client, iptablesFilterTable, iptablesRoutingForwardingChain)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
}
err = createChain(i.ipv4Client, iptablesNatTable, iptablesRoutingNatChain)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
}
err = i.restoreRules(i.ipv4Client)
if err != nil {
return fmt.Errorf("iptables: error while restoring ipv4 rules: %v", err)
}
}
if i.ipv6Client != nil {
err := createChain(i.ipv6Client, iptablesFilterTable, iptablesRoutingForwardingChain)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
}
err = createChain(i.ipv6Client, iptablesNatTable, iptablesRoutingNatChain)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
}
err = i.restoreRules(i.ipv6Client)
if err != nil {
return fmt.Errorf("iptables: error while restoring ipv6 rules: %v", err)
}
}
err := i.addJumpRules()
if err != nil {
return fmt.Errorf("iptables: error while creating jump rules: %v", err)
}
return nil
}
// addJumpRules create jump rules to send packets to NetBird chains
func (i *iptablesManager) addJumpRules() error {
err := i.cleanJumpRules()
if err != nil {
return err
}
if i.ipv4Client != nil {
rule := append(iptablesDefaultForwardingRule, ipv4Forwarding) //nolint:gocritic
err = i.ipv4Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
if err != nil {
return err
}
i.rules[ipv4][ipv4Forwarding] = rule
rule = append(iptablesDefaultNatRule, ipv4Nat) //nolint:gocritic
err = i.ipv4Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
if err != nil {
return err
}
i.rules[ipv4][ipv4Nat] = rule
}
if i.ipv6Client != nil {
rule := append(iptablesDefaultForwardingRule, ipv6Forwarding) //nolint:gocritic
err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
if err != nil {
return err
}
i.rules[ipv6][ipv6Forwarding] = rule
rule = append(iptablesDefaultNatRule, ipv6Nat) //nolint:gocritic
err = i.ipv6Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
if err != nil {
return err
}
i.rules[ipv6][ipv6Nat] = rule
}
return nil
}
// cleanJumpRules cleans jump rules that was sending packets to NetBird chains
func (i *iptablesManager) cleanJumpRules() error {
var err error
errMSGFormat := "iptables: failed cleaning rule from %s chain %s,err: %v"
rule, found := i.rules[ipv4][ipv4Forwarding]
if i.ipv4Client != nil {
if found {
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Forwarding)
err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv4, iptablesForwardChain, err)
}
}
rule, found = i.rules[ipv4][ipv4Nat]
if found {
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Nat)
err = i.ipv4Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err)
}
}
}
if i.ipv6Client == nil {
rule, found = i.rules[ipv6][ipv6Forwarding]
if found {
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Forwarding)
err = i.ipv6Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv6, iptablesForwardChain, err)
}
}
rule, found = i.rules[ipv6][ipv6Nat]
if found {
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Nat)
err = i.ipv6Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err)
}
}
}
return nil
}
func iptablesProtoToString(proto iptables.Protocol) string {
if proto == iptables.ProtocolIPv6 {
return ipv6
}
return ipv4
}
// restoreRules restores existing NetBird rules
func (i *iptablesManager) restoreRules(iptablesClient *iptables.IPTables) error {
ipVersion := iptablesProtoToString(iptablesClient.Proto())
if i.rules[ipVersion] == nil {
i.rules[ipVersion] = make(map[string][]string)
}
table := iptablesFilterTable
for _, chain := range []string{iptablesForwardChain, iptablesRoutingForwardingChain} {
rules, err := iptablesClient.List(table, chain)
if err != nil {
return err
}
for _, ruleString := range rules {
rule := strings.Fields(ruleString)
id := getRuleRouteID(rule)
if id != "" {
i.rules[ipVersion][id] = rule[2:]
}
}
}
table = iptablesNatTable
for _, chain := range []string{iptablesPostRoutingChain, iptablesRoutingNatChain} {
rules, err := iptablesClient.List(table, chain)
if err != nil {
return err
}
for _, ruleString := range rules {
rule := strings.Fields(ruleString)
id := getRuleRouteID(rule)
if id != "" {
i.rules[ipVersion][id] = rule[2:]
}
}
}
return nil
}
// createChain create NetBird chains
func createChain(iptables *iptables.IPTables, table, newChain string) error {
chains, err := iptables.ListChains(table)
if err != nil {
return fmt.Errorf("couldn't get %s %s table chains, error: %v", iptablesProtoToString(iptables.Proto()), table, err)
}
shouldCreateChain := true
for _, chain := range chains {
if chain == newChain {
shouldCreateChain = false
}
}
if shouldCreateChain {
err = iptables.NewChain(table, newChain)
if err != nil {
return fmt.Errorf("couldn't create %s chain %s in %s table, error: %v", iptablesProtoToString(iptables.Proto()), newChain, table, err)
}
if table == iptablesNatTable {
err = iptables.Append(table, newChain, iptablesDefaultNetbirdNatRule...)
} else {
err = iptables.Append(table, newChain, iptablesDefaultNetbirdForwardingRule...)
}
if err != nil {
return fmt.Errorf("couldn't create %s chain %s default rule, error: %v", iptablesProtoToString(iptables.Proto()), newChain, err)
}
}
return nil
}
// genRuleSpec generates rule specification with comment identifier
func genRuleSpec(jump, id, source, destination string) []string {
return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id}
}
// getRuleRouteID returns the rule ID if matches our prefix
func getRuleRouteID(rule []string) string {
for i, flag := range rule {
if flag == "--comment" {
id := rule[i+1]
if strings.HasPrefix(id, "netbird-") {
return id
}
}
}
return ""
}
// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain
func (i *iptablesManager) InsertRoutingRules(pair routerPair) error {
i.mux.Lock()
defer i.mux.Unlock()
err := i.insertRoutingRule(forwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, routingFinalForwardJump, pair)
if err != nil {
return err
}
err = i.insertRoutingRule(inForwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, routingFinalForwardJump, getInPair(pair))
if err != nil {
return err
}
if !pair.masquerade {
return nil
}
err = i.insertRoutingRule(natFormat, iptablesNatTable, iptablesRoutingNatChain, routingFinalNatJump, pair)
if err != nil {
return err
}
err = i.insertRoutingRule(inNatFormat, iptablesNatTable, iptablesRoutingNatChain, routingFinalNatJump, getInPair(pair))
if err != nil {
return err
}
return nil
}
// insertRoutingRule inserts an iptable rule
func (i *iptablesManager) insertRoutingRule(keyFormat, table, chain, jump string, pair routerPair) error {
var err error
prefix := netip.MustParsePrefix(pair.source)
ipVersion := ipv4
iptablesClient := i.ipv4Client
if prefix.Addr().Unmap().Is6() {
iptablesClient = i.ipv6Client
ipVersion = ipv6
}
if iptablesClient == nil {
return fmt.Errorf("unable to insert iptables routing rules. Iptables client is not initialized")
}
ruleKey := genKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, ruleKey, pair.source, pair.destination)
existingRule, found := i.rules[ipVersion][ruleKey]
if found {
err = iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("iptables: error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
}
delete(i.rules[ipVersion], ruleKey)
}
err = iptablesClient.Insert(table, chain, 1, rule...)
if err != nil {
return fmt.Errorf("iptables: error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
}
i.rules[ipVersion][ruleKey] = rule
return nil
}
// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
func (i *iptablesManager) RemoveRoutingRules(pair routerPair) error {
i.mux.Lock()
defer i.mux.Unlock()
err := i.removeRoutingRule(forwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(inForwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, getInPair(pair))
if err != nil {
return err
}
if !pair.masquerade {
return nil
}
err = i.removeRoutingRule(natFormat, iptablesNatTable, iptablesRoutingNatChain, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(inNatFormat, iptablesNatTable, iptablesRoutingNatChain, getInPair(pair))
if err != nil {
return err
}
return nil
}
// removeRoutingRule removes an iptables rule
func (i *iptablesManager) removeRoutingRule(keyFormat, table, chain string, pair routerPair) error {
var err error
prefix := netip.MustParsePrefix(pair.source)
ipVersion := ipv4
iptablesClient := i.ipv4Client
if prefix.Addr().Unmap().Is6() {
iptablesClient = i.ipv6Client
ipVersion = ipv6
}
if iptablesClient == nil {
return fmt.Errorf("unable to remove iptables routing rules. Iptables client is not initialized")
}
ruleKey := genKey(keyFormat, pair.ID)
existingRule, found := i.rules[ipVersion][ruleKey]
if found {
err = iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("iptables: error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
}
}
delete(i.rules[ipVersion], ruleKey)
return nil
}
func getIptablesRuleType(table string) string {
ruleType := "forwarding"
if table == iptablesNatTable {
ruleType = "nat"
}
return ruleType
}

View File

@@ -1,294 +0,0 @@
//go:build !android
package routemanager
import (
"context"
"testing"
"github.com/coreos/go-iptables/iptables"
"github.com/stretchr/testify/require"
)
func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
manager, err := newIptablesManager(context.TODO(), true)
require.NoError(t, err, "should return a valid iptables manager")
defer manager.CleanRoutingRules()
err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
require.Len(t, manager.rules, 2, "should have created maps for ipv4 and ipv6")
require.Len(t, manager.rules[ipv4], 2, "should have created minimal rules for ipv4")
exists, err := manager.ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesFilterTable, iptablesForwardChain)
require.True(t, exists, "forwarding rule should exist")
exists, err = manager.ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesNatTable, iptablesPostRoutingChain)
require.True(t, exists, "postrouting rule should exist")
require.Len(t, manager.rules[ipv6], 2, "should have created minimal rules for ipv6")
exists, err = manager.ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesFilterTable, iptablesForwardChain)
require.True(t, exists, "forwarding rule should exist")
exists, err = manager.ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesNatTable, iptablesPostRoutingChain)
require.True(t, exists, "postrouting rule should exist")
pair := routerPair{
ID: "abc",
source: "100.100.100.1/32",
destination: "100.100.100.0/24",
masquerade: true,
}
forward4RuleKey := genKey(forwardingFormat, pair.ID)
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.source, pair.destination)
err = manager.ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...)
require.NoError(t, err, "inserting rule should not return error")
nat4RuleKey := genKey(natFormat, pair.ID)
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.source, pair.destination)
err = manager.ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error")
pair = routerPair{
ID: "abc",
source: "fc00::1/128",
destination: "fc11::/64",
masquerade: true,
}
forward6RuleKey := genKey(forwardingFormat, pair.ID)
forward6Rule := genRuleSpec(routingFinalForwardJump, forward6RuleKey, pair.source, pair.destination)
err = manager.ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...)
require.NoError(t, err, "inserting rule should not return error")
nat6RuleKey := genKey(natFormat, pair.ID)
nat6Rule := genRuleSpec(routingFinalNatJump, nat6RuleKey, pair.source, pair.destination)
err = manager.ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...)
require.NoError(t, err, "inserting rule should not return error")
delete(manager.rules, ipv4)
delete(manager.rules, ipv6)
err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
require.Len(t, manager.rules[ipv4], 4, "should have restored all rules for ipv4")
foundRule, found := manager.rules[ipv4][forward4RuleKey]
require.True(t, found, "forwarding rule should exist in the map")
require.Equal(t, forward4Rule[:4], foundRule[:4], "stored forwarding rule should match")
foundRule, found = manager.rules[ipv4][nat4RuleKey]
require.True(t, found, "nat rule should exist in the map")
require.Equal(t, nat4Rule[:4], foundRule[:4], "stored nat rule should match")
require.Len(t, manager.rules[ipv6], 4, "should have restored all rules for ipv6")
foundRule, found = manager.rules[ipv6][forward6RuleKey]
require.True(t, found, "forwarding rule should exist in the map")
require.Equal(t, forward6Rule[:4], foundRule[:4], "stored forward rule should match")
foundRule, found = manager.rules[ipv6][nat6RuleKey]
require.True(t, found, "nat rule should exist in the map")
require.Equal(t, nat6Rule[:4], foundRule[:4], "stored nat rule should match")
}
func TestIptablesManager_InsertRoutingRules(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
for _, testCase := range insertRuleTestCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
iptablesClient := ipv4Client
if testCase.ipVersion == ipv6 {
iptablesClient = ipv6Client
}
manager := &iptablesManager{
ctx: ctx,
stop: cancel,
ipv4Client: ipv4Client,
ipv6Client: ipv6Client,
rules: make(map[string]map[string][]string),
}
defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
err = manager.InsertRoutingRules(testCase.inputPair)
require.NoError(t, err, "forwarding pair should be inserted")
forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
exists, err := iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
require.True(t, exists, "forwarding rule should exist")
foundRule, found := manager.rules[testCase.ipVersion][forwardRuleKey]
require.True(t, found, "forwarding rule should exist in the manager map")
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
exists, err = iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
require.True(t, exists, "income forwarding rule should exist")
foundRule, found = manager.rules[testCase.ipVersion][inForwardRuleKey]
require.True(t, found, "income forwarding rule should exist in the manager map")
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
if testCase.inputPair.masquerade {
require.True(t, exists, "nat rule should be created")
foundNatRule, foundNat := manager.rules[testCase.ipVersion][natRuleKey]
require.True(t, foundNat, "nat rule should exist in the map")
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[testCase.ipVersion][natRuleKey]
require.False(t, foundNat, "nat rule should not exist in the map")
}
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
if testCase.inputPair.masquerade {
require.True(t, exists, "income nat rule should be created")
foundNatRule, foundNat := manager.rules[testCase.ipVersion][inNatRuleKey]
require.True(t, foundNat, "income nat rule should exist in the map")
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[testCase.ipVersion][inNatRuleKey]
require.False(t, foundNat, "income nat rule should not exist in the map")
}
})
}
}
func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
}
for _, testCase := range removeRuleTestCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
iptablesClient := ipv4Client
if testCase.ipVersion == ipv6 {
iptablesClient = ipv6Client
}
manager := &iptablesManager{
ctx: ctx,
stop: cancel,
ipv4Client: ipv4Client,
ipv6Client: ipv6Client,
rules: make(map[string]map[string][]string),
}
defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...)
require.NoError(t, err, "inserting rule should not return error")
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, inForwardRule...)
require.NoError(t, err, "inserting rule should not return error")
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error")
delete(manager.rules, ipv4)
delete(manager.rules, ipv6)
err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveRoutingRules(testCase.inputPair)
require.NoError(t, err, "shouldn't return error")
exists, err := iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
require.False(t, exists, "forwarding rule should not exist")
_, found := manager.rules[testCase.ipVersion][forwardRuleKey]
require.False(t, found, "forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
require.False(t, exists, "income forwarding rule should not exist")
_, found = manager.rules[testCase.ipVersion][inForwardRuleKey]
require.False(t, found, "income forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
require.False(t, exists, "nat rule should not exist")
_, found = manager.rules[testCase.ipVersion][natRuleKey]
require.False(t, found, "nat rule should exist in the manager map")
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
require.False(t, exists, "income nat rule should not exist")
_, found = manager.rules[testCase.ipVersion][inNatRuleKey]
require.False(t, found, "income nat rule should exist in the manager map")
})
}
}

View File

@@ -7,6 +7,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
@@ -19,6 +20,7 @@ type Manager interface {
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
SetRouteChangeListener(listener listener.NetworkChangeListener) SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string InitialRouteRange() []string
EnableServerRouter(firewall firewall.Manager) error
Stop() Stop()
} }
@@ -35,19 +37,12 @@ type DefaultManager struct {
notifier *notifier notifier *notifier
} }
// NewManager returns a new route manager
func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager { func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, statusRecorder *peer.Status, initialRoutes []*route.Route) *DefaultManager {
srvRouter, err := newServerRouter(ctx, wgInterface)
if err != nil {
log.Errorf("server router is not supported: %s", err)
}
mCTX, cancel := context.WithCancel(ctx) mCTX, cancel := context.WithCancel(ctx)
dm := &DefaultManager{ dm := &DefaultManager{
ctx: mCTX, ctx: mCTX,
stop: cancel, stop: cancel,
clientNetworks: make(map[string]*clientNetwork), clientNetworks: make(map[string]*clientNetwork),
serverRouter: srvRouter,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgInterface: wgInterface, wgInterface: wgInterface,
pubKey: pubKey, pubKey: pubKey,
@@ -61,6 +56,15 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
return dm return dm
} }
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
var err error
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall)
if err != nil {
return err
}
return nil
}
// Stop stops the manager watchers and clean firewall rules // Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop() { func (m *DefaultManager) Stop() {
m.stop() m.stop()

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@@ -37,6 +38,10 @@ func (m *MockManager) SetRouteChangeListener(listener listener.NetworkChangeList
} }
func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error {
panic("implement me")
}
// Stop mock implementation of Stop from Manager interface // Stop mock implementation of Stop from Manager interface
func (m *MockManager) Stop() { func (m *MockManager) Stop() {
if m.StopFunc != nil { if m.StopFunc != nil {

View File

@@ -1,571 +0,0 @@
//go:build !android
package routemanager
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
)
const (
nftablesTable = "netbird-rt"
nftablesRoutingForwardingChain = "netbird-rt-fwd"
nftablesRoutingNatChain = "netbird-rt-nat"
userDataAcceptForwardRuleSrc = "frwacceptsrc"
userDataAcceptForwardRuleDst = "frwacceptdst"
)
// constants needed to create nftable rules
const (
ipv4Len = 4
ipv4SrcOffset = 12
ipv4DestOffset = 16
ipv6Len = 16
ipv6SrcOffset = 8
ipv6DestOffset = 24
exprDirectionSource = "source"
exprDirectionDestination = "destination"
)
// some presets for building nftable rules
var (
zeroXor = binaryutil.NativeEndian.PutUint32(0)
zeroXor6 = append(binaryutil.NativeEndian.PutUint64(0), binaryutil.NativeEndian.PutUint64(0)...)
exprAllowRelatedEstablished = []expr.Any{
&expr.Ct{
Register: 1,
SourceRegister: false,
Key: 0,
},
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 4,
Mask: []uint8{0x6, 0x0, 0x0, 0x0},
Xor: zeroXor,
},
&expr.Cmp{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
exprCounterAccept = []expr.Any{
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
)
type nftablesManager struct {
ctx context.Context
stop context.CancelFunc
conn *nftables.Conn
tableIPv4 *nftables.Table
tableIPv6 *nftables.Table
chains map[string]map[string]*nftables.Chain
rules map[string]*nftables.Rule
filterTable *nftables.Table
defaultForwardRules []*nftables.Rule
mux sync.Mutex
}
func newNFTablesManager(parentCtx context.Context) *nftablesManager {
ctx, cancel := context.WithCancel(parentCtx)
return &nftablesManager{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
chains: make(map[string]map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
defaultForwardRules: make([]*nftables.Rule, 2),
}
}
// CleanRoutingRules cleans existing nftables rules from the system
func (n *nftablesManager) CleanRoutingRules() {
n.mux.Lock()
defer n.mux.Unlock()
log.Debug("flushing tables")
if n.tableIPv4 != nil && n.tableIPv6 != nil {
n.conn.FlushTable(n.tableIPv6)
n.conn.FlushTable(n.tableIPv4)
}
if n.defaultForwardRules[0] != nil {
err := n.eraseDefaultForwardRule()
if err != nil {
log.Errorf("failed to delete forward rule: %s", err)
}
}
log.Debugf("flushing tables result in: %v error", n.conn.Flush())
}
// RestoreOrCreateContainers restores existing nftables containers (tables and chains)
// if they don't exist, we create them
func (n *nftablesManager) RestoreOrCreateContainers() error {
n.mux.Lock()
defer n.mux.Unlock()
if n.tableIPv6 != nil && n.tableIPv4 != nil {
log.Debugf("nftables: containers already restored, skipping")
return nil
}
tables, err := n.conn.ListTables()
if err != nil {
return fmt.Errorf("nftables: unable to list tables: %v", err)
}
for _, table := range tables {
if table.Name == "filter" && table.Family == nftables.TableFamilyIPv4 {
log.Debugf("nftables: found filter table for ipv4")
n.filterTable = table
continue
}
if table.Name == nftablesTable {
if table.Family == nftables.TableFamilyIPv4 {
n.tableIPv4 = table
continue
}
n.tableIPv6 = table
}
}
if n.tableIPv4 == nil {
n.tableIPv4 = n.conn.AddTable(&nftables.Table{
Name: nftablesTable,
Family: nftables.TableFamilyIPv4,
})
}
if n.tableIPv6 == nil {
n.tableIPv6 = n.conn.AddTable(&nftables.Table{
Name: nftablesTable,
Family: nftables.TableFamilyIPv6,
})
}
chains, err := n.conn.ListChains()
if err != nil {
return fmt.Errorf("nftables: unable to list chains: %v", err)
}
n.chains[ipv4] = make(map[string]*nftables.Chain)
n.chains[ipv6] = make(map[string]*nftables.Chain)
for _, chain := range chains {
switch {
case chain.Table.Name == nftablesTable && chain.Table.Family == nftables.TableFamilyIPv4:
n.chains[ipv4][chain.Name] = chain
case chain.Table.Name == nftablesTable && chain.Table.Family == nftables.TableFamilyIPv6:
n.chains[ipv6][chain.Name] = chain
}
}
if _, found := n.chains[ipv4][nftablesRoutingForwardingChain]; !found {
n.chains[ipv4][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{
Name: nftablesRoutingForwardingChain,
Table: n.tableIPv4,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityNATDest + 1,
Type: nftables.ChainTypeFilter,
})
}
if _, found := n.chains[ipv4][nftablesRoutingNatChain]; !found {
n.chains[ipv4][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{
Name: nftablesRoutingNatChain,
Table: n.tableIPv4,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityNATSource - 1,
Type: nftables.ChainTypeNAT,
})
}
if _, found := n.chains[ipv6][nftablesRoutingForwardingChain]; !found {
n.chains[ipv6][nftablesRoutingForwardingChain] = n.conn.AddChain(&nftables.Chain{
Name: nftablesRoutingForwardingChain,
Table: n.tableIPv6,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityNATDest + 1,
Type: nftables.ChainTypeFilter,
})
}
if _, found := n.chains[ipv6][nftablesRoutingNatChain]; !found {
n.chains[ipv6][nftablesRoutingNatChain] = n.conn.AddChain(&nftables.Chain{
Name: nftablesRoutingNatChain,
Table: n.tableIPv6,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityNATSource - 1,
Type: nftables.ChainTypeNAT,
})
}
err = n.refreshRulesMap()
if err != nil {
return err
}
n.checkOrCreateDefaultForwardingRules()
err = n.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
func (n *nftablesManager) refreshRulesMap() error {
for _, registeredChains := range n.chains {
for _, chain := range registeredChains {
rules, err := n.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
n.rules[string(rule.UserData)] = rule
}
}
}
}
return nil
}
func (n *nftablesManager) eraseDefaultForwardRule() error {
if n.defaultForwardRules[0] == nil {
return nil
}
err := n.refreshDefaultForwardRule()
if err != nil {
return err
}
for i, r := range n.defaultForwardRules {
err = n.conn.DelRule(r)
if err != nil {
log.Errorf("failed to delete forward rule (%d): %s", i, err)
}
n.defaultForwardRules[i] = nil
}
return nil
}
func (n *nftablesManager) refreshDefaultForwardRule() error {
rules, err := n.conn.GetRules(n.defaultForwardRules[0].Table, n.defaultForwardRules[0].Chain)
if err != nil {
return fmt.Errorf("unable to list rules in forward chain: %s", err)
}
found := false
for i, r := range n.defaultForwardRules {
for _, rule := range rules {
if string(rule.UserData) == string(r.UserData) {
n.defaultForwardRules[i] = rule
found = true
break
}
}
}
if !found {
return fmt.Errorf("unable to find forward accept rule")
}
return nil
}
func (n *nftablesManager) acceptForwardRule(sourceNetwork string) error {
src := generateCIDRMatcherExpressions("source", sourceNetwork)
dst := generateCIDRMatcherExpressions("destination", "0.0.0.0/0")
var exprs []expr.Any
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
Kind: expr.VerdictAccept,
})...)
r := &nftables.Rule{
Table: n.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: n.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleSrc),
}
n.defaultForwardRules[0] = n.conn.AddRule(r)
src = generateCIDRMatcherExpressions("source", "0.0.0.0/0")
dst = generateCIDRMatcherExpressions("destination", sourceNetwork)
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
Kind: expr.VerdictAccept,
})...)
r = &nftables.Rule{
Table: n.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: n.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleDst),
}
n.defaultForwardRules[1] = n.conn.AddRule(r)
return nil
}
// checkOrCreateDefaultForwardingRules checks if the default forwarding rules are enabled
func (n *nftablesManager) checkOrCreateDefaultForwardingRules() {
_, foundIPv4 := n.rules[ipv4Forwarding]
if !foundIPv4 {
n.rules[ipv4Forwarding] = n.conn.AddRule(&nftables.Rule{
Table: n.tableIPv4,
Chain: n.chains[ipv4][nftablesRoutingForwardingChain],
Exprs: exprAllowRelatedEstablished,
UserData: []byte(ipv4Forwarding),
})
}
_, foundIPv6 := n.rules[ipv6Forwarding]
if !foundIPv6 {
n.rules[ipv6Forwarding] = n.conn.AddRule(&nftables.Rule{
Table: n.tableIPv6,
Chain: n.chains[ipv6][nftablesRoutingForwardingChain],
Exprs: exprAllowRelatedEstablished,
UserData: []byte(ipv6Forwarding),
})
}
}
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (n *nftablesManager) InsertRoutingRules(pair routerPair) error {
n.mux.Lock()
defer n.mux.Unlock()
err := n.refreshRulesMap()
if err != nil {
return err
}
err = n.insertRoutingRule(forwardingFormat, nftablesRoutingForwardingChain, pair, false)
if err != nil {
return err
}
err = n.insertRoutingRule(inForwardingFormat, nftablesRoutingForwardingChain, getInPair(pair), false)
if err != nil {
return err
}
if pair.masquerade {
err = n.insertRoutingRule(natFormat, nftablesRoutingNatChain, pair, true)
if err != nil {
return err
}
err = n.insertRoutingRule(inNatFormat, nftablesRoutingNatChain, getInPair(pair), true)
if err != nil {
return err
}
}
if n.defaultForwardRules[0] == nil && n.filterTable != nil {
err = n.acceptForwardRule(pair.source)
if err != nil {
log.Errorf("unable to create default forward rule: %s", err)
}
log.Debugf("default accept forward rule added")
}
err = n.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
}
return nil
}
// insertRoutingRule inserts a nftable rule to the conn client flush queue
func (n *nftablesManager) insertRoutingRule(format, chain string, pair routerPair, isNat bool) error {
prefix := netip.MustParsePrefix(pair.source)
sourceExp := generateCIDRMatcherExpressions("source", pair.source)
destExp := generateCIDRMatcherExpressions("destination", pair.destination)
var expression []expr.Any
if isNat {
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
} else {
expression = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
}
ruleKey := genKey(format, pair.ID)
_, exists := n.rules[ruleKey]
if exists {
err := n.removeRoutingRule(format, pair)
if err != nil {
return err
}
}
if prefix.Addr().Unmap().Is4() {
n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv4,
Chain: n.chains[ipv4][chain],
Exprs: expression,
UserData: []byte(ruleKey),
})
} else {
n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv6,
Chain: n.chains[ipv6][chain],
Exprs: expression,
UserData: []byte(ruleKey),
})
}
return nil
}
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
n.mux.Lock()
defer n.mux.Unlock()
err := n.refreshRulesMap()
if err != nil {
return err
}
err = n.removeRoutingRule(forwardingFormat, pair)
if err != nil {
return err
}
err = n.removeRoutingRule(inForwardingFormat, getInPair(pair))
if err != nil {
return err
}
err = n.removeRoutingRule(natFormat, pair)
if err != nil {
return err
}
err = n.removeRoutingRule(inNatFormat, getInPair(pair))
if err != nil {
return err
}
if len(n.rules) == 2 && n.defaultForwardRules[0] != nil {
err := n.eraseDefaultForwardRule()
if err != nil {
log.Errorf("failed to delete default fwd rule: %s", err)
}
}
err = n.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
}
log.Debugf("nftables: removed rules for %s", pair.destination)
return nil
}
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) error {
ruleKey := genKey(format, pair.ID)
rule, found := n.rules[ruleKey]
if found {
ruleType := "forwarding"
if rule.Chain.Type == nftables.ChainTypeNAT {
ruleType = "nat"
}
err := n.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.destination, err)
}
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.destination)
delete(n.rules, ruleKey)
}
return nil
}
// getPayloadDirectives get expression directives based on ip version and direction
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
switch {
case direction == exprDirectionSource && isIPv4:
return ipv4SrcOffset, ipv4Len, zeroXor
case direction == exprDirectionDestination && isIPv4:
return ipv4DestOffset, ipv4Len, zeroXor
case direction == exprDirectionSource && isIPv6:
return ipv6SrcOffset, ipv6Len, zeroXor6
case direction == exprDirectionDestination && isIPv6:
return ipv6DestOffset, ipv6Len, zeroXor6
default:
panic("no matched payload directive")
}
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(direction string, cidr string) []expr.Any {
ip, network, _ := net.ParseCIDR(cidr)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
offSet, packetLen, zeroXor := getPayloadDirectives(direction, add.Is4(), add.Is6())
return []expr.Any{
// fetch src add
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offSet,
Len: packetLen,
},
// net mask
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: packetLen,
Mask: network.Mask,
Xor: zeroXor,
},
// net address
&expr.Cmp{
Register: 1,
Data: add.AsSlice(),
},
}
}

View File

@@ -1,324 +0,0 @@
//go:build !android
package routemanager
import (
"context"
"testing"
"github.com/google/nftables"
"github.com/google/nftables/expr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/checkfw"
)
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
if checkfw.Check() != checkfw.NFTABLES {
t.Skip("nftables not supported on this OS")
}
manager := newNFTablesManager(context.TODO())
nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv4")
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv6")
require.Len(t, manager.rules, 2, "should have created rules for ipv4 and ipv6")
pair := routerPair{
ID: "abc",
source: "100.100.100.1/32",
destination: "100.100.100.0/24",
masquerade: true,
}
sourceExp := generateCIDRMatcherExpressions("source", pair.source)
destExp := generateCIDRMatcherExpressions("destination", pair.destination)
forward4Exp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
forward4RuleKey := genKey(forwardingFormat, pair.ID)
inserted4Forwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.tableIPv4,
Chain: manager.chains[ipv4][nftablesRoutingForwardingChain],
Exprs: forward4Exp,
UserData: []byte(forward4RuleKey),
})
nat4Exp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
nat4RuleKey := genKey(natFormat, pair.ID)
inserted4Nat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.tableIPv4,
Chain: manager.chains[ipv4][nftablesRoutingNatChain],
Exprs: nat4Exp,
UserData: []byte(nat4RuleKey),
})
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
pair = routerPair{
ID: "xyz",
source: "fc00::1/128",
destination: "fc11::/64",
masquerade: true,
}
sourceExp = generateCIDRMatcherExpressions("source", pair.source)
destExp = generateCIDRMatcherExpressions("destination", pair.destination)
forward6Exp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
forward6RuleKey := genKey(forwardingFormat, pair.ID)
inserted6Forwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.tableIPv6,
Chain: manager.chains[ipv6][nftablesRoutingForwardingChain],
Exprs: forward6Exp,
UserData: []byte(forward6RuleKey),
})
nat6Exp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
nat6RuleKey := genKey(natFormat, pair.ID)
inserted6Nat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.tableIPv6,
Chain: manager.chains[ipv6][nftablesRoutingNatChain],
Exprs: nat6Exp,
UserData: []byte(nat6RuleKey),
})
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
manager.tableIPv4 = nil
manager.tableIPv6 = nil
err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv4")
require.Len(t, manager.chains[ipv4], 2, "should have created chains for ipv6")
require.Len(t, manager.rules, 6, "should have restored all rules for ipv4 and ipv6")
foundRule, found := manager.rules[forward4RuleKey]
require.True(t, found, "forwarding rule should exist in the map")
assert.Equal(t, inserted4Forwarding.Exprs, foundRule.Exprs, "stored forwarding rule expressions should match")
foundRule, found = manager.rules[nat4RuleKey]
require.True(t, found, "nat rule should exist in the map")
// match len of output as nftables client doesn't return expressions with masquerade expression
assert.ElementsMatch(t, inserted4Nat.Exprs[:len(foundRule.Exprs)], foundRule.Exprs, "stored nat rule expressions should match")
foundRule, found = manager.rules[forward6RuleKey]
require.True(t, found, "forwarding rule should exist in the map")
assert.Equal(t, inserted6Forwarding.Exprs, foundRule.Exprs, "stored forward rule should match")
foundRule, found = manager.rules[nat6RuleKey]
require.True(t, found, "nat rule should exist in the map")
// match len of output as nftables client doesn't return expressions with masquerade expression
assert.ElementsMatch(t, inserted6Nat.Exprs[:len(foundRule.Exprs)], foundRule.Exprs, "stored nat rule should match")
}
func TestNftablesManager_InsertRoutingRules(t *testing.T) {
if checkfw.Check() != checkfw.NFTABLES {
t.Skip("nftables not supported on this OS")
}
for _, testCase := range insertRuleTestCases {
t.Run(testCase.name, func(t *testing.T) {
manager := newNFTablesManager(context.TODO())
nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
err = manager.InsertRoutingRules(testCase.inputPair)
require.NoError(t, err, "forwarding pair should be inserted")
sourceExp := generateCIDRMatcherExpressions("source", testCase.inputPair.source)
destExp := generateCIDRMatcherExpressions("destination", testCase.inputPair.destination)
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
fwdRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
found := 0
for _, registeredChains := range manager.chains {
for _, chain := range registeredChains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
found = 1
}
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.inputPair.masquerade {
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
found := 0
for _, registeredChains := range manager.chains {
for _, chain := range registeredChains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
found = 1
}
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
sourceExp = generateCIDRMatcherExpressions("source", getInPair(testCase.inputPair).source)
destExp = generateCIDRMatcherExpressions("destination", getInPair(testCase.inputPair).destination)
testingExpression = append(sourceExp, destExp...) //nolint:gocritic
inFwdRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
found = 0
for _, registeredChains := range manager.chains {
for _, chain := range registeredChains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
found = 1
}
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.inputPair.masquerade {
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
found := 0
for _, registeredChains := range manager.chains {
for _, chain := range registeredChains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
found = 1
}
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
})
}
}
func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
if checkfw.Check() != checkfw.NFTABLES {
t.Skip("nftables not supported on this OS")
}
for _, testCase := range removeRuleTestCases {
t.Run(testCase.name, func(t *testing.T) {
manager := newNFTablesManager(context.TODO())
nftablesTestingClient := &nftables.Conn{}
defer manager.CleanRoutingRules()
err := manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
table := manager.tableIPv4
if testCase.ipVersion == ipv6 {
table = manager.tableIPv6
}
sourceExp := generateCIDRMatcherExpressions("source", testCase.inputPair.source)
destExp := generateCIDRMatcherExpressions("destination", testCase.inputPair.destination)
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
forwardRuleKey := genKey(forwardingFormat, testCase.inputPair.ID)
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: table,
Chain: manager.chains[testCase.ipVersion][nftablesRoutingForwardingChain],
Exprs: forwardExp,
UserData: []byte(forwardRuleKey),
})
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
natRuleKey := genKey(natFormat, testCase.inputPair.ID)
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: table,
Chain: manager.chains[testCase.ipVersion][nftablesRoutingNatChain],
Exprs: natExp,
UserData: []byte(natRuleKey),
})
sourceExp = generateCIDRMatcherExpressions("source", getInPair(testCase.inputPair).source)
destExp = generateCIDRMatcherExpressions("destination", getInPair(testCase.inputPair).destination)
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: table,
Chain: manager.chains[testCase.ipVersion][nftablesRoutingForwardingChain],
Exprs: forwardExp,
UserData: []byte(inForwardRuleKey),
})
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: table,
Chain: manager.chains[testCase.ipVersion][nftablesRoutingNatChain],
Exprs: natExp,
UserData: []byte(inNatRuleKey),
})
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
manager.tableIPv4 = nil
manager.tableIPv6 = nil
err = manager.RestoreOrCreateContainers()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveRoutingRules(testCase.inputPair)
require.NoError(t, err, "shouldn't return error")
for _, registeredChains := range manager.chains {
for _, chain := range registeredChains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 {
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
}
}
}
}
})
}
}

View File

@@ -1,24 +0,0 @@
package routemanager
import (
"net/netip"
"github.com/netbirdio/netbird/route"
)
type routerPair struct {
ID string
source string
destination string
masquerade bool
}
func routeToRouterPair(source string, route *route.Route) routerPair {
parsed := netip.MustParsePrefix(source).Masked()
return routerPair{
ID: route.ID,
source: parsed.String(),
destination: route.Network.Masked().String(),
masquerade: route.Masquerade,
}
}

View File

@@ -4,9 +4,10 @@ import (
"context" "context"
"fmt" "fmt"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
) )
func newServerRouter(context.Context, *iface.WGIface) (serverRouter, error) { func newServerRouter(context.Context, *iface.WGIface, firewall.Manager) (serverRouter, error) {
return nil, fmt.Errorf("server route not supported on this os") return nil, fmt.Errorf("server route not supported on this os")
} }

View File

@@ -4,11 +4,12 @@ package routemanager
import ( import (
"context" "context"
"fmt" "net/netip"
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -17,16 +18,11 @@ type defaultServerRouter struct {
mux sync.Mutex mux sync.Mutex
ctx context.Context ctx context.Context
routes map[string]*route.Route routes map[string]*route.Route
firewall firewallManager firewall firewall.Manager
wgInterface *iface.WGIface wgInterface *iface.WGIface
} }
func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) (serverRouter, error) { func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager) (serverRouter, error) {
firewall, err := newFirewall(ctx)
if err != nil {
return nil, err
}
return &defaultServerRouter{ return &defaultServerRouter{
ctx: ctx, ctx: ctx,
routes: make(map[string]*route.Route), routes: make(map[string]*route.Route),
@@ -38,13 +34,6 @@ func newServerRouter(ctx context.Context, wgInterface *iface.WGIface) (serverRou
func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) error { func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) error {
serverRoutesToRemove := make([]string, 0) serverRoutesToRemove := make([]string, 0)
if len(routesMap) > 0 {
err := m.firewall.RestoreOrCreateContainers()
if err != nil {
return fmt.Errorf("couldn't initialize firewall containers, got err: %v", err)
}
}
for routeID := range m.routes { for routeID := range m.routes {
update, found := routesMap[routeID] update, found := routesMap[routeID]
if !found || !update.IsEqual(m.routes[routeID]) { if !found || !update.IsEqual(m.routes[routeID]) {
@@ -121,5 +110,22 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
} }
func (m *defaultServerRouter) cleanUp() { func (m *defaultServerRouter) cleanUp() {
m.firewall.CleanRoutingRules() m.mux.Lock()
defer m.mux.Unlock()
for _, r := range m.routes {
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), r))
if err != nil {
log.Warnf("failed to remove clean up route: %s", r.ID)
}
}
}
func routeToRouterPair(source string, route *route.Route) firewall.RouterPair {
parsed := netip.MustParsePrefix(source).Masked()
return firewall.RouterPair{
ID: route.ID,
Source: parsed.String(),
Destination: route.Network.Masked().String(),
Masquerade: route.Masquerade,
}
} }

View File

@@ -1,3 +1,5 @@
//go:build !android
package routemanager package routemanager
import ( import (

6
go.mod
View File

@@ -30,6 +30,7 @@ require (
require ( require (
fyne.io/fyne/v2 v2.1.4 fyne.io/fyne/v2 v2.1.4
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5
github.com/c-robinson/iplib v1.0.3 github.com/c-robinson/iplib v1.0.3
github.com/cilium/ebpf v0.10.0 github.com/cilium/ebpf v0.10.0
github.com/coreos/go-iptables v0.7.0 github.com/coreos/go-iptables v0.7.0
@@ -51,7 +52,8 @@ require (
github.com/miekg/dns v1.1.43 github.com/miekg/dns v1.1.43
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0 github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20231027143200-a966bce7db88 github.com/netbirdio/management-integrations/additions v0.0.0-20231205113053-c462587ae695
github.com/netbirdio/management-integrations/integrations v0.0.0-20231205113053-c462587ae695
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pion/logging v0.2.2 github.com/pion/logging v0.2.2
@@ -108,6 +110,7 @@ require (
github.com/go-stack/stack v1.8.0 // indirect github.com/go-stack/stack v1.8.0 // indirect
github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/btree v1.0.1 // indirect
github.com/google/s2a-go v0.1.4 // indirect github.com/google/s2a-go v0.1.4 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect
github.com/googleapis/gax-go/v2 v2.10.0 // indirect github.com/googleapis/gax-go/v2 v2.10.0 // indirect
@@ -153,6 +156,7 @@ require (
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect
gvisor.dev/gvisor v0.0.0-20221203005347-703fd9b7fbc0 // indirect
honnef.co/go/tools v0.2.2 // indirect honnef.co/go/tools v0.2.2 // indirect
k8s.io/apimachinery v0.23.5 // indirect k8s.io/apimachinery v0.23.5 // indirect
) )

8
go.sum
View File

@@ -76,6 +76,8 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY=
github.com/bazelbuild/rules_go v0.30.0/go.mod h1:MC23Dc/wkXEyk3Wpq6lCqz0ZAYOZDw2DR5y3N1q2i7M= github.com/bazelbuild/rules_go v0.30.0/go.mod h1:MC23Dc/wkXEyk3Wpq6lCqz0ZAYOZDw2DR5y3N1q2i7M=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
@@ -495,8 +497,10 @@ github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRW
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc= github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc=
github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ= github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
github.com/netbirdio/management-integrations/integrations v0.0.0-20231027143200-a966bce7db88 h1:zhe8qseauBuYOS910jpl5sv8Tb+36zxQPXrwYXqll0g= github.com/netbirdio/management-integrations/additions v0.0.0-20231205113053-c462587ae695 h1:c/Rvyy/mqbFoKo6FS8ihQ3/3y+TAl0qDEH0pO2tXayM=
github.com/netbirdio/management-integrations/integrations v0.0.0-20231027143200-a966bce7db88/go.mod h1:KSqjzHcqlodTWiuap5lRXxt5KT3vtYRoksL0KIrTK40= github.com/netbirdio/management-integrations/additions v0.0.0-20231205113053-c462587ae695/go.mod h1:31FhBNvQ+riHEIu6LSTmqr8IeuSIsGfQffqV4LFmbwA=
github.com/netbirdio/management-integrations/integrations v0.0.0-20231205113053-c462587ae695 h1:9HRnqSosRuKyOZgVN/hJW3DG2zVyt5AARmiQlSuDPIc=
github.com/netbirdio/management-integrations/integrations v0.0.0-20231205113053-c462587ae695/go.mod h1:B0nMS3es77gOvPYhc0K91fAzTkQLi/jRq5TffUN3klM=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM= github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM=

View File

@@ -20,10 +20,10 @@ func NewWGIFace(iFaceName string, address string, mtu int, tunAdapter TunAdapter
return wgIFace, err return wgIFace, err
} }
wgIFace.tun = newTunDevice(iFaceName, wgAddress, mtu, transportNet) tun := newTunDevice(iFaceName, wgAddress, mtu, transportNet)
wgIFace.tun = tun
wgIFace.configurer = newWGConfigurer(iFaceName) wgIFace.configurer = newWGConfigurer(tun)
wgIFace.userspaceBind = !WireGuardModuleIsLoaded() wgIFace.userspaceBind = true
return wgIFace, nil return wgIFace, nil
} }

View File

@@ -3,7 +3,6 @@
package iface package iface
import ( import (
"fmt"
"os" "os"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -11,14 +10,6 @@ import (
) )
func (c *tunDevice) Create() error { func (c *tunDevice) Create() error {
if WireGuardModuleIsLoaded() {
log.Infof("create tun interface with kernel WireGuard support: %s", c.DeviceName())
return c.createWithKernel()
}
if !tunModuleIsLoaded() {
return fmt.Errorf("couldn't check or load tun module")
}
log.Infof("create tun interface with userspace WireGuard support: %s", c.DeviceName()) log.Infof("create tun interface with userspace WireGuard support: %s", c.DeviceName())
var err error var err error
c.netInterface, err = c.createWithUserspace() c.netInterface, err = c.createWithUserspace()
@@ -27,7 +18,6 @@ func (c *tunDevice) Create() error {
} }
return c.assignAddr() return c.assignAddr()
} }
// createWithKernel Creates a new WireGuard interface using kernel WireGuard module. // createWithKernel Creates a new WireGuard interface using kernel WireGuard module.
@@ -90,34 +80,7 @@ func (c *tunDevice) createWithKernel() error {
// assignAddr Adds IP address to the tunnel interface // assignAddr Adds IP address to the tunnel interface
func (c *tunDevice) assignAddr() error { func (c *tunDevice) assignAddr() error {
link := newWGLink(c.name) return nil
//delete existing addresses
list, err := netlink.AddrList(link, 0)
if err != nil {
return err
}
if len(list) > 0 {
for _, a := range list {
addr := a
err = netlink.AddrDel(link, &addr)
if err != nil {
return err
}
}
}
log.Debugf("adding address %s to interface: %s", c.address.String(), c.name)
addr, _ := netlink.ParseAddr(c.address.String())
err = netlink.AddrAdd(link, addr)
if os.IsExist(err) {
log.Infof("interface %s already has the address: %s", c.name, c.address.String())
} else if err != nil {
return err
}
// On linux, the link must be brought up
err = netlink.LinkSetUp(link)
return err
} }
type wgLink struct { type wgLink struct {

View File

@@ -10,10 +10,9 @@ import (
"golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/ipc"
"github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/iface/uspproxy"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
) )
type tunDevice struct { type tunDevice struct {
@@ -25,6 +24,7 @@ type tunDevice struct {
uapi net.Listener uapi net.Listener
wrapper *DeviceWrapper wrapper *DeviceWrapper
close chan struct{} close chan struct{}
tunDevice *device.Device
} }
func newTunDevice(name string, address WGAddress, mtu int, transportNet transport.Net) *tunDevice { func newTunDevice(name string, address WGAddress, mtu int, transportNet transport.Net) *tunDevice {
@@ -46,6 +46,11 @@ func (c *tunDevice) WgAddress() WGAddress {
return c.address return c.address
} }
func (t *tunDevice) Device() *device.Device {
// todo: potential nil pointer exception
return t.tunDevice
}
func (c *tunDevice) DeviceName() string { func (c *tunDevice) DeviceName() string {
return c.name return c.name
} }
@@ -85,15 +90,14 @@ func (c *tunDevice) Close() error {
return err3 return err3
} }
// createWithUserspace Creates a new Wireguard interface, using wireguard-go userspace implementation
func (c *tunDevice) createWithUserspace() (NetInterface, error) { func (c *tunDevice) createWithUserspace() (NetInterface, error) {
tunIface, err := tun.CreateTUN(c.name, c.mtu) nsTun := uspproxy.NewNetStackTun(c.address.IP.String())
tunIface, err := nsTun.Create()
if err != nil { if err != nil {
return nil, err return nil, err
} }
c.wrapper = newDeviceWrapper(tunIface) c.wrapper = newDeviceWrapper(tunIface)
// We need to create a wireguard-go device and listen to configuration requests
tunDev := device.NewDevice( tunDev := device.NewDevice(
c.wrapper, c.wrapper,
c.iceBind, c.iceBind,
@@ -105,33 +109,7 @@ func (c *tunDevice) createWithUserspace() (NetInterface, error) {
return nil, err return nil, err
} }
c.uapi, err = c.getUAPI(c.name) c.tunDevice = tunDev
if err != nil {
_ = tunIface.Close()
return nil, err
}
go func() {
for {
select {
case <-c.close:
log.Debugf("exit uapi.Accept()")
return
default:
}
uapiConn, uapiErr := c.uapi.Accept()
if uapiErr != nil {
log.Traceln("uapi Accept failed with error: ", uapiErr)
continue
}
go func() {
tunDev.IpcHandle(uapiConn)
log.Debugf("exit tunDevice.IpcHandle")
}()
}
}()
log.Debugln("UAPI listener started")
return tunIface, nil return tunIface, nil
} }

43
iface/uspproxy/proxy.go Normal file
View File

@@ -0,0 +1,43 @@
package uspproxy
import (
"context"
"net"
"github.com/armon/go-socks5"
log "github.com/sirupsen/logrus"
)
type Dialer interface {
Dial(ctx context.Context, network, addr string) (net.Conn, error)
}
// Proxy todo close server
type Proxy struct {
server *socks5.Server
}
func NewSocks5(dialer Dialer) (*Proxy, error) {
conf := &socks5.Config{
Dial: dialer.Dial,
}
server, err := socks5.New(conf)
if err != nil {
log.Debugf("failed to init socks5 proxy: %s", err)
return nil, err
}
return &Proxy{
server: server,
}, nil
}
func (s *Proxy) ListenAndServe(addr string) error {
go func() {
err := s.server.ListenAndServe("tcp", addr)
if err != nil {
log.Debugf("failed to start socks5 proxy: %s", err)
}
}()
return nil
}

55
iface/uspproxy/tun.go Normal file
View File

@@ -0,0 +1,55 @@
package uspproxy
import (
"context"
"net"
"net/netip"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
)
type NSDialer struct {
net *netstack.Net
}
func (d *NSDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err := d.net.Dial(network, addr)
if err != nil {
log.Debugf("failed to deal connection: %s", err)
}
return conn, err
}
type NetStackTun struct {
address string
proxy *Proxy
}
func NewNetStackTun(address string) *NetStackTun {
return &NetStackTun{
address: address,
}
}
func (t *NetStackTun) Create() (tun.Device, error) {
nsTunDev, tunNet, err := netstack.CreateNetTUN(
[]netip.Addr{netip.MustParseAddr(t.address)},
[]netip.Addr{netip.MustParseAddr("8.8.8.8")},
1420)
if err != nil {
return nil, err
}
dialer := &NSDialer{tunNet}
t.proxy, err = NewSocks5(dialer)
if err != nil {
// close nsTunDev
return nil, err
}
err = t.proxy.ListenAndServe("0.0.0.0:1234")
return nsTunDev, nil
}

View File

@@ -1,205 +0,0 @@
//go:build !android
package iface
import (
"fmt"
"net"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type wGConfigurer struct {
deviceName string
}
func newWGConfigurer(deviceName string) wGConfigurer {
return wGConfigurer{
deviceName: deviceName,
}
}
func (c *wGConfigurer) configureInterface(privateKey string, port int) error {
log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey)
if err != nil {
return err
}
fwmark := 0
config := wgtypes.Config{
PrivateKey: &key,
ReplacePeers: true,
FirewallMark: &fwmark,
ListenPort: &port,
}
err = c.configure(config)
if err != nil {
return fmt.Errorf(`received error "%w" while configuring interface %s with port %d`, err, c.deviceName, port)
}
return nil
}
func (c *wGConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
//parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true,
AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey,
Endpoint: endpoint,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = c.configure(config)
if err != nil {
return fmt.Errorf(`received error "%w" while updating peer on interface %s with settings: allowed ips %s, endpoint %s`, err, c.deviceName, allowedIps, endpoint.String())
}
return nil
}
func (c *wGConfigurer) removePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = c.configure(config)
if err != nil {
return fmt.Errorf(`received error "%w" while removing peer %s from interface %s`, err, peerKey, c.deviceName)
}
return nil
}
func (c *wGConfigurer) addAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
UpdateOnly: true,
ReplaceAllowedIPs: false,
AllowedIPs: []net.IPNet{*ipNet},
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = c.configure(config)
if err != nil {
return fmt.Errorf(`received error "%w" while adding allowed Ip to peer on interface %s with settings: allowed ips %s`, err, c.deviceName, allowedIP)
}
return nil
}
func (c *wGConfigurer) removeAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
existingPeer, err := c.getPeer(c.deviceName, peerKey)
if err != nil {
return err
}
newAllowedIPs := existingPeer.AllowedIPs
for i, existingAllowedIP := range existingPeer.AllowedIPs {
if existingAllowedIP.String() == ipNet.String() {
newAllowedIPs = append(existingPeer.AllowedIPs[:i], existingPeer.AllowedIPs[i+1:]...) //nolint:gocritic
break
}
}
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
UpdateOnly: true,
ReplaceAllowedIPs: true,
AllowedIPs: newAllowedIPs,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
err = c.configure(config)
if err != nil {
return fmt.Errorf(`received error "%w" while removing allowed IP from peer on interface %s with settings: allowed ips %s`, err, c.deviceName, allowedIP)
}
return nil
}
func (c *wGConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
wg, err := wgctrl.New()
if err != nil {
return wgtypes.Peer{}, err
}
defer func() {
err = wg.Close()
if err != nil {
log.Errorf("got error while closing wgctl: %v", err)
}
}()
wgDevice, err := wg.Device(ifaceName)
if err != nil {
return wgtypes.Peer{}, err
}
for _, peer := range wgDevice.Peers {
if peer.PublicKey.String() == peerPubKey {
return peer, nil
}
}
return wgtypes.Peer{}, fmt.Errorf("peer not found")
}
func (c *wGConfigurer) configure(config wgtypes.Config) error {
wg, err := wgctrl.New()
if err != nil {
return err
}
defer wg.Close()
// validate if device with name exists
_, err = wg.Device(c.deviceName)
if err != nil {
return err
}
log.Tracef("got Wireguard device %s", c.deviceName)
return wg.ConfigureDevice(c.deviceName, config)
}

View File

@@ -83,7 +83,10 @@ var (
if err != nil { if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err) return fmt.Errorf("failed reading provided config file: %s: %v", mgmtConfig, err)
} }
config.HttpConfig.IdpSignKeyRefreshEnabled = idpSignKeyRefreshEnabled
if cmd.Flag(idpSignKeyRefreshEnabledFlagName).Changed {
config.HttpConfig.IdpSignKeyRefreshEnabled = idpSignKeyRefreshEnabled
}
tlsEnabled := false tlsEnabled := false
if mgmtLetsencryptDomain != "" || (config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "") { if mgmtLetsencryptDomain != "" || (config.HttpConfig.CertFile != "" && config.HttpConfig.CertKey != "") {

View File

@@ -12,7 +12,8 @@ import (
const ( const (
// ExitSetupFailed defines exit code // ExitSetupFailed defines exit code
ExitSetupFailed = 1 ExitSetupFailed = 1
idpSignKeyRefreshEnabledFlagName = "idp-sign-key-refresh-enabled"
) )
var ( var (
@@ -62,7 +63,7 @@ func init() {
mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird") mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird")
mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max length is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain)) mgmtCmd.Flags().StringVar(&dnsDomain, "dns-domain", defaultSingleAccModeDomain, fmt.Sprintf("Domain used for peer resolution. This is appended to the peer's name, e.g. pi-server. %s. Max length is 192 characters to allow appending to a peer name with up to 63 characters.", defaultSingleAccModeDomain))
mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, "idp-sign-key-refresh-enabled", false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.") mgmtCmd.Flags().BoolVar(&idpSignKeyRefreshEnabled, idpSignKeyRefreshEnabledFlagName, false, "Enable cache headers evaluation to determine signing key rotation period. This will refresh the signing key upon expiry.")
mgmtCmd.Flags().BoolVar(&userDeleteFromIDPEnabled, "user-delete-from-idp", false, "Allows to delete user from IDP when user is deleted from account") mgmtCmd.Flags().BoolVar(&userDeleteFromIDPEnabled, "user-delete-from-idp", false, "Allows to delete user from IDP when user is deleted from account")
rootCmd.MarkFlagRequired("config") //nolint rootCmd.MarkFlagRequired("config") //nolint

View File

@@ -17,15 +17,18 @@ import (
"github.com/eko/gocache/v3/cache" "github.com/eko/gocache/v3/cache"
cacheStore "github.com/eko/gocache/v3/store" cacheStore "github.com/eko/gocache/v3/store"
"github.com/netbirdio/management-integrations/additions"
gocache "github.com/patrickmn/go-cache" gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/base62" "github.com/netbirdio/netbird/base62"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -68,13 +71,13 @@ type AccountManager interface {
MarkPATUsed(tokenID string) error MarkPATUsed(tokenID string) error
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error) GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
ListUsers(accountID string) ([]*User, error) ListUsers(accountID string) ([]*User, error)
GetPeers(accountID, userID string) ([]*Peer, error) GetPeers(accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnected(peerKey string, connected bool) error MarkPeerConnected(peerKey string, connected bool) error
DeletePeer(accountID, peerID, userID string) error DeletePeer(accountID, peerID, userID string) error
UpdatePeer(accountID, userID string, peer *Peer) (*Peer, error) UpdatePeer(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
GetNetworkMap(peerID string) (*NetworkMap, error) GetNetworkMap(peerID string) (*NetworkMap, error)
GetPeerNetwork(peerID string) (*Network, error) GetPeerNetwork(peerID string) (*Network, error)
AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error) AddPeer(setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, error)
CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error)
DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
@@ -106,11 +109,12 @@ type AccountManager interface {
GetEvents(accountID, userID string) ([]*activity.Event, error) GetEvents(accountID, userID string) ([]*activity.Event, error)
GetDNSSettings(accountID string, userID string) (*DNSSettings, error) GetDNSSettings(accountID string, userID string) (*DNSSettings, error)
SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error
GetPeer(accountID, peerID, userID string) (*Peer, error) GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error) UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error)
LoginPeer(login PeerLogin) (*Peer, *NetworkMap, error) // used by peer gRPC API LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
SyncPeer(sync PeerSync) (*Peer, *NetworkMap, error) // used by peer gRPC API SyncPeer(sync PeerSync) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error) GetAllConnectedPeers() (map[string]struct{}, error)
HasConnectedChannel(peerID string) bool
GetExternalCacheManager() ExternalCacheManager GetExternalCacheManager() ExternalCacheManager
} }
@@ -159,17 +163,28 @@ type Settings struct {
// JWTGroupsClaimName from which we extract groups name to add it to account groups // JWTGroupsClaimName from which we extract groups name to add it to account groups
JWTGroupsClaimName string JWTGroupsClaimName string
// JWTAllowGroups list of groups to which users are allowed access
JWTAllowGroups []string `gorm:"serializer:json"`
// Extra is a dictionary of Account settings
Extra *account.ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"`
} }
// Copy copies the Settings struct // Copy copies the Settings struct
func (s *Settings) Copy() *Settings { func (s *Settings) Copy() *Settings {
return &Settings{ settings := &Settings{
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled, PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
PeerLoginExpiration: s.PeerLoginExpiration, PeerLoginExpiration: s.PeerLoginExpiration,
JWTGroupsEnabled: s.JWTGroupsEnabled, JWTGroupsEnabled: s.JWTGroupsEnabled,
JWTGroupsClaimName: s.JWTGroupsClaimName, JWTGroupsClaimName: s.JWTGroupsClaimName,
GroupsPropagationEnabled: s.GroupsPropagationEnabled, GroupsPropagationEnabled: s.GroupsPropagationEnabled,
JWTAllowGroups: s.JWTAllowGroups,
} }
if s.Extra != nil {
settings.Extra = s.Extra.Copy()
}
return settings
} }
// Account represents a unique account of the system // Account represents a unique account of the system
@@ -185,8 +200,8 @@ type Account struct {
SetupKeys map[string]*SetupKey `gorm:"-"` SetupKeys map[string]*SetupKey `gorm:"-"`
SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"` SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"`
Network *Network `gorm:"embedded;embeddedPrefix:network_"` Network *Network `gorm:"embedded;embeddedPrefix:network_"`
Peers map[string]*Peer `gorm:"-"` Peers map[string]*nbpeer.Peer `gorm:"-"`
PeersG []Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"`
Users map[string]*User `gorm:"-"` Users map[string]*User `gorm:"-"`
UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"`
Groups map[string]*Group `gorm:"-"` Groups map[string]*Group `gorm:"-"`
@@ -221,7 +236,7 @@ type UserInfo struct {
// getRoutesToSync returns the enabled routes for the peer ID and the routes // getRoutesToSync returns the enabled routes for the peer ID and the routes
// from the ACL peers that have distribution groups associated with the peer ID. // from the ACL peers that have distribution groups associated with the peer ID.
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. // Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
func (a *Account) getRoutesToSync(peerID string, aclPeers []*Peer) []*route.Route { func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*route.Route {
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(peerID) routes, peerDisabledRoutes := a.getRoutingPeerRoutes(peerID)
peerRoutesMembership := make(lookupMap) peerRoutesMembership := make(lookupMap)
for _, r := range append(routes, peerDisabledRoutes...) { for _, r := range append(routes, peerDisabledRoutes...) {
@@ -345,10 +360,22 @@ func (a *Account) GetGroup(groupID string) *Group {
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise // GetPeerNetworkMap returns a group by ID if exists, nil otherwise
func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap { func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap {
peer := a.Peers[peerID]
if peer == nil {
return &NetworkMap{
Network: a.Network.Copy(),
}
}
validatedPeers := additions.ValidatePeers([]*nbpeer.Peer{peer})
if len(validatedPeers) == 0 {
return &NetworkMap{
Network: a.Network.Copy(),
}
}
aclPeers, firewallRules := a.getPeerConnectionResources(peerID) aclPeers, firewallRules := a.getPeerConnectionResources(peerID)
// exclude expired peers // exclude expired peers
var peersToConnect []*Peer var peersToConnect []*nbpeer.Peer
var expiredPeers []*Peer var expiredPeers []*nbpeer.Peer
for _, p := range aclPeers { for _, p := range aclPeers {
expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration) expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration)
if a.Settings.PeerLoginExpirationEnabled && expired { if a.Settings.PeerLoginExpirationEnabled && expired {
@@ -386,8 +413,8 @@ func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap {
} }
// GetExpiredPeers returns peers that have been expired // GetExpiredPeers returns peers that have been expired
func (a *Account) GetExpiredPeers() []*Peer { func (a *Account) GetExpiredPeers() []*nbpeer.Peer {
var peers []*Peer var peers []*nbpeer.Peer
for _, peer := range a.GetPeersWithExpiration() { for _, peer := range a.GetPeersWithExpiration() {
expired, _ := peer.LoginExpired(a.Settings.PeerLoginExpiration) expired, _ := peer.LoginExpired(a.Settings.PeerLoginExpiration)
if expired { if expired {
@@ -426,8 +453,8 @@ func (a *Account) GetNextPeerExpiration() (time.Duration, bool) {
} }
// GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user // GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user
func (a *Account) GetPeersWithExpiration() []*Peer { func (a *Account) GetPeersWithExpiration() []*nbpeer.Peer {
peers := make([]*Peer, 0) peers := make([]*nbpeer.Peer, 0)
for _, peer := range a.Peers { for _, peer := range a.Peers {
if peer.LoginExpirationEnabled && peer.AddedWithSSOLogin() { if peer.LoginExpirationEnabled && peer.AddedWithSSOLogin() {
peers = append(peers, peer) peers = append(peers, peer)
@@ -437,8 +464,8 @@ func (a *Account) GetPeersWithExpiration() []*Peer {
} }
// GetPeers returns a list of all Account peers // GetPeers returns a list of all Account peers
func (a *Account) GetPeers() []*Peer { func (a *Account) GetPeers() []*nbpeer.Peer {
var peers []*Peer var peers []*nbpeer.Peer
for _, peer := range a.Peers { for _, peer := range a.Peers {
peers = append(peers, peer) peers = append(peers, peer)
} }
@@ -452,7 +479,7 @@ func (a *Account) UpdateSettings(update *Settings) *Account {
} }
// UpdatePeer saves new or replaces existing peer // UpdatePeer saves new or replaces existing peer
func (a *Account) UpdatePeer(update *Peer) { func (a *Account) UpdatePeer(update *nbpeer.Peer) {
a.Peers[update.ID] = update a.Peers[update.ID] = update
} }
@@ -481,7 +508,7 @@ func (a *Account) DeletePeer(peerID string) {
// FindPeerByPubKey looks for a Peer by provided WireGuard public key in the Account or returns error if it wasn't found. // FindPeerByPubKey looks for a Peer by provided WireGuard public key in the Account or returns error if it wasn't found.
// It will return an object copy of the peer. // It will return an object copy of the peer.
func (a *Account) FindPeerByPubKey(peerPubKey string) (*Peer, error) { func (a *Account) FindPeerByPubKey(peerPubKey string) (*nbpeer.Peer, error) {
for _, peer := range a.Peers { for _, peer := range a.Peers {
if peer.Key == peerPubKey { if peer.Key == peerPubKey {
return peer.Copy(), nil return peer.Copy(), nil
@@ -492,8 +519,8 @@ func (a *Account) FindPeerByPubKey(peerPubKey string) (*Peer, error) {
} }
// FindUserPeers returns a list of peers that user owns (created) // FindUserPeers returns a list of peers that user owns (created)
func (a *Account) FindUserPeers(userID string) ([]*Peer, error) { func (a *Account) FindUserPeers(userID string) ([]*nbpeer.Peer, error) {
peers := make([]*Peer, 0) peers := make([]*nbpeer.Peer, 0)
for _, peer := range a.Peers { for _, peer := range a.Peers {
if peer.UserID == userID { if peer.UserID == userID {
peers = append(peers, peer) peers = append(peers, peer)
@@ -585,7 +612,7 @@ func (a *Account) getPeerDNSLabels() lookupMap {
} }
func (a *Account) Copy() *Account { func (a *Account) Copy() *Account {
peers := map[string]*Peer{} peers := map[string]*nbpeer.Peer{}
for id, peer := range a.Peers { for id, peer := range a.Peers {
peers[id] = peer.Copy() peers[id] = peer.Copy()
} }
@@ -662,7 +689,7 @@ func (a *Account) GetGroupAll() (*Group, error) {
} }
// GetPeer looks up a Peer by ID // GetPeer looks up a Peer by ID
func (a *Account) GetPeer(peerID string) *Peer { func (a *Account) GetPeer(peerID string) *nbpeer.Peer {
return a.Peers[peerID] return a.Peers[peerID]
} }
@@ -872,6 +899,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string,
return nil, err return nil, err
} }
err = additions.ValidateExtraSettings(newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID, am.eventStore)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID) user, err := account.FindUser(userID)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -1649,6 +1681,11 @@ func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, er
return am.peersUpdateManager.GetAllConnectedPeers(), nil return am.peersUpdateManager.GetAllConnectedPeers(), nil
} }
// HasConnectedChannel returns true if peers has channel in update manager, otherwise false
func (am *DefaultAccountManager) HasConnectedChannel(peerID string) bool {
return am.peersUpdateManager.HasChannel(peerID)
}
var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`) var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
func isDomainValid(domain string) bool { func isDomainValid(domain string) bool {
@@ -1698,7 +1735,7 @@ func newAccountWithId(accountID, userID, domain string) *Account {
log.Debugf("creating new account") log.Debugf("creating new account")
network := NewNetwork() network := NewNetwork()
peers := make(map[string]*Peer) peers := make(map[string]*nbpeer.Peer)
users := make(map[string]*User) users := make(map[string]*User)
routes := make(map[string]*route.Route) routes := make(map[string]*route.Route)
setupKeys := map[string]*SetupKey{} setupKeys := map[string]*SetupKey{}

View File

@@ -0,0 +1,13 @@
package account
type ExtraSettings struct {
// PeerApprovalEnabled enables or disables the need for peers bo be approved by an administrator
PeerApprovalEnabled bool
}
// Copy copies the ExtraSettings struct
func (e *ExtraSettings) Copy() *ExtraSettings {
return &ExtraSettings{
PeerApprovalEnabled: e.PeerApprovalEnabled,
}
}

View File

@@ -15,6 +15,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -26,10 +27,10 @@ import (
func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) { func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) {
t.Helper() t.Helper()
peer := &Peer{ peer := &nbpeer.Peer{
Key: "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8=", Key: "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8=",
Name: "test-host@netbird.io", Name: "test-host@netbird.io",
Meta: PeerSystemMeta{ Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host@netbird.io", Hostname: "test-host@netbird.io",
GoOS: "linux", GoOS: "linux",
Kernel: "Linux", Kernel: "Linux",
@@ -110,13 +111,14 @@ func verifyNewAccountHasDefaultFields(t *testing.T, account *Account, createdBy
func TestAccount_GetPeerNetworkMap(t *testing.T) { func TestAccount_GetPeerNetworkMap(t *testing.T) {
peerID1 := "peer-1" peerID1 := "peer-1"
peerID2 := "peer-2" peerID2 := "peer-2"
// peerID3 := "peer-3"
tt := []struct { tt := []struct {
name string name string
accountSettings Settings accountSettings Settings
peerID string peerID string
expectedPeers []string expectedPeers []string
expectedOfflinePeers []string expectedOfflinePeers []string
peers map[string]*Peer peers map[string]*nbpeer.Peer
}{ }{
{ {
name: "Should return ALL peers when global peer login expiration disabled", name: "Should return ALL peers when global peer login expiration disabled",
@@ -124,14 +126,14 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
peerID: peerID1, peerID: peerID1,
expectedPeers: []string{peerID2}, expectedPeers: []string{peerID2},
expectedOfflinePeers: []string{}, expectedOfflinePeers: []string{},
peers: map[string]*Peer{ peers: map[string]*nbpeer.Peer{
"peer-1": { "peer-1": {
ID: peerID1, ID: peerID1,
Key: "peer-1-key", Key: "peer-1-key",
IP: net.IP{100, 64, 0, 1}, IP: net.IP{100, 64, 0, 1},
Name: peerID1, Name: peerID1,
DNSLabel: peerID1, DNSLabel: peerID1,
Status: &PeerStatus{ Status: &nbpeer.PeerStatus{
LastSeen: time.Now().UTC(), LastSeen: time.Now().UTC(),
Connected: false, Connected: false,
LoginExpired: true, LoginExpired: true,
@@ -145,7 +147,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
IP: net.IP{100, 64, 0, 1}, IP: net.IP{100, 64, 0, 1},
Name: peerID2, Name: peerID2,
DNSLabel: peerID2, DNSLabel: peerID2,
Status: &PeerStatus{ Status: &nbpeer.PeerStatus{
LastSeen: time.Now().UTC(), LastSeen: time.Now().UTC(),
Connected: false, Connected: false,
LoginExpired: false, LoginExpired: false,
@@ -162,14 +164,14 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
peerID: peerID1, peerID: peerID1,
expectedPeers: []string{}, expectedPeers: []string{},
expectedOfflinePeers: []string{peerID2}, expectedOfflinePeers: []string{peerID2},
peers: map[string]*Peer{ peers: map[string]*nbpeer.Peer{
"peer-1": { "peer-1": {
ID: peerID1, ID: peerID1,
Key: "peer-1-key", Key: "peer-1-key",
IP: net.IP{100, 64, 0, 1}, IP: net.IP{100, 64, 0, 1},
Name: peerID1, Name: peerID1,
DNSLabel: peerID1, DNSLabel: peerID1,
Status: &PeerStatus{ Status: &nbpeer.PeerStatus{
LastSeen: time.Now().UTC(), LastSeen: time.Now().UTC(),
Connected: false, Connected: false,
LoginExpired: true, LoginExpired: true,
@@ -184,7 +186,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
IP: net.IP{100, 64, 0, 1}, IP: net.IP{100, 64, 0, 1},
Name: peerID2, Name: peerID2,
DNSLabel: peerID2, DNSLabel: peerID2,
Status: &PeerStatus{ Status: &nbpeer.PeerStatus{
LastSeen: time.Now().UTC(), LastSeen: time.Now().UTC(),
Connected: false, Connected: false,
LoginExpired: true, LoginExpired: true,
@@ -195,6 +197,159 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
}, },
}, },
}, },
// {
// name: "Should return only peers that are approved when peer approval is enabled",
// accountSettings: Settings{PeerApprovalEnabled: true},
// peerID: peerID1,
// expectedPeers: []string{peerID3},
// expectedOfflinePeers: []string{},
// peers: map[string]*Peer{
// "peer-1": {
// ID: peerID1,
// Key: "peer-1-key",
// IP: net.IP{100, 64, 0, 1},
// Name: peerID1,
// DNSLabel: peerID1,
// Status: &PeerStatus{
// LastSeen: time.Now().UTC(),
// Connected: false,
// Approved: true,
// },
// UserID: userID,
// LastLogin: time.Now().UTC().Add(-time.Hour * 24 * 30 * 30),
// },
// "peer-2": {
// ID: peerID2,
// Key: "peer-2-key",
// IP: net.IP{100, 64, 0, 1},
// Name: peerID2,
// DNSLabel: peerID2,
// Status: &PeerStatus{
// LastSeen: time.Now().UTC(),
// Connected: false,
// Approved: false,
// },
// UserID: userID,
// LastLogin: time.Now().UTC().Add(-time.Hour * 24 * 30 * 30),
// },
// "peer-3": {
// ID: peerID3,
// Key: "peer-3-key",
// IP: net.IP{100, 64, 0, 1},
// Name: peerID3,
// DNSLabel: peerID3,
// Status: &PeerStatus{
// LastSeen: time.Now().UTC(),
// Connected: false,
// Approved: true,
// },
// UserID: userID,
// LastLogin: time.Now().UTC().Add(-time.Hour * 24 * 30 * 30),
// },
// },
// },
// {
// name: "Should return all peers when peer approval is disabled",
// accountSettings: Settings{PeerApprovalEnabled: false},
// peerID: peerID1,
// expectedPeers: []string{peerID2, peerID3},
// expectedOfflinePeers: []string{},
// peers: map[string]*Peer{
// "peer-1": {
// ID: peerID1,
// Key: "peer-1-key",
// IP: net.IP{100, 64, 0, 1},
// Name: peerID1,
// DNSLabel: peerID1,
// Status: &PeerStatus{
// LastSeen: time.Now().UTC(),
// Connected: false,
// Approved: true,
// },
// UserID: userID,
// LastLogin: time.Now().UTC().Add(-time.Hour * 24 * 30 * 30),
// },
// "peer-2": {
// ID: peerID2,
// Key: "peer-2-key",
// IP: net.IP{100, 64, 0, 1},
// Name: peerID2,
// DNSLabel: peerID2,
// Status: &PeerStatus{
// LastSeen: time.Now().UTC(),
// Connected: false,
// Approved: false,
// },
// UserID: userID,
// LastLogin: time.Now().UTC().Add(-time.Hour * 24 * 30 * 30),
// },
// "peer-3": {
// ID: peerID3,
// Key: "peer-3-key",
// IP: net.IP{100, 64, 0, 1},
// Name: peerID3,
// DNSLabel: peerID3,
// Status: &PeerStatus{
// LastSeen: time.Now().UTC(),
// Connected: false,
// Approved: true,
// },
// UserID: userID,
// LastLogin: time.Now().UTC().Add(-time.Hour * 24 * 30 * 30),
// },
// },
// },
// {
// name: "Should return no peers when peer approval is enabled and the requesting peer is not approved",
// accountSettings: Settings{PeerApprovalEnabled: true},
// peerID: peerID1,
// expectedPeers: []string{},
// expectedOfflinePeers: []string{},
// peers: map[string]*Peer{
// "peer-1": {
// ID: peerID1,
// Key: "peer-1-key",
// IP: net.IP{100, 64, 0, 1},
// Name: peerID1,
// DNSLabel: peerID1,
// Status: &PeerStatus{
// LastSeen: time.Now().UTC(),
// Connected: false,
// Approved: false,
// },
// UserID: userID,
// LastLogin: time.Now().UTC().Add(-time.Hour * 24 * 30 * 30),
// },
// "peer-2": {
// ID: peerID2,
// Key: "peer-2-key",
// IP: net.IP{100, 64, 0, 1},
// Name: peerID2,
// DNSLabel: peerID2,
// Status: &PeerStatus{
// LastSeen: time.Now().UTC(),
// Connected: false,
// Approved: true,
// },
// UserID: userID,
// LastLogin: time.Now().UTC().Add(-time.Hour * 24 * 30 * 30),
// },
// "peer-3": {
// ID: peerID3,
// Key: "peer-3-key",
// IP: net.IP{100, 64, 0, 1},
// Name: peerID3,
// DNSLabel: peerID3,
// Status: &PeerStatus{
// LastSeen: time.Now().UTC(),
// Connected: false,
// Approved: true,
// },
// UserID: userID,
// LastLogin: time.Now().UTC().Add(-time.Hour * 24 * 30 * 30),
// },
// },
// },
} }
netIP := net.IP{100, 64, 0, 0} netIP := net.IP{100, 64, 0, 0}
@@ -209,6 +364,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
for _, testCase := range tt { for _, testCase := range tt {
account := newAccountWithId("account-1", userID, "netbird.io") account := newAccountWithId("account-1", userID, "netbird.io")
account.UpdateSettings(&testCase.accountSettings)
account.Network = network account.Network = network
account.Peers = testCase.peers account.Peers = testCase.peers
for _, peer := range account.Peers { for _, peer := range account.Peers {
@@ -805,9 +961,9 @@ func TestAccountManager_AddPeer(t *testing.T) {
expectedPeerKey := key.PublicKey().String() expectedPeerKey := key.PublicKey().String()
expectedSetupKey := setupKey.Key expectedSetupKey := setupKey.Key
peer, _, err := manager.AddPeer(setupKey.Key, "", &Peer{ peer, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey, Key: expectedPeerKey,
Meta: PeerSystemMeta{Hostname: expectedPeerKey}, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
}) })
if err != nil { if err != nil {
t.Errorf("expecting peer to be added, got failure %v", err) t.Errorf("expecting peer to be added, got failure %v", err)
@@ -873,9 +1029,9 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
expectedPeerKey := key.PublicKey().String() expectedPeerKey := key.PublicKey().String()
expectedUserID := userID expectedUserID := userID
peer, _, err := manager.AddPeer("", userID, &Peer{ peer, _, err := manager.AddPeer("", userID, &nbpeer.Peer{
Key: expectedPeerKey, Key: expectedPeerKey,
Meta: PeerSystemMeta{Hostname: expectedPeerKey}, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
}) })
if err != nil { if err != nil {
t.Errorf("expecting peer to be added, got failure %v, account users: %v", err, account.CreatedBy) t.Errorf("expecting peer to be added, got failure %v, account users: %v", err, account.CreatedBy)
@@ -940,7 +1096,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
return return
} }
getPeer := func() *Peer { getPeer := func() *nbpeer.Peer {
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -948,9 +1104,9 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
} }
expectedPeerKey := key.PublicKey().String() expectedPeerKey := key.PublicKey().String()
peer, _, err := manager.AddPeer(setupKey.Key, "", &Peer{ peer, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey, Key: expectedPeerKey,
Meta: PeerSystemMeta{Hostname: expectedPeerKey}, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
}) })
if err != nil { if err != nil {
t.Fatalf("expecting peer1 to be added, got failure %v", err) t.Fatalf("expecting peer1 to be added, got failure %v", err)
@@ -1122,9 +1278,9 @@ func TestAccountManager_DeletePeer(t *testing.T) {
peerKey := key.PublicKey().String() peerKey := key.PublicKey().String()
peer, _, err := manager.AddPeer(setupKey.Key, "", &Peer{ peer, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: peerKey, Key: peerKey,
Meta: PeerSystemMeta{Hostname: peerKey}, Meta: nbpeer.PeerSystemMeta{Hostname: peerKey},
}) })
if err != nil { if err != nil {
t.Errorf("expecting peer to be added, got failure %v", err) t.Errorf("expecting peer to be added, got failure %v", err)
@@ -1263,8 +1419,8 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
account := &Account{ account := &Account{
Peers: map[string]*Peer{ Peers: map[string]*nbpeer.Peer{
"peer-1": {Key: "peer-1", Meta: PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: PeerSystemMeta{GoOS: "linux"}}, "peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}},
}, },
Groups: map[string]*Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, Groups: map[string]*Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}},
Routes: map[string]*route.Route{ Routes: map[string]*route.Route{
@@ -1307,7 +1463,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
}, },
} }
routes := account.getRoutesToSync("peer-2", []*Peer{{Key: "peer-1"}, {Key: "peer-3"}}) routes := account.getRoutesToSync("peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
assert.Len(t, routes, 2) assert.Len(t, routes, 2)
routeIDs := make(map[string]struct{}, 2) routeIDs := make(map[string]struct{}, 2)
@@ -1317,7 +1473,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
assert.Contains(t, routeIDs, "route-2") assert.Contains(t, routeIDs, "route-2")
assert.Contains(t, routeIDs, "route-3") assert.Contains(t, routeIDs, "route-3")
emptyRoutes := account.getRoutesToSync("peer-3", []*Peer{{Key: "peer-1"}, {Key: "peer-2"}}) emptyRoutes := account.getRoutesToSync("peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
assert.Len(t, emptyRoutes, 0) assert.Len(t, emptyRoutes, 0)
} }
@@ -1338,10 +1494,10 @@ func TestAccount_Copy(t *testing.T) {
Network: &Network{ Network: &Network{
Identifier: "net1", Identifier: "net1",
}, },
Peers: map[string]*Peer{ Peers: map[string]*nbpeer.Peer{
"peer1": { "peer1": {
Key: "key1", Key: "key1",
Status: &PeerStatus{ Status: &nbpeer.PeerStatus{
LastSeen: time.Now(), LastSeen: time.Now(),
Connected: true, Connected: true,
LoginExpired: false, LoginExpired: false,
@@ -1468,9 +1624,9 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key") require.NoError(t, err, "unable to generate WireGuard key")
peer, _, err := manager.AddPeer("", userID, &Peer{ peer, _, err := manager.AddPeer("", userID, &nbpeer.Peer{
Key: key.PublicKey().String(), Key: key.PublicKey().String(),
Meta: PeerSystemMeta{Hostname: "test-peer"}, Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true, LoginExpirationEnabled: true,
}) })
require.NoError(t, err, "unable to add peer") require.NoError(t, err, "unable to add peer")
@@ -1517,9 +1673,9 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key") require.NoError(t, err, "unable to generate WireGuard key")
_, _, err = manager.AddPeer("", userID, &Peer{ _, _, err = manager.AddPeer("", userID, &nbpeer.Peer{
Key: key.PublicKey().String(), Key: key.PublicKey().String(),
Meta: PeerSystemMeta{Hostname: "test-peer"}, Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true, LoginExpirationEnabled: true,
}) })
require.NoError(t, err, "unable to add peer") require.NoError(t, err, "unable to add peer")
@@ -1558,9 +1714,9 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key") require.NoError(t, err, "unable to generate WireGuard key")
_, _, err = manager.AddPeer("", userID, &Peer{ _, _, err = manager.AddPeer("", userID, &nbpeer.Peer{
Key: key.PublicKey().String(), Key: key.PublicKey().String(),
Meta: PeerSystemMeta{Hostname: "test-peer"}, Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true, LoginExpirationEnabled: true,
}) })
require.NoError(t, err, "unable to add peer") require.NoError(t, err, "unable to add peer")
@@ -1639,13 +1795,13 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
func TestAccount_GetExpiredPeers(t *testing.T) { func TestAccount_GetExpiredPeers(t *testing.T) {
type test struct { type test struct {
name string name string
peers map[string]*Peer peers map[string]*nbpeer.Peer
expectedPeers map[string]struct{} expectedPeers map[string]struct{}
} }
testCases := []test{ testCases := []test{
{ {
name: "Peers with login expiration disabled, no expired peers", name: "Peers with login expiration disabled, no expired peers",
peers: map[string]*Peer{ peers: map[string]*nbpeer.Peer{
"peer-1": { "peer-1": {
LoginExpirationEnabled: false, LoginExpirationEnabled: false,
}, },
@@ -1657,11 +1813,11 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
}, },
{ {
name: "Two peers expired", name: "Two peers expired",
peers: map[string]*Peer{ peers: map[string]*nbpeer.Peer{
"peer-1": { "peer-1": {
ID: "peer-1", ID: "peer-1",
LoginExpirationEnabled: true, LoginExpirationEnabled: true,
Status: &PeerStatus{ Status: &nbpeer.PeerStatus{
LastSeen: time.Now().UTC(), LastSeen: time.Now().UTC(),
Connected: true, Connected: true,
LoginExpired: false, LoginExpired: false,
@@ -1672,7 +1828,7 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
"peer-2": { "peer-2": {
ID: "peer-2", ID: "peer-2",
LoginExpirationEnabled: true, LoginExpirationEnabled: true,
Status: &PeerStatus{ Status: &nbpeer.PeerStatus{
LastSeen: time.Now().UTC(), LastSeen: time.Now().UTC(),
Connected: true, Connected: true,
LoginExpired: false, LoginExpired: false,
@@ -1684,7 +1840,7 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
"peer-3": { "peer-3": {
ID: "peer-3", ID: "peer-3",
LoginExpirationEnabled: true, LoginExpirationEnabled: true,
Status: &PeerStatus{ Status: &nbpeer.PeerStatus{
LastSeen: time.Now().UTC(), LastSeen: time.Now().UTC(),
Connected: true, Connected: true,
LoginExpired: false, LoginExpired: false,
@@ -1724,19 +1880,19 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
func TestAccount_GetPeersWithExpiration(t *testing.T) { func TestAccount_GetPeersWithExpiration(t *testing.T) {
type test struct { type test struct {
name string name string
peers map[string]*Peer peers map[string]*nbpeer.Peer
expectedPeers map[string]struct{} expectedPeers map[string]struct{}
} }
testCases := []test{ testCases := []test{
{ {
name: "No account peers, no peers with expiration", name: "No account peers, no peers with expiration",
peers: map[string]*Peer{}, peers: map[string]*nbpeer.Peer{},
expectedPeers: map[string]struct{}{}, expectedPeers: map[string]struct{}{},
}, },
{ {
name: "Peers with login expiration disabled, no peers with expiration", name: "Peers with login expiration disabled, no peers with expiration",
peers: map[string]*Peer{ peers: map[string]*nbpeer.Peer{
"peer-1": { "peer-1": {
LoginExpirationEnabled: false, LoginExpirationEnabled: false,
UserID: userID, UserID: userID,
@@ -1750,7 +1906,7 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) {
}, },
{ {
name: "Peers with login expiration enabled, return peers with expiration", name: "Peers with login expiration enabled, return peers with expiration",
peers: map[string]*Peer{ peers: map[string]*nbpeer.Peer{
"peer-1": { "peer-1": {
ID: "peer-1", ID: "peer-1",
LoginExpirationEnabled: true, LoginExpirationEnabled: true,
@@ -1793,7 +1949,7 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) {
func TestAccount_GetNextPeerExpiration(t *testing.T) { func TestAccount_GetNextPeerExpiration(t *testing.T) {
type test struct { type test struct {
name string name string
peers map[string]*Peer peers map[string]*nbpeer.Peer
expiration time.Duration expiration time.Duration
expirationEnabled bool expirationEnabled bool
expectedNextRun bool expectedNextRun bool
@@ -1804,7 +1960,7 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
testCases := []test{ testCases := []test{
{ {
name: "No peers, no expiration", name: "No peers, no expiration",
peers: map[string]*Peer{}, peers: map[string]*nbpeer.Peer{},
expiration: time.Second, expiration: time.Second,
expirationEnabled: false, expirationEnabled: false,
expectedNextRun: false, expectedNextRun: false,
@@ -1812,16 +1968,16 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
}, },
{ {
name: "No connected peers, no expiration", name: "No connected peers, no expiration",
peers: map[string]*Peer{ peers: map[string]*nbpeer.Peer{
"peer-1": { "peer-1": {
Status: &PeerStatus{ Status: &nbpeer.PeerStatus{
Connected: false, Connected: false,
}, },
LoginExpirationEnabled: true, LoginExpirationEnabled: true,
UserID: userID, UserID: userID,
}, },
"peer-2": { "peer-2": {
Status: &PeerStatus{ Status: &nbpeer.PeerStatus{
Connected: true, Connected: true,
}, },
LoginExpirationEnabled: false, LoginExpirationEnabled: false,
@@ -1835,16 +1991,16 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
}, },
{ {
name: "Connected peers with disabled expiration, no expiration", name: "Connected peers with disabled expiration, no expiration",
peers: map[string]*Peer{ peers: map[string]*nbpeer.Peer{
"peer-1": { "peer-1": {
Status: &PeerStatus{ Status: &nbpeer.PeerStatus{
Connected: true, Connected: true,
}, },
LoginExpirationEnabled: false, LoginExpirationEnabled: false,
UserID: userID, UserID: userID,
}, },
"peer-2": { "peer-2": {
Status: &PeerStatus{ Status: &nbpeer.PeerStatus{
Connected: true, Connected: true,
}, },
LoginExpirationEnabled: false, LoginExpirationEnabled: false,
@@ -1858,9 +2014,9 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
}, },
{ {
name: "Expired peers, no expiration", name: "Expired peers, no expiration",
peers: map[string]*Peer{ peers: map[string]*nbpeer.Peer{
"peer-1": { "peer-1": {
Status: &PeerStatus{ Status: &nbpeer.PeerStatus{
Connected: true, Connected: true,
LoginExpired: true, LoginExpired: true,
}, },
@@ -1868,7 +2024,7 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
UserID: userID, UserID: userID,
}, },
"peer-2": { "peer-2": {
Status: &PeerStatus{ Status: &nbpeer.PeerStatus{
Connected: true, Connected: true,
LoginExpired: true, LoginExpired: true,
}, },
@@ -1883,9 +2039,9 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
}, },
{ {
name: "To be expired peer, return expiration", name: "To be expired peer, return expiration",
peers: map[string]*Peer{ peers: map[string]*nbpeer.Peer{
"peer-1": { "peer-1": {
Status: &PeerStatus{ Status: &nbpeer.PeerStatus{
Connected: true, Connected: true,
LoginExpired: false, LoginExpired: false,
}, },
@@ -1894,7 +2050,7 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
UserID: userID, UserID: userID,
}, },
"peer-2": { "peer-2": {
Status: &PeerStatus{ Status: &nbpeer.PeerStatus{
Connected: true, Connected: true,
LoginExpired: true, LoginExpired: true,
}, },
@@ -1909,9 +2065,9 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
}, },
{ {
name: "Peers added with setup keys, no expiration", name: "Peers added with setup keys, no expiration",
peers: map[string]*Peer{ peers: map[string]*nbpeer.Peer{
"peer-1": { "peer-1": {
Status: &PeerStatus{ Status: &nbpeer.PeerStatus{
Connected: true, Connected: true,
LoginExpired: false, LoginExpired: false,
}, },
@@ -1919,7 +2075,7 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
SetupKey: "key", SetupKey: "key",
}, },
"peer-2": { "peer-2": {
Status: &PeerStatus{ Status: &nbpeer.PeerStatus{
Connected: true, Connected: true,
LoginExpired: false, LoginExpired: false,
}, },
@@ -1954,7 +2110,7 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) {
func TestAccount_SetJWTGroups(t *testing.T) { func TestAccount_SetJWTGroups(t *testing.T) {
// create a new account // create a new account
account := &Account{ account := &Account{
Peers: map[string]*Peer{ Peers: map[string]*nbpeer.Peer{
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, "peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"}, "peer2": {ID: "peer2", Key: "key2", UserID: "user1"},
"peer3": {ID: "peer3", Key: "key3", UserID: "user1"}, "peer3": {ID: "peer3", Key: "key3", UserID: "user1"},
@@ -2002,7 +2158,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
func TestAccount_UserGroupsAddToPeers(t *testing.T) { func TestAccount_UserGroupsAddToPeers(t *testing.T) {
account := &Account{ account := &Account{
Peers: map[string]*Peer{ Peers: map[string]*nbpeer.Peer{
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, "peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"}, "peer2": {ID: "peer2", Key: "key2", UserID: "user1"},
"peer3": {ID: "peer3", Key: "key3", UserID: "user1"}, "peer3": {ID: "peer3", Key: "key3", UserID: "user1"},
@@ -2038,7 +2194,7 @@ func TestAccount_UserGroupsAddToPeers(t *testing.T) {
func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) { func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
account := &Account{ account := &Account{
Peers: map[string]*Peer{ Peers: map[string]*nbpeer.Peer{
"peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, "peer1": {ID: "peer1", Key: "key1", UserID: "user1"},
"peer2": {ID: "peer2", Key: "key2", UserID: "user1"}, "peer2": {ID: "peer2", Key: "key2", UserID: "user1"},
"peer3": {ID: "peer3", Key: "key3", UserID: "user1"}, "peer3": {ID: "peer3", Key: "key3", UserID: "user1"},

View File

@@ -120,6 +120,14 @@ const (
IntegrationUpdated IntegrationUpdated
// IntegrationDeleted indicates that the user deleted an integration // IntegrationDeleted indicates that the user deleted an integration
IntegrationDeleted IntegrationDeleted
// AccountPeerApprovalEnabled indicates that the user enabled peer approval for the account
AccountPeerApprovalEnabled
// AccountPeerApprovalDisabled indicates that the user disabled peer approval for the account
AccountPeerApprovalDisabled
// PeerApproved indicates that the peer has been approved
PeerApproved
// PeerApprovalRevoked indicates that the peer approval has been revoked
PeerApprovalRevoked
// TransferredOwnerRole indicates that the user transferred the owner role of the account // TransferredOwnerRole indicates that the user transferred the owner role of the account
TransferredOwnerRole TransferredOwnerRole
) )
@@ -180,6 +188,10 @@ var activityMap = map[Activity]Code{
IntegrationCreated: {"Integration created", "integration.create"}, IntegrationCreated: {"Integration created", "integration.create"},
IntegrationUpdated: {"Integration updated", "integration.update"}, IntegrationUpdated: {"Integration updated", "integration.update"},
IntegrationDeleted: {"Integration deleted", "integration.delete"}, IntegrationDeleted: {"Integration deleted", "integration.delete"},
AccountPeerApprovalEnabled: {"Account peer approval enabled", "account.setting.peer.approval.enable"},
AccountPeerApprovalDisabled: {"Account peer approval disabled", "account.setting.peer.approval.disable"},
PeerApproved: {"Peer approved", "peer.approve"},
PeerApprovalRevoked: {"Peer approval revoked", "peer.approval.revoke"},
TransferredOwnerRole: {"Transferred owner role", "transferred.owner.role"}, TransferredOwnerRole: {"Transferred owner role", "transferred.owner.role"},
} }

View File

@@ -10,6 +10,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
@@ -200,7 +201,7 @@ func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
} }
// peerIsNameserver returns true if the peer is a nameserver for a nsGroup // peerIsNameserver returns true if the peer is a nameserver for a nsGroup
func peerIsNameserver(peer *Peer, nsGroup *nbdns.NameServerGroup) bool { func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
for _, ns := range nsGroup.NameServers { for _, ns := range nsGroup.NameServers {
if peer.IP.Equal(ns.IP.AsSlice()) { if peer.IP.Equal(ns.IP.AsSlice()) {
return true return true

View File

@@ -8,6 +8,7 @@ import (
"github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
@@ -208,10 +209,10 @@ func createDNSStore(t *testing.T) (Store, error) {
func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) {
t.Helper() t.Helper()
peer1 := &Peer{ peer1 := &nbpeer.Peer{
Key: dnsPeer1Key, Key: dnsPeer1Key,
Name: "test-host1@netbird.io", Name: "test-host1@netbird.io",
Meta: PeerSystemMeta{ Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host1@netbird.io", Hostname: "test-host1@netbird.io",
GoOS: "linux", GoOS: "linux",
Kernel: "Linux", Kernel: "Linux",
@@ -223,10 +224,10 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
}, },
DNSLabel: dnsPeer1Key, DNSLabel: dnsPeer1Key,
} }
peer2 := &Peer{ peer2 := &nbpeer.Peer{
Key: dnsPeer2Key, Key: dnsPeer2Key,
Name: "test-host2@netbird.io", Name: "test-host2@netbird.io",
Meta: PeerSystemMeta{ Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host2@netbird.io", Hostname: "test-host2@netbird.io",
GoOS: "linux", GoOS: "linux",
Kernel: "Linux", Kernel: "Linux",

View File

@@ -7,6 +7,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
) )
const ( const (
@@ -72,7 +73,7 @@ func (e *EphemeralManager) Stop() {
// OnPeerConnected remove the peer from the linked list of ephemeral peers. Because it has been called when the peer // OnPeerConnected remove the peer from the linked list of ephemeral peers. Because it has been called when the peer
// is active the manager will not delete it while it is active. // is active the manager will not delete it while it is active.
func (e *EphemeralManager) OnPeerConnected(peer *Peer) { func (e *EphemeralManager) OnPeerConnected(peer *nbpeer.Peer) {
if !peer.Ephemeral { if !peer.Ephemeral {
return return
} }
@@ -93,7 +94,7 @@ func (e *EphemeralManager) OnPeerConnected(peer *Peer) {
// OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer // OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer
// is inactive it will be deleted after the ephemeralLifeTime period. // is inactive it will be deleted after the ephemeralLifeTime period.
func (e *EphemeralManager) OnPeerDisconnected(peer *Peer) { func (e *EphemeralManager) OnPeerDisconnected(peer *nbpeer.Peer) {
if !peer.Ephemeral { if !peer.Ephemeral {
return return
} }

View File

@@ -4,6 +4,8 @@ import (
"fmt" "fmt"
"testing" "testing"
"time" "time"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
) )
type MockStore struct { type MockStore struct {
@@ -124,7 +126,7 @@ func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int)
for i := 0; i < numberOfPeers; i++ { for i := 0; i < numberOfPeers; i++ {
peerId := fmt.Sprintf("peer_%d", i) peerId := fmt.Sprintf("peer_%d", i)
p := &Peer{ p := &nbpeer.Peer{
ID: peerId, ID: peerId,
Ephemeral: false, Ephemeral: false,
} }
@@ -133,7 +135,7 @@ func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int)
for i := 0; i < numberOfEphemeralPeers; i++ { for i := 0; i < numberOfEphemeralPeers; i++ {
peerId := fmt.Sprintf("ephemeral_peer_%d", i) peerId := fmt.Sprintf("ephemeral_peer_%d", i)
p := &Peer{ p := &nbpeer.Peer{
ID: peerId, ID: peerId,
Ephemeral: true, Ephemeral: true,
} }

View File

@@ -10,6 +10,7 @@ import (
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
@@ -204,7 +205,7 @@ func restore(file string) (*FileStore, error) {
// Set the Peer.ID to the newly generated value. // Set the Peer.ID to the newly generated value.
// Replace all the mentions of Peer.Key as ID (groups and routes). // Replace all the mentions of Peer.Key as ID (groups and routes).
// Swap Peer.Key with Peer.ID in the Account.Peers map. // Swap Peer.Key with Peer.ID in the Account.Peers map.
migrationPeers := make(map[string]*Peer) // key to Peer migrationPeers := make(map[string]*nbpeer.Peer) // key to Peer
for key, peer := range account.Peers { for key, peer := range account.Peers {
// set LastLogin for the peers that were onboarded before the peer login expiration feature // set LastLogin for the peers that were onboarded before the peer login expiration feature
if peer.LastLogin.IsZero() { if peer.LastLogin.IsZero() {
@@ -606,7 +607,7 @@ func (s *FileStore) SaveInstallationID(ID string) error {
// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things. // SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things.
// PeerStatus will be saved eventually when some other changes occur. // PeerStatus will be saved eventually when some other changes occur.
func (s *FileStore) SavePeerStatus(accountID, peerID string, peerStatus PeerStatus) error { func (s *FileStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()

View File

@@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -35,7 +36,7 @@ func TestStalePeerIndices(t *testing.T) {
peerID := "some_peer" peerID := "some_peer"
peerKey := "some_peer_key" peerKey := "some_peer_key"
account.Peers[peerID] = &Peer{ account.Peers[peerID] = &nbpeer.Peer{
ID: peerID, ID: peerID,
Key: peerKey, Key: peerKey,
} }
@@ -89,13 +90,13 @@ func TestSaveAccount(t *testing.T) {
account := newAccountWithId("account_id", "testuser", "") account := newAccountWithId("account_id", "testuser", "")
setupKey := GenerateDefaultSetupKey() setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &Peer{ account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey", Key: "peerkey",
SetupKey: "peerkeysetupkey", SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1}, IP: net.IP{127, 0, 0, 1},
Meta: PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name", Name: "peer name",
Status: &PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
} }
// SaveAccount should trigger persist // SaveAccount should trigger persist
@@ -179,13 +180,13 @@ func TestStore(t *testing.T) {
store := newStore(t) store := newStore(t)
account := newAccountWithId("account_id", "testuser", "") account := newAccountWithId("account_id", "testuser", "")
account.Peers["testpeer"] = &Peer{ account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey", Key: "peerkey",
SetupKey: "peerkeysetupkey", SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1}, IP: net.IP{127, 0, 0, 1},
Meta: PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name", Name: "peer name",
Status: &PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
} }
account.Groups["all"] = &Group{ account.Groups["all"] = &Group{
ID: "all", ID: "all",
@@ -600,19 +601,19 @@ func TestFileStore_SavePeerStatus(t *testing.T) {
} }
// save status of non-existing peer // save status of non-existing peer
newStatus := PeerStatus{Connected: true, LastSeen: time.Now().UTC()} newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
assert.Error(t, err) assert.Error(t, err)
// save new status of existing peer // save new status of existing peer
account.Peers["testpeer"] = &Peer{ account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey", Key: "peerkey",
ID: "testpeer", ID: "testpeer",
SetupKey: "peerkeysetupkey", SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1}, IP: net.IP{127, 0, 0, 1},
Meta: PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name", Name: "peer name",
Status: &PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
} }
err = store.SaveAccount(account) err = store.SaveAccount(account)

View File

@@ -17,6 +17,7 @@ import (
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
internalStatus "github.com/netbirdio/netbird/management/server/status" internalStatus "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
) )
@@ -196,7 +197,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
} }
} }
func (s *GRPCServer) cancelPeerRoutines(peer *Peer) { func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(peer.ID) s.peersUpdateManager.CloseChannel(peer.ID)
s.turnCredentialsManager.CancelRefresh(peer.ID) s.turnCredentialsManager.CancelRefresh(peer.ID)
_ = s.accountManager.MarkPeerConnected(peer.Key, false) _ = s.accountManager.MarkPeerConnected(peer.Key, false)
@@ -243,8 +244,8 @@ func mapError(err error) error {
return status.Errorf(codes.Internal, "failed handling request") return status.Errorf(codes.Internal, "failed handling request")
} }
func extractPeerMeta(loginReq *proto.LoginRequest) PeerSystemMeta { func extractPeerMeta(loginReq *proto.LoginRequest) nbpeer.PeerSystemMeta {
return PeerSystemMeta{ return nbpeer.PeerSystemMeta{
Hostname: loginReq.GetMeta().GetHostname(), Hostname: loginReq.GetMeta().GetHostname(),
GoOS: loginReq.GetMeta().GetGoOS(), GoOS: loginReq.GetMeta().GetGoOS(),
Kernel: loginReq.GetMeta().GetKernel(), Kernel: loginReq.GetMeta().GetKernel(),
@@ -413,7 +414,7 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
} }
} }
func toPeerConfig(peer *Peer, network *Network, dnsName string) *proto.PeerConfig { func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.PeerConfig {
netmask, _ := network.Net.Mask.Size() netmask, _ := network.Net.Mask.Size()
fqdn := peer.FQDN(dnsName) fqdn := peer.FQDN(dnsName)
return &proto.PeerConfig{ return &proto.PeerConfig{
@@ -423,7 +424,7 @@ func toPeerConfig(peer *Peer, network *Network, dnsName string) *proto.PeerConfi
} }
} }
func toRemotePeerConfig(peers []*Peer, dnsName string) []*proto.RemotePeerConfig { func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
remotePeers := []*proto.RemotePeerConfig{} remotePeers := []*proto.RemotePeerConfig{}
for _, rPeer := range peers { for _, rPeer := range peers {
fqdn := rPeer.FQDN(dnsName) fqdn := rPeer.FQDN(dnsName)
@@ -437,7 +438,7 @@ func toRemotePeerConfig(peers []*Peer, dnsName string) []*proto.RemotePeerConfig
return remotePeers return remotePeers
} }
func toSyncResponse(config *Config, peer *Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse { func toSyncResponse(config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials) wtConfig := toWiretrusteeConfig(config, turnCredentials)
pConfig := toPeerConfig(peer, networkMap.Network, dnsName) pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
@@ -477,7 +478,7 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em
} }
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *Peer, networkMap *NetworkMap, srv proto.ManagementService_SyncServer) error { func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, srv proto.ManagementService_SyncServer) error {
// make secret time based TURN credentials optional // make secret time based TURN credentials optional
var turnCredentials *TURNCredentials var turnCredentials *TURNCredentials
if s.config.TURNConfig.TimeBasedCredentials { if s.config.TURNConfig.TimeBasedCredentials {

View File

@@ -8,6 +8,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -77,6 +78,10 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
} }
if req.Settings.Extra != nil {
settings.Extra = &account.ExtraSettings{PeerApprovalEnabled: *req.Settings.Extra.PeerApprovalEnabled}
}
if req.Settings.JwtGroupsEnabled != nil { if req.Settings.JwtGroupsEnabled != nil {
settings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled settings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled
} }
@@ -86,6 +91,9 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
if req.Settings.JwtGroupsClaimName != nil { if req.Settings.JwtGroupsClaimName != nil {
settings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName settings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName
} }
if req.Settings.JwtAllowGroups != nil {
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
}
updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, settings) updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, settings)
if err != nil { if err != nil {
@@ -123,14 +131,26 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
} }
func toAccountResponse(account *server.Account) *api.Account { func toAccountResponse(account *server.Account) *api.Account {
jwtAllowGroups := account.Settings.JWTAllowGroups
if jwtAllowGroups == nil {
jwtAllowGroups = []string{}
}
settings := api.AccountSettings{
PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()),
PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled,
GroupsPropagationEnabled: &account.Settings.GroupsPropagationEnabled,
JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled,
JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName,
JwtAllowGroups: &jwtAllowGroups,
}
if account.Settings.Extra != nil {
settings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &account.Settings.Extra.PeerApprovalEnabled}
}
return &api.Account{ return &api.Account{
Id: account.Id, Id: account.Id,
Settings: api.AccountSettings{ Settings: settings,
PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()),
PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled,
GroupsPropagationEnabled: &account.Settings.GroupsPropagationEnabled,
JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled,
JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName,
},
} }
} }

View File

@@ -95,6 +95,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
GroupsPropagationEnabled: br(false), GroupsPropagationEnabled: br(false),
JwtGroupsClaimName: sr(""), JwtGroupsClaimName: sr(""),
JwtGroupsEnabled: br(false), JwtGroupsEnabled: br(false),
JwtAllowGroups: &[]string{},
}, },
expectedArray: true, expectedArray: true,
expectedID: accountID, expectedID: accountID,
@@ -112,6 +113,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
GroupsPropagationEnabled: br(false), GroupsPropagationEnabled: br(false),
JwtGroupsClaimName: sr(""), JwtGroupsClaimName: sr(""),
JwtGroupsEnabled: br(false), JwtGroupsEnabled: br(false),
JwtAllowGroups: &[]string{},
}, },
expectedArray: false, expectedArray: false,
expectedID: accountID, expectedID: accountID,
@@ -121,7 +123,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
expectedBody: true, expectedBody: true,
requestType: http.MethodPut, requestType: http.MethodPut,
requestPath: "/api/accounts/" + accountID, requestPath: "/api/accounts/" + accountID,
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\"}}"), requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"]}}"),
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
expectedSettings: api.AccountSettings{ expectedSettings: api.AccountSettings{
PeerLoginExpiration: 15552000, PeerLoginExpiration: 15552000,
@@ -129,6 +131,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
GroupsPropagationEnabled: br(false), GroupsPropagationEnabled: br(false),
JwtGroupsClaimName: sr("roles"), JwtGroupsClaimName: sr("roles"),
JwtGroupsEnabled: br(true), JwtGroupsEnabled: br(true),
JwtAllowGroups: &[]string{"test"},
}, },
expectedArray: false, expectedArray: false,
expectedID: accountID, expectedID: accountID,
@@ -146,6 +149,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
GroupsPropagationEnabled: br(true), GroupsPropagationEnabled: br(true),
JwtGroupsClaimName: sr("groups"), JwtGroupsClaimName: sr("groups"),
JwtGroupsEnabled: br(true), JwtGroupsEnabled: br(true),
JwtAllowGroups: &[]string{},
}, },
expectedArray: false, expectedArray: false,
expectedID: accountID, expectedID: accountID,

View File

@@ -66,9 +66,24 @@ components:
description: Name of the claim from which we extract groups names to add it to account groups. description: Name of the claim from which we extract groups names to add it to account groups.
type: string type: string
example: "roles" example: "roles"
jwt_allow_groups:
description: List of groups to which users are allowed access
type: array
items:
type: string
example: Administrators
extra:
$ref: '#/components/schemas/AccountExtraSettings'
required: required:
- peer_login_expiration_enabled - peer_login_expiration_enabled
- peer_login_expiration - peer_login_expiration
AccountExtraSettings:
type: object
properties:
peer_approval_enabled:
description: (Cloud only) Enables or disables peer approval globally. If enabled, all peers added will be in pending state until approved by an admin.
type: boolean
example: true
AccountRequest: AccountRequest:
type: object type: object
properties: properties:
@@ -106,11 +121,11 @@ components:
format: date-time format: date-time
example: 2023-05-05T09:00:35.477782Z example: 2023-05-05T09:00:35.477782Z
auto_groups: auto_groups:
description: Groups to auto-assign to peers registered by this user description: Group IDs to auto-assign to peers registered by this user
type: array type: array
items: items:
type: string type: string
example: devs example: ch8i4ug6lnn4g9hqv7m0
is_current: is_current:
description: Is true if authenticated user is the same as this user description: Is true if authenticated user is the same as this user
type: boolean type: boolean
@@ -145,11 +160,11 @@ components:
type: string type: string
example: admin example: admin
auto_groups: auto_groups:
description: Groups to auto-assign to peers registered by this user description: Group IDs to auto-assign to peers registered by this user
type: array type: array
items: items:
type: string type: string
example: devs example: ch8i4ug6lnn4g9hqv7m0
is_blocked: is_blocked:
description: If set to true then user is blocked and can't use the system description: If set to true then user is blocked and can't use the system
type: boolean type: boolean
@@ -174,11 +189,11 @@ components:
type: string type: string
example: admin example: admin
auto_groups: auto_groups:
description: Groups to auto-assign to peers registered by this user description: Group IDs to auto-assign to peers registered by this user
type: array type: array
items: items:
type: string type: string
example: devs example: ch8i4ug6lnn4g9hqv7m0
is_service_user: is_service_user:
description: Is true if this user is a service user description: Is true if this user is a service user
type: boolean type: boolean
@@ -213,6 +228,10 @@ components:
login_expiration_enabled: login_expiration_enabled:
type: boolean type: boolean
example: false example: false
approval_required:
description: (Cloud only) Indicates whether peer needs approval
type: boolean
example: true
required: required:
- name - name
- ssh_enabled - ssh_enabled
@@ -281,6 +300,10 @@ components:
type: string type: string
format: date-time format: date-time
example: 2023-05-05T09:00:35.477782Z example: 2023-05-05T09:00:35.477782Z
approval_required:
description: (Cloud only) Indicates whether peer needs approval
type: boolean
example: true
required: required:
- ip - ip
- connected - connected
@@ -388,7 +411,7 @@ components:
type: array type: array
items: items:
type: string type: string
example: "devs" example: "ch8i4ug6lnn4g9hqv7m0"
updated_at: updated_at:
description: Setup key last update date description: Setup key last update date
type: string type: string
@@ -443,7 +466,7 @@ components:
type: array type: array
items: items:
type: string type: string
example: "devs" example: "ch8i4ug6lnn4g9hqv7m0"
usage_limit: usage_limit:
description: A number of times this key can be used. The value of 0 indicates the unlimited usage. description: A number of times this key can be used. The value of 0 indicates the unlimited usage.
type: integer type: integer
@@ -604,13 +627,13 @@ components:
properties: properties:
sources: sources:
type: array type: array
description: List of source groups description: List of source group IDs
items: items:
type: string type: string
example: "ch8i4ug6lnn4g9hqv7m1" example: "ch8i4ug6lnn4g9hqv7m1"
destinations: destinations:
type: array type: array
description: List of destination groups description: List of destination group IDs
items: items:
type: string type: string
example: "ch8i4ug6lnn4g9hqv7m0" example: "ch8i4ug6lnn4g9hqv7m0"
@@ -628,12 +651,12 @@ components:
- type: object - type: object
properties: properties:
sources: sources:
description: Rule source groups description: Rule source group IDs
type: array type: array
items: items:
$ref: '#/components/schemas/GroupMinimum' $ref: '#/components/schemas/GroupMinimum'
destinations: destinations:
description: Rule destination groups description: Rule destination group IDs
type: array type: array
items: items:
$ref: '#/components/schemas/GroupMinimum' $ref: '#/components/schemas/GroupMinimum'
@@ -691,13 +714,13 @@ components:
- type: object - type: object
properties: properties:
sources: sources:
description: Policy rule source groups description: Policy rule source group IDs
type: array type: array
items: items:
type: string type: string
example: "ch8i4ug6lnn4g9hqv797" example: "ch8i4ug6lnn4g9hqv797"
destinations: destinations:
description: Policy rule destination groups description: Policy rule destination group IDs
type: array type: array
items: items:
type: string type: string
@@ -711,12 +734,12 @@ components:
- type: object - type: object
properties: properties:
sources: sources:
description: Policy rule source groups description: Policy rule source group IDs
type: array type: array
items: items:
$ref: '#/components/schemas/GroupMinimum' $ref: '#/components/schemas/GroupMinimum'
destinations: destinations:
description: Policy rule destination groups description: Policy rule destination group IDs
type: array type: array
items: items:
$ref: '#/components/schemas/GroupMinimum' $ref: '#/components/schemas/GroupMinimum'
@@ -817,7 +840,7 @@ components:
type: boolean type: boolean
example: true example: true
groups: groups:
description: Route group tag groups description: Group IDs containing routing peers
type: array type: array
items: items:
type: string type: string
@@ -874,17 +897,17 @@ components:
type: object type: object
properties: properties:
name: name:
description: Nameserver group name description: Name of nameserver group name
type: string type: string
maxLength: 40 maxLength: 40
minLength: 1 minLength: 1
example: Google DNS example: Google DNS
description: description:
description: Nameserver group description description: Description of the nameserver group
type: string type: string
example: Google DNS servers example: Google DNS servers
nameservers: nameservers:
description: Nameserver group description: Nameserver list
minLength: 1 minLength: 1
maxLength: 2 maxLength: 2
type: array type: array
@@ -895,17 +918,17 @@ components:
type: boolean type: boolean
example: true example: true
groups: groups:
description: Nameserver group tag groups description: Distribution group IDs that defines group of peers that will use this nameserver group
type: array type: array
items: items:
type: string type: string
example: ch8i4ug6lnn4g9hqv7m0 example: ch8i4ug6lnn4g9hqv7m0
primary: primary:
description: Nameserver group primary status description: Defines if a nameserver group is primary that resolves all domains. It should be true only if domains list is empty.
type: boolean type: boolean
example: true example: true
domains: domains:
description: Nameserver group match domain list description: Match domain list. It should be empty only if primary is true.
type: array type: array
items: items:
type: string type: string
@@ -913,7 +936,7 @@ components:
maxLength: 255 maxLength: 255
example: "example.com" example: "example.com"
search_domains_enabled: search_domains_enabled:
description: Nameserver group search domain status for match domains. It should be true only if domains list is not empty. description: Search domain status for match domains. It should be true only if domains list is not empty.
type: boolean type: boolean
example: true example: true
required: required:

View File

@@ -142,6 +142,12 @@ type Account struct {
Settings AccountSettings `json:"settings"` Settings AccountSettings `json:"settings"`
} }
// AccountExtraSettings defines model for AccountExtraSettings.
type AccountExtraSettings struct {
// PeerApprovalEnabled (Cloud only) Enables or disables peer approval globally. If enabled, all peers added will be in pending state until approved by an admin.
PeerApprovalEnabled *bool `json:"peer_approval_enabled,omitempty"`
}
// AccountRequest defines model for AccountRequest. // AccountRequest defines model for AccountRequest.
type AccountRequest struct { type AccountRequest struct {
Settings AccountSettings `json:"settings"` Settings AccountSettings `json:"settings"`
@@ -149,9 +155,14 @@ type AccountRequest struct {
// AccountSettings defines model for AccountSettings. // AccountSettings defines model for AccountSettings.
type AccountSettings struct { type AccountSettings struct {
Extra *AccountExtraSettings `json:"extra,omitempty"`
// GroupsPropagationEnabled Allows propagate the new user auto groups to peers that belongs to the user // GroupsPropagationEnabled Allows propagate the new user auto groups to peers that belongs to the user
GroupsPropagationEnabled *bool `json:"groups_propagation_enabled,omitempty"` GroupsPropagationEnabled *bool `json:"groups_propagation_enabled,omitempty"`
// JwtAllowGroups List of groups to which users are allowed access
JwtAllowGroups *[]string `json:"jwt_allow_groups,omitempty"`
// JwtGroupsClaimName Name of the claim from which we extract groups names to add it to account groups. // JwtGroupsClaimName Name of the claim from which we extract groups names to add it to account groups.
JwtGroupsClaimName *string `json:"jwt_groups_claim_name,omitempty"` JwtGroupsClaimName *string `json:"jwt_groups_claim_name,omitempty"`
@@ -263,58 +274,58 @@ type NameserverNsType string
// NameserverGroup defines model for NameserverGroup. // NameserverGroup defines model for NameserverGroup.
type NameserverGroup struct { type NameserverGroup struct {
// Description Nameserver group description // Description Description of the nameserver group
Description string `json:"description"` Description string `json:"description"`
// Domains Nameserver group match domain list // Domains Match domain list. It should be empty only if primary is true.
Domains []string `json:"domains"` Domains []string `json:"domains"`
// Enabled Nameserver group status // Enabled Nameserver group status
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
// Groups Nameserver group tag groups // Groups Distribution group IDs that defines group of peers that will use this nameserver group
Groups []string `json:"groups"` Groups []string `json:"groups"`
// Id Nameserver group ID // Id Nameserver group ID
Id string `json:"id"` Id string `json:"id"`
// Name Nameserver group name // Name Name of nameserver group name
Name string `json:"name"` Name string `json:"name"`
// Nameservers Nameserver group // Nameservers Nameserver list
Nameservers []Nameserver `json:"nameservers"` Nameservers []Nameserver `json:"nameservers"`
// Primary Nameserver group primary status // Primary Defines if a nameserver group is primary that resolves all domains. It should be true only if domains list is empty.
Primary bool `json:"primary"` Primary bool `json:"primary"`
// SearchDomainsEnabled Nameserver group search domain status for match domains. It should be true only if domains list is not empty. // SearchDomainsEnabled Search domain status for match domains. It should be true only if domains list is not empty.
SearchDomainsEnabled bool `json:"search_domains_enabled"` SearchDomainsEnabled bool `json:"search_domains_enabled"`
} }
// NameserverGroupRequest defines model for NameserverGroupRequest. // NameserverGroupRequest defines model for NameserverGroupRequest.
type NameserverGroupRequest struct { type NameserverGroupRequest struct {
// Description Nameserver group description // Description Description of the nameserver group
Description string `json:"description"` Description string `json:"description"`
// Domains Nameserver group match domain list // Domains Match domain list. It should be empty only if primary is true.
Domains []string `json:"domains"` Domains []string `json:"domains"`
// Enabled Nameserver group status // Enabled Nameserver group status
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
// Groups Nameserver group tag groups // Groups Distribution group IDs that defines group of peers that will use this nameserver group
Groups []string `json:"groups"` Groups []string `json:"groups"`
// Name Nameserver group name // Name Name of nameserver group name
Name string `json:"name"` Name string `json:"name"`
// Nameservers Nameserver group // Nameservers Nameserver list
Nameservers []Nameserver `json:"nameservers"` Nameservers []Nameserver `json:"nameservers"`
// Primary Nameserver group primary status // Primary Defines if a nameserver group is primary that resolves all domains. It should be true only if domains list is empty.
Primary bool `json:"primary"` Primary bool `json:"primary"`
// SearchDomainsEnabled Nameserver group search domain status for match domains. It should be true only if domains list is not empty. // SearchDomainsEnabled Search domain status for match domains. It should be true only if domains list is not empty.
SearchDomainsEnabled bool `json:"search_domains_enabled"` SearchDomainsEnabled bool `json:"search_domains_enabled"`
} }
@@ -323,6 +334,9 @@ type Peer struct {
// AccessiblePeers List of accessible peers // AccessiblePeers List of accessible peers
AccessiblePeers []AccessiblePeer `json:"accessible_peers"` AccessiblePeers []AccessiblePeer `json:"accessible_peers"`
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired *bool `json:"approval_required,omitempty"`
// Connected Peer to Management connection status // Connected Peer to Management connection status
Connected bool `json:"connected"` Connected bool `json:"connected"`
@@ -374,6 +388,9 @@ type Peer struct {
// PeerBase defines model for PeerBase. // PeerBase defines model for PeerBase.
type PeerBase struct { type PeerBase struct {
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired *bool `json:"approval_required,omitempty"`
// Connected Peer to Management connection status // Connected Peer to Management connection status
Connected bool `json:"connected"` Connected bool `json:"connected"`
@@ -428,6 +445,9 @@ type PeerBatch struct {
// AccessiblePeersCount Number of accessible peers // AccessiblePeersCount Number of accessible peers
AccessiblePeersCount int `json:"accessible_peers_count"` AccessiblePeersCount int `json:"accessible_peers_count"`
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired *bool `json:"approval_required,omitempty"`
// Connected Peer to Management connection status // Connected Peer to Management connection status
Connected bool `json:"connected"` Connected bool `json:"connected"`
@@ -488,6 +508,8 @@ type PeerMinimum struct {
// PeerRequest defines model for PeerRequest. // PeerRequest defines model for PeerRequest.
type PeerRequest struct { type PeerRequest struct {
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired *bool `json:"approval_required,omitempty"`
LoginExpirationEnabled bool `json:"login_expiration_enabled"` LoginExpirationEnabled bool `json:"login_expiration_enabled"`
Name string `json:"name"` Name string `json:"name"`
SshEnabled bool `json:"ssh_enabled"` SshEnabled bool `json:"ssh_enabled"`
@@ -581,7 +603,7 @@ type PolicyRule struct {
// Description Policy rule friendly description // Description Policy rule friendly description
Description *string `json:"description,omitempty"` Description *string `json:"description,omitempty"`
// Destinations Policy rule destination groups // Destinations Policy rule destination group IDs
Destinations []GroupMinimum `json:"destinations"` Destinations []GroupMinimum `json:"destinations"`
// Enabled Policy rule status // Enabled Policy rule status
@@ -599,7 +621,7 @@ type PolicyRule struct {
// Protocol Policy rule type of the traffic // Protocol Policy rule type of the traffic
Protocol PolicyRuleProtocol `json:"protocol"` Protocol PolicyRuleProtocol `json:"protocol"`
// Sources Policy rule source groups // Sources Policy rule source group IDs
Sources []GroupMinimum `json:"sources"` Sources []GroupMinimum `json:"sources"`
} }
@@ -653,7 +675,7 @@ type PolicyRuleUpdate struct {
// Description Policy rule friendly description // Description Policy rule friendly description
Description *string `json:"description,omitempty"` Description *string `json:"description,omitempty"`
// Destinations Policy rule destination groups // Destinations Policy rule destination group IDs
Destinations []string `json:"destinations"` Destinations []string `json:"destinations"`
// Enabled Policy rule status // Enabled Policy rule status
@@ -671,7 +693,7 @@ type PolicyRuleUpdate struct {
// Protocol Policy rule type of the traffic // Protocol Policy rule type of the traffic
Protocol PolicyRuleUpdateProtocol `json:"protocol"` Protocol PolicyRuleUpdateProtocol `json:"protocol"`
// Sources Policy rule source groups // Sources Policy rule source group IDs
Sources []string `json:"sources"` Sources []string `json:"sources"`
} }
@@ -710,7 +732,7 @@ type Route struct {
// Enabled Route status // Enabled Route status
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
// Groups Route group tag groups // Groups Group IDs containing routing peers
Groups []string `json:"groups"` Groups []string `json:"groups"`
// Id Route Id // Id Route Id
@@ -746,7 +768,7 @@ type RouteRequest struct {
// Enabled Route status // Enabled Route status
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
// Groups Route group tag groups // Groups Group IDs containing routing peers
Groups []string `json:"groups"` Groups []string `json:"groups"`
// Masquerade Indicate if peer should masquerade traffic to this route's prefix // Masquerade Indicate if peer should masquerade traffic to this route's prefix
@@ -773,7 +795,7 @@ type Rule struct {
// Description Rule friendly description // Description Rule friendly description
Description string `json:"description"` Description string `json:"description"`
// Destinations Rule destination groups // Destinations Rule destination group IDs
Destinations []GroupMinimum `json:"destinations"` Destinations []GroupMinimum `json:"destinations"`
// Disabled Rules status // Disabled Rules status
@@ -788,7 +810,7 @@ type Rule struct {
// Name Rule name identifier // Name Rule name identifier
Name string `json:"name"` Name string `json:"name"`
// Sources Rule source groups // Sources Rule source group IDs
Sources []GroupMinimum `json:"sources"` Sources []GroupMinimum `json:"sources"`
} }
@@ -812,7 +834,7 @@ type RuleRequest struct {
// Description Rule friendly description // Description Rule friendly description
Description string `json:"description"` Description string `json:"description"`
// Destinations List of destination groups // Destinations List of destination group IDs
Destinations *[]string `json:"destinations,omitempty"` Destinations *[]string `json:"destinations,omitempty"`
// Disabled Rules status // Disabled Rules status
@@ -824,7 +846,7 @@ type RuleRequest struct {
// Name Rule name identifier // Name Rule name identifier
Name string `json:"name"` Name string `json:"name"`
// Sources List of source groups // Sources List of source group IDs
Sources *[]string `json:"sources,omitempty"` Sources *[]string `json:"sources,omitempty"`
} }
@@ -899,7 +921,7 @@ type SetupKeyRequest struct {
// User defines model for User. // User defines model for User.
type User struct { type User struct {
// AutoGroups Groups to auto-assign to peers registered by this user // AutoGroups Group IDs to auto-assign to peers registered by this user
AutoGroups []string `json:"auto_groups"` AutoGroups []string `json:"auto_groups"`
// Email User's email address // Email User's email address
@@ -938,7 +960,7 @@ type UserStatus string
// UserCreateRequest defines model for UserCreateRequest. // UserCreateRequest defines model for UserCreateRequest.
type UserCreateRequest struct { type UserCreateRequest struct {
// AutoGroups Groups to auto-assign to peers registered by this user // AutoGroups Group IDs to auto-assign to peers registered by this user
AutoGroups []string `json:"auto_groups"` AutoGroups []string `json:"auto_groups"`
// Email User's Email to send invite to // Email User's Email to send invite to
@@ -956,7 +978,7 @@ type UserCreateRequest struct {
// UserRequest defines model for UserRequest. // UserRequest defines model for UserRequest.
type UserRequest struct { type UserRequest struct {
// AutoGroups Groups to auto-assign to peers registered by this user // AutoGroups Group IDs to auto-assign to peers registered by this user
AutoGroups []string `json:"auto_groups"` AutoGroups []string `json:"auto_groups"`
// IsBlocked If set to true then user is blocked and can't use the system // IsBlocked If set to true then user is blocked and can't use the system

View File

@@ -19,10 +19,11 @@ import (
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
var TestPeers = map[string]*server.Peer{ var TestPeers = map[string]*nbpeer.Peer{
"A": {Key: "A", ID: "peer-A-ID", IP: net.ParseIP("100.100.100.100")}, "A": {Key: "A", ID: "peer-A-ID", IP: net.ParseIP("100.100.100.100")},
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
} }

View File

@@ -34,12 +34,20 @@ type emptyObject struct {
// APIHandler creates the Management service HTTP API handler registering all the available endpoints. // APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) { func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
)
authMiddleware := middleware.NewAuthMiddleware( authMiddleware := middleware.NewAuthMiddleware(
accountManager.GetAccountFromPAT, accountManager.GetAccountFromPAT,
jwtValidator.ValidateAndParse, jwtValidator.ValidateAndParse,
accountManager.MarkPATUsed, accountManager.MarkPATUsed,
accountManager.GetAccountFromToken,
claimsExtractor,
authCfg.Audience, authCfg.Audience,
authCfg.UserIDClaim) authCfg.UserIDClaim,
)
corsMiddleware := cors.AllowAll() corsMiddleware := cors.AllowAll()
@@ -60,11 +68,6 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid
AuthCfg: authCfg, AuthCfg: authCfg,
} }
claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
)
integrations.RegisterHandlers(api.Router, accountManager, claimsExtractor) integrations.RegisterHandlers(api.Router, accountManager, claimsExtractor)
api.addAccountsEndpoint() api.addAccountsEndpoint()
api.addPeersEndpoint() api.addPeersEndpoint()

View File

@@ -26,11 +26,16 @@ type ValidateAndParseTokenFunc func(token string) (*jwt.Token, error)
// MarkPATUsedFunc function // MarkPATUsedFunc function
type MarkPATUsedFunc func(token string) error type MarkPATUsedFunc func(token string) error
// GetAccountFromTokenFunc function
type GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct { type AuthMiddleware struct {
getAccountFromPAT GetAccountFromPATFunc getAccountFromPAT GetAccountFromPATFunc
validateAndParseToken ValidateAndParseTokenFunc validateAndParseToken ValidateAndParseTokenFunc
markPATUsed MarkPATUsedFunc markPATUsed MarkPATUsedFunc
getAccountFromToken GetAccountFromTokenFunc
claimsExtractor *jwtclaims.ClaimsExtractor
audience string audience string
userIDClaim string userIDClaim string
} }
@@ -40,14 +45,19 @@ const (
) )
// NewAuthMiddleware instance constructor // NewAuthMiddleware instance constructor
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, audience string, userIdClaim string) *AuthMiddleware { func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
markPATUsed MarkPATUsedFunc, getAccountFromToken GetAccountFromTokenFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
audience string, userIdClaim string) *AuthMiddleware {
if userIdClaim == "" { if userIdClaim == "" {
userIdClaim = jwtclaims.UserIDClaim userIdClaim = jwtclaims.UserIDClaim
} }
return &AuthMiddleware{ return &AuthMiddleware{
getAccountFromPAT: getAccountFromPAT, getAccountFromPAT: getAccountFromPAT,
validateAndParseToken: validateAndParseToken, validateAndParseToken: validateAndParseToken,
markPATUsed: markPATUsed, markPATUsed: markPATUsed,
getAccountFromToken: getAccountFromToken,
claimsExtractor: claimsExtractor,
audience: audience, audience: audience,
userIDClaim: userIdClaim, userIDClaim: userIdClaim,
} }
@@ -107,6 +117,10 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
return nil return nil
} }
if err := m.verifyUserAccess(validatedToken); err != nil {
return err
}
// If we get here, everything worked and we can set the // If we get here, everything worked and we can set the
// user property in context. // user property in context.
newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint newRequest := r.WithContext(context.WithValue(r.Context(), userProperty, validatedToken)) //nolint
@@ -115,6 +129,41 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ
return nil return nil
} }
// verifyUserAccess checks if a user, based on a validated JWT token,
// is allowed access, particularly in cases where the admin enabled JWT
// group propagation and designated certain groups with access permissions.
func (m *AuthMiddleware) verifyUserAccess(validatedToken *jwt.Token) error {
authClaims := m.claimsExtractor.FromToken(validatedToken)
account, _, err := m.getAccountFromToken(authClaims)
if err != nil {
return fmt.Errorf("failed to get the account from token: %w", err)
}
// Ensures JWT group synchronization to the management is enabled before,
// filtering access based on the allowed groups.
if account.Settings != nil && account.Settings.JWTGroupsEnabled {
if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 {
userJWTGroups := make([]string, 0)
if claim, ok := authClaims.Raw[account.Settings.JWTGroupsClaimName]; ok {
if claimGroups, ok := claim.([]interface{}); ok {
for _, g := range claimGroups {
if group, ok := g.(string); ok {
userJWTGroups = append(userJWTGroups, group)
}
}
}
}
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
}
}
}
return nil
}
// CheckPATFromRequest checks if the PAT is valid // CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error { func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
token, err := getTokenFromPATRequest(auth) token, err := getTokenFromPATRequest(auth)
@@ -168,3 +217,15 @@ func getTokenFromPATRequest(authHeaderParts []string) (string, error) {
return authHeaderParts[1], nil return authHeaderParts[1], nil
} }
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
for _, userGroup := range userGroups {
for _, allowedGroup := range allowedGroups {
if userGroup == allowedGroup {
return true
}
}
}
return false
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
) )
const ( const (
@@ -54,7 +55,13 @@ func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server
func mockValidateAndParseToken(token string) (*jwt.Token, error) { func mockValidateAndParseToken(token string) (*jwt.Token, error) {
if token == JWT { if token == JWT {
return &jwt.Token{}, nil return &jwt.Token{
Claims: jwt.MapClaims{
userIDClaim: userID,
audience + jwtclaims.AccountIDSuffix: accountID,
},
Valid: true,
}, nil
} }
return nil, fmt.Errorf("JWT invalid") return nil, fmt.Errorf("JWT invalid")
} }
@@ -66,6 +73,19 @@ func mockMarkPATUsed(token string) error {
return fmt.Errorf("Should never get reached") return fmt.Errorf("Should never get reached")
} }
func mockGetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
if testAccount.Id != claims.AccountId {
return nil, nil, fmt.Errorf("account with id %s does not exist", claims.AccountId)
}
user, ok := testAccount.Users[claims.UserId]
if !ok {
return nil, nil, fmt.Errorf("user with id %s does not exist", claims.UserId)
}
return testAccount, user, nil
}
func TestAuthMiddleware_Handler(t *testing.T) { func TestAuthMiddleware_Handler(t *testing.T) {
tt := []struct { tt := []struct {
name string name string
@@ -108,7 +128,20 @@ func TestAuthMiddleware_Handler(t *testing.T) {
// do nothing // do nothing
}) })
authMiddleware := NewAuthMiddleware(mockGetAccountFromPAT, mockValidateAndParseToken, mockMarkPATUsed, audience, userIDClaim) claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(audience),
jwtclaims.WithUserIDClaim(userIDClaim),
)
authMiddleware := NewAuthMiddleware(
mockGetAccountFromPAT,
mockValidateAndParseToken,
mockMarkPATUsed,
mockGetAccountFromToken,
claimsExtractor,
audience,
userIDClaim,
)
handlerToTest := authMiddleware.Handler(nextHandler) handlerToTest := authMiddleware.Handler(nextHandler)

View File

@@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
@@ -31,17 +32,12 @@ func NewPeersHandler(accountManager server.AccountManager, authCfg AuthCfg) *Pee
} }
} }
func (h *PeersHandler) checkPeerStatus(peer *server.Peer) (*server.Peer, error) { func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) {
peerToReturn := peer.Copy() peerToReturn := peer.Copy()
if peer.Status.Connected { if peer.Status.Connected {
statuses, err := h.accountManager.GetAllConnectedPeers()
if err != nil {
return peerToReturn, err
}
// Although we have online status in store we do not yet have an updated channel so have to show it as disconnected // Although we have online status in store we do not yet have an updated channel so have to show it as disconnected
// This may happen after server restart when not all peers are yet connected // This may happen after server restart when not all peers are yet connected
if _, connected := statuses[peerToReturn.ID]; !connected { if !h.accountManager.HasConnectedChannel(peer.ID) {
peerToReturn.Status.Connected = false peerToReturn.Status.Connected = false
} }
} }
@@ -79,8 +75,13 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe
return return
} }
update := &server.Peer{ID: peerID, SSHEnabled: req.SshEnabled, Name: req.Name, update := &nbpeer.Peer{ID: peerID, SSHEnabled: req.SshEnabled, Name: req.Name,
LoginExpirationEnabled: req.LoginExpirationEnabled} LoginExpirationEnabled: req.LoginExpirationEnabled}
if req.ApprovalRequired != nil {
update.Status = &nbpeer.PeerStatus{RequiresApproval: *req.ApprovalRequired}
}
peer, err := h.accountManager.UpdatePeer(account.Id, user.Id, update) peer, err := h.accountManager.UpdatePeer(account.Id, user.Id, update)
if err != nil { if err != nil {
util.WriteError(err, w) util.WriteError(err, w)
@@ -229,7 +230,7 @@ func toGroupsInfo(groups map[string]*server.Group, peerID string) []api.GroupMin
return groupsInfo return groupsInfo
} }
func toSinglePeerResponse(peer *server.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer) *api.Peer { func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer) *api.Peer {
return &api.Peer{ return &api.Peer{
Id: peer.ID, Id: peer.ID,
Name: peer.Name, Name: peer.Name,
@@ -248,10 +249,11 @@ func toSinglePeerResponse(peer *server.Peer, groupsInfo []api.GroupMinimum, dnsD
LastLogin: peer.LastLogin, LastLogin: peer.LastLogin,
LoginExpired: peer.Status.LoginExpired, LoginExpired: peer.Status.LoginExpired,
AccessiblePeers: accessiblePeer, AccessiblePeers: accessiblePeer,
ApprovalRequired: &peer.Status.RequiresApproval,
} }
} }
func toPeerListItemResponse(peer *server.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch { func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch {
return &api.PeerBatch{ return &api.PeerBatch{
Id: peer.ID, Id: peer.ID,
Name: peer.Name, Name: peer.Name,
@@ -270,10 +272,11 @@ func toPeerListItemResponse(peer *server.Peer, groupsInfo []api.GroupMinimum, dn
LastLogin: peer.LastLogin, LastLogin: peer.LastLogin,
LoginExpired: peer.Status.LoginExpired, LoginExpired: peer.Status.LoginExpired,
AccessiblePeersCount: accessiblePeersCount, AccessiblePeersCount: accessiblePeersCount,
ApprovalRequired: &peer.Status.RequiresApproval,
} }
} }
func fqdn(peer *server.Peer, dnsDomain string) string { func fqdn(peer *nbpeer.Peer, dnsDomain string) string {
fqdn := peer.FQDN(dnsDomain) fqdn := peer.FQDN(dnsDomain)
if fqdn == "" { if fqdn == "" {
return peer.DNSLabel return peer.DNSLabel

View File

@@ -13,6 +13,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -25,11 +26,11 @@ import (
const testPeerID = "test_peer" const testPeerID = "test_peer"
const noUpdateChannelTestPeerID = "no-update-channel" const noUpdateChannelTestPeerID = "no-update-channel"
func initTestMetaData(peers ...*server.Peer) *PeersHandler { func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return &PeersHandler{ return &PeersHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
UpdatePeerFunc: func(accountID, userID string, update *server.Peer) (*server.Peer, error) { UpdatePeerFunc: func(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
var p *server.Peer var p *nbpeer.Peer
for _, peer := range peers { for _, peer := range peers {
if update.ID == peer.ID { if update.ID == peer.ID {
p = peer.Copy() p = peer.Copy()
@@ -41,8 +42,8 @@ func initTestMetaData(peers ...*server.Peer) *PeersHandler {
p.Name = update.Name p.Name = update.Name
return p, nil return p, nil
}, },
GetPeerFunc: func(accountID, peerID, userID string) (*server.Peer, error) { GetPeerFunc: func(accountID, peerID, userID string) (*nbpeer.Peer, error) {
var p *server.Peer var p *nbpeer.Peer
for _, peer := range peers { for _, peer := range peers {
if peerID == peer.ID { if peerID == peer.ID {
p = peer.Copy() p = peer.Copy()
@@ -51,7 +52,7 @@ func initTestMetaData(peers ...*server.Peer) *PeersHandler {
} }
return p, nil return p, nil
}, },
GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) { GetPeersFunc: func(accountID, userID string) ([]*nbpeer.Peer, error) {
return peers, nil return peers, nil
}, },
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
@@ -59,8 +60,9 @@ func initTestMetaData(peers ...*server.Peer) *PeersHandler {
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,
Domain: "hotmail.com", Domain: "hotmail.com",
Peers: map[string]*server.Peer{ Peers: map[string]*nbpeer.Peer{
peers[0].ID: peers[0], peers[0].ID: peers[0],
peers[1].ID: peers[1],
}, },
Users: map[string]*server.User{ Users: map[string]*server.User{
"test_user": user, "test_user": user,
@@ -79,8 +81,7 @@ func initTestMetaData(peers ...*server.Peer) *PeersHandler {
}, },
}, user, nil }, user, nil
}, },
HasConnectedChannelFunc: func(peerID string) bool {
GetAllConnectedPeersFunc: func() (map[string]struct{}, error) {
statuses := make(map[string]struct{}) statuses := make(map[string]struct{})
for _, peer := range peers { for _, peer := range peers {
if peer.ID == noUpdateChannelTestPeerID { if peer.ID == noUpdateChannelTestPeerID {
@@ -88,7 +89,8 @@ func initTestMetaData(peers ...*server.Peer) *PeersHandler {
} }
statuses[peer.ID] = struct{}{} statuses[peer.ID] = struct{}{}
} }
return statuses, nil _, ok := statuses[peerID]
return ok
}, },
}, },
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
@@ -107,15 +109,15 @@ func initTestMetaData(peers ...*server.Peer) *PeersHandler {
// Use the metadata generated by initTestMetaData() to check for values // Use the metadata generated by initTestMetaData() to check for values
func TestGetPeers(t *testing.T) { func TestGetPeers(t *testing.T) {
peer := &server.Peer{ peer := &nbpeer.Peer{
ID: testPeerID, ID: testPeerID,
Key: "key", Key: "key",
SetupKey: "setupkey", SetupKey: "setupkey",
IP: net.ParseIP("100.64.0.1"), IP: net.ParseIP("100.64.0.1"),
Status: &server.PeerStatus{Connected: true}, Status: &nbpeer.PeerStatus{Connected: true},
Name: "PeerName", Name: "PeerName",
LoginExpirationEnabled: false, LoginExpirationEnabled: false,
Meta: server.PeerSystemMeta{ Meta: nbpeer.PeerSystemMeta{
Hostname: "hostname", Hostname: "hostname",
GoOS: "GoOS", GoOS: "GoOS",
Kernel: "kernel", Kernel: "kernel",
@@ -144,7 +146,7 @@ func TestGetPeers(t *testing.T) {
requestPath string requestPath string
requestBody io.Reader requestBody io.Reader
expectedArray bool expectedArray bool
expectedPeer *server.Peer expectedPeer *nbpeer.Peer
}{ }{
{ {
name: "GetPeersMetaData", name: "GetPeersMetaData",

View File

@@ -11,6 +11,7 @@ import (
"testing" "testing"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@@ -55,12 +56,12 @@ var baseExistingRoute = &route.Route{
var testingAccount = &server.Account{ var testingAccount = &server.Account{
Id: testAccountID, Id: testAccountID,
Domain: "hotmail.com", Domain: "hotmail.com",
Peers: map[string]*server.Peer{ Peers: map[string]*nbpeer.Peer{
existingPeerID: { existingPeerID: {
Key: existingPeerKey, Key: existingPeerKey,
IP: netip.MustParseAddr(existingPeerIP1).AsSlice(), IP: netip.MustParseAddr(existingPeerIP1).AsSlice(),
ID: existingPeerID, ID: existingPeerID,
Meta: server.PeerSystemMeta{ Meta: nbpeer.PeerSystemMeta{
GoOS: "linux", GoOS: "linux",
}, },
}, },
@@ -68,7 +69,7 @@ var testingAccount = &server.Account{
Key: nonLinuxExistingPeerID, Key: nonLinuxExistingPeerID,
IP: netip.MustParseAddr(existingPeerIP2).AsSlice(), IP: netip.MustParseAddr(existingPeerIP2).AsSlice(),
ID: nonLinuxExistingPeerID, ID: nonLinuxExistingPeerID,
Meta: server.PeerSystemMeta{ Meta: nbpeer.PeerSystemMeta{
GoOS: "darwin", GoOS: "darwin",
}, },
}, },

View File

@@ -108,6 +108,8 @@ func NewJWTValidator(issuer string, audienceList []string, keysLocation string,
refreshedKeys = keys refreshedKeys = keys
} }
log.Debugf("keys refreshed, new UTC expiration time: %s", refreshedKeys.expiresInTime.UTC())
keys = refreshedKeys keys = refreshedKeys
} }
} }
@@ -179,7 +181,7 @@ func (m *JWTValidator) ValidateAndParse(token string) (*jwt.Token, error) {
// stillValid returns true if the JSONWebKey still valid and have enough time to be used // stillValid returns true if the JSONWebKey still valid and have enough time to be used
func (jwks *Jwks) stillValid() bool { func (jwks *Jwks) stillValid() bool {
return jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime) return !jwks.expiresInTime.IsZero() && time.Now().Add(5*time.Second).Before(jwks.expiresInTime)
} }
func getPemKeys(keysLocation string) (*Jwks, error) { func getPemKeys(keysLocation string) (*Jwks, error) {

View File

@@ -5,6 +5,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -37,12 +38,12 @@ func (mockDatasource) GetAllAccounts() []*server.Account {
NameServerGroups: map[string]*nbdns.NameServerGroup{ NameServerGroups: map[string]*nbdns.NameServerGroup{
"1": {}, "1": {},
}, },
Peers: map[string]*server.Peer{ Peers: map[string]*nbpeer.Peer{
"1": { "1": {
ID: "1", ID: "1",
UserID: "test", UserID: "test",
SSHEnabled: true, SSHEnabled: true,
Meta: server.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"}, Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"},
}, },
}, },
Policies: []*server.Policy{ Policies: []*server.Policy{
@@ -101,12 +102,12 @@ func (mockDatasource) GetAllAccounts() []*server.Account {
NameServerGroups: map[string]*nbdns.NameServerGroup{ NameServerGroups: map[string]*nbdns.NameServerGroup{
"1": {}, "1": {},
}, },
Peers: map[string]*server.Peer{ Peers: map[string]*nbpeer.Peer{
"1": { "1": {
ID: "1", ID: "1",
UserID: "test", UserID: "test",
SSHEnabled: true, SSHEnabled: true,
Meta: server.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"}, Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"},
}, },
}, },
Policies: []*server.Policy{ Policies: []*server.Policy{

View File

@@ -10,6 +10,7 @@ import (
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -21,12 +22,12 @@ type MockAccountManager struct {
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error) GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(accountID string) ([]*server.User, error) ListUsersFunc func(accountID string) ([]*server.User, error)
GetPeersFunc func(accountID, userID string) ([]*server.Peer, error) GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(peerKey string, connected bool) error MarkPeerConnectedFunc func(peerKey string, connected bool) error
DeletePeerFunc func(accountID, peerKey, userID string) error DeletePeerFunc func(accountID, peerKey, userID string) error
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
GetPeerNetworkFunc func(peerKey string) (*server.Network, error) GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, *server.NetworkMap, error) AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error)
GetGroupFunc func(accountID, groupID string) (*server.Group, error) GetGroupFunc func(accountID, groupID string) (*server.Group, error)
SaveGroupFunc func(accountID, userID string, group *server.Group) error SaveGroupFunc func(accountID, userID string, group *server.Group) error
DeleteGroupFunc func(accountID, userId, groupID string) error DeleteGroupFunc func(accountID, userId, groupID string) error
@@ -44,9 +45,9 @@ type MockAccountManager struct {
GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error) GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error)
GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
MarkPATUsedFunc func(pat string) error MarkPATUsedFunc func(pat string) error
UpdatePeerMetaFunc func(peerID string, meta server.PeerSystemMeta) error UpdatePeerMetaFunc func(peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error
UpdatePeerFunc func(accountID, userID string, peer *server.Peer) (*server.Peer, error) UpdatePeerFunc func(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) CreateRouteFunc func(accountID, prefix, peer string, peerGroups []string, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error) GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
SaveRouteFunc func(accountID, userID string, route *route.Route) error SaveRouteFunc func(accountID, userID string, route *route.Route) error
@@ -74,12 +75,13 @@ type MockAccountManager struct {
GetEventsFunc func(accountID, userID string) ([]*activity.Event, error) GetEventsFunc func(accountID, userID string) ([]*activity.Event, error)
GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error) GetDNSSettingsFunc func(accountID, userID string) (*server.DNSSettings, error)
SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error SaveDNSSettingsFunc func(accountID, userID string, dnsSettingsToSave *server.DNSSettings) error
GetPeerFunc func(accountID, peerID, userID string) (*server.Peer, error) GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error) UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
LoginPeerFunc func(login server.PeerLogin) (*server.Peer, *server.NetworkMap, error) LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error)
SyncPeerFunc func(sync server.PeerSync) (*server.Peer, *server.NetworkMap, error) SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error)
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
GetAllConnectedPeersFunc func() (map[string]struct{}, error) GetAllConnectedPeersFunc func() (map[string]struct{}, error)
HasConnectedChannelFunc func(peerID string) bool
GetExternalCacheManagerFunc func() server.ExternalCacheManager GetExternalCacheManagerFunc func() server.ExternalCacheManager
} }
@@ -226,8 +228,8 @@ func (am *MockAccountManager) GetPeerNetwork(peerKey string) (*server.Network, e
func (am *MockAccountManager) AddPeer( func (am *MockAccountManager) AddPeer(
setupKey string, setupKey string,
userId string, userId string,
peer *server.Peer, peer *nbpeer.Peer,
) (*server.Peer, *server.NetworkMap, error) { ) (*nbpeer.Peer, *server.NetworkMap, error) {
if am.AddPeerFunc != nil { if am.AddPeerFunc != nil {
return am.AddPeerFunc(setupKey, userId, peer) return am.AddPeerFunc(setupKey, userId, peer)
} }
@@ -347,7 +349,7 @@ func (am *MockAccountManager) ListPolicies(accountID, userID string) ([]*server.
} }
// UpdatePeerMeta mock implementation of UpdatePeerMeta from server.AccountManager interface // UpdatePeerMeta mock implementation of UpdatePeerMeta from server.AccountManager interface
func (am *MockAccountManager) UpdatePeerMeta(peerID string, meta server.PeerSystemMeta) error { func (am *MockAccountManager) UpdatePeerMeta(peerID string, meta nbpeer.PeerSystemMeta) error {
if am.UpdatePeerMetaFunc != nil { if am.UpdatePeerMetaFunc != nil {
return am.UpdatePeerMetaFunc(peerID, meta) return am.UpdatePeerMetaFunc(peerID, meta)
} }
@@ -364,7 +366,7 @@ func (am *MockAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*se
func (am *MockAccountManager) ListUsers(accountID string) ([]*server.User, error) { func (am *MockAccountManager) ListUsers(accountID string) ([]*server.User, error) {
if am.ListUsersFunc != nil { if am.ListUsersFunc != nil {
return am.ListUsers(accountID) return am.ListUsersFunc(accountID)
} }
return nil, status.Errorf(codes.Unimplemented, "method ListUsers is not implemented") return nil, status.Errorf(codes.Unimplemented, "method ListUsers is not implemented")
} }
@@ -378,7 +380,7 @@ func (am *MockAccountManager) UpdatePeerSSHKey(peerID string, sshKey string) err
} }
// UpdatePeer mocks UpdatePeerFunc function of the account manager // UpdatePeer mocks UpdatePeerFunc function of the account manager
func (am *MockAccountManager) UpdatePeer(accountID, userID string, peer *server.Peer) (*server.Peer, error) { func (am *MockAccountManager) UpdatePeer(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) {
if am.UpdatePeerFunc != nil { if am.UpdatePeerFunc != nil {
return am.UpdatePeerFunc(accountID, userID, peer) return am.UpdatePeerFunc(accountID, userID, peer)
} }
@@ -462,7 +464,7 @@ func (am *MockAccountManager) SaveUser(accountID, userID string, user *server.Us
// SaveOrAddUser mocks SaveOrAddUser of the AccountManager interface // SaveOrAddUser mocks SaveOrAddUser of the AccountManager interface
func (am *MockAccountManager) SaveOrAddUser(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) { func (am *MockAccountManager) SaveOrAddUser(accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) {
if am.SaveUserFunc != nil { if am.SaveOrAddUserFunc != nil {
return am.SaveOrAddUserFunc(accountID, userID, user, addIfNotExists) return am.SaveOrAddUserFunc(accountID, userID, user, addIfNotExists)
} }
return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUser is not implemented") return nil, status.Errorf(codes.Unimplemented, "method SaveOrAddUser is not implemented")
@@ -542,8 +544,8 @@ func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.Authorization
} }
// GetPeers mocks GetPeers of the AccountManager interface // GetPeers mocks GetPeers of the AccountManager interface
func (am *MockAccountManager) GetPeers(accountID, userID string) ([]*server.Peer, error) { func (am *MockAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.Peer, error) {
if am.GetAccountFromTokenFunc != nil { if am.GetPeersFunc != nil {
return am.GetPeersFunc(accountID, userID) return am.GetPeersFunc(accountID, userID)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetPeers is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetPeers is not implemented")
@@ -582,7 +584,7 @@ func (am *MockAccountManager) SaveDNSSettings(accountID string, userID string, d
} }
// GetPeer mocks GetPeer of the AccountManager interface // GetPeer mocks GetPeer of the AccountManager interface
func (am *MockAccountManager) GetPeer(accountID, peerID, userID string) (*server.Peer, error) { func (am *MockAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) {
if am.GetPeerFunc != nil { if am.GetPeerFunc != nil {
return am.GetPeerFunc(accountID, peerID, userID) return am.GetPeerFunc(accountID, peerID, userID)
} }
@@ -598,7 +600,7 @@ func (am *MockAccountManager) UpdateAccountSettings(accountID, userID string, ne
} }
// LoginPeer mocks LoginPeer of the AccountManager interface // LoginPeer mocks LoginPeer of the AccountManager interface
func (am *MockAccountManager) LoginPeer(login server.PeerLogin) (*server.Peer, *server.NetworkMap, error) { func (am *MockAccountManager) LoginPeer(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error) {
if am.LoginPeerFunc != nil { if am.LoginPeerFunc != nil {
return am.LoginPeerFunc(login) return am.LoginPeerFunc(login)
} }
@@ -606,7 +608,7 @@ func (am *MockAccountManager) LoginPeer(login server.PeerLogin) (*server.Peer, *
} }
// SyncPeer mocks SyncPeer of the AccountManager interface // SyncPeer mocks SyncPeer of the AccountManager interface
func (am *MockAccountManager) SyncPeer(sync server.PeerSync) (*server.Peer, *server.NetworkMap, error) { func (am *MockAccountManager) SyncPeer(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error) {
if am.SyncPeerFunc != nil { if am.SyncPeerFunc != nil {
return am.SyncPeerFunc(sync) return am.SyncPeerFunc(sync)
} }
@@ -621,6 +623,14 @@ func (am *MockAccountManager) GetAllConnectedPeers() (map[string]struct{}, error
return nil, status.Errorf(codes.Unimplemented, "method GetAllConnectedPeers is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetAllConnectedPeers is not implemented")
} }
// HasconnectedChannel mocks HasConnectedChannel of the AccountManager interface
func (am *MockAccountManager) HasConnectedChannel(peerID string) bool {
if am.HasConnectedChannelFunc != nil {
return am.HasConnectedChannelFunc(peerID)
}
return false
}
// StoreEvent mocks StoreEvent of the AccountManager interface // StoreEvent mocks StoreEvent of the AccountManager interface
func (am *MockAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.Activity, meta map[string]any) { func (am *MockAccountManager) StoreEvent(initiatorID, targetID, accountID string, activityID activity.Activity, meta map[string]any) {
if am.StoreEventFunc != nil { if am.StoreEventFunc != nil {

View File

@@ -8,6 +8,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
) )
const ( const (
@@ -763,10 +764,10 @@ func createNSStore(t *testing.T) (Store, error) {
func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) {
t.Helper() t.Helper()
peer1 := &Peer{ peer1 := &nbpeer.Peer{
Key: nsGroupPeer1Key, Key: nsGroupPeer1Key,
Name: "test-host1@netbird.io", Name: "test-host1@netbird.io",
Meta: PeerSystemMeta{ Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host1@netbird.io", Hostname: "test-host1@netbird.io",
GoOS: "linux", GoOS: "linux",
Kernel: "Linux", Kernel: "Linux",
@@ -777,10 +778,10 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error
UIVersion: "development", UIVersion: "development",
}, },
} }
peer2 := &Peer{ peer2 := &nbpeer.Peer{
Key: nsGroupPeer2Key, Key: nsGroupPeer2Key,
Name: "test-host2@netbird.io", Name: "test-host2@netbird.io",
Meta: PeerSystemMeta{ Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host2@netbird.io", Hostname: "test-host2@netbird.io",
GoOS: "linux", GoOS: "linux",
Kernel: "Linux", Kernel: "Linux",

View File

@@ -10,6 +10,7 @@ import (
"github.com/rs/xid" "github.com/rs/xid"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -25,11 +26,11 @@ const (
) )
type NetworkMap struct { type NetworkMap struct {
Peers []*Peer Peers []*nbpeer.Peer
Network *Network Network *Network
Routes []*route.Route Routes []*route.Route
DNSConfig nbdns.Config DNSConfig nbdns.Config
OfflinePeers []*Peer OfflinePeers []*nbpeer.Peer
FirewallRules []*FirewallRule FirewallRules []*FirewallRule
} }

View File

@@ -2,13 +2,14 @@ package server
import ( import (
"fmt" "fmt"
"net"
"strings" "strings"
"time" "time"
"github.com/netbirdio/management-integrations/additions"
"github.com/rs/xid" "github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -16,38 +17,6 @@ import (
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
) )
// PeerSystemMeta is a metadata of a Peer machine system
type PeerSystemMeta struct {
Hostname string
GoOS string
Kernel string
Core string
Platform string
OS string
WtVersion string
UIVersion string
}
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
return p.Hostname == other.Hostname &&
p.GoOS == other.GoOS &&
p.Kernel == other.Kernel &&
p.Core == other.Core &&
p.Platform == other.Platform &&
p.OS == other.OS &&
p.WtVersion == other.WtVersion &&
p.UIVersion == other.UIVersion
}
type PeerStatus struct {
// LastSeen is the last time peer was connected to the management service
LastSeen time.Time
// Connected indicates whether peer is connected to the management service or not
Connected bool
// LoginExpired
LoginExpired bool
}
// PeerSync used as a data object between the gRPC API and AccountManager on Sync request. // PeerSync used as a data object between the gRPC API and AccountManager on Sync request.
type PeerSync struct { type PeerSync struct {
// WireGuardPubKey is a peers WireGuard public key // WireGuardPubKey is a peers WireGuard public key
@@ -61,146 +30,16 @@ type PeerLogin struct {
// SSHKey is a peer's ssh key. Can be empty (e.g., old version do not provide it, or this feature is disabled) // SSHKey is a peer's ssh key. Can be empty (e.g., old version do not provide it, or this feature is disabled)
SSHKey string SSHKey string
// Meta is the system information passed by peer, must be always present. // Meta is the system information passed by peer, must be always present.
Meta PeerSystemMeta Meta nbpeer.PeerSystemMeta
// UserID indicates that JWT was used to log in, and it was valid. Can be empty when SetupKey is used or auth is not required. // UserID indicates that JWT was used to log in, and it was valid. Can be empty when SetupKey is used or auth is not required.
UserID string UserID string
// SetupKey references to a server.SetupKey to log in. Can be empty when UserID is used or auth is not required. // SetupKey references to a server.SetupKey to log in. Can be empty when UserID is used or auth is not required.
SetupKey string SetupKey string
} }
// Peer represents a machine connected to the network.
// The Peer is a WireGuard peer identified by a public key
type Peer struct {
// ID is an internal ID of the peer
ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index;uniqueIndex:idx_peers_account_id_ip"`
// WireGuard public key
Key string `gorm:"index"`
// A setup key this peer was registered with
SetupKey string
// IP address of the Peer
IP net.IP `gorm:"uniqueIndex:idx_peers_account_id_ip"`
// Meta is a Peer system meta data
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
// Name is peer's name (machine name)
Name string
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
DNSLabel string
// Status peer's management connection status
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
// The user ID that registered the peer
UserID string
// SSHKey is a public SSH key of the peer
SSHKey string
// SSHEnabled indicates whether SSH server is enabled on the peer
SSHEnabled bool
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
// Works with LastLogin
LoginExpirationEnabled bool
// LastLogin the time when peer performed last login operation
LastLogin time.Time
// Indicate ephemeral peer attribute
Ephemeral bool
}
// AddedWithSSOLogin indicates whether this peer has been added with an SSO login by a user.
func (p *Peer) AddedWithSSOLogin() bool {
return p.UserID != ""
}
// Copy copies Peer object
func (p *Peer) Copy() *Peer {
peerStatus := p.Status
if peerStatus != nil {
peerStatus = p.Status.Copy()
}
return &Peer{
ID: p.ID,
AccountID: p.AccountID,
Key: p.Key,
SetupKey: p.SetupKey,
IP: p.IP,
Meta: p.Meta,
Name: p.Name,
DNSLabel: p.DNSLabel,
Status: peerStatus,
UserID: p.UserID,
SSHKey: p.SSHKey,
SSHEnabled: p.SSHEnabled,
LoginExpirationEnabled: p.LoginExpirationEnabled,
LastLogin: p.LastLogin,
Ephemeral: p.Ephemeral,
}
}
// UpdateMetaIfNew updates peer's system metadata if new information is provided
// returns true if meta was updated, false otherwise
func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) bool {
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
if meta.UIVersion == "" {
meta.UIVersion = p.Meta.UIVersion
}
if p.Meta.isEqual(meta) {
return false
}
p.Meta = meta
return true
}
// MarkLoginExpired marks peer's status expired or not
func (p *Peer) MarkLoginExpired(expired bool) {
newStatus := p.Status.Copy()
newStatus.LoginExpired = expired
if expired {
newStatus.Connected = false
}
p.Status = newStatus
}
// LoginExpired indicates whether the peer's login has expired or not.
// If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired.
// Return true if a login has expired, false otherwise, and time left to expiration (negative when expired).
// Login expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property.
// Login expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled.
// Only peers added by interactive SSO login can be expired.
func (p *Peer) LoginExpired(expiresIn time.Duration) (bool, time.Duration) {
if !p.AddedWithSSOLogin() || !p.LoginExpirationEnabled {
return false, 0
}
expiresAt := p.LastLogin.Add(expiresIn)
now := time.Now()
timeLeft := expiresAt.Sub(now)
return timeLeft <= 0, timeLeft
}
// FQDN returns peers FQDN combined of the peer's DNS label and the system's DNS domain
func (p *Peer) FQDN(dnsDomain string) string {
if dnsDomain == "" {
return ""
}
return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain)
}
// EventMeta returns activity event meta related to the peer
func (p *Peer) EventMeta(dnsDomain string) map[string]any {
return map[string]any{"name": p.Name, "fqdn": p.FQDN(dnsDomain), "ip": p.IP}
}
// Copy PeerStatus
func (p *PeerStatus) Copy() *PeerStatus {
return &PeerStatus{
LastSeen: p.LastSeen,
Connected: p.Connected,
LoginExpired: p.LoginExpired,
}
}
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// the current user is not an admin. // the current user is not an admin.
func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, error) { func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.Peer, error) {
account, err := am.Store.GetAccount(accountID) account, err := am.Store.GetAccount(accountID)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -211,8 +50,8 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, er
return nil, err return nil, err
} }
peers := make([]*Peer, 0) peers := make([]*nbpeer.Peer, 0)
peersMap := make(map[string]*Peer) peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range account.Peers { for _, peer := range account.Peers {
if !user.HasAdminPower() && user.Id != peer.UserID { if !user.HasAdminPower() && user.Id != peer.UserID {
// only display peers that belong to the current user if the current user is not an admin // only display peers that belong to the current user if the current user is not an admin
@@ -231,7 +70,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, er
} }
} }
peers = make([]*Peer, 0, len(peersMap)) peers = make([]*nbpeer.Peer, 0, len(peersMap))
for _, peer := range peersMap { for _, peer := range peersMap {
peers = append(peers, peer) peers = append(peers, peer)
} }
@@ -290,7 +129,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected
} }
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated. // UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated.
func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *Peer) (*Peer, error) { func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@@ -304,6 +143,11 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *Pe
return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID) return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID)
} }
update, err = additions.ValidatePeersUpdateRequest(update, peer, userID, accountID, am.eventStore, am.GetDNSDomain())
if err != nil {
return nil, err
}
if peer.SSHEnabled != update.SSHEnabled { if peer.SSHEnabled != update.SSHEnabled {
peer.SSHEnabled = update.SSHEnabled peer.SSHEnabled = update.SSHEnabled
event := activity.PeerSSHEnabled event := activity.PeerSSHEnabled
@@ -364,7 +208,7 @@ func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string,
// the first loop is needed to ensure all peers present under the account before modifying, otherwise // the first loop is needed to ensure all peers present under the account before modifying, otherwise
// we might have some inconsistencies // we might have some inconsistencies
peers := make([]*Peer, 0, len(peerIDs)) peers := make([]*nbpeer.Peer, 0, len(peerIDs))
for _, peerID := range peerIDs { for _, peerID := range peerIDs {
peer := account.GetPeer(peerID) peer := account.GetPeer(peerID)
@@ -456,7 +300,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(peerID string) (*Network, error)
// to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied // to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused). // Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
// The peer property is just a placeholder for the Peer properties to pass further // The peer property is just a placeholder for the Peer properties to pass further
func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error) { func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, error) {
if setupKey == "" && userID == "" { if setupKey == "" && userID == "" {
// no auth method provided => reject access // no auth method provided => reject access
return nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login") return nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
@@ -510,6 +354,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
} }
var ephemeral bool var ephemeral bool
setupKeyName := ""
if !addedByUser { if !addedByUser {
// validate the setup key if adding with a key // validate the setup key if adding with a key
sk, err := account.FindSetupKey(upperKey) sk, err := account.FindSetupKey(upperKey)
@@ -525,6 +370,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
opEvent.InitiatorID = sk.Id opEvent.InitiatorID = sk.Id
opEvent.Activity = activity.PeerAddedWithSetupKey opEvent.Activity = activity.PeerAddedWithSetupKey
ephemeral = sk.Ephemeral ephemeral = sk.Ephemeral
setupKeyName = sk.Name
} else { } else {
opEvent.InitiatorID = userID opEvent.InitiatorID = userID
opEvent.Activity = activity.PeerAddedByUser opEvent.Activity = activity.PeerAddedByUser
@@ -545,7 +391,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
return nil, nil, err return nil, nil, err
} }
newPeer := &Peer{ newPeer := &nbpeer.Peer{
ID: xid.New().String(), ID: xid.New().String(),
Key: peer.Key, Key: peer.Key,
SetupKey: upperKey, SetupKey: upperKey,
@@ -554,7 +400,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
Name: peer.Meta.Hostname, Name: peer.Meta.Hostname,
DNSLabel: newLabel, DNSLabel: newLabel,
UserID: userID, UserID: userID,
Status: &PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
SSHEnabled: false, SSHEnabled: false,
SSHKey: peer.SSHKey, SSHKey: peer.SSHKey,
LastLogin: time.Now().UTC(), LastLogin: time.Now().UTC(),
@@ -562,6 +408,10 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
Ephemeral: ephemeral, Ephemeral: ephemeral,
} }
if account.Settings.Extra != nil {
newPeer = additions.PreparePeer(newPeer, account.Settings.Extra)
}
// add peer to 'All' group // add peer to 'All' group
group, err := account.GetGroupAll() group, err := account.GetGroupAll()
if err != nil { if err != nil {
@@ -599,6 +449,10 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
opEvent.TargetID = newPeer.ID opEvent.TargetID = newPeer.ID
opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain()) opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
if !addedByUser {
opEvent.Meta["setup_key_name"] = setupKeyName
}
am.StoreEvent(opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) am.StoreEvent(opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
am.updateAccountPeers(account) am.updateAccountPeers(account)
@@ -608,7 +462,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (*
} }
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*Peer, *NetworkMap, error) { func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *NetworkMap, error) {
account, err := am.Store.GetAccountByPeerPubKey(sync.WireGuardPubKey) account, err := am.Store.GetAccountByPeerPubKey(sync.WireGuardPubKey)
if err != nil { if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
@@ -645,14 +499,14 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*Peer, *NetworkMap, er
// LoginPeer logs in or registers a peer. // LoginPeer logs in or registers a peer.
// If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. // If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so.
func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap, error) { func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) {
account, err := am.Store.GetAccountByPeerPubKey(login.WireGuardPubKey) account, err := am.Store.GetAccountByPeerPubKey(login.WireGuardPubKey)
if err != nil { if err != nil {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
// we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet.
// Try registering it. // Try registering it.
return am.AddPeer(login.SetupKey, login.UserID, &Peer{ return am.AddPeer(login.SetupKey, login.UserID, &nbpeer.Peer{
Key: login.WireGuardPubKey, Key: login.WireGuardPubKey,
Meta: login.Meta, Meta: login.Meta,
SSHKey: login.SSHKey, SSHKey: login.SSHKey,
@@ -722,7 +576,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap,
return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil
} }
func checkIfPeerOwnerIsBlocked(peer *Peer, account *Account) error { func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error {
if peer.AddedWithSSOLogin() { if peer.AddedWithSSOLogin() {
user, err := account.FindUser(peer.UserID) user, err := account.FindUser(peer.UserID)
if err != nil { if err != nil {
@@ -735,7 +589,7 @@ func checkIfPeerOwnerIsBlocked(peer *Peer, account *Account) error {
return nil return nil
} }
func checkAuth(loginUserID string, peer *Peer) error { func checkAuth(loginUserID string, peer *nbpeer.Peer) error {
if loginUserID == "" { if loginUserID == "" {
// absence of a user ID indicates that JWT wasn't provided. // absence of a user ID indicates that JWT wasn't provided.
return status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") return status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more")
@@ -747,7 +601,7 @@ func checkAuth(loginUserID string, peer *Peer) error {
return nil return nil
} }
func peerLoginExpired(peer *Peer, account *Account) bool { func peerLoginExpired(peer *nbpeer.Peer, account *Account) bool {
expired, expiresIn := peer.LoginExpired(account.Settings.PeerLoginExpiration) expired, expiresIn := peer.LoginExpired(account.Settings.PeerLoginExpiration)
expired = account.Settings.PeerLoginExpirationEnabled && expired expired = account.Settings.PeerLoginExpirationEnabled && expired
if expired || peer.Status.LoginExpired { if expired || peer.Status.LoginExpired {
@@ -757,21 +611,12 @@ func peerLoginExpired(peer *Peer, account *Account) bool {
return false return false
} }
func updatePeerLastLogin(peer *Peer, account *Account) { func updatePeerLastLogin(peer *nbpeer.Peer, account *Account) {
peer.UpdateLastLogin() peer.UpdateLastLogin()
account.UpdatePeer(peer) account.UpdatePeer(peer)
} }
// UpdateLastLogin and set login expired false func (am *DefaultAccountManager) checkAndUpdatePeerSSHKey(peer *nbpeer.Peer, account *Account, newSSHKey string) (*nbpeer.Peer, error) {
func (p *Peer) UpdateLastLogin() *Peer {
p.LastLogin = time.Now().UTC()
newStatus := p.Status.Copy()
newStatus.LoginExpired = false
p.Status = newStatus
return p
}
func (am *DefaultAccountManager) checkAndUpdatePeerSSHKey(peer *Peer, account *Account, newSSHKey string) (*Peer, error) {
if len(newSSHKey) == 0 { if len(newSSHKey) == 0 {
log.Debugf("no new SSH key provided for peer %s, skipping update", peer.ID) log.Debugf("no new SSH key provided for peer %s, skipping update", peer.ID)
return peer, nil return peer, nil
@@ -842,7 +687,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string)
} }
// GetPeer for a given accountID, peerID and userID error if not found. // GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*Peer, error) { func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
@@ -885,7 +730,7 @@ func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*Pee
return nil, status.Errorf(status.Internal, "user %s has no access to peer %s under account %s", userID, peerID, accountID) return nil, status.Errorf(status.Internal, "user %s has no access to peer %s under account %s", userID, peerID, accountID)
} }
func updatePeerMeta(peer *Peer, meta PeerSystemMeta, account *Account) (*Peer, bool) { func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Account) (*nbpeer.Peer, bool) {
if peer.UpdateMetaIfNew(meta) { if peer.UpdateMetaIfNew(meta) {
account.UpdatePeer(peer) account.UpdatePeer(peer)
return peer, true return peer, true

View File

@@ -0,0 +1,181 @@
package peer
import (
"fmt"
"net"
"time"
)
// Peer represents a machine connected to the network.
// The Peer is a WireGuard peer identified by a public key
type Peer struct {
// ID is an internal ID of the peer
ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index;uniqueIndex:idx_peers_account_id_ip"`
// WireGuard public key
Key string `gorm:"index"`
// A setup key this peer was registered with
SetupKey string
// IP address of the Peer
IP net.IP `gorm:"uniqueIndex:idx_peers_account_id_ip"`
// Meta is a Peer system meta data
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
// Name is peer's name (machine name)
Name string
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
DNSLabel string
// Status peer's management connection status
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
// The user ID that registered the peer
UserID string
// SSHKey is a public SSH key of the peer
SSHKey string
// SSHEnabled indicates whether SSH server is enabled on the peer
SSHEnabled bool
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
// Works with LastLogin
LoginExpirationEnabled bool
// LastLogin the time when peer performed last login operation
LastLogin time.Time
// Indicate ephemeral peer attribute
Ephemeral bool
}
type PeerStatus struct {
// LastSeen is the last time peer was connected to the management service
LastSeen time.Time
// Connected indicates whether peer is connected to the management service or not
Connected bool
// LoginExpired
LoginExpired bool
// RequiresApproval indicates whether peer requires approval or not
RequiresApproval bool
}
// PeerSystemMeta is a metadata of a Peer machine system
type PeerSystemMeta struct {
Hostname string
GoOS string
Kernel string
Core string
Platform string
OS string
WtVersion string
UIVersion string
}
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
return p.Hostname == other.Hostname &&
p.GoOS == other.GoOS &&
p.Kernel == other.Kernel &&
p.Core == other.Core &&
p.Platform == other.Platform &&
p.OS == other.OS &&
p.WtVersion == other.WtVersion &&
p.UIVersion == other.UIVersion
}
// AddedWithSSOLogin indicates whether this peer has been added with an SSO login by a user.
func (p *Peer) AddedWithSSOLogin() bool {
return p.UserID != ""
}
// Copy copies Peer object
func (p *Peer) Copy() *Peer {
peerStatus := p.Status
if peerStatus != nil {
peerStatus = p.Status.Copy()
}
return &Peer{
ID: p.ID,
AccountID: p.AccountID,
Key: p.Key,
SetupKey: p.SetupKey,
IP: p.IP,
Meta: p.Meta,
Name: p.Name,
DNSLabel: p.DNSLabel,
Status: peerStatus,
UserID: p.UserID,
SSHKey: p.SSHKey,
SSHEnabled: p.SSHEnabled,
LoginExpirationEnabled: p.LoginExpirationEnabled,
LastLogin: p.LastLogin,
Ephemeral: p.Ephemeral,
}
}
// UpdateMetaIfNew updates peer's system metadata if new information is provided
// returns true if meta was updated, false otherwise
func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) bool {
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
if meta.UIVersion == "" {
meta.UIVersion = p.Meta.UIVersion
}
if p.Meta.isEqual(meta) {
return false
}
p.Meta = meta
return true
}
// MarkLoginExpired marks peer's status expired or not
func (p *Peer) MarkLoginExpired(expired bool) {
newStatus := p.Status.Copy()
newStatus.LoginExpired = expired
if expired {
newStatus.Connected = false
}
p.Status = newStatus
}
// LoginExpired indicates whether the peer's login has expired or not.
// If Peer.LastLogin plus the expiresIn duration has happened already; then login has expired.
// Return true if a login has expired, false otherwise, and time left to expiration (negative when expired).
// Login expiration can be disabled/enabled on a Peer level via Peer.LoginExpirationEnabled property.
// Login expiration can also be disabled/enabled globally on the Account level via Settings.PeerLoginExpirationEnabled.
// Only peers added by interactive SSO login can be expired.
func (p *Peer) LoginExpired(expiresIn time.Duration) (bool, time.Duration) {
if !p.AddedWithSSOLogin() || !p.LoginExpirationEnabled {
return false, 0
}
expiresAt := p.LastLogin.Add(expiresIn)
now := time.Now()
timeLeft := expiresAt.Sub(now)
return timeLeft <= 0, timeLeft
}
// FQDN returns peers FQDN combined of the peer's DNS label and the system's DNS domain
func (p *Peer) FQDN(dnsDomain string) string {
if dnsDomain == "" {
return ""
}
return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain)
}
// EventMeta returns activity event meta related to the peer
func (p *Peer) EventMeta(dnsDomain string) map[string]any {
return map[string]any{"name": p.Name, "fqdn": p.FQDN(dnsDomain), "ip": p.IP}
}
// Copy PeerStatus
func (p *PeerStatus) Copy() *PeerStatus {
return &PeerStatus{
LastSeen: p.LastSeen,
Connected: p.Connected,
LoginExpired: p.LoginExpired,
RequiresApproval: p.RequiresApproval,
}
}
// UpdateLastLogin and set login expired false
func (p *Peer) UpdateLastLogin() *Peer {
p.LastLogin = time.Now().UTC()
newStatus := p.Status.Copy()
newStatus.LoginExpired = false
p.Status = newStatus
return p
}

View File

@@ -8,6 +8,8 @@ import (
"github.com/rs/xid" "github.com/rs/xid"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
) )
func TestPeer_LoginExpired(t *testing.T) { func TestPeer_LoginExpired(t *testing.T) {
@@ -52,7 +54,7 @@ func TestPeer_LoginExpired(t *testing.T) {
for _, c := range tt { for _, c := range tt {
t.Run(c.name, func(t *testing.T) { t.Run(c.name, func(t *testing.T) {
peer := &Peer{ peer := &nbpeer.Peer{
LoginExpirationEnabled: c.expirationEnabled, LoginExpirationEnabled: c.expirationEnabled,
LastLogin: c.lastLogin, LastLogin: c.lastLogin,
UserID: userID, UserID: userID,
@@ -90,9 +92,9 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
return return
} }
peer1, _, err := manager.AddPeer(setupKey.Key, "", &Peer{ peer1, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: peerKey1.PublicKey().String(), Key: peerKey1.PublicKey().String(),
Meta: PeerSystemMeta{Hostname: "test-peer-1"}, Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
}) })
if err != nil { if err != nil {
t.Errorf("expecting peer to be added, got failure %v", err) t.Errorf("expecting peer to be added, got failure %v", err)
@@ -104,9 +106,9 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
t.Fatal(err) t.Fatal(err)
return return
} }
_, _, err = manager.AddPeer(setupKey.Key, "", &Peer{ _, _, err = manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: peerKey2.PublicKey().String(), Key: peerKey2.PublicKey().String(),
Meta: PeerSystemMeta{Hostname: "test-peer-2"}, Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
}) })
if err != nil { if err != nil {
@@ -163,9 +165,9 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
return return
} }
peer1, _, err := manager.AddPeer(setupKey.Key, "", &Peer{ peer1, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: peerKey1.PublicKey().String(), Key: peerKey1.PublicKey().String(),
Meta: PeerSystemMeta{Hostname: "test-peer-1"}, Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
}) })
if err != nil { if err != nil {
t.Errorf("expecting peer to be added, got failure %v", err) t.Errorf("expecting peer to be added, got failure %v", err)
@@ -177,9 +179,9 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
t.Fatal(err) t.Fatal(err)
return return
} }
peer2, _, err := manager.AddPeer(setupKey.Key, "", &Peer{ peer2, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: peerKey2.PublicKey().String(), Key: peerKey2.PublicKey().String(),
Meta: PeerSystemMeta{Hostname: "test-peer-2"}, Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
}) })
if err != nil { if err != nil {
t.Errorf("expecting peer to be added, got failure %v", err) t.Errorf("expecting peer to be added, got failure %v", err)
@@ -339,9 +341,9 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
return return
} }
peer1, _, err := manager.AddPeer(setupKey.Key, "", &Peer{ peer1, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: peerKey1.PublicKey().String(), Key: peerKey1.PublicKey().String(),
Meta: PeerSystemMeta{Hostname: "test-peer-1"}, Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
}) })
if err != nil { if err != nil {
t.Errorf("expecting peer to be added, got failure %v", err) t.Errorf("expecting peer to be added, got failure %v", err)
@@ -353,9 +355,9 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
t.Fatal(err) t.Fatal(err)
return return
} }
_, _, err = manager.AddPeer(setupKey.Key, "", &Peer{ _, _, err = manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: peerKey2.PublicKey().String(), Key: peerKey2.PublicKey().String(),
Meta: PeerSystemMeta{Hostname: "test-peer-2"}, Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
}) })
if err != nil { if err != nil {
@@ -409,9 +411,9 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
return return
} }
peer1, _, err := manager.AddPeer("", someUser, &Peer{ peer1, _, err := manager.AddPeer("", someUser, &nbpeer.Peer{
Key: peerKey1.PublicKey().String(), Key: peerKey1.PublicKey().String(),
Meta: PeerSystemMeta{Hostname: "test-peer-2"}, Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
}) })
if err != nil { if err != nil {
t.Errorf("expecting peer to be added, got failure %v", err) t.Errorf("expecting peer to be added, got failure %v", err)
@@ -425,9 +427,9 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
} }
// the second peer added with a setup key // the second peer added with a setup key
peer2, _, err := manager.AddPeer(setupKey.Key, "", &Peer{ peer2, _, err := manager.AddPeer(setupKey.Key, "", &nbpeer.Peer{
Key: peerKey2.PublicKey().String(), Key: peerKey2.PublicKey().String(),
Meta: PeerSystemMeta{Hostname: "test-peer-2"}, Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@@ -5,10 +5,12 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/netbirdio/management-integrations/additions"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
) )
@@ -205,7 +207,7 @@ type FirewallRule struct {
// getPeerConnectionResources for a given peer // getPeerConnectionResources for a given peer
// //
// This function returns the list of peers and firewall rules that are applicable to a given peer. // This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) getPeerConnectionResources(peerID string) ([]*Peer, []*FirewallRule) { func (a *Account) getPeerConnectionResources(peerID string) ([]*nbpeer.Peer, []*FirewallRule) {
generateResources, getAccumulatedResources := a.connResourcesGenerator() generateResources, getAccumulatedResources := a.connResourcesGenerator()
for _, policy := range a.Policies { for _, policy := range a.Policies {
if !policy.Enabled { if !policy.Enabled {
@@ -219,6 +221,8 @@ func (a *Account) getPeerConnectionResources(peerID string) ([]*Peer, []*Firewal
sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID) sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID)
destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID) destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID)
sourcePeers = additions.ValidatePeers(sourcePeers)
destinationPeers = additions.ValidatePeers(destinationPeers)
if rule.Bidirectional { if rule.Bidirectional {
if peerInSources { if peerInSources {
@@ -247,11 +251,11 @@ func (a *Account) getPeerConnectionResources(peerID string) ([]*Peer, []*Firewal
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer. // The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be // It safe to call the generator function multiple times for same peer and different rules no duplicates will be
// generated. The accumulator function returns the result of all the generator calls. // generated. The accumulator function returns the result of all the generator calls.
func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*Peer, int), func() ([]*Peer, []*FirewallRule)) { func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
rulesExists := make(map[string]struct{}) rulesExists := make(map[string]struct{})
peersExists := make(map[string]struct{}) peersExists := make(map[string]struct{})
rules := make([]*FirewallRule, 0) rules := make([]*FirewallRule, 0)
peers := make([]*Peer, 0) peers := make([]*nbpeer.Peer, 0)
all, err := a.GetGroupAll() all, err := a.GetGroupAll()
if err != nil { if err != nil {
@@ -259,7 +263,7 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*Peer, int), fun
all = &Group{} all = &Group{}
} }
return func(rule *PolicyRule, groupPeers []*Peer, direction int) { return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) {
isAll := (len(all.Peers) - 1) == len(groupPeers) isAll := (len(all.Peers) - 1) == len(groupPeers)
for _, peer := range groupPeers { for _, peer := range groupPeers {
if peer == nil { if peer == nil {
@@ -299,7 +303,7 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*Peer, int), fun
rules = append(rules, &pr) rules = append(rules, &pr)
} }
} }
}, func() ([]*Peer, []*FirewallRule) { }, func() ([]*nbpeer.Peer, []*FirewallRule) {
return peers, rules return peers, rules
} }
} }
@@ -478,9 +482,9 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
// getAllPeersFromGroups for given peer ID and list of groups // getAllPeersFromGroups for given peer ID and list of groups
// //
// Returns list of peers and boolean indicating if peer is in any of the groups // Returns list of peers and boolean indicating if peer is in any of the groups
func getAllPeersFromGroups(account *Account, groups []string, peerID string) ([]*Peer, bool) { func getAllPeersFromGroups(account *Account, groups []string, peerID string) ([]*nbpeer.Peer, bool) {
peerInGroups := false peerInGroups := false
filteredPeers := make([]*Peer, 0, len(groups)) filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
for _, g := range groups { for _, g := range groups {
group, ok := account.Groups[g] group, ok := account.Groups[g]
if !ok { if !ok {

View File

@@ -7,42 +7,52 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
) )
func TestAccount_getPeersByPolicy(t *testing.T) { func TestAccount_getPeersByPolicy(t *testing.T) {
account := &Account{ account := &Account{
Peers: map[string]*Peer{ Peers: map[string]*nbpeer.Peer{
"peerA": { "peerA": {
ID: "peerA", ID: "peerA",
IP: net.ParseIP("100.65.14.88"), IP: net.ParseIP("100.65.14.88"),
Status: &nbpeer.PeerStatus{},
}, },
"peerB": { "peerB": {
ID: "peerB", ID: "peerB",
IP: net.ParseIP("100.65.80.39"), IP: net.ParseIP("100.65.80.39"),
Status: &nbpeer.PeerStatus{},
}, },
"peerC": { "peerC": {
ID: "peerC", ID: "peerC",
IP: net.ParseIP("100.65.254.139"), IP: net.ParseIP("100.65.254.139"),
Status: &nbpeer.PeerStatus{},
}, },
"peerD": { "peerD": {
ID: "peerD", ID: "peerD",
IP: net.ParseIP("100.65.62.5"), IP: net.ParseIP("100.65.62.5"),
Status: &nbpeer.PeerStatus{},
}, },
"peerE": { "peerE": {
ID: "peerE", ID: "peerE",
IP: net.ParseIP("100.65.32.206"), IP: net.ParseIP("100.65.32.206"),
Status: &nbpeer.PeerStatus{},
}, },
"peerF": { "peerF": {
ID: "peerF", ID: "peerF",
IP: net.ParseIP("100.65.250.202"), IP: net.ParseIP("100.65.250.202"),
Status: &nbpeer.PeerStatus{},
}, },
"peerG": { "peerG": {
ID: "peerG", ID: "peerG",
IP: net.ParseIP("100.65.13.186"), IP: net.ParseIP("100.65.13.186"),
Status: &nbpeer.PeerStatus{},
}, },
"peerH": { "peerH": {
ID: "peerH", ID: "peerH",
IP: net.ParseIP("100.65.29.55"), IP: net.ParseIP("100.65.29.55"),
Status: &nbpeer.PeerStatus{},
}, },
}, },
Groups: map[string]*Group{ Groups: map[string]*Group{
@@ -255,18 +265,21 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
func TestAccount_getPeersByPolicyDirect(t *testing.T) { func TestAccount_getPeersByPolicyDirect(t *testing.T) {
account := &Account{ account := &Account{
Peers: map[string]*Peer{ Peers: map[string]*nbpeer.Peer{
"peerA": { "peerA": {
ID: "peerA", ID: "peerA",
IP: net.ParseIP("100.65.14.88"), IP: net.ParseIP("100.65.14.88"),
Status: &nbpeer.PeerStatus{},
}, },
"peerB": { "peerB": {
ID: "peerB", ID: "peerB",
IP: net.ParseIP("100.65.80.39"), IP: net.ParseIP("100.65.80.39"),
Status: &nbpeer.PeerStatus{},
}, },
"peerC": { "peerC": {
ID: "peerC", ID: "peerC",
IP: net.ParseIP("100.65.254.139"), IP: net.ParseIP("100.65.254.139"),
Status: &nbpeer.PeerStatus{},
}, },
}, },
Groups: map[string]*Group{ Groups: map[string]*Group{

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -1045,13 +1046,13 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
return nil, err return nil, err
} }
peer1 := &Peer{ peer1 := &nbpeer.Peer{
IP: peer1IP, IP: peer1IP,
ID: peer1ID, ID: peer1ID,
Key: peer1Key, Key: peer1Key,
Name: "test-host1@netbird.io", Name: "test-host1@netbird.io",
UserID: userID, UserID: userID,
Meta: PeerSystemMeta{ Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host1@netbird.io", Hostname: "test-host1@netbird.io",
GoOS: "linux", GoOS: "linux",
Kernel: "Linux", Kernel: "Linux",
@@ -1061,6 +1062,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
WtVersion: "development", WtVersion: "development",
UIVersion: "development", UIVersion: "development",
}, },
Status: &nbpeer.PeerStatus{},
} }
account.Peers[peer1.ID] = peer1 account.Peers[peer1.ID] = peer1
@@ -1070,13 +1072,13 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
return nil, err return nil, err
} }
peer2 := &Peer{ peer2 := &nbpeer.Peer{
IP: peer2IP, IP: peer2IP,
ID: peer2ID, ID: peer2ID,
Key: peer2Key, Key: peer2Key,
Name: "test-host2@netbird.io", Name: "test-host2@netbird.io",
UserID: userID, UserID: userID,
Meta: PeerSystemMeta{ Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host2@netbird.io", Hostname: "test-host2@netbird.io",
GoOS: "linux", GoOS: "linux",
Kernel: "Linux", Kernel: "Linux",
@@ -1086,6 +1088,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
WtVersion: "development", WtVersion: "development",
UIVersion: "development", UIVersion: "development",
}, },
Status: &nbpeer.PeerStatus{},
} }
account.Peers[peer2.ID] = peer2 account.Peers[peer2.ID] = peer2
@@ -1095,13 +1098,13 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
return nil, err return nil, err
} }
peer3 := &Peer{ peer3 := &nbpeer.Peer{
IP: peer3IP, IP: peer3IP,
ID: peer3ID, ID: peer3ID,
Key: peer3Key, Key: peer3Key,
Name: "test-host3@netbird.io", Name: "test-host3@netbird.io",
UserID: userID, UserID: userID,
Meta: PeerSystemMeta{ Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host3@netbird.io", Hostname: "test-host3@netbird.io",
GoOS: "darwin", GoOS: "darwin",
Kernel: "Darwin", Kernel: "Darwin",
@@ -1111,6 +1114,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
WtVersion: "development", WtVersion: "development",
UIVersion: "development", UIVersion: "development",
}, },
Status: &nbpeer.PeerStatus{},
} }
account.Peers[peer3.ID] = peer3 account.Peers[peer3.ID] = peer3
@@ -1120,13 +1124,13 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
return nil, err return nil, err
} }
peer4 := &Peer{ peer4 := &nbpeer.Peer{
IP: peer4IP, IP: peer4IP,
ID: peer4ID, ID: peer4ID,
Key: peer4Key, Key: peer4Key,
Name: "test-host4@netbird.io", Name: "test-host4@netbird.io",
UserID: userID, UserID: userID,
Meta: PeerSystemMeta{ Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host4@netbird.io", Hostname: "test-host4@netbird.io",
GoOS: "linux", GoOS: "linux",
Kernel: "Linux", Kernel: "Linux",
@@ -1136,6 +1140,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
WtVersion: "development", WtVersion: "development",
UIVersion: "development", UIVersion: "development",
}, },
Status: &nbpeer.PeerStatus{},
} }
account.Peers[peer4.ID] = peer4 account.Peers[peer4.ID] = peer4
@@ -1145,13 +1150,13 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
return nil, err return nil, err
} }
peer5 := &Peer{ peer5 := &nbpeer.Peer{
IP: peer5IP, IP: peer5IP,
ID: peer5ID, ID: peer5ID,
Key: peer5Key, Key: peer5Key,
Name: "test-host4@netbird.io", Name: "test-host4@netbird.io",
UserID: userID, UserID: userID,
Meta: PeerSystemMeta{ Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host4@netbird.io", Hostname: "test-host4@netbird.io",
GoOS: "linux", GoOS: "linux",
Kernel: "Linux", Kernel: "Linux",
@@ -1161,6 +1166,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
WtVersion: "development", WtVersion: "development",
UIVersion: "development", UIVersion: "development",
}, },
Status: &nbpeer.PeerStatus{},
} }
account.Peers[peer5.ID] = peer5 account.Peers[peer5.ID] = peer5

View File

@@ -14,6 +14,8 @@ import (
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/account"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@@ -59,9 +61,9 @@ func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore,
sql.SetMaxOpenConns(conns) // TODO: make it configurable sql.SetMaxOpenConns(conns) // TODO: make it configurable
err = db.AutoMigrate( err = db.AutoMigrate(
&SetupKey{}, &Peer{}, &User{}, &PersonalAccessToken{}, &Group{}, &Rule{}, &SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &Group{}, &Rule{},
&Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &installation{}, &account.ExtraSettings{},
) )
if err != nil { if err != nil {
return nil, err return nil, err
@@ -251,8 +253,8 @@ func (s *SqliteStore) GetInstallationID() string {
return installation.InstallationIDValue return installation.InstallationIDValue
} }
func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus PeerStatus) error { func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
var peer Peer var peer nbpeer.Peer
result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerID) result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerID)
if result.Error != nil { if result.Error != nil {
@@ -379,7 +381,7 @@ func (s *SqliteStore) GetAccount(accountID string) (*Account, error) {
} }
account.SetupKeysG = nil account.SetupKeysG = nil
account.Peers = make(map[string]*Peer, len(account.PeersG)) account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
for _, peer := range account.PeersG { for _, peer := range account.PeersG {
account.Peers[peer.ID] = peer.Copy() account.Peers[peer.ID] = peer.Copy()
} }
@@ -437,7 +439,7 @@ func (s *SqliteStore) GetAccountByUser(userID string) (*Account, error) {
} }
func (s *SqliteStore) GetAccountByPeerID(peerID string) (*Account, error) { func (s *SqliteStore) GetAccountByPeerID(peerID string) (*Account, error) {
var peer Peer var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, "id = ?", peerID) result := s.db.Select("account_id").First(&peer, "id = ?", peerID)
if result.Error != nil { if result.Error != nil {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
@@ -451,7 +453,7 @@ func (s *SqliteStore) GetAccountByPeerID(peerID string) (*Account, error) {
} }
func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
var peer Peer var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, "key = ?", peerKey) result := s.db.Select("account_id").First(&peer, "key = ?", peerKey)
if result.Error != nil { if result.Error != nil {

View File

@@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -37,13 +38,13 @@ func TestSqlite_SaveAccount(t *testing.T) {
account := newAccountWithId("account_id", "testuser", "") account := newAccountWithId("account_id", "testuser", "")
setupKey := GenerateDefaultSetupKey() setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &Peer{ account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey", Key: "peerkey",
SetupKey: "peerkeysetupkey", SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1}, IP: net.IP{127, 0, 0, 1},
Meta: PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name", Name: "peer name",
Status: &PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
} }
err := store.SaveAccount(account) err := store.SaveAccount(account)
@@ -52,13 +53,13 @@ func TestSqlite_SaveAccount(t *testing.T) {
account2 := newAccountWithId("account_id2", "testuser2", "") account2 := newAccountWithId("account_id2", "testuser2", "")
setupKey = GenerateDefaultSetupKey() setupKey = GenerateDefaultSetupKey()
account2.SetupKeys[setupKey.Key] = setupKey account2.SetupKeys[setupKey.Key] = setupKey
account2.Peers["testpeer2"] = &Peer{ account2.Peers["testpeer2"] = &nbpeer.Peer{
Key: "peerkey2", Key: "peerkey2",
SetupKey: "peerkeysetupkey2", SetupKey: "peerkeysetupkey2",
IP: net.IP{127, 0, 0, 2}, IP: net.IP{127, 0, 0, 2},
Meta: PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name 2", Name: "peer name 2",
Status: &PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
} }
err = store.SaveAccount(account2) err = store.SaveAccount(account2)
@@ -116,13 +117,13 @@ func TestSqlite_DeleteAccount(t *testing.T) {
account := newAccountWithId("account_id", testUserID, "") account := newAccountWithId("account_id", testUserID, "")
setupKey := GenerateDefaultSetupKey() setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey account.SetupKeys[setupKey.Key] = setupKey
account.Peers["testpeer"] = &Peer{ account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey", Key: "peerkey",
SetupKey: "peerkeysetupkey", SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1}, IP: net.IP{127, 0, 0, 1},
Meta: PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name", Name: "peer name",
Status: &PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
} }
account.Users[testUserID] = user account.Users[testUserID] = user
@@ -184,19 +185,19 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// save status of non-existing peer // save status of non-existing peer
newStatus := PeerStatus{Connected: true, LastSeen: time.Now().UTC()} newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
assert.Error(t, err) assert.Error(t, err)
// save new status of existing peer // save new status of existing peer
account.Peers["testpeer"] = &Peer{ account.Peers["testpeer"] = &nbpeer.Peer{
Key: "peerkey", Key: "peerkey",
ID: "testpeer", ID: "testpeer",
SetupKey: "peerkeysetupkey", SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1}, IP: net.IP{127, 0, 0, 1},
Meta: PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name", Name: "peer name",
Status: &PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
} }
err = store.SaveAccount(account) err = store.SaveAccount(account)
@@ -291,13 +292,13 @@ func newAccount(store Store, id int) error {
account := newAccountWithId(str, str+"-testuser", "example.com") account := newAccountWithId(str, str+"-testuser", "example.com")
setupKey := GenerateDefaultSetupKey() setupKey := GenerateDefaultSetupKey()
account.SetupKeys[setupKey.Key] = setupKey account.SetupKeys[setupKey.Key] = setupKey
account.Peers["p"+str] = &Peer{ account.Peers["p"+str] = &nbpeer.Peer{
Key: "peerkey" + str, Key: "peerkey" + str,
SetupKey: "peerkeysetupkey", SetupKey: "peerkeysetupkey",
IP: net.IP{127, 0, 0, 1}, IP: net.IP{127, 0, 0, 1},
Meta: PeerSystemMeta{}, Meta: nbpeer.PeerSystemMeta{},
Name: "peer name", Name: "peer name",
Status: &PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
} }
return store.SaveAccount(account) return store.SaveAccount(account)

Some files were not shown because too many files have changed in this diff Show More