mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 07:46:38 +00:00
Compare commits
30 Commits
feature/ke
...
v0.21.9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a4d830ef83 | ||
|
|
9e540cd5b4 | ||
|
|
3027d8f27e | ||
|
|
e69ec6ab6a | ||
|
|
7ddde41c92 | ||
|
|
7ebe58f20a | ||
|
|
9c2c0e7934 | ||
|
|
c6af1037d9 | ||
|
|
5cb9a126f1 | ||
|
|
f40951cdf5 | ||
|
|
6e264d9de7 | ||
|
|
42db9773f4 | ||
|
|
bb9f6f6d0a | ||
|
|
829ce6573e | ||
|
|
a366d9e208 | ||
|
|
e074c24487 | ||
|
|
54fe05f6d8 | ||
|
|
33a155d9aa | ||
|
|
51878659f8 | ||
|
|
c000c05435 | ||
|
|
b39ffef22c | ||
|
|
d96f882acb | ||
|
|
d409219b51 | ||
|
|
8b619a8224 | ||
|
|
ed075bc9b9 | ||
|
|
8eb098d6fd | ||
|
|
68a8687c80 | ||
|
|
f7d97b02fd | ||
|
|
2691e729cd | ||
|
|
b524a9d49d |
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -116,7 +116,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-mingw-w64-x86-64
|
||||
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
||||
- name: Install rsrc
|
||||
run: go install github.com/akavel/rsrc@v0.10.2
|
||||
- name: Generate windows rsrc
|
||||
|
||||
@@ -53,6 +53,7 @@ jobs:
|
||||
CI_NETBIRD_MGMT_IDP: "none"
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
||||
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
|
||||
|
||||
- name: check values
|
||||
working-directory: infrastructure_files
|
||||
|
||||
@@ -11,6 +11,8 @@ builds:
|
||||
- amd64
|
||||
ldflags:
|
||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||
tags:
|
||||
- legacy_appindicator
|
||||
mod_timestamp: '{{ .CommitTimestamp }}'
|
||||
|
||||
- id: netbird-ui-windows
|
||||
@@ -55,9 +57,6 @@ nfpms:
|
||||
- src: client/ui/disconnected.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- libayatana-appindicator3-1
|
||||
- libgtk-3-dev
|
||||
- libappindicator3-dev
|
||||
- netbird
|
||||
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
@@ -75,9 +74,6 @@ nfpms:
|
||||
- src: client/ui/disconnected.png
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- libayatana-appindicator3-1
|
||||
- libgtk-3-dev
|
||||
- libappindicator3-dev
|
||||
- netbird
|
||||
|
||||
uploads:
|
||||
|
||||
@@ -70,9 +70,9 @@ For stable versions, see [releases](https://github.com/netbirdio/netbird/release
|
||||
|
||||
### Start using NetBird
|
||||
- Hosted version: [https://app.netbird.io/](https://app.netbird.io/).
|
||||
- See our documentation for [Quickstart Guide](https://netbird.io/docs/getting-started/quickstart).
|
||||
- If you are looking to self-host NetBird, check our [Self-Hosting Guide](https://netbird.io/docs/getting-started/self-hosting).
|
||||
- Step-by-step [Installation Guide](https://netbird.io/docs/getting-started/installation) for different platforms.
|
||||
- See our documentation for [Quickstart Guide](https://docs.netbird.io/how-to/getting-started).
|
||||
- If you are looking to self-host NetBird, check our [Self-Hosting Guide](https://docs.netbird.io/selfhosted/selfhosted-guide).
|
||||
- Step-by-step [Installation Guide](https://docs.netbird.io/how-to/getting-started#installation) for different platforms.
|
||||
- Web UI [repository](https://github.com/netbirdio/dashboard).
|
||||
- 5 min [demo video](https://youtu.be/Tu9tPsUWaY0) on YouTube.
|
||||
|
||||
@@ -91,7 +91,7 @@ For stable versions, see [releases](https://github.com/netbirdio/netbird/release
|
||||
<img src="https://netbird.io/docs/img/architecture/high-level-dia.png" width="700"/>
|
||||
</p>
|
||||
|
||||
See a complete [architecture overview](https://netbird.io/docs/overview/architecture) for details.
|
||||
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
||||
|
||||
### Roadmap
|
||||
- [Public Roadmap](https://github.com/netbirdio/netbird/projects/2)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
@@ -35,6 +36,11 @@ type RouteListener interface {
|
||||
routemanager.RouteListener
|
||||
}
|
||||
|
||||
// DnsReadyListener export internal dns ReadyListener for mobile
|
||||
type DnsReadyListener interface {
|
||||
dns.ReadyListener
|
||||
}
|
||||
|
||||
func init() {
|
||||
formatter.SetLogcatFormatter(log.StandardLogger())
|
||||
}
|
||||
@@ -49,6 +55,7 @@ type Client struct {
|
||||
ctxCancelLock *sync.Mutex
|
||||
deviceName string
|
||||
routeListener routemanager.RouteListener
|
||||
onHostDnsFn func([]string)
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
@@ -65,7 +72,7 @@ func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover
|
||||
}
|
||||
|
||||
// Run start the internal client. It is a blocker function
|
||||
func (c *Client) Run(urlOpener URLOpener) error {
|
||||
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
|
||||
cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
@@ -90,7 +97,8 @@ func (c *Client) Run(urlOpener URLOpener) error {
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
return internal.RunClient(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener)
|
||||
c.onHostDnsFn = func([]string) {}
|
||||
return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.routeListener, dns.items, dnsReadyListener)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
@@ -126,6 +134,17 @@ func (c *Client) PeersList() *PeerInfoArray {
|
||||
return &PeerInfoArray{items: peerInfos}
|
||||
}
|
||||
|
||||
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
||||
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
||||
dnsServer, err := dns.GetServerDns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dnsServer.OnUpdatedHostDNSServer(list.items)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetConnectionListener set the network connection listener
|
||||
func (c *Client) SetConnectionListener(listener ConnectionListener) {
|
||||
c.recorder.SetConnectionListener(listener)
|
||||
|
||||
26
client/android/dns_list.go
Normal file
26
client/android/dns_list.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package android
|
||||
|
||||
import "fmt"
|
||||
|
||||
// DNSList is a wrapper of []string
|
||||
type DNSList struct {
|
||||
items []string
|
||||
}
|
||||
|
||||
// Add new DNS address to the collection
|
||||
func (array *DNSList) Add(s string) {
|
||||
array.items = append(array.items, s)
|
||||
}
|
||||
|
||||
// Get return an element of the collection
|
||||
func (array *DNSList) Get(i int) (string, error) {
|
||||
if i >= len(array.items) || i < 0 {
|
||||
return "", fmt.Errorf("out of range")
|
||||
}
|
||||
return array.items[i], nil
|
||||
}
|
||||
|
||||
// Size return with the size of the collection
|
||||
func (array *DNSList) Size() int {
|
||||
return len(array.items)
|
||||
}
|
||||
24
client/android/dns_list_test.go
Normal file
24
client/android/dns_list_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package android
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestDNSList_Get(t *testing.T) {
|
||||
l := DNSList{
|
||||
items: make([]string, 1),
|
||||
}
|
||||
|
||||
_, err := l.Get(0)
|
||||
if err != nil {
|
||||
t.Errorf("invalid error: %s", err)
|
||||
}
|
||||
|
||||
_, err = l.Get(-1)
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got nil")
|
||||
}
|
||||
|
||||
_, err = l.Get(1)
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got nil")
|
||||
}
|
||||
}
|
||||
@@ -73,7 +73,8 @@ var sshCmd = &cobra.Command{
|
||||
go func() {
|
||||
// blocking
|
||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
||||
log.Print(err)
|
||||
log.Debug(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
@@ -92,12 +93,10 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
|
||||
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
|
||||
if err != nil {
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
cmd.Printf("Couldn't connect. " +
|
||||
"You might be disconnected from the NetBird network, or the NetBird agent isn't running.\n" +
|
||||
"Run the status command: \n\n" +
|
||||
" netbird status\n\n" +
|
||||
"It might also be that the SSH server is disabled on the agent you are trying to connect to.\n")
|
||||
return nil
|
||||
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
|
||||
"You can verify the connection by running:\n\n" +
|
||||
" netbird status\n\n")
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
|
||||
@@ -104,7 +104,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
SetupCloseHandler(ctx, cancel)
|
||||
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()), nil, nil, nil)
|
||||
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
|
||||
}
|
||||
|
||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||
|
||||
@@ -51,6 +51,7 @@ type Manager interface {
|
||||
dPort *Port,
|
||||
direction RuleDirection,
|
||||
action Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) (Rule, error)
|
||||
|
||||
@@ -60,5 +61,8 @@ type Manager interface {
|
||||
// Reset firewall to the default state
|
||||
Reset() error
|
||||
|
||||
// Flush the changes to firewall controller
|
||||
Flush() error
|
||||
|
||||
// TODO: migrate routemanager firewal actions to this interface
|
||||
}
|
||||
|
||||
@@ -58,13 +58,17 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("iptables is not installed in the system or not supported")
|
||||
}
|
||||
m.ipv4Client = ipv4Client
|
||||
if isIptablesClientAvailable(ipv4Client) {
|
||||
m.ipv4Client = ipv4Client
|
||||
}
|
||||
|
||||
ipv6Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
if err != nil {
|
||||
log.Errorf("ip6tables is not installed in the system or not supported: %v", err)
|
||||
} else {
|
||||
m.ipv6Client = ipv6Client
|
||||
if isIptablesClientAvailable(ipv6Client) {
|
||||
m.ipv6Client = ipv6Client
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.Reset(); err != nil {
|
||||
@@ -73,6 +77,11 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||
_, err := client.ListChains("filter")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// AddFiltering rule to the firewall
|
||||
//
|
||||
// If comment is empty rule ID is used as comment
|
||||
@@ -83,6 +92,7 @@ func (m *Manager) AddFiltering(
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) (fw.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
@@ -193,6 +203,9 @@ func (m *Manager) Reset() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// reset firewall chain, clear it and drop it
|
||||
func (m *Manager) reset(client *iptables.IPTables, table string) error {
|
||||
ok, err := client.ChainExists(table, ChainInputFilterName)
|
||||
|
||||
@@ -68,7 +68,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
t.Run("add first rule", func(t *testing.T) {
|
||||
ip := net.ParseIP("10.20.0.2")
|
||||
port := &fw.Port{Values: []int{8080}}
|
||||
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
||||
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, ChainOutputFilterName, true, rule1.(*Rule).specs...)
|
||||
@@ -81,7 +81,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
Values: []int{8043: 8046},
|
||||
}
|
||||
rule2, err = manager.AddFiltering(
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTPS traffic from ports range")
|
||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
checkRuleSpecs(t, ipv4Client, ChainInputFilterName, true, rule2.(*Rule).specs...)
|
||||
@@ -107,7 +107,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
// add second rule
|
||||
ip := net.ParseIP("10.20.0.3")
|
||||
port := &fw.Port{Values: []int{5353}}
|
||||
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept Fake DNS traffic")
|
||||
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Reset()
|
||||
@@ -167,9 +167,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
@@ -6,12 +6,14 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall"
|
||||
@@ -29,11 +31,14 @@ const (
|
||||
FilterOutputChainName = "netbird-acl-output-filter"
|
||||
)
|
||||
|
||||
var anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
|
||||
// Manager of iptables firewall
|
||||
type Manager struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
conn *nftables.Conn
|
||||
rConn *nftables.Conn
|
||||
sConn *nftables.Conn
|
||||
tableIPv4 *nftables.Table
|
||||
tableIPv6 *nftables.Table
|
||||
|
||||
@@ -43,6 +48,10 @@ type Manager struct {
|
||||
filterInputChainIPv6 *nftables.Chain
|
||||
filterOutputChainIPv6 *nftables.Chain
|
||||
|
||||
rulesetManager *rulesetManager
|
||||
setRemovedIPs map[string]struct{}
|
||||
setRemoved map[string]*nftables.Set
|
||||
|
||||
wgIface iFaceMapper
|
||||
}
|
||||
|
||||
@@ -54,8 +63,23 @@ type iFaceMapper interface {
|
||||
|
||||
// Create nftables firewall manager
|
||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
// sConn is used for creating sets and adding/removing elements from them
|
||||
// it's differ then rConn (which does create new conn for each flush operation)
|
||||
// and is permanent. Using same connection for booth type of operations
|
||||
// overloads netlink with high amount of rules ( > 10000)
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := &Manager{
|
||||
conn: &nftables.Conn{},
|
||||
rConn: &nftables.Conn{},
|
||||
sConn: sConn,
|
||||
|
||||
rulesetManager: newRuleManager(),
|
||||
setRemovedIPs: map[string]struct{}{},
|
||||
setRemoved: map[string]*nftables.Set{},
|
||||
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
@@ -77,6 +101,7 @@ func (m *Manager) AddFiltering(
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) (fw.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
@@ -84,6 +109,7 @@ func (m *Manager) AddFiltering(
|
||||
|
||||
var (
|
||||
err error
|
||||
ipset *nftables.Set
|
||||
table *nftables.Table
|
||||
chain *nftables.Chain
|
||||
)
|
||||
@@ -107,6 +133,46 @@ func (m *Manager) AddFiltering(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawIP := ip.To4()
|
||||
if rawIP == nil {
|
||||
rawIP = ip.To16()
|
||||
}
|
||||
|
||||
rulesetID := m.getRulesetID(ip, proto, sPort, dPort, direction, action, ipsetName)
|
||||
|
||||
if ipsetName != "" {
|
||||
// if we already have set with given name, just add ip to the set
|
||||
// and return rule with new ID in other case let's create rule
|
||||
// with fresh created set and set element
|
||||
|
||||
var isSetNew bool
|
||||
ipset, err := m.rConn.GetSetByName(table, ipsetName)
|
||||
if err != nil {
|
||||
if ipset, err = m.createSet(table, rawIP, ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("get set name: %v", err)
|
||||
}
|
||||
isSetNew = true
|
||||
}
|
||||
|
||||
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
|
||||
return nil, fmt.Errorf("add set element for the first time: %v", err)
|
||||
}
|
||||
if err := m.sConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush add elements: %v", err)
|
||||
}
|
||||
|
||||
if !isSetNew {
|
||||
// if we already have nftables rules with set for given direction
|
||||
// just add new rule to the ruleset and return new fw.Rule object
|
||||
|
||||
if ruleset, ok := m.rulesetManager.getRuleset(rulesetID); ok {
|
||||
return m.rulesetManager.addRule(ruleset, rawIP)
|
||||
}
|
||||
// if ipset exists but it is not linked to rule for given direction
|
||||
// create new rule for direction and bind ipset to it later
|
||||
}
|
||||
}
|
||||
|
||||
ifaceKey := expr.MetaKeyIIFNAME
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
ifaceKey = expr.MetaKeyOIFNAME
|
||||
@@ -146,39 +212,47 @@ func (m *Manager) AddFiltering(
|
||||
})
|
||||
}
|
||||
|
||||
// don't use IP matching if IP is ip 0.0.0.0
|
||||
if s := ip.String(); s != "0.0.0.0" && s != "::" {
|
||||
// check if rawIP contains zeroed IPv4 0.0.0.0 or same IPv6 value
|
||||
// in that case not add IP match expression into the rule definition
|
||||
if !bytes.HasPrefix(anyIP, rawIP) {
|
||||
// source address position
|
||||
var adrLen, adrOffset uint32
|
||||
if ip.To4() == nil {
|
||||
adrLen = 16
|
||||
adrOffset = 8
|
||||
} else {
|
||||
adrLen = 4
|
||||
adrOffset = 12
|
||||
addrLen := uint32(len(rawIP))
|
||||
addrOffset := uint32(12)
|
||||
if addrLen == 16 {
|
||||
addrOffset = 8
|
||||
}
|
||||
|
||||
// change to destination address position if need
|
||||
if direction == fw.RuleDirectionOUT {
|
||||
adrOffset += adrLen
|
||||
addrOffset += addrLen
|
||||
}
|
||||
|
||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||
add := ipToAdd.Unmap()
|
||||
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: adrOffset,
|
||||
Len: adrLen,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: add.AsSlice(),
|
||||
Offset: addrOffset,
|
||||
Len: addrLen,
|
||||
},
|
||||
)
|
||||
// add individual IP for match if no ipset defined
|
||||
if ipset == nil {
|
||||
expressions = append(expressions,
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: rawIP,
|
||||
},
|
||||
)
|
||||
} else {
|
||||
expressions = append(expressions,
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: ipsetName,
|
||||
SetID: ipset.ID,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if sPort != nil && len(sPort.Values) != 0 {
|
||||
@@ -219,39 +293,76 @@ func (m *Manager) AddFiltering(
|
||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
userData := []byte(strings.Join([]string{id, comment}, " "))
|
||||
userData := []byte(strings.Join([]string{rulesetID, comment}, " "))
|
||||
|
||||
_ = m.conn.InsertRule(&nftables.Rule{
|
||||
rule := m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Position: 0,
|
||||
Exprs: expressions,
|
||||
UserData: userData,
|
||||
})
|
||||
|
||||
if err := m.conn.Flush(); err != nil {
|
||||
return nil, err
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush insert rule: %v", err)
|
||||
}
|
||||
|
||||
list, err := m.conn.GetRules(table, chain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
ruleset := m.rulesetManager.createRuleset(rulesetID, rule, ipset)
|
||||
return m.rulesetManager.addRule(ruleset, rawIP)
|
||||
}
|
||||
|
||||
// getRulesetID returns ruleset ID based on given parameters
|
||||
func (m *Manager) getRulesetID(
|
||||
ip net.IP,
|
||||
proto fw.Protocol,
|
||||
sPort *fw.Port,
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
ipsetName string,
|
||||
) string {
|
||||
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
|
||||
if sPort != nil {
|
||||
rulesetID += sPort.String()
|
||||
}
|
||||
rulesetID += ":"
|
||||
if dPort != nil {
|
||||
rulesetID += dPort.String()
|
||||
}
|
||||
rulesetID += ":"
|
||||
rulesetID += strconv.Itoa(int(action))
|
||||
if ipsetName == "" {
|
||||
return "ip:" + ip.String() + rulesetID
|
||||
}
|
||||
return "set:" + ipsetName + rulesetID
|
||||
}
|
||||
|
||||
// createSet in given table by name
|
||||
func (m *Manager) createSet(
|
||||
table *nftables.Table,
|
||||
rawIP []byte,
|
||||
name string,
|
||||
) (*nftables.Set, error) {
|
||||
keyType := nftables.TypeIPAddr
|
||||
if len(rawIP) == 16 {
|
||||
keyType = nftables.TypeIP6Addr
|
||||
}
|
||||
// else we create new ipset and continue creating rule
|
||||
ipset := &nftables.Set{
|
||||
Name: name,
|
||||
Table: table,
|
||||
Dynamic: true,
|
||||
KeyType: keyType,
|
||||
}
|
||||
|
||||
// Add the rule to the chain
|
||||
rule := &Rule{id: id}
|
||||
for _, r := range list {
|
||||
if bytes.Equal(r.UserData, userData) {
|
||||
rule.Rule = r
|
||||
break
|
||||
}
|
||||
}
|
||||
if rule.Rule == nil {
|
||||
return nil, fmt.Errorf("rule not found")
|
||||
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
||||
return nil, fmt.Errorf("create set: %v", err)
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush created set: %v", err)
|
||||
}
|
||||
|
||||
return ipset, nil
|
||||
}
|
||||
|
||||
// chain returns the chain for the given IP address with specific settings
|
||||
@@ -315,7 +426,7 @@ func (m *Manager) table(family nftables.TableFamily) (*nftables.Table, error) {
|
||||
}
|
||||
|
||||
func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables.Table, error) {
|
||||
tables, err := m.conn.ListTablesOfFamily(family)
|
||||
tables, err := m.rConn.ListTablesOfFamily(family)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
@@ -326,7 +437,11 @@ func (m *Manager) createTableIfNotExists(family nftables.TableFamily) (*nftables
|
||||
}
|
||||
}
|
||||
|
||||
return m.conn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4}), nil
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: FilterTableName, Family: nftables.TableFamilyIPv4})
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return table, nil
|
||||
}
|
||||
|
||||
func (m *Manager) createChainIfNotExists(
|
||||
@@ -341,7 +456,7 @@ func (m *Manager) createChainIfNotExists(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chains, err := m.conn.ListChainsOfTableFamily(family)
|
||||
chains, err := m.rConn.ListChainsOfTableFamily(family)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list of chains: %w", err)
|
||||
}
|
||||
@@ -362,7 +477,7 @@ func (m *Manager) createChainIfNotExists(
|
||||
Policy: &polAccept,
|
||||
}
|
||||
|
||||
chain = m.conn.AddChain(chain)
|
||||
chain = m.rConn.AddChain(chain)
|
||||
|
||||
ifaceKey := expr.MetaKeyIIFNAME
|
||||
shiftDSTAddr := 0
|
||||
@@ -429,7 +544,7 @@ func (m *Manager) createChainIfNotExists(
|
||||
)
|
||||
}
|
||||
|
||||
_ = m.conn.AddRule(&nftables.Rule{
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
@@ -444,12 +559,13 @@ func (m *Manager) createChainIfNotExists(
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||
}
|
||||
_ = m.conn.AddRule(&nftables.Rule{
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
if err := m.conn.Flush(); err != nil {
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -458,16 +574,58 @@ func (m *Manager) createChainIfNotExists(
|
||||
|
||||
// DeleteRule from the firewall by rule definition
|
||||
func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
nativeRule, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
|
||||
if err := m.conn.DelRule(nativeRule.Rule); err != nil {
|
||||
return err
|
||||
if nativeRule.nftRule == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return m.conn.Flush()
|
||||
if nativeRule.nftSet != nil {
|
||||
// call twice of delete set element raises error
|
||||
// so we need to check if element is already removed
|
||||
key := fmt.Sprintf("%s:%v", nativeRule.nftSet.Name, nativeRule.ip)
|
||||
if _, ok := m.setRemovedIPs[key]; !ok {
|
||||
err := m.sConn.SetDeleteElements(nativeRule.nftSet, []nftables.SetElement{{Key: nativeRule.ip}})
|
||||
if err != nil {
|
||||
log.Errorf("delete elements for set %q: %v", nativeRule.nftSet.Name, err)
|
||||
}
|
||||
if err := m.sConn.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
m.setRemovedIPs[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if m.rulesetManager.deleteRule(nativeRule) {
|
||||
// deleteRule indicates that we still have IP in the ruleset
|
||||
// it means we should not remove the nftables rule but need to update set
|
||||
// so we prepare IP to be removed from set on the next flush call
|
||||
return nil
|
||||
}
|
||||
|
||||
// ruleset doesn't contain IP anymore (or contains only one), remove nft rule
|
||||
if err := m.rConn.DelRule(nativeRule.nftRule); err != nil {
|
||||
log.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
nativeRule.nftRule = nil
|
||||
|
||||
if nativeRule.nftSet != nil {
|
||||
if _, ok := m.setRemoved[nativeRule.nftSet.Name]; !ok {
|
||||
m.setRemoved[nativeRule.nftSet.Name] = nativeRule.nftSet
|
||||
}
|
||||
nativeRule.nftSet = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset firewall to the default state
|
||||
@@ -475,27 +633,116 @@ func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
chains, err := m.conn.ListChains()
|
||||
chains, err := m.rConn.ListChains()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of chains: %w", err)
|
||||
}
|
||||
for _, c := range chains {
|
||||
if c.Name == FilterInputChainName || c.Name == FilterOutputChainName {
|
||||
m.conn.DelChain(c)
|
||||
m.rConn.DelChain(c)
|
||||
}
|
||||
}
|
||||
|
||||
tables, err := m.conn.ListTables()
|
||||
tables, err := m.rConn.ListTables()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
for _, t := range tables {
|
||||
if t.Name == FilterTableName {
|
||||
m.conn.DelTable(t)
|
||||
m.rConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
return m.conn.Flush()
|
||||
return m.rConn.Flush()
|
||||
}
|
||||
|
||||
// Flush rule/chain/set operations from the buffer
|
||||
//
|
||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||
func (m *Manager) Flush() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if err := m.flushWithBackoff(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// set must be removed after flush rule changes
|
||||
// otherwise we will get error
|
||||
for _, s := range m.setRemoved {
|
||||
m.rConn.FlushSet(s)
|
||||
m.rConn.DelSet(s)
|
||||
}
|
||||
|
||||
if len(m.setRemoved) > 0 {
|
||||
if err := m.flushWithBackoff(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
m.setRemovedIPs = map[string]struct{}{}
|
||||
m.setRemoved = map[string]*nftables.Set{}
|
||||
|
||||
if err := m.refreshRuleHandles(m.tableIPv4, m.filterInputChainIPv4); err != nil {
|
||||
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.tableIPv4, m.filterOutputChainIPv4); err != nil {
|
||||
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.tableIPv6, m.filterInputChainIPv6); err != nil {
|
||||
log.Errorf("failed to refresh rule handles IPv6 input chain: %v", err)
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.tableIPv6, m.filterOutputChainIPv6); err != nil {
|
||||
log.Errorf("failed to refresh rule handles IPv6 output chain: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) flushWithBackoff() (err error) {
|
||||
backoff := 4
|
||||
backoffTime := 1000 * time.Millisecond
|
||||
for i := 0; ; i++ {
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), "busy") {
|
||||
return
|
||||
}
|
||||
log.Error("failed to flush nftables, retrying...")
|
||||
if i == backoff-1 {
|
||||
return err
|
||||
}
|
||||
time.Sleep(backoffTime)
|
||||
backoffTime = backoffTime * 2
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Manager) refreshRuleHandles(table *nftables.Table, chain *nftables.Chain) error {
|
||||
if table == nil || chain == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
list, err := m.rConn.GetRules(table, chain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, rule := range list {
|
||||
if len(rule.UserData) != 0 {
|
||||
if err := m.rulesetManager.setNftRuleHandle(rule); err != nil {
|
||||
log.Errorf("failed to set rule handle: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodePort(port fw.Port) []byte {
|
||||
|
||||
@@ -55,7 +55,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
defer func() {
|
||||
err = manager.Reset()
|
||||
@@ -75,11 +75,16 @@ func TestNftablesManager(t *testing.T) {
|
||||
fw.RuleDirectionIN,
|
||||
fw.ActionDrop,
|
||||
"",
|
||||
"",
|
||||
)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err := testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
|
||||
// test expectations:
|
||||
// 1) regular rule
|
||||
// 2) "accept extra routed traffic rule" for the interface
|
||||
@@ -135,6 +140,9 @@ func TestNftablesManager(t *testing.T) {
|
||||
err = manager.DeleteRule(rule)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err = testClient.GetRules(manager.tableIPv4, manager.filterInputChainIPv4)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
// test expectations:
|
||||
@@ -167,7 +175,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
defer func() {
|
||||
if err := manager.Reset(); err != nil {
|
||||
@@ -181,13 +189,18 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
if i%100 == 0 {
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("execution avg per rule: %s", time.Since(start)/time.Duration(testMax))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,11 +6,14 @@ import (
|
||||
|
||||
// Rule to handle management of rules
|
||||
type Rule struct {
|
||||
*nftables.Rule
|
||||
id string
|
||||
nftRule *nftables.Rule
|
||||
nftSet *nftables.Set
|
||||
|
||||
ruleID string
|
||||
ip []byte
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *Rule) GetRuleID() string {
|
||||
return r.id
|
||||
return r.ruleID
|
||||
}
|
||||
|
||||
115
client/firewall/nftables/ruleset_linux.go
Normal file
115
client/firewall/nftables/ruleset_linux.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
// nftRuleset links native firewall rule and ipset to ACL generated rules
|
||||
type nftRuleset struct {
|
||||
nftRule *nftables.Rule
|
||||
nftSet *nftables.Set
|
||||
issuedRules map[string]*Rule
|
||||
rulesetID string
|
||||
}
|
||||
|
||||
type rulesetManager struct {
|
||||
rulesets map[string]*nftRuleset
|
||||
|
||||
nftSetName2rulesetID map[string]string
|
||||
issuedRuleID2rulesetID map[string]string
|
||||
}
|
||||
|
||||
func newRuleManager() *rulesetManager {
|
||||
return &rulesetManager{
|
||||
rulesets: map[string]*nftRuleset{},
|
||||
|
||||
nftSetName2rulesetID: map[string]string{},
|
||||
issuedRuleID2rulesetID: map[string]string{},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rulesetManager) getRuleset(rulesetID string) (*nftRuleset, bool) {
|
||||
ruleset, ok := r.rulesets[rulesetID]
|
||||
return ruleset, ok
|
||||
}
|
||||
|
||||
func (r *rulesetManager) createRuleset(
|
||||
rulesetID string,
|
||||
nftRule *nftables.Rule,
|
||||
nftSet *nftables.Set,
|
||||
) *nftRuleset {
|
||||
ruleset := nftRuleset{
|
||||
rulesetID: rulesetID,
|
||||
nftRule: nftRule,
|
||||
nftSet: nftSet,
|
||||
issuedRules: map[string]*Rule{},
|
||||
}
|
||||
r.rulesets[ruleset.rulesetID] = &ruleset
|
||||
if nftSet != nil {
|
||||
r.nftSetName2rulesetID[nftSet.Name] = ruleset.rulesetID
|
||||
}
|
||||
return &ruleset
|
||||
}
|
||||
|
||||
func (r *rulesetManager) addRule(
|
||||
ruleset *nftRuleset,
|
||||
ip []byte,
|
||||
) (*Rule, error) {
|
||||
if _, ok := r.rulesets[ruleset.rulesetID]; !ok {
|
||||
return nil, fmt.Errorf("ruleset not found")
|
||||
}
|
||||
|
||||
rule := Rule{
|
||||
nftRule: ruleset.nftRule,
|
||||
nftSet: ruleset.nftSet,
|
||||
ruleID: xid.New().String(),
|
||||
ip: ip,
|
||||
}
|
||||
|
||||
ruleset.issuedRules[rule.ruleID] = &rule
|
||||
r.issuedRuleID2rulesetID[rule.ruleID] = ruleset.rulesetID
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
// deleteRule from ruleset and returns true if contains other rules
|
||||
func (r *rulesetManager) deleteRule(rule *Rule) bool {
|
||||
rulesetID, ok := r.issuedRuleID2rulesetID[rule.ruleID]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
ruleset := r.rulesets[rulesetID]
|
||||
if ruleset.nftRule == nil {
|
||||
return false
|
||||
}
|
||||
delete(r.issuedRuleID2rulesetID, rule.ruleID)
|
||||
delete(ruleset.issuedRules, rule.ruleID)
|
||||
|
||||
if len(ruleset.issuedRules) == 0 {
|
||||
delete(r.rulesets, ruleset.rulesetID)
|
||||
if rule.nftSet != nil {
|
||||
delete(r.nftSetName2rulesetID, rule.nftSet.Name)
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// setNftRuleHandle finds rule by userdata which contains rulesetID and updates it's handle number
|
||||
//
|
||||
// This is important to do, because after we add rule to the nftables we can't update it until
|
||||
// we set correct handle value to it.
|
||||
func (r *rulesetManager) setNftRuleHandle(nftRule *nftables.Rule) error {
|
||||
split := bytes.Split(nftRule.UserData, []byte(" "))
|
||||
ruleset, ok := r.rulesets[string(split[0])]
|
||||
if !ok {
|
||||
return fmt.Errorf("ruleset not found")
|
||||
}
|
||||
*ruleset.nftRule = *nftRule
|
||||
return nil
|
||||
}
|
||||
122
client/firewall/nftables/ruleset_linux_test.go
Normal file
122
client/firewall/nftables/ruleset_linux_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRulesetManager_createRuleset(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{
|
||||
UserData: []byte(rulesetID),
|
||||
}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||
require.NotNil(t, ruleset, "createRuleset() failed")
|
||||
require.Equal(t, ruleset.rulesetID, rulesetID, "rulesetID is incorrect")
|
||||
require.Equal(t, ruleset.nftRule, &nftRule, "nftRule is incorrect")
|
||||
}
|
||||
|
||||
func TestRulesetManager_addRule(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||
|
||||
// Add a rule to the ruleset.
|
||||
ip := []byte("192.168.1.1")
|
||||
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||
require.NoError(t, err, "addRule() failed")
|
||||
require.NotNil(t, rule, "rule should not be nil")
|
||||
require.NotEqual(t, rule.ruleID, "ruleID is empty")
|
||||
require.EqualValues(t, rule.ip, ip, "ip is incorrect")
|
||||
require.Contains(t, ruleset.issuedRules, rule.ruleID, "ruleID already exists in ruleset")
|
||||
require.Contains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "ruleID already exists in ruleset manager")
|
||||
|
||||
ruleset2 := &nftRuleset{
|
||||
rulesetID: "ruleset-2",
|
||||
}
|
||||
_, err = rulesetManager.addRule(ruleset2, ip)
|
||||
require.Error(t, err, "addRule() should have failed")
|
||||
}
|
||||
|
||||
func TestRulesetManager_deleteRule(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||
|
||||
// Add a rule to the ruleset.
|
||||
ip := []byte("192.168.1.1")
|
||||
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||
require.NoError(t, err, "addRule() failed")
|
||||
require.NotNil(t, rule, "rule should not be nil")
|
||||
|
||||
ip2 := []byte("192.168.1.1")
|
||||
rule2, err := rulesetManager.addRule(ruleset, ip2)
|
||||
require.NoError(t, err, "addRule() failed")
|
||||
require.NotNil(t, rule2, "rule should not be nil")
|
||||
|
||||
hasNext := rulesetManager.deleteRule(rule)
|
||||
require.True(t, hasNext, "deleteRule() should have returned true")
|
||||
|
||||
// Check that the rule is no longer in the manager.
|
||||
require.NotContains(t, rulesetManager.issuedRuleID2rulesetID, rule.ruleID, "rule should have been deleted")
|
||||
|
||||
hasNext = rulesetManager.deleteRule(rule2)
|
||||
require.False(t, hasNext, "deleteRule() should have returned false")
|
||||
}
|
||||
|
||||
func TestRulesetManager_setNftRuleHandle(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, nil)
|
||||
// Add a rule to the ruleset.
|
||||
ip := []byte("192.168.0.1")
|
||||
|
||||
rule, err := rulesetManager.addRule(ruleset, ip)
|
||||
require.NoError(t, err, "addRule() failed")
|
||||
require.NotNil(t, rule, "rule should not be nil")
|
||||
|
||||
nftRuleCopy := nftRule
|
||||
nftRuleCopy.Handle = 2
|
||||
nftRuleCopy.UserData = []byte(rulesetID)
|
||||
err = rulesetManager.setNftRuleHandle(&nftRuleCopy)
|
||||
require.NoError(t, err, "setNftRuleHandle() failed")
|
||||
// check correct work with references
|
||||
require.Equal(t, nftRule.Handle, uint64(2), "nftRule.Handle is incorrect")
|
||||
}
|
||||
|
||||
func TestRulesetManager_getRuleset(t *testing.T) {
|
||||
// Create a ruleset manager.
|
||||
rulesetManager := newRuleManager()
|
||||
// Create a ruleset.
|
||||
rulesetID := "ruleset-1"
|
||||
nftRule := nftables.Rule{}
|
||||
nftSet := nftables.Set{
|
||||
ID: 2,
|
||||
}
|
||||
ruleset := rulesetManager.createRuleset(rulesetID, &nftRule, &nftSet)
|
||||
require.NotNil(t, ruleset, "createRuleset() failed")
|
||||
|
||||
find, ok := rulesetManager.getRuleset(rulesetID)
|
||||
require.True(t, ok, "getRuleset() failed")
|
||||
require.Equal(t, ruleset, find, "getRulesetBySetID() failed")
|
||||
|
||||
_, ok = rulesetManager.getRuleset("does-not-exist")
|
||||
require.False(t, ok, "getRuleset() failed")
|
||||
}
|
||||
@@ -21,11 +21,13 @@ type IFaceMapper interface {
|
||||
SetFilter(iface.PacketFilter) error
|
||||
}
|
||||
|
||||
// RuleSet is a set of rules grouped by a string key
|
||||
type RuleSet map[string]Rule
|
||||
|
||||
// Manager userspace firewall manager
|
||||
type Manager struct {
|
||||
outgoingRules []Rule
|
||||
incomingRules []Rule
|
||||
rulesIndex map[string]int
|
||||
outgoingRules map[string]RuleSet
|
||||
incomingRules map[string]RuleSet
|
||||
wgNetwork *net.IPNet
|
||||
decoders sync.Pool
|
||||
|
||||
@@ -48,7 +50,6 @@ type decoder struct {
|
||||
// Create userspace firewall manager constructor
|
||||
func Create(iface IFaceMapper) (*Manager, error) {
|
||||
m := &Manager{
|
||||
rulesIndex: make(map[string]int),
|
||||
decoders: sync.Pool{
|
||||
New: func() any {
|
||||
d := &decoder{
|
||||
@@ -62,6 +63,8 @@ func Create(iface IFaceMapper) (*Manager, error) {
|
||||
return d
|
||||
},
|
||||
},
|
||||
outgoingRules: make(map[string]RuleSet),
|
||||
incomingRules: make(map[string]RuleSet),
|
||||
}
|
||||
|
||||
if err := iface.SetFilter(m); err != nil {
|
||||
@@ -81,6 +84,7 @@ func (m *Manager) AddFiltering(
|
||||
dPort *fw.Port,
|
||||
direction fw.RuleDirection,
|
||||
action fw.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) (fw.Rule, error) {
|
||||
r := Rule{
|
||||
@@ -124,15 +128,17 @@ func (m *Manager) AddFiltering(
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
var p int
|
||||
if direction == fw.RuleDirectionIN {
|
||||
m.incomingRules = append(m.incomingRules, r)
|
||||
p = len(m.incomingRules) - 1
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
} else {
|
||||
m.outgoingRules = append(m.outgoingRules, r)
|
||||
p = len(m.outgoingRules) - 1
|
||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||
m.outgoingRules[r.ip.String()] = make(RuleSet)
|
||||
}
|
||||
m.outgoingRules[r.ip.String()][r.id] = r
|
||||
}
|
||||
m.rulesIndex[r.id] = p
|
||||
m.mutex.Unlock()
|
||||
|
||||
return &r, nil
|
||||
@@ -148,24 +154,20 @@ func (m *Manager) DeleteRule(rule fw.Rule) error {
|
||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||
}
|
||||
|
||||
p, ok := m.rulesIndex[r.id]
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.rulesIndex, r.id)
|
||||
|
||||
var toUpdate []Rule
|
||||
if r.direction == fw.RuleDirectionIN {
|
||||
m.incomingRules = append(m.incomingRules[:p], m.incomingRules[p+1:]...)
|
||||
toUpdate = m.incomingRules
|
||||
_, ok := m.incomingRules[r.ip.String()][r.id]
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.incomingRules[r.ip.String()], r.id)
|
||||
} else {
|
||||
m.outgoingRules = append(m.outgoingRules[:p], m.outgoingRules[p+1:]...)
|
||||
toUpdate = m.outgoingRules
|
||||
_, ok := m.outgoingRules[r.ip.String()][r.id]
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(m.outgoingRules[r.ip.String()], r.id)
|
||||
}
|
||||
|
||||
for i := 0; i < len(toUpdate); i++ {
|
||||
m.rulesIndex[toUpdate[i].id] = i
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -174,13 +176,15 @@ func (m *Manager) Reset() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
m.outgoingRules = m.outgoingRules[:0]
|
||||
m.incomingRules = m.incomingRules[:0]
|
||||
m.rulesIndex = make(map[string]int)
|
||||
m.outgoingRules = make(map[string]RuleSet)
|
||||
m.incomingRules = make(map[string]RuleSet)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush doesn't need to be implemented for this manager
|
||||
func (m *Manager) Flush() error { return nil }
|
||||
|
||||
// DropOutgoing filter outgoing packets
|
||||
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
||||
return m.dropFilter(packetData, m.outgoingRules, false)
|
||||
@@ -192,7 +196,7 @@ func (m *Manager) DropIncoming(packetData []byte) bool {
|
||||
}
|
||||
|
||||
// dropFilter imlements same logic for booth direction of the traffic
|
||||
func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket bool) bool {
|
||||
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
@@ -224,37 +228,49 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
|
||||
log.Errorf("unknown layer: %v", d.decoded[0])
|
||||
return true
|
||||
}
|
||||
payloadLayer := d.decoded[1]
|
||||
|
||||
// check if IP address match by IP
|
||||
var ip net.IP
|
||||
switch ipLayer {
|
||||
case layers.LayerTypeIPv4:
|
||||
if isIncomingPacket {
|
||||
ip = d.ip4.SrcIP
|
||||
} else {
|
||||
ip = d.ip4.DstIP
|
||||
}
|
||||
case layers.LayerTypeIPv6:
|
||||
if isIncomingPacket {
|
||||
ip = d.ip6.SrcIP
|
||||
} else {
|
||||
ip = d.ip6.DstIP
|
||||
}
|
||||
}
|
||||
|
||||
filter, ok := validateRule(ip, packetData, rules[ip.String()], d)
|
||||
if ok {
|
||||
return filter
|
||||
}
|
||||
filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d)
|
||||
if ok {
|
||||
return filter
|
||||
}
|
||||
filter, ok = validateRule(ip, packetData, rules["::"], d)
|
||||
if ok {
|
||||
return filter
|
||||
}
|
||||
|
||||
// default policy is DROP ALL
|
||||
return true
|
||||
}
|
||||
|
||||
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
|
||||
payloadLayer := d.decoded[1]
|
||||
for _, rule := range rules {
|
||||
if rule.matchByIP {
|
||||
switch ipLayer {
|
||||
case layers.LayerTypeIPv4:
|
||||
if isIncomingPacket {
|
||||
if !d.ip4.SrcIP.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if !d.ip4.DstIP.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
case layers.LayerTypeIPv6:
|
||||
if isIncomingPacket {
|
||||
if !d.ip6.SrcIP.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if !d.ip6.DstIP.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
if rule.matchByIP && !ip.Equal(rule.ip) {
|
||||
continue
|
||||
}
|
||||
|
||||
if rule.protoLayer == layerTypeAll {
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
|
||||
if payloadLayer != rule.protoLayer {
|
||||
@@ -264,38 +280,36 @@ func (m *Manager) dropFilter(packetData []byte, rules []Rule, isIncomingPacket b
|
||||
switch payloadLayer {
|
||||
case layers.LayerTypeTCP:
|
||||
if rule.sPort == 0 && rule.dPort == 0 {
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
// if rule has UDP hook (and if we are here we match this rule)
|
||||
// we ignore rule.drop and call this hook
|
||||
if rule.udpHook != nil {
|
||||
return rule.udpHook(packetData)
|
||||
return rule.udpHook(packetData), true
|
||||
}
|
||||
|
||||
if rule.sPort == 0 && rule.dPort == 0 {
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return rule.drop
|
||||
return rule.drop, true
|
||||
}
|
||||
}
|
||||
|
||||
// default policy is DROP ALL
|
||||
return true
|
||||
return false, false
|
||||
}
|
||||
|
||||
// SetNetwork of the wireguard interface to which filtering applied
|
||||
@@ -325,19 +339,19 @@ func (m *Manager) AddUDPPacketHook(
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
var toUpdate []Rule
|
||||
if in {
|
||||
r.direction = fw.RuleDirectionIN
|
||||
m.incomingRules = append([]Rule{r}, m.incomingRules...)
|
||||
toUpdate = m.incomingRules
|
||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
||||
m.incomingRules[r.ip.String()] = make(map[string]Rule)
|
||||
}
|
||||
m.incomingRules[r.ip.String()][r.id] = r
|
||||
} else {
|
||||
m.outgoingRules = append([]Rule{r}, m.outgoingRules...)
|
||||
toUpdate = m.outgoingRules
|
||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
||||
m.outgoingRules[r.ip.String()] = make(map[string]Rule)
|
||||
}
|
||||
m.outgoingRules[r.ip.String()][r.id] = r
|
||||
}
|
||||
|
||||
for i := range toUpdate {
|
||||
m.rulesIndex[toUpdate[i].id] = i
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
|
||||
return r.id
|
||||
@@ -345,14 +359,18 @@ func (m *Manager) AddUDPPacketHook(
|
||||
|
||||
// RemovePacketHook removes packet hook by given ID
|
||||
func (m *Manager) RemovePacketHook(hookID string) error {
|
||||
for _, r := range m.incomingRules {
|
||||
if r.id == hookID {
|
||||
return m.DeleteRule(&r)
|
||||
for _, arr := range m.incomingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
return m.DeleteRule(&r)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, r := range m.outgoingRules {
|
||||
if r.id == hookID {
|
||||
return m.DeleteRule(&r)
|
||||
for _, arr := range m.outgoingRules {
|
||||
for _, r := range arr {
|
||||
if r.id == hookID {
|
||||
return m.DeleteRule(&r)
|
||||
}
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("hook with given id not found")
|
||||
|
||||
@@ -63,7 +63,7 @@ func TestManagerAddFiltering(t *testing.T) {
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -98,7 +98,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
||||
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -111,7 +111,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
action = fw.ActionDrop
|
||||
comment = "Test rule 2"
|
||||
|
||||
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
||||
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -123,8 +123,8 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if idx, ok := m.rulesIndex[rule2.GetRuleID()]; !ok || len(m.incomingRules) != 1 || idx != 0 {
|
||||
t.Errorf("rule2 is not in the rulesIndex")
|
||||
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; !ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
}
|
||||
|
||||
err = m.DeleteRule(rule2)
|
||||
@@ -133,8 +133,8 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(m.rulesIndex) != 0 || len(m.incomingRules) != 0 {
|
||||
t.Errorf("rule1 still in the rulesIndex")
|
||||
if _, ok := m.incomingRules[ip.String()][rule2.GetRuleID()]; ok {
|
||||
t.Errorf("rule2 is not in the incomingRules")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,26 +169,29 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := &Manager{
|
||||
incomingRules: []Rule{},
|
||||
outgoingRules: []Rule{},
|
||||
rulesIndex: make(map[string]int),
|
||||
incomingRules: map[string]RuleSet{},
|
||||
outgoingRules: map[string]RuleSet{},
|
||||
}
|
||||
|
||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||
|
||||
var addedRule Rule
|
||||
if tt.in {
|
||||
if len(manager.incomingRules) != 1 {
|
||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||
return
|
||||
}
|
||||
addedRule = manager.incomingRules[0]
|
||||
for _, rule := range manager.incomingRules[tt.ip.String()] {
|
||||
addedRule = rule
|
||||
}
|
||||
} else {
|
||||
if len(manager.outgoingRules) != 1 {
|
||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
||||
return
|
||||
}
|
||||
addedRule = manager.outgoingRules[0]
|
||||
for _, rule := range manager.outgoingRules[tt.ip.String()] {
|
||||
addedRule = rule
|
||||
}
|
||||
}
|
||||
|
||||
if !tt.ip.Equal(addedRule.ip) {
|
||||
@@ -211,17 +214,6 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
t.Errorf("expected udpHook to be set")
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure rulesIndex is correctly updated
|
||||
index, ok := manager.rulesIndex[addedRule.id]
|
||||
if !ok {
|
||||
t.Errorf("expected rule to be in rulesIndex")
|
||||
return
|
||||
}
|
||||
if index != 0 {
|
||||
t.Errorf("expected rule index to be 0, got %d", index)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -244,7 +236,7 @@ func TestManagerReset(t *testing.T) {
|
||||
action := fw.ActionDrop
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, comment)
|
||||
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -256,7 +248,7 @@ func TestManagerReset(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(m.rulesIndex) != 0 || len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
|
||||
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 {
|
||||
t.Errorf("rules is not empty")
|
||||
}
|
||||
}
|
||||
@@ -282,7 +274,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
action := fw.ActionAccept
|
||||
comment := "Test rule"
|
||||
|
||||
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, comment)
|
||||
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -346,10 +338,12 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
|
||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||
found := false
|
||||
for _, rule := range manager.outgoingRules {
|
||||
if rule.id == hookID {
|
||||
found = true
|
||||
break
|
||||
for _, arr := range manager.outgoingRules {
|
||||
for _, rule := range arr {
|
||||
if rule.id == hookID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -364,9 +358,11 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
}
|
||||
|
||||
// Assert the hook is removed by checking it in the manager's outgoing rules
|
||||
for _, rule := range manager.outgoingRules {
|
||||
if rule.id == hookID {
|
||||
t.Fatalf("The hook was not removed properly.")
|
||||
for _, arr := range manager.outgoingRules {
|
||||
for _, rule := range arr {
|
||||
if rule.id == hookID {
|
||||
t.Fatalf("The hook was not removed properly.")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -394,9 +390,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []int{1000 + i}}
|
||||
if i%2 == 0 {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
} else {
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "accept HTTP traffic")
|
||||
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
||||
}
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
@@ -33,9 +33,22 @@ type Manager interface {
|
||||
|
||||
// DefaultManager uses firewall manager to handle
|
||||
type DefaultManager struct {
|
||||
manager firewall.Manager
|
||||
rulesPairs map[string][]firewall.Rule
|
||||
mutex sync.Mutex
|
||||
manager firewall.Manager
|
||||
ipsetCounter int
|
||||
rulesPairs map[string][]firewall.Rule
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
type ipsetInfo struct {
|
||||
name string
|
||||
ipCount int
|
||||
}
|
||||
|
||||
func newDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||
return &DefaultManager{
|
||||
manager: fm,
|
||||
rulesPairs: make(map[string][]firewall.Rule),
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
||||
@@ -61,6 +74,12 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := d.manager.Flush(); err != nil {
|
||||
log.Error("failed to flush firewall rules: ", err)
|
||||
}
|
||||
}()
|
||||
|
||||
rules, squashedProtocols := d.squashAcceptRules(networkMap)
|
||||
|
||||
enableSSH := (networkMap.PeerConfig != nil &&
|
||||
@@ -108,8 +127,32 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
||||
|
||||
applyFailed := false
|
||||
newRulePairs := make(map[string][]firewall.Rule)
|
||||
ipsetByRuleSelectors := make(map[string]*ipsetInfo)
|
||||
|
||||
// calculate which IP's can be grouped in by which ipset
|
||||
// to do that we use rule selector (which is just rule properties without IP's)
|
||||
for _, r := range rules {
|
||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r)
|
||||
selector := d.getRuleGroupingSelector(r)
|
||||
ipset, ok := ipsetByRuleSelectors[selector]
|
||||
if !ok {
|
||||
ipset = &ipsetInfo{}
|
||||
}
|
||||
|
||||
ipset.ipCount++
|
||||
ipsetByRuleSelectors[selector] = ipset
|
||||
}
|
||||
|
||||
for _, r := range rules {
|
||||
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
||||
// it's IP address can be used in the ipset for firewall manager which supports it
|
||||
ipset := ipsetByRuleSelectors[d.getRuleGroupingSelector(r)]
|
||||
ipsetName := ""
|
||||
if ipset.name == "" {
|
||||
d.ipsetCounter++
|
||||
ipset.name = fmt.Sprintf("nb%07d", d.ipsetCounter)
|
||||
}
|
||||
ipsetName = ipset.name
|
||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
||||
if err != nil {
|
||||
log.Errorf("failed to apply firewall rule: %+v, %v", r, err)
|
||||
applyFailed = true
|
||||
@@ -154,7 +197,10 @@ func (d *DefaultManager) Stop() {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (string, []firewall.Rule, error) {
|
||||
func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
r *mgmProto.FirewallRule,
|
||||
ipsetName string,
|
||||
) (string, []firewall.Rule, error) {
|
||||
ip := net.ParseIP(r.PeerIP)
|
||||
if ip == nil {
|
||||
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
@@ -190,9 +236,9 @@ func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (stri
|
||||
var err error
|
||||
switch r.Direction {
|
||||
case mgmProto.FirewallRule_IN:
|
||||
rules, err = d.addInRules(ip, protocol, port, action, "")
|
||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
||||
case mgmProto.FirewallRule_OUT:
|
||||
rules, err = d.addOutRules(ip, protocol, port, action, "")
|
||||
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
||||
default:
|
||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
}
|
||||
@@ -205,9 +251,17 @@ func (d *DefaultManager) protoRuleToFirewallRule(r *mgmProto.FirewallRule) (stri
|
||||
return ruleID, rules, nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) {
|
||||
func (d *DefaultManager) addInRules(
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var rules []firewall.Rule
|
||||
rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionIN, action, comment)
|
||||
rule, err := d.manager.AddFiltering(
|
||||
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
@@ -217,7 +271,8 @@ func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionOUT, action, comment)
|
||||
rule, err = d.manager.AddFiltering(
|
||||
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
@@ -225,9 +280,17 @@ func (d *DefaultManager) addInRules(ip net.IP, protocol firewall.Protocol, port
|
||||
return append(rules, rule), nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port *firewall.Port, action firewall.Action, comment string) ([]firewall.Rule, error) {
|
||||
func (d *DefaultManager) addOutRules(
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
comment string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var rules []firewall.Rule
|
||||
rule, err := d.manager.AddFiltering(ip, protocol, nil, port, firewall.RuleDirectionOUT, action, comment)
|
||||
rule, err := d.manager.AddFiltering(
|
||||
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
@@ -237,7 +300,8 @@ func (d *DefaultManager) addOutRules(ip net.IP, protocol firewall.Protocol, port
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
rule, err = d.manager.AddFiltering(ip, protocol, port, nil, firewall.RuleDirectionIN, action, comment)
|
||||
rule, err = d.manager.AddFiltering(
|
||||
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
|
||||
}
|
||||
@@ -282,6 +346,10 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
in := protoMatch{}
|
||||
out := protoMatch{}
|
||||
|
||||
// trace which type of protocols was squashed
|
||||
squashedRules := []*mgmProto.FirewallRule{}
|
||||
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
|
||||
|
||||
// this function we use to do calculation, can we squash the rules by protocol or not.
|
||||
// We summ amount of Peers IP for given protocol we found in original rules list.
|
||||
// But we zeroed the IP's for protocol if:
|
||||
@@ -298,12 +366,22 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
if _, ok := protocols[r.Protocol]; !ok {
|
||||
protocols[r.Protocol] = map[string]int{}
|
||||
}
|
||||
match := protocols[r.Protocol]
|
||||
|
||||
if _, ok := match[r.PeerIP]; ok {
|
||||
// special case, when we recieve this all network IP address
|
||||
// it means that rules for that protocol was already optimized on the
|
||||
// management side
|
||||
if r.PeerIP == "0.0.0.0" {
|
||||
squashedRules = append(squashedRules, r)
|
||||
squashedProtocols[r.Protocol] = struct{}{}
|
||||
return
|
||||
}
|
||||
match[r.PeerIP] = i
|
||||
|
||||
ipset := protocols[r.Protocol]
|
||||
|
||||
if _, ok := ipset[r.PeerIP]; ok {
|
||||
return
|
||||
}
|
||||
ipset[r.PeerIP] = i
|
||||
}
|
||||
|
||||
for i, r := range networkMap.FirewallRules {
|
||||
@@ -324,9 +402,6 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
mgmProto.FirewallRule_UDP,
|
||||
}
|
||||
|
||||
// trace which type of protocols was squashed
|
||||
squashedRules := []*mgmProto.FirewallRule{}
|
||||
squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
|
||||
squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
|
||||
for _, protocol := range protocolOrders {
|
||||
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
|
||||
@@ -382,6 +457,11 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
return append(rules, squashedRules...), squashedProtocols
|
||||
}
|
||||
|
||||
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
||||
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
||||
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
|
||||
}
|
||||
|
||||
func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) firewall.Protocol {
|
||||
switch protocol {
|
||||
case mgmProto.FirewallRule_TCP:
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
)
|
||||
|
||||
@@ -18,10 +17,7 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &DefaultManager{
|
||||
manager: fm,
|
||||
rulesPairs: make(map[string][]firewall.Rule),
|
||||
}, nil
|
||||
return newDefaultManager(fm), nil
|
||||
}
|
||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
@@ -29,8 +29,5 @@ func Create(iface IFaceMapper) (manager *DefaultManager, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
return &DefaultManager{
|
||||
manager: fm,
|
||||
rulesPairs: make(map[string][]firewall.Rule),
|
||||
}, nil
|
||||
return newDefaultManager(fm), nil
|
||||
}
|
||||
|
||||
@@ -215,10 +215,12 @@ func update(input ConfigInput) (*Config, error) {
|
||||
}
|
||||
|
||||
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
|
||||
log.Infof("new pre-shared key provided, updated to %s (old value %s)",
|
||||
*input.PreSharedKey, config.PreSharedKey)
|
||||
config.PreSharedKey = *input.PreSharedKey
|
||||
refresh = true
|
||||
if *input.PreSharedKey != "" {
|
||||
log.Infof("new pre-shared key provides, updated to %s (old value %s)",
|
||||
*input.PreSharedKey, config.PreSharedKey)
|
||||
config.PreSharedKey = *input.PreSharedKey
|
||||
refresh = true
|
||||
}
|
||||
}
|
||||
|
||||
if config.SSHKey == "" {
|
||||
|
||||
@@ -63,7 +63,22 @@ func TestGetConfig(t *testing.T) {
|
||||
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
||||
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
||||
|
||||
// case 4: existing config, but new managementURL has been provided -> update config
|
||||
// case 4: new empty pre-shared key config -> fetch it
|
||||
newPreSharedKey := ""
|
||||
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||
ManagementURL: managementURL,
|
||||
AdminURL: adminURL,
|
||||
ConfigPath: path,
|
||||
PreSharedKey: &newPreSharedKey,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, config.ManagementURL.String(), managementURL)
|
||||
assert.Equal(t, config.PreSharedKey, preSharedKey)
|
||||
|
||||
// case 5: existing config, but new managementURL has been provided -> update config
|
||||
newManagementURL := "https://test.newManagement.url:33071"
|
||||
config, err = UpdateOrCreateConfig(ConfigInput{
|
||||
ManagementURL: newManagementURL,
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
@@ -24,7 +25,24 @@ import (
|
||||
)
|
||||
|
||||
// RunClient with main logic.
|
||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener) error {
|
||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
|
||||
return runClient(ctx, config, statusRecorder, MobileDependency{})
|
||||
}
|
||||
|
||||
// RunClientMobile with main logic on mobile system
|
||||
func RunClientMobile(ctx context.Context, config *Config, statusRecorder *peer.Status, tunAdapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, routeListener routemanager.RouteListener, dnsAddresses []string, dnsReadyListener dns.ReadyListener) error {
|
||||
// in case of non Android os these variables will be nil
|
||||
mobileDependency := MobileDependency{
|
||||
TunAdapter: tunAdapter,
|
||||
IFaceDiscover: iFaceDiscover,
|
||||
RouteListener: routeListener,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
}
|
||||
return runClient(ctx, config, statusRecorder, mobileDependency)
|
||||
}
|
||||
|
||||
func runClient(ctx context.Context, config *Config, statusRecorder *peer.Status, mobileDependency MobileDependency) error {
|
||||
backOff := &backoff.ExponentialBackOff{
|
||||
InitialInterval: time.Second,
|
||||
RandomizationFactor: 1,
|
||||
@@ -151,14 +169,7 @@ func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status,
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
// in case of non Android os these variables will be nil
|
||||
md := MobileDependency{
|
||||
TunAdapter: tunAdapter,
|
||||
IFaceDiscover: iFaceDiscover,
|
||||
RouteListener: routeListener,
|
||||
}
|
||||
|
||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, md, statusRecorder)
|
||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder)
|
||||
err = engine.Start()
|
||||
if err != nil {
|
||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
type androidHostManager struct {
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||
return &androidHostManager{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,8 +9,6 @@ import (
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -34,7 +32,7 @@ type systemConfigurator struct {
|
||||
createdKeys map[string]struct{}
|
||||
}
|
||||
|
||||
func newHostManager(_ *iface.WGIface) (hostManager, error) {
|
||||
func newHostManager(_ WGIface) (hostManager, error) {
|
||||
return &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}, nil
|
||||
|
||||
@@ -5,10 +5,10 @@ package dns
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -25,7 +25,7 @@ const (
|
||||
|
||||
type osManagerType int
|
||||
|
||||
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||
osManager, err := getOSDNSManagerType()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -6,8 +6,6 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -33,7 +31,7 @@ type registryConfigurator struct {
|
||||
existingSearchDomains []string
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface *iface.WGIface) (hostManager, error) {
|
||||
func newHostManager(wgInterface WGIface) (hostManager, error) {
|
||||
guid, err := wgInterface.GetInterfaceGUIDString()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -31,6 +31,11 @@ func (m *MockServer) DnsIP() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *MockServer) OnUpdatedHostDNSServer(strings []string) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
// UpdateDNSServer mock implementation of UpdateDNSServer from Server interface
|
||||
func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||
if m.UpdateDNSServerFunc != nil {
|
||||
|
||||
@@ -14,8 +14,6 @@ import (
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -72,7 +70,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
|
||||
}
|
||||
}
|
||||
|
||||
func newNetworkManagerDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
||||
func newNetworkManagerDbusConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const resolvconfCommand = "resolvconf"
|
||||
@@ -18,7 +16,7 @@ type resolvconf struct {
|
||||
ifaceName string
|
||||
}
|
||||
|
||||
func newResolvConfConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
||||
func newResolvConfConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||
return &resolvconf{
|
||||
ifaceName: wgInterface.Name(),
|
||||
}, nil
|
||||
|
||||
@@ -3,29 +3,20 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/mitchellh/hashstructure/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPort = 53
|
||||
customPort = 5053
|
||||
defaultIP = "127.0.0.1"
|
||||
customIP = "127.0.0.153"
|
||||
)
|
||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||
type ReadyListener interface {
|
||||
OnReady()
|
||||
}
|
||||
|
||||
// Server is a dns server interface
|
||||
type Server interface {
|
||||
@@ -33,6 +24,7 @@ type Server interface {
|
||||
Stop()
|
||||
DnsIP() string
|
||||
UpdateDNSServer(serial uint64, update nbdns.Config) error
|
||||
OnUpdatedHostDNSServer(strings []string)
|
||||
}
|
||||
|
||||
type registeredHandlerMap map[string]handlerWithStop
|
||||
@@ -42,21 +34,19 @@ type DefaultServer struct {
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
mux sync.Mutex
|
||||
udpFilterHookID string
|
||||
server *dns.Server
|
||||
dnsMux *dns.ServeMux
|
||||
service service
|
||||
dnsMuxMap registeredHandlerMap
|
||||
localResolver *localResolver
|
||||
wgInterface *iface.WGIface
|
||||
wgInterface WGIface
|
||||
hostManager hostManager
|
||||
updateSerial uint64
|
||||
listenerIsRunning bool
|
||||
runtimePort int
|
||||
runtimeIP string
|
||||
previousConfigHash uint64
|
||||
currentConfig hostDNSConfig
|
||||
customAddress *netip.AddrPort
|
||||
enabled bool
|
||||
|
||||
// permanent related properties
|
||||
permanent bool
|
||||
hostsDnsList []string
|
||||
hostsDnsListLock sync.Mutex
|
||||
}
|
||||
|
||||
type handlerWithStop interface {
|
||||
@@ -70,9 +60,7 @@ type muxUpdate struct {
|
||||
}
|
||||
|
||||
// NewDefaultServer returns a new dns server
|
||||
func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAddress string, initialDnsCfg *nbdns.Config) (*DefaultServer, error) {
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string) (*DefaultServer, error) {
|
||||
var addrPort *netip.AddrPort
|
||||
if customAddress != "" {
|
||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||
@@ -82,37 +70,44 @@ func NewDefaultServer(ctx context.Context, wgInterface *iface.WGIface, customAdd
|
||||
addrPort = &parsedAddrPort
|
||||
}
|
||||
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
var dnsService service
|
||||
if wgInterface.IsUserspaceBind() {
|
||||
dnsService = newServiceViaMemory(wgInterface)
|
||||
} else {
|
||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||
}
|
||||
|
||||
return newDefaultServer(ctx, wgInterface, dnsService), nil
|
||||
}
|
||||
|
||||
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
||||
func NewDefaultServerPermanentUpstream(ctx context.Context, wgInterface WGIface, hostsDnsList []string) *DefaultServer {
|
||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||
ds := newDefaultServer(ctx, wgInterface, newServiceViaMemory(wgInterface))
|
||||
ds.permanent = true
|
||||
ds.hostsDnsList = hostsDnsList
|
||||
ds.addHostRootZone()
|
||||
setServerDns(ds)
|
||||
return ds
|
||||
}
|
||||
|
||||
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service) *DefaultServer {
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
ctxCancel: stop,
|
||||
server: &dns.Server{
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
},
|
||||
dnsMux: mux,
|
||||
service: dnsService,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
wgInterface: wgInterface,
|
||||
customAddress: addrPort,
|
||||
wgInterface: wgInterface,
|
||||
}
|
||||
|
||||
if initialDnsCfg != nil {
|
||||
defaultServer.enabled = hasValidDnsServer(initialDnsCfg)
|
||||
}
|
||||
|
||||
if wgInterface.IsUserspaceBind() {
|
||||
defaultServer.evelRuntimeAddressForUserspace()
|
||||
}
|
||||
|
||||
return defaultServer, nil
|
||||
return defaultServer
|
||||
}
|
||||
|
||||
// Initialize instantiate host manager. It required to be initialized wginterface
|
||||
// Initialize instantiate host manager and the dns service
|
||||
func (s *DefaultServer) Initialize() (err error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
@@ -121,72 +116,23 @@ func (s *DefaultServer) Initialize() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !s.wgInterface.IsUserspaceBind() {
|
||||
s.evalRuntimeAddress()
|
||||
if s.permanent {
|
||||
err = s.service.Listen()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
s.hostManager, err = newHostManager(s.wgInterface)
|
||||
return
|
||||
}
|
||||
|
||||
// listen runs the listener in a go routine
|
||||
func (s *DefaultServer) listen() {
|
||||
// nil check required in unit tests
|
||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() {
|
||||
s.udpFilterHookID = s.filterDNSTraffic()
|
||||
s.setListenerStatus(true)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("starting dns on %s", s.server.Addr)
|
||||
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
defer s.setListenerStatus(false)
|
||||
|
||||
err := s.server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// DnsIP returns the DNS resolver server IP address
|
||||
//
|
||||
// When kernel space interface used it return real DNS server listener IP address
|
||||
// For bind interface, fake DNS resolver address returned (second last IP address from Nebird network)
|
||||
func (s *DefaultServer) DnsIP() string {
|
||||
if !s.enabled {
|
||||
return ""
|
||||
}
|
||||
return s.runtimeIP
|
||||
}
|
||||
|
||||
func (s *DefaultServer) getFirstListenerAvailable() (string, int, error) {
|
||||
ips := []string{defaultIP, customIP}
|
||||
if runtime.GOOS != "darwin" && s.wgInterface != nil {
|
||||
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
||||
}
|
||||
ports := []int{defaultPort, customPort}
|
||||
for _, port := range ports {
|
||||
for _, ip := range ips {
|
||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||
if err == nil {
|
||||
err = probeListener.Close()
|
||||
if err != nil {
|
||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||
}
|
||||
return ip, port, nil
|
||||
}
|
||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||
}
|
||||
}
|
||||
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) setListenerStatus(running bool) {
|
||||
s.listenerIsRunning = running
|
||||
return s.service.RuntimeIP()
|
||||
}
|
||||
|
||||
// Stop stops the server
|
||||
@@ -202,37 +148,23 @@ func (s *DefaultServer) Stop() {
|
||||
}
|
||||
}
|
||||
|
||||
err := s.stopListener()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
s.service.Stop()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) stopListener() error {
|
||||
if s.wgInterface != nil && s.wgInterface.IsUserspaceBind() && s.listenerIsRunning {
|
||||
// udpFilterHookID here empty only in the unit tests
|
||||
if filter := s.wgInterface.GetFilter(); filter != nil && s.udpFilterHookID != "" {
|
||||
if err := filter.RemovePacketHook(s.udpFilterHookID); err != nil {
|
||||
log.Errorf("unable to remove DNS packet hook: %s", err)
|
||||
}
|
||||
}
|
||||
s.udpFilterHookID = ""
|
||||
s.listenerIsRunning = false
|
||||
return nil
|
||||
}
|
||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||
// It will be applied if the mgm server do not enforce DNS settings for root zone
|
||||
func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
||||
s.hostsDnsListLock.Lock()
|
||||
defer s.hostsDnsListLock.Unlock()
|
||||
|
||||
if !s.listenerIsRunning {
|
||||
return nil
|
||||
s.hostsDnsList = hostsDnsList
|
||||
_, ok := s.dnsMuxMap[nbdns.RootZone]
|
||||
if ok {
|
||||
log.Debugf("on new host DNS config but skip to apply it")
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.server.ShutdownContext(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stopping dns server listener returned an error: %v", err)
|
||||
}
|
||||
return nil
|
||||
log.Debugf("update host DNS settings: %+v", hostsDnsList)
|
||||
s.addHostRootZone()
|
||||
}
|
||||
|
||||
// UpdateDNSServer processes an update received from the management service
|
||||
@@ -283,12 +215,10 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
|
||||
func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
// is the service should be disabled, we stop the listener or fake resolver
|
||||
// and proceed with a regular update to clean up the handlers and records
|
||||
if !update.ServiceEnable {
|
||||
if err := s.stopListener(); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
} else if !s.listenerIsRunning {
|
||||
s.listen()
|
||||
if update.ServiceEnable {
|
||||
_ = s.service.Listen()
|
||||
} else if !s.permanent {
|
||||
s.service.Stop()
|
||||
}
|
||||
|
||||
localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
|
||||
@@ -299,15 +229,14 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("not applying dns update, error: %v", err)
|
||||
}
|
||||
|
||||
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...)
|
||||
|
||||
s.updateMux(muxUpdates)
|
||||
s.updateLocalResolver(localRecords)
|
||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.runtimeIP, s.runtimePort)
|
||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||
|
||||
hostUpdate := s.currentConfig
|
||||
if s.runtimePort != defaultPort && !s.hostManager.supportCustomPort() {
|
||||
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
||||
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
||||
"Learn more at: https://netbird.io/docs/how-to-guides/nameservers#local-resolver")
|
||||
hostUpdate.routeAll = false
|
||||
@@ -412,19 +341,32 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
|
||||
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
|
||||
muxUpdateMap := make(registeredHandlerMap)
|
||||
|
||||
var isContainRootUpdate bool
|
||||
|
||||
for _, update := range muxUpdates {
|
||||
s.registerMux(update.domain, update.handler)
|
||||
s.service.RegisterMux(update.domain, update.handler)
|
||||
muxUpdateMap[update.domain] = update.handler
|
||||
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
|
||||
existingHandler.stop()
|
||||
}
|
||||
|
||||
if update.domain == nbdns.RootZone {
|
||||
isContainRootUpdate = true
|
||||
}
|
||||
}
|
||||
|
||||
for key, existingHandler := range s.dnsMuxMap {
|
||||
_, found := muxUpdateMap[key]
|
||||
if !found {
|
||||
existingHandler.stop()
|
||||
s.deregisterMux(key)
|
||||
if !isContainRootUpdate && key == nbdns.RootZone {
|
||||
s.hostsDnsListLock.Lock()
|
||||
s.addHostRootZone()
|
||||
s.hostsDnsListLock.Unlock()
|
||||
existingHandler.stop()
|
||||
} else {
|
||||
existingHandler.stop()
|
||||
s.service.DeregisterMux(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -455,14 +397,6 @@ func getNSHostPort(ns nbdns.NameServer) string {
|
||||
return fmt.Sprintf("%s:%d", ns.IP.String(), ns.Port)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) registerMux(pattern string, handler dns.Handler) {
|
||||
s.dnsMux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) deregisterMux(pattern string) {
|
||||
s.dnsMux.HandleRemove(pattern)
|
||||
}
|
||||
|
||||
// upstreamCallbacks returns two functions, the first one is used to deactivate
|
||||
// the upstream resolver from the configuration, the second one is used to
|
||||
// reactivate it. Not allowed to call reactivate before deactivate.
|
||||
@@ -490,7 +424,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
for i, item := range s.currentConfig.domains {
|
||||
if _, found := removeIndex[item.domain]; found {
|
||||
s.currentConfig.domains[i].disabled = true
|
||||
s.deregisterMux(item.domain)
|
||||
s.service.DeregisterMux(item.domain)
|
||||
removeIndex[item.domain] = i
|
||||
}
|
||||
}
|
||||
@@ -507,7 +441,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
continue
|
||||
}
|
||||
s.currentConfig.domains[i].disabled = false
|
||||
s.registerMux(domain, handler)
|
||||
s.service.RegisterMux(domain, handler)
|
||||
}
|
||||
|
||||
l := log.WithField("nameservers", nsGroup.NameServers)
|
||||
@@ -523,93 +457,13 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
return
|
||||
}
|
||||
|
||||
func (s *DefaultServer) filterDNSTraffic() string {
|
||||
filter := s.wgInterface.GetFilter()
|
||||
if filter == nil {
|
||||
log.Error("can't set DNS filter, filter not initialized")
|
||||
return ""
|
||||
func (s *DefaultServer) addHostRootZone() {
|
||||
handler := newUpstreamResolver(s.ctx)
|
||||
handler.upstreamServers = make([]string, len(s.hostsDnsList))
|
||||
for n, ua := range s.hostsDnsList {
|
||||
handler.upstreamServers[n] = fmt.Sprintf("%s:53", ua)
|
||||
}
|
||||
|
||||
firstLayerDecoder := layers.LayerTypeIPv4
|
||||
if s.wgInterface.Address().Network.IP.To4() == nil {
|
||||
firstLayerDecoder = layers.LayerTypeIPv6
|
||||
}
|
||||
|
||||
hook := func(packetData []byte) bool {
|
||||
// Decode the packet
|
||||
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
||||
|
||||
// Get the UDP layer
|
||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||
udp := udpLayer.(*layers.UDP)
|
||||
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(udp.Payload); err != nil {
|
||||
log.Tracef("parse DNS request: %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
writer := responseWriter{
|
||||
packet: packet,
|
||||
device: s.wgInterface.GetDevice().Device,
|
||||
}
|
||||
go s.dnsMux.ServeDNS(&writer, msg)
|
||||
return true
|
||||
}
|
||||
|
||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) evelRuntimeAddressForUserspace() {
|
||||
s.runtimeIP = getLastIPFromNetwork(s.wgInterface.Address().Network, 1)
|
||||
s.runtimePort = defaultPort
|
||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||
}
|
||||
|
||||
func (s *DefaultServer) evalRuntimeAddress() {
|
||||
defer func() {
|
||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||
}()
|
||||
|
||||
if s.customAddress != nil {
|
||||
s.runtimeIP = s.customAddress.Addr().String()
|
||||
s.runtimePort = int(s.customAddress.Port())
|
||||
return
|
||||
}
|
||||
|
||||
ip, port, err := s.getFirstListenerAvailable()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
s.runtimeIP = ip
|
||||
s.runtimePort = port
|
||||
}
|
||||
|
||||
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
|
||||
// Calculate the last IP in the CIDR range
|
||||
var endIP net.IP
|
||||
for i := 0; i < len(network.IP); i++ {
|
||||
endIP = append(endIP, network.IP[i]|^network.Mask[i])
|
||||
}
|
||||
|
||||
// convert to big.Int
|
||||
endInt := big.NewInt(0)
|
||||
endInt.SetBytes(endIP)
|
||||
|
||||
// subtract fromEnd from the last ip
|
||||
fromEndBig := big.NewInt(int64(fromEnd))
|
||||
resultInt := big.NewInt(0)
|
||||
resultInt.Sub(endInt, fromEndBig)
|
||||
|
||||
return net.IP(resultInt.Bytes()).String()
|
||||
}
|
||||
|
||||
func hasValidDnsServer(cfg *nbdns.Config) bool {
|
||||
for _, c := range cfg.NameServerGroups {
|
||||
if c.Primary {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
handler.deactivate = func() {}
|
||||
handler.reactivate = func() {}
|
||||
s.service.RegisterMux(nbdns.RootZone, handler)
|
||||
}
|
||||
|
||||
29
client/internal/dns/server_export.go
Normal file
29
client/internal/dns/server_export.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
mutex sync.Mutex
|
||||
server Server
|
||||
)
|
||||
|
||||
// GetServerDns export the DNS server instance in static way. It used by the Mobile client
|
||||
func GetServerDns() (Server, error) {
|
||||
mutex.Lock()
|
||||
if server == nil {
|
||||
mutex.Unlock()
|
||||
return nil, fmt.Errorf("DNS server not instantiated yet")
|
||||
}
|
||||
s := server
|
||||
mutex.Unlock()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func setServerDns(newServerServer Server) {
|
||||
mutex.Lock()
|
||||
server = newServerServer
|
||||
defer mutex.Unlock()
|
||||
}
|
||||
24
client/internal/dns/server_export_test.go
Normal file
24
client/internal/dns/server_export_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetServerDns(t *testing.T) {
|
||||
_, err := GetServerDns()
|
||||
if err == nil {
|
||||
t.Errorf("invalid dns server instance")
|
||||
}
|
||||
|
||||
srv := &MockServer{}
|
||||
setServerDns(srv)
|
||||
|
||||
srvB, err := GetServerDns()
|
||||
if err != nil {
|
||||
t.Errorf("invalid dns server instance: %s", err)
|
||||
}
|
||||
|
||||
if srvB != srv {
|
||||
t.Errorf("missmatch dns instances")
|
||||
}
|
||||
}
|
||||
@@ -11,14 +11,53 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
pfmock "github.com/netbirdio/netbird/iface/mocks"
|
||||
)
|
||||
|
||||
type mocWGIface struct {
|
||||
filter iface.PacketFilter
|
||||
}
|
||||
|
||||
func (w *mocWGIface) Name() string {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (w *mocWGIface) Address() iface.WGAddress {
|
||||
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
|
||||
return iface.WGAddress{
|
||||
IP: ip,
|
||||
Network: network,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetFilter() iface.PacketFilter {
|
||||
return w.filter
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetDevice() *iface.DeviceWrapper {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetInterfaceGUIDString() (string, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (w *mocWGIface) IsUserspaceBind() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error {
|
||||
w.filter = filter
|
||||
return nil
|
||||
}
|
||||
|
||||
var zoneRecords = []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "peera.netbird.cloud",
|
||||
@@ -29,6 +68,11 @@ var zoneRecords = []nbdns.SimpleRecord{
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
log.SetLevel(log.TraceLevel)
|
||||
formatter.SetTextFormatter(log.StandardLogger())
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
nameServers := []nbdns.NameServer{
|
||||
{
|
||||
@@ -224,7 +268,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil)
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -242,8 +286,6 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
dnsServer.dnsMuxMap = testCase.initUpstreamMap
|
||||
dnsServer.localResolver.registeredMap = testCase.initLocalMap
|
||||
dnsServer.updateSerial = testCase.initSerial
|
||||
// pretend we are running
|
||||
dnsServer.listenerIsRunning = true
|
||||
|
||||
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
|
||||
if err != nil {
|
||||
@@ -282,7 +324,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
os.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
_ = os.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet(nil)
|
||||
if err != nil {
|
||||
t.Errorf("create stdnet: %v", err)
|
||||
@@ -316,17 +358,17 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
}
|
||||
|
||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||
packetfilter.EXPECT().SetNetwork(ipNet)
|
||||
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any()).AnyTimes()
|
||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||
packetfilter.EXPECT().SetNetwork(ipNet)
|
||||
|
||||
if err := wgIface.SetFilter(packetfilter); err != nil {
|
||||
t.Errorf("set packet filter: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", nil)
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "")
|
||||
if err != nil {
|
||||
t.Errorf("create DNS server: %v", err)
|
||||
return
|
||||
@@ -421,21 +463,23 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
dnsServer := getDefaultServerWithNoHostManager(t, testCase.addrPort)
|
||||
|
||||
dnsServer.hostManager = newNoopHostMocker()
|
||||
dnsServer.listen()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if !dnsServer.listenerIsRunning {
|
||||
t.Fatal("dns server listener is not running")
|
||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort)
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
dnsServer.hostManager = newNoopHostMocker()
|
||||
err = dnsServer.service.Listen()
|
||||
if err != nil {
|
||||
t.Fatalf("dns server is not running: %s", err)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
defer dnsServer.Stop()
|
||||
err := dnsServer.localResolver.registerRecord(zoneRecords[0])
|
||||
err = dnsServer.localResolver.registerRecord(zoneRecords[0])
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
dnsServer.dnsMux.Handle("netbird.cloud", dnsServer.localResolver)
|
||||
dnsServer.service.RegisterMux("netbird.cloud", dnsServer.localResolver)
|
||||
|
||||
resolver := &net.Resolver{
|
||||
PreferGo: true,
|
||||
@@ -443,7 +487,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
d := net.Dialer{
|
||||
Timeout: time.Second * 5,
|
||||
}
|
||||
addr := fmt.Sprintf("%s:%d", dnsServer.runtimeIP, dnsServer.runtimePort)
|
||||
addr := fmt.Sprintf("%s:%d", dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||
conn, err := d.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
t.Log(err)
|
||||
@@ -478,7 +522,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
hostManager := &mockHostConfigurator{}
|
||||
server := DefaultServer{
|
||||
dnsMux: dns.DefaultServeMux,
|
||||
service: newServiceViaMemory(&mocWGIface{}),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
},
|
||||
@@ -541,62 +585,237 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func getDefaultServerWithNoHostManager(t *testing.T, addrPort string) *DefaultServer {
|
||||
mux := dns.NewServeMux()
|
||||
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||
wgIFace, err := createWgInterfaceWithBind(t)
|
||||
if err != nil {
|
||||
t.Fatal("failed to initialize wg interface")
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
|
||||
var parsedAddrPort *netip.AddrPort
|
||||
if addrPort != "" {
|
||||
parsed, err := netip.ParseAddrPort(addrPort)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
parsedAddrPort = &parsed
|
||||
var dnsList []string
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList)
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
defer dnsServer.Stop()
|
||||
|
||||
dnsServer.OnUpdatedHostDNSServer([]string{"8.8.8.8"})
|
||||
|
||||
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||
wgIFace, err := createWgInterfaceWithBind(t)
|
||||
if err != nil {
|
||||
t.Fatal("failed to initialize wg interface")
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"})
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
defer dnsServer.Stop()
|
||||
|
||||
// check initial state
|
||||
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
|
||||
dnsServer := &dns.Server{
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
|
||||
ds := &DefaultServer{
|
||||
ctx: ctx,
|
||||
ctxCancel: cancel,
|
||||
server: dnsServer,
|
||||
dnsMux: mux,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
localResolver: &localResolver{
|
||||
registeredMap: make(registrationMap),
|
||||
update := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
Enabled: true,
|
||||
Primary: true,
|
||||
},
|
||||
},
|
||||
customAddress: parsedAddrPort,
|
||||
}
|
||||
ds.evalRuntimeAddress()
|
||||
return ds
|
||||
}
|
||||
|
||||
func TestGetLastIPFromNetwork(t *testing.T) {
|
||||
tests := []struct {
|
||||
addr string
|
||||
ip string
|
||||
}{
|
||||
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
|
||||
{"192.168.0.0/30", "192.168.0.2"},
|
||||
{"192.168.0.0/16", "192.168.255.254"},
|
||||
{"192.168.0.0/24", "192.168.0.254"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
_, ipnet, err := net.ParseCIDR(tt.addr)
|
||||
if err != nil {
|
||||
t.Errorf("Error parsing CIDR: %v", err)
|
||||
return
|
||||
}
|
||||
err = dnsServer.UpdateDNSServer(1, update)
|
||||
if err != nil {
|
||||
t.Errorf("failed to update dns server: %s", err)
|
||||
}
|
||||
|
||||
lastIP := getLastIPFromNetwork(ipnet, 1)
|
||||
if lastIP != tt.ip {
|
||||
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
||||
}
|
||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed resolve zone record: %v", err)
|
||||
}
|
||||
if ips[0] != zoneRecords[0].RData {
|
||||
t.Fatalf("invalid zone record: %v", err)
|
||||
}
|
||||
|
||||
update2 := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{},
|
||||
}
|
||||
|
||||
err = dnsServer.UpdateDNSServer(2, update2)
|
||||
if err != nil {
|
||||
t.Errorf("failed to update dns server: %s", err)
|
||||
}
|
||||
|
||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
|
||||
ips, err = resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed resolve zone record: %v", err)
|
||||
}
|
||||
if ips[0] != zoneRecords[0].RData {
|
||||
t.Fatalf("invalid zone record: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
wgIFace, err := createWgInterfaceWithBind(t)
|
||||
if err != nil {
|
||||
t.Fatal("failed to initialize wg interface")
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"})
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
return
|
||||
}
|
||||
defer dnsServer.Stop()
|
||||
|
||||
// check initial state
|
||||
resolver := newDnsResolver(dnsServer.service.RuntimeIP(), dnsServer.service.RuntimePort())
|
||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
|
||||
update := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "netbird.cloud",
|
||||
Records: zoneRecords,
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.4.4"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
Domains: []string{"customdomain.com"},
|
||||
Primary: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = dnsServer.UpdateDNSServer(1, update)
|
||||
if err != nil {
|
||||
t.Errorf("failed to update dns server: %s", err)
|
||||
}
|
||||
|
||||
_, err = resolver.LookupHost(context.Background(), "netbird.io")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
ips, err := resolver.LookupHost(context.Background(), zoneRecords[0].Name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed resolve zone record: %v", err)
|
||||
}
|
||||
if ips[0] != zoneRecords[0].RData {
|
||||
t.Fatalf("invalid zone record: %v", err)
|
||||
}
|
||||
_, err = resolver.LookupHost(context.Background(), "customdomain.com")
|
||||
if err != nil {
|
||||
t.Errorf("failed to resolve: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||
defer os.Setenv("NB_WG_KERNEL_DISABLED", ov)
|
||||
|
||||
_ = os.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
newNet, err := stdnet.NewNet(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("create stdnet: %v", err)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", iface.DefaultMTU, nil, newNet)
|
||||
if err != nil {
|
||||
t.Fatalf("build interface wireguard: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = wgIface.Create()
|
||||
if err != nil {
|
||||
t.Fatalf("crate and init wireguard interface: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pf, err := uspfilter.Create(wgIface)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create uspfilter: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = wgIface.SetFilter(pf)
|
||||
if err != nil {
|
||||
t.Fatalf("set packet filter: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wgIface, nil
|
||||
}
|
||||
|
||||
func newDnsResolver(ip string, port int) *net.Resolver {
|
||||
return &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d := net.Dialer{
|
||||
Timeout: time.Second * 3,
|
||||
}
|
||||
addr := fmt.Sprintf("%s:%d", ip, port)
|
||||
return d.DialContext(ctx, network, addr)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
18
client/internal/dns/service.go
Normal file
18
client/internal/dns/service.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPort = 53
|
||||
)
|
||||
|
||||
type service interface {
|
||||
Listen() error
|
||||
Stop()
|
||||
RegisterMux(domain string, handler dns.Handler)
|
||||
DeregisterMux(key string)
|
||||
RuntimePort() int
|
||||
RuntimeIP() string
|
||||
}
|
||||
145
client/internal/dns/service_listener.go
Normal file
145
client/internal/dns/service_listener.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
customPort = 5053
|
||||
defaultIP = "127.0.0.1"
|
||||
customIP = "127.0.0.153"
|
||||
)
|
||||
|
||||
type serviceViaListener struct {
|
||||
wgInterface WGIface
|
||||
dnsMux *dns.ServeMux
|
||||
customAddr *netip.AddrPort
|
||||
server *dns.Server
|
||||
runtimeIP string
|
||||
runtimePort int
|
||||
listenerIsRunning bool
|
||||
listenerFlagLock sync.Mutex
|
||||
}
|
||||
|
||||
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
|
||||
mux := dns.NewServeMux()
|
||||
|
||||
s := &serviceViaListener{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: mux,
|
||||
customAddr: customAddr,
|
||||
server: &dns.Server{
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
UDPSize: 65535,
|
||||
},
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) Listen() error {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if s.listenerIsRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
s.runtimeIP, s.runtimePort, err = s.evalRuntimeAddress()
|
||||
if err != nil {
|
||||
log.Errorf("failed to eval runtime address: %s", err)
|
||||
return err
|
||||
}
|
||||
s.server.Addr = fmt.Sprintf("%s:%d", s.runtimeIP, s.runtimePort)
|
||||
|
||||
log.Debugf("starting dns on %s", s.server.Addr)
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
defer s.setListenerStatus(false)
|
||||
|
||||
err := s.server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.runtimePort, err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) Stop() {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if !s.listenerIsRunning {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.server.ShutdownContext(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("stopping dns server listener returned an error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
|
||||
s.dnsMux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) DeregisterMux(pattern string) {
|
||||
s.dnsMux.HandleRemove(pattern)
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) RuntimePort() int {
|
||||
return s.runtimePort
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) RuntimeIP() string {
|
||||
return s.runtimeIP
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) setListenerStatus(running bool) {
|
||||
s.listenerIsRunning = running
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) getFirstListenerAvailable() (string, int, error) {
|
||||
ips := []string{defaultIP, customIP}
|
||||
if runtime.GOOS != "darwin" {
|
||||
ips = append([]string{s.wgInterface.Address().IP.String()}, ips...)
|
||||
}
|
||||
ports := []int{defaultPort, customPort}
|
||||
for _, port := range ports {
|
||||
for _, ip := range ips {
|
||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||
if err == nil {
|
||||
err = probeListener.Close()
|
||||
if err != nil {
|
||||
log.Errorf("got an error closing the probe listener, error: %s", err)
|
||||
}
|
||||
return ip, port, nil
|
||||
}
|
||||
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
|
||||
}
|
||||
}
|
||||
return "", 0, fmt.Errorf("unable to find an unused ip and port combination. IPs tested: %v and ports %v", ips, ports)
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) evalRuntimeAddress() (string, int, error) {
|
||||
if s.customAddr != nil {
|
||||
return s.customAddr.Addr().String(), int(s.customAddr.Port()), nil
|
||||
}
|
||||
|
||||
return s.getFirstListenerAvailable()
|
||||
}
|
||||
139
client/internal/dns/service_memory.go
Normal file
139
client/internal/dns/service_memory.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type serviceViaMemory struct {
|
||||
wgInterface WGIface
|
||||
dnsMux *dns.ServeMux
|
||||
runtimeIP string
|
||||
runtimePort int
|
||||
udpFilterHookID string
|
||||
listenerIsRunning bool
|
||||
listenerFlagLock sync.Mutex
|
||||
}
|
||||
|
||||
func newServiceViaMemory(wgIface WGIface) *serviceViaMemory {
|
||||
s := &serviceViaMemory{
|
||||
wgInterface: wgIface,
|
||||
dnsMux: dns.NewServeMux(),
|
||||
|
||||
runtimeIP: getLastIPFromNetwork(wgIface.Address().Network, 1),
|
||||
runtimePort: defaultPort,
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) Listen() error {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if s.listenerIsRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
s.udpFilterHookID, err = s.filterDNSTraffic()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.listenerIsRunning = true
|
||||
|
||||
log.Debugf("dns service listening on: %s", s.RuntimeIP())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) Stop() {
|
||||
s.listenerFlagLock.Lock()
|
||||
defer s.listenerFlagLock.Unlock()
|
||||
|
||||
if !s.listenerIsRunning {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
|
||||
log.Errorf("unable to remove DNS packet hook: %s", err)
|
||||
}
|
||||
|
||||
s.listenerIsRunning = false
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
|
||||
s.dnsMux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) DeregisterMux(pattern string) {
|
||||
s.dnsMux.HandleRemove(pattern)
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) RuntimePort() int {
|
||||
return s.runtimePort
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) RuntimeIP() string {
|
||||
return s.runtimeIP
|
||||
}
|
||||
|
||||
func (s *serviceViaMemory) filterDNSTraffic() (string, error) {
|
||||
filter := s.wgInterface.GetFilter()
|
||||
if filter == nil {
|
||||
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
|
||||
}
|
||||
|
||||
firstLayerDecoder := layers.LayerTypeIPv4
|
||||
if s.wgInterface.Address().Network.IP.To4() == nil {
|
||||
firstLayerDecoder = layers.LayerTypeIPv6
|
||||
}
|
||||
|
||||
hook := func(packetData []byte) bool {
|
||||
// Decode the packet
|
||||
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
|
||||
|
||||
// Get the UDP layer
|
||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
||||
udp := udpLayer.(*layers.UDP)
|
||||
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(udp.Payload); err != nil {
|
||||
log.Tracef("parse DNS request: %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
writer := responseWriter{
|
||||
packet: packet,
|
||||
device: s.wgInterface.GetDevice().Device,
|
||||
}
|
||||
go s.dnsMux.ServeDNS(&writer, msg)
|
||||
return true
|
||||
}
|
||||
|
||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
|
||||
}
|
||||
|
||||
func getLastIPFromNetwork(network *net.IPNet, fromEnd int) string {
|
||||
// Calculate the last IP in the CIDR range
|
||||
var endIP net.IP
|
||||
for i := 0; i < len(network.IP); i++ {
|
||||
endIP = append(endIP, network.IP[i]|^network.Mask[i])
|
||||
}
|
||||
|
||||
// convert to big.Int
|
||||
endInt := big.NewInt(0)
|
||||
endInt.SetBytes(endIP)
|
||||
|
||||
// subtract fromEnd from the last ip
|
||||
fromEndBig := big.NewInt(int64(fromEnd))
|
||||
resultInt := big.NewInt(0)
|
||||
resultInt.Sub(endInt, fromEndBig)
|
||||
|
||||
return net.IP(resultInt.Bytes()).String()
|
||||
}
|
||||
31
client/internal/dns/service_memory_test.go
Normal file
31
client/internal/dns/service_memory_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetLastIPFromNetwork(t *testing.T) {
|
||||
tests := []struct {
|
||||
addr string
|
||||
ip string
|
||||
}{
|
||||
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
|
||||
{"192.168.0.0/30", "192.168.0.2"},
|
||||
{"192.168.0.0/16", "192.168.255.254"},
|
||||
{"192.168.0.0/24", "192.168.0.254"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
_, ipnet, err := net.ParseCIDR(tt.addr)
|
||||
if err != nil {
|
||||
t.Errorf("Error parsing CIDR: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
lastIP := getLastIPFromNetwork(ipnet, 1)
|
||||
if lastIP != tt.ip {
|
||||
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -53,7 +52,7 @@ type systemdDbusLinkDomainsInput struct {
|
||||
MatchOnly bool
|
||||
}
|
||||
|
||||
func newSystemdDbusConfigurator(wgInterface *iface.WGIface) (hostManager, error) {
|
||||
func newSystemdDbusConfigurator(wgInterface WGIface) (hostManager, error) {
|
||||
iface, err := net.InterfaceByName(wgInterface.Name())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
14
client/internal/dns/wgiface.go
Normal file
14
client/internal/dns/wgiface.go
Normal file
@@ -0,0 +1,14 @@
|
||||
//go:build !windows
|
||||
|
||||
package dns
|
||||
|
||||
import "github.com/netbirdio/netbird/iface"
|
||||
|
||||
// WGIface defines subset methods of interface required for manager
|
||||
type WGIface interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() iface.PacketFilter
|
||||
GetDevice() *iface.DeviceWrapper
|
||||
}
|
||||
13
client/internal/dns/wgiface_windows.go
Normal file
13
client/internal/dns/wgiface_windows.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package dns
|
||||
|
||||
import "github.com/netbirdio/netbird/iface"
|
||||
|
||||
// WGIface defines subset methods of interface required for manager
|
||||
type WGIface interface {
|
||||
Name() string
|
||||
Address() iface.WGAddress
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() iface.PacketFilter
|
||||
GetDevice() *iface.DeviceWrapper
|
||||
GetInterfaceGUIDString() (string, error)
|
||||
}
|
||||
@@ -190,23 +190,25 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
|
||||
var routes []*route.Route
|
||||
var dnsCfg *nbdns.Config
|
||||
|
||||
if runtime.GOOS == "android" {
|
||||
routes, dnsCfg, err = e.readInitialSettings()
|
||||
routes, err = e.readInitialSettings()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if e.dnsServer == nil {
|
||||
if e.dnsServer == nil {
|
||||
e.dnsServer = dns.NewDefaultServerPermanentUpstream(e.ctx, e.wgInterface, e.mobileDep.HostDNSAddresses)
|
||||
go e.mobileDep.DnsReadyListener.OnReady()
|
||||
}
|
||||
} else {
|
||||
// todo fix custom address
|
||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, dnsCfg)
|
||||
if err != nil {
|
||||
e.close()
|
||||
return err
|
||||
if e.dnsServer == nil {
|
||||
e.dnsServer, err = dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress)
|
||||
if err != nil {
|
||||
e.close()
|
||||
return err
|
||||
}
|
||||
}
|
||||
e.dnsServer = dnsServer
|
||||
}
|
||||
|
||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, routes)
|
||||
@@ -1045,14 +1047,13 @@ func (e *Engine) close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
|
||||
func (e *Engine) readInitialSettings() ([]*route.Route, error) {
|
||||
netMap, err := e.mgmClient.GetNetworkMap()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
routes := toRoutes(netMap.GetRoutes())
|
||||
dnsCfg := toDNSConfig(netMap.GetDNSConfig())
|
||||
return routes, &dnsCfg, nil
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
@@ -8,7 +9,9 @@ import (
|
||||
|
||||
// MobileDependency collect all dependencies for mobile platform
|
||||
type MobileDependency struct {
|
||||
TunAdapter iface.TunAdapter
|
||||
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
RouteListener routemanager.RouteListener
|
||||
TunAdapter iface.TunAdapter
|
||||
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
RouteListener routemanager.RouteListener
|
||||
HostDNSAddresses []string
|
||||
DnsReadyListener dns.ReadyListener
|
||||
}
|
||||
|
||||
@@ -6,8 +6,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -30,33 +28,13 @@ func genKey(format string, input string) string {
|
||||
|
||||
// NewFirewall if supported, returns an iptables manager, otherwise returns a nftables manager
|
||||
func NewFirewall(parentCTX context.Context) firewallManager {
|
||||
ctx, cancel := context.WithCancel(parentCTX)
|
||||
|
||||
if isIptablesSupported() {
|
||||
log.Debugf("iptables is supported")
|
||||
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
|
||||
return &iptablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
ipv4Client: ipv4Client,
|
||||
ipv6Client: ipv6Client,
|
||||
rules: make(map[string]map[string][]string),
|
||||
}
|
||||
manager, err := newNFTablesManager(parentCTX)
|
||||
if err == nil {
|
||||
log.Debugf("nftables firewall manager will be used")
|
||||
return manager
|
||||
}
|
||||
|
||||
log.Debugf("iptables is not supported, using nftables")
|
||||
|
||||
manager := &nftablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
chains: make(map[string]map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
}
|
||||
|
||||
return manager
|
||||
log.Debugf("fallback to iptables firewall manager: %s", err)
|
||||
return newIptablesManager(parentCTX)
|
||||
}
|
||||
|
||||
func getInPair(pair routerPair) routerPair {
|
||||
|
||||
@@ -49,6 +49,28 @@ type iptablesManager struct {
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func newIptablesManager(parentCtx context.Context) *iptablesManager {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if !isIptablesClientAvailable(ipv4Client) {
|
||||
log.Infof("iptables is missing for ipv4")
|
||||
ipv4Client = nil
|
||||
}
|
||||
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
if !isIptablesClientAvailable(ipv6Client) {
|
||||
log.Infof("iptables is missing for ipv6")
|
||||
ipv6Client = nil
|
||||
}
|
||||
|
||||
return &iptablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
ipv4Client: ipv4Client,
|
||||
ipv6Client: ipv6Client,
|
||||
rules: make(map[string]map[string][]string),
|
||||
}
|
||||
}
|
||||
|
||||
// CleanRoutingRules cleans existing iptables resources that we created by the agent
|
||||
func (i *iptablesManager) CleanRoutingRules() {
|
||||
i.mux.Lock()
|
||||
@@ -61,24 +83,28 @@ func (i *iptablesManager) CleanRoutingRules() {
|
||||
|
||||
log.Debug("flushing tables")
|
||||
errMSGFormat := "iptables: failed cleaning %s chain %s,error: %v"
|
||||
err = i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||
if err != nil {
|
||||
log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
||||
if i.ipv4Client != nil {
|
||||
err = i.ipv4Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||
if err != nil {
|
||||
log.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
||||
}
|
||||
|
||||
err = i.ipv4Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
|
||||
if err != nil {
|
||||
log.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = i.ipv4Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
|
||||
if err != nil {
|
||||
log.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
||||
}
|
||||
if i.ipv6Client != nil {
|
||||
err = i.ipv6Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||
if err != nil {
|
||||
log.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
||||
}
|
||||
|
||||
err = i.ipv6Client.ClearAndDeleteChain(iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||
if err != nil {
|
||||
log.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
||||
}
|
||||
|
||||
err = i.ipv6Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
|
||||
if err != nil {
|
||||
log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
||||
err = i.ipv6Client.ClearAndDeleteChain(iptablesNatTable, iptablesRoutingNatChain)
|
||||
if err != nil {
|
||||
log.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("done cleaning up iptables rules")
|
||||
@@ -96,37 +122,41 @@ func (i *iptablesManager) RestoreOrCreateContainers() error {
|
||||
|
||||
errMSGFormat := "iptables: failed creating %s chain %s,error: %v"
|
||||
|
||||
err := createChain(i.ipv4Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
||||
if i.ipv4Client != nil {
|
||||
err := createChain(i.ipv4Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingForwardingChain, err)
|
||||
}
|
||||
|
||||
err = createChain(i.ipv4Client, iptablesNatTable, iptablesRoutingNatChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
||||
}
|
||||
|
||||
err = i.restoreRules(i.ipv4Client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("iptables: error while restoring ipv4 rules: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = createChain(i.ipv4Client, iptablesNatTable, iptablesRoutingNatChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesRoutingNatChain, err)
|
||||
if i.ipv6Client != nil {
|
||||
err := createChain(i.ipv6Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
||||
}
|
||||
|
||||
err = createChain(i.ipv6Client, iptablesNatTable, iptablesRoutingNatChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
||||
}
|
||||
|
||||
err = i.restoreRules(i.ipv6Client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("iptables: error while restoring ipv6 rules: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = createChain(i.ipv6Client, iptablesFilterTable, iptablesRoutingForwardingChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingForwardingChain, err)
|
||||
}
|
||||
|
||||
err = createChain(i.ipv6Client, iptablesNatTable, iptablesRoutingNatChain)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesRoutingNatChain, err)
|
||||
}
|
||||
|
||||
err = i.restoreRules(i.ipv4Client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("iptables: error while restoring ipv4 rules: %v", err)
|
||||
}
|
||||
|
||||
err = i.restoreRules(i.ipv6Client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("iptables: error while restoring ipv6 rules: %v", err)
|
||||
}
|
||||
|
||||
err = i.addJumpRules()
|
||||
err := i.addJumpRules()
|
||||
if err != nil {
|
||||
return fmt.Errorf("iptables: error while creating jump rules: %v", err)
|
||||
}
|
||||
@@ -140,34 +170,38 @@ func (i *iptablesManager) addJumpRules() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rule := append(iptablesDefaultForwardingRule, ipv4Forwarding)
|
||||
err = i.ipv4Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
if i.ipv4Client != nil {
|
||||
rule := append(iptablesDefaultForwardingRule, ipv4Forwarding)
|
||||
|
||||
err = i.ipv4Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[ipv4][ipv4Forwarding] = rule
|
||||
|
||||
rule = append(iptablesDefaultNatRule, ipv4Nat)
|
||||
err = i.ipv4Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[ipv4][ipv4Nat] = rule
|
||||
}
|
||||
|
||||
i.rules[ipv4][ipv4Forwarding] = rule
|
||||
if i.ipv6Client != nil {
|
||||
rule := append(iptablesDefaultForwardingRule, ipv6Forwarding)
|
||||
err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[ipv6][ipv6Forwarding] = rule
|
||||
|
||||
rule = append(iptablesDefaultNatRule, ipv4Nat)
|
||||
err = i.ipv4Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
rule = append(iptablesDefaultNatRule, ipv6Nat)
|
||||
err = i.ipv6Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[ipv6][ipv6Nat] = rule
|
||||
}
|
||||
i.rules[ipv4][ipv4Nat] = rule
|
||||
|
||||
rule = append(iptablesDefaultForwardingRule, ipv6Forwarding)
|
||||
err = i.ipv6Client.Insert(iptablesFilterTable, iptablesForwardChain, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[ipv6][ipv6Forwarding] = rule
|
||||
|
||||
rule = append(iptablesDefaultNatRule, ipv6Nat)
|
||||
err = i.ipv6Client.Insert(iptablesNatTable, iptablesPostRoutingChain, 1, rule...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i.rules[ipv6][ipv6Nat] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -177,35 +211,39 @@ func (i *iptablesManager) cleanJumpRules() error {
|
||||
var err error
|
||||
errMSGFormat := "iptables: failed cleaning rule from %s chain %s,err: %v"
|
||||
rule, found := i.rules[ipv4][ipv4Forwarding]
|
||||
if found {
|
||||
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Forwarding)
|
||||
err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesForwardChain, err)
|
||||
if i.ipv4Client != nil {
|
||||
if found {
|
||||
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Forwarding)
|
||||
err = i.ipv4Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesForwardChain, err)
|
||||
}
|
||||
}
|
||||
rule, found = i.rules[ipv4][ipv4Nat]
|
||||
if found {
|
||||
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Nat)
|
||||
err = i.ipv4Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
rule, found = i.rules[ipv4][ipv4Nat]
|
||||
if found {
|
||||
log.Debugf("iptables: removing %s rule: %s ", ipv4, ipv4Nat)
|
||||
err = i.ipv4Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv4, iptablesPostRoutingChain, err)
|
||||
if i.ipv6Client == nil {
|
||||
rule, found = i.rules[ipv6][ipv6Forwarding]
|
||||
if found {
|
||||
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Forwarding)
|
||||
err = i.ipv6Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesForwardChain, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
rule, found = i.rules[ipv6][ipv6Forwarding]
|
||||
if found {
|
||||
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Forwarding)
|
||||
err = i.ipv6Client.DeleteIfExists(iptablesFilterTable, iptablesForwardChain, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesForwardChain, err)
|
||||
}
|
||||
}
|
||||
rule, found = i.rules[ipv6][ipv6Nat]
|
||||
if found {
|
||||
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Nat)
|
||||
err = i.ipv6Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err)
|
||||
rule, found = i.rules[ipv6][ipv6Nat]
|
||||
if found {
|
||||
log.Debugf("iptables: removing %s rule: %s ", ipv6, ipv6Nat)
|
||||
err = i.ipv6Client.DeleteIfExists(iptablesNatTable, iptablesPostRoutingChain, rule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf(errMSGFormat, ipv6, iptablesPostRoutingChain, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -437,3 +475,8 @@ func getIptablesRuleType(table string) string {
|
||||
}
|
||||
return ruleType
|
||||
}
|
||||
|
||||
func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
||||
_, err := client.ListChains("filter")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
@@ -16,17 +16,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
ipv4Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
ipv6Client, _ := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
||||
|
||||
manager := &iptablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
ipv4Client: ipv4Client,
|
||||
ipv6Client: ipv6Client,
|
||||
rules: make(map[string]map[string][]string),
|
||||
}
|
||||
manager := newIptablesManager(context.TODO())
|
||||
|
||||
defer manager.CleanRoutingRules()
|
||||
|
||||
@@ -37,21 +27,21 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
|
||||
require.Len(t, manager.rules[ipv4], 2, "should have created minimal rules for ipv4")
|
||||
|
||||
exists, err := ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...)
|
||||
exists, err := manager.ipv4Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv4][ipv4Forwarding]...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesFilterTable, iptablesForwardChain)
|
||||
require.True(t, exists, "forwarding rule should exist")
|
||||
|
||||
exists, err = ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...)
|
||||
exists, err = manager.ipv4Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv4][ipv4Nat]...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv4, iptablesNatTable, iptablesPostRoutingChain)
|
||||
require.True(t, exists, "postrouting rule should exist")
|
||||
|
||||
require.Len(t, manager.rules[ipv6], 2, "should have created minimal rules for ipv6")
|
||||
|
||||
exists, err = ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...)
|
||||
exists, err = manager.ipv6Client.Exists(iptablesFilterTable, iptablesForwardChain, manager.rules[ipv6][ipv6Forwarding]...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesFilterTable, iptablesForwardChain)
|
||||
require.True(t, exists, "forwarding rule should exist")
|
||||
|
||||
exists, err = ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...)
|
||||
exists, err = manager.ipv6Client.Exists(iptablesNatTable, iptablesPostRoutingChain, manager.rules[ipv6][ipv6Nat]...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", ipv6, iptablesNatTable, iptablesPostRoutingChain)
|
||||
require.True(t, exists, "postrouting rule should exist")
|
||||
|
||||
@@ -64,13 +54,13 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
forward4RuleKey := genKey(forwardingFormat, pair.ID)
|
||||
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.source, pair.destination)
|
||||
|
||||
err = ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...)
|
||||
err = manager.ipv4Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward4Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
nat4RuleKey := genKey(natFormat, pair.ID)
|
||||
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.source, pair.destination)
|
||||
|
||||
err = ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...)
|
||||
err = manager.ipv4Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat4Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
pair = routerPair{
|
||||
@@ -83,13 +73,13 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
forward6RuleKey := genKey(forwardingFormat, pair.ID)
|
||||
forward6Rule := genRuleSpec(routingFinalForwardJump, forward6RuleKey, pair.source, pair.destination)
|
||||
|
||||
err = ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...)
|
||||
err = manager.ipv6Client.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forward6Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
nat6RuleKey := genKey(natFormat, pair.ID)
|
||||
nat6Rule := genRuleSpec(routingFinalNatJump, nat6RuleKey, pair.source, pair.destination)
|
||||
|
||||
err = ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...)
|
||||
err = manager.ipv6Client.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, nat6Rule...)
|
||||
require.NoError(t, err, "inserting rule should not return error")
|
||||
|
||||
delete(manager.rules, ipv4)
|
||||
|
||||
@@ -81,6 +81,25 @@ type nftablesManager struct {
|
||||
mux sync.Mutex
|
||||
}
|
||||
|
||||
func newNFTablesManager(parentCtx context.Context) (*nftablesManager, error) {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
|
||||
mgr := &nftablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
chains: make(map[string]map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
}
|
||||
|
||||
err := mgr.isSupported()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
// CleanRoutingRules cleans existing nftables rules from the system
|
||||
func (n *nftablesManager) CleanRoutingRules() {
|
||||
n.mux.Lock()
|
||||
@@ -386,6 +405,14 @@ func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *nftablesManager) isSupported() error {
|
||||
_, err := n.conn.ListChains()
|
||||
if err != nil {
|
||||
return fmt.Errorf("nftables is not supported: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getPayloadDirectives get expression directives based on ip version and direction
|
||||
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
|
||||
switch {
|
||||
|
||||
@@ -14,21 +14,16 @@ import (
|
||||
|
||||
func TestNftablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
|
||||
manager := &nftablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
chains: make(map[string]map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
manager, err := newNFTablesManager(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create nftables manager: %s", err)
|
||||
}
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer manager.CleanRoutingRules()
|
||||
|
||||
err := manager.RestoreOrCreateContainers()
|
||||
err = manager.RestoreOrCreateContainers()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
require.Len(t, manager.chains, 2, "should have created chains for ipv4 and ipv6")
|
||||
@@ -134,21 +129,16 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
|
||||
|
||||
for _, testCase := range insertRuleTestCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
|
||||
manager := &nftablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
chains: make(map[string]map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
manager, err := newNFTablesManager(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create nftables manager: %s", err)
|
||||
}
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer manager.CleanRoutingRules()
|
||||
|
||||
err := manager.RestoreOrCreateContainers()
|
||||
err = manager.RestoreOrCreateContainers()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
err = manager.InsertRoutingRules(testCase.inputPair)
|
||||
@@ -239,21 +229,16 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
|
||||
|
||||
for _, testCase := range removeRuleTestCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
|
||||
manager := &nftablesManager{
|
||||
ctx: ctx,
|
||||
stop: cancel,
|
||||
conn: &nftables.Conn{},
|
||||
chains: make(map[string]map[string]*nftables.Chain),
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
manager, err := newNFTablesManager(context.TODO())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create nftables manager: %s", err)
|
||||
}
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
defer manager.CleanRoutingRules()
|
||||
|
||||
err := manager.RestoreOrCreateContainers()
|
||||
err = manager.RestoreOrCreateContainers()
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
table := manager.tableIPv4
|
||||
|
||||
@@ -102,7 +102,7 @@ func (s *Server) Start() error {
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := internal.RunClient(ctx, config, s.statusRecorder, nil, nil, nil); err != nil {
|
||||
if err := internal.RunClient(ctx, config, s.statusRecorder); err != nil {
|
||||
log.Errorf("init connections: %v", err)
|
||||
}
|
||||
}()
|
||||
@@ -391,7 +391,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := internal.RunClient(ctx, s.config, s.statusRecorder, nil, nil, nil); err != nil {
|
||||
if err := internal.RunClient(ctx, s.config, s.statusRecorder); err != nil {
|
||||
log.Errorf("run client connection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,9 +2,6 @@ package ssh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/creack/pty"
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
@@ -13,11 +10,22 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/creack/pty"
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
|
||||
const DefaultSSHPort = 44338
|
||||
|
||||
// TerminalTimeout is the timeout for terminal session to be ready
|
||||
const TerminalTimeout = 10 * time.Second
|
||||
|
||||
// TerminalBackoffDelay is the delay between terminal session readiness checks
|
||||
const TerminalBackoffDelay = 500 * time.Millisecond
|
||||
|
||||
// DefaultSSHServer is a function that creates DefaultServer
|
||||
func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
|
||||
return newDefaultServer(hostKeyPEM, addr)
|
||||
@@ -137,6 +145,8 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
|
||||
}
|
||||
}()
|
||||
|
||||
log.Infof("Establishing SSH session for %s from host %s", session.User(), session.RemoteAddr().String())
|
||||
|
||||
localUser, err := userNameLookup(session.User())
|
||||
if err != nil {
|
||||
_, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint
|
||||
@@ -172,6 +182,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("Login command: %s", cmd.String())
|
||||
file, err := pty.Start(cmd)
|
||||
if err != nil {
|
||||
log.Errorf("failed starting SSH server %v", err)
|
||||
@@ -199,6 +210,7 @@ func (srv *DefaultServer) sessionHandler(session ssh.Session) {
|
||||
return
|
||||
}
|
||||
}
|
||||
log.Debugf("SSH session ended")
|
||||
}
|
||||
|
||||
func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
|
||||
@@ -206,17 +218,29 @@ func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
|
||||
// stdin
|
||||
_, err := io.Copy(file, session)
|
||||
if err != nil {
|
||||
_ = session.Exit(1)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
// stdout
|
||||
_, err := io.Copy(session, file)
|
||||
if err != nil {
|
||||
// AWS Linux 2 machines need some time to open the terminal so we need to wait for it
|
||||
timer := time.NewTimer(TerminalTimeout)
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
_, _ = session.Write([]byte("Reached timeout while opening connection\n"))
|
||||
_ = session.Exit(1)
|
||||
return
|
||||
default:
|
||||
// stdout
|
||||
writtenBytes, err := io.Copy(session, file)
|
||||
if err != nil && writtenBytes != 0 {
|
||||
_ = session.Exit(0)
|
||||
return
|
||||
}
|
||||
time.Sleep(TerminalBackoffDelay)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts SSH server. Blocking
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/pion/transport/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
@@ -34,7 +33,6 @@ func NewWGIFace(ifaceName string, address string, mtu int, tunAdapter TunAdapter
|
||||
func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
log.Debugf("create WireGuard interface %s", w.tun.DeviceName())
|
||||
return w.tun.Create(mIFaceArgs)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/pion/transport/v2"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// NewWGIFace Creates a new WireGuard interface instance
|
||||
@@ -38,6 +37,5 @@ func (w *WGIface) CreateOnMobile(mIFaceArgs MobileIFaceArguments) error {
|
||||
func (w *WGIface) Create() error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
log.Debugf("create WireGuard interface %s", w.tun.DeviceName())
|
||||
return w.tun.Create()
|
||||
}
|
||||
|
||||
@@ -34,6 +34,7 @@ func newTunDevice(address WGAddress, mtu int, tunAdapter TunAdapter, transportNe
|
||||
}
|
||||
|
||||
func (t *tunDevice) Create(mIFaceArgs MobileIFaceArguments) error {
|
||||
log.Info("create tun interface")
|
||||
var err error
|
||||
routesString := t.routesToString(mIFaceArgs.Routes)
|
||||
t.fd, err = t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, mIFaceArgs.Dns, routesString)
|
||||
|
||||
@@ -12,14 +12,14 @@ import (
|
||||
|
||||
func (c *tunDevice) Create() error {
|
||||
if WireGuardModuleIsLoaded() {
|
||||
log.Info("using kernel WireGuard")
|
||||
log.Infof("create tun interface with kernel WireGuard support: %s", c.DeviceName())
|
||||
return c.createWithKernel()
|
||||
}
|
||||
|
||||
if !tunModuleIsLoaded() {
|
||||
return fmt.Errorf("couldn't check or load tun module")
|
||||
}
|
||||
log.Info("using userspace WireGuard")
|
||||
log.Infof("create tun interface with userspace WireGuard support: %s", c.DeviceName())
|
||||
var err error
|
||||
c.netInterface, err = c.createWithUserspace()
|
||||
if err != nil {
|
||||
|
||||
@@ -56,7 +56,7 @@ func (c *tunDevice) createWithUserspace() (NetInterface, error) {
|
||||
c.wrapper = newDeviceWrapper(tunIface)
|
||||
|
||||
// We need to create a wireguard-go device and listen to configuration requests
|
||||
tunDev := device.NewDevice(tunIface, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
|
||||
tunDev := device.NewDevice(c.wrapper, c.iceBind, device.NewLogger(device.LogLevelSilent, "[netbird] "))
|
||||
err = tunDev.Up()
|
||||
if err != nil {
|
||||
_ = tunIface.Close()
|
||||
|
||||
@@ -97,16 +97,9 @@ curl "${NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT}" -q -o openid-configuration.js
|
||||
|
||||
export NETBIRD_AUTH_AUTHORITY=$(jq -r '.issuer' openid-configuration.json)
|
||||
export NETBIRD_AUTH_JWT_CERTS=$(jq -r '.jwks_uri' openid-configuration.json)
|
||||
export NETBIRD_AUTH_SUPPORTED_SCOPES=$(jq -r '.scopes_supported | join(" ")' openid-configuration.json)
|
||||
export NETBIRD_AUTH_TOKEN_ENDPOINT=$(jq -r '.token_endpoint' openid-configuration.json)
|
||||
export NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT=$(jq -r '.device_authorization_endpoint' openid-configuration.json)
|
||||
|
||||
if [ "$NETBIRD_USE_AUTH0" == "true" ]; then
|
||||
export NETBIRD_AUTH_SUPPORTED_SCOPES="openid profile email offline_access api email_verified"
|
||||
else
|
||||
export NETBIRD_AUTH_SUPPORTED_SCOPES="openid profile email offline_access api"
|
||||
fi
|
||||
|
||||
if [[ ! -z "${NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID}" ]]; then
|
||||
# user enabled Device Authorization Grant feature
|
||||
export NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="hosted"
|
||||
|
||||
@@ -11,7 +11,10 @@ NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT=""
|
||||
NETBIRD_AUTH_AUDIENCE=""
|
||||
# e.g. netbird-client
|
||||
NETBIRD_AUTH_CLIENT_ID=""
|
||||
NETBIRD_AUTH_CLIENT_SECRET=""
|
||||
# indicates the scopes that will be requested to the IDP
|
||||
NETBIRD_AUTH_SUPPORTED_SCOPES=""
|
||||
# NETBIRD_AUTH_CLIENT_SECRET is required only by Google workspace.
|
||||
# NETBIRD_AUTH_CLIENT_SECRET=""
|
||||
# if you want to use a custom claim for the user ID instead of 'sub', set it here
|
||||
# NETBIRD_AUTH_USER_ID_CLAIM=""
|
||||
# indicates whether to use Auth0 or not: true or false
|
||||
|
||||
@@ -6,6 +6,7 @@ NETBIRD_DOMAIN=$CI_NETBIRD_DOMAIN
|
||||
NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT="https://example.eu.auth0.com/.well-known/openid-configuration"
|
||||
# e.g. netbird-client
|
||||
NETBIRD_AUTH_CLIENT_ID=$CI_NETBIRD_AUTH_CLIENT_ID
|
||||
NETBIRD_AUTH_SUPPORTED_SCOPES=$CI_NETBIRD_AUTH_SUPPORTED_SCOPES
|
||||
NETBIRD_AUTH_CLIENT_SECRET=$CI_NETBIRD_AUTH_CLIENT_SECRET
|
||||
# indicates whether to use Auth0 or not: true or false
|
||||
NETBIRD_USE_AUTH0=$CI_NETBIRD_USE_AUTH0
|
||||
|
||||
@@ -34,6 +34,8 @@ const (
|
||||
PublicCategory = "public"
|
||||
PrivateCategory = "private"
|
||||
UnknownCategory = "unknown"
|
||||
GroupIssuedAPI = "api"
|
||||
GroupIssuedJWT = "jwt"
|
||||
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
||||
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
||||
DefaultPeerLoginExpiration = 24 * time.Hour
|
||||
@@ -51,6 +53,7 @@ type AccountManager interface {
|
||||
SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error)
|
||||
CreateUser(accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error)
|
||||
DeleteUser(accountID, initiatorUserID string, targetUserID string) error
|
||||
InviteUser(accountID string, initiatorUserID string, targetUserID string) error
|
||||
ListSetupKeys(accountID, userID string) ([]*SetupKey, error)
|
||||
SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error)
|
||||
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
|
||||
@@ -78,7 +81,7 @@ type AccountManager interface {
|
||||
GetGroup(accountId, groupID string) (*Group, error)
|
||||
SaveGroup(accountID, userID string, group *Group) error
|
||||
UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error)
|
||||
DeleteGroup(accountId, groupID string) error
|
||||
DeleteGroup(accountId, userId, groupID string) error
|
||||
ListGroups(accountId string) ([]*Group, error)
|
||||
GroupAddPeer(accountId, groupID, peerID string) error
|
||||
GroupDeletePeer(accountId, groupID, peerKey string) error
|
||||
@@ -139,6 +142,13 @@ type Settings struct {
|
||||
// PeerLoginExpiration is a setting that indicates when peer login expires.
|
||||
// Applies to all peers that have Peer.LoginExpirationEnabled set to true.
|
||||
PeerLoginExpiration time.Duration
|
||||
|
||||
// JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName
|
||||
// and add it to account groups.
|
||||
JWTGroupsEnabled bool
|
||||
|
||||
// JWTGroupsClaimName from which we extract groups name to add it to account groups
|
||||
JWTGroupsClaimName string
|
||||
}
|
||||
|
||||
// Copy copies the Settings struct
|
||||
@@ -146,6 +156,8 @@ func (s *Settings) Copy() *Settings {
|
||||
return &Settings{
|
||||
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
|
||||
PeerLoginExpiration: s.PeerLoginExpiration,
|
||||
JWTGroupsEnabled: s.JWTGroupsEnabled,
|
||||
JWTGroupsClaimName: s.JWTGroupsClaimName,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -612,6 +624,28 @@ func (a *Account) GetPeer(peerID string) *Peer {
|
||||
return a.Peers[peerID]
|
||||
}
|
||||
|
||||
// AddJWTGroups to existed groups if they does not exists
|
||||
func (a *Account) AddJWTGroups(groups []string) (int, error) {
|
||||
existedGroups := make(map[string]*Group)
|
||||
for _, g := range a.Groups {
|
||||
existedGroups[g.Name] = g
|
||||
}
|
||||
|
||||
var count int
|
||||
for _, name := range groups {
|
||||
if _, ok := existedGroups[name]; !ok {
|
||||
id := xid.New().String()
|
||||
a.Groups[id] = &Group{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Issued: GroupIssuedJWT,
|
||||
}
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// BuildManager creates a new DefaultAccountManager with a provided Store
|
||||
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
|
||||
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store,
|
||||
@@ -1241,6 +1275,38 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat
|
||||
}
|
||||
}
|
||||
|
||||
if account.Settings.JWTGroupsEnabled {
|
||||
if account.Settings.JWTGroupsClaimName == "" {
|
||||
log.Errorf("JWT groups are enabled but no claim name is set")
|
||||
return account, user, nil
|
||||
}
|
||||
if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
|
||||
if slice, ok := claim.([]interface{}); ok {
|
||||
var groups []string
|
||||
for _, item := range slice {
|
||||
if g, ok := item.(string); ok {
|
||||
groups = append(groups, g)
|
||||
} else {
|
||||
log.Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item)
|
||||
}
|
||||
}
|
||||
n, err := account.AddJWTGroups(groups)
|
||||
if err != nil {
|
||||
log.Errorf("failed to add JWT groups: %v", err)
|
||||
}
|
||||
if n > 0 {
|
||||
if err := am.Store.SaveAccount(account); err != nil {
|
||||
log.Errorf("failed to save account: %v", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName)
|
||||
}
|
||||
}
|
||||
|
||||
return account, user, nil
|
||||
}
|
||||
|
||||
@@ -1344,8 +1410,9 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
|
||||
func addAllGroup(account *Account) error {
|
||||
if len(account.Groups) == 0 {
|
||||
allGroup := &Group{
|
||||
ID: xid.New().String(),
|
||||
Name: "All",
|
||||
ID: xid.New().String(),
|
||||
Name: "All",
|
||||
Issued: GroupIssuedAPI,
|
||||
}
|
||||
for _, peer := range account.Peers {
|
||||
allGroup.Peers = append(allGroup.Peers, peer.ID)
|
||||
@@ -1373,33 +1440,28 @@ func addAllGroup(account *Account) error {
|
||||
}
|
||||
|
||||
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
|
||||
func newAccountWithId(accountId, userId, domain string) *Account {
|
||||
func newAccountWithId(accountID, userID, domain string) *Account {
|
||||
log.Debugf("creating new account")
|
||||
|
||||
setupKeys := make(map[string]*SetupKey)
|
||||
defaultKey := GenerateDefaultSetupKey()
|
||||
oneOffKey := GenerateSetupKey("One-off key", SetupKeyOneOff, DefaultSetupKeyDuration, []string{},
|
||||
SetupKeyUnlimitedUsage)
|
||||
setupKeys[defaultKey.Key] = defaultKey
|
||||
setupKeys[oneOffKey.Key] = oneOffKey
|
||||
network := NewNetwork()
|
||||
peers := make(map[string]*Peer)
|
||||
users := make(map[string]*User)
|
||||
routes := make(map[string]*route.Route)
|
||||
setupKeys := map[string]*SetupKey{}
|
||||
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
||||
users[userId] = NewAdminUser(userId)
|
||||
users[userID] = NewAdminUser(userID)
|
||||
dnsSettings := &DNSSettings{
|
||||
DisabledManagementGroups: make([]string, 0),
|
||||
}
|
||||
log.Debugf("created new account %s with setup key %s", accountId, defaultKey.Key)
|
||||
log.Debugf("created new account %s", accountID)
|
||||
|
||||
acc := &Account{
|
||||
Id: accountId,
|
||||
Id: accountID,
|
||||
SetupKeys: setupKeys,
|
||||
Network: network,
|
||||
Peers: peers,
|
||||
Users: users,
|
||||
CreatedBy: userId,
|
||||
CreatedBy: userID,
|
||||
Domain: domain,
|
||||
Routes: routes,
|
||||
NameServerGroups: nameServersGroups,
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -53,7 +54,7 @@ func verifyNewAccountHasDefaultFields(t *testing.T, account *Account, createdBy
|
||||
t.Errorf("expected account to have len(Peers) = %v, got %v", 0, len(account.Peers))
|
||||
}
|
||||
|
||||
if len(account.SetupKeys) != 2 {
|
||||
if len(account.SetupKeys) != 0 {
|
||||
t.Errorf("expected account to have len(SetupKeys) = %v, got %v", 2, len(account.SetupKeys))
|
||||
}
|
||||
|
||||
@@ -460,6 +461,69 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
|
||||
userId := "user-id"
|
||||
domain := "test.domain"
|
||||
|
||||
initAccount := newAccountWithId("", userId, domain)
|
||||
manager, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID := initAccount.Id
|
||||
_, err = manager.GetAccountByUserOrAccountID(userId, accountID, domain)
|
||||
require.NoError(t, err, "create init user failed")
|
||||
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
AccountId: accountID,
|
||||
Domain: domain,
|
||||
UserId: userId,
|
||||
DomainCategory: "test-category",
|
||||
Raw: jwt.MapClaims{"idp-groups": []interface{}{"group1", "group2"}},
|
||||
}
|
||||
|
||||
t.Run("JWT groups disabled", func(t *testing.T) {
|
||||
account, _, err := manager.GetAccountFromToken(claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
require.Len(t, account.Groups, 1, "only ALL group should exists")
|
||||
})
|
||||
|
||||
t.Run("JWT groups enabled without claim name", func(t *testing.T) {
|
||||
initAccount.Settings.JWTGroupsEnabled = true
|
||||
err := manager.Store.SaveAccount(initAccount)
|
||||
require.NoError(t, err, "save account failed")
|
||||
|
||||
account, _, err := manager.GetAccountFromToken(claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
|
||||
})
|
||||
|
||||
t.Run("JWT groups enabled", func(t *testing.T) {
|
||||
initAccount.Settings.JWTGroupsEnabled = true
|
||||
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
|
||||
err := manager.Store.SaveAccount(initAccount)
|
||||
require.NoError(t, err, "save account failed")
|
||||
|
||||
account, _, err := manager.GetAccountFromToken(claims)
|
||||
require.NoError(t, err, "get account by token failed")
|
||||
require.Len(t, account.Groups, 3, "groups should be added to the account")
|
||||
|
||||
groupsByNames := map[string]*Group{}
|
||||
for _, g := range account.Groups {
|
||||
groupsByNames[g.Name] = g
|
||||
}
|
||||
|
||||
g1, ok := groupsByNames["group1"]
|
||||
require.True(t, ok, "group1 should be added to the account")
|
||||
require.Equal(t, g1.Name, "group1", "group1 name should match")
|
||||
require.Equal(t, g1.Issued, GroupIssuedJWT, "group1 issued should match")
|
||||
|
||||
g2, ok := groupsByNames["group2"]
|
||||
require.True(t, ok, "group2 should be added to the account")
|
||||
require.Equal(t, g2.Name, "group2", "group2 name should match")
|
||||
require.Equal(t, g2.Issued, GroupIssuedJWT, "group2 issued should match")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccountManager_GetAccountFromPAT(t *testing.T) {
|
||||
store := newStore(t)
|
||||
account := newAccountWithId("account_id", "testuser", "")
|
||||
@@ -704,20 +768,21 @@ func TestAccountManager_AddPeer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
account, err := createAccount(manager, "test_account", "account_creator", "netbird.cloud")
|
||||
userID := "account_creator"
|
||||
account, err := createAccount(manager, "test_account", userID, "netbird.cloud")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
serial := account.Network.CurrentSerial() // should be 0
|
||||
|
||||
var setupKey *SetupKey
|
||||
for _, key := range account.SetupKeys {
|
||||
setupKey = key
|
||||
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if setupKey == nil {
|
||||
t.Errorf("expecting account to have a default setup key")
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -858,16 +923,13 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var setupKey *SetupKey
|
||||
for _, key := range account.SetupKeys {
|
||||
setupKey = key
|
||||
if setupKey.Type == SetupKeyReusable {
|
||||
break
|
||||
}
|
||||
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if setupKey == nil {
|
||||
t.Errorf("expecting account to have a default setup key")
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1021,7 +1083,10 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.DeleteGroup(account.Id, group.ID); err != nil {
|
||||
// clean policy is pre requirement for delete group
|
||||
_ = manager.DeletePolicy(account.Id, policy.ID, userID)
|
||||
|
||||
if err := manager.DeleteGroup(account.Id, "", group.ID); err != nil {
|
||||
t.Errorf("delete group: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1042,9 +1107,14 @@ func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var setupKey *SetupKey
|
||||
for _, key := range account.SetupKeys {
|
||||
setupKey = key
|
||||
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
}
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
|
||||
@@ -95,6 +95,8 @@ const (
|
||||
UserBlocked
|
||||
// UserUnblocked indicates that a user unblocked another user
|
||||
UserUnblocked
|
||||
// GroupDeleted indicates that a user deleted group
|
||||
GroupDeleted
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -192,6 +194,8 @@ const (
|
||||
UserBlockedMessage string = "User blocked"
|
||||
// UserUnblockedMessage is a human-readable text message of the UserUnblocked activity
|
||||
UserUnblockedMessage string = "User unblocked"
|
||||
// GroupDeletedMessage is a human-readable text message of the GroupDeleted activity
|
||||
GroupDeletedMessage string = "Group deleted"
|
||||
)
|
||||
|
||||
// Activity that triggered an Event
|
||||
@@ -294,6 +298,8 @@ func (a Activity) Message() string {
|
||||
return UserBlockedMessage
|
||||
case UserUnblocked:
|
||||
return UserUnblockedMessage
|
||||
case GroupDeleted:
|
||||
return GroupDeletedMessage
|
||||
default:
|
||||
return "UNKNOWN_ACTIVITY"
|
||||
}
|
||||
@@ -342,6 +348,8 @@ func (a Activity) StringCode() string {
|
||||
return "group.add"
|
||||
case GroupUpdated:
|
||||
return "group.update"
|
||||
case GroupDeleted:
|
||||
return "group.delete"
|
||||
case GroupRemovedFromPeer:
|
||||
return "peer.group.delete"
|
||||
case GroupAddedToPeer:
|
||||
|
||||
@@ -2,13 +2,15 @@ package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const defaultTTL = 300
|
||||
@@ -199,8 +201,10 @@ func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
|
||||
for _, gID := range nsGroup.Groups {
|
||||
_, found := groupList[gID]
|
||||
if found {
|
||||
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
|
||||
break
|
||||
if !peerIsNameserver(account.GetPeer(peerID), nsGroup) {
|
||||
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -208,6 +212,16 @@ func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
|
||||
return peerNSGroups
|
||||
}
|
||||
|
||||
// peerIsNameserver returns true if the peer is a nameserver for a nsGroup
|
||||
func peerIsNameserver(peer *Peer, nsGroup *nbdns.NameServerGroup) bool {
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
if peer.IP.Equal(ns.IP.AsSlice()) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func addPeerLabelsToAccount(account *Account, peerLabels lookupMap) {
|
||||
for _, peer := range account.Peers {
|
||||
label, err := getPeerHostLabel(peer.Name, peerLabels)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
@@ -17,6 +19,7 @@ const (
|
||||
dnsAccountID = "testingAcc"
|
||||
dnsAdminUserID = "testingAdminUser"
|
||||
dnsRegularUserID = "testingRegularUser"
|
||||
dnsNSGroup1 = "ns1"
|
||||
)
|
||||
|
||||
func TestGetDNSSettings(t *testing.T) {
|
||||
@@ -163,6 +166,7 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS config should have one custom zone for peers")
|
||||
require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled")
|
||||
require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS config should have no nameserver groups since peer 1 is NS for the only existing NS group")
|
||||
|
||||
dnsSettings := account.DNSSettings.Copy()
|
||||
dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID)
|
||||
@@ -174,11 +178,11 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS config should have no custom zone when peer belongs to a disabled group")
|
||||
require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS config should have local DNS service disabled when peer belongs to a disabled group")
|
||||
|
||||
peer2AccountDNSConfig, err := am.GetNetworkMap(peer2.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS config should have one custom zone for peers not in the disabled group")
|
||||
require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS config should have DNS service enabled for peers not in the disabled group")
|
||||
require.Len(t, peer2AccountDNSConfig.DNSConfig.NameServerGroups, 1, "updated DNS config should have 1 nameserver groups since peer 2 is part of the group All")
|
||||
}
|
||||
|
||||
func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
@@ -246,7 +250,7 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, _, err = am.AddPeer("", dnsAdminUserID, peer1)
|
||||
savedPeer1, _, err := am.AddPeer("", dnsAdminUserID, peer1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -284,6 +288,24 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
|
||||
account.Groups[newGroup1.ID] = newGroup1
|
||||
account.Groups[newGroup2.ID] = newGroup2
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account.NameServerGroups[dnsNSGroup1] = &dns.NameServerGroup{
|
||||
ID: dnsNSGroup1,
|
||||
Name: "ns-group-1",
|
||||
NameServers: []dns.NameServer{{
|
||||
IP: netip.MustParseAddr(savedPeer1.IP.String()),
|
||||
NSType: dns.UDPNameServerType,
|
||||
Port: dns.DefaultDNSPort,
|
||||
}},
|
||||
Primary: true,
|
||||
Enabled: true,
|
||||
Groups: []string{allGroup.ID},
|
||||
}
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -157,6 +157,14 @@ func restore(file string) (*FileStore, error) {
|
||||
addPeerLabelsToAccount(account, existingLabels)
|
||||
}
|
||||
|
||||
// TODO: delete this block after migration
|
||||
// Set API as issuer for groups which has not this field
|
||||
for _, group := range account.Groups {
|
||||
if group.Issued == "" {
|
||||
group.Issued = GroupIssuedAPI
|
||||
}
|
||||
}
|
||||
|
||||
allGroup, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
log.Errorf("unable to find the All group, this should happen only when migrate from a version that didn't support groups. Error: %v", err)
|
||||
@@ -326,7 +334,7 @@ func (s *FileStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error {
|
||||
|
||||
delete(s.HashedPAT2TokenID, hashedToken)
|
||||
|
||||
return s.persist(s.storeFile)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteTokenID2UserIDIndex removes an entry from the indexing map TokenID2UserID
|
||||
@@ -336,7 +344,7 @@ func (s *FileStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
||||
|
||||
delete(s.TokenID2UserID, tokenID)
|
||||
|
||||
return s.persist(s.storeFile)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountByPrivateDomain returns account by private domain
|
||||
|
||||
@@ -262,6 +262,7 @@ func TestRestore(t *testing.T) {
|
||||
require.Len(t, store.TokenID2UserID, 1, "failed to restore a FileStore wrong TokenID2UserID mapping length")
|
||||
}
|
||||
|
||||
// TODO: outdated, delete this
|
||||
func TestRestorePolicies_Migration(t *testing.T) {
|
||||
storeDir := t.TempDir()
|
||||
|
||||
@@ -296,6 +297,40 @@ func TestRestorePolicies_Migration(t *testing.T) {
|
||||
"failed to restore a FileStore file - missing Account Policies Sources")
|
||||
}
|
||||
|
||||
func TestRestoreGroups_Migration(t *testing.T) {
|
||||
storeDir := t.TempDir()
|
||||
|
||||
err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
store, err := NewFileStore(storeDir, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// create default group
|
||||
account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
|
||||
account.Groups = map[string]*Group{
|
||||
"cfefqs706sqkneg59g3g": {
|
||||
ID: "cfefqs706sqkneg59g3g",
|
||||
Name: "All",
|
||||
},
|
||||
}
|
||||
err = store.SaveAccount(account)
|
||||
require.NoError(t, err, "failed to save account")
|
||||
|
||||
// restore account with default group with empty Issue field
|
||||
if store, err = NewFileStore(storeDir, nil); err != nil {
|
||||
return
|
||||
}
|
||||
account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"]
|
||||
|
||||
require.Contains(t, account.Groups, "cfefqs706sqkneg59g3g", "failed to restore a FileStore file - missing Account Groups")
|
||||
require.Equal(t, GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark")
|
||||
}
|
||||
|
||||
func TestGetAccountByPrivateDomain(t *testing.T) {
|
||||
storeDir := t.TempDir()
|
||||
|
||||
|
||||
@@ -1,11 +1,23 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type GroupLinkError struct {
|
||||
Resource string
|
||||
Name string
|
||||
}
|
||||
|
||||
func (e *GroupLinkError) Error() string {
|
||||
return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name)
|
||||
}
|
||||
|
||||
// Group of the peers for ACL
|
||||
type Group struct {
|
||||
// ID of the group
|
||||
@@ -14,6 +26,9 @@ type Group struct {
|
||||
// Name visible in the UI
|
||||
Name string
|
||||
|
||||
// Issued of the group
|
||||
Issued string
|
||||
|
||||
// Peers list of the group
|
||||
Peers []string
|
||||
}
|
||||
@@ -45,9 +60,10 @@ func (g *Group) EventMeta() map[string]any {
|
||||
|
||||
func (g *Group) Copy() *Group {
|
||||
return &Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Peers: g.Peers[:],
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Issued: g.Issued,
|
||||
Peers: g.Peers[:],
|
||||
}
|
||||
}
|
||||
|
||||
@@ -199,15 +215,80 @@ func (am *DefaultAccountManager) UpdateGroup(accountID string,
|
||||
}
|
||||
|
||||
// DeleteGroup object of the peers
|
||||
func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountId)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
account, err := am.Store.GetAccount(accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
g, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// check route links
|
||||
for _, r := range account.Routes {
|
||||
for _, g := range r.Groups {
|
||||
if g == groupID {
|
||||
return &GroupLinkError{"route", r.NetID}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check DNS links
|
||||
for _, dns := range account.NameServerGroups {
|
||||
for _, g := range dns.Groups {
|
||||
if g == groupID {
|
||||
return &GroupLinkError{"name server groups", dns.Name}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check ACL links
|
||||
for _, policy := range account.Policies {
|
||||
for _, rule := range policy.Rules {
|
||||
for _, src := range rule.Sources {
|
||||
if src == groupID {
|
||||
return &GroupLinkError{"policy", policy.Name}
|
||||
}
|
||||
}
|
||||
|
||||
for _, dst := range rule.Destinations {
|
||||
if dst == groupID {
|
||||
return &GroupLinkError{"policy", policy.Name}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check setup key links
|
||||
for _, setupKey := range account.SetupKeys {
|
||||
for _, grp := range setupKey.AutoGroups {
|
||||
if grp == groupID {
|
||||
return &GroupLinkError{"setup key", setupKey.Name}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check user links
|
||||
for _, user := range account.Users {
|
||||
for _, grp := range user.AutoGroups {
|
||||
if grp == groupID {
|
||||
return &GroupLinkError{"user", user.Id}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check DisabledManagementGroups
|
||||
for _, disabledMgmGrp := range account.DNSSettings.DisabledManagementGroups {
|
||||
if disabledMgmGrp == groupID {
|
||||
return &GroupLinkError{"disabled DNS management groups", g.Name}
|
||||
}
|
||||
}
|
||||
|
||||
delete(account.Groups, groupID)
|
||||
|
||||
account.Network.IncSerial()
|
||||
@@ -215,6 +296,8 @@ func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
am.storeEvent(userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
|
||||
|
||||
return am.updateAccountPeers(account)
|
||||
}
|
||||
|
||||
|
||||
164
management/server/group_test.go
Normal file
164
management/server/group_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const (
|
||||
groupAdminUserID = "testingAdminUser"
|
||||
)
|
||||
|
||||
func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
|
||||
am, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestGroupAccount(am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
groupID string
|
||||
expectedReason string
|
||||
}{
|
||||
{
|
||||
"route",
|
||||
"grp-for-route",
|
||||
"route",
|
||||
},
|
||||
{
|
||||
"name server groups",
|
||||
"grp-for-name-server-grp",
|
||||
"name server groups",
|
||||
},
|
||||
{
|
||||
"policy",
|
||||
"grp-for-policies",
|
||||
"policy",
|
||||
},
|
||||
{
|
||||
"setup keys",
|
||||
"grp-for-keys",
|
||||
"setup key",
|
||||
},
|
||||
{
|
||||
"users",
|
||||
"grp-for-users",
|
||||
"user",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
err = am.DeleteGroup(account.Id, "", testCase.groupID)
|
||||
if err == nil {
|
||||
t.Errorf("delete %s group successfully", testCase.groupID)
|
||||
return
|
||||
}
|
||||
|
||||
gErr, ok := err.(*GroupLinkError)
|
||||
if !ok {
|
||||
t.Error("invalid error type")
|
||||
return
|
||||
}
|
||||
if gErr.Resource != testCase.expectedReason {
|
||||
t.Errorf("invalid error case: %s, expected: %s", gErr.Resource, testCase.expectedReason)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
|
||||
accountID := "testingAcc"
|
||||
domain := "example.com"
|
||||
|
||||
groupForRoute := &Group{
|
||||
"grp-for-route",
|
||||
"Group for route",
|
||||
GroupIssuedAPI,
|
||||
make([]string, 0),
|
||||
}
|
||||
|
||||
groupForNameServerGroups := &Group{
|
||||
"grp-for-name-server-grp",
|
||||
"Group for name server groups",
|
||||
GroupIssuedAPI,
|
||||
make([]string, 0),
|
||||
}
|
||||
|
||||
groupForPolicies := &Group{
|
||||
"grp-for-policies",
|
||||
"Group for policies",
|
||||
GroupIssuedAPI,
|
||||
make([]string, 0),
|
||||
}
|
||||
|
||||
groupForSetupKeys := &Group{
|
||||
"grp-for-keys",
|
||||
"Group for setup keys",
|
||||
GroupIssuedAPI,
|
||||
make([]string, 0),
|
||||
}
|
||||
|
||||
groupForUsers := &Group{
|
||||
"grp-for-users",
|
||||
"Group for users",
|
||||
GroupIssuedAPI,
|
||||
make([]string, 0),
|
||||
}
|
||||
|
||||
routeResource := &route.Route{
|
||||
ID: "example route",
|
||||
Groups: []string{groupForRoute.ID},
|
||||
}
|
||||
|
||||
nameServerGroup := &nbdns.NameServerGroup{
|
||||
ID: "example name server group",
|
||||
Groups: []string{groupForNameServerGroups.ID},
|
||||
}
|
||||
|
||||
policy := &Policy{
|
||||
ID: "example policy",
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "example policy rule",
|
||||
Destinations: []string{groupForPolicies.ID},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
setupKey := &SetupKey{
|
||||
Id: "example setup key",
|
||||
AutoGroups: []string{groupForSetupKeys.ID},
|
||||
}
|
||||
|
||||
user := &User{
|
||||
Id: "example user",
|
||||
AutoGroups: []string{groupForUsers.ID},
|
||||
}
|
||||
account := newAccountWithId(accountID, groupAdminUserID, domain)
|
||||
account.Routes[routeResource.ID] = routeResource
|
||||
account.NameServerGroups[nameServerGroup.ID] = nameServerGroup
|
||||
account.Policies = append(account.Policies, policy)
|
||||
account.SetupKeys[setupKey.Id] = setupKey
|
||||
account.Users[user.Id] = user
|
||||
|
||||
err := am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForRoute)
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForNameServerGroups)
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForPolicies)
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForSetupKeys)
|
||||
_ = am.SaveGroup(accountID, groupAdminUserID, groupForUsers)
|
||||
|
||||
return am.Store.GetAccount(account.Id)
|
||||
}
|
||||
@@ -72,10 +72,19 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, &server.Settings{
|
||||
settings := &server.Settings{
|
||||
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
|
||||
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
|
||||
})
|
||||
}
|
||||
|
||||
if req.Settings.JwtGroupsEnabled != nil {
|
||||
settings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled
|
||||
}
|
||||
if req.Settings.JwtGroupsClaimName != nil {
|
||||
settings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName
|
||||
}
|
||||
|
||||
updatedAccount, err := h.accountManager.UpdateAccountSettings(accountID, user.Id, settings)
|
||||
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
@@ -93,6 +102,8 @@ func toAccountResponse(account *server.Account) *api.Account {
|
||||
Settings: api.AccountSettings{
|
||||
PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()),
|
||||
PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled,
|
||||
JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled,
|
||||
JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,6 +58,9 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := server.NewAdminUser("test_user")
|
||||
|
||||
sr := func(v string) *string { return &v }
|
||||
br := func(v bool) *bool { return &v }
|
||||
|
||||
handler := initAccountsTestData(&server.Account{
|
||||
Id: accountID,
|
||||
Domain: "hotmail.com",
|
||||
@@ -91,6 +94,8 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
expectedSettings: api.AccountSettings{
|
||||
PeerLoginExpiration: int(time.Hour.Seconds()),
|
||||
PeerLoginExpirationEnabled: false,
|
||||
JwtGroupsClaimName: sr(""),
|
||||
JwtGroupsEnabled: br(false),
|
||||
},
|
||||
expectedArray: true,
|
||||
expectedID: accountID,
|
||||
@@ -105,6 +110,24 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
expectedSettings: api.AccountSettings{
|
||||
PeerLoginExpiration: 15552000,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
JwtGroupsClaimName: sr(""),
|
||||
JwtGroupsEnabled: br(false),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
},
|
||||
{
|
||||
name: "PutAccount OK wiht JWT",
|
||||
expectedBody: true,
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/accounts/" + accountID,
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\"}}"),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedSettings: api.AccountSettings{
|
||||
PeerLoginExpiration: 15552000,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
JwtGroupsClaimName: sr("roles"),
|
||||
JwtGroupsEnabled: br(true),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
|
||||
@@ -54,6 +54,14 @@ components:
|
||||
description: Period of time after which peer login expires (seconds).
|
||||
type: integer
|
||||
example: 43200
|
||||
jwt_groups_enabled:
|
||||
description: Allows extract groups from JWT claim and add it to account groups.
|
||||
type: boolean
|
||||
example: true
|
||||
jwt_groups_claim_name:
|
||||
description: Name of the claim from which we extract groups names to add it to account groups.
|
||||
type: string
|
||||
example: "roles"
|
||||
required:
|
||||
- peer_login_expiration_enabled
|
||||
- peer_login_expiration
|
||||
@@ -462,6 +470,10 @@ components:
|
||||
description: Count of peers associated to the group
|
||||
type: integer
|
||||
example: 2
|
||||
issued:
|
||||
description: How group was issued by API or from JWT token
|
||||
type: string
|
||||
example: api
|
||||
required:
|
||||
- id
|
||||
- name
|
||||
@@ -1262,6 +1274,33 @@ paths:
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/users/{userId}/invite:
|
||||
post:
|
||||
summary: Resend user invitation
|
||||
description: Resend user invitation
|
||||
tags: [ Users ]
|
||||
security:
|
||||
- BearerAuth: [ ]
|
||||
- TokenAuth: [ ]
|
||||
parameters:
|
||||
- in: path
|
||||
name: userId
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
description: The unique identifier of a user
|
||||
responses:
|
||||
'200':
|
||||
description: Invite status code
|
||||
content: {}
|
||||
'400':
|
||||
"$ref": "#/components/responses/bad_request"
|
||||
'401':
|
||||
"$ref": "#/components/responses/requires_authentication"
|
||||
'403':
|
||||
"$ref": "#/components/responses/forbidden"
|
||||
'500':
|
||||
"$ref": "#/components/responses/internal_error"
|
||||
/api/peers:
|
||||
get:
|
||||
summary: List all Peers
|
||||
|
||||
@@ -129,6 +129,12 @@ type AccountRequest struct {
|
||||
|
||||
// AccountSettings defines model for AccountSettings.
|
||||
type AccountSettings struct {
|
||||
// JwtGroupsClaimName Name of the claim from which we extract groups names to add it to account groups.
|
||||
JwtGroupsClaimName *string `json:"jwt_groups_claim_name,omitempty"`
|
||||
|
||||
// JwtGroupsEnabled Allows extract groups from JWT claim and add it to account groups.
|
||||
JwtGroupsEnabled *bool `json:"jwt_groups_enabled,omitempty"`
|
||||
|
||||
// PeerLoginExpiration Period of time after which peer login expires (seconds).
|
||||
PeerLoginExpiration int `json:"peer_login_expiration"`
|
||||
|
||||
@@ -174,6 +180,9 @@ type Group struct {
|
||||
// Id Group ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// Issued How group was issued by API or from JWT token
|
||||
Issued *string `json:"issued,omitempty"`
|
||||
|
||||
// Name Group Name identifier
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -189,6 +198,9 @@ type GroupMinimum struct {
|
||||
// Id Group ID
|
||||
Id string `json:"id"`
|
||||
|
||||
// Issued How group was issued by API or from JWT token
|
||||
Issued *string `json:"issued,omitempty"`
|
||||
|
||||
// Name Group Name identifier
|
||||
Name string `json:"name"`
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
_, ok = account.Groups[groupID]
|
||||
eg, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
util.WriteError(status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
|
||||
return
|
||||
@@ -107,9 +107,10 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
peers = *req.Peers
|
||||
}
|
||||
group := server.Group{
|
||||
ID: groupID,
|
||||
Name: req.Name,
|
||||
Peers: peers,
|
||||
ID: groupID,
|
||||
Name: req.Name,
|
||||
Peers: peers,
|
||||
Issued: eg.Issued,
|
||||
}
|
||||
|
||||
if err := h.accountManager.SaveGroup(account.Id, user.Id, &group); err != nil {
|
||||
@@ -149,9 +150,10 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
peers = *req.Peers
|
||||
}
|
||||
group := server.Group{
|
||||
ID: xid.New().String(),
|
||||
Name: req.Name,
|
||||
Peers: peers,
|
||||
ID: xid.New().String(),
|
||||
Name: req.Name,
|
||||
Peers: peers,
|
||||
Issued: server.GroupIssuedAPI,
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveGroup(account.Id, user.Id, &group)
|
||||
@@ -166,7 +168,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
// DeleteGroup handles group deletion request
|
||||
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, _, err := h.accountManager.GetAccountFromToken(claims)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
@@ -190,8 +192,13 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteGroup(aID, groupID)
|
||||
err = h.accountManager.DeleteGroup(aID, user.Id, groupID)
|
||||
if err != nil {
|
||||
_, ok := err.(*server.GroupLinkError)
|
||||
if ok {
|
||||
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
@@ -237,6 +244,7 @@ func toGroupResponse(account *server.Account, group *server.Group) *api.Group {
|
||||
Id: group.ID,
|
||||
Name: group.Name,
|
||||
PeersCount: len(group.Peers),
|
||||
Issued: &group.Issued,
|
||||
}
|
||||
|
||||
for _, pid := range group.Peers {
|
||||
|
||||
@@ -11,17 +11,15 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
|
||||
"github.com/magiconair/properties/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
var TestPeers = map[string]*server.Peer{
|
||||
@@ -42,9 +40,17 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle
|
||||
if groupID != "idofthegroup" {
|
||||
return nil, status.Errorf(status.NotFound, "not found")
|
||||
}
|
||||
if groupID == "id-jwt-group" {
|
||||
return &server.Group{
|
||||
ID: "id-jwt-group",
|
||||
Name: "Default Group",
|
||||
Issued: server.GroupIssuedJWT,
|
||||
}, nil
|
||||
}
|
||||
return &server.Group{
|
||||
ID: "idofthegroup",
|
||||
Name: "Group",
|
||||
ID: "idofthegroup",
|
||||
Name: "Group",
|
||||
Issued: server.GroupIssuedAPI,
|
||||
}, nil
|
||||
},
|
||||
UpdateGroupFunc: func(_ string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) {
|
||||
@@ -80,11 +86,24 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle
|
||||
user.Id: user,
|
||||
},
|
||||
Groups: map[string]*server.Group{
|
||||
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}},
|
||||
"id-all": {ID: "id-all", Name: "All"},
|
||||
"id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: server.GroupIssuedJWT},
|
||||
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: server.GroupIssuedAPI},
|
||||
"id-all": {ID: "id-all", Name: "All", Issued: server.GroupIssuedAPI},
|
||||
},
|
||||
}, user, nil
|
||||
},
|
||||
DeleteGroupFunc: func(accountID, userId, groupID string) error {
|
||||
if groupID == "linked-grp" {
|
||||
return &server.GroupLinkError{
|
||||
Resource: "something",
|
||||
Name: "linked-grp",
|
||||
}
|
||||
}
|
||||
if groupID == "invalid-grp" {
|
||||
return fmt.Errorf("internal error")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
@@ -169,6 +188,8 @@ func TestGetGroup(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWriteGroup(t *testing.T) {
|
||||
groupIssuedAPI := "api"
|
||||
groupIssuedJWT := "jwt"
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
@@ -187,8 +208,9 @@ func TestWriteGroup(t *testing.T) {
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedGroup: &api.Group{
|
||||
Id: "id-was-set",
|
||||
Name: "Default POSTed Group",
|
||||
Id: "id-was-set",
|
||||
Name: "Default POSTed Group",
|
||||
Issued: &groupIssuedAPI,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -208,8 +230,9 @@ func TestWriteGroup(t *testing.T) {
|
||||
[]byte(`{"Name":"Default POSTed Group"}`)),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedGroup: &api.Group{
|
||||
Id: "id-existed",
|
||||
Name: "Default POSTed Group",
|
||||
Id: "id-existed",
|
||||
Name: "Default POSTed Group",
|
||||
Issued: &groupIssuedAPI,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -230,6 +253,19 @@ func TestWriteGroup(t *testing.T) {
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedBody: false,
|
||||
},
|
||||
{
|
||||
name: "Write Group PUT not not change Issue",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/groups/id-jwt-group",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte(`{"Name":"changed","Issued":"api"}`)),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedGroup: &api.Group{
|
||||
Id: "id-jwt-group",
|
||||
Name: "changed",
|
||||
Issued: &groupIssuedJWT,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
adminUser := server.NewAdminUser("test_user")
|
||||
@@ -271,3 +307,79 @@ func TestWriteGroup(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteGroup(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
expectedBody bool
|
||||
requestType string
|
||||
requestPath string
|
||||
}{
|
||||
{
|
||||
name: "Try to delete linked group",
|
||||
requestType: http.MethodDelete,
|
||||
requestPath: "/api/groups/linked-grp",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: true,
|
||||
},
|
||||
{
|
||||
name: "Try to cause internal error",
|
||||
requestType: http.MethodDelete,
|
||||
requestPath: "/api/groups/invalid-grp",
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: true,
|
||||
},
|
||||
{
|
||||
name: "Try to cause internal error",
|
||||
requestType: http.MethodDelete,
|
||||
requestPath: "/api/groups/invalid-grp",
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: true,
|
||||
},
|
||||
{
|
||||
name: "Delete group",
|
||||
requestType: http.MethodDelete,
|
||||
requestPath: "/api/groups/any-grp",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: false,
|
||||
},
|
||||
}
|
||||
|
||||
adminUser := server.NewAdminUser("test_user")
|
||||
p := initGroupTestData(adminUser)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/groups/{groupId}", p.DeleteGroup).Methods("DELETE")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
content, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("I don't know what I expected; %v", err)
|
||||
}
|
||||
|
||||
if status := recorder.Code; status != tc.expectedStatus {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
|
||||
status, tc.expectedStatus, string(content))
|
||||
return
|
||||
}
|
||||
|
||||
if tc.expectedBody {
|
||||
got := &util.ErrorResponse{}
|
||||
|
||||
if err = json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
assert.Equal(t, got.Code, tc.expectedStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,6 +113,7 @@ func (apiHandler *apiHandler) addUsersEndpoint() {
|
||||
apiHandler.Router.HandleFunc("/users/{userId}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
|
||||
apiHandler.Router.HandleFunc("/users/{userId}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS")
|
||||
apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS")
|
||||
apiHandler.Router.HandleFunc("/users/{userId}/invite", userHandler.InviteUser).Methods("POST", "OPTIONS")
|
||||
}
|
||||
|
||||
func (apiHandler *apiHandler) addUsersTokensEndpoint() {
|
||||
|
||||
@@ -208,6 +208,37 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(w, users)
|
||||
}
|
||||
|
||||
// InviteUser resend invitations to users who haven't activated their accounts,
|
||||
// prior to the expiration period.
|
||||
func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
account, user, err := h.accountManager.GetAccountFromToken(claims)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.InviteUser(account.Id, user.Id, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, emptyObject{})
|
||||
}
|
||||
|
||||
func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
|
||||
autoGroups := user.AutoGroups
|
||||
if autoGroups == nil {
|
||||
|
||||
@@ -98,6 +98,17 @@ func initUsersTestData() *UsersHandler {
|
||||
}
|
||||
return info, nil
|
||||
},
|
||||
InviteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error {
|
||||
if initiatorUserID != existingUserID {
|
||||
return status.Errorf(status.NotFound, "user with ID %s does not exists", initiatorUserID)
|
||||
}
|
||||
|
||||
if targetUserID == notFoundUserID {
|
||||
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
@@ -340,6 +351,51 @@ func TestCreateUser(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestInviteUser(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
requestType string
|
||||
requestPath string
|
||||
requestVars map[string]string
|
||||
}{
|
||||
{
|
||||
name: "Invite User with Existing User",
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/users/" + existingUserID + "/invite",
|
||||
expectedStatus: http.StatusOK,
|
||||
requestVars: map[string]string{"userId": existingUserID},
|
||||
},
|
||||
{
|
||||
name: "Invite User with missing user_id",
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/users/" + notFoundUserID + "/invite",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
requestVars: map[string]string{"userId": notFoundUserID},
|
||||
},
|
||||
}
|
||||
|
||||
userHandler := initUsersTestData()
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = mux.SetURLVars(req, tc.requestVars)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userHandler.InviteUser(rr, req)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if status := rr.Code; status != tc.expectedStatus {
|
||||
t.Fatalf("handler returned wrong status code: got %v want %v",
|
||||
status, tc.expectedStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteUser(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
|
||||
@@ -13,6 +13,11 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
type ErrorResponse struct {
|
||||
Message string `json:"message"`
|
||||
Code int `json:"code"`
|
||||
}
|
||||
|
||||
// WriteJSONObject simply writes object to the HTTP reponse in JSON format
|
||||
func WriteJSONObject(w http.ResponseWriter, obj interface{}) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -58,14 +63,9 @@ func (d *Duration) UnmarshalJSON(b []byte) error {
|
||||
|
||||
// WriteErrorResponse prepares and writes an error response i nJSON
|
||||
func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) {
|
||||
type errorResponse struct {
|
||||
Message string `json:"message"`
|
||||
Code int `json:"code"`
|
||||
}
|
||||
|
||||
w.WriteHeader(httpStatus)
|
||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||
err := json.NewEncoder(w).Encode(&errorResponse{
|
||||
err := json.NewEncoder(w).Encode(&ErrorResponse{
|
||||
Message: errMsg,
|
||||
Code: httpStatus,
|
||||
})
|
||||
|
||||
@@ -98,6 +98,11 @@ type userExportJobStatusResponse struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// userVerificationJobRequest is a user verification request struct
|
||||
type userVerificationJobRequest struct {
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
// auth0Profile represents an Auth0 user profile response
|
||||
type auth0Profile struct {
|
||||
AccountID string `json:"wt_account_id"`
|
||||
@@ -689,6 +694,48 @@ func (am *Auth0Manager) CreateUser(email string, name string, accountID string)
|
||||
return &createResp, nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (am *Auth0Manager) InviteUserByID(userID string) error {
|
||||
userVerificationReq := userVerificationJobRequest{
|
||||
UserID: userID,
|
||||
}
|
||||
|
||||
payload, err := am.helper.Marshal(userVerificationReq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := am.createPostRequest("/api/v2/jobs/verification-email", string(payload))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Debugf("Couldn't get job response %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("error while closing invite user response body: %v", err)
|
||||
}
|
||||
}()
|
||||
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to invite user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
|
||||
// If the status is "completed", then return the downloadLink
|
||||
func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) {
|
||||
|
||||
@@ -440,6 +440,12 @@ func (am *AuthentikManager) GetUserByEmail(email string) ([]*UserData, error) {
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (am *AuthentikManager) InviteUserByID(_ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
func (am *AuthentikManager) authenticationContext() (context.Context, error) {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
if err != nil {
|
||||
|
||||
@@ -448,6 +448,12 @@ func (am *AzureManager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
|
||||
return nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (am *AzureManager) InviteUserByID(_ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
func (am *AzureManager) getUserExtensions() ([]azureExtension, error) {
|
||||
q := url.Values{}
|
||||
q.Add("$select", extensionFields)
|
||||
|
||||
@@ -247,6 +247,12 @@ func (gm *GoogleWorkspaceManager) GetUserByEmail(email string) ([]*UserData, err
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (gm *GoogleWorkspaceManager) InviteUserByID(_ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey.
|
||||
// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it.
|
||||
// If that fails, it falls back to using the default Google credentials path.
|
||||
|
||||
@@ -17,6 +17,7 @@ type Manager interface {
|
||||
GetAllAccounts() (map[string][]*UserData, error)
|
||||
CreateUser(email string, name string, accountID string) (*UserData, error)
|
||||
GetUserByEmail(email string) ([]*UserData, error)
|
||||
InviteUserByID(userID string) error
|
||||
}
|
||||
|
||||
// ClientConfig defines common client configuration for all IdP manager
|
||||
|
||||
@@ -461,6 +461,12 @@ func (km *KeycloakManager) UpdateUserAppMetadata(userID string, appMetadata AppM
|
||||
return nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (km *KeycloakManager) InviteUserByID(_ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) {
|
||||
attrs := keycloakUserAttributes{}
|
||||
attrs.Set(wtAccountID, appMetadata.WTAccountID)
|
||||
|
||||
@@ -270,21 +270,32 @@ func (om *OktaManager) GetAllAccounts() (map[string][]*UserData, error) {
|
||||
|
||||
// UpdateUserAppMetadata updates user app metadata based on userID and metadata map.
|
||||
func (om *OktaManager) UpdateUserAppMetadata(userID string, appMetadata AppMetadata) error {
|
||||
var pendingInvite bool
|
||||
if appMetadata.WTPendingInvite != nil {
|
||||
pendingInvite = *appMetadata.WTPendingInvite
|
||||
user, resp, err := om.client.User.GetUser(context.Background(), userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, resp, err := om.client.User.UpdateUser(context.Background(), userID,
|
||||
okta.User{
|
||||
Profile: &okta.UserProfile{
|
||||
wtAccountID: appMetadata.WTAccountID,
|
||||
wtPendingInvite: pendingInvite,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to update user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
profile := *user.Profile
|
||||
|
||||
if appMetadata.WTPendingInvite != nil {
|
||||
profile[wtPendingInvite] = *appMetadata.WTPendingInvite
|
||||
}
|
||||
|
||||
if appMetadata.WTAccountID != "" {
|
||||
profile[wtAccountID] = appMetadata.WTAccountID
|
||||
}
|
||||
|
||||
user.Profile = &profile
|
||||
_, resp, err = om.client.User.UpdateUser(context.Background(), userID, *user, nil)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -302,10 +313,18 @@ func (om *OktaManager) UpdateUserAppMetadata(userID string, appMetadata AppMetad
|
||||
return nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (om *OktaManager) InviteUserByID(_ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// updateUserProfileSchema updates the Okta user schema to include custom fields,
|
||||
// wt_account_id and wt_pending_invite.
|
||||
func updateUserProfileSchema(client *okta.Client) error {
|
||||
required := true
|
||||
// Ensure Okta doesn't enforce user input for these fields, as they are solely used by Netbird
|
||||
userPermissions := []*okta.UserSchemaAttributePermission{{Action: "HIDE", Principal: "SELF"}}
|
||||
|
||||
_, resp, err := client.UserSchema.UpdateUserProfile(
|
||||
context.Background(),
|
||||
"default",
|
||||
@@ -316,18 +335,20 @@ func updateUserProfileSchema(client *okta.Client) error {
|
||||
Type: "object",
|
||||
Properties: map[string]*okta.UserSchemaAttribute{
|
||||
wtAccountID: {
|
||||
MaxLength: 100,
|
||||
MinLength: 1,
|
||||
Required: &required,
|
||||
Scope: "NONE",
|
||||
Title: "Wt Account Id",
|
||||
Type: "string",
|
||||
MaxLength: 100,
|
||||
MinLength: 1,
|
||||
Required: new(bool),
|
||||
Scope: "NONE",
|
||||
Title: "Wt Account Id",
|
||||
Type: "string",
|
||||
Permissions: userPermissions,
|
||||
},
|
||||
wtPendingInvite: {
|
||||
Required: new(bool),
|
||||
Scope: "NONE",
|
||||
Title: "Wt Pending Invite",
|
||||
Type: "boolean",
|
||||
Required: new(bool),
|
||||
Scope: "NONE",
|
||||
Title: "Wt Pending Invite",
|
||||
Type: "boolean",
|
||||
Permissions: userPermissions,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -441,6 +441,12 @@ func (zm *ZitadelManager) UpdateUserAppMetadata(userID string, appMetadata AppMe
|
||||
return nil
|
||||
}
|
||||
|
||||
// InviteUserByID resend invitations to users who haven't activated,
|
||||
// their accounts prior to the expiration period.
|
||||
func (zm *ZitadelManager) InviteUserByID(_ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// getUserMetadata requests user metadata from zitadel via ID.
|
||||
func (zm *ZitadelManager) getUserMetadata(userID string) ([]zitadelMetadata, error) {
|
||||
resource := fmt.Sprintf("users/%s/metadata/_search", userID)
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
package jwtclaims
|
||||
|
||||
import (
|
||||
"github.com/golang-jwt/jwt"
|
||||
)
|
||||
|
||||
// AuthorizationClaims stores authorization information from JWTs
|
||||
type AuthorizationClaims struct {
|
||||
UserId string
|
||||
AccountId string
|
||||
Domain string
|
||||
DomainCategory string
|
||||
|
||||
Raw jwt.MapClaims
|
||||
}
|
||||
|
||||
@@ -73,7 +73,9 @@ func NewClaimsExtractor(options ...ClaimsExtractorOption) *ClaimsExtractor {
|
||||
// FromToken extracts claims from the token (after auth)
|
||||
func (c *ClaimsExtractor) FromToken(token *jwt.Token) AuthorizationClaims {
|
||||
claims := token.Claims.(jwt.MapClaims)
|
||||
jwtClaims := AuthorizationClaims{}
|
||||
jwtClaims := AuthorizationClaims{
|
||||
Raw: claims,
|
||||
}
|
||||
userID, ok := claims[c.userIDClaim].(string)
|
||||
if !ok {
|
||||
return jwtClaims
|
||||
|
||||
@@ -48,6 +48,12 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
Domain: "test.com",
|
||||
AccountId: "testAcc",
|
||||
DomainCategory: "public",
|
||||
Raw: jwt.MapClaims{
|
||||
"https://login/wt_account_domain": "test.com",
|
||||
"https://login/wt_account_domain_category": "public",
|
||||
"https://login/wt_account_id": "testAcc",
|
||||
"sub": "test",
|
||||
},
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
@@ -59,6 +65,10 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
inputAuthorizationClaims: AuthorizationClaims{
|
||||
UserId: "test",
|
||||
AccountId: "testAcc",
|
||||
Raw: jwt.MapClaims{
|
||||
"https://login/wt_account_id": "testAcc",
|
||||
"sub": "test",
|
||||
},
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
@@ -70,6 +80,10 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
inputAuthorizationClaims: AuthorizationClaims{
|
||||
UserId: "test",
|
||||
Domain: "test.com",
|
||||
Raw: jwt.MapClaims{
|
||||
"https://login/wt_account_domain": "test.com",
|
||||
"sub": "test",
|
||||
},
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
@@ -82,6 +96,11 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
UserId: "test",
|
||||
Domain: "test.com",
|
||||
AccountId: "testAcc",
|
||||
Raw: jwt.MapClaims{
|
||||
"https://login/wt_account_domain": "test.com",
|
||||
"https://login/wt_account_id": "testAcc",
|
||||
"sub": "test",
|
||||
},
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
@@ -92,6 +111,9 @@ func TestExtractClaimsFromRequestContext(t *testing.T) {
|
||||
inputAudiance: "https://login/",
|
||||
inputAuthorizationClaims: AuthorizationClaims{
|
||||
UserId: "test",
|
||||
Raw: jwt.MapClaims{
|
||||
"sub": "test",
|
||||
},
|
||||
},
|
||||
testingFunc: require.EqualValues,
|
||||
expectedMSG: "extracted claims should match input claims",
|
||||
|
||||
@@ -32,7 +32,7 @@ type MockAccountManager struct {
|
||||
GetGroupFunc func(accountID, groupID string) (*server.Group, error)
|
||||
SaveGroupFunc func(accountID, userID string, group *server.Group) error
|
||||
UpdateGroupFunc func(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error)
|
||||
DeleteGroupFunc func(accountID, groupID string) error
|
||||
DeleteGroupFunc func(accountID, userId, groupID string) error
|
||||
ListGroupsFunc func(accountID string) ([]*server.Group, error)
|
||||
GroupAddPeerFunc func(accountID, groupID, peerKey string) error
|
||||
GroupDeletePeerFunc func(accountID, groupID, peerKey string) error
|
||||
@@ -81,6 +81,7 @@ type MockAccountManager struct {
|
||||
UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error)
|
||||
LoginPeerFunc func(login server.PeerLogin) (*server.Peer, *server.NetworkMap, error)
|
||||
SyncPeerFunc func(sync server.PeerSync) (*server.Peer, *server.NetworkMap, error)
|
||||
InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error
|
||||
}
|
||||
|
||||
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
|
||||
@@ -274,9 +275,9 @@ func (am *MockAccountManager) UpdateGroup(accountID string, groupID string, oper
|
||||
}
|
||||
|
||||
// DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface
|
||||
func (am *MockAccountManager) DeleteGroup(accountID, groupID string) error {
|
||||
func (am *MockAccountManager) DeleteGroup(accountId, userId, groupID string) error {
|
||||
if am.DeleteGroupFunc != nil {
|
||||
return am.DeleteGroupFunc(accountID, groupID)
|
||||
return am.DeleteGroupFunc(accountId, userId, groupID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteGroup is not implemented")
|
||||
}
|
||||
@@ -500,6 +501,13 @@ func (am *MockAccountManager) DeleteUser(accountID string, initiatorUserID strin
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) InviteUser(accountID string, initiatorUserID string, targetUserID string) error {
|
||||
if am.InviteUserFunc != nil {
|
||||
return am.InviteUserFunc(accountID, initiatorUserID, targetUserID)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method InviteUser is not implemented")
|
||||
}
|
||||
|
||||
// GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface
|
||||
func (am *MockAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
if am.GetNameServerGroupFunc != nil {
|
||||
|
||||
@@ -78,11 +78,14 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var setupKey *SetupKey
|
||||
for _, key := range account.SetupKeys {
|
||||
if key.Type == SetupKeyReusable {
|
||||
setupKey = key
|
||||
}
|
||||
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
}
|
||||
|
||||
peerKey1, err := wgtypes.GeneratePrivateKey()
|
||||
@@ -328,7 +331,15 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
setupKey := getSetupKey(account, SetupKeyReusable)
|
||||
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
}
|
||||
|
||||
peerKey1, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
@@ -394,7 +405,15 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
}
|
||||
|
||||
// two peers one added by a regular user and one with a setup key
|
||||
setupKey := getSetupKey(account, SetupKeyReusable)
|
||||
setupKey, err := manager.CreateSetupKey(account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
}
|
||||
peerKey1, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -470,13 +489,3 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) {
|
||||
}
|
||||
assert.NotNil(t, peer)
|
||||
}
|
||||
|
||||
func getSetupKey(account *Account, keyType SetupKeyType) *SetupKey {
|
||||
var setupKey *SetupKey
|
||||
for _, key := range account.SetupKeys {
|
||||
if key.Type == keyType {
|
||||
setupKey = key
|
||||
}
|
||||
}
|
||||
return setupKey
|
||||
}
|
||||
|
||||
@@ -2,9 +2,11 @@ package server
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
@@ -240,7 +242,15 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*Peer, int), fun
|
||||
peersExists := make(map[string]struct{})
|
||||
rules := make([]*FirewallRule, 0)
|
||||
peers := make([]*Peer, 0)
|
||||
|
||||
all, err := a.GetGroupAll()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get group all: %v", err)
|
||||
all = &Group{}
|
||||
}
|
||||
|
||||
return func(rule *PolicyRule, groupPeers []*Peer, direction int) {
|
||||
isAll := (len(all.Peers) - 1) == len(groupPeers)
|
||||
for _, peer := range groupPeers {
|
||||
if peer == nil {
|
||||
continue
|
||||
@@ -250,29 +260,33 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*Peer, int), fun
|
||||
peersExists[peer.ID] = struct{}{}
|
||||
}
|
||||
|
||||
fwRule := FirewallRule{
|
||||
fr := FirewallRule{
|
||||
PeerIP: peer.IP.String(),
|
||||
Direction: direction,
|
||||
Action: string(rule.Action),
|
||||
Protocol: string(rule.Protocol),
|
||||
}
|
||||
|
||||
ruleID := fmt.Sprintf("%s%d", peer.ID+peer.IP.String(), direction)
|
||||
ruleID += string(rule.Protocol) + string(rule.Action) + strings.Join(rule.Ports, ",")
|
||||
if isAll {
|
||||
fr.PeerIP = "0.0.0.0"
|
||||
}
|
||||
|
||||
ruleID := (rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||
fr.Protocol + fr.Action + strings.Join(rule.Ports, ","))
|
||||
if _, ok := rulesExists[ruleID]; ok {
|
||||
continue
|
||||
}
|
||||
rulesExists[ruleID] = struct{}{}
|
||||
|
||||
if len(rule.Ports) == 0 {
|
||||
rules = append(rules, &fwRule)
|
||||
rules = append(rules, &fr)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, port := range rule.Ports {
|
||||
addRule := fwRule
|
||||
addRule.Port = port
|
||||
rules = append(rules, &addRule)
|
||||
pr := fr // clone rule and add set new port
|
||||
pr.Port = port
|
||||
rules = append(rules, &pr)
|
||||
}
|
||||
}
|
||||
}, func() ([]*Peer, []*FirewallRule) {
|
||||
|
||||
@@ -126,6 +126,20 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
assert.Contains(t, peers, account.Peers["peerF"])
|
||||
|
||||
epectedFirewallRules := []*FirewallRule{
|
||||
{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: firewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
},
|
||||
{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: firewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.14.88",
|
||||
Direction: firewallRuleDirectionIN,
|
||||
|
||||
@@ -318,6 +318,46 @@ func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, t
|
||||
return nil
|
||||
}
|
||||
|
||||
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
|
||||
func (am *DefaultAccountManager) InviteUser(accountID string, initiatorUserID string, targetUserID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
if am.idpManager == nil {
|
||||
return status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
|
||||
}
|
||||
|
||||
// check if the user is already registered with this ID
|
||||
user, err := am.lookupUserInCache(targetUserID, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return status.Errorf(status.NotFound, "user account %s doesn't exist", targetUserID)
|
||||
}
|
||||
|
||||
// check if user account is already invited and account is not activated
|
||||
pendingInvite := user.AppMetadata.WTPendingInvite
|
||||
if pendingInvite == nil || !*pendingInvite {
|
||||
return status.Errorf(status.PreconditionFailed, "can't invite a user with an activated NetBird account")
|
||||
}
|
||||
|
||||
err = am.idpManager.InviteUserByID(user.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.storeEvent(initiatorUserID, user.ID, accountID, activity.UserInvited, nil)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreatePAT creates a new PAT for the given user
|
||||
func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
|
||||
Reference in New Issue
Block a user